base tools classes clean up refactorization make it much more seamless

pull/39/head
Kye 1 year ago
parent 267253283e
commit 1daa3c2c39

@ -1,26 +1,21 @@
#--------------------------------------> AUTO GPT TOOLS
# General
import os
import pandas as pd
from langchain.agents.agent_toolkits.pandas.base import create_pandas_dataframe_agent
from langchain.docstore.document import Document
import asyncio
import os
# Tools
from contextlib import contextmanager
from typing import Optional
import pandas as pd
from langchain.agents import tool
from langchain.agents.agent_toolkits.pandas.base import create_pandas_dataframe_agent
from langchain.docstore.document import Document
ROOT_DIR = "./data/"
from langchain.tools import BaseTool, DuckDuckGoSearchRun
from langchain.chains.qa_with_sources.loading import BaseCombineDocumentsChain
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.tools import BaseTool, DuckDuckGoSearchRun
from pydantic import Field
from langchain.chains.qa_with_sources.loading import BaseCombineDocumentsChain
@contextmanager

@ -2,21 +2,21 @@ from __future__ import annotations
from enum import Enum
from abc import ABC, abstractmethod
from typing import Any, Callable, Optional, Type, Union
from typing import Any, Callable, Optional, Type, Union, Tuple
from pydantic import BaseModel
from swarms.utils.logger import logger
class ToolScope(Enum):
GLOBAL = "global"
SESSION = "session"
class ToolException(Exception):
pass
class BaseTool(ABC):
name: str
description: str
@ -25,12 +25,14 @@ class BaseTool(ABC):
def run(self, *args: Any, **kwargs: Any) -> Any:
pass
@abstractmethod
async def arun(self, *args: Any, **kwargs: Any) -> Any:
pass
def __call__(self, *args: Any, **kwargs: Any) -> Any:
return self.run(*args, **kwargs)
class Tool(BaseTool):
def __init__(self, name: str, description: str, func: Callable[..., Any]):
self.name = name
@ -49,6 +51,7 @@ class Tool(BaseTool):
except ToolException as e:
raise e
class StructuredTool(BaseTool):
def __init__(
self,
@ -74,49 +77,9 @@ class StructuredTool(BaseTool):
except ToolException as e:
raise e
def tool(
name: Optional[str] = None,
description: Optional[str] = None,
args_schema: Optional[Type[BaseModel]] = None,
return_direct: bool = False,
infer_schema: bool = True
) -> Callable:
def decorator(func: Callable[..., Any]) -> Union[Tool, StructuredTool]:
nonlocal name, description
if name is None:
name = func.__name__
if description is None:
description = func.__doc__ or ""
if args_schema or infer_schema:
if args_schema is None:
args_schema = BaseModel
return StructuredTool(name, description, args_schema, func)
else:
return Tool(name, description, func)
return decorator
SessionGetter = Callable[[], Tuple[str, AgentExecutor]]
def tool(
name: str,
description: str,
scope: ToolScope = ToolScope.GLOBAL,
):
def decorator(func):
func.name = name
func.description = description
func.is_tool = True
func.scope = scope
return func
return decorator
class ToolWrapper:
def __init__(self, name: str, description: str, scope: ToolScope, func):
@ -131,52 +94,27 @@ class ToolWrapper:
def is_per_session(self) -> bool:
return self.scope == ToolScope.SESSION
def to_tool(
self,
get_session: SessionGetter = lambda: [],
) -> BaseTool:
func = self.func
def to_tool(self, get_session: SessionGetter = lambda: []) -> BaseTool:
if self.is_per_session():
def func(*args, **kwargs):
return self.func(*args, **kwargs, get_session=get_session)
self.func = lambda *args, **kwargs: self.func(*args, **kwargs, get_session=get_session)
return Tool(
name=self.name,
description=self.description,
func=func,
)
return Tool(name=self.name, description=self.description, func=self.func)
class BaseToolSet:
def tool_wrappers(cls) -> list[ToolWrapper]:
methods = [
getattr(cls, m) for m in dir(cls) if hasattr(getattr(cls, m), "is_tool")
]
methods = [getattr(cls, m) for m in dir(cls) if hasattr(getattr(cls, m), "is_tool")]
return [ToolWrapper(m.name, m.description, m.scope, m) for m in methods]
class ToolCreator(ABC):
@abstractmethod
def create_tools(self, toolsets: list[BaseToolSet]) -> list[BaseTool]:
pass
class ToolsFactory:
@staticmethod
def from_toolset(
toolset: BaseToolSet,
only_global: Optional[bool] = False,
only_per_session: Optional[bool] = False,
get_session: SessionGetter = lambda: [],
) -> list[BaseTool]:
tools = []
for wrapper in toolset.tool_wrappers():
if only_global and not wrapper.is_global():
continue
if only_per_session and not wrapper.is_per_session():
continue
tools.append(wrapper.to_tool(get_session=get_session))
return tools
@staticmethod
def create_global_tools(
toolsets: list[BaseToolSet],
) -> list[BaseTool]:
class GlobalToolsCreator(ToolCreator):
def create_tools(self, toolsets: list[BaseToolSet]) -> list[BaseTool]:
tools = []
for toolset in toolsets:
tools.extend(
@ -187,11 +125,9 @@ class ToolsFactory:
)
return tools
@staticmethod
def create_per_session_tools(
toolsets: list[BaseToolSet],
get_session: SessionGetter = lambda: [],
) -> list[BaseTool]:
class SessionToolsCreator(ToolCreator):
def create_tools(self, toolsets: list[BaseToolSet], get_session: SessionGetter = lambda: []) -> list[BaseTool]:
tools = []
for toolset in toolsets:
tools.extend(
@ -203,10 +139,23 @@ class ToolsFactory:
)
return tools
class ToolsFactory:
@staticmethod
def create_global_tools_from_names(
toolnames: list[str],
llm: Optional[BaseLLM],
) -> list[BaseTool]:
return load_tools(toolnames, llm=llm)
def from_toolset(toolset: BaseToolSet, only_global: Optional[bool] = False, only_per_session: Optional[bool] = False, get_session: SessionGetter = lambda: []) -> list[BaseTool]:
tools = []
for wrapper in toolset.tool_wrappers():
if only_global and not wrapper.is_global():
continue
if only_per_session and not wrapper.is_per_session():
continue
tools.append(wrapper.to_tool(get_session=get_session))
return tools
@staticmethod
def create_tools(tool_creator: ToolCreator, toolsets: list[BaseToolSet], get_session: SessionGetter = lambda: []):
return tool_creator.create_tools(toolsets, get_session)
@staticmethod
def create_global_tools_from_names(toolnames: list[str], llm: Optional[BaseLLM]) -> list[BaseTool]:
return load_tools(toolnames, llm=llm)

Loading…
Cancel
Save