|
|
|
@ -29,13 +29,7 @@ from langchain.callbacks.manager import (
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
from langchain.load.serializable import Serializable
|
|
|
|
|
from pydantic import (
|
|
|
|
|
model_validator, BaseModel,
|
|
|
|
|
Extra,
|
|
|
|
|
Field,
|
|
|
|
|
create_model,
|
|
|
|
|
validate_arguments,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
from langchain.schema.runnable import (
|
|
|
|
|
Runnable,
|
|
|
|
|
RunnableConfig,
|
|
|
|
@ -47,61 +41,8 @@ class SchemaAnnotationError(TypeError):
|
|
|
|
|
"""Raised when 'args_schema' is missing or has an incorrect type annotation."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _create_subset_model(
|
|
|
|
|
name: str, model: BaseModel, field_names: list
|
|
|
|
|
) -> Type[BaseModel]:
|
|
|
|
|
"""Create a pydantic model with only a subset of model's fields."""
|
|
|
|
|
fields = {}
|
|
|
|
|
for field_name in field_names:
|
|
|
|
|
field = model.__fields__[field_name]
|
|
|
|
|
fields[field_name] = (field.outer_type_, field.field_info)
|
|
|
|
|
return create_model(name, **fields) # type: ignore
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_filtered_args(
|
|
|
|
|
inferred_model: Type[BaseModel],
|
|
|
|
|
func: Callable,
|
|
|
|
|
) -> dict:
|
|
|
|
|
"""Get the arguments from a function's signature."""
|
|
|
|
|
schema = inferred_model.schema()["properties"]
|
|
|
|
|
valid_keys = signature(func).parameters
|
|
|
|
|
return {
|
|
|
|
|
k: schema[k]
|
|
|
|
|
for k in valid_keys
|
|
|
|
|
if k not in ("run_manager", "callbacks")
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _SchemaConfig:
|
|
|
|
|
"""Configuration for the pydantic model."""
|
|
|
|
|
|
|
|
|
|
extra: Any = Extra.forbid
|
|
|
|
|
arbitrary_types_allowed: bool = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_schema_from_function(
|
|
|
|
|
model_name: str,
|
|
|
|
|
func: Callable,
|
|
|
|
|
) -> Type[BaseModel]:
|
|
|
|
|
"""Create a pydantic schema from a function's signature.
|
|
|
|
|
Args:
|
|
|
|
|
model_name: Name to assign to the generated pydandic schema
|
|
|
|
|
func: Function to generate the schema from
|
|
|
|
|
Returns:
|
|
|
|
|
A pydantic model with the same arguments as the function
|
|
|
|
|
"""
|
|
|
|
|
# https://docs.pydantic.dev/latest/usage/validation_decorator/
|
|
|
|
|
validated = validate_arguments(func, config=_SchemaConfig) # type: ignore
|
|
|
|
|
inferred_model = validated.model # type: ignore
|
|
|
|
|
if "run_manager" in inferred_model.__fields__:
|
|
|
|
|
del inferred_model.__fields__["run_manager"]
|
|
|
|
|
if "callbacks" in inferred_model.__fields__:
|
|
|
|
|
del inferred_model.__fields__["callbacks"]
|
|
|
|
|
# Pydantic adds placeholder virtual fields we need to strip
|
|
|
|
|
valid_properties = _get_filtered_args(inferred_model, func)
|
|
|
|
|
return _create_subset_model(
|
|
|
|
|
f"{model_name}Schema", inferred_model, list(valid_properties)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ToolException(Exception):
|
|
|
|
@ -130,7 +71,7 @@ class BaseTool(RunnableSerializable[Union[str, Dict], Any]):
|
|
|
|
|
if args_schema_type is not None:
|
|
|
|
|
if (
|
|
|
|
|
args_schema_type is None
|
|
|
|
|
or args_schema_type == BaseModel
|
|
|
|
|
# or args_schema_type == BaseModel
|
|
|
|
|
):
|
|
|
|
|
# Throw errors for common mis-annotations.
|
|
|
|
|
# TODO: Use get_args / get_origin and fully
|
|
|
|
@ -167,8 +108,9 @@ class ChildTool(BaseTool):
|
|
|
|
|
verbose: bool = False
|
|
|
|
|
"""Whether to log the tool's progress."""
|
|
|
|
|
|
|
|
|
|
callbacks: Callbacks = Field(default=None, exclude=True)
|
|
|
|
|
callbacks:
|
|
|
|
|
"""Callbacks to be called during tool execution."""
|
|
|
|
|
# TODO: I don't know how to remove Field here
|
|
|
|
|
callback_manager: Optional[BaseCallbackManager] = Field(
|
|
|
|
|
default=None, exclude=True
|
|
|
|
|
)
|
|
|
|
@ -195,7 +137,7 @@ class ChildTool(BaseTool):
|
|
|
|
|
# Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information.
|
|
|
|
|
class Config(Serializable.Config):
|
|
|
|
|
"""Configuration for this pydantic object."""
|
|
|
|
|
|
|
|
|
|
model_config = {}
|
|
|
|
|
arbitrary_types_allowed = True
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
@ -215,7 +157,8 @@ class ChildTool(BaseTool):
|
|
|
|
|
# --- Runnable ---
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def input_schema(self) -> Type[BaseModel]:
|
|
|
|
|
# TODO
|
|
|
|
|
def input_schema(self):
|
|
|
|
|
"""The tool's input schema."""
|
|
|
|
|
if self.args_schema is not None:
|
|
|
|
|
return self.args_schema
|
|
|
|
@ -277,7 +220,6 @@ class ChildTool(BaseTool):
|
|
|
|
|
}
|
|
|
|
|
return tool_input
|
|
|
|
|
|
|
|
|
|
@model_validator()
|
|
|
|
|
@classmethod
|
|
|
|
|
def raise_deprecation(cls, values: Dict) -> Dict:
|
|
|
|
|
"""Raise deprecation warning if callback_manager is used."""
|
|
|
|
@ -673,9 +615,7 @@ class StructuredTool(BaseTool):
|
|
|
|
|
"""Tool that can operate on any number of inputs."""
|
|
|
|
|
|
|
|
|
|
description: str = ""
|
|
|
|
|
args_schema: Type[BaseModel] = Field(
|
|
|
|
|
..., description="The tool schema."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
"""The input arguments' schema."""
|
|
|
|
|
func: Optional[Callable[..., Any]]
|
|
|
|
|
"""The function to run when the tool is called."""
|
|
|
|
|