base tools classes clean up refactorization make it much more seamless

pull/39/head
Kye 1 year ago
parent 267253283e
commit 1daa3c2c39

@ -1,26 +1,21 @@
#--------------------------------------> AUTO GPT TOOLS
# General
import os
import pandas as pd
from langchain.agents.agent_toolkits.pandas.base import create_pandas_dataframe_agent
from langchain.docstore.document import Document
import asyncio import asyncio
import os
# Tools # Tools
from contextlib import contextmanager from contextlib import contextmanager
from typing import Optional from typing import Optional
import pandas as pd
from langchain.agents import tool from langchain.agents import tool
from langchain.agents.agent_toolkits.pandas.base import create_pandas_dataframe_agent
from langchain.docstore.document import Document
ROOT_DIR = "./data/" ROOT_DIR = "./data/"
from langchain.tools import BaseTool, DuckDuckGoSearchRun from langchain.chains.qa_with_sources.loading import BaseCombineDocumentsChain
from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.tools import BaseTool, DuckDuckGoSearchRun
from pydantic import Field from pydantic import Field
from langchain.chains.qa_with_sources.loading import BaseCombineDocumentsChain
@contextmanager @contextmanager

@ -2,21 +2,21 @@ from __future__ import annotations
from enum import Enum from enum import Enum
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Callable, Optional, Type, Union from typing import Any, Callable, Optional, Type, Union, Tuple
from pydantic import BaseModel from pydantic import BaseModel
from swarms.utils.logger import logger from swarms.utils.logger import logger
class ToolScope(Enum): class ToolScope(Enum):
GLOBAL = "global" GLOBAL = "global"
SESSION = "session" SESSION = "session"
class ToolException(Exception): class ToolException(Exception):
pass pass
class BaseTool(ABC): class BaseTool(ABC):
name: str name: str
description: str description: str
@ -25,12 +25,14 @@ class BaseTool(ABC):
def run(self, *args: Any, **kwargs: Any) -> Any: def run(self, *args: Any, **kwargs: Any) -> Any:
pass pass
@abstractmethod
async def arun(self, *args: Any, **kwargs: Any) -> Any: async def arun(self, *args: Any, **kwargs: Any) -> Any:
pass pass
def __call__(self, *args: Any, **kwargs: Any) -> Any: def __call__(self, *args: Any, **kwargs: Any) -> Any:
return self.run(*args, **kwargs) return self.run(*args, **kwargs)
class Tool(BaseTool): class Tool(BaseTool):
def __init__(self, name: str, description: str, func: Callable[..., Any]): def __init__(self, name: str, description: str, func: Callable[..., Any]):
self.name = name self.name = name
@ -49,6 +51,7 @@ class Tool(BaseTool):
except ToolException as e: except ToolException as e:
raise e raise e
class StructuredTool(BaseTool): class StructuredTool(BaseTool):
def __init__( def __init__(
self, self,
@ -74,49 +77,9 @@ class StructuredTool(BaseTool):
except ToolException as e: except ToolException as e:
raise e raise e
def tool(
name: Optional[str] = None,
description: Optional[str] = None,
args_schema: Optional[Type[BaseModel]] = None,
return_direct: bool = False,
infer_schema: bool = True
) -> Callable:
def decorator(func: Callable[..., Any]) -> Union[Tool, StructuredTool]:
nonlocal name, description
if name is None:
name = func.__name__
if description is None:
description = func.__doc__ or ""
if args_schema or infer_schema:
if args_schema is None:
args_schema = BaseModel
return StructuredTool(name, description, args_schema, func)
else:
return Tool(name, description, func)
return decorator
SessionGetter = Callable[[], Tuple[str, AgentExecutor]] SessionGetter = Callable[[], Tuple[str, AgentExecutor]]
def tool(
name: str,
description: str,
scope: ToolScope = ToolScope.GLOBAL,
):
def decorator(func):
func.name = name
func.description = description
func.is_tool = True
func.scope = scope
return func
return decorator
class ToolWrapper: class ToolWrapper:
def __init__(self, name: str, description: str, scope: ToolScope, func): def __init__(self, name: str, description: str, scope: ToolScope, func):
@ -131,52 +94,27 @@ class ToolWrapper:
def is_per_session(self) -> bool: def is_per_session(self) -> bool:
return self.scope == ToolScope.SESSION return self.scope == ToolScope.SESSION
def to_tool( def to_tool(self, get_session: SessionGetter = lambda: []) -> BaseTool:
self,
get_session: SessionGetter = lambda: [],
) -> BaseTool:
func = self.func
if self.is_per_session(): if self.is_per_session():
def func(*args, **kwargs): self.func = lambda *args, **kwargs: self.func(*args, **kwargs, get_session=get_session)
return self.func(*args, **kwargs, get_session=get_session)
return Tool( return Tool(name=self.name, description=self.description, func=self.func)
name=self.name,
description=self.description,
func=func,
)
class BaseToolSet: class BaseToolSet:
def tool_wrappers(cls) -> list[ToolWrapper]: def tool_wrappers(cls) -> list[ToolWrapper]:
methods = [ methods = [getattr(cls, m) for m in dir(cls) if hasattr(getattr(cls, m), "is_tool")]
getattr(cls, m) for m in dir(cls) if hasattr(getattr(cls, m), "is_tool")
]
return [ToolWrapper(m.name, m.description, m.scope, m) for m in methods] return [ToolWrapper(m.name, m.description, m.scope, m) for m in methods]
class ToolCreator(ABC):
@abstractmethod
def create_tools(self, toolsets: list[BaseToolSet]) -> list[BaseTool]:
pass
class ToolsFactory:
@staticmethod
def from_toolset(
toolset: BaseToolSet,
only_global: Optional[bool] = False,
only_per_session: Optional[bool] = False,
get_session: SessionGetter = lambda: [],
) -> list[BaseTool]:
tools = []
for wrapper in toolset.tool_wrappers():
if only_global and not wrapper.is_global():
continue
if only_per_session and not wrapper.is_per_session():
continue
tools.append(wrapper.to_tool(get_session=get_session))
return tools
@staticmethod class GlobalToolsCreator(ToolCreator):
def create_global_tools( def create_tools(self, toolsets: list[BaseToolSet]) -> list[BaseTool]:
toolsets: list[BaseToolSet],
) -> list[BaseTool]:
tools = [] tools = []
for toolset in toolsets: for toolset in toolsets:
tools.extend( tools.extend(
@ -187,11 +125,9 @@ class ToolsFactory:
) )
return tools return tools
@staticmethod
def create_per_session_tools( class SessionToolsCreator(ToolCreator):
toolsets: list[BaseToolSet], def create_tools(self, toolsets: list[BaseToolSet], get_session: SessionGetter = lambda: []) -> list[BaseTool]:
get_session: SessionGetter = lambda: [],
) -> list[BaseTool]:
tools = [] tools = []
for toolset in toolsets: for toolset in toolsets:
tools.extend( tools.extend(
@ -203,10 +139,23 @@ class ToolsFactory:
) )
return tools return tools
class ToolsFactory:
@staticmethod @staticmethod
def create_global_tools_from_names( def from_toolset(toolset: BaseToolSet, only_global: Optional[bool] = False, only_per_session: Optional[bool] = False, get_session: SessionGetter = lambda: []) -> list[BaseTool]:
toolnames: list[str], tools = []
llm: Optional[BaseLLM], for wrapper in toolset.tool_wrappers():
) -> list[BaseTool]: if only_global and not wrapper.is_global():
return load_tools(toolnames, llm=llm) continue
if only_per_session and not wrapper.is_per_session():
continue
tools.append(wrapper.to_tool(get_session=get_session))
return tools
@staticmethod
def create_tools(tool_creator: ToolCreator, toolsets: list[BaseToolSet], get_session: SessionGetter = lambda: []):
return tool_creator.create_tools(toolsets, get_session)
@staticmethod
def create_global_tools_from_names(toolnames: list[str], llm: Optional[BaseLLM]) -> list[BaseTool]:
return load_tools(toolnames, llm=llm)

Loading…
Cancel
Save