main
Kye 2 years ago
parent 7f48e29fb7
commit a1e78118c9

@ -1,10 +1,143 @@
from __future__ import annotations from __future__ import annotations
from enum import Enum
from typing import Callable, Tuple
from langchain.agents.agent import AgentExecutor
from langchain.agents.tools import BaseTool, Tool
class ToolScope(Enum):
GLOBAL = "global"
SESSION = "session"
SessionGetter = Callable[[], Tuple[str, AgentExecutor]]
def tool(
name: str,
description: str,
scope: ToolScope = ToolScope.GLOBAL,
):
def decorator(func):
func.name = name
func.description = description
func.is_tool = True
func.scope = scope
return func
return decorator
class ToolWrapper:
def __init__(self, name: str, description: str, scope: ToolScope, func):
self.name = name
self.description = description
self.scope = scope
self.func = func
def is_global(self) -> bool:
return self.scope == ToolScope.GLOBAL
def is_per_session(self) -> bool:
return self.scope == ToolScope.SESSION
def to_tool(
self,
get_session: SessionGetter = lambda: [],
) -> BaseTool:
func = self.func
if self.is_per_session():
func = lambda *args, **kwargs: self.func(
*args, **kwargs, get_session=get_session
)
return Tool(
name=self.name,
description=self.description,
func=func,
)
class BaseToolSet:
def tool_wrappers(cls) -> list[ToolWrapper]:
methods = [
getattr(cls, m) for m in dir(cls) if hasattr(getattr(cls, m), "is_tool")
]
return [ToolWrapper(m.name, m.description, m.scope, m) for m in methods]
#=====================================>
from typing import Optional
from langchain.agents import load_tools
from langchain.agents.tools import BaseTool
from langchain.llms.base import BaseLLM
class ToolsFactory:
@staticmethod
def from_toolset(
toolset: BaseToolSet,
only_global: Optional[bool] = False,
only_per_session: Optional[bool] = False,
get_session: SessionGetter = lambda: [],
) -> list[BaseTool]:
tools = []
for wrapper in toolset.tool_wrappers():
if only_global and not wrapper.is_global():
continue
if only_per_session and not wrapper.is_per_session():
continue
tools.append(wrapper.to_tool(get_session=get_session))
return tools
@staticmethod
def create_global_tools(
toolsets: list[BaseToolSet],
) -> list[BaseTool]:
tools = []
for toolset in toolsets:
tools.extend(
ToolsFactory.from_toolset(
toolset=toolset,
only_global=True,
)
)
return tools
@staticmethod
def create_per_session_tools(
toolsets: list[BaseToolSet],
get_session: SessionGetter = lambda: [],
) -> list[BaseTool]:
tools = []
for toolset in toolsets:
tools.extend(
ToolsFactory.from_toolset(
toolset=toolset,
only_per_session=True,
get_session=get_session,
)
)
return tools
@staticmethod
def create_global_tools_from_names(
toolnames: list[str],
llm: Optional[BaseLLM],
) -> list[BaseTool]:
return load_tools(toolnames, llm=llm)
#=====================================>
#=====================================>
################ ################
from core.prompts.input import EVAL_PREFIX, EVAL_SUFFIX # from core.prompts.input import EVAL_PREFIX, EVAL_SUFFIX
from core.tools.base import BaseToolSet from swarms.prompts.prompts import EVAL_PREFIX, EVAL_SUFFIX
from core.tools.factory import ToolsFactory
############ ############
@ -312,6 +445,7 @@ from langchain.schema import (
from langchain.tools.base import BaseTool from langchain.tools.base import BaseTool
# from core.prompts.input import EVAL_TOOL_RESPONSE # from core.prompts.input import EVAL_TOOL_RESPONSE
from swarms.prompts.prompts import EVAL_TOOL_RESPONSE
from swarms.prompts.prompts import EVAL_FORMAT_INSTRUCTIONS from swarms.prompts.prompts import EVAL_FORMAT_INSTRUCTIONS
@ -453,7 +587,7 @@ from tenacity import (
) )
from env import settings from env import settings
from ansi import ANSI, Color, Style
def _create_retry_decorator(llm: ChatOpenAI) -> Callable[[Any], Any]: def _create_retry_decorator(llm: ChatOpenAI) -> Callable[[Any], Any]:
@ -785,7 +919,6 @@ class ChatOpenAI(BaseChatModel, BaseModel):
from typing import Dict, Optional from typing import Dict, Optional
from celery import Task
from langchain.agents.agent import AgentExecutor from langchain.agents.agent import AgentExecutor
from langchain.callbacks.base import CallbackManager from langchain.callbacks.base import CallbackManager
@ -879,7 +1012,9 @@ from typing import Dict
from langchain.schema import BaseOutputParser from langchain.schema import BaseOutputParser
from core.prompts.input import EVAL_FORMAT_INSTRUCTIONS # from core.prompts.input import EVAL_FORMAT_INSTRUCTIONS
from swarms.prompts.prompts import EVAL_FORMAT_INSTRUCTIONS
class EvalOutputParser(BaseOutputParser): class EvalOutputParser(BaseOutputParser):

Loading…
Cancel
Save