pull/55/head
Kye 1 year ago
parent d6b62b1ec1
commit 358d1ea146

@ -26,3 +26,10 @@ from swarms.structs.workflow import Workflow
from swarms.swarms.dialogue_simulator import DialogueSimulator
from swarms.swarms.autoscaler import AutoScaler
from swarms.swarms.orchestrate import Orchestrator
#agents
from swarms.swarms.profitpilot import ProfitPilot
from swarms.aot import AoTAgent
from swarms.agents.multi_modal_agent import MultiModalVisualAgent
from swarms.agents.omni_modal_agent import OmniModalAgent

@ -1,8 +1,10 @@
# from swarms.workers.multi_modal_workers.multi_modal_agent import MultiModalVisualAgent
from swarms.workers.multi_modal_workers.multi_modal_agent import MultiModalVisualAgent
class MultiModalVisualAgent:
def __init__(self, agent: MultiModalVisualAgent):
def __init__(
self,
agent: MultiModalVisualAgent
):
self.agent = agent
def _run(self, text: str) -> str:

@ -1,5 +1,3 @@
from langchain.tools import tool
from swarms.workers.multi_modal_workers.omni_agent.omni_chat import chat_huggingface
@ -13,7 +11,6 @@ class OmniModalAgent:
self.api_endpoint = api_endpoint
self.api_type = api_type
@tool
def chat(self, data):
"""Chat with omni-modality model that uses huggingface to query for a specific model at run time. Translate text to speech, create images and more"""
messages = data.get("messages")

@ -17,7 +17,6 @@ from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import Chroma
from pydantic import BaseModel, Field
llm = ChatOpenAI(temperature=0.9)

@ -1559,11 +1559,11 @@ class MultiModalVisualAgent:
def clear_memory(self):
self.memory.clear()
if __name__ == '__main__':
if not os.path.exists("checkpoints"):
os.mkdir("checkpoints")
parser = argparse.ArgumentParser()
parser.add_argument('--load', type=str, default="ImageCaptioning_cuda:0,Text2Image_cuda:0")
args = parser.parse_args()
load_dict = {e.split('_')[0].strip(): e.split('_')[1].strip() for e in args.load.split(',')}
agent = MultiModalVisualAgent(load_dict=load_dict)
# if __name__ == '__main__':
# if not os.path.exists("checkpoints"):
# os.mkdir("checkpoints")
# parser = argparse.ArgumentParser()
# parser.add_argument('--load', type=str, default="ImageCaptioning_cuda:0,Text2Image_cuda:0")
# args = parser.parse_args()
# load_dict = {e.split('_')[0].strip(): e.split('_')[1].strip() for e in args.load.split(',')}
# agent = MultiModalVisualAgent(load_dict=load_dict)
Loading…
Cancel
Save