diff --git a/swarms/agents/omni_modal_agent.py b/swarms/agents/omni_modal_agent.py index 80eac09b..dbb1bd15 100644 --- a/swarms/agents/omni_modal_agent.py +++ b/swarms/agents/omni_modal_agent.py @@ -12,6 +12,8 @@ from langchain_experimental.autonomous_agents.hugginggpt.task_planner import ( load_chat_planner, ) from transformers import load_tool +from swarms.agents.message import Message + class Step: def __init__( @@ -103,7 +105,10 @@ class OmniModalAgent: # self.task_executor = TaskExecutor - def run(self, input: str) -> str: + def run( + self, + input: str + ) -> str: """Run the OmniAgent""" plan = self.chat_planner.plan( inputs={ @@ -119,5 +124,83 @@ class OmniModalAgent: ) return response + + def chat( + self, + msg: str = None, + streaming: bool = False + ): + """ + Run chat + + Args: + msg (str, optional): Message to send to the agent. Defaults to None. + language (str, optional): Language to use. Defaults to None. + streaming (bool, optional): Whether to stream the response. Defaults to False. + + Returns: + str: Response from the agent + + Usage: + -------------- + agent = MultiModalAgent() + agent.chat("Hello") + + """ + + #add users message to the history + self.history.append( + Message( + "User", + msg + ) + ) + + #process msg + try: + response = self.agent.run(msg) + + #add agent's response to the history + self.history.append( + Message( + "Agent", + response + ) + ) + + #if streaming is = True + if streaming: + return self._stream_response(response) + else: + response + + except Exception as error: + error_message = f"Error processing message: {str(error)}" + + #add error to history + self.history.append( + Message( + "Agent", + error_message + ) + ) + + return error_message + + def _stream_response( + self, + response: str = None + ): + """ + Yield the response token by token (word by word) + + Usage: + -------------- + for token in _stream_response(response): + print(token) + + """ + for token in response.split(): + yield token diff --git a/tests/agents/omni_modal.py b/tests/agents/omni_modal.py index e8a602ba..3c3c79df 100644 --- a/tests/agents/omni_modal.py +++ b/tests/agents/omni_modal.py @@ -10,7 +10,11 @@ from langchain_experimental.autonomous_agents.hugginggpt.task_planner import ( load_chat_planner, ) from transformers import load_tool -from swarms.agents import OmniModalAgent # Replace `your_module_name` with the appropriate module name + +from swarms.agents import ( + OmniModalAgent, # Replace `your_module_name` with the appropriate module name +) + # Mock objects or set up fixtures for dependent classes or external methods @pytest.fixture