pull/450/head
Kye 9 months ago
parent 01bfc71eb0
commit 8f91a7dd9b

@ -15,7 +15,14 @@ agent = Agent(
verbose=True, verbose=True,
stopping_token="<DONE>", stopping_token="<DONE>",
interactive=True, interactive=True,
state_save_file_type="json",
saved_state_path="transcript_generator.json",
) )
# Run the Agent on a task # Run the Agent on a task
agent("Generate a transcript for a youtube video on what swarms are!") # out = agent("Generate a transcript for a youtube video on what swarms are!")
check = agent.save_state(
"transcript_generator.json",
"Generate a transcript for a youtube video on what swarms are!",
)
print(check)

@ -0,0 +1,134 @@
from swarms import Agent, Anthropic, tool, ChromaDB
import subprocess
from pydantic import BaseModel
# Initilaize the chromadb client
chromadb = ChromaDB(
metric="cosine",
output="results",
docs_folder="docs",
)
# Create a schema for the code revision tool
class CodeRevisionSchema(BaseModel):
code: str = None
revision: str = None
# iNitialize the schema
tool_schema = CodeRevisionSchema(
code="print('Hello, World!')",
revision="print('What is 2+2')",
)
# Model
llm = Anthropic(
temperature=0.1,
)
# Tools
@tool
def terminal(
code: str,
):
"""
Run code in the terminal.
Args:
code (str): The code to run in the terminal.
Returns:
str: The output of the code.
"""
out = subprocess.run(
code, shell=True, capture_output=True, text=True
).stdout
return str(out)
@tool
def browser(query: str):
"""
Search the query in the browser with the `browser` tool.
Args:
query (str): The query to search in the browser.
Returns:
str: The search results.
"""
import webbrowser
url = f"https://www.google.com/search?q={query}"
webbrowser.open(url)
return f"Searching for {query} in the browser."
@tool
def create_file(file_path: str, content: str):
"""
Create a file using the file editor tool.
Args:
file_path (str): The path to the file.
content (str): The content to write to the file.
Returns:
str: The result of the file creation operation.
"""
with open(file_path, "w") as file:
file.write(content)
return f"File {file_path} created successfully."
@tool
def file_editor(file_path: str, mode: str, content: str):
"""
Edit a file using the file editor tool.
Args:
file_path (str): The path to the file.
mode (str): The mode to open the file in.
content (str): The content to write to the file.
Returns:
str: The result of the file editing operation.
"""
with open(file_path, mode) as file:
file.write(content)
return f"File {file_path} edited successfully."
# Agent
agent = Agent(
agent_name="Devin",
system_prompt=(
"Autonomous agent that can interact with humans and other"
" agents. Be Helpful and Kind. Use the tools provided to"
" assist the user. Return all code in markdown format."
),
llm=llm,
max_loops="auto",
autosave=True,
dashboard=False,
streaming_on=True,
verbose=True,
stopping_token="<DONE>",
interactive=True,
tools=[terminal, browser, file_editor, create_file],
long_term_memory=chromadb,
output_type=tool_schema, # or dict, or str
metadata_output_type="json",
# List of schemas that the agent can handle
list_tool_schemas=[tool_schema],
function_calling_format_type="OpenAI",
function_calling_type="json", # or soon yaml
)
# Run the agent
out = agent.run("Create a new file for a plan to take over the world.")
print(out)

