pull/55/head
Kye 1 year ago
parent 55184d629f
commit bef34141d2

@ -0,0 +1,76 @@
from __future__ import annotations
import json
import uuid
from abc import ABC, abstractmethod
from attr import define, field, Factory
from marshmallow import class_registry
from marshmallow.exceptions import RegistryError
@define
class BaseArtifact(ABC):
id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True)
name: str = field(default=Factory(lambda self: self.id, takes_self=True), kw_only=True)
value: any = field()
type: str = field(default=Factory(lambda self: self.__class__.__name__, takes_self=True), kw_only=True)
@classmethod
def value_to_bytes(cls, value: any) -> bytes:
if isinstance(value, bytes):
return value
else:
return str(value).encode()
@classmethod
def value_to_dict(cls, value: any) -> dict:
if isinstance(value, dict):
dict_value = value
else:
dict_value = json.loads(value)
return {k: v for k, v in dict_value.items()}
@classmethod
def from_dict(cls, artifact_dict: dict) -> BaseArtifact:
from griptape.schemas import (
TextArtifactSchema,
InfoArtifactSchema,
ErrorArtifactSchema,
BlobArtifactSchema,
CsvRowArtifactSchema,
ListArtifactSchema
)
class_registry.register("TextArtifact", TextArtifactSchema)
class_registry.register("InfoArtifact", InfoArtifactSchema)
class_registry.register("ErrorArtifact", ErrorArtifactSchema)
class_registry.register("BlobArtifact", BlobArtifactSchema)
class_registry.register("CsvRowArtifact", CsvRowArtifactSchema)
class_registry.register("ListArtifact", ListArtifactSchema)
try:
return class_registry.get_class(artifact_dict["type"])().load(artifact_dict)
except RegistryError:
raise ValueError("Unsupported artifact type")
@classmethod
def from_json(cls, artifact_str: str) -> BaseArtifact:
return cls.from_dict(json.loads(artifact_str))
def __str__(self):
return json.dumps(self.to_dict())
def to_json(self) -> str:
return json.dumps(self.to_dict())
@abstractmethod
def to_text(self) -> str:
...
@abstractmethod
def to_dict(self) -> dict:
...
@abstractmethod
def __add__(self, other: BaseArtifact) -> BaseArtifact:
...

@ -0,0 +1,20 @@
from __future__ import annotations
from attr import define, field
from swarms.artifacts.base import BaseArtifact
@define(frozen=True)
class ErrorArtifact(BaseArtifact):
value: str = field(converter=str)
def __add__(self, other: ErrorArtifact) -> ErrorArtifact:
return ErrorArtifact(self.value + other.value)
def to_text(self) -> str:
return self.value
def to_dict(self) -> dict:
from griptape.schemas import ErrorArtifactSchema
return dict(ErrorArtifactSchema().dump(self))

