Former-commit-id: c41e5edd5a3919381e84e15bdc6d21077682954e
pull/160/head
Kye 2 years ago
parent 3832e249a2
commit ab010e32ef

@ -15,7 +15,8 @@ from langchain.chat_models import ChatOpenAI
from .EvalOutputParser import EvalOutputParser from .EvalOutputParser import EvalOutputParser
class AgentBuilder:
class AgentSetup:
def __init__(self, toolsets: list[BaseToolSet] = [], openai_api_key: str = None, serpapi_api_key: str = None, bing_search_url: str = None, bing_subscription_key: str = None): def __init__(self, toolsets: list[BaseToolSet] = [], openai_api_key: str = None, serpapi_api_key: str = None, bing_search_url: str = None, bing_subscription_key: str = None):
self.llm: BaseChatModel = None self.llm: BaseChatModel = None
self.parser: BaseOutputParser = None self.parser: BaseOutputParser = None
@ -28,7 +29,7 @@ class AgentBuilder:
if not self.openai_api_key: if not self.openai_api_key:
raise ValueError("OpenAI key is missing, it should either be set as an environment variable or passed as a parameter") raise ValueError("OpenAI key is missing, it should either be set as an environment variable or passed as a parameter")
def build_llm(self, callback_manager: BaseCallbackManager = None, openai_api_key: str = None): def setup_llm(self, callback_manager: BaseCallbackManager = None, openai_api_key: str = None):
if openai_api_key is None: if openai_api_key is None:
openai_api_key = os.getenv('OPENAI_API_KEY') openai_api_key = os.getenv('OPENAI_API_KEY')
if openai_api_key is None: if openai_api_key is None:
@ -36,16 +37,15 @@ class AgentBuilder:
self.llm = ChatOpenAI(openai_api_key=openai_api_key, temperature=0.5, callback_manager=callback_manager, verbose=True) self.llm = ChatOpenAI(openai_api_key=openai_api_key, temperature=0.5, callback_manager=callback_manager, verbose=True)
def build_parser(self): def setup_parser(self):
self.parser = EvalOutputParser() self.parser = EvalOutputParser()
def build_global_tools(self): def setup_global_tools(self):
if self.llm is None: if self.llm is None:
raise ValueError("LLM must be initialized before tools") raise ValueError("LLM must be initialized before tools")
toolnames = ["wikipedia"] toolnames = ["wikipedia"]
if self.serpapi_api_key: if self.serpapi_api_key:
toolnames.append("serpapi") toolnames.append("serpapi")

@ -1,22 +1,20 @@
from typing import Dict, Optional import os
import logging import logging
from typing import Dict, Optional
from celery import Task from celery import Task
from langchain.agents.agent import AgentExecutor from langchain.agents.agent import AgentExecutor
from langchain.callbacks.manager import CallbackManager from langchain.callbacks.manager import CallbackManager
from langchain.chains.conversation.memory import ConversationBufferMemory from langchain.chains.conversation.memory import ConversationBufferMemory
from langchain.memory.chat_memory import BaseChatMemory from langchain.memory.chat_memory import BaseChatMemory
from swarms.tools.main import BaseToolSet, ToolsFactory from swarms.tools.main import BaseToolSet, ToolsFactory
from .AgentBuilder import AgentBuilder from swarms.prompts.prompts import EVAL_PREFIX, EVAL_SUFFIX
from .Calback import EVALCallbackHandler, ExecutionTracingCallbackHandler
from swarms.agents.utils.AgentBuilder import AgentSetup
from swarms.agents.utils.EvalOutputParser import EVALCallbackHandler, ExecutionTracingCallbackHandler
callback_manager_instance = CallbackManager(EVALCallbackHandler()) callback_manager_instance = CallbackManager(EVALCallbackHandler())
class AgentCreator:
class AgentManager:
def __init__(self, toolsets: list[BaseToolSet] = []): def __init__(self, toolsets: list[BaseToolSet] = []):
if not isinstance(toolsets, list): if not isinstance(toolsets, list):
raise TypeError("Toolsets must be a list") raise TypeError("Toolsets must be a list")
@ -38,9 +36,8 @@ class AgentManager:
def create_executor(self, session: str, execution: Optional[Task] = None, openai_api_key: str = None) -> AgentExecutor: def create_executor(self, session: str, execution: Optional[Task] = None, openai_api_key: str = None) -> AgentExecutor:
try: try:
builder = AgentBuilder(self.toolsets) builder = AgentSetup(self.toolsets)
builder.build_parser() builder.setup_parser()
callbacks = [] callbacks = []
eval_callback = EVALCallbackHandler() eval_callback = EVALCallbackHandler()
@ -52,15 +49,13 @@ class AgentManager:
execution_callback.set_parser(builder.get_parser()) execution_callback.set_parser(builder.get_parser())
callbacks.append(execution_callback) callbacks.append(execution_callback)
#llm init
callback_manager = CallbackManager(callbacks) callback_manager = CallbackManager(callbacks)
builder.build_llm(callback_manager, openai_api_key) builder.setup_llm(callback_manager, openai_api_key)
if builder.llm is None: if builder.llm is None:
raise ValueError('LLM not created') raise ValueError('LLM not created')
builder.build_global_tools() builder.setup_global_tools()
#agent init
agent = builder.get_agent() agent = builder.get_agent()
if not agent: if not agent:
raise ValueError("Agent not created") raise ValueError("Agent not created")
@ -77,9 +72,6 @@ class AgentManager:
for tool in tools: for tool in tools:
tool.callback_manager = callback_manager tool.callback_manager = callback_manager
# Ensure the 'agent' key is present in the values dictionary
# values = {'agent': agent, 'tools': tools}
executor = AgentExecutor.from_agent_and_tools( executor = AgentExecutor.from_agent_and_tools(
agent=agent, agent=agent,
tools=tools, tools=tools,
@ -98,7 +90,7 @@ class AgentManager:
raise e raise e
@staticmethod @staticmethod
def create(toolsets: list[BaseToolSet]) -> "AgentManager": def create(toolsets: list[BaseToolSet]) -> "AgentCreator":
if not isinstance(toolsets, list): if not isinstance(toolsets, list):
raise TypeError("Toolsets must be a list") raise TypeError("Toolsets must be a list")
return AgentManager(toolsets=toolsets) return AgentCreator(toolsets=toolsets)
Loading…
Cancel
Save