import functools import inspect import json from logging import getLogger from typing import ( Any, Callable, Dict, ForwardRef, List, Optional, Set, Tuple, Type, TypeVar, Union, get_args, ) from pydantic import BaseModel, Field from pydantic.version import VERSION as PYDANTIC_VERSION from typing_extensions import Annotated, Literal, get_args, get_origin T = TypeVar("T") __all__ = ( "JsonSchemaValue", "model_dump", "model_dump_json", "type2schema", "evaluate_forwardref", ) PYDANTIC_V1 = PYDANTIC_VERSION.startswith("1.") logger = getLogger(__name__) if not PYDANTIC_V1: from pydantic import TypeAdapter from pydantic._internal._typing_extra import ( eval_type_lenient as evaluate_forwardref, ) from pydantic.json_schema import JsonSchemaValue def type2schema(t: Any) -> JsonSchemaValue: """Convert a type to a JSON schema Args: t (Type): The type to convert Returns: JsonSchemaValue: The JSON schema """ return TypeAdapter(t).json_schema() def model_dump(model: BaseModel) -> Dict[str, Any]: """Convert a pydantic model to a dict Args: model (BaseModel): The model to convert Returns: Dict[str, Any]: The dict representation of the model """ return model.model_dump() def model_dump_json(model: BaseModel) -> str: """Convert a pydantic model to a JSON string Args: model (BaseModel): The model to convert Returns: str: The JSON string representation of the model """ return model.model_dump_json() # Remove this once we drop support for pydantic 1.x else: # pragma: no cover from pydantic import schema_of from pydantic.typing import ( evaluate_forwardref as evaluate_forwardref, # type: ignore[no-redef] ) JsonSchemaValue = Dict[str, Any] # type: ignore[misc] def type2schema(t: Any) -> JsonSchemaValue: """Convert a type to a JSON schema Args: t (Type): The type to convert Returns: JsonSchemaValue: The JSON schema """ if PYDANTIC_V1: if t is None: return {"type": "null"} elif get_origin(t) is Union: return { "anyOf": [type2schema(tt) for tt in get_args(t)] } elif get_origin(t) in [Tuple, tuple]: prefixItems = [type2schema(tt) for tt in get_args(t)] return { "maxItems": len(prefixItems), "minItems": len(prefixItems), "prefixItems": prefixItems, "type": "array", } d = schema_of(t) if "title" in d: d.pop("title") if "description" in d: d.pop("description") return d def model_dump(model: BaseModel) -> Dict[str, Any]: """Convert a pydantic model to a dict Args: model (BaseModel): The model to convert Returns: Dict[str, Any]: The dict representation of the model """ return model.dict() def model_dump_json(model: BaseModel) -> str: """Convert a pydantic model to a JSON string Args: model (BaseModel): The model to convert Returns: str: The JSON string representation of the model """ return model.json() def get_typed_annotation( annotation: Any, globalns: Dict[str, Any] ) -> Any: """Get the type annotation of a parameter. Args: annotation: The annotation of the parameter globalns: The global namespace of the function Returns: The type annotation of the parameter """ if isinstance(annotation, str): annotation = ForwardRef(annotation) annotation = evaluate_forwardref( annotation, globalns, globalns ) return annotation def get_typed_signature( call: Callable[..., Any] ) -> inspect.Signature: """Get the signature of a function with type annotations. Args: call: The function to get the signature for Returns: The signature of the function with type annotations """ signature = inspect.signature(call) globalns = getattr(call, "__globals__", {}) typed_params = [ inspect.Parameter( name=param.name, kind=param.kind, default=param.default, annotation=get_typed_annotation( param.annotation, globalns ), ) for param in signature.parameters.values() ] typed_signature = inspect.Signature(typed_params) return typed_signature def get_typed_return_annotation(call: Callable[..., Any]) -> Any: """Get the return annotation of a function. Args: call: The function to get the return annotation for Returns: The return annotation of the function """ signature = inspect.signature(call) annotation = signature.return_annotation if annotation is inspect.Signature.empty: return None globalns = getattr(call, "__globals__", {}) return get_typed_annotation(annotation, globalns) def get_param_annotations( typed_signature: inspect.Signature, ) -> Dict[str, Union[Annotated[Type[Any], str], Type[Any]]]: """Get the type annotations of the parameters of a function Args: typed_signature: The signature of the function with type annotations Returns: A dictionary of the type annotations of the parameters of the function """ return { k: v.annotation for k, v in typed_signature.parameters.items() if v.annotation is not inspect.Signature.empty } class Parameters(BaseModel): """Parameters of a function as defined by the OpenAI API""" type: Literal["object"] = "object" properties: Dict[str, JsonSchemaValue] required: List[str] class Function(BaseModel): """A function as defined by the OpenAI API""" description: Annotated[ str, Field(description="Description of the function") ] name: Annotated[str, Field(description="Name of the function")] parameters: Annotated[ Parameters, Field(description="Parameters of the function") ] class ToolFunction(BaseModel): """A function under tool as defined by the OpenAI API.""" type: Literal["function"] = "function" function: Annotated[ Function, Field(description="Function under tool") ] def get_parameter_json_schema( k: str, v: Any, default_values: Dict[str, Any] ) -> JsonSchemaValue: """Get a JSON schema for a parameter as defined by the OpenAI API Args: k: The name of the parameter v: The type of the parameter default_values: The default values of the parameters of the function Returns: A Pydanitc model for the parameter """ def type2description( k: str, v: Union[Annotated[Type[Any], str], Type[Any]] ) -> str: # handles Annotated if hasattr(v, "__metadata__"): retval = v.__metadata__[0] if isinstance(retval, str): return retval else: raise ValueError( f"Invalid description {retval} for parameter {k}, should be a string." ) else: return k schema = type2schema(v) if k in default_values: dv = default_values[k] schema["default"] = dv schema["description"] = type2description(k, v) return schema def get_required_params( typed_signature: inspect.Signature, ) -> List[str]: """Get the required parameters of a function Args: signature: The signature of the function as returned by inspect.signature Returns: A list of the required parameters of the function """ return [ k for k, v in typed_signature.parameters.items() if v.default == inspect.Signature.empty ] def get_default_values( typed_signature: inspect.Signature, ) -> Dict[str, Any]: """Get default values of parameters of a function Args: signature: The signature of the function as returned by inspect.signature Returns: A dictionary of the default values of the parameters of the function """ return { k: v.default for k, v in typed_signature.parameters.items() if v.default != inspect.Signature.empty } def get_parameters( required: List[str], param_annotations: Dict[ str, Union[Annotated[Type[Any], str], Type[Any]] ], default_values: Dict[str, Any], ) -> Parameters: """Get the parameters of a function as defined by the OpenAI API Args: required: The required parameters of the function hints: The type hints of the function as returned by typing.get_type_hints Returns: A Pydantic model for the parameters of the function """ return Parameters( properties={ k: get_parameter_json_schema(k, v, default_values) for k, v in param_annotations.items() if v is not inspect.Signature.empty }, required=required, ) def get_missing_annotations( typed_signature: inspect.Signature, required: List[str] ) -> Tuple[Set[str], Set[str]]: """Get the missing annotations of a function Ignores the parameters with default values as they are not required to be annotated, but logs a warning. Args: typed_signature: The signature of the function with type annotations required: The required parameters of the function Returns: A set of the missing annotations of the function """ all_missing = { k for k, v in typed_signature.parameters.items() if v.annotation is inspect.Signature.empty } missing = all_missing.intersection(set(required)) unannotated_with_default = all_missing.difference(missing) return missing, unannotated_with_default def get_openai_function_schema_from_func( function: Callable[..., Any], *, name: Optional[str] = None, description: str = None, ) -> Dict[str, Any]: """Get a JSON schema for a function as defined by the OpenAI API Args: f: The function to get the JSON schema for name: The name of the function description: The description of the function Returns: A JSON schema for the function Raises: TypeError: If the function is not annotated Examples: ```python def f(a: Annotated[str, "Parameter a"], b: int = 2, c: Annotated[float, "Parameter c"] = 0.1) -> None: pass get_function_schema(f, description="function f") # {'type': 'function', # 'function': {'description': 'function f', # 'name': 'f', # 'parameters': {'type': 'object', # 'properties': {'a': {'type': 'str', 'description': 'Parameter a'}, # 'b': {'type': 'int', 'description': 'b'}, # 'c': {'type': 'float', 'description': 'Parameter c'}}, # 'required': ['a']}}} ``` """ typed_signature = get_typed_signature(function) required = get_required_params(typed_signature) default_values = get_default_values(typed_signature) param_annotations = get_param_annotations(typed_signature) return_annotation = get_typed_return_annotation(function) missing, unannotated_with_default = get_missing_annotations( typed_signature, required ) if return_annotation is None: logger.warning( f"The return type of the function '{function.__name__}' is not annotated. Although annotating it is " + "optional, the function should return either a string, a subclass of 'pydantic.BaseModel'." ) if unannotated_with_default != set(): unannotated_with_default_s = [ f"'{k}'" for k in sorted(unannotated_with_default) ] logger.warning( f"The following parameters of the function '{function.__name__}' with default values are not annotated: " + f"{', '.join(unannotated_with_default_s)}." ) if missing != set(): missing_s = [f"'{k}'" for k in sorted(missing)] raise TypeError( f"All parameters of the function '{function.__name__}' without default values must be annotated. " + f"The annotations are missing for the following parameters: {', '.join(missing_s)}" ) fname = name if name else function.__name__ parameters = get_parameters( required, param_annotations, default_values=default_values ) function = ToolFunction( function=Function( description=description, name=fname, parameters=parameters, ) ) return model_dump(function) # def get_load_param_if_needed_function( t: Any, ) -> Optional[Callable[[Dict[str, Any], Type[BaseModel]], BaseModel]]: """Get a function to load a parameter if it is a Pydantic model Args: t: The type annotation of the parameter Returns: A function to load the parameter if it is a Pydantic model, otherwise None """ if get_origin(t) is Annotated: return get_load_param_if_needed_function(get_args(t)[0]) def load_base_model( v: Dict[str, Any], t: Type[BaseModel] ) -> BaseModel: return t(**v) return ( load_base_model if isinstance(t, type) and issubclass(t, BaseModel) else None ) def load_basemodels_if_needed( func: Callable[..., Any] ) -> Callable[..., Any]: """A decorator to load the parameters of a function if they are Pydantic models Args: func: The function with annotated parameters Returns: A function that loads the parameters before calling the original function """ # get the type annotations of the parameters typed_signature = get_typed_signature(func) param_annotations = get_param_annotations(typed_signature) # get functions for loading BaseModels when needed based on the type annotations kwargs_mapping_with_nones = { k: get_load_param_if_needed_function(t) for k, t in param_annotations.items() } # remove the None values kwargs_mapping = { k: f for k, f in kwargs_mapping_with_nones.items() if f is not None } # a function that loads the parameters before calling the original function @functools.wraps(func) def _load_parameters_if_needed(*args: Any, **kwargs: Any) -> Any: # load the BaseModels if needed for k, f in kwargs_mapping.items(): kwargs[k] = f(kwargs[k], param_annotations[k]) # call the original function return func(*args, **kwargs) @functools.wraps(func) async def _a_load_parameters_if_needed( *args: Any, **kwargs: Any ) -> Any: # load the BaseModels if needed for k, f in kwargs_mapping.items(): kwargs[k] = f(kwargs[k], param_annotations[k]) # call the original function return await func(*args, **kwargs) if inspect.iscoroutinefunction(func): return _a_load_parameters_if_needed else: return _load_parameters_if_needed def serialize_to_str(x: Any) -> str: if isinstance(x, str): return x elif isinstance(x, BaseModel): return model_dump_json(x) else: return json.dumps(x)