From a1e78118c99f11efa83b80d3fe70c76cfa807b31 Mon Sep 17 00:00:00 2001 From: Kye Date: Tue, 27 Jun 2023 07:39:39 -0400 Subject: [PATCH] clean up --- swarms/agents/agents.py | 147 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 141 insertions(+), 6 deletions(-) diff --git a/swarms/agents/agents.py b/swarms/agents/agents.py index 2c646616..fc4600ae 100644 --- a/swarms/agents/agents.py +++ b/swarms/agents/agents.py @@ -1,10 +1,143 @@ from __future__ import annotations +from enum import Enum +from typing import Callable, Tuple + +from langchain.agents.agent import AgentExecutor +from langchain.agents.tools import BaseTool, Tool + + +class ToolScope(Enum): + GLOBAL = "global" + SESSION = "session" + +SessionGetter = Callable[[], Tuple[str, AgentExecutor]] + + +def tool( + name: str, + description: str, + scope: ToolScope = ToolScope.GLOBAL, +): + def decorator(func): + func.name = name + func.description = description + func.is_tool = True + func.scope = scope + return func + + return decorator + + +class ToolWrapper: + def __init__(self, name: str, description: str, scope: ToolScope, func): + self.name = name + self.description = description + self.scope = scope + self.func = func + + def is_global(self) -> bool: + return self.scope == ToolScope.GLOBAL + + def is_per_session(self) -> bool: + return self.scope == ToolScope.SESSION + + def to_tool( + self, + get_session: SessionGetter = lambda: [], + ) -> BaseTool: + func = self.func + if self.is_per_session(): + func = lambda *args, **kwargs: self.func( + *args, **kwargs, get_session=get_session + ) + + return Tool( + name=self.name, + description=self.description, + func=func, + ) + + +class BaseToolSet: + def tool_wrappers(cls) -> list[ToolWrapper]: + methods = [ + getattr(cls, m) for m in dir(cls) if hasattr(getattr(cls, m), "is_tool") + ] + return [ToolWrapper(m.name, m.description, m.scope, m) for m in methods] + + +#=====================================> +from typing import Optional + +from langchain.agents import load_tools +from langchain.agents.tools import BaseTool +from langchain.llms.base import BaseLLM + + +class ToolsFactory: + @staticmethod + def from_toolset( + toolset: BaseToolSet, + only_global: Optional[bool] = False, + only_per_session: Optional[bool] = False, + get_session: SessionGetter = lambda: [], + ) -> list[BaseTool]: + tools = [] + for wrapper in toolset.tool_wrappers(): + if only_global and not wrapper.is_global(): + continue + if only_per_session and not wrapper.is_per_session(): + continue + tools.append(wrapper.to_tool(get_session=get_session)) + return tools + + @staticmethod + def create_global_tools( + toolsets: list[BaseToolSet], + ) -> list[BaseTool]: + tools = [] + for toolset in toolsets: + tools.extend( + ToolsFactory.from_toolset( + toolset=toolset, + only_global=True, + ) + ) + return tools + + @staticmethod + def create_per_session_tools( + toolsets: list[BaseToolSet], + get_session: SessionGetter = lambda: [], + ) -> list[BaseTool]: + tools = [] + for toolset in toolsets: + tools.extend( + ToolsFactory.from_toolset( + toolset=toolset, + only_per_session=True, + get_session=get_session, + ) + ) + return tools + + @staticmethod + def create_global_tools_from_names( + toolnames: list[str], + llm: Optional[BaseLLM], + ) -> list[BaseTool]: + return load_tools(toolnames, llm=llm) + +#=====================================> +#=====================================> + + + ################ -from core.prompts.input import EVAL_PREFIX, EVAL_SUFFIX -from core.tools.base import BaseToolSet -from core.tools.factory import ToolsFactory +# from core.prompts.input import EVAL_PREFIX, EVAL_SUFFIX +from swarms.prompts.prompts import EVAL_PREFIX, EVAL_SUFFIX ############ @@ -312,6 +445,7 @@ from langchain.schema import ( from langchain.tools.base import BaseTool # from core.prompts.input import EVAL_TOOL_RESPONSE +from swarms.prompts.prompts import EVAL_TOOL_RESPONSE from swarms.prompts.prompts import EVAL_FORMAT_INSTRUCTIONS @@ -453,7 +587,7 @@ from tenacity import ( ) from env import settings -from ansi import ANSI, Color, Style + def _create_retry_decorator(llm: ChatOpenAI) -> Callable[[Any], Any]: @@ -785,7 +919,6 @@ class ChatOpenAI(BaseChatModel, BaseModel): from typing import Dict, Optional -from celery import Task from langchain.agents.agent import AgentExecutor from langchain.callbacks.base import CallbackManager @@ -879,7 +1012,9 @@ from typing import Dict from langchain.schema import BaseOutputParser -from core.prompts.input import EVAL_FORMAT_INSTRUCTIONS +# from core.prompts.input import EVAL_FORMAT_INSTRUCTIONS + +from swarms.prompts.prompts import EVAL_FORMAT_INSTRUCTIONS class EvalOutputParser(BaseOutputParser):