Source code for nv_ingest_api.util.introspection.class_inspect
# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import inspect
from typing import Optional, Type, Union, Callable
from pydantic import BaseModel
[docs]
def find_pydantic_config_schema(
actor_class: Type,
base_class_to_find: Type,
param_name: str = "config",
) -> Optional[Type[BaseModel]]:
"""
Introspects a class's MRO to find a Pydantic model in its __init__ signature.
This function is designed to find the specific Pydantic configuration model
for a pipeline actor, which might be a direct class or a proxy object.
Parameters
----------
actor_class : Type
The actor class or proxy object to inspect.
base_class_to_find : Type
The specific base class (e.g., RaySource, RayStage) to look for when
resolving the true actor class from a proxy.
param_name : str, optional
The name of the __init__ parameter to inspect for the Pydantic schema,
by default "config".
Returns
-------
Optional[Type[BaseModel]]
The Pydantic BaseModel class if found, otherwise None.
"""
# 1. Find the actual class to inspect, handling proxy objects.
cls_to_inspect = None
if inspect.isclass(actor_class):
cls_to_inspect = actor_class
else:
for base in actor_class.__class__.__mro__:
if inspect.isclass(base) and issubclass(base, base_class_to_find) and base is not base_class_to_find:
cls_to_inspect = base
break
if not cls_to_inspect:
return None
# 2. Walk the MRO of the real class to find the __init__ with the typed parameter.
for cls in cls_to_inspect.__mro__:
if param_name in getattr(cls.__init__, "__annotations__", {}):
try:
init_sig = inspect.signature(cls.__init__)
config_param = init_sig.parameters.get(param_name)
if (
config_param
and config_param.annotation is not BaseModel
and issubclass(config_param.annotation, BaseModel)
):
return config_param.annotation # Found the schema
except (ValueError, TypeError):
# This class's __init__ is not inspectable (e.g., a C-extension), continue up the MRO.
continue
return None
[docs]
def find_pydantic_config_schema_for_callable(
callable_fn: Callable,
param_name: str = "stage_config",
) -> Optional[Type[BaseModel]]:
"""
Introspects a callable's signature to find a Pydantic model parameter.
This function is designed to find the specific Pydantic configuration model
for a pipeline callable function.
Parameters
----------
callable_fn : Callable
The callable function to inspect.
param_name : str, optional
The name of the parameter to inspect for the Pydantic schema,
by default "stage_config".
Returns
-------
Optional[Type[BaseModel]]
The Pydantic BaseModel class if found, otherwise None.
"""
try:
sig = inspect.signature(callable_fn)
config_param = sig.parameters.get(param_name)
if (
config_param
and config_param.annotation is not BaseModel
and hasattr(config_param.annotation, "__mro__")
and issubclass(config_param.annotation, BaseModel)
):
return config_param.annotation
except (ValueError, TypeError):
# Function signature is not inspectable
pass
return None
[docs]
def find_pydantic_config_schema_unified(
target: Union[Type, Callable],
base_class_to_find: Optional[Type] = None,
param_name: str = "config",
) -> Optional[Type[BaseModel]]:
"""
Unified function to find Pydantic schema for either classes or callables.
Parameters
----------
target : Union[Type, Callable]
The class or callable to inspect.
base_class_to_find : Optional[Type], optional
The specific base class to look for when resolving actor classes from proxies.
Only used for class inspection.
param_name : str, optional
The name of the parameter to inspect for the Pydantic schema.
For classes: defaults to "config"
For callables: should be "stage_config"
Returns
-------
Optional[Type[BaseModel]]
The Pydantic BaseModel class if found, otherwise None.
"""
if callable(target) and not inspect.isclass(target):
# Handle callable function
return find_pydantic_config_schema_for_callable(target, param_name)
elif inspect.isclass(target) or hasattr(target, "__class__"):
# Handle class or proxy object
if base_class_to_find is None:
# If no base class specified, we can't use the original function
return None
return find_pydantic_config_schema(target, base_class_to_find, param_name)
else:
return None