From f79f7bbdd827838fbaf5f270a9e28aed327d32f6 Mon Sep 17 00:00:00 2001 From: evelynmitchell Date: Sun, 3 Dec 2023 17:07:26 -0700 Subject: [PATCH] remove pydantic from tools.py --- swarms/tools/tool.py | 76 +++++--------------------------------------- 1 file changed, 8 insertions(+), 68 deletions(-) diff --git a/swarms/tools/tool.py b/swarms/tools/tool.py index ba7752bd..6248d7fd 100644 --- a/swarms/tools/tool.py +++ b/swarms/tools/tool.py @@ -29,13 +29,7 @@ from langchain.callbacks.manager import ( ) from langchain.load.serializable import Serializable -from pydantic import ( - model_validator, BaseModel, - Extra, - Field, - create_model, - validate_arguments, -) + from langchain.schema.runnable import ( Runnable, RunnableConfig, @@ -47,62 +41,9 @@ class SchemaAnnotationError(TypeError): """Raised when 'args_schema' is missing or has an incorrect type annotation.""" -def _create_subset_model( - name: str, model: BaseModel, field_names: list -) -> Type[BaseModel]: - """Create a pydantic model with only a subset of model's fields.""" - fields = {} - for field_name in field_names: - field = model.__fields__[field_name] - fields[field_name] = (field.outer_type_, field.field_info) - return create_model(name, **fields) # type: ignore - - -def _get_filtered_args( - inferred_model: Type[BaseModel], - func: Callable, -) -> dict: - """Get the arguments from a function's signature.""" - schema = inferred_model.schema()["properties"] - valid_keys = signature(func).parameters - return { - k: schema[k] - for k in valid_keys - if k not in ("run_manager", "callbacks") - } - - -class _SchemaConfig: - """Configuration for the pydantic model.""" - extra: Any = Extra.forbid - arbitrary_types_allowed: bool = True -def create_schema_from_function( - model_name: str, - func: Callable, -) -> Type[BaseModel]: - """Create a pydantic schema from a function's signature. - Args: - model_name: Name to assign to the generated pydandic schema - func: Function to generate the schema from - Returns: - A pydantic model with the same arguments as the function - """ - # https://docs.pydantic.dev/latest/usage/validation_decorator/ - validated = validate_arguments(func, config=_SchemaConfig) # type: ignore - inferred_model = validated.model # type: ignore - if "run_manager" in inferred_model.__fields__: - del inferred_model.__fields__["run_manager"] - if "callbacks" in inferred_model.__fields__: - del inferred_model.__fields__["callbacks"] - # Pydantic adds placeholder virtual fields we need to strip - valid_properties = _get_filtered_args(inferred_model, func) - return _create_subset_model( - f"{model_name}Schema", inferred_model, list(valid_properties) - ) - class ToolException(Exception): """An optional exception that tool throws when execution error occurs. @@ -130,7 +71,7 @@ class BaseTool(RunnableSerializable[Union[str, Dict], Any]): if args_schema_type is not None: if ( args_schema_type is None - or args_schema_type == BaseModel + # or args_schema_type == BaseModel ): # Throw errors for common mis-annotations. # TODO: Use get_args / get_origin and fully @@ -167,8 +108,9 @@ class ChildTool(BaseTool): verbose: bool = False """Whether to log the tool's progress.""" - callbacks: Callbacks = Field(default=None, exclude=True) + callbacks: """Callbacks to be called during tool execution.""" + # TODO: I don't know how to remove Field here callback_manager: Optional[BaseCallbackManager] = Field( default=None, exclude=True ) @@ -195,7 +137,7 @@ class ChildTool(BaseTool): # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information. class Config(Serializable.Config): """Configuration for this pydantic object.""" - + model_config = {} arbitrary_types_allowed = True @property @@ -215,7 +157,8 @@ class ChildTool(BaseTool): # --- Runnable --- @property - def input_schema(self) -> Type[BaseModel]: + # TODO + def input_schema(self): """The tool's input schema.""" if self.args_schema is not None: return self.args_schema @@ -277,7 +220,6 @@ class ChildTool(BaseTool): } return tool_input - @model_validator() @classmethod def raise_deprecation(cls, values: Dict) -> Dict: """Raise deprecation warning if callback_manager is used.""" @@ -673,9 +615,7 @@ class StructuredTool(BaseTool): """Tool that can operate on any number of inputs.""" description: str = "" - args_schema: Type[BaseModel] = Field( - ..., description="The tool schema." - ) + """The input arguments' schema.""" func: Optional[Callable[..., Any]] """The function to run when the tool is called."""