diff --git a/swarms/swarms.py b/swarms/swarms.py index e4e21cf5..1aa78a2c 100644 --- a/swarms/swarms.py +++ b/swarms/swarms.py @@ -1,10 +1,26 @@ from swarms.tools.agent_tools import * from swarms.agents.workers.worker_agent import WorkerNode from swarms.agents.boss.boss_agent import BossNode +from swarms.agents.workers.omni_agent import OmniWorkerAgent + + class Swarms: - def __init__(self, openai_api_key): + def __init__(self, + openai_api_key, + omni_api_key=None, + omni_api_endpoint=None, + omni_api_type=None + ): self.openai_api_key = openai_api_key + self.omni_api_key = omni_api_key + self.omni_api_endpoint = omni_api_endpoint + self.omni_api_key = omni_api_type + + if omni_api_key and omni_api_endpoint and omni_api_type: + self.omni_worker_agent = OmniWorkerAgent(omni_api_key, omni_api_endpoint, omni_api_type) + else: + self.omni_worker_agent = None def initialize_llm(self): # Initialize language model @@ -19,7 +35,11 @@ class Swarms: ReadFileTool(root_dir=ROOT_DIR), process_csv, WebpageQATool(qa_chain=load_qa_with_sources_chain(llm)), + # self.omni_worker_agent.chat # Add the OmniWorkerAgent's chat method as a tool ] + + if self.omni_worker_agent: + tools.append(self.omni_worker_agent.chat) #add omniworker agent class return tools def initialize_vectorstore(self):