You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
171 lines
4.5 KiB
171 lines
4.5 KiB
from typing import Dict, List
|
|
|
|
from langchain.base_language import BaseLanguageModel
|
|
from langchain.tools.base import BaseTool
|
|
from langchain_experimental.autonomous_agents.hugginggpt.repsonse_generator import (
|
|
load_response_generator,
|
|
)
|
|
from langchain_experimental.autonomous_agents.hugginggpt.task_executor import (
|
|
TaskExecutor,
|
|
)
|
|
from langchain_experimental.autonomous_agents.hugginggpt.task_planner import (
|
|
load_chat_planner,
|
|
)
|
|
from transformers import load_tool
|
|
from swarms.agents.message import Message
|
|
|
|
|
|
class Step:
|
|
def __init__(
|
|
self, task: str, id: int, dep: List[int], args: Dict[str, str], tool: BaseTool
|
|
):
|
|
self.task = task
|
|
self.id = id
|
|
self.dep = dep
|
|
self.args = args
|
|
self.tool = tool
|
|
|
|
|
|
class Plan:
|
|
def __init__(self, steps: List[Step]):
|
|
self.steps = steps
|
|
|
|
def __str__(self) -> str:
|
|
return str([str(step) for step in self.steps])
|
|
|
|
def __repr(self) -> str:
|
|
return str(self)
|
|
|
|
|
|
class OmniModalAgent:
|
|
"""
|
|
OmniModalAgent
|
|
LLM -> Plans -> Tasks -> Tools -> Response
|
|
|
|
Architecture:
|
|
1. LLM: Language Model
|
|
2. Chat Planner: Plans
|
|
3. Task Executor: Tasks
|
|
4. Tools: Tools
|
|
|
|
Args:
|
|
llm (BaseLanguageModel): Language Model
|
|
tools (List[BaseTool]): List of tools
|
|
|
|
Returns:
|
|
str: response
|
|
|
|
Usage:
|
|
from swarms import OmniModalAgent, OpenAIChat,
|
|
|
|
llm = OpenAIChat()
|
|
agent = OmniModalAgent(llm)
|
|
response = agent.run("Hello, how are you? Create an image of how your are doing!")
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
llm: BaseLanguageModel,
|
|
# tools: List[BaseTool]
|
|
):
|
|
self.llm = llm
|
|
|
|
print("Loading tools...")
|
|
self.tools = [
|
|
load_tool(tool_name)
|
|
for tool_name in [
|
|
"document-question-answering",
|
|
"image-captioning",
|
|
"image-question-answering",
|
|
"image-segmentation",
|
|
"speech-to-text",
|
|
"summarization",
|
|
"text-classification",
|
|
"text-question-answering",
|
|
"translation",
|
|
"huggingface-tools/text-to-image",
|
|
"huggingface-tools/text-to-video",
|
|
"text-to-speech",
|
|
"huggingface-tools/text-download",
|
|
"huggingface-tools/image-transformation",
|
|
]
|
|
]
|
|
|
|
self.chat_planner = load_chat_planner(llm)
|
|
self.response_generator = load_response_generator(llm)
|
|
# self.task_executor = TaskExecutor
|
|
self.history = []
|
|
|
|
def run(self, input: str) -> str:
|
|
"""Run the OmniAgent"""
|
|
plan = self.chat_planner.plan(
|
|
inputs={
|
|
"input": input,
|
|
"hf_tools": self.tools,
|
|
}
|
|
)
|
|
self.task_executor = TaskExecutor(plan)
|
|
self.task_executor.run()
|
|
|
|
response = self.response_generator.generate(
|
|
{"task_execution": self.task_executor}
|
|
)
|
|
|
|
return response
|
|
|
|
def chat(self, msg: str = None, streaming: bool = False):
|
|
"""
|
|
Run chat
|
|
|
|
Args:
|
|
msg (str, optional): Message to send to the agent. Defaults to None.
|
|
language (str, optional): Language to use. Defaults to None.
|
|
streaming (bool, optional): Whether to stream the response. Defaults to False.
|
|
|
|
Returns:
|
|
str: Response from the agent
|
|
|
|
Usage:
|
|
--------------
|
|
agent = MultiModalAgent()
|
|
agent.chat("Hello")
|
|
|
|
"""
|
|
|
|
# add users message to the history
|
|
self.history.append(Message("User", msg))
|
|
|
|
# process msg
|
|
try:
|
|
response = self.agent.run(msg)
|
|
|
|
# add agent's response to the history
|
|
self.history.append(Message("Agent", response))
|
|
|
|
# if streaming is = True
|
|
if streaming:
|
|
return self._stream_response(response)
|
|
else:
|
|
response
|
|
|
|
except Exception as error:
|
|
error_message = f"Error processing message: {str(error)}"
|
|
|
|
# add error to history
|
|
self.history.append(Message("Agent", error_message))
|
|
|
|
return error_message
|
|
|
|
def _stream_response(self, response: str = None):
|
|
"""
|
|
Yield the response token by token (word by word)
|
|
|
|
Usage:
|
|
--------------
|
|
for token in _stream_response(response):
|
|
print(token)
|
|
|
|
"""
|
|
for token in response.split():
|
|
yield token
|