remove pydantic from tools.py

pull/250/head
evelynmitchell 1 year ago
parent cfb08bb11a
commit f79f7bbdd8

@ -29,13 +29,7 @@ from langchain.callbacks.manager import (
) )
from langchain.load.serializable import Serializable from langchain.load.serializable import Serializable
from pydantic import (
model_validator, BaseModel,
Extra,
Field,
create_model,
validate_arguments,
)
from langchain.schema.runnable import ( from langchain.schema.runnable import (
Runnable, Runnable,
RunnableConfig, RunnableConfig,
@ -47,61 +41,8 @@ class SchemaAnnotationError(TypeError):
"""Raised when 'args_schema' is missing or has an incorrect type annotation.""" """Raised when 'args_schema' is missing or has an incorrect type annotation."""
def _create_subset_model(
name: str, model: BaseModel, field_names: list
) -> Type[BaseModel]:
"""Create a pydantic model with only a subset of model's fields."""
fields = {}
for field_name in field_names:
field = model.__fields__[field_name]
fields[field_name] = (field.outer_type_, field.field_info)
return create_model(name, **fields) # type: ignore
def _get_filtered_args(
inferred_model: Type[BaseModel],
func: Callable,
) -> dict:
"""Get the arguments from a function's signature."""
schema = inferred_model.schema()["properties"]
valid_keys = signature(func).parameters
return {
k: schema[k]
for k in valid_keys
if k not in ("run_manager", "callbacks")
}
class _SchemaConfig:
"""Configuration for the pydantic model."""
extra: Any = Extra.forbid
arbitrary_types_allowed: bool = True
def create_schema_from_function(
model_name: str,
func: Callable,
) -> Type[BaseModel]:
"""Create a pydantic schema from a function's signature.
Args:
model_name: Name to assign to the generated pydandic schema
func: Function to generate the schema from
Returns:
A pydantic model with the same arguments as the function
"""
# https://docs.pydantic.dev/latest/usage/validation_decorator/
validated = validate_arguments(func, config=_SchemaConfig) # type: ignore
inferred_model = validated.model # type: ignore
if "run_manager" in inferred_model.__fields__:
del inferred_model.__fields__["run_manager"]
if "callbacks" in inferred_model.__fields__:
del inferred_model.__fields__["callbacks"]
# Pydantic adds placeholder virtual fields we need to strip
valid_properties = _get_filtered_args(inferred_model, func)
return _create_subset_model(
f"{model_name}Schema", inferred_model, list(valid_properties)
)
class ToolException(Exception): class ToolException(Exception):
@ -130,7 +71,7 @@ class BaseTool(RunnableSerializable[Union[str, Dict], Any]):
if args_schema_type is not None: if args_schema_type is not None:
if ( if (
args_schema_type is None args_schema_type is None
or args_schema_type == BaseModel # or args_schema_type == BaseModel
): ):
# Throw errors for common mis-annotations. # Throw errors for common mis-annotations.
# TODO: Use get_args / get_origin and fully # TODO: Use get_args / get_origin and fully
@ -167,8 +108,9 @@ class ChildTool(BaseTool):
verbose: bool = False verbose: bool = False
"""Whether to log the tool's progress.""" """Whether to log the tool's progress."""
callbacks: Callbacks = Field(default=None, exclude=True) callbacks:
"""Callbacks to be called during tool execution.""" """Callbacks to be called during tool execution."""
# TODO: I don't know how to remove Field here
callback_manager: Optional[BaseCallbackManager] = Field( callback_manager: Optional[BaseCallbackManager] = Field(
default=None, exclude=True default=None, exclude=True
) )
@ -195,7 +137,7 @@ class ChildTool(BaseTool):
# Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information. # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information.
class Config(Serializable.Config): class Config(Serializable.Config):
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
model_config = {}
arbitrary_types_allowed = True arbitrary_types_allowed = True
@property @property
@ -215,7 +157,8 @@ class ChildTool(BaseTool):
# --- Runnable --- # --- Runnable ---
@property @property
def input_schema(self) -> Type[BaseModel]: # TODO
def input_schema(self):
"""The tool's input schema.""" """The tool's input schema."""
if self.args_schema is not None: if self.args_schema is not None:
return self.args_schema return self.args_schema
@ -277,7 +220,6 @@ class ChildTool(BaseTool):
} }
return tool_input return tool_input
@model_validator()
@classmethod @classmethod
def raise_deprecation(cls, values: Dict) -> Dict: def raise_deprecation(cls, values: Dict) -> Dict:
"""Raise deprecation warning if callback_manager is used.""" """Raise deprecation warning if callback_manager is used."""
@ -673,9 +615,7 @@ class StructuredTool(BaseTool):
"""Tool that can operate on any number of inputs.""" """Tool that can operate on any number of inputs."""
description: str = "" description: str = ""
args_schema: Type[BaseModel] = Field(
..., description="The tool schema."
)
"""The input arguments' schema.""" """The input arguments' schema."""
func: Optional[Callable[..., Any]] func: Optional[Callable[..., Any]]
"""The function to run when the tool is called.""" """The function to run when the tool is called."""

Loading…
Cancel
Save