Former-commit-id: c41e5edd5a3919381e84e15bdc6d21077682954e
pull/160/head
Kye 2 years ago
parent 3832e249a2
commit ab010e32ef

@ -15,7 +15,8 @@ from langchain.chat_models import ChatOpenAI
from .EvalOutputParser import EvalOutputParser
class AgentBuilder:
class AgentSetup:
def __init__(self, toolsets: list[BaseToolSet] = [], openai_api_key: str = None, serpapi_api_key: str = None, bing_search_url: str = None, bing_subscription_key: str = None):
self.llm: BaseChatModel = None
self.parser: BaseOutputParser = None
@ -28,7 +29,7 @@ class AgentBuilder:
if not self.openai_api_key:
raise ValueError("OpenAI key is missing, it should either be set as an environment variable or passed as a parameter")
def build_llm(self, callback_manager: BaseCallbackManager = None, openai_api_key: str = None):
def setup_llm(self, callback_manager: BaseCallbackManager = None, openai_api_key: str = None):
if openai_api_key is None:
openai_api_key = os.getenv('OPENAI_API_KEY')
if openai_api_key is None:
@ -36,16 +37,15 @@ class AgentBuilder:
self.llm = ChatOpenAI(openai_api_key=openai_api_key, temperature=0.5, callback_manager=callback_manager, verbose=True)
def build_parser(self):
def setup_parser(self):
self.parser = EvalOutputParser()
def build_global_tools(self):
def setup_global_tools(self):
if self.llm is None:
raise ValueError("LLM must be initialized before tools")
toolnames = ["wikipedia"]
if self.serpapi_api_key:
toolnames.append("serpapi")

@ -1,22 +1,20 @@
from typing import Dict, Optional
import os
import logging
from typing import Dict, Optional
from celery import Task
from langchain.agents.agent import AgentExecutor
from langchain.callbacks.manager import CallbackManager
from langchain.chains.conversation.memory import ConversationBufferMemory
from langchain.memory.chat_memory import BaseChatMemory
from swarms.tools.main import BaseToolSet, ToolsFactory
from .AgentBuilder import AgentBuilder
from .Calback import EVALCallbackHandler, ExecutionTracingCallbackHandler
from swarms.prompts.prompts import EVAL_PREFIX, EVAL_SUFFIX
from swarms.agents.utils.AgentBuilder import AgentSetup
from swarms.agents.utils.EvalOutputParser import EVALCallbackHandler, ExecutionTracingCallbackHandler
callback_manager_instance = CallbackManager(EVALCallbackHandler())
class AgentManager:
class AgentCreator:
def __init__(self, toolsets: list[BaseToolSet] = []):
if not isinstance(toolsets, list):
raise TypeError("Toolsets must be a list")
@ -38,9 +36,8 @@ class AgentManager:
def create_executor(self, session: str, execution: Optional[Task] = None, openai_api_key: str = None) -> AgentExecutor:
try:
builder = AgentBuilder(self.toolsets)
builder.build_parser()
builder = AgentSetup(self.toolsets)
builder.setup_parser()
callbacks = []
eval_callback = EVALCallbackHandler()
@ -52,15 +49,13 @@ class AgentManager:
execution_callback.set_parser(builder.get_parser())
callbacks.append(execution_callback)
#llm init
callback_manager = CallbackManager(callbacks)
builder.build_llm(callback_manager, openai_api_key)
builder.setup_llm(callback_manager, openai_api_key)
if builder.llm is None:
raise ValueError('LLM not created')
builder.build_global_tools()
builder.setup_global_tools()
#agent init
agent = builder.get_agent()
if not agent:
raise ValueError("Agent not created")
@ -77,9 +72,6 @@ class AgentManager:
for tool in tools:
tool.callback_manager = callback_manager
# Ensure the 'agent' key is present in the values dictionary
# values = {'agent': agent, 'tools': tools}
executor = AgentExecutor.from_agent_and_tools(
agent=agent,
tools=tools,
@ -98,7 +90,7 @@ class AgentManager:
raise e
@staticmethod
def create(toolsets: list[BaseToolSet]) -> "AgentManager":
def create(toolsets: list[BaseToolSet]) -> "AgentCreator":
if not isinstance(toolsets, list):
raise TypeError("Toolsets must be a list")
return AgentManager(toolsets=toolsets)
return AgentCreator(toolsets=toolsets)
Loading…
Cancel
Save