You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
109 lines
2.9 KiB
109 lines
2.9 KiB
from typing import List, Dict, Any, Union
|
|
from concurrent.futures import Executor, ThreadPoolExecutor, as_completed
|
|
from graphlib import TopologicalSorter
|
|
|
|
|
|
class Task:
|
|
def __init__(
|
|
self,
|
|
id: str,
|
|
parents: List["Task"] = None,
|
|
children: List["Task"] = None
|
|
):
|
|
self.id = id
|
|
self.parents = parents
|
|
self.children = children
|
|
|
|
def can_execute(self):
|
|
raise NotImplementedError
|
|
|
|
def execute(self):
|
|
raise NotImplementedError
|
|
|
|
|
|
class NonLinearWorkflow:
|
|
"""
|
|
NonLinearWorkflow constructs a non sequential DAG of tasks to be executed by agents
|
|
|
|
|
|
Architecture:
|
|
NonLinearWorkflow = Task + Agent + Executor
|
|
|
|
ASCII Diagram:
|
|
+-------------------+
|
|
| NonLinearWorkflow |
|
|
+-------------------+
|
|
| |
|
|
| |
|
|
| |
|
|
| |
|
|
| |
|
|
| |
|
|
| |
|
|
| |
|
|
| |
|
|
| |
|
|
+-------------------+
|
|
|
|
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
agents,
|
|
iters_per_task
|
|
):
|
|
"""A workflow is a collection of tasks that can be executed in parallel or sequentially."""
|
|
super().__init__()
|
|
self.executor = ThreadPoolExecutor()
|
|
self.agents = agents
|
|
self.tasks = []
|
|
|
|
def add(self, task: Task):
|
|
"""Add a task to the workflow"""
|
|
assert isinstance(
|
|
task,
|
|
Task
|
|
), "Input must be an nstance of Task"
|
|
self.tasks.append(task)
|
|
return task
|
|
|
|
def run(self):
|
|
"""Run the workflow"""
|
|
ordered_tasks = self.ordered_tasks
|
|
exit_loop = False
|
|
|
|
while not self.is_finished() and not exit_loop:
|
|
futures_list = {}
|
|
|
|
for task in ordered_tasks:
|
|
if task.can_execute:
|
|
future = self.executor.submit(self.agents.run, task.task_string)
|
|
futures_list[future] = task
|
|
|
|
for future in as_completed(futures_list):
|
|
if isinstance(future.result(), Exception):
|
|
exit_loop = True
|
|
break
|
|
return self.output_tasks()
|
|
|
|
def output_tasks(self) -> List[Task]:
|
|
"""Output tasks from the workflow"""
|
|
return [task for task in self.tasks if not task.children]
|
|
|
|
def to_graph(self) -> Dict[str, set[str]]:
|
|
"""Convert the workflow to a graph"""
|
|
graph = {
|
|
task.id: set(child.id for child in task.children) for task in self.tasks
|
|
}
|
|
return graph
|
|
|
|
def order_tasks(self) -> List[Task]:
|
|
"""Order the tasks USING TOPOLOGICAL SORTING"""
|
|
task_order = TopologicalSorter(
|
|
self.to_graph()
|
|
).static_order()
|
|
return [
|
|
self.find_task(task_id) for task_id in task_order
|
|
]
|