pull/58/head
Kye 1 year ago
parent a92c3ee7d9
commit 932bc1488e

@ -41,33 +41,33 @@ from transformers import load_tool
# # response = chat_huggingface(messages, api_key, api_type, api_endpoint)
# # return response
# class Step:
# def __init__(
# self,
# task: str,
# id: int,
# dep: List[int],
# args: Dict[str, str],
# tool: BaseTool
# ):
# self.task = task
# self.id = id
# self.dep = dep
# self.args = args
# self.tool = tool
class Step:
def __init__(
self,
task: str,
id: int,
dep: List[int],
args: Dict[str, str],
tool: BaseTool
):
self.task = task
self.id = id
self.dep = dep
self.args = args
self.tool = tool
# class Plan:
# def __init__(
# self,
# steps: List[Step]
# ):
# self.steps = steps
class Plan:
def __init__(
self,
steps: List[Step]
):
self.steps = steps
# def __str__(self) -> str:
# return str([str(step) for step in self.steps])
def __str__(self) -> str:
return str([str(step) for step in self.steps])
# def __repr(self) -> str:
# return str(self)
def __repr(self) -> str:
return str(self)
@ -128,10 +128,11 @@ class OmniModalAgent:
self.chat_planner = load_chat_planner(llm)
self.response_generator = load_response_generator(llm)
self.task_executor: TaskExecutor
self.task_executor = TaskExecutor
def run(self, input: str) -> str:
"""Run the OmniAgent"""
plan = self.chat_planner.plan(
inputs={
"input": input,

Loading…
Cancel
Save