From 1daa3c2c399ad600ec8baa90a2781b5b4e38362d Mon Sep 17 00:00:00 2001 From: Kye Date: Thu, 3 Aug 2023 10:34:18 -0400 Subject: [PATCH] base tools classes clean up refactorization make it much more seamless --- swarms/agents/tools/autogpt.py | 19 ++--- swarms/agents/tools/base.py | 127 ++++++++++----------------------- 2 files changed, 45 insertions(+), 101 deletions(-) diff --git a/swarms/agents/tools/autogpt.py b/swarms/agents/tools/autogpt.py index 24ee50c4..011bcd8f 100644 --- a/swarms/agents/tools/autogpt.py +++ b/swarms/agents/tools/autogpt.py @@ -1,26 +1,21 @@ -#--------------------------------------> AUTO GPT TOOLS - -# General -import os -import pandas as pd - -from langchain.agents.agent_toolkits.pandas.base import create_pandas_dataframe_agent -from langchain.docstore.document import Document import asyncio +import os # Tools from contextlib import contextmanager from typing import Optional + +import pandas as pd from langchain.agents import tool +from langchain.agents.agent_toolkits.pandas.base import create_pandas_dataframe_agent +from langchain.docstore.document import Document ROOT_DIR = "./data/" -from langchain.tools import BaseTool, DuckDuckGoSearchRun +from langchain.chains.qa_with_sources.loading import BaseCombineDocumentsChain from langchain.text_splitter import RecursiveCharacterTextSplitter - +from langchain.tools import BaseTool, DuckDuckGoSearchRun from pydantic import Field -from langchain.chains.qa_with_sources.loading import BaseCombineDocumentsChain - @contextmanager diff --git a/swarms/agents/tools/base.py b/swarms/agents/tools/base.py index a7e3a7aa..6e307f44 100644 --- a/swarms/agents/tools/base.py +++ b/swarms/agents/tools/base.py @@ -2,21 +2,21 @@ from __future__ import annotations from enum import Enum from abc import ABC, abstractmethod -from typing import Any, Callable, Optional, Type, Union +from typing import Any, Callable, Optional, Type, Union, Tuple from pydantic import BaseModel - from swarms.utils.logger import logger + class ToolScope(Enum): GLOBAL = "global" SESSION = "session" - class ToolException(Exception): pass + class BaseTool(ABC): name: str description: str @@ -25,12 +25,14 @@ class BaseTool(ABC): def run(self, *args: Any, **kwargs: Any) -> Any: pass + @abstractmethod async def arun(self, *args: Any, **kwargs: Any) -> Any: pass def __call__(self, *args: Any, **kwargs: Any) -> Any: return self.run(*args, **kwargs) + class Tool(BaseTool): def __init__(self, name: str, description: str, func: Callable[..., Any]): self.name = name @@ -49,6 +51,7 @@ class Tool(BaseTool): except ToolException as e: raise e + class StructuredTool(BaseTool): def __init__( self, @@ -74,49 +77,9 @@ class StructuredTool(BaseTool): except ToolException as e: raise e -def tool( - name: Optional[str] = None, - description: Optional[str] = None, - args_schema: Optional[Type[BaseModel]] = None, - return_direct: bool = False, - infer_schema: bool = True -) -> Callable: - def decorator(func: Callable[..., Any]) -> Union[Tool, StructuredTool]: - nonlocal name, description - - if name is None: - name = func.__name__ - if description is None: - description = func.__doc__ or "" - - if args_schema or infer_schema: - if args_schema is None: - args_schema = BaseModel - - return StructuredTool(name, description, args_schema, func) - else: - return Tool(name, description, func) - - return decorator - - SessionGetter = Callable[[], Tuple[str, AgentExecutor]] -def tool( - name: str, - description: str, - scope: ToolScope = ToolScope.GLOBAL, -): - def decorator(func): - func.name = name - func.description = description - func.is_tool = True - func.scope = scope - return func - - return decorator - class ToolWrapper: def __init__(self, name: str, description: str, scope: ToolScope, func): @@ -131,52 +94,27 @@ class ToolWrapper: def is_per_session(self) -> bool: return self.scope == ToolScope.SESSION - def to_tool( - self, - get_session: SessionGetter = lambda: [], - ) -> BaseTool: - func = self.func + def to_tool(self, get_session: SessionGetter = lambda: []) -> BaseTool: if self.is_per_session(): - def func(*args, **kwargs): - return self.func(*args, **kwargs, get_session=get_session) + self.func = lambda *args, **kwargs: self.func(*args, **kwargs, get_session=get_session) - return Tool( - name=self.name, - description=self.description, - func=func, - ) + return Tool(name=self.name, description=self.description, func=self.func) class BaseToolSet: def tool_wrappers(cls) -> list[ToolWrapper]: - methods = [ - getattr(cls, m) for m in dir(cls) if hasattr(getattr(cls, m), "is_tool") - ] + methods = [getattr(cls, m) for m in dir(cls) if hasattr(getattr(cls, m), "is_tool")] return [ToolWrapper(m.name, m.description, m.scope, m) for m in methods] - -class ToolsFactory: - @staticmethod - def from_toolset( - toolset: BaseToolSet, - only_global: Optional[bool] = False, - only_per_session: Optional[bool] = False, - get_session: SessionGetter = lambda: [], - ) -> list[BaseTool]: - tools = [] - for wrapper in toolset.tool_wrappers(): - if only_global and not wrapper.is_global(): - continue - if only_per_session and not wrapper.is_per_session(): - continue - tools.append(wrapper.to_tool(get_session=get_session)) - return tools +class ToolCreator(ABC): + @abstractmethod + def create_tools(self, toolsets: list[BaseToolSet]) -> list[BaseTool]: + pass - @staticmethod - def create_global_tools( - toolsets: list[BaseToolSet], - ) -> list[BaseTool]: + +class GlobalToolsCreator(ToolCreator): + def create_tools(self, toolsets: list[BaseToolSet]) -> list[BaseTool]: tools = [] for toolset in toolsets: tools.extend( @@ -187,11 +125,9 @@ class ToolsFactory: ) return tools - @staticmethod - def create_per_session_tools( - toolsets: list[BaseToolSet], - get_session: SessionGetter = lambda: [], - ) -> list[BaseTool]: + +class SessionToolsCreator(ToolCreator): + def create_tools(self, toolsets: list[BaseToolSet], get_session: SessionGetter = lambda: []) -> list[BaseTool]: tools = [] for toolset in toolsets: tools.extend( @@ -203,10 +139,23 @@ class ToolsFactory: ) return tools + +class ToolsFactory: + @staticmethod + def from_toolset(toolset: BaseToolSet, only_global: Optional[bool] = False, only_per_session: Optional[bool] = False, get_session: SessionGetter = lambda: []) -> list[BaseTool]: + tools = [] + for wrapper in toolset.tool_wrappers(): + if only_global and not wrapper.is_global(): + continue + if only_per_session and not wrapper.is_per_session(): + continue + tools.append(wrapper.to_tool(get_session=get_session)) + return tools + + @staticmethod + def create_tools(tool_creator: ToolCreator, toolsets: list[BaseToolSet], get_session: SessionGetter = lambda: []): + return tool_creator.create_tools(toolsets, get_session) + @staticmethod - def create_global_tools_from_names( - toolnames: list[str], - llm: Optional[BaseLLM], - ) -> list[BaseTool]: + def create_global_tools_from_names(toolnames: list[str], llm: Optional[BaseLLM]) -> list[BaseTool]: return load_tools(toolnames, llm=llm) - \ No newline at end of file