From 8f91a7dd9baf8697cb2119fc559dce24586dd11f Mon Sep 17 00:00:00 2001 From: Kye Date: Wed, 24 Apr 2024 18:17:37 -0400 Subject: [PATCH] [CLEANUP] --- example.py | 9 +- playground/agents/complete_agent.py | 134 +++++ auto_docs.py => scripts/auto_docs.py | 0 swarms/models/__init__.py | 2 +- swarms/structs/agent.py | 109 ++-- swarms/structs/yaml_model.py | 10 +- swarms/telemetry/sys_info.py | 3 +- swarms/telemetry/user_utils.py | 2 - swarms/tools/__init__.py | 5 + swarms/tools/openai_func_calling_schema.py | 41 ++ swarms/tools/py_func_to_openai_func_str.py | 547 +++++++++++++++++++++ 11 files changed, 822 insertions(+), 40 deletions(-) create mode 100644 playground/agents/complete_agent.py rename auto_docs.py => scripts/auto_docs.py (100%) create mode 100644 swarms/tools/openai_func_calling_schema.py create mode 100644 swarms/tools/py_func_to_openai_func_str.py diff --git a/example.py b/example.py index 35413816..c6539dbe 100644 --- a/example.py +++ b/example.py @@ -15,7 +15,14 @@ agent = Agent( verbose=True, stopping_token="", interactive=True, + state_save_file_type="json", + saved_state_path="transcript_generator.json", ) # Run the Agent on a task -agent("Generate a transcript for a youtube video on what swarms are!") +# out = agent("Generate a transcript for a youtube video on what swarms are!") +check = agent.save_state( + "transcript_generator.json", + "Generate a transcript for a youtube video on what swarms are!", +) +print(check) diff --git a/playground/agents/complete_agent.py b/playground/agents/complete_agent.py new file mode 100644 index 00000000..03d7a088 --- /dev/null +++ b/playground/agents/complete_agent.py @@ -0,0 +1,134 @@ +from swarms import Agent, Anthropic, tool, ChromaDB +import subprocess +from pydantic import BaseModel + + +# Initilaize the chromadb client +chromadb = ChromaDB( + metric="cosine", + output="results", + docs_folder="docs", +) + + +# Create a schema for the code revision tool +class CodeRevisionSchema(BaseModel): + code: str = None + revision: str = None + + +# iNitialize the schema +tool_schema = CodeRevisionSchema( + code="print('Hello, World!')", + revision="print('What is 2+2')", +) + + +# Model +llm = Anthropic( + temperature=0.1, +) + + +# Tools +@tool +def terminal( + code: str, +): + """ + Run code in the terminal. + + Args: + code (str): The code to run in the terminal. + + Returns: + str: The output of the code. + """ + out = subprocess.run( + code, shell=True, capture_output=True, text=True + ).stdout + return str(out) + + +@tool +def browser(query: str): + """ + Search the query in the browser with the `browser` tool. + + Args: + query (str): The query to search in the browser. + + Returns: + str: The search results. + """ + import webbrowser + + url = f"https://www.google.com/search?q={query}" + webbrowser.open(url) + return f"Searching for {query} in the browser." + + +@tool +def create_file(file_path: str, content: str): + """ + Create a file using the file editor tool. + + Args: + file_path (str): The path to the file. + content (str): The content to write to the file. + + Returns: + str: The result of the file creation operation. + """ + with open(file_path, "w") as file: + file.write(content) + return f"File {file_path} created successfully." + + +@tool +def file_editor(file_path: str, mode: str, content: str): + """ + Edit a file using the file editor tool. + + Args: + file_path (str): The path to the file. + mode (str): The mode to open the file in. + content (str): The content to write to the file. + + Returns: + str: The result of the file editing operation. + """ + with open(file_path, mode) as file: + file.write(content) + return f"File {file_path} edited successfully." + + +# Agent +agent = Agent( + agent_name="Devin", + system_prompt=( + "Autonomous agent that can interact with humans and other" + " agents. Be Helpful and Kind. Use the tools provided to" + " assist the user. Return all code in markdown format." + ), + llm=llm, + max_loops="auto", + autosave=True, + dashboard=False, + streaming_on=True, + verbose=True, + stopping_token="", + interactive=True, + tools=[terminal, browser, file_editor, create_file], + long_term_memory=chromadb, + output_type=tool_schema, # or dict, or str + metadata_output_type="json", + # List of schemas that the agent can handle + list_tool_schemas=[tool_schema], + function_calling_format_type="OpenAI", + function_calling_type="json", # or soon yaml +) + +# Run the agent +out = agent.run("Create a new file for a plan to take over the world.") +print(out) diff --git a/auto_docs.py b/scripts/auto_docs.py similarity index 100% rename from auto_docs.py rename to scripts/auto_docs.py diff --git a/swarms/models/__init__.py b/swarms/models/__init__.py index 40e72831..02e0e7c2 100644 --- a/swarms/models/__init__.py +++ b/swarms/models/__init__.py @@ -31,7 +31,7 @@ from swarms.models.popular_llms import ( ) from swarms.models.popular_llms import OctoAIChat from swarms.models.qwen import QwenVLMultiModal # noqa: E402 - +from swarms.models.popular_llms import ReplicateChat as Replicate from swarms.models.sampling_params import SamplingParams, SamplingType from swarms.models.together import TogetherLLM # noqa: E402 from swarms.models.types import ( # noqa: E402 diff --git a/swarms/structs/agent.py b/swarms/structs/agent.py index 304fa7ad..928d3619 100644 --- a/swarms/structs/agent.py +++ b/swarms/structs/agent.py @@ -33,6 +33,8 @@ from swarms.tools.pydantic_to_json import ( multi_pydantic_to_functions, ) from swarms.structs.schemas import Step, ManySteps +from swarms.telemetry.user_utils import get_user_device_data +from swarms.structs.yaml_model import YamlModel # Utils @@ -228,6 +230,7 @@ class Agent: function_calling_format_type: Optional[str] = "OpenAI", list_tool_schemas: Optional[List[BaseModel]] = None, metadata_output_type: str = "json", + state_save_file_type: str = "json", *args, **kwargs, ): @@ -295,6 +298,7 @@ class Agent: self.function_calling_format_type = function_calling_format_type self.list_tool_schemas = list_tool_schemas self.metadata_output_type = metadata_output_type + self.state_save_file_type = state_save_file_type # The max_loops will be set dynamically if the dynamic_loop if self.dynamic_loops: @@ -903,7 +907,7 @@ class Agent: if self.autosave: logger.info("Autosaving agent state.") - self.save_state(self.saved_state_path) + self.save_state(self.saved_state_path, task) # Apply the cleaner function to the response if self.output_cleaner is not None: @@ -1131,7 +1135,7 @@ class Agent: def graceful_shutdown(self): """Gracefully shutdown the system saving the state""" print(colored("Shutting down the system...", "red")) - return self.save_state("flow_state.json") + return self.save_state(f"{self.agent_name}.json") def run_with_timeout(self, task: str, timeout: int = 60) -> str: """Run the loop but stop if it takes longer than the timeout""" @@ -1236,7 +1240,10 @@ class Agent: except Exception as error: print(colored(f"Error saving agent to YAML: {error}", "red")) - def save_state(self, file_path: str) -> None: + def get_llm_parameters(self): + return str(vars(self.llm)) + + def save_state(self, file_path: str, task: str = None) -> None: """ Saves the current state of the agent to a JSON file, including the llm parameters. @@ -1247,58 +1254,96 @@ class Agent: >>> agent.save_state('saved_flow.json') """ try: - logger.info(f"Saving agent state to: {file_path}") + logger.info( + f"Saving Agent {self.agent_name} state to: {file_path}" + ) state = { "agent_id": str(self.id), "agent_name": self.agent_name, "agent_description": self.agent_description, + "LLM": str(self.get_llm_parameters()), "system_prompt": self.system_prompt, - "sop": self.sop, - "short_memory": ( - self.short_memory.return_history_as_string() - ), + "short_memory": self.short_memory.return_history_as_string(), "loop_interval": self.loop_interval, "retry_attempts": self.retry_attempts, "retry_interval": self.retry_interval, "interactive": self.interactive, "dashboard": self.dashboard, - "dynamic_temperature": (self.dynamic_temperature_enabled), + "dynamic_temperature": self.dynamic_temperature_enabled, "autosave": self.autosave, "saved_state_path": self.saved_state_path, "max_loops": self.max_loops, + "StepCache": self.step_cache, + "Task": task, + "Stopping Token": self.stopping_token, + "Dynamic Loops": self.dynamic_loops, + "tools": self.tools, + "sop": self.sop, + "sop_list": self.sop_list, + "context_length": self.context_length, + "user_name": self.user_name, + "self_healing_enabled": self.self_healing_enabled, + "code_interpreter": self.code_interpreter, + "multi_modal": self.multi_modal, + "pdf_path": self.pdf_path, + "list_of_pdf": self.list_of_pdf, + "tokenizer": self.tokenizer, + "long_term_memory": self.long_term_memory, + "preset_stopping_token": self.preset_stopping_token, + "traceback": self.traceback, + "traceback_handlers": self.traceback_handlers, + "streaming_on": self.streaming_on, + "docs": self.docs, + "docs_folder": self.docs_folder, + "verbose": self.verbose, + "parser": self.parser, + "best_of_n": self.best_of_n, + "callback": self.callback, + "metadata": self.metadata, + "callbacks": self.callbacks, + # "logger_handler": self.logger_handler, + "search_algorithm": self.search_algorithm, + "logs_to_filename": self.logs_to_filename, + "evaluator": self.evaluator, + "output_json": self.output_json, + "stopping_func": self.stopping_func, + "custom_loop_condition": self.custom_loop_condition, + "sentiment_threshold": self.sentiment_threshold, + "custom_exit_command": self.custom_exit_command, + "sentiment_analyzer": self.sentiment_analyzer, + "limit_tokens_from_string": self.limit_tokens_from_string, + # "custom_tools_prompt": self.custom_tools_prompt, + "tool_schema": self.tool_schema, + "output_type": self.output_type, + "function_calling_type": self.function_calling_type, + "output_cleaner": self.output_cleaner, + "function_calling_format_type": self.function_calling_format_type, + "list_tool_schemas": self.list_tool_schemas, + "metadata_output_type": self.metadata_output_type, + "user_meta_data": get_user_device_data(), } - with open(file_path, "w") as f: - json.dump(state, f, indent=4) + # Save as JSON + if self.state_save_file_type == "json": + with open(file_path, "w") as f: + json.dump(state, f, indent=4) + + # Save as YAML + elif self.state_save_file_type == "yaml": + out = YamlModel(input_dict=state).to_yaml() + with open(self.saved_state_path, "w") as f: + f.write(out) + # Log the saved state saved = colored(f"Saved agent state to: {file_path}", "green") print(saved) except Exception as error: print(colored(f"Error saving agent state: {error}", "red")) - def state_to_str(self): + def state_to_str(self, task: str): """Transform the JSON into a string""" try: - state = { - "agent_id": str(self.id), - "agent_name": self.agent_name, - "agent_description": self.agent_description, - "system_prompt": self.system_prompt, - "sop": self.sop, - "short_memory": ( - self.short_memory.return_history_as_string() - ), - "loop_interval": self.loop_interval, - "retry_attempts": self.retry_attempts, - "retry_interval": self.retry_interval, - "interactive": self.interactive, - "dashboard": self.dashboard, - "dynamic_temperature": (self.dynamic_temperature_enabled), - "autosave": self.autosave, - "saved_state_path": self.saved_state_path, - "max_loops": self.max_loops, - } - out = str(state) + out = self.save_state(self.saved_state_path, task) return out except Exception as error: print( diff --git a/swarms/structs/yaml_model.py b/swarms/structs/yaml_model.py index 5e242867..69d6a231 100644 --- a/swarms/structs/yaml_model.py +++ b/swarms/structs/yaml_model.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel +from pydantic import BaseModel, Field import yaml import json from swarms.utils.loguru_logger import logger @@ -120,11 +120,17 @@ class YamlModel(BaseModel): >>> user.save_to_yaml('user.yaml') """ + input_dict: Dict[str, Any] = Field( + None, + title="Data", + description="The data to be serialized to YAML.", + ) + def to_yaml(self): """ Serialize the Pydantic model instance to a YAML string. """ - return yaml.safe_dump(self.dict(), sort_keys=False) + return yaml.safe_dump(self.input_dict, sort_keys=False) def from_yaml(self, cls, yaml_str: str): """ diff --git a/swarms/telemetry/sys_info.py b/swarms/telemetry/sys_info.py index ae59e792..7ddf809e 100644 --- a/swarms/telemetry/sys_info.py +++ b/swarms/telemetry/sys_info.py @@ -88,11 +88,10 @@ def get_package_mismatches(file_path="pyproject.toml"): def system_info(): - swarms_verison = get_swarms_verison() return { "Python Version": get_python_version(), "Pip Version": get_pip_version(), - "Swarms Version": swarms_verison, + # "Swarms Version": swarms_verison, "OS Version and Architecture": get_os_version(), "CPU Info": get_cpu_info(), "RAM Info": get_ram_info(), diff --git a/swarms/telemetry/user_utils.py b/swarms/telemetry/user_utils.py index e38a1648..9da52a4c 100644 --- a/swarms/telemetry/user_utils.py +++ b/swarms/telemetry/user_utils.py @@ -3,7 +3,6 @@ import platform import socket import uuid -from swarms.telemetry.check_update import check_for_package from swarms.telemetry.sys_info import system_info @@ -83,6 +82,5 @@ def get_user_device_data(): "Machine ID": get_machine_id(), "System Info": get_system_info(), "UniqueID": generate_unique_identifier(), - "Swarms [Version]": check_for_package("swarms"), } return data diff --git a/swarms/tools/__init__.py b/swarms/tools/__init__.py index 1c723993..7c243efb 100644 --- a/swarms/tools/__init__.py +++ b/swarms/tools/__init__.py @@ -21,6 +21,10 @@ from swarms.tools.pydantic_to_json import ( function_to_str, functions_to_str, ) +from swarms.tools.openai_func_calling_schema import ( + OpenAIFunctionCallSchema, +) + __all__ = [ "scrape_tool_func_docs", @@ -43,4 +47,5 @@ __all__ = [ "multi_pydantic_to_functions", "function_to_str", "functions_to_str", + "OpenAIFunctionCallSchema", ] diff --git a/swarms/tools/openai_func_calling_schema.py b/swarms/tools/openai_func_calling_schema.py new file mode 100644 index 00000000..ade30143 --- /dev/null +++ b/swarms/tools/openai_func_calling_schema.py @@ -0,0 +1,41 @@ +from pydantic import BaseModel, Field +from typing import List + + +class FunctionSchema(BaseModel): + name: str = Field( + ..., + title="Name", + description="The name of the function.", + ) + description: str = Field( + ..., + title="Description", + description="The description of the function.", + ) + parameters: BaseModel = Field( + ..., + title="Parameters", + description="The parameters of the function.", + ) + + +class OpenAIFunctionCallSchema(BaseModel): + """ + Represents the schema for an OpenAI function call. + + Attributes: + type (str): The type of the function. + function (List[FunctionSchema]): The function to call. + """ + + type: str = Field( + "function", + title="Type", + description="The type of the function.", + ) + function: List[FunctionSchema] = Field( + ..., + title="Function", + description="The function to call.", + ) diff --git a/swarms/tools/py_func_to_openai_func_str.py b/swarms/tools/py_func_to_openai_func_str.py new file mode 100644 index 00000000..fce670fc --- /dev/null +++ b/swarms/tools/py_func_to_openai_func_str.py @@ -0,0 +1,547 @@ +import functools +import inspect +import json +from logging import getLogger +from typing import ( + Any, + Callable, + Dict, + ForwardRef, + List, + Optional, + Set, + Tuple, + Type, + TypeVar, + Union, + get_args, +) + +from pydantic import BaseModel, Field +from pydantic.version import VERSION as PYDANTIC_VERSION +from typing_extensions import Annotated, Literal, get_args, get_origin + +T = TypeVar("T") + +__all__ = ( + "JsonSchemaValue", + "model_dump", + "model_dump_json", + "type2schema", + "evaluate_forwardref", +) + +PYDANTIC_V1 = PYDANTIC_VERSION.startswith("1.") + +logger = getLogger(__name__) + + +if not PYDANTIC_V1: + from pydantic import TypeAdapter + from pydantic._internal._typing_extra import ( + eval_type_lenient as evaluate_forwardref, + ) + from pydantic.json_schema import JsonSchemaValue + + def type2schema(t: Any) -> JsonSchemaValue: + """Convert a type to a JSON schema + + Args: + t (Type): The type to convert + + Returns: + JsonSchemaValue: The JSON schema + """ + return TypeAdapter(t).json_schema() + + def model_dump(model: BaseModel) -> Dict[str, Any]: + """Convert a pydantic model to a dict + + Args: + model (BaseModel): The model to convert + + Returns: + Dict[str, Any]: The dict representation of the model + + """ + return model.model_dump() + + def model_dump_json(model: BaseModel) -> str: + """Convert a pydantic model to a JSON string + + Args: + model (BaseModel): The model to convert + + Returns: + str: The JSON string representation of the model + """ + return model.model_dump_json() + + +# Remove this once we drop support for pydantic 1.x +else: # pragma: no cover + from pydantic import schema_of + from pydantic.typing import ( + evaluate_forwardref as evaluate_forwardref, # type: ignore[no-redef] + ) + + JsonSchemaValue = Dict[str, Any] # type: ignore[misc] + + def type2schema(t: Any) -> JsonSchemaValue: + """Convert a type to a JSON schema + + Args: + t (Type): The type to convert + + Returns: + JsonSchemaValue: The JSON schema + """ + if PYDANTIC_V1: + if t is None: + return {"type": "null"} + elif get_origin(t) is Union: + return {"anyOf": [type2schema(tt) for tt in get_args(t)]} + elif get_origin(t) in [Tuple, tuple]: + prefixItems = [type2schema(tt) for tt in get_args(t)] + return { + "maxItems": len(prefixItems), + "minItems": len(prefixItems), + "prefixItems": prefixItems, + "type": "array", + } + + d = schema_of(t) + if "title" in d: + d.pop("title") + if "description" in d: + d.pop("description") + + return d + + def model_dump(model: BaseModel) -> Dict[str, Any]: + """Convert a pydantic model to a dict + + Args: + model (BaseModel): The model to convert + + Returns: + Dict[str, Any]: The dict representation of the model + + """ + return model.dict() + + def model_dump_json(model: BaseModel) -> str: + """Convert a pydantic model to a JSON string + + Args: + model (BaseModel): The model to convert + + Returns: + str: The JSON string representation of the model + """ + return model.json() + + +def get_typed_annotation(annotation: Any, globalns: Dict[str, Any]) -> Any: + """Get the type annotation of a parameter. + + Args: + annotation: The annotation of the parameter + globalns: The global namespace of the function + + Returns: + The type annotation of the parameter + """ + if isinstance(annotation, str): + annotation = ForwardRef(annotation) + annotation = evaluate_forwardref(annotation, globalns, globalns) + return annotation + + +def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: + """Get the signature of a function with type annotations. + + Args: + call: The function to get the signature for + + Returns: + The signature of the function with type annotations + """ + signature = inspect.signature(call) + globalns = getattr(call, "__globals__", {}) + typed_params = [ + inspect.Parameter( + name=param.name, + kind=param.kind, + default=param.default, + annotation=get_typed_annotation(param.annotation, globalns), + ) + for param in signature.parameters.values() + ] + typed_signature = inspect.Signature(typed_params) + return typed_signature + + +def get_typed_return_annotation(call: Callable[..., Any]) -> Any: + """Get the return annotation of a function. + + Args: + call: The function to get the return annotation for + + Returns: + The return annotation of the function + """ + signature = inspect.signature(call) + annotation = signature.return_annotation + + if annotation is inspect.Signature.empty: + return None + + globalns = getattr(call, "__globals__", {}) + return get_typed_annotation(annotation, globalns) + + +def get_param_annotations( + typed_signature: inspect.Signature, +) -> Dict[str, Union[Annotated[Type[Any], str], Type[Any]]]: + """Get the type annotations of the parameters of a function + + Args: + typed_signature: The signature of the function with type annotations + + Returns: + A dictionary of the type annotations of the parameters of the function + """ + return { + k: v.annotation + for k, v in typed_signature.parameters.items() + if v.annotation is not inspect.Signature.empty + } + + +class Parameters(BaseModel): + """Parameters of a function as defined by the OpenAI API""" + + type: Literal["object"] = "object" + properties: Dict[str, JsonSchemaValue] + required: List[str] + + +class Function(BaseModel): + """A function as defined by the OpenAI API""" + + description: Annotated[ + str, Field(description="Description of the function") + ] + name: Annotated[str, Field(description="Name of the function")] + parameters: Annotated[ + Parameters, Field(description="Parameters of the function") + ] + + +class ToolFunction(BaseModel): + """A function under tool as defined by the OpenAI API.""" + + type: Literal["function"] = "function" + function: Annotated[Function, Field(description="Function under tool")] + + +def get_parameter_json_schema( + k: str, v: Any, default_values: Dict[str, Any] +) -> JsonSchemaValue: + """Get a JSON schema for a parameter as defined by the OpenAI API + + Args: + k: The name of the parameter + v: The type of the parameter + default_values: The default values of the parameters of the function + + Returns: + A Pydanitc model for the parameter + """ + + def type2description( + k: str, v: Union[Annotated[Type[Any], str], Type[Any]] + ) -> str: + # handles Annotated + if hasattr(v, "__metadata__"): + retval = v.__metadata__[0] + if isinstance(retval, str): + return retval + else: + raise ValueError( + f"Invalid description {retval} for parameter {k}, should be a string." + ) + else: + return k + + schema = type2schema(v) + if k in default_values: + dv = default_values[k] + schema["default"] = dv + + schema["description"] = type2description(k, v) + + return schema + + +def get_required_params(typed_signature: inspect.Signature) -> List[str]: + """Get the required parameters of a function + + Args: + signature: The signature of the function as returned by inspect.signature + + Returns: + A list of the required parameters of the function + """ + return [ + k + for k, v in typed_signature.parameters.items() + if v.default == inspect.Signature.empty + ] + + +def get_default_values( + typed_signature: inspect.Signature, +) -> Dict[str, Any]: + """Get default values of parameters of a function + + Args: + signature: The signature of the function as returned by inspect.signature + + Returns: + A dictionary of the default values of the parameters of the function + """ + return { + k: v.default + for k, v in typed_signature.parameters.items() + if v.default != inspect.Signature.empty + } + + +def get_parameters( + required: List[str], + param_annotations: Dict[ + str, Union[Annotated[Type[Any], str], Type[Any]] + ], + default_values: Dict[str, Any], +) -> Parameters: + """Get the parameters of a function as defined by the OpenAI API + + Args: + required: The required parameters of the function + hints: The type hints of the function as returned by typing.get_type_hints + + Returns: + A Pydantic model for the parameters of the function + """ + return Parameters( + properties={ + k: get_parameter_json_schema(k, v, default_values) + for k, v in param_annotations.items() + if v is not inspect.Signature.empty + }, + required=required, + ) + + +def get_missing_annotations( + typed_signature: inspect.Signature, required: List[str] +) -> Tuple[Set[str], Set[str]]: + """Get the missing annotations of a function + + Ignores the parameters with default values as they are not required to be annotated, but logs a warning. + Args: + typed_signature: The signature of the function with type annotations + required: The required parameters of the function + + Returns: + A set of the missing annotations of the function + """ + all_missing = { + k + for k, v in typed_signature.parameters.items() + if v.annotation is inspect.Signature.empty + } + missing = all_missing.intersection(set(required)) + unannotated_with_default = all_missing.difference(missing) + return missing, unannotated_with_default + + +def get_openai_function_schema( + function: Callable[..., Any], + *, + name: Optional[str] = None, + description: str, +) -> Dict[str, Any]: + """Get a JSON schema for a function as defined by the OpenAI API + + Args: + f: The function to get the JSON schema for + name: The name of the function + description: The description of the function + + Returns: + A JSON schema for the function + + Raises: + TypeError: If the function is not annotated + + Examples: + + ```python + def f(a: Annotated[str, "Parameter a"], b: int = 2, c: Annotated[float, "Parameter c"] = 0.1) -> None: + pass + + get_function_schema(f, description="function f") + + # {'type': 'function', + # 'function': {'description': 'function f', + # 'name': 'f', + # 'parameters': {'type': 'object', + # 'properties': {'a': {'type': 'str', 'description': 'Parameter a'}, + # 'b': {'type': 'int', 'description': 'b'}, + # 'c': {'type': 'float', 'description': 'Parameter c'}}, + # 'required': ['a']}}} + ``` + + """ + typed_signature = get_typed_signature(function) + required = get_required_params(typed_signature) + default_values = get_default_values(typed_signature) + param_annotations = get_param_annotations(typed_signature) + return_annotation = get_typed_return_annotation(function) + missing, unannotated_with_default = get_missing_annotations( + typed_signature, required + ) + + if return_annotation is None: + logger.warning( + f"The return type of the function '{function.__name__}' is not annotated. Although annotating it is " + + "optional, the function should return either a string, a subclass of 'pydantic.BaseModel'." + ) + + if unannotated_with_default != set(): + unannotated_with_default_s = [ + f"'{k}'" for k in sorted(unannotated_with_default) + ] + logger.warning( + f"The following parameters of the function '{function.__name__}' with default values are not annotated: " + + f"{', '.join(unannotated_with_default_s)}." + ) + + if missing != set(): + missing_s = [f"'{k}'" for k in sorted(missing)] + raise TypeError( + f"All parameters of the function '{function.__name__}' without default values must be annotated. " + + f"The annotations are missing for the following parameters: {', '.join(missing_s)}" + ) + + fname = name if name else function.__name__ + + parameters = get_parameters( + required, param_annotations, default_values=default_values + ) + + function = ToolFunction( + function=Function( + description=description, + name=fname, + parameters=parameters, + ) + ) + + return model_dump(function) + +def test(a: int = 1, b: int = 2): + return a + b + +# +def get_load_param_if_needed_function( + t: Any, +) -> Optional[Callable[[Dict[str, Any], Type[BaseModel]], BaseModel]]: + """Get a function to load a parameter if it is a Pydantic model + + Args: + t: The type annotation of the parameter + + Returns: + A function to load the parameter if it is a Pydantic model, otherwise None + + """ + if get_origin(t) is Annotated: + return get_load_param_if_needed_function(get_args(t)[0]) + + def load_base_model( + v: Dict[str, Any], t: Type[BaseModel] + ) -> BaseModel: + return t(**v) + + return ( + load_base_model + if isinstance(t, type) and issubclass(t, BaseModel) + else None + ) + + +def load_basemodels_if_needed( + func: Callable[..., Any] +) -> Callable[..., Any]: + """A decorator to load the parameters of a function if they are Pydantic models + + Args: + func: The function with annotated parameters + + Returns: + A function that loads the parameters before calling the original function + + """ + # get the type annotations of the parameters + typed_signature = get_typed_signature(func) + param_annotations = get_param_annotations(typed_signature) + + # get functions for loading BaseModels when needed based on the type annotations + kwargs_mapping_with_nones = { + k: get_load_param_if_needed_function(t) + for k, t in param_annotations.items() + } + + # remove the None values + kwargs_mapping = { + k: f for k, f in kwargs_mapping_with_nones.items() if f is not None + } + + # a function that loads the parameters before calling the original function + @functools.wraps(func) + def _load_parameters_if_needed(*args: Any, **kwargs: Any) -> Any: + # load the BaseModels if needed + for k, f in kwargs_mapping.items(): + kwargs[k] = f(kwargs[k], param_annotations[k]) + + # call the original function + return func(*args, **kwargs) + + @functools.wraps(func) + async def _a_load_parameters_if_needed( + *args: Any, **kwargs: Any + ) -> Any: + # load the BaseModels if needed + for k, f in kwargs_mapping.items(): + kwargs[k] = f(kwargs[k], param_annotations[k]) + + # call the original function + return await func(*args, **kwargs) + + if inspect.iscoroutinefunction(func): + return _a_load_parameters_if_needed + else: + return _load_parameters_if_needed + + +def serialize_to_str(x: Any) -> str: + if isinstance(x, str): + return x + elif isinstance(x, BaseModel): + return model_dump_json(x) + else: + return json.dumps(x)