@ -2,11 +2,144 @@ from __future__ import annotations
import json
import pprint
import uuid
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Optional
from artifacts.main import Artifact
from pydantic import BaseModel, Field, StrictStr, conlist
from swarms.artifacts.error_artifact import ErrorArtifact
class BaseTask(ABC):
class State(Enum):
PENDING = 1
EXECUTING = 2
FINISHED = 3
def __init__(self):
self.id = uuid.uuid4().hex
self.state = self.State.PENDING
self.parent_ids = []
self.child_ids = []
self.output = None
self.structure = None
@property
@abstractmethod
def input(self):
pass
@property
def parents(self):
return [self.structure.find_task(parent_id) for parent_id in self.parent_ids]
@property
def children(self):
return [self.structure.find_task(child_id) for child_id in self.child_ids]
def __rshift__(self, child):
return self.add_child(child)
def __lshift__(self, child):
return self.add_parent(child)
def preprocess(self, structure):
self.structure = structure
return self
def add_child(self, child):
if self.structure:
child.structure = self.structure
elif child.structure:
self.structure = child.structure
if child not in self.structure.tasks:
self.structure.tasks.append(child)
if self not in self.structure.tasks:
self.structure.tasks.append(self)
if child.id not in self.child_ids:
self.child_ids.append(child.id)
if self.id not in child.parent_ids:
child.parent_ids.append(self.id)
return child
def add_parent(self, parent):
if self.structure:
parent.structure = self.structure
elif parent.structure:
self.structure = parent.structure
if parent not in self.structure.tasks:
self.structure.tasks.append(parent)
if self not in self.structure.tasks:
self.structure.tasks.append(self)
if parent.id not in self.parent_ids:
self.parent_ids.append(parent.id)
if self.id not in parent.child_ids:
parent.child_ids.append(self.id)
return parent
def is_pending(self):
return self.state == self.State.PENDING
def is_finished(self):
return self.state == self.State.FINISHED
def is_executing(self):
return self.state == self.State.EXECUTING
def before_run(self):
pass
def after_run(self):
pass
def execute(self):
try:
self.state = self.State.EXECUTING
self.before_run()
self.output = self.run()
self.after_run()
except Exception as e:
self.output = ErrorArtifact(str(e))
finally:
self.state = self.State.FINISHED
return self.output
def can_execute(self):
return self.state == self.State.PENDING and all(parent.is_finished() for parent in self.parents)
def reset(self):
self.state = self.State.PENDING
self.output = None
return self
@abstractmethod
def run(self):
pass
class Task(BaseModel):
input: Optional[StrictStr] = Field(

@ -1,140 +1,248 @@
from __future__ import annotations
# from __future__ import annotations
# import concurrent.futures as futures
# import logging
# import uuid
# from abc import ABC, abstractmethod
# from graphlib import TopologicalSorter
# from logging import Logger
# from typing import Optional, Union
# from rich.logging import RichHandler
# # from swarms.artifacts.error_artifact import ErrorArtifact
# from swarms.artifacts.main import Artifact as ErrorArtifact
# from swarms.structs.task import BaseTask
# #@shapeless
# class Workflow(ABC):
# def __init__(
# self,
# id: str = uuid.uuid4().hex,
# model = None,
# custom_logger: Optional[Logger] = None,
# logger_level: int = logging.INFO,
# futures_executor: futures.Executor = futures.ThreadPoolExecutor()
# ):
# self.id = id
# self.model = model
# self.custom_logger = custom_logger
# self.logger_level = logger_level
# self.futures_executor = futures_executor
# self._execution_args = ()
# self._logger = None
# [task.preprocess(self) for task in self.tasks]
# self.model.structure = self
# @property
# def execution_args(self) -> tuple:
# return self._execution_args
import concurrent.futures as futures
import logging
import uuid
from abc import ABC, abstractmethod
from graphlib import TopologicalSorter
from logging import Logger
from typing import Optional, Union
# @property
# def logger(self) -> Logger:
# if self.custom_logger:
# return self.custom_logger
# else:
# if self._logger is None:
# self._logger = logging.getLogger(self.LOGGER_NAME)
from rich.logging import RichHandler
# self._logger.propagate = False
# self._logger.level = self.logger_level
from swarms.artifacts.error_artifact import ErrorArtifact
from swarms.structs.task import BaseTask
# self._logger.handlers = [
# RichHandler(
# show_time=True,
# show_path=False
# )
# ]
# return self._logger
# def is_finished(self) -> bool:
# return all(s.is_finished() for s in self.tasks)
# def is_executing(self) -> bool:
# return any(s for s in self.tasks if s.is_executing())
# def find_task(self, task_id: str) -> Optional[BaseTask]:
# return next((task for task in self.tasks if task.id == task_id), None)
# def add_tasks(self, *tasks: BaseTask) -> list[BaseTask]:
# return [self.add_task(s) for s in tasks]
# def context(self, task: BaseTask) -> dict[str, any]:
# return {
# "args": self.execution_args,
# "structure": self,
# }
# @abstractmethod
# def add(self, task: BaseTask) -> BaseTask:
# task.preprocess(self)
# self.tasks.append(task)
# return task
# @abstractmethod
# def run(self, *args) -> Union[BaseTask, list[BaseTask]]:
# self._execution_args = args
# ordered_tasks = self.order_tasks()
# exit_loop = False
# while not self.is_finished() and not exit_loop:
# futures_list = {}
# for task in ordered_tasks:
# if task.can_execute():
# future = self.futures_executor.submit(task.execute)
# futures_list[future] = task
# # Wait for all tasks to complete
# for future in futures.as_completed(futures_list):
# if isinstance(future.result(), ErrorArtifact):
# exit_loop = True
# break
# self._execution_args = ()
# return self.output_tasks()
# def context(self, task: BaseTask) -> dict[str, any]:
# context = super().context(task)
# context.update(
# {
# "parent_outputs": {parent.id: parent.output.to_text() if parent.output else "" for parent in task.parents},
# "parents": {parent.id: parent for parent in task.parents},
# "children": {child.id: child for child in task.children}
# }
# )
# return context
# def output_tasks(self) -> list[BaseTask]:
# return [task for task in self.tasks if not task.children]
# def to_graph(self) -> dict[str, set[str]]:
# graph: dict[str, set[str]] = {}
# for key_task in self.tasks:
# graph[key_task.id] = set()
# for value_task in self.tasks:
# if key_task.id in value_task.child_ids:
# graph[key_task.id].add(value_task.id)
# return graph
# def order_tasks(self) -> list[BaseTask]:
# return [self.find_task(task_id) for task_id in TopologicalSorter(self.to_graph()).static_order()]
#@shapeless
class Workflow(ABC):
from __future__ import annotations
from typing import Any, Dict, List, Optional
from swarms.artifacts.error_artifacts import ErrorArtifact
from swarms.structs.task import BaseTask
class StringTask(BaseTask):
def __init__(
self,
id: str = uuid.uuid4().hex,
model = None,
custom_logger: Optional[Logger] = None,
logger_level: int = logging.INFO,
futures_executor: futures.Executor = futures.ThreadPoolExecutor()
task
):
self.id = id
self.model = model
self.custom_logger = custom_logger
self.logger_level = logger_level
super().__init__()
self.task = task
self.futures_executor = futures_executor
self._execution_args = ()
self._logger = None
def execute(self) -> Any:
prompt = self.task_string.replace("{{ parent_input }}", self.parents[0].output if self.parents else "")
response = self.structure.llm.run(prompt)
self.output = response
return response
[task.preprocess(self) for task in self.tasks]
self.model.structure = self
@property
def execution_args(self) -> tuple:
return self._execution_args
class Workflow:
"""
Workflows are ideal for prescriptive processes that need to be executed sequentially.
They string together multiple tasks of varying types, and can use Short-Term Memory
or pass specific arguments downstream.
@property
def logger(self) -> Logger:
if self.custom_logger:
return self.custom_logger
else:
if self._logger is None:
self._logger = logging.getLogger(self.LOGGER_NAME)
self._logger.propagate = False
self._logger.level = self.logger_level
self._logger.handlers = [
RichHandler(
show_time=True,
show_path=False
)
]
return self._logger
```
llm = LLM()
workflow = Workflow(llm)
def is_finished(self) -> bool:
return all(s.is_finished() for s in self.tasks)
workflow.add("What's the weather in miami")
workflow.add("Provide detauls for {{ parent_output }}")
workflow.add("Summarize the above information: {{ parent_output}})
def is_executing(self) -> bool:
return any(s for s in self.tasks if s.is_executing())
workflow.run()
def find_task(self, task_id: str) -> Optional[BaseTask]:
return next((task for task in self.tasks if task.id == task_id), None)
def add_tasks(self, *tasks: BaseTask) -> list[BaseTask]:
return [self.add_task(s) for s in tasks]
"""
def __init__(
self,
llm
):
self.llm = llm
self.tasks: List[BaseTask] = []
def context(self, task: BaseTask) -> dict[str, any]:
return {
"args": self.execution_args,
"structure": self,
}
def add(
self,
task: BaseTask
) -> BaseTask:
task = StringTask(task)
@abstractmethod
def add(self, task: BaseTask) -> BaseTask:
task.preprocess(self)
if self.last_task():
self.last_task().add_child(task)
else:
task.structure = self
self.tasks.append(task)
return task
@abstractmethod
def run(self, *args) -> Union[BaseTask, list[BaseTask]]:
self._execution_args = args
ordered_tasks = self.order_tasks()
exit_loop = False
def first_task(self) -> Optional[BaseTask]:
return self.tasks[0] if self.tasks else None
while not self.is_finished() and not exit_loop:
futures_list = {}
def last_task(self) -> Optional[BaseTask]:
return self.tasks[-1] if self.tasks else None
for task in ordered_tasks:
if task.can_execute():
future = self.futures_executor.submit(task.execute)
futures_list[future] = task
def run(self, *args) -> BaseTask:
self._execution_args = args
[task.reset() for task in self.tasks]
# Wait for all tasks to complete
for future in futures.as_completed(futures_list):
if isinstance(future.result(), ErrorArtifact):
exit_loop = True
break
self.__run_from_task(self.first_task())
self._execution_args = ()
return self.output_tasks()
return self.last_task()
def context(self, task: BaseTask) -> dict[str, any]:
def context(self, task: BaseTask) -> Dict[str, Any]:
context = super().context(task)
context.update(
{
"parent_outputs": {parent.id: parent.output.to_text() if parent.output else "" for parent in task.parents},
"parents": {parent.id: parent for parent in task.parents},
"children": {child.id: child for child in task.children}
"parent_output": task.parents[0].output.to_text() \
if task.parents and task.parents[0].output else None,
"parent": task.parents[0] if task.parents else None,
"child": task.children[0] if task.children else None
}
)
return context
def output_tasks(self) -> list[BaseTask]:
return [task for task in self.tasks if not task.children]
def to_graph(self) -> dict[str, set[str]]:
graph: dict[str, set[str]] = {}
for key_task in self.tasks:
graph[key_task.id] = set()
for value_task in self.tasks:
if key_task.id in value_task.child_ids:
graph[key_task.id].add(value_task.id)
def __run_from_task(self, task: Optional[BaseTask]) -> None:
if task is None:
return
else:
if isinstance(task.execute(), ErrorArtifact):
return
else:
self.__run_from_task(next(iter(task.children), None))
return graph
def order_tasks(self) -> list[BaseTask]:
return [self.find_task(task_id) for task_id in TopologicalSorter(self.to_graph()).static_order()]
Loading…
Cancel
Save