Former-commit-id: a7b019039f
group-chat
Kye 1 year ago
parent e176cbc946
commit 4393f70504

@ -27,7 +27,6 @@ transformers = "*"
openai = "*" openai = "*"
langchain = "0.0.240" langchain = "0.0.240"
asyncio = "*" asyncio = "*"
simpleaichat = "*"
nest_asyncio = "*" nest_asyncio = "*"
pegasusx = "*" pegasusx = "*"
einops = "*" einops = "*"
@ -46,6 +45,7 @@ tenacity = "*"
celery = "*" celery = "*"
redis = "*" redis = "*"
Pillow = "*" Pillow = "*"
shapeless="*"
[tool.poetry.dev-dependencies] [tool.poetry.dev-dependencies]
# Add development dependencies here # Add development dependencies here

@ -17,7 +17,7 @@ faiss-cpu
openai openai
google-generativeai google-generativeai
duckduckgo-search duckduckgo-search
shapeless
mkdocs mkdocs
mkdocs-material mkdocs-material

@ -0,0 +1,14 @@
from shapeless import shapeless
@shapeless
class Task:
def __init__(
self,
id,
):
self.id = id
def forward(self):
pass

@ -0,0 +1,133 @@
from __future__ import annotations
import logging
import uuid
from abc import ABC, abstractmethod
from logging import Logger
from typing import Optional, Union, TYPE_CHECKING, Callable, Type
from rich.logging import RichHandler
import concurrent.futures as futures
from graphlib import TopologicalSorter
class Workflow(ABC):
def __init__(
self,
id: str = uuid.uuid4().hex,
model = None,
custom_logger: Optional[Logger] = None,
logger_level: int = logging.INFO,
futures_executor: futures.Executor = futures.ThreadPoolExecutor()
):
self.id = id
self.model = model
self.custom_logger = custom_logger
self.logger_level = logger_level
self.futures_executor = futures_executor
self._execution_args = ()
self._logger = None
[task.preprocess(self) for task in self.tasks]
self.model.structure = self
@property
def execution_args(self) -> tuple:
return self._execution_args
@property
def logger(self) -> Logger:
if self.custom_logger:
return self.custom_logger
else:
if self._logger is None:
self._logger = logging.getLogger(self.LOGGER_NAME)
self._logger.propagate = False
self._logger.level = self.logger_level
self._logger.handlers = [
RichHandler(
show_time=True,
show_path=False
)
]
return self._logger
def is_finished(self) -> bool:
return all(s.is_finished() for s in self.tasks)
def is_executing(self) -> bool:
return any(s for s in self.tasks if s.is_executing())
def find_task(self, task_id: str) -> Optional[BaseTask]:
return next((task for task in self.tasks if task.id == task_id), None)
def add_tasks(self, *tasks: BaseTask) -> list[BaseTask]:
return [self.add_task(s) for s in tasks]
def context(self, task: BaseTask) -> dict[str, any]:
return {
"args": self.execution_args,
"structure": self,
}
@abstractmethod
def add_task(self, task: BaseTask) -> BaseTask:
task.preprocess(self)
self.tasks.append(task)
return task
@abstractmethod
def run(self, *args) -> Union[BaseTask, list[BaseTask]]:
self._execution_args = args
ordered_tasks = self.order_tasks()
exit_loop = False
while not self.is_finished() and not exit_loop:
futures_list = {}
for task in ordered_tasks:
if task.can_execute():
future = self.futures_executor.submit(task.execute)
futures_list[future] = task
# Wait for all tasks to complete
for future in futures.as_completed(futures_list):
if isinstance(future.result(), ErrorArtifact):
exit_loop = True
break
self._execution_args = ()
return self.output_tasks()
def context(self, task: BaseTask) -> dict[str, any]:
context = super().context(task)
context.update(
{
"parent_outputs": {parent.id: parent.output.to_text() if parent.output else "" for parent in task.parents},
"parents": {parent.id: parent for parent in task.parents},
"children": {child.id: child for child in task.children}
}
)
return context
def output_tasks(self) -> list[BaseTask]:
return [task for task in self.tasks if not task.children]
def to_graph(self) -> dict[str, set[str]]:
graph: dict[str, set[str]] = {}
for key_task in self.tasks:
graph[key_task.id] = set()
for value_task in self.tasks:
if key_task.id in value_task.child_ids:
graph[key_task.id].add(value_task.id)
return graph
def order_tasks(self) -> list[BaseTask]:
return [self.find_task(task_id) for task_id in TopologicalSorter(self.to_graph()).static_order()]
Loading…
Cancel
Save