From 3c3d05fa1c6afc57ca095e52028a29bde95a73b8 Mon Sep 17 00:00:00 2001 From: Kye Date: Thu, 12 Oct 2023 11:55:38 -0400 Subject: [PATCH] tools --- swarms/agents/conversabe_agent.py | 2 +- swarms/agents/meta_prompter.py | 2 +- .../groundingdino/datasets/transforms.py | 2 +- .../models/GroundingDINO/backbone/backbone.py | 4 +- .../backbone/swin_transformer.py | 2 +- .../models/GroundingDINO/bertwarper.py | 10 +- .../models/GroundingDINO/utils.py | 2 +- .../agents/models/groundingdino/util/utils.py | 2 +- .../models/groundingdino/util/vl_utils.py | 2 +- .../segment_anything/modeling/mask_decoder.py | 2 +- .../segment_anything/utils/amg.py | 4 +- swarms/agents/multi_modal_visual_agent.py | 4 +- .../omni_agent/omni_chat.py | 2 +- swarms/embeddings/openai.py | 8 +- swarms/models/chat_openai.py | 4 +- swarms/tools/__init__.py | 1 + swarms/tools/autogpt.py | 2 +- swarms/tools/developer.py | 8 +- swarms/tools/tool.py | 845 ++++++++++++++++++ swarms/tools/tool_registry.py | 45 + swarms/utils/main.py | 2 +- 21 files changed, 923 insertions(+), 32 deletions(-) create mode 100644 swarms/tools/tool.py create mode 100644 swarms/tools/tool_registry.py diff --git a/swarms/agents/conversabe_agent.py b/swarms/agents/conversabe_agent.py index 1ef2f647..35808c4b 100644 --- a/swarms/agents/conversabe_agent.py +++ b/swarms/agents/conversabe_agent.py @@ -977,7 +977,7 @@ class ConversableAgent(Agent): ) elif lang in ["python", "Python"]: if code.startswith("# filename: "): - filename = code[11: code.find("\n")].strip() + filename = code[11 : code.find("\n")].strip() else: filename = None exitcode, logs, image = self.run_code( diff --git a/swarms/agents/meta_prompter.py b/swarms/agents/meta_prompter.py index 10bfc6a1..7488a25c 100644 --- a/swarms/agents/meta_prompter.py +++ b/swarms/agents/meta_prompter.py @@ -108,7 +108,7 @@ class MetaPrompterAgent: def get_new_instructions(self, meta_output): """Get New Instructions from the meta_output""" delimiter = "Instructions: " - new_instructions = meta_output[meta_output.find(delimiter) + len(delimiter):] + new_instructions = meta_output[meta_output.find(delimiter) + len(delimiter) :] return new_instructions def run(self, task: str): diff --git a/swarms/agents/models/groundingdino/datasets/transforms.py b/swarms/agents/models/groundingdino/datasets/transforms.py index 5d6d2cfd..c34a1453 100644 --- a/swarms/agents/models/groundingdino/datasets/transforms.py +++ b/swarms/agents/models/groundingdino/datasets/transforms.py @@ -38,7 +38,7 @@ def crop(image, target, region): if "masks" in target: # FIXME should we update the area here if there are no boxes? - target["masks"] = target["masks"][:, i: i + h, j: j + w] + target["masks"] = target["masks"][:, i : i + h, j : j + w] fields.append("masks") # remove elements for which the boxes or masks that have zero area diff --git a/swarms/agents/models/groundingdino/models/GroundingDINO/backbone/backbone.py b/swarms/agents/models/groundingdino/models/GroundingDINO/backbone/backbone.py index 91e74de4..a56f369e 100644 --- a/swarms/agents/models/groundingdino/models/GroundingDINO/backbone/backbone.py +++ b/swarms/agents/models/groundingdino/models/GroundingDINO/backbone/backbone.py @@ -159,7 +159,7 @@ class Backbone(BackboneBase): ), "Only resnet50 and resnet101 are available." assert return_interm_indices in [[0, 1, 2, 3], [1, 2, 3], [3]] num_channels_all = [256, 512, 1024, 2048] - num_channels = num_channels_all[4 - len(return_interm_indices):] + num_channels = num_channels_all[4 - len(return_interm_indices) :] super().__init__(backbone, train_backbone, num_channels, return_interm_indices) @@ -224,7 +224,7 @@ def build_backbone(args): use_checkpoint=use_checkpoint, ) - bb_num_channels = backbone.num_features[4 - len(return_interm_indices):] + bb_num_channels = backbone.num_features[4 - len(return_interm_indices) :] else: raise NotImplementedError("Unknown backbone {}".format(args.backbone)) diff --git a/swarms/agents/models/groundingdino/models/GroundingDINO/backbone/swin_transformer.py b/swarms/agents/models/groundingdino/models/GroundingDINO/backbone/swin_transformer.py index b476627e..1a74ca36 100644 --- a/swarms/agents/models/groundingdino/models/GroundingDINO/backbone/swin_transformer.py +++ b/swarms/agents/models/groundingdino/models/GroundingDINO/backbone/swin_transformer.py @@ -649,7 +649,7 @@ class SwinTransformer(nn.Module): qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, - drop_path=dpr[sum(depths[:i_layer]): sum(depths[: i_layer + 1])], + drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])], norm_layer=norm_layer, # downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, downsample=downsamplelist[i_layer], diff --git a/swarms/agents/models/groundingdino/models/GroundingDINO/bertwarper.py b/swarms/agents/models/groundingdino/models/GroundingDINO/bertwarper.py index 7a46aa70..2ad9c020 100644 --- a/swarms/agents/models/groundingdino/models/GroundingDINO/bertwarper.py +++ b/swarms/agents/models/groundingdino/models/GroundingDINO/bertwarper.py @@ -221,9 +221,9 @@ def generate_masks_with_special_tokens(tokenized, special_tokens_list, tokenizer position_ids[row, col] = 0 else: attention_mask[ - row, previous_col + 1: col + 1, previous_col + 1: col + 1 + row, previous_col + 1 : col + 1, previous_col + 1 : col + 1 ] = True - position_ids[row, previous_col + 1: col + 1] = torch.arange( + position_ids[row, previous_col + 1 : col + 1] = torch.arange( 0, col - previous_col, device=input_ids.device ) @@ -273,13 +273,13 @@ def generate_masks_with_special_tokens_and_transfer_map( position_ids[row, col] = 0 else: attention_mask[ - row, previous_col + 1: col + 1, previous_col + 1: col + 1 + row, previous_col + 1 : col + 1, previous_col + 1 : col + 1 ] = True - position_ids[row, previous_col + 1: col + 1] = torch.arange( + position_ids[row, previous_col + 1 : col + 1] = torch.arange( 0, col - previous_col, device=input_ids.device ) c2t_maski = torch.zeros((num_token), device=input_ids.device).bool() - c2t_maski[previous_col + 1: col] = True + c2t_maski[previous_col + 1 : col] = True cate_to_token_mask_list[row].append(c2t_maski) previous_col = col diff --git a/swarms/agents/models/groundingdino/models/GroundingDINO/utils.py b/swarms/agents/models/groundingdino/models/GroundingDINO/utils.py index 9488f827..2bb3e9b8 100644 --- a/swarms/agents/models/groundingdino/models/GroundingDINO/utils.py +++ b/swarms/agents/models/groundingdino/models/GroundingDINO/utils.py @@ -76,7 +76,7 @@ def gen_encoder_output_proposals( proposals = [] _cur = 0 for lvl, (H_, W_) in enumerate(spatial_shapes): - mask_flatten_ = memory_padding_mask[:, _cur: (_cur + H_ * W_)].view( + mask_flatten_ = memory_padding_mask[:, _cur : (_cur + H_ * W_)].view( N_, H_, W_, 1 ) valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1) diff --git a/swarms/agents/models/groundingdino/util/utils.py b/swarms/agents/models/groundingdino/util/utils.py index 90af343d..7a0815ef 100644 --- a/swarms/agents/models/groundingdino/util/utils.py +++ b/swarms/agents/models/groundingdino/util/utils.py @@ -619,7 +619,7 @@ def get_phrases_from_posmap( ): assert isinstance(posmap, torch.Tensor), "posmap must be torch.Tensor" if posmap.dim() == 1: - posmap[0: left_idx + 1] = False + posmap[0 : left_idx + 1] = False posmap[right_idx:] = False non_zero_idx = posmap.nonzero(as_tuple=True)[0].tolist() token_ids = [tokenized["input_ids"][i] for i in non_zero_idx] diff --git a/swarms/agents/models/groundingdino/util/vl_utils.py b/swarms/agents/models/groundingdino/util/vl_utils.py index 44ff4d5e..4fd8592c 100644 --- a/swarms/agents/models/groundingdino/util/vl_utils.py +++ b/swarms/agents/models/groundingdino/util/vl_utils.py @@ -41,7 +41,7 @@ def create_positive_map_from_span(tokenized, token_span, max_text_len=256): positive_map[j, beg_pos] = 1 break else: - positive_map[j, beg_pos: end_pos + 1].fill_(1) + positive_map[j, beg_pos : end_pos + 1].fill_(1) return positive_map / (positive_map.sum(-1)[:, None] + 1e-6) diff --git a/swarms/agents/models/segment_anything/segment_anything/modeling/mask_decoder.py b/swarms/agents/models/segment_anything/segment_anything/modeling/mask_decoder.py index 35170835..f94bee1f 100644 --- a/swarms/agents/models/segment_anything/segment_anything/modeling/mask_decoder.py +++ b/swarms/agents/models/segment_anything/segment_anything/modeling/mask_decoder.py @@ -139,7 +139,7 @@ class MaskDecoder(nn.Module): # Run the transformer hs, src = self.transformer(src, pos_src, tokens) iou_token_out = hs[:, 0, :] - mask_tokens_out = hs[:, 1: (1 + self.num_mask_tokens), :] + mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] # Upscale mask embeddings and predict masks using the mask tokens src = src.transpose(1, 2).view(b, c, h, w) diff --git a/swarms/agents/models/segment_anything/segment_anything/utils/amg.py b/swarms/agents/models/segment_anything/segment_anything/utils/amg.py index cb67232a..be064071 100644 --- a/swarms/agents/models/segment_anything/segment_anything/utils/amg.py +++ b/swarms/agents/models/segment_anything/segment_anything/utils/amg.py @@ -101,7 +101,7 @@ def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]: ), "Batched iteration must have inputs of all the same size." n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0) for b in range(n_batches): - yield [arg[b * batch_size: (b + 1) * batch_size] for arg in args] + yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args] def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]: @@ -142,7 +142,7 @@ def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray: idx = 0 parity = False for count in rle["counts"]: - mask[idx: idx + count] = parity + mask[idx : idx + count] = parity idx += count parity ^= True mask = mask.reshape(w, h) diff --git a/swarms/agents/multi_modal_visual_agent.py b/swarms/agents/multi_modal_visual_agent.py index b0172431..68941ef0 100644 --- a/swarms/agents/multi_modal_visual_agent.py +++ b/swarms/agents/multi_modal_visual_agent.py @@ -207,12 +207,12 @@ def blend_gt2pt(old_image, new_image, sigma=0.15, steps=100): kernel[steps:-steps, :steps] = left kernel[steps:-steps, -steps:] = right - pt_gt_img = easy_img[pos_h: pos_h + old_size[1], pos_w: pos_w + old_size[0]] + pt_gt_img = easy_img[pos_h : pos_h + old_size[1], pos_w : pos_w + old_size[0]] gaussian_gt_img = ( kernel * gt_img_array + (1 - kernel) * pt_gt_img ) # gt img with blur img gaussian_gt_img = gaussian_gt_img.astype(np.int64) - easy_img[pos_h: pos_h + old_size[1], pos_w: pos_w + old_size[0]] = gaussian_gt_img + easy_img[pos_h : pos_h + old_size[1], pos_w : pos_w + old_size[0]] = gaussian_gt_img gaussian_img = Image.fromarray(easy_img) return gaussian_img diff --git a/swarms/agents/multi_modal_workers/omni_agent/omni_chat.py b/swarms/agents/multi_modal_workers/omni_agent/omni_chat.py index 18d87578..2198af25 100644 --- a/swarms/agents/multi_modal_workers/omni_agent/omni_chat.py +++ b/swarms/agents/multi_modal_workers/omni_agent/omni_chat.py @@ -317,7 +317,7 @@ def find_json(s): s = s.replace("'", '"') start = s.find("{") end = s.rfind("}") - res = s[start: end + 1] + res = s[start : end + 1] res = res.replace("\n", "") return res diff --git a/swarms/embeddings/openai.py b/swarms/embeddings/openai.py index 2eba8c71..230dade9 100644 --- a/swarms/embeddings/openai.py +++ b/swarms/embeddings/openai.py @@ -347,7 +347,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): disallowed_special=self.disallowed_special, ) for j in range(0, len(token), self.embedding_ctx_length): - tokens.append(token[j: j + self.embedding_ctx_length]) + tokens.append(token[j : j + self.embedding_ctx_length]) indices.append(i) batched_embeddings: List[List[float]] = [] @@ -366,7 +366,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): for i in _iter: response = embed_with_retry( self, - input=tokens[i: i + _chunk_size], + input=tokens[i : i + _chunk_size], **self._invocation_params, ) batched_embeddings.extend(r["embedding"] for r in response["data"]) @@ -428,7 +428,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): disallowed_special=self.disallowed_special, ) for j in range(0, len(token), self.embedding_ctx_length): - tokens.append(token[j: j + self.embedding_ctx_length]) + tokens.append(token[j : j + self.embedding_ctx_length]) indices.append(i) batched_embeddings: List[List[float]] = [] @@ -436,7 +436,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): for i in range(0, len(tokens), _chunk_size): response = await async_embed_with_retry( self, - input=tokens[i: i + _chunk_size], + input=tokens[i : i + _chunk_size], **self._invocation_params, ) batched_embeddings.extend(r["embedding"] for r in response["data"]) diff --git a/swarms/models/chat_openai.py b/swarms/models/chat_openai.py index 7ffc9136..380623c3 100644 --- a/swarms/models/chat_openai.py +++ b/swarms/models/chat_openai.py @@ -458,7 +458,7 @@ class BaseOpenAI(BaseLLM): ) params["max_tokens"] = self.max_tokens_for_prompt(prompts[0]) sub_prompts = [ - prompts[i: i + self.batch_size] + prompts[i : i + self.batch_size] for i in range(0, len(prompts), self.batch_size) ] return sub_prompts @@ -469,7 +469,7 @@ class BaseOpenAI(BaseLLM): """Create the LLMResult from the choices and prompts.""" generations = [] for i, _ in enumerate(prompts): - sub_choices = choices[i * self.n: (i + 1) * self.n] + sub_choices = choices[i * self.n : (i + 1) * self.n] generations.append( [ Generation( diff --git a/swarms/tools/__init__.py b/swarms/tools/__init__.py index 54785578..c803f404 100644 --- a/swarms/tools/__init__.py +++ b/swarms/tools/__init__.py @@ -7,3 +7,4 @@ # from swarms.tools.requests import RequestsGet # from swarms.tools.developer import Terminal, CodeEditor +from swarms.tools.tool import tool diff --git a/swarms/tools/autogpt.py b/swarms/tools/autogpt.py index 50732063..1755d259 100644 --- a/swarms/tools/autogpt.py +++ b/swarms/tools/autogpt.py @@ -129,7 +129,7 @@ class WebpageQATool(BaseTool): results = [] # TODO: Handle this with a MapReduceChain for i in range(0, len(web_docs), 4): - input_docs = web_docs[i: i + 4] + input_docs = web_docs[i : i + 4] window_result = self.qa_chain( {"input_documents": input_docs, "question": question}, return_only_outputs=True, diff --git a/swarms/tools/developer.py b/swarms/tools/developer.py index 062f463b..04e4b30a 100644 --- a/swarms/tools/developer.py +++ b/swarms/tools/developer.py @@ -306,7 +306,7 @@ class WriteCommand: @staticmethod def from_str(command: str) -> "WriteCommand": filepath = command.split(WriteCommand.separator)[0] - return WriteCommand(filepath, command[len(filepath) + 1:]) + return WriteCommand(filepath, command[len(filepath) + 1 :]) class CodeWriter: @@ -433,7 +433,7 @@ class ReadCommand: if self.start == self.end: code = code[self.start - 1] else: - code = "".join(code[self.start - 1: self.end]) + code = "".join(code[self.start - 1 : self.end]) return code @staticmethod @@ -590,9 +590,9 @@ class PatchCommand: lines[self.start.line] = ( lines[self.start.line][: self.start.col] + self.content - + lines[self.end.line][self.end.col:] + + lines[self.end.line][self.end.col :] ) - lines = lines[: self.start.line + 1] + lines[self.end.line + 1:] + lines = lines[: self.start.line + 1] + lines[self.end.line + 1 :] after = self.write_lines(lines) diff --git a/swarms/tools/tool.py b/swarms/tools/tool.py new file mode 100644 index 00000000..127328c7 --- /dev/null +++ b/swarms/tools/tool.py @@ -0,0 +1,845 @@ +"""Base implementation for tools or skills.""" +from __future__ import annotations + +import asyncio +import inspect +import warnings +from abc import abstractmethod +from functools import partial +from inspect import signature +from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union + +from langchain.callbacks.base import BaseCallbackManager +from langchain.callbacks.manager import ( + AsyncCallbackManager, + AsyncCallbackManagerForToolRun, + CallbackManager, + CallbackManagerForToolRun, + Callbacks, +) + +from langchain.load.serializable import Serializable +from langchain.pydantic_v1 import ( + BaseModel, + Extra, + Field, + create_model, + root_validator, + validate_arguments, +) +from langchain.schema.runnable import Runnable, RunnableConfig, RunnableSerializable + + +class SchemaAnnotationError(TypeError): + """Raised when 'args_schema' is missing or has an incorrect type annotation.""" + + +def _create_subset_model( + name: str, model: BaseModel, field_names: list +) -> Type[BaseModel]: + """Create a pydantic model with only a subset of model's fields.""" + fields = {} + for field_name in field_names: + field = model.__fields__[field_name] + fields[field_name] = (field.outer_type_, field.field_info) + return create_model(name, **fields) # type: ignore + + +def _get_filtered_args( + inferred_model: Type[BaseModel], + func: Callable, +) -> dict: + """Get the arguments from a function's signature.""" + schema = inferred_model.schema()["properties"] + valid_keys = signature(func).parameters + return {k: schema[k] for k in valid_keys if k not in ("run_manager", "callbacks")} + + +class _SchemaConfig: + """Configuration for the pydantic model.""" + + extra: Any = Extra.forbid + arbitrary_types_allowed: bool = True + + +def create_schema_from_function( + model_name: str, + func: Callable, +) -> Type[BaseModel]: + """Create a pydantic schema from a function's signature. + Args: + model_name: Name to assign to the generated pydandic schema + func: Function to generate the schema from + Returns: + A pydantic model with the same arguments as the function + """ + # https://docs.pydantic.dev/latest/usage/validation_decorator/ + validated = validate_arguments(func, config=_SchemaConfig) # type: ignore + inferred_model = validated.model # type: ignore + if "run_manager" in inferred_model.__fields__: + del inferred_model.__fields__["run_manager"] + if "callbacks" in inferred_model.__fields__: + del inferred_model.__fields__["callbacks"] + # Pydantic adds placeholder virtual fields we need to strip + valid_properties = _get_filtered_args(inferred_model, func) + return _create_subset_model( + f"{model_name}Schema", inferred_model, list(valid_properties) + ) + + +class ToolException(Exception): + """An optional exception that tool throws when execution error occurs. + + When this exception is thrown, the agent will not stop working, + but will handle the exception according to the handle_tool_error + variable of the tool, and the processing result will be returned + to the agent as observation, and printed in red on the console. + """ + + pass + + +class BaseTool(RunnableSerializable[Union[str, Dict], Any]): + """Interface LangChain tools must implement.""" + + def __init_subclass__(cls, **kwargs: Any) -> None: + """Create the definition of the new tool class.""" + super().__init_subclass__(**kwargs) + + args_schema_type = cls.__annotations__.get("args_schema", None) + + if args_schema_type is not None: + if args_schema_type is None or args_schema_type == BaseModel: + # Throw errors for common mis-annotations. + # TODO: Use get_args / get_origin and fully + # specify valid annotations. + typehint_mandate = """ +class ChildTool(BaseTool): + ... + args_schema: Type[BaseModel] = SchemaClass + ...""" + name = cls.__name__ + raise SchemaAnnotationError( + f"Tool definition for {name} must include valid type annotations" + f" for argument 'args_schema' to behave as expected.\n" + f"Expected annotation of 'Type[BaseModel]'" + f" but got '{args_schema_type}'.\n" + f"Expected class looks like:\n" + f"{typehint_mandate}" + ) + + name: str + """The unique name of the tool that clearly communicates its purpose.""" + description: str + """Used to tell the model how/when/why to use the tool. + + You can provide few-shot examples as a part of the description. + """ + args_schema: Optional[Type[BaseModel]] = None + """Pydantic model class to validate and parse the tool's input arguments.""" + return_direct: bool = False + """Whether to return the tool's output directly. Setting this to True means + + that after the tool is called, the AgentExecutor will stop looping. + """ + verbose: bool = False + """Whether to log the tool's progress.""" + + callbacks: Callbacks = Field(default=None, exclude=True) + """Callbacks to be called during tool execution.""" + callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True) + """Deprecated. Please use callbacks instead.""" + tags: Optional[List[str]] = None + """Optional list of tags associated with the tool. Defaults to None + These tags will be associated with each call to this tool, + and passed as arguments to the handlers defined in `callbacks`. + You can use these to eg identify a specific instance of a tool with its use case. + """ + metadata: Optional[Dict[str, Any]] = None + """Optional metadata associated with the tool. Defaults to None + This metadata will be associated with each call to this tool, + and passed as arguments to the handlers defined in `callbacks`. + You can use these to eg identify a specific instance of a tool with its use case. + """ + + handle_tool_error: Optional[ + Union[bool, str, Callable[[ToolException], str]] + ] = False + """Handle the content of the ToolException thrown.""" + + class Config(Serializable.Config): + """Configuration for this pydantic object.""" + + arbitrary_types_allowed = True + + @property + def is_single_input(self) -> bool: + """Whether the tool only accepts a single input.""" + keys = {k for k in self.args if k != "kwargs"} + return len(keys) == 1 + + @property + def args(self) -> dict: + if self.args_schema is not None: + return self.args_schema.schema()["properties"] + else: + schema = create_schema_from_function(self.name, self._run) + return schema.schema()["properties"] + + # --- Runnable --- + + @property + def input_schema(self) -> Type[BaseModel]: + """The tool's input schema.""" + if self.args_schema is not None: + return self.args_schema + else: + return create_schema_from_function(self.name, self._run) + + def invoke( + self, + input: Union[str, Dict], + config: Optional[RunnableConfig] = None, + **kwargs: Any, + ) -> Any: + config = config or {} + return self.run( + input, + callbacks=config.get("callbacks"), + tags=config.get("tags"), + metadata=config.get("metadata"), + run_name=config.get("run_name"), + **kwargs, + ) + + async def ainvoke( + self, + input: Union[str, Dict], + config: Optional[RunnableConfig] = None, + **kwargs: Any, + ) -> Any: + config = config or {} + return await self.arun( + input, + callbacks=config.get("callbacks"), + tags=config.get("tags"), + metadata=config.get("metadata"), + run_name=config.get("run_name"), + **kwargs, + ) + + # --- Tool --- + + def _parse_input( + self, + tool_input: Union[str, Dict], + ) -> Union[str, Dict[str, Any]]: + """Convert tool input to pydantic model.""" + input_args = self.args_schema + if isinstance(tool_input, str): + if input_args is not None: + key_ = next(iter(input_args.__fields__.keys())) + input_args.validate({key_: tool_input}) + return tool_input + else: + if input_args is not None: + result = input_args.parse_obj(tool_input) + return {k: v for k, v in result.dict().items() if k in tool_input} + return tool_input + + @root_validator() + def raise_deprecation(cls, values: Dict) -> Dict: + """Raise deprecation warning if callback_manager is used.""" + if values.get("callback_manager") is not None: + warnings.warn( + "callback_manager is deprecated. Please use callbacks instead.", + DeprecationWarning, + ) + values["callbacks"] = values.pop("callback_manager", None) + return values + + @abstractmethod + def _run( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + """Use the tool. + + Add run_manager: Optional[CallbackManagerForToolRun] = None + to child implementations to enable tracing, + """ + + async def _arun( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + """Use the tool asynchronously. + + Add run_manager: Optional[AsyncCallbackManagerForToolRun] = None + to child implementations to enable tracing, + """ + return await asyncio.get_running_loop().run_in_executor( + None, + partial(self._run, **kwargs), + *args, + ) + + def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]: + # For backwards compatibility, if run_input is a string, + # pass as a positional argument. + if isinstance(tool_input, str): + return (tool_input,), {} + else: + return (), tool_input + + def run( + self, + tool_input: Union[str, Dict], + verbose: Optional[bool] = None, + start_color: Optional[str] = "green", + color: Optional[str] = "green", + callbacks: Callbacks = None, + *, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + run_name: Optional[str] = None, + **kwargs: Any, + ) -> Any: + """Run the tool.""" + parsed_input = self._parse_input(tool_input) + if not self.verbose and verbose is not None: + verbose_ = verbose + else: + verbose_ = self.verbose + callback_manager = CallbackManager.configure( + callbacks, + self.callbacks, + verbose_, + tags, + self.tags, + metadata, + self.metadata, + ) + # TODO: maybe also pass through run_manager is _run supports kwargs + new_arg_supported = signature(self._run).parameters.get("run_manager") + run_manager = callback_manager.on_tool_start( + {"name": self.name, "description": self.description}, + tool_input if isinstance(tool_input, str) else str(tool_input), + color=start_color, + name=run_name, + **kwargs, + ) + try: + tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input) + observation = ( + self._run(*tool_args, run_manager=run_manager, **tool_kwargs) + if new_arg_supported + else self._run(*tool_args, **tool_kwargs) + ) + except ToolException as e: + if not self.handle_tool_error: + run_manager.on_tool_error(e) + raise e + elif isinstance(self.handle_tool_error, bool): + if e.args: + observation = e.args[0] + else: + observation = "Tool execution error" + elif isinstance(self.handle_tool_error, str): + observation = self.handle_tool_error + elif callable(self.handle_tool_error): + observation = self.handle_tool_error(e) + else: + raise ValueError( + f"Got unexpected type of `handle_tool_error`. Expected bool, str " + f"or callable. Received: {self.handle_tool_error}" + ) + run_manager.on_tool_end( + str(observation), color="red", name=self.name, **kwargs + ) + return observation + except (Exception, KeyboardInterrupt) as e: + run_manager.on_tool_error(e) + raise e + else: + run_manager.on_tool_end( + str(observation), color=color, name=self.name, **kwargs + ) + return observation + + async def arun( + self, + tool_input: Union[str, Dict], + verbose: Optional[bool] = None, + start_color: Optional[str] = "green", + color: Optional[str] = "green", + callbacks: Callbacks = None, + *, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + run_name: Optional[str] = None, + **kwargs: Any, + ) -> Any: + """Run the tool asynchronously.""" + parsed_input = self._parse_input(tool_input) + if not self.verbose and verbose is not None: + verbose_ = verbose + else: + verbose_ = self.verbose + callback_manager = AsyncCallbackManager.configure( + callbacks, + self.callbacks, + verbose_, + tags, + self.tags, + metadata, + self.metadata, + ) + new_arg_supported = signature(self._arun).parameters.get("run_manager") + run_manager = await callback_manager.on_tool_start( + {"name": self.name, "description": self.description}, + tool_input if isinstance(tool_input, str) else str(tool_input), + color=start_color, + name=run_name, + **kwargs, + ) + try: + # We then call the tool on the tool input to get an observation + tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input) + observation = ( + await self._arun(*tool_args, run_manager=run_manager, **tool_kwargs) + if new_arg_supported + else await self._arun(*tool_args, **tool_kwargs) + ) + except ToolException as e: + if not self.handle_tool_error: + await run_manager.on_tool_error(e) + raise e + elif isinstance(self.handle_tool_error, bool): + if e.args: + observation = e.args[0] + else: + observation = "Tool execution error" + elif isinstance(self.handle_tool_error, str): + observation = self.handle_tool_error + elif callable(self.handle_tool_error): + observation = self.handle_tool_error(e) + else: + raise ValueError( + f"Got unexpected type of `handle_tool_error`. Expected bool, str " + f"or callable. Received: {self.handle_tool_error}" + ) + await run_manager.on_tool_end( + str(observation), color="red", name=self.name, **kwargs + ) + return observation + except (Exception, KeyboardInterrupt) as e: + await run_manager.on_tool_error(e) + raise e + else: + await run_manager.on_tool_end( + str(observation), color=color, name=self.name, **kwargs + ) + return observation + + def __call__(self, tool_input: str, callbacks: Callbacks = None) -> str: + """Make tool callable.""" + return self.run(tool_input, callbacks=callbacks) + + +class Tool(BaseTool): + """Tool that takes in function or coroutine directly.""" + + description: str = "" + func: Optional[Callable[..., str]] + """The function to run when the tool is called.""" + coroutine: Optional[Callable[..., Awaitable[str]]] = None + """The asynchronous version of the function.""" + + # --- Runnable --- + + async def ainvoke( + self, + input: Union[str, Dict], + config: Optional[RunnableConfig] = None, + **kwargs: Any, + ) -> Any: + if not self.coroutine: + # If the tool does not implement async, fall back to default implementation + return await asyncio.get_running_loop().run_in_executor( + None, partial(self.invoke, input, config, **kwargs) + ) + + return await super().ainvoke(input, config, **kwargs) + + # --- Tool --- + + @property + def args(self) -> dict: + """The tool's input arguments.""" + if self.args_schema is not None: + return self.args_schema.schema()["properties"] + # For backwards compatibility, if the function signature is ambiguous, + # assume it takes a single string input. + return {"tool_input": {"type": "string"}} + + def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]: + """Convert tool input to pydantic model.""" + args, kwargs = super()._to_args_and_kwargs(tool_input) + # For backwards compatibility. The tool must be run with a single input + all_args = list(args) + list(kwargs.values()) + if len(all_args) != 1: + raise ToolException( + f"Too many arguments to single-input tool {self.name}." + f" Args: {all_args}" + ) + return tuple(all_args), {} + + def _run( + self, + *args: Any, + run_manager: Optional[CallbackManagerForToolRun] = None, + **kwargs: Any, + ) -> Any: + """Use the tool.""" + if self.func: + new_argument_supported = signature(self.func).parameters.get("callbacks") + return ( + self.func( + *args, + callbacks=run_manager.get_child() if run_manager else None, + **kwargs, + ) + if new_argument_supported + else self.func(*args, **kwargs) + ) + raise NotImplementedError("Tool does not support sync") + + async def _arun( + self, + *args: Any, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + **kwargs: Any, + ) -> Any: + """Use the tool asynchronously.""" + if self.coroutine: + new_argument_supported = signature(self.coroutine).parameters.get( + "callbacks" + ) + return ( + await self.coroutine( + *args, + callbacks=run_manager.get_child() if run_manager else None, + **kwargs, + ) + if new_argument_supported + else await self.coroutine(*args, **kwargs) + ) + else: + return await asyncio.get_running_loop().run_in_executor( + None, partial(self._run, run_manager=run_manager, **kwargs), *args + ) + + # TODO: this is for backwards compatibility, remove in future + def __init__( + self, name: str, func: Optional[Callable], description: str, **kwargs: Any + ) -> None: + """Initialize tool.""" + super(Tool, self).__init__( + name=name, func=func, description=description, **kwargs + ) + + @classmethod + def from_function( + cls, + func: Optional[Callable], + name: str, # We keep these required to support backwards compatibility + description: str, + return_direct: bool = False, + args_schema: Optional[Type[BaseModel]] = None, + coroutine: Optional[ + Callable[..., Awaitable[Any]] + ] = None, # This is last for compatibility, but should be after func + **kwargs: Any, + ) -> Tool: + """Initialize tool from a function.""" + if func is None and coroutine is None: + raise ValueError("Function and/or coroutine must be provided") + return cls( + name=name, + func=func, + coroutine=coroutine, + description=description, + return_direct=return_direct, + args_schema=args_schema, + **kwargs, + ) + + +class StructuredTool(BaseTool): + """Tool that can operate on any number of inputs.""" + + description: str = "" + args_schema: Type[BaseModel] = Field(..., description="The tool schema.") + """The input arguments' schema.""" + func: Optional[Callable[..., Any]] + """The function to run when the tool is called.""" + coroutine: Optional[Callable[..., Awaitable[Any]]] = None + """The asynchronous version of the function.""" + + # --- Runnable --- + + async def ainvoke( + self, + input: Union[str, Dict], + config: Optional[RunnableConfig] = None, + **kwargs: Any, + ) -> Any: + if not self.coroutine: + # If the tool does not implement async, fall back to default implementation + return await asyncio.get_running_loop().run_in_executor( + None, partial(self.invoke, input, config, **kwargs) + ) + + return await super().ainvoke(input, config, **kwargs) + + # --- Tool --- + + @property + def args(self) -> dict: + """The tool's input arguments.""" + return self.args_schema.schema()["properties"] + + def _run( + self, + *args: Any, + run_manager: Optional[CallbackManagerForToolRun] = None, + **kwargs: Any, + ) -> Any: + """Use the tool.""" + if self.func: + new_argument_supported = signature(self.func).parameters.get("callbacks") + return ( + self.func( + *args, + callbacks=run_manager.get_child() if run_manager else None, + **kwargs, + ) + if new_argument_supported + else self.func(*args, **kwargs) + ) + raise NotImplementedError("Tool does not support sync") + + async def _arun( + self, + *args: Any, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + **kwargs: Any, + ) -> str: + """Use the tool asynchronously.""" + if self.coroutine: + new_argument_supported = signature(self.coroutine).parameters.get( + "callbacks" + ) + return ( + await self.coroutine( + *args, + callbacks=run_manager.get_child() if run_manager else None, + **kwargs, + ) + if new_argument_supported + else await self.coroutine(*args, **kwargs) + ) + return await asyncio.get_running_loop().run_in_executor( + None, + partial(self._run, run_manager=run_manager, **kwargs), + *args, + ) + + @classmethod + def from_function( + cls, + func: Optional[Callable] = None, + coroutine: Optional[Callable[..., Awaitable[Any]]] = None, + name: Optional[str] = None, + description: Optional[str] = None, + return_direct: bool = False, + args_schema: Optional[Type[BaseModel]] = None, + infer_schema: bool = True, + **kwargs: Any, + ) -> StructuredTool: + """Create tool from a given function. + + A classmethod that helps to create a tool from a function. + + Args: + func: The function from which to create a tool + coroutine: The async function from which to create a tool + name: The name of the tool. Defaults to the function name + description: The description of the tool. Defaults to the function docstring + return_direct: Whether to return the result directly or as a callback + args_schema: The schema of the tool's input arguments + infer_schema: Whether to infer the schema from the function's signature + **kwargs: Additional arguments to pass to the tool + + Returns: + The tool + + Examples: + + .. code-block:: python + + def add(a: int, b: int) -> int: + \"\"\"Add two numbers\"\"\" + return a + b + tool = StructuredTool.from_function(add) + tool.run(1, 2) # 3 + """ + + if func is not None: + source_function = func + elif coroutine is not None: + source_function = coroutine + else: + raise ValueError("Function and/or coroutine must be provided") + name = name or source_function.__name__ + description = description or source_function.__doc__ + if description is None: + raise ValueError( + "Function must have a docstring if description not provided." + ) + + # Description example: + # search_api(query: str) - Searches the API for the query. + sig = signature(source_function) + description = f"{name}{sig} - {description.strip()}" + _args_schema = args_schema + if _args_schema is None and infer_schema: + _args_schema = create_schema_from_function(f"{name}Schema", source_function) + return cls( + name=name, + func=func, + coroutine=coroutine, + args_schema=_args_schema, + description=description, + return_direct=return_direct, + **kwargs, + ) + + +def tool( + *args: Union[str, Callable, Runnable], + return_direct: bool = False, + args_schema: Optional[Type[BaseModel]] = None, + infer_schema: bool = True, +) -> Callable: + """Make tools out of functions, can be used with or without arguments. + + Args: + *args: The arguments to the tool. + return_direct: Whether to return directly from the tool rather + than continuing the agent loop. + args_schema: optional argument schema for user to specify + infer_schema: Whether to infer the schema of the arguments from + the function's signature. This also makes the resultant tool + accept a dictionary input to its `run()` function. + + Requires: + - Function must be of type (str) -> str + - Function must have a docstring + + Examples: + .. code-block:: python + + @tool + def search_api(query: str) -> str: + # Searches the API for the query. + return + + @tool("search", return_direct=True) + def search_api(query: str) -> str: + # Searches the API for the query. + return + """ + + def _make_with_name(tool_name: str) -> Callable: + def _make_tool(dec_func: Union[Callable, Runnable]) -> BaseTool: + if isinstance(dec_func, Runnable): + runnable = dec_func + + if runnable.input_schema.schema().get("type") != "object": + raise ValueError("Runnable must have an object schema.") + + async def ainvoke_wrapper( + callbacks: Optional[Callbacks] = None, **kwargs: Any + ) -> Any: + return await runnable.ainvoke(kwargs, {"callbacks": callbacks}) + + def invoke_wrapper( + callbacks: Optional[Callbacks] = None, **kwargs: Any + ) -> Any: + return runnable.invoke(kwargs, {"callbacks": callbacks}) + + coroutine = ainvoke_wrapper + func = invoke_wrapper + schema: Optional[Type[BaseModel]] = runnable.input_schema + description = repr(runnable) + elif inspect.iscoroutinefunction(dec_func): + coroutine = dec_func + func = None + schema = args_schema + description = None + else: + coroutine = None + func = dec_func + schema = args_schema + description = None + + if infer_schema or args_schema is not None: + return StructuredTool.from_function( + func, + coroutine, + name=tool_name, + description=description, + return_direct=return_direct, + args_schema=schema, + infer_schema=infer_schema, + ) + # If someone doesn't want a schema applied, we must treat it as + # a simple string->string function + if func.__doc__ is None: + raise ValueError( + "Function must have a docstring if " + "description not provided and infer_schema is False." + ) + return Tool( + name=tool_name, + func=func, + description=f"{tool_name} tool", + return_direct=return_direct, + coroutine=coroutine, + ) + + return _make_tool + + if len(args) == 2 and isinstance(args[0], str) and isinstance(args[1], Runnable): + return _make_with_name(args[0])(args[1]) + elif len(args) == 1 and isinstance(args[0], str): + # if the argument is a string, then we use the string as the tool name + # Example usage: @tool("search", return_direct=True) + return _make_with_name(args[0]) + elif len(args) == 1 and callable(args[0]): + # if the argument is a function, then we use the function name as the tool name + # Example usage: @tool + return _make_with_name(args[0].__name__)(args[0]) + elif len(args) == 0: + # if there are no arguments, then we use the function name as the tool name + # Example usage: @tool(return_direct=True) + def _partial(func: Callable[[str], str]) -> BaseTool: + return _make_with_name(func.__name__)(func) + + return _partial + else: + raise ValueError("Too many arguments for tool decorator") diff --git a/swarms/tools/tool_registry.py b/swarms/tools/tool_registry.py new file mode 100644 index 00000000..42f3b556 --- /dev/null +++ b/swarms/tools/tool_registry.py @@ -0,0 +1,45 @@ +from .tool import Tool +from typing import Dict, Callable, Any, List + +ToolBuilder = Callable[[Any], Tool] +FuncToolBuilder = Callable[[], ToolBuilder] + + +class ToolsRegistry: + def __init__(self) -> None: + self.tools: Dict[str, FuncToolBuilder] = {} + + def register(self, tool_name: str, tool: FuncToolBuilder): + print(f"will register {tool_name}") + self.tools[tool_name] = tool + + def build(self, tool_name, config) -> Tool: + ret = self.tools[tool_name]()(config) + if isinstance(ret, Tool): + return ret + raise ValueError( + "Tool builder {} did not return a Tool instance".format(tool_name) + ) + + def list_tools(self) -> List[str]: + return list(self.tools.keys()) + + +tools_registry = ToolsRegistry() + + +def register(tool_name): + def decorator(tool: FuncToolBuilder): + tools_registry.register(tool_name, tool) + return tool + + return decorator + + +def build_tool(tool_name: str, config: Any) -> Tool: + print(f"will build {tool_name}") + return tools_registry.build(tool_name, config) + + +def list_tools() -> List[str]: + return tools_registry.list_tools() diff --git a/swarms/utils/main.py b/swarms/utils/main.py index f76c369f..3fa4b2ea 100644 --- a/swarms/utils/main.py +++ b/swarms/utils/main.py @@ -365,7 +365,7 @@ class FileHandler: try: if url.startswith(os.environ.get("SERVER", "http://localhost:8000")): local_filepath = url[ - len(os.environ.get("SERVER", "http://localhost:8000")) + 1: + len(os.environ.get("SERVER", "http://localhost:8000")) + 1 : ] local_filename = Path("file") / local_filepath.split("/")[-1] src = self.path / local_filepath