From 22829d0e4223180e1274d770f0a6d0962b30576d Mon Sep 17 00:00:00 2001 From: Kye Date: Fri, 15 Sep 2023 14:03:51 -0400 Subject: [PATCH] `Workflow` Former-commit-id: bef34141d2232bb74063c84e6de92ffafa15efc5 --- swarms/artifacts/base.py | 76 +++++++ swarms/artifacts/error_artifact.py | 20 ++ swarms/structs/task.py | 133 ++++++++++++ swarms/structs/workflow.py | 324 +++++++++++++++++++---------- 4 files changed, 445 insertions(+), 108 deletions(-) create mode 100644 swarms/artifacts/base.py create mode 100644 swarms/artifacts/error_artifact.py diff --git a/swarms/artifacts/base.py b/swarms/artifacts/base.py new file mode 100644 index 00000000..b1d5a1f5 --- /dev/null +++ b/swarms/artifacts/base.py @@ -0,0 +1,76 @@ +from __future__ import annotations +import json +import uuid +from abc import ABC, abstractmethod +from attr import define, field, Factory +from marshmallow import class_registry +from marshmallow.exceptions import RegistryError + + +@define +class BaseArtifact(ABC): + id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True) + name: str = field(default=Factory(lambda self: self.id, takes_self=True), kw_only=True) + value: any = field() + type: str = field(default=Factory(lambda self: self.__class__.__name__, takes_self=True), kw_only=True) + + @classmethod + def value_to_bytes(cls, value: any) -> bytes: + if isinstance(value, bytes): + return value + else: + return str(value).encode() + + @classmethod + def value_to_dict(cls, value: any) -> dict: + if isinstance(value, dict): + dict_value = value + else: + dict_value = json.loads(value) + + return {k: v for k, v in dict_value.items()} + + @classmethod + def from_dict(cls, artifact_dict: dict) -> BaseArtifact: + from griptape.schemas import ( + TextArtifactSchema, + InfoArtifactSchema, + ErrorArtifactSchema, + BlobArtifactSchema, + CsvRowArtifactSchema, + ListArtifactSchema + ) + + class_registry.register("TextArtifact", TextArtifactSchema) + class_registry.register("InfoArtifact", InfoArtifactSchema) + class_registry.register("ErrorArtifact", ErrorArtifactSchema) + class_registry.register("BlobArtifact", BlobArtifactSchema) + class_registry.register("CsvRowArtifact", CsvRowArtifactSchema) + class_registry.register("ListArtifact", ListArtifactSchema) + + try: + return class_registry.get_class(artifact_dict["type"])().load(artifact_dict) + except RegistryError: + raise ValueError("Unsupported artifact type") + + @classmethod + def from_json(cls, artifact_str: str) -> BaseArtifact: + return cls.from_dict(json.loads(artifact_str)) + + def __str__(self): + return json.dumps(self.to_dict()) + + def to_json(self) -> str: + return json.dumps(self.to_dict()) + + @abstractmethod + def to_text(self) -> str: + ... + + @abstractmethod + def to_dict(self) -> dict: + ... + + @abstractmethod + def __add__(self, other: BaseArtifact) -> BaseArtifact: + ... \ No newline at end of file diff --git a/swarms/artifacts/error_artifact.py b/swarms/artifacts/error_artifact.py new file mode 100644 index 00000000..68851540 --- /dev/null +++ b/swarms/artifacts/error_artifact.py @@ -0,0 +1,20 @@ +from __future__ import annotations +from attr import define, field +from swarms.artifacts.base import BaseArtifact + + +@define(frozen=True) +class ErrorArtifact(BaseArtifact): + value: str = field(converter=str) + + def __add__(self, other: ErrorArtifact) -> ErrorArtifact: + return ErrorArtifact(self.value + other.value) + + def to_text(self) -> str: + return self.value + + def to_dict(self) -> dict: + from griptape.schemas import ErrorArtifactSchema + + return dict(ErrorArtifactSchema().dump(self)) + \ No newline at end of file diff --git a/swarms/structs/task.py b/swarms/structs/task.py index 8618234a..ee42e122 100644 --- a/swarms/structs/task.py +++ b/swarms/structs/task.py @@ -2,11 +2,144 @@ from __future__ import annotations import json import pprint +import uuid +from abc import ABC, abstractmethod +from enum import Enum from typing import Any, Optional from artifacts.main import Artifact from pydantic import BaseModel, Field, StrictStr, conlist +from swarms.artifacts.error_artifact import ErrorArtifact + + +class BaseTask(ABC): + class State(Enum): + PENDING = 1 + EXECUTING = 2 + FINISHED = 3 + + def __init__(self): + self.id = uuid.uuid4().hex + self.state = self.State.PENDING + self.parent_ids = [] + self.child_ids = [] + self.output = None + self.structure = None + + @property + @abstractmethod + def input(self): + pass + + @property + def parents(self): + return [self.structure.find_task(parent_id) for parent_id in self.parent_ids] + + @property + def children(self): + return [self.structure.find_task(child_id) for child_id in self.child_ids] + + def __rshift__(self, child): + return self.add_child(child) + + def __lshift__(self, child): + return self.add_parent(child) + + def preprocess(self, structure): + self.structure = structure + return self + + def add_child(self, child): + if self.structure: + child.structure = self.structure + elif child.structure: + self.structure = child.structure + + if child not in self.structure.tasks: + self.structure.tasks.append(child) + + if self not in self.structure.tasks: + self.structure.tasks.append(self) + + if child.id not in self.child_ids: + self.child_ids.append(child.id) + + if self.id not in child.parent_ids: + child.parent_ids.append(self.id) + + return child + + def add_parent(self, parent): + if self.structure: + parent.structure = self.structure + elif parent.structure: + self.structure = parent.structure + + if parent not in self.structure.tasks: + self.structure.tasks.append(parent) + + if self not in self.structure.tasks: + self.structure.tasks.append(self) + + if parent.id not in self.parent_ids: + self.parent_ids.append(parent.id) + + if self.id not in parent.child_ids: + parent.child_ids.append(self.id) + + return parent + + def is_pending(self): + return self.state == self.State.PENDING + + def is_finished(self): + return self.state == self.State.FINISHED + + def is_executing(self): + return self.state == self.State.EXECUTING + + def before_run(self): + pass + + def after_run(self): + pass + + def execute(self): + try: + self.state = self.State.EXECUTING + self.before_run() + self.output = self.run() + self.after_run() + except Exception as e: + self.output = ErrorArtifact(str(e)) + finally: + self.state = self.State.FINISHED + return self.output + + def can_execute(self): + return self.state == self.State.PENDING and all(parent.is_finished() for parent in self.parents) + + def reset(self): + self.state = self.State.PENDING + self.output = None + return self + + @abstractmethod + def run(self): + pass + + + + + + + + + + + + class Task(BaseModel): input: Optional[StrictStr] = Field( diff --git a/swarms/structs/workflow.py b/swarms/structs/workflow.py index 4bcc7ff1..14f30f72 100644 --- a/swarms/structs/workflow.py +++ b/swarms/structs/workflow.py @@ -1,140 +1,248 @@ -from __future__ import annotations +# from __future__ import annotations + +# import concurrent.futures as futures +# import logging +# import uuid +# from abc import ABC, abstractmethod +# from graphlib import TopologicalSorter +# from logging import Logger +# from typing import Optional, Union + +# from rich.logging import RichHandler + +# # from swarms.artifacts.error_artifact import ErrorArtifact +# from swarms.artifacts.main import Artifact as ErrorArtifact +# from swarms.structs.task import BaseTask + + +# #@shapeless +# class Workflow(ABC): +# def __init__( +# self, +# id: str = uuid.uuid4().hex, +# model = None, +# custom_logger: Optional[Logger] = None, +# logger_level: int = logging.INFO, +# futures_executor: futures.Executor = futures.ThreadPoolExecutor() +# ): +# self.id = id +# self.model = model +# self.custom_logger = custom_logger +# self.logger_level = logger_level + +# self.futures_executor = futures_executor +# self._execution_args = () +# self._logger = None + +# [task.preprocess(self) for task in self.tasks] + +# self.model.structure = self + +# @property +# def execution_args(self) -> tuple: +# return self._execution_args -import concurrent.futures as futures -import logging -import uuid -from abc import ABC, abstractmethod -from graphlib import TopologicalSorter -from logging import Logger -from typing import Optional, Union +# @property +# def logger(self) -> Logger: +# if self.custom_logger: +# return self.custom_logger +# else: +# if self._logger is None: +# self._logger = logging.getLogger(self.LOGGER_NAME) -from rich.logging import RichHandler +# self._logger.propagate = False +# self._logger.level = self.logger_level -from swarms.artifacts.error_artifact import ErrorArtifact -from swarms.structs.task import BaseTask +# self._logger.handlers = [ +# RichHandler( +# show_time=True, +# show_path=False +# ) +# ] +# return self._logger + +# def is_finished(self) -> bool: +# return all(s.is_finished() for s in self.tasks) + +# def is_executing(self) -> bool: +# return any(s for s in self.tasks if s.is_executing()) + +# def find_task(self, task_id: str) -> Optional[BaseTask]: +# return next((task for task in self.tasks if task.id == task_id), None) + +# def add_tasks(self, *tasks: BaseTask) -> list[BaseTask]: +# return [self.add_task(s) for s in tasks] + +# def context(self, task: BaseTask) -> dict[str, any]: +# return { +# "args": self.execution_args, +# "structure": self, +# } + +# @abstractmethod +# def add(self, task: BaseTask) -> BaseTask: +# task.preprocess(self) +# self.tasks.append(task) +# return task + +# @abstractmethod +# def run(self, *args) -> Union[BaseTask, list[BaseTask]]: +# self._execution_args = args +# ordered_tasks = self.order_tasks() +# exit_loop = False + +# while not self.is_finished() and not exit_loop: +# futures_list = {} + +# for task in ordered_tasks: +# if task.can_execute(): +# future = self.futures_executor.submit(task.execute) +# futures_list[future] = task + +# # Wait for all tasks to complete +# for future in futures.as_completed(futures_list): +# if isinstance(future.result(), ErrorArtifact): +# exit_loop = True +# break + +# self._execution_args = () + +# return self.output_tasks() + +# def context(self, task: BaseTask) -> dict[str, any]: +# context = super().context(task) + +# context.update( +# { +# "parent_outputs": {parent.id: parent.output.to_text() if parent.output else "" for parent in task.parents}, +# "parents": {parent.id: parent for parent in task.parents}, +# "children": {child.id: child for child in task.children} +# } +# ) + +# return context + +# def output_tasks(self) -> list[BaseTask]: +# return [task for task in self.tasks if not task.children] + +# def to_graph(self) -> dict[str, set[str]]: +# graph: dict[str, set[str]] = {} + +# for key_task in self.tasks: +# graph[key_task.id] = set() + +# for value_task in self.tasks: +# if key_task.id in value_task.child_ids: +# graph[key_task.id].add(value_task.id) + +# return graph +# def order_tasks(self) -> list[BaseTask]: +# return [self.find_task(task_id) for task_id in TopologicalSorter(self.to_graph()).static_order()] -#@shapeless -class Workflow(ABC): + +from __future__ import annotations + +from typing import Any, Dict, List, Optional + +from swarms.artifacts.error_artifacts import ErrorArtifact +from swarms.structs.task import BaseTask + +class StringTask(BaseTask): def __init__( - self, - id: str = uuid.uuid4().hex, - model = None, - custom_logger: Optional[Logger] = None, - logger_level: int = logging.INFO, - futures_executor: futures.Executor = futures.ThreadPoolExecutor() + self, + task ): - self.id = id - self.model = model - self.custom_logger = custom_logger - self.logger_level = logger_level + super().__init__() + self.task = task + + def execute(self) -> Any: + prompt = self.task_string.replace("{{ parent_input }}", self.parents[0].output if self.parents else "") + response = self.structure.llm.run(prompt) + self.output = response + return response + + + +class Workflow: + """ + Workflows are ideal for prescriptive processes that need to be executed sequentially. + They string together multiple tasks of varying types, and can use Short-Term Memory + or pass specific arguments downstream. + - self.futures_executor = futures_executor - self._execution_args = () - self._logger = None - [task.preprocess(self) for task in self.tasks] + ``` + llm = LLM() + workflow = Workflow(llm) - self.model.structure = self + workflow.add("What's the weather in miami") + workflow.add("Provide detauls for {{ parent_output }}") + workflow.add("Summarize the above information: {{ parent_output}}) - @property - def execution_args(self) -> tuple: - return self._execution_args + workflow.run() - @property - def logger(self) -> Logger: - if self.custom_logger: - return self.custom_logger + + """ + def __init__( + self, + llm + ): + self.llm = llm + self.tasks: List[BaseTask] = [] + + def add( + self, + task: BaseTask + ) -> BaseTask: + task = StringTask(task) + + if self.last_task(): + self.last_task().add_child(task) else: - if self._logger is None: - self._logger = logging.getLogger(self.LOGGER_NAME) - - self._logger.propagate = False - self._logger.level = self.logger_level - - self._logger.handlers = [ - RichHandler( - show_time=True, - show_path=False - ) - ] - return self._logger - - def is_finished(self) -> bool: - return all(s.is_finished() for s in self.tasks) - - def is_executing(self) -> bool: - return any(s for s in self.tasks if s.is_executing()) - - def find_task(self, task_id: str) -> Optional[BaseTask]: - return next((task for task in self.tasks if task.id == task_id), None) - - def add_tasks(self, *tasks: BaseTask) -> list[BaseTask]: - return [self.add_task(s) for s in tasks] - - def context(self, task: BaseTask) -> dict[str, any]: - return { - "args": self.execution_args, - "structure": self, - } - - @abstractmethod - def add(self, task: BaseTask) -> BaseTask: - task.preprocess(self) - self.tasks.append(task) + task.structure = self + self.tasks.append(task) return task - @abstractmethod - def run(self, *args) -> Union[BaseTask, list[BaseTask]]: - self._execution_args = args - ordered_tasks = self.order_tasks() - exit_loop = False + def first_task(self) -> Optional[BaseTask]: + return self.tasks[0] if self.tasks else None + + def last_task(self) -> Optional[BaseTask]: + return self.tasks[-1] if self.tasks else None - while not self.is_finished() and not exit_loop: - futures_list = {} + def run(self, *args) -> BaseTask: + self._execution_args = args - for task in ordered_tasks: - if task.can_execute(): - future = self.futures_executor.submit(task.execute) - futures_list[future] = task + [task.reset() for task in self.tasks] - # Wait for all tasks to complete - for future in futures.as_completed(futures_list): - if isinstance(future.result(), ErrorArtifact): - exit_loop = True - break + self.__run_from_task(self.first_task()) self._execution_args = () - return self.output_tasks() + return self.last_task() - def context(self, task: BaseTask) -> dict[str, any]: + + def context(self, task: BaseTask) -> Dict[str, Any]: context = super().context(task) context.update( { - "parent_outputs": {parent.id: parent.output.to_text() if parent.output else "" for parent in task.parents}, - "parents": {parent.id: parent for parent in task.parents}, - "children": {child.id: child for child in task.children} + "parent_output": task.parents[0].output.to_text() \ + if task.parents and task.parents[0].output else None, + "parent": task.parents[0] if task.parents else None, + "child": task.children[0] if task.children else None } ) - return context - def output_tasks(self) -> list[BaseTask]: - return [task for task in self.tasks if not task.children] - - def to_graph(self) -> dict[str, set[str]]: - graph: dict[str, set[str]] = {} - - for key_task in self.tasks: - graph[key_task.id] = set() - - for value_task in self.tasks: - if key_task.id in value_task.child_ids: - graph[key_task.id].add(value_task.id) + + def __run_from_task(self, task: Optional[BaseTask]) -> None: + if task is None: + return + else: + if isinstance(task.execute(), ErrorArtifact): + return + else: + self.__run_from_task(next(iter(task.children), None)) - return graph - def order_tasks(self) -> list[BaseTask]: - return [self.find_task(task_id) for task_id in TopologicalSorter(self.to_graph()).static_order()] \ No newline at end of file