@ -31,7 +31,7 @@ from swarms.models.popular_llms import (
) )
from swarms.models.popular_llms import OctoAIChat from swarms.models.popular_llms import OctoAIChat
from swarms.models.qwen import QwenVLMultiModal # noqa: E402 from swarms.models.qwen import QwenVLMultiModal # noqa: E402
from swarms.models.popular_llms import ReplicateChat as Replicate
from swarms.models.sampling_params import SamplingParams, SamplingType from swarms.models.sampling_params import SamplingParams, SamplingType
from swarms.models.together import TogetherLLM # noqa: E402 from swarms.models.together import TogetherLLM # noqa: E402
from swarms.models.types import ( # noqa: E402 from swarms.models.types import ( # noqa: E402

@ -33,6 +33,8 @@ from swarms.tools.pydantic_to_json import (
multi_pydantic_to_functions, multi_pydantic_to_functions,
) )
from swarms.structs.schemas import Step, ManySteps from swarms.structs.schemas import Step, ManySteps
from swarms.telemetry.user_utils import get_user_device_data
from swarms.structs.yaml_model import YamlModel
# Utils # Utils
@ -228,6 +230,7 @@ class Agent:
function_calling_format_type: Optional[str] = "OpenAI", function_calling_format_type: Optional[str] = "OpenAI",
list_tool_schemas: Optional[List[BaseModel]] = None, list_tool_schemas: Optional[List[BaseModel]] = None,
metadata_output_type: str = "json", metadata_output_type: str = "json",
state_save_file_type: str = "json",
*args, *args,
**kwargs, **kwargs,
): ):
@ -295,6 +298,7 @@ class Agent:
self.function_calling_format_type = function_calling_format_type self.function_calling_format_type = function_calling_format_type
self.list_tool_schemas = list_tool_schemas self.list_tool_schemas = list_tool_schemas
self.metadata_output_type = metadata_output_type self.metadata_output_type = metadata_output_type
self.state_save_file_type = state_save_file_type
# The max_loops will be set dynamically if the dynamic_loop # The max_loops will be set dynamically if the dynamic_loop
if self.dynamic_loops: if self.dynamic_loops:
@ -903,7 +907,7 @@ class Agent:
if self.autosave: if self.autosave:
logger.info("Autosaving agent state.") logger.info("Autosaving agent state.")
self.save_state(self.saved_state_path) self.save_state(self.saved_state_path, task)
# Apply the cleaner function to the response # Apply the cleaner function to the response
if self.output_cleaner is not None: if self.output_cleaner is not None:
@ -1131,7 +1135,7 @@ class Agent:
def graceful_shutdown(self): def graceful_shutdown(self):
"""Gracefully shutdown the system saving the state""" """Gracefully shutdown the system saving the state"""
print(colored("Shutting down the system...", "red")) print(colored("Shutting down the system...", "red"))
return self.save_state("flow_state.json") return self.save_state(f"{self.agent_name}.json")
def run_with_timeout(self, task: str, timeout: int = 60) -> str: def run_with_timeout(self, task: str, timeout: int = 60) -> str:
"""Run the loop but stop if it takes longer than the timeout""" """Run the loop but stop if it takes longer than the timeout"""
@ -1236,7 +1240,10 @@ class Agent:
except Exception as error: except Exception as error:
print(colored(f"Error saving agent to YAML: {error}", "red")) print(colored(f"Error saving agent to YAML: {error}", "red"))
def save_state(self, file_path: str) -> None: def get_llm_parameters(self):
return str(vars(self.llm))
def save_state(self, file_path: str, task: str = None) -> None:
""" """
Saves the current state of the agent to a JSON file, including the llm parameters. Saves the current state of the agent to a JSON file, including the llm parameters.
@ -1247,58 +1254,96 @@ class Agent:
>>> agent.save_state('saved_flow.json') >>> agent.save_state('saved_flow.json')
""" """
try: try:
logger.info(f"Saving agent state to: {file_path}") logger.info(
f"Saving Agent {self.agent_name} state to: {file_path}"
)
state = { state = {
"agent_id": str(self.id), "agent_id": str(self.id),
"agent_name": self.agent_name, "agent_name": self.agent_name,
"agent_description": self.agent_description, "agent_description": self.agent_description,
"LLM": str(self.get_llm_parameters()),
"system_prompt": self.system_prompt, "system_prompt": self.system_prompt,
"sop": self.sop, "short_memory": self.short_memory.return_history_as_string(),
"short_memory": (
self.short_memory.return_history_as_string()
),
"loop_interval": self.loop_interval, "loop_interval": self.loop_interval,
"retry_attempts": self.retry_attempts, "retry_attempts": self.retry_attempts,
"retry_interval": self.retry_interval, "retry_interval": self.retry_interval,
"interactive": self.interactive, "interactive": self.interactive,
"dashboard": self.dashboard, "dashboard": self.dashboard,
"dynamic_temperature": (self.dynamic_temperature_enabled), "dynamic_temperature": self.dynamic_temperature_enabled,
"autosave": self.autosave, "autosave": self.autosave,
"saved_state_path": self.saved_state_path, "saved_state_path": self.saved_state_path,
"max_loops": self.max_loops, "max_loops": self.max_loops,
"StepCache": self.step_cache,
"Task": task,
"Stopping Token": self.stopping_token,
"Dynamic Loops": self.dynamic_loops,
"tools": self.tools,
"sop": self.sop,
"sop_list": self.sop_list,
"context_length": self.context_length,
"user_name": self.user_name,
"self_healing_enabled": self.self_healing_enabled,
"code_interpreter": self.code_interpreter,
"multi_modal": self.multi_modal,
"pdf_path": self.pdf_path,
"list_of_pdf": self.list_of_pdf,
"tokenizer": self.tokenizer,
"long_term_memory": self.long_term_memory,
"preset_stopping_token": self.preset_stopping_token,
"traceback": self.traceback,
"traceback_handlers": self.traceback_handlers,
"streaming_on": self.streaming_on,
"docs": self.docs,
"docs_folder": self.docs_folder,
"verbose": self.verbose,
"parser": self.parser,
"best_of_n": self.best_of_n,
"callback": self.callback,
"metadata": self.metadata,
"callbacks": self.callbacks,
# "logger_handler": self.logger_handler,
"search_algorithm": self.search_algorithm,
"logs_to_filename": self.logs_to_filename,
"evaluator": self.evaluator,
"output_json": self.output_json,
"stopping_func": self.stopping_func,
"custom_loop_condition": self.custom_loop_condition,
"sentiment_threshold": self.sentiment_threshold,
"custom_exit_command": self.custom_exit_command,
"sentiment_analyzer": self.sentiment_analyzer,
"limit_tokens_from_string": self.limit_tokens_from_string,
# "custom_tools_prompt": self.custom_tools_prompt,
"tool_schema": self.tool_schema,
"output_type": self.output_type,
"function_calling_type": self.function_calling_type,
"output_cleaner": self.output_cleaner,
"function_calling_format_type": self.function_calling_format_type,
"list_tool_schemas": self.list_tool_schemas,
"metadata_output_type": self.metadata_output_type,
"user_meta_data": get_user_device_data(),
} }
# Save as JSON
if self.state_save_file_type == "json":
with open(file_path, "w") as f: with open(file_path, "w") as f:
json.dump(state, f, indent=4) json.dump(state, f, indent=4)
# Save as YAML
elif self.state_save_file_type == "yaml":
out = YamlModel(input_dict=state).to_yaml()
with open(self.saved_state_path, "w") as f:
f.write(out)
# Log the saved state
saved = colored(f"Saved agent state to: {file_path}", "green") saved = colored(f"Saved agent state to: {file_path}", "green")
print(saved) print(saved)
except Exception as error: except Exception as error:
print(colored(f"Error saving agent state: {error}", "red")) print(colored(f"Error saving agent state: {error}", "red"))
def state_to_str(self): def state_to_str(self, task: str):
"""Transform the JSON into a string""" """Transform the JSON into a string"""
try: try:
state = { out = self.save_state(self.saved_state_path, task)
"agent_id": str(self.id),
"agent_name": self.agent_name,
"agent_description": self.agent_description,
"system_prompt": self.system_prompt,
"sop": self.sop,
"short_memory": (
self.short_memory.return_history_as_string()
),
"loop_interval": self.loop_interval,
"retry_attempts": self.retry_attempts,
"retry_interval": self.retry_interval,
"interactive": self.interactive,
"dashboard": self.dashboard,
"dynamic_temperature": (self.dynamic_temperature_enabled),
"autosave": self.autosave,
"saved_state_path": self.saved_state_path,
"max_loops": self.max_loops,
}
out = str(state)
return out return out
except Exception as error: except Exception as error:
print( print(

@ -1,4 +1,4 @@
from pydantic import BaseModel from pydantic import BaseModel, Field
import yaml import yaml
import json import json
from swarms.utils.loguru_logger import logger from swarms.utils.loguru_logger import logger
@ -120,11 +120,17 @@ class YamlModel(BaseModel):
>>> user.save_to_yaml('user.yaml') >>> user.save_to_yaml('user.yaml')
""" """
input_dict: Dict[str, Any] = Field(
None,
title="Data",
description="The data to be serialized to YAML.",
)
def to_yaml(self): def to_yaml(self):
""" """
Serialize the Pydantic model instance to a YAML string. Serialize the Pydantic model instance to a YAML string.
""" """
return yaml.safe_dump(self.dict(), sort_keys=False) return yaml.safe_dump(self.input_dict, sort_keys=False)
def from_yaml(self, cls, yaml_str: str): def from_yaml(self, cls, yaml_str: str):
""" """

@ -88,11 +88,10 @@ def get_package_mismatches(file_path="pyproject.toml"):
def system_info(): def system_info():
swarms_verison = get_swarms_verison()
return { return {
"Python Version": get_python_version(), "Python Version": get_python_version(),
"Pip Version": get_pip_version(), "Pip Version": get_pip_version(),
"Swarms Version": swarms_verison, # "Swarms Version": swarms_verison,
"OS Version and Architecture": get_os_version(), "OS Version and Architecture": get_os_version(),
"CPU Info": get_cpu_info(), "CPU Info": get_cpu_info(),
"RAM Info": get_ram_info(), "RAM Info": get_ram_info(),

@ -3,7 +3,6 @@ import platform
import socket import socket
import uuid import uuid
from swarms.telemetry.check_update import check_for_package
from swarms.telemetry.sys_info import system_info from swarms.telemetry.sys_info import system_info
@ -83,6 +82,5 @@ def get_user_device_data():
"Machine ID": get_machine_id(), "Machine ID": get_machine_id(),
"System Info": get_system_info(), "System Info": get_system_info(),
"UniqueID": generate_unique_identifier(), "UniqueID": generate_unique_identifier(),
"Swarms [Version]": check_for_package("swarms"),
} }
return data return data

@ -21,6 +21,10 @@ from swarms.tools.pydantic_to_json import (
function_to_str, function_to_str,
functions_to_str, functions_to_str,
) )
from swarms.tools.openai_func_calling_schema import (
OpenAIFunctionCallSchema,
)
__all__ = [ __all__ = [
"scrape_tool_func_docs", "scrape_tool_func_docs",
@ -43,4 +47,5 @@ __all__ = [
"multi_pydantic_to_functions", "multi_pydantic_to_functions",
"function_to_str", "function_to_str",
"functions_to_str", "functions_to_str",
"OpenAIFunctionCallSchema",
] ]

@ -0,0 +1,41 @@
from pydantic import BaseModel, Field
from typing import List
class FunctionSchema(BaseModel):
name: str = Field(
...,
title="Name",
description="The name of the function.",
)
description: str = Field(
...,
title="Description",
description="The description of the function.",
)
parameters: BaseModel = Field(
...,
title="Parameters",
description="The parameters of the function.",
)
class OpenAIFunctionCallSchema(BaseModel):
"""
Represents the schema for an OpenAI function call.
Attributes:
type (str): The type of the function.
function (List[FunctionSchema]): The function to call.
"""
type: str = Field(
"function",
title="Type",
description="The type of the function.",
)
function: List[FunctionSchema] = Field(
...,
title="Function",
description="The function to call.",
)

@ -0,0 +1,547 @@
import functools
import inspect
import json
from logging import getLogger
from typing import (
Any,
Callable,
Dict,
ForwardRef,
List,
Optional,
Set,
Tuple,
Type,
TypeVar,
Union,
get_args,
)
from pydantic import BaseModel, Field
from pydantic.version import VERSION as PYDANTIC_VERSION
from typing_extensions import Annotated, Literal, get_args, get_origin
T = TypeVar("T")
__all__ = (
"JsonSchemaValue",
"model_dump",
"model_dump_json",
"type2schema",
"evaluate_forwardref",
)
PYDANTIC_V1 = PYDANTIC_VERSION.startswith("1.")
logger = getLogger(__name__)
if not PYDANTIC_V1:
from pydantic import TypeAdapter
from pydantic._internal._typing_extra import (
eval_type_lenient as evaluate_forwardref,
)
from pydantic.json_schema import JsonSchemaValue
def type2schema(t: Any) -> JsonSchemaValue:
"""Convert a type to a JSON schema
Args:
t (Type): The type to convert
Returns:
JsonSchemaValue: The JSON schema
"""
return TypeAdapter(t).json_schema()
def model_dump(model: BaseModel) -> Dict[str, Any]:
"""Convert a pydantic model to a dict
Args:
model (BaseModel): The model to convert
Returns:
Dict[str, Any]: The dict representation of the model
"""
return model.model_dump()
def model_dump_json(model: BaseModel) -> str:
"""Convert a pydantic model to a JSON string
Args:
model (BaseModel): The model to convert
Returns:
str: The JSON string representation of the model
"""
return model.model_dump_json()
# Remove this once we drop support for pydantic 1.x
else: # pragma: no cover
from pydantic import schema_of
from pydantic.typing import (
evaluate_forwardref as evaluate_forwardref, # type: ignore[no-redef]
)
JsonSchemaValue = Dict[str, Any] # type: ignore[misc]
def type2schema(t: Any) -> JsonSchemaValue:
"""Convert a type to a JSON schema
Args:
t (Type): The type to convert
Returns:
JsonSchemaValue: The JSON schema
"""
if PYDANTIC_V1:
if t is None:
return {"type": "null"}
elif get_origin(t) is Union:
return {"anyOf": [type2schema(tt) for tt in get_args(t)]}
elif get_origin(t) in [Tuple, tuple]:
prefixItems = [type2schema(tt) for tt in get_args(t)]
return {
"maxItems": len(prefixItems),
"minItems": len(prefixItems),
"prefixItems": prefixItems,
"type": "array",
}
d = schema_of(t)
if "title" in d:
d.pop("title")
if "description" in d:
d.pop("description")
return d
def model_dump(model: BaseModel) -> Dict[str, Any]:
"""Convert a pydantic model to a dict
Args:
model (BaseModel): The model to convert
Returns:
Dict[str, Any]: The dict representation of the model
"""
return model.dict()
def model_dump_json(model: BaseModel) -> str:
"""Convert a pydantic model to a JSON string
Args:
model (BaseModel): The model to convert
Returns:
str: The JSON string representation of the model
"""
return model.json()
def get_typed_annotation(annotation: Any, globalns: Dict[str, Any]) -> Any:
"""Get the type annotation of a parameter.
Args:
annotation: The annotation of the parameter
globalns: The global namespace of the function
Returns:
The type annotation of the parameter
"""
if isinstance(annotation, str):
annotation = ForwardRef(annotation)
annotation = evaluate_forwardref(annotation, globalns, globalns)
return annotation
def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
"""Get the signature of a function with type annotations.
Args:
call: The function to get the signature for
Returns:
The signature of the function with type annotations
"""
signature = inspect.signature(call)
globalns = getattr(call, "__globals__", {})
typed_params = [
inspect.Parameter(
name=param.name,
kind=param.kind,
default=param.default,
annotation=get_typed_annotation(param.annotation, globalns),
)
for param in signature.parameters.values()
]
typed_signature = inspect.Signature(typed_params)
return typed_signature
def get_typed_return_annotation(call: Callable[..., Any]) -> Any:
"""Get the return annotation of a function.
Args:
call: The function to get the return annotation for
Returns:
The return annotation of the function
"""
signature = inspect.signature(call)
annotation = signature.return_annotation
if annotation is inspect.Signature.empty:
return None
globalns = getattr(call, "__globals__", {})
return get_typed_annotation(annotation, globalns)
def get_param_annotations(
typed_signature: inspect.Signature,
) -> Dict[str, Union[Annotated[Type[Any], str], Type[Any]]]:
"""Get the type annotations of the parameters of a function
Args:
typed_signature: The signature of the function with type annotations
Returns:
A dictionary of the type annotations of the parameters of the function
"""
return {
k: v.annotation
for k, v in typed_signature.parameters.items()
if v.annotation is not inspect.Signature.empty
}
class Parameters(BaseModel):
"""Parameters of a function as defined by the OpenAI API"""
type: Literal["object"] = "object"
properties: Dict[str, JsonSchemaValue]
required: List[str]
class Function(BaseModel):
"""A function as defined by the OpenAI API"""
description: Annotated[
str, Field(description="Description of the function")
]
name: Annotated[str, Field(description="Name of the function")]
parameters: Annotated[
Parameters, Field(description="Parameters of the function")
]
class ToolFunction(BaseModel):
"""A function under tool as defined by the OpenAI API."""
type: Literal["function"] = "function"
function: Annotated[Function, Field(description="Function under tool")]
def get_parameter_json_schema(
k: str, v: Any, default_values: Dict[str, Any]
) -> JsonSchemaValue:
"""Get a JSON schema for a parameter as defined by the OpenAI API
Args:
k: The name of the parameter
v: The type of the parameter
default_values: The default values of the parameters of the function
Returns:
A Pydanitc model for the parameter
"""
def type2description(
k: str, v: Union[Annotated[Type[Any], str], Type[Any]]
) -> str:
# handles Annotated
if hasattr(v, "__metadata__"):
retval = v.__metadata__[0]
if isinstance(retval, str):
return retval
else:
raise ValueError(
f"Invalid description {retval} for parameter {k}, should be a string."
)
else:
return k
schema = type2schema(v)
if k in default_values:
dv = default_values[k]
schema["default"] = dv
schema["description"] = type2description(k, v)
return schema
def get_required_params(typed_signature: inspect.Signature) -> List[str]:
"""Get the required parameters of a function
Args:
signature: The signature of the function as returned by inspect.signature
Returns:
A list of the required parameters of the function
"""
return [
k
for k, v in typed_signature.parameters.items()
if v.default == inspect.Signature.empty
]
def get_default_values(
typed_signature: inspect.Signature,
) -> Dict[str, Any]:
"""Get default values of parameters of a function
Args:
signature: The signature of the function as returned by inspect.signature
Returns:
A dictionary of the default values of the parameters of the function
"""
return {
k: v.default
for k, v in typed_signature.parameters.items()
if v.default != inspect.Signature.empty
}
def get_parameters(
required: List[str],
param_annotations: Dict[
str, Union[Annotated[Type[Any], str], Type[Any]]
],
default_values: Dict[str, Any],
) -> Parameters:
"""Get the parameters of a function as defined by the OpenAI API
Args:
required: The required parameters of the function
hints: The type hints of the function as returned by typing.get_type_hints
Returns:
A Pydantic model for the parameters of the function
"""
return Parameters(
properties={
k: get_parameter_json_schema(k, v, default_values)
for k, v in param_annotations.items()
if v is not inspect.Signature.empty
},
required=required,
)
def get_missing_annotations(
typed_signature: inspect.Signature, required: List[str]
) -> Tuple[Set[str], Set[str]]:
"""Get the missing annotations of a function
Ignores the parameters with default values as they are not required to be annotated, but logs a warning.
Args:
typed_signature: The signature of the function with type annotations
required: The required parameters of the function
Returns:
A set of the missing annotations of the function
"""
all_missing = {
k
for k, v in typed_signature.parameters.items()
if v.annotation is inspect.Signature.empty
}
missing = all_missing.intersection(set(required))
unannotated_with_default = all_missing.difference(missing)
return missing, unannotated_with_default
def get_openai_function_schema(
function: Callable[..., Any],
*,
name: Optional[str] = None,
description: str,
) -> Dict[str, Any]:
"""Get a JSON schema for a function as defined by the OpenAI API
Args:
f: The function to get the JSON schema for
name: The name of the function
description: The description of the function
Returns:
A JSON schema for the function
Raises:
TypeError: If the function is not annotated
Examples:
```python
def f(a: Annotated[str, "Parameter a"], b: int = 2, c: Annotated[float, "Parameter c"] = 0.1) -> None:
pass
get_function_schema(f, description="function f")
# {'type': 'function',
# 'function': {'description': 'function f',
# 'name': 'f',
# 'parameters': {'type': 'object',
# 'properties': {'a': {'type': 'str', 'description': 'Parameter a'},
# 'b': {'type': 'int', 'description': 'b'},
# 'c': {'type': 'float', 'description': 'Parameter c'}},
# 'required': ['a']}}}
```
"""
typed_signature = get_typed_signature(function)
required = get_required_params(typed_signature)
default_values = get_default_values(typed_signature)
param_annotations = get_param_annotations(typed_signature)
return_annotation = get_typed_return_annotation(function)
missing, unannotated_with_default = get_missing_annotations(
typed_signature, required
)
if return_annotation is None:
logger.warning(
f"The return type of the function '{function.__name__}' is not annotated. Although annotating it is "
+ "optional, the function should return either a string, a subclass of 'pydantic.BaseModel'."
)
if unannotated_with_default != set():
unannotated_with_default_s = [
f"'{k}'" for k in sorted(unannotated_with_default)
]
logger.warning(
f"The following parameters of the function '{function.__name__}' with default values are not annotated: "
+ f"{', '.join(unannotated_with_default_s)}."
)
if missing != set():
missing_s = [f"'{k}'" for k in sorted(missing)]
raise TypeError(
f"All parameters of the function '{function.__name__}' without default values must be annotated. "
+ f"The annotations are missing for the following parameters: {', '.join(missing_s)}"
)
fname = name if name else function.__name__
parameters = get_parameters(
required, param_annotations, default_values=default_values
)
function = ToolFunction(
function=Function(
description=description,
name=fname,
parameters=parameters,
)
)
return model_dump(function)
def test(a: int = 1, b: int = 2):
return a + b
#
def get_load_param_if_needed_function(
t: Any,
) -> Optional[Callable[[Dict[str, Any], Type[BaseModel]], BaseModel]]:
"""Get a function to load a parameter if it is a Pydantic model
Args:
t: The type annotation of the parameter
Returns:
A function to load the parameter if it is a Pydantic model, otherwise None
"""
if get_origin(t) is Annotated:
return get_load_param_if_needed_function(get_args(t)[0])
def load_base_model(
v: Dict[str, Any], t: Type[BaseModel]
) -> BaseModel:
return t(**v)
return (
load_base_model
if isinstance(t, type) and issubclass(t, BaseModel)
else None
)
def load_basemodels_if_needed(
func: Callable[..., Any]
) -> Callable[..., Any]:
"""A decorator to load the parameters of a function if they are Pydantic models
Args:
func: The function with annotated parameters
Returns:
A function that loads the parameters before calling the original function
"""
# get the type annotations of the parameters
typed_signature = get_typed_signature(func)
param_annotations = get_param_annotations(typed_signature)
# get functions for loading BaseModels when needed based on the type annotations
kwargs_mapping_with_nones = {
k: get_load_param_if_needed_function(t)
for k, t in param_annotations.items()
}
# remove the None values
kwargs_mapping = {
k: f for k, f in kwargs_mapping_with_nones.items() if f is not None
}
# a function that loads the parameters before calling the original function
@functools.wraps(func)
def _load_parameters_if_needed(*args: Any, **kwargs: Any) -> Any:
# load the BaseModels if needed
for k, f in kwargs_mapping.items():
kwargs[k] = f(kwargs[k], param_annotations[k])
# call the original function
return func(*args, **kwargs)
@functools.wraps(func)
async def _a_load_parameters_if_needed(
*args: Any, **kwargs: Any
) -> Any:
# load the BaseModels if needed
for k, f in kwargs_mapping.items():
kwargs[k] = f(kwargs[k], param_annotations[k])
# call the original function
return await func(*args, **kwargs)
if inspect.iscoroutinefunction(func):
return _a_load_parameters_if_needed
else:
return _load_parameters_if_needed
def serialize_to_str(x: Any) -> str:
if isinstance(x, str):
return x
elif isinstance(x, BaseModel):
return model_dump_json(x)
else:
return json.dumps(x)
Loading…
Cancel
Save