parent
01bfc71eb0
commit
8f91a7dd9b
@ -0,0 +1,134 @@
|
|||||||
|
from swarms import Agent, Anthropic, tool, ChromaDB
|
||||||
|
import subprocess
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
# Initilaize the chromadb client
|
||||||
|
chromadb = ChromaDB(
|
||||||
|
metric="cosine",
|
||||||
|
output="results",
|
||||||
|
docs_folder="docs",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Create a schema for the code revision tool
|
||||||
|
class CodeRevisionSchema(BaseModel):
|
||||||
|
code: str = None
|
||||||
|
revision: str = None
|
||||||
|
|
||||||
|
|
||||||
|
# iNitialize the schema
|
||||||
|
tool_schema = CodeRevisionSchema(
|
||||||
|
code="print('Hello, World!')",
|
||||||
|
revision="print('What is 2+2')",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Model
|
||||||
|
llm = Anthropic(
|
||||||
|
temperature=0.1,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Tools
|
||||||
|
@tool
|
||||||
|
def terminal(
|
||||||
|
code: str,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Run code in the terminal.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
code (str): The code to run in the terminal.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The output of the code.
|
||||||
|
"""
|
||||||
|
out = subprocess.run(
|
||||||
|
code, shell=True, capture_output=True, text=True
|
||||||
|
).stdout
|
||||||
|
return str(out)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def browser(query: str):
|
||||||
|
"""
|
||||||
|
Search the query in the browser with the `browser` tool.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query (str): The query to search in the browser.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The search results.
|
||||||
|
"""
|
||||||
|
import webbrowser
|
||||||
|
|
||||||
|
url = f"https://www.google.com/search?q={query}"
|
||||||
|
webbrowser.open(url)
|
||||||
|
return f"Searching for {query} in the browser."
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def create_file(file_path: str, content: str):
|
||||||
|
"""
|
||||||
|
Create a file using the file editor tool.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path (str): The path to the file.
|
||||||
|
content (str): The content to write to the file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The result of the file creation operation.
|
||||||
|
"""
|
||||||
|
with open(file_path, "w") as file:
|
||||||
|
file.write(content)
|
||||||
|
return f"File {file_path} created successfully."
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def file_editor(file_path: str, mode: str, content: str):
|
||||||
|
"""
|
||||||
|
Edit a file using the file editor tool.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path (str): The path to the file.
|
||||||
|
mode (str): The mode to open the file in.
|
||||||
|
content (str): The content to write to the file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The result of the file editing operation.
|
||||||
|
"""
|
||||||
|
with open(file_path, mode) as file:
|
||||||
|
file.write(content)
|
||||||
|
return f"File {file_path} edited successfully."
|
||||||
|
|
||||||
|
|
||||||
|
# Agent
|
||||||
|
agent = Agent(
|
||||||
|
agent_name="Devin",
|
||||||
|
system_prompt=(
|
||||||
|
"Autonomous agent that can interact with humans and other"
|
||||||
|
" agents. Be Helpful and Kind. Use the tools provided to"
|
||||||
|
" assist the user. Return all code in markdown format."
|
||||||
|
),
|
||||||
|
llm=llm,
|
||||||
|
max_loops="auto",
|
||||||
|
autosave=True,
|
||||||
|
dashboard=False,
|
||||||
|
streaming_on=True,
|
||||||
|
verbose=True,
|
||||||
|
stopping_token="<DONE>",
|
||||||
|
interactive=True,
|
||||||
|
tools=[terminal, browser, file_editor, create_file],
|
||||||
|
long_term_memory=chromadb,
|
||||||
|
output_type=tool_schema, # or dict, or str
|
||||||
|
metadata_output_type="json",
|
||||||
|
# List of schemas that the agent can handle
|
||||||
|
list_tool_schemas=[tool_schema],
|
||||||
|
function_calling_format_type="OpenAI",
|
||||||
|
function_calling_type="json", # or soon yaml
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run the agent
|
||||||
|
out = agent.run("Create a new file for a plan to take over the world.")
|
||||||
|
print(out)
|
@ -0,0 +1,41 @@
|
|||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionSchema(BaseModel):
|
||||||
|
name: str = Field(
|
||||||
|
...,
|
||||||
|
title="Name",
|
||||||
|
description="The name of the function.",
|
||||||
|
)
|
||||||
|
description: str = Field(
|
||||||
|
...,
|
||||||
|
title="Description",
|
||||||
|
description="The description of the function.",
|
||||||
|
)
|
||||||
|
parameters: BaseModel = Field(
|
||||||
|
...,
|
||||||
|
title="Parameters",
|
||||||
|
description="The parameters of the function.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIFunctionCallSchema(BaseModel):
|
||||||
|
"""
|
||||||
|
Represents the schema for an OpenAI function call.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
type (str): The type of the function.
|
||||||
|
function (List[FunctionSchema]): The function to call.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: str = Field(
|
||||||
|
"function",
|
||||||
|
title="Type",
|
||||||
|
description="The type of the function.",
|
||||||
|
)
|
||||||
|
function: List[FunctionSchema] = Field(
|
||||||
|
...,
|
||||||
|
title="Function",
|
||||||
|
description="The function to call.",
|
||||||
|
)
|
@ -0,0 +1,547 @@
|
|||||||
|
import functools
|
||||||
|
import inspect
|
||||||
|
import json
|
||||||
|
from logging import getLogger
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
ForwardRef,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Set,
|
||||||
|
Tuple,
|
||||||
|
Type,
|
||||||
|
TypeVar,
|
||||||
|
Union,
|
||||||
|
get_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from pydantic.version import VERSION as PYDANTIC_VERSION
|
||||||
|
from typing_extensions import Annotated, Literal, get_args, get_origin
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
__all__ = (
|
||||||
|
"JsonSchemaValue",
|
||||||
|
"model_dump",
|
||||||
|
"model_dump_json",
|
||||||
|
"type2schema",
|
||||||
|
"evaluate_forwardref",
|
||||||
|
)
|
||||||
|
|
||||||
|
PYDANTIC_V1 = PYDANTIC_VERSION.startswith("1.")
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
if not PYDANTIC_V1:
|
||||||
|
from pydantic import TypeAdapter
|
||||||
|
from pydantic._internal._typing_extra import (
|
||||||
|
eval_type_lenient as evaluate_forwardref,
|
||||||
|
)
|
||||||
|
from pydantic.json_schema import JsonSchemaValue
|
||||||
|
|
||||||
|
def type2schema(t: Any) -> JsonSchemaValue:
|
||||||
|
"""Convert a type to a JSON schema
|
||||||
|
|
||||||
|
Args:
|
||||||
|
t (Type): The type to convert
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JsonSchemaValue: The JSON schema
|
||||||
|
"""
|
||||||
|
return TypeAdapter(t).json_schema()
|
||||||
|
|
||||||
|
def model_dump(model: BaseModel) -> Dict[str, Any]:
|
||||||
|
"""Convert a pydantic model to a dict
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (BaseModel): The model to convert
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Any]: The dict representation of the model
|
||||||
|
|
||||||
|
"""
|
||||||
|
return model.model_dump()
|
||||||
|
|
||||||
|
def model_dump_json(model: BaseModel) -> str:
|
||||||
|
"""Convert a pydantic model to a JSON string
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (BaseModel): The model to convert
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The JSON string representation of the model
|
||||||
|
"""
|
||||||
|
return model.model_dump_json()
|
||||||
|
|
||||||
|
|
||||||
|
# Remove this once we drop support for pydantic 1.x
|
||||||
|
else: # pragma: no cover
|
||||||
|
from pydantic import schema_of
|
||||||
|
from pydantic.typing import (
|
||||||
|
evaluate_forwardref as evaluate_forwardref, # type: ignore[no-redef]
|
||||||
|
)
|
||||||
|
|
||||||
|
JsonSchemaValue = Dict[str, Any] # type: ignore[misc]
|
||||||
|
|
||||||
|
def type2schema(t: Any) -> JsonSchemaValue:
|
||||||
|
"""Convert a type to a JSON schema
|
||||||
|
|
||||||
|
Args:
|
||||||
|
t (Type): The type to convert
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JsonSchemaValue: The JSON schema
|
||||||
|
"""
|
||||||
|
if PYDANTIC_V1:
|
||||||
|
if t is None:
|
||||||
|
return {"type": "null"}
|
||||||
|
elif get_origin(t) is Union:
|
||||||
|
return {"anyOf": [type2schema(tt) for tt in get_args(t)]}
|
||||||
|
elif get_origin(t) in [Tuple, tuple]:
|
||||||
|
prefixItems = [type2schema(tt) for tt in get_args(t)]
|
||||||
|
return {
|
||||||
|
"maxItems": len(prefixItems),
|
||||||
|
"minItems": len(prefixItems),
|
||||||
|
"prefixItems": prefixItems,
|
||||||
|
"type": "array",
|
||||||
|
}
|
||||||
|
|
||||||
|
d = schema_of(t)
|
||||||
|
if "title" in d:
|
||||||
|
d.pop("title")
|
||||||
|
if "description" in d:
|
||||||
|
d.pop("description")
|
||||||
|
|
||||||
|
return d
|
||||||
|
|
||||||
|
def model_dump(model: BaseModel) -> Dict[str, Any]:
|
||||||
|
"""Convert a pydantic model to a dict
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (BaseModel): The model to convert
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Any]: The dict representation of the model
|
||||||
|
|
||||||
|
"""
|
||||||
|
return model.dict()
|
||||||
|
|
||||||
|
def model_dump_json(model: BaseModel) -> str:
|
||||||
|
"""Convert a pydantic model to a JSON string
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (BaseModel): The model to convert
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The JSON string representation of the model
|
||||||
|
"""
|
||||||
|
return model.json()
|
||||||
|
|
||||||
|
|
||||||
|
def get_typed_annotation(annotation: Any, globalns: Dict[str, Any]) -> Any:
|
||||||
|
"""Get the type annotation of a parameter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
annotation: The annotation of the parameter
|
||||||
|
globalns: The global namespace of the function
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The type annotation of the parameter
|
||||||
|
"""
|
||||||
|
if isinstance(annotation, str):
|
||||||
|
annotation = ForwardRef(annotation)
|
||||||
|
annotation = evaluate_forwardref(annotation, globalns, globalns)
|
||||||
|
return annotation
|
||||||
|
|
||||||
|
|
||||||
|
def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
|
||||||
|
"""Get the signature of a function with type annotations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
call: The function to get the signature for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The signature of the function with type annotations
|
||||||
|
"""
|
||||||
|
signature = inspect.signature(call)
|
||||||
|
globalns = getattr(call, "__globals__", {})
|
||||||
|
typed_params = [
|
||||||
|
inspect.Parameter(
|
||||||
|
name=param.name,
|
||||||
|
kind=param.kind,
|
||||||
|
default=param.default,
|
||||||
|
annotation=get_typed_annotation(param.annotation, globalns),
|
||||||
|
)
|
||||||
|
for param in signature.parameters.values()
|
||||||
|
]
|
||||||
|
typed_signature = inspect.Signature(typed_params)
|
||||||
|
return typed_signature
|
||||||
|
|
||||||
|
|
||||||
|
def get_typed_return_annotation(call: Callable[..., Any]) -> Any:
|
||||||
|
"""Get the return annotation of a function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
call: The function to get the return annotation for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The return annotation of the function
|
||||||
|
"""
|
||||||
|
signature = inspect.signature(call)
|
||||||
|
annotation = signature.return_annotation
|
||||||
|
|
||||||
|
if annotation is inspect.Signature.empty:
|
||||||
|
return None
|
||||||
|
|
||||||
|
globalns = getattr(call, "__globals__", {})
|
||||||
|
return get_typed_annotation(annotation, globalns)
|
||||||
|
|
||||||
|
|
||||||
|
def get_param_annotations(
|
||||||
|
typed_signature: inspect.Signature,
|
||||||
|
) -> Dict[str, Union[Annotated[Type[Any], str], Type[Any]]]:
|
||||||
|
"""Get the type annotations of the parameters of a function
|
||||||
|
|
||||||
|
Args:
|
||||||
|
typed_signature: The signature of the function with type annotations
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary of the type annotations of the parameters of the function
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
k: v.annotation
|
||||||
|
for k, v in typed_signature.parameters.items()
|
||||||
|
if v.annotation is not inspect.Signature.empty
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class Parameters(BaseModel):
|
||||||
|
"""Parameters of a function as defined by the OpenAI API"""
|
||||||
|
|
||||||
|
type: Literal["object"] = "object"
|
||||||
|
properties: Dict[str, JsonSchemaValue]
|
||||||
|
required: List[str]
|
||||||
|
|
||||||
|
|
||||||
|
class Function(BaseModel):
|
||||||
|
"""A function as defined by the OpenAI API"""
|
||||||
|
|
||||||
|
description: Annotated[
|
||||||
|
str, Field(description="Description of the function")
|
||||||
|
]
|
||||||
|
name: Annotated[str, Field(description="Name of the function")]
|
||||||
|
parameters: Annotated[
|
||||||
|
Parameters, Field(description="Parameters of the function")
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class ToolFunction(BaseModel):
|
||||||
|
"""A function under tool as defined by the OpenAI API."""
|
||||||
|
|
||||||
|
type: Literal["function"] = "function"
|
||||||
|
function: Annotated[Function, Field(description="Function under tool")]
|
||||||
|
|
||||||
|
|
||||||
|
def get_parameter_json_schema(
|
||||||
|
k: str, v: Any, default_values: Dict[str, Any]
|
||||||
|
) -> JsonSchemaValue:
|
||||||
|
"""Get a JSON schema for a parameter as defined by the OpenAI API
|
||||||
|
|
||||||
|
Args:
|
||||||
|
k: The name of the parameter
|
||||||
|
v: The type of the parameter
|
||||||
|
default_values: The default values of the parameters of the function
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Pydanitc model for the parameter
|
||||||
|
"""
|
||||||
|
|
||||||
|
def type2description(
|
||||||
|
k: str, v: Union[Annotated[Type[Any], str], Type[Any]]
|
||||||
|
) -> str:
|
||||||
|
# handles Annotated
|
||||||
|
if hasattr(v, "__metadata__"):
|
||||||
|
retval = v.__metadata__[0]
|
||||||
|
if isinstance(retval, str):
|
||||||
|
return retval
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid description {retval} for parameter {k}, should be a string."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return k
|
||||||
|
|
||||||
|
schema = type2schema(v)
|
||||||
|
if k in default_values:
|
||||||
|
dv = default_values[k]
|
||||||
|
schema["default"] = dv
|
||||||
|
|
||||||
|
schema["description"] = type2description(k, v)
|
||||||
|
|
||||||
|
return schema
|
||||||
|
|
||||||
|
|
||||||
|
def get_required_params(typed_signature: inspect.Signature) -> List[str]:
|
||||||
|
"""Get the required parameters of a function
|
||||||
|
|
||||||
|
Args:
|
||||||
|
signature: The signature of the function as returned by inspect.signature
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of the required parameters of the function
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
k
|
||||||
|
for k, v in typed_signature.parameters.items()
|
||||||
|
if v.default == inspect.Signature.empty
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_values(
|
||||||
|
typed_signature: inspect.Signature,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Get default values of parameters of a function
|
||||||
|
|
||||||
|
Args:
|
||||||
|
signature: The signature of the function as returned by inspect.signature
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary of the default values of the parameters of the function
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
k: v.default
|
||||||
|
for k, v in typed_signature.parameters.items()
|
||||||
|
if v.default != inspect.Signature.empty
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_parameters(
|
||||||
|
required: List[str],
|
||||||
|
param_annotations: Dict[
|
||||||
|
str, Union[Annotated[Type[Any], str], Type[Any]]
|
||||||
|
],
|
||||||
|
default_values: Dict[str, Any],
|
||||||
|
) -> Parameters:
|
||||||
|
"""Get the parameters of a function as defined by the OpenAI API
|
||||||
|
|
||||||
|
Args:
|
||||||
|
required: The required parameters of the function
|
||||||
|
hints: The type hints of the function as returned by typing.get_type_hints
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Pydantic model for the parameters of the function
|
||||||
|
"""
|
||||||
|
return Parameters(
|
||||||
|
properties={
|
||||||
|
k: get_parameter_json_schema(k, v, default_values)
|
||||||
|
for k, v in param_annotations.items()
|
||||||
|
if v is not inspect.Signature.empty
|
||||||
|
},
|
||||||
|
required=required,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_missing_annotations(
|
||||||
|
typed_signature: inspect.Signature, required: List[str]
|
||||||
|
) -> Tuple[Set[str], Set[str]]:
|
||||||
|
"""Get the missing annotations of a function
|
||||||
|
|
||||||
|
Ignores the parameters with default values as they are not required to be annotated, but logs a warning.
|
||||||
|
Args:
|
||||||
|
typed_signature: The signature of the function with type annotations
|
||||||
|
required: The required parameters of the function
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A set of the missing annotations of the function
|
||||||
|
"""
|
||||||
|
all_missing = {
|
||||||
|
k
|
||||||
|
for k, v in typed_signature.parameters.items()
|
||||||
|
if v.annotation is inspect.Signature.empty
|
||||||
|
}
|
||||||
|
missing = all_missing.intersection(set(required))
|
||||||
|
unannotated_with_default = all_missing.difference(missing)
|
||||||
|
return missing, unannotated_with_default
|
||||||
|
|
||||||
|
|
||||||
|
def get_openai_function_schema(
|
||||||
|
function: Callable[..., Any],
|
||||||
|
*,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
description: str,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Get a JSON schema for a function as defined by the OpenAI API
|
||||||
|
|
||||||
|
Args:
|
||||||
|
f: The function to get the JSON schema for
|
||||||
|
name: The name of the function
|
||||||
|
description: The description of the function
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A JSON schema for the function
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If the function is not annotated
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def f(a: Annotated[str, "Parameter a"], b: int = 2, c: Annotated[float, "Parameter c"] = 0.1) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
get_function_schema(f, description="function f")
|
||||||
|
|
||||||
|
# {'type': 'function',
|
||||||
|
# 'function': {'description': 'function f',
|
||||||
|
# 'name': 'f',
|
||||||
|
# 'parameters': {'type': 'object',
|
||||||
|
# 'properties': {'a': {'type': 'str', 'description': 'Parameter a'},
|
||||||
|
# 'b': {'type': 'int', 'description': 'b'},
|
||||||
|
# 'c': {'type': 'float', 'description': 'Parameter c'}},
|
||||||
|
# 'required': ['a']}}}
|
||||||
|
```
|
||||||
|
|
||||||
|
"""
|
||||||
|
typed_signature = get_typed_signature(function)
|
||||||
|
required = get_required_params(typed_signature)
|
||||||
|
default_values = get_default_values(typed_signature)
|
||||||
|
param_annotations = get_param_annotations(typed_signature)
|
||||||
|
return_annotation = get_typed_return_annotation(function)
|
||||||
|
missing, unannotated_with_default = get_missing_annotations(
|
||||||
|
typed_signature, required
|
||||||
|
)
|
||||||
|
|
||||||
|
if return_annotation is None:
|
||||||
|
logger.warning(
|
||||||
|
f"The return type of the function '{function.__name__}' is not annotated. Although annotating it is "
|
||||||
|
+ "optional, the function should return either a string, a subclass of 'pydantic.BaseModel'."
|
||||||
|
)
|
||||||
|
|
||||||
|
if unannotated_with_default != set():
|
||||||
|
unannotated_with_default_s = [
|
||||||
|
f"'{k}'" for k in sorted(unannotated_with_default)
|
||||||
|
]
|
||||||
|
logger.warning(
|
||||||
|
f"The following parameters of the function '{function.__name__}' with default values are not annotated: "
|
||||||
|
+ f"{', '.join(unannotated_with_default_s)}."
|
||||||
|
)
|
||||||
|
|
||||||
|
if missing != set():
|
||||||
|
missing_s = [f"'{k}'" for k in sorted(missing)]
|
||||||
|
raise TypeError(
|
||||||
|
f"All parameters of the function '{function.__name__}' without default values must be annotated. "
|
||||||
|
+ f"The annotations are missing for the following parameters: {', '.join(missing_s)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
fname = name if name else function.__name__
|
||||||
|
|
||||||
|
parameters = get_parameters(
|
||||||
|
required, param_annotations, default_values=default_values
|
||||||
|
)
|
||||||
|
|
||||||
|
function = ToolFunction(
|
||||||
|
function=Function(
|
||||||
|
description=description,
|
||||||
|
name=fname,
|
||||||
|
parameters=parameters,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return model_dump(function)
|
||||||
|
|
||||||
|
def test(a: int = 1, b: int = 2):
|
||||||
|
return a + b
|
||||||
|
|
||||||
|
#
|
||||||
|
def get_load_param_if_needed_function(
|
||||||
|
t: Any,
|
||||||
|
) -> Optional[Callable[[Dict[str, Any], Type[BaseModel]], BaseModel]]:
|
||||||
|
"""Get a function to load a parameter if it is a Pydantic model
|
||||||
|
|
||||||
|
Args:
|
||||||
|
t: The type annotation of the parameter
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A function to load the parameter if it is a Pydantic model, otherwise None
|
||||||
|
|
||||||
|
"""
|
||||||
|
if get_origin(t) is Annotated:
|
||||||
|
return get_load_param_if_needed_function(get_args(t)[0])
|
||||||
|
|
||||||
|
def load_base_model(
|
||||||
|
v: Dict[str, Any], t: Type[BaseModel]
|
||||||
|
) -> BaseModel:
|
||||||
|
return t(**v)
|
||||||
|
|
||||||
|
return (
|
||||||
|
load_base_model
|
||||||
|
if isinstance(t, type) and issubclass(t, BaseModel)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_basemodels_if_needed(
|
||||||
|
func: Callable[..., Any]
|
||||||
|
) -> Callable[..., Any]:
|
||||||
|
"""A decorator to load the parameters of a function if they are Pydantic models
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: The function with annotated parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A function that loads the parameters before calling the original function
|
||||||
|
|
||||||
|
"""
|
||||||
|
# get the type annotations of the parameters
|
||||||
|
typed_signature = get_typed_signature(func)
|
||||||
|
param_annotations = get_param_annotations(typed_signature)
|
||||||
|
|
||||||
|
# get functions for loading BaseModels when needed based on the type annotations
|
||||||
|
kwargs_mapping_with_nones = {
|
||||||
|
k: get_load_param_if_needed_function(t)
|
||||||
|
for k, t in param_annotations.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
# remove the None values
|
||||||
|
kwargs_mapping = {
|
||||||
|
k: f for k, f in kwargs_mapping_with_nones.items() if f is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
# a function that loads the parameters before calling the original function
|
||||||
|
@functools.wraps(func)
|
||||||
|
def _load_parameters_if_needed(*args: Any, **kwargs: Any) -> Any:
|
||||||
|
# load the BaseModels if needed
|
||||||
|
for k, f in kwargs_mapping.items():
|
||||||
|
kwargs[k] = f(kwargs[k], param_annotations[k])
|
||||||
|
|
||||||
|
# call the original function
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
@functools.wraps(func)
|
||||||
|
async def _a_load_parameters_if_needed(
|
||||||
|
*args: Any, **kwargs: Any
|
||||||
|
) -> Any:
|
||||||
|
# load the BaseModels if needed
|
||||||
|
for k, f in kwargs_mapping.items():
|
||||||
|
kwargs[k] = f(kwargs[k], param_annotations[k])
|
||||||
|
|
||||||
|
# call the original function
|
||||||
|
return await func(*args, **kwargs)
|
||||||
|
|
||||||
|
if inspect.iscoroutinefunction(func):
|
||||||
|
return _a_load_parameters_if_needed
|
||||||
|
else:
|
||||||
|
return _load_parameters_if_needed
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_to_str(x: Any) -> str:
|
||||||
|
if isinstance(x, str):
|
||||||
|
return x
|
||||||
|
elif isinstance(x, BaseModel):
|
||||||
|
return model_dump_json(x)
|
||||||
|
else:
|
||||||
|
return json.dumps(x)
|
Loading…
Reference in new issue