parent
55184d629f
commit
bef34141d2
@ -0,0 +1,76 @@
|
||||
from __future__ import annotations
|
||||
import json
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from attr import define, field, Factory
|
||||
from marshmallow import class_registry
|
||||
from marshmallow.exceptions import RegistryError
|
||||
|
||||
|
||||
@define
|
||||
class BaseArtifact(ABC):
|
||||
id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True)
|
||||
name: str = field(default=Factory(lambda self: self.id, takes_self=True), kw_only=True)
|
||||
value: any = field()
|
||||
type: str = field(default=Factory(lambda self: self.__class__.__name__, takes_self=True), kw_only=True)
|
||||
|
||||
@classmethod
|
||||
def value_to_bytes(cls, value: any) -> bytes:
|
||||
if isinstance(value, bytes):
|
||||
return value
|
||||
else:
|
||||
return str(value).encode()
|
||||
|
||||
@classmethod
|
||||
def value_to_dict(cls, value: any) -> dict:
|
||||
if isinstance(value, dict):
|
||||
dict_value = value
|
||||
else:
|
||||
dict_value = json.loads(value)
|
||||
|
||||
return {k: v for k, v in dict_value.items()}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, artifact_dict: dict) -> BaseArtifact:
|
||||
from griptape.schemas import (
|
||||
TextArtifactSchema,
|
||||
InfoArtifactSchema,
|
||||
ErrorArtifactSchema,
|
||||
BlobArtifactSchema,
|
||||
CsvRowArtifactSchema,
|
||||
ListArtifactSchema
|
||||
)
|
||||
|
||||
class_registry.register("TextArtifact", TextArtifactSchema)
|
||||
class_registry.register("InfoArtifact", InfoArtifactSchema)
|
||||
class_registry.register("ErrorArtifact", ErrorArtifactSchema)
|
||||
class_registry.register("BlobArtifact", BlobArtifactSchema)
|
||||
class_registry.register("CsvRowArtifact", CsvRowArtifactSchema)
|
||||
class_registry.register("ListArtifact", ListArtifactSchema)
|
||||
|
||||
try:
|
||||
return class_registry.get_class(artifact_dict["type"])().load(artifact_dict)
|
||||
except RegistryError:
|
||||
raise ValueError("Unsupported artifact type")
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, artifact_str: str) -> BaseArtifact:
|
||||
return cls.from_dict(json.loads(artifact_str))
|
||||
|
||||
def __str__(self):
|
||||
return json.dumps(self.to_dict())
|
||||
|
||||
def to_json(self) -> str:
|
||||
return json.dumps(self.to_dict())
|
||||
|
||||
@abstractmethod
|
||||
def to_text(self) -> str:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def to_dict(self) -> dict:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def __add__(self, other: BaseArtifact) -> BaseArtifact:
|
||||
...
|
@ -0,0 +1,20 @@
|
||||
from __future__ import annotations
|
||||
from attr import define, field
|
||||
from swarms.artifacts.base import BaseArtifact
|
||||
|
||||
|
||||
@define(frozen=True)
|
||||
class ErrorArtifact(BaseArtifact):
|
||||
value: str = field(converter=str)
|
||||
|
||||
def __add__(self, other: ErrorArtifact) -> ErrorArtifact:
|
||||
return ErrorArtifact(self.value + other.value)
|
||||
|
||||
def to_text(self) -> str:
|
||||
return self.value
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
from griptape.schemas import ErrorArtifactSchema
|
||||
|
||||
return dict(ErrorArtifactSchema().dump(self))
|
||||
|
@ -1,140 +1,248 @@
|
||||
from __future__ import annotations
|
||||
# from __future__ import annotations
|
||||
|
||||
# import concurrent.futures as futures
|
||||
# import logging
|
||||
# import uuid
|
||||
# from abc import ABC, abstractmethod
|
||||
# from graphlib import TopologicalSorter
|
||||
# from logging import Logger
|
||||
# from typing import Optional, Union
|
||||
|
||||
# from rich.logging import RichHandler
|
||||
|
||||
# # from swarms.artifacts.error_artifact import ErrorArtifact
|
||||
# from swarms.artifacts.main import Artifact as ErrorArtifact
|
||||
# from swarms.structs.task import BaseTask
|
||||
|
||||
|
||||
# #@shapeless
|
||||
# class Workflow(ABC):
|
||||
# def __init__(
|
||||
# self,
|
||||
# id: str = uuid.uuid4().hex,
|
||||
# model = None,
|
||||
# custom_logger: Optional[Logger] = None,
|
||||
# logger_level: int = logging.INFO,
|
||||
# futures_executor: futures.Executor = futures.ThreadPoolExecutor()
|
||||
# ):
|
||||
# self.id = id
|
||||
# self.model = model
|
||||
# self.custom_logger = custom_logger
|
||||
# self.logger_level = logger_level
|
||||
|
||||
# self.futures_executor = futures_executor
|
||||
# self._execution_args = ()
|
||||
# self._logger = None
|
||||
|
||||
# [task.preprocess(self) for task in self.tasks]
|
||||
|
||||
# self.model.structure = self
|
||||
|
||||
# @property
|
||||
# def execution_args(self) -> tuple:
|
||||
# return self._execution_args
|
||||
|
||||
import concurrent.futures as futures
|
||||
import logging
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from graphlib import TopologicalSorter
|
||||
from logging import Logger
|
||||
from typing import Optional, Union
|
||||
# @property
|
||||
# def logger(self) -> Logger:
|
||||
# if self.custom_logger:
|
||||
# return self.custom_logger
|
||||
# else:
|
||||
# if self._logger is None:
|
||||
# self._logger = logging.getLogger(self.LOGGER_NAME)
|
||||
|
||||
from rich.logging import RichHandler
|
||||
# self._logger.propagate = False
|
||||
# self._logger.level = self.logger_level
|
||||
|
||||
from swarms.artifacts.error_artifact import ErrorArtifact
|
||||
from swarms.structs.task import BaseTask
|
||||
# self._logger.handlers = [
|
||||
# RichHandler(
|
||||
# show_time=True,
|
||||
# show_path=False
|
||||
# )
|
||||
# ]
|
||||
# return self._logger
|
||||
|
||||
# def is_finished(self) -> bool:
|
||||
# return all(s.is_finished() for s in self.tasks)
|
||||
|
||||
# def is_executing(self) -> bool:
|
||||
# return any(s for s in self.tasks if s.is_executing())
|
||||
|
||||
# def find_task(self, task_id: str) -> Optional[BaseTask]:
|
||||
# return next((task for task in self.tasks if task.id == task_id), None)
|
||||
|
||||
# def add_tasks(self, *tasks: BaseTask) -> list[BaseTask]:
|
||||
# return [self.add_task(s) for s in tasks]
|
||||
|
||||
# def context(self, task: BaseTask) -> dict[str, any]:
|
||||
# return {
|
||||
# "args": self.execution_args,
|
||||
# "structure": self,
|
||||
# }
|
||||
|
||||
# @abstractmethod
|
||||
# def add(self, task: BaseTask) -> BaseTask:
|
||||
# task.preprocess(self)
|
||||
# self.tasks.append(task)
|
||||
# return task
|
||||
|
||||
# @abstractmethod
|
||||
# def run(self, *args) -> Union[BaseTask, list[BaseTask]]:
|
||||
# self._execution_args = args
|
||||
# ordered_tasks = self.order_tasks()
|
||||
# exit_loop = False
|
||||
|
||||
# while not self.is_finished() and not exit_loop:
|
||||
# futures_list = {}
|
||||
|
||||
# for task in ordered_tasks:
|
||||
# if task.can_execute():
|
||||
# future = self.futures_executor.submit(task.execute)
|
||||
# futures_list[future] = task
|
||||
|
||||
# # Wait for all tasks to complete
|
||||
# for future in futures.as_completed(futures_list):
|
||||
# if isinstance(future.result(), ErrorArtifact):
|
||||
# exit_loop = True
|
||||
# break
|
||||
|
||||
# self._execution_args = ()
|
||||
|
||||
# return self.output_tasks()
|
||||
|
||||
# def context(self, task: BaseTask) -> dict[str, any]:
|
||||
# context = super().context(task)
|
||||
|
||||
# context.update(
|
||||
# {
|
||||
# "parent_outputs": {parent.id: parent.output.to_text() if parent.output else "" for parent in task.parents},
|
||||
# "parents": {parent.id: parent for parent in task.parents},
|
||||
# "children": {child.id: child for child in task.children}
|
||||
# }
|
||||
# )
|
||||
|
||||
# return context
|
||||
|
||||
# def output_tasks(self) -> list[BaseTask]:
|
||||
# return [task for task in self.tasks if not task.children]
|
||||
|
||||
# def to_graph(self) -> dict[str, set[str]]:
|
||||
# graph: dict[str, set[str]] = {}
|
||||
|
||||
# for key_task in self.tasks:
|
||||
# graph[key_task.id] = set()
|
||||
|
||||
# for value_task in self.tasks:
|
||||
# if key_task.id in value_task.child_ids:
|
||||
# graph[key_task.id].add(value_task.id)
|
||||
|
||||
# return graph
|
||||
|
||||
# def order_tasks(self) -> list[BaseTask]:
|
||||
# return [self.find_task(task_id) for task_id in TopologicalSorter(self.to_graph()).static_order()]
|
||||
|
||||
#@shapeless
|
||||
class Workflow(ABC):
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from swarms.artifacts.error_artifacts import ErrorArtifact
|
||||
from swarms.structs.task import BaseTask
|
||||
|
||||
class StringTask(BaseTask):
|
||||
def __init__(
|
||||
self,
|
||||
id: str = uuid.uuid4().hex,
|
||||
model = None,
|
||||
custom_logger: Optional[Logger] = None,
|
||||
logger_level: int = logging.INFO,
|
||||
futures_executor: futures.Executor = futures.ThreadPoolExecutor()
|
||||
self,
|
||||
task
|
||||
):
|
||||
self.id = id
|
||||
self.model = model
|
||||
self.custom_logger = custom_logger
|
||||
self.logger_level = logger_level
|
||||
super().__init__()
|
||||
self.task = task
|
||||
|
||||
def execute(self) -> Any:
|
||||
prompt = self.task_string.replace("{{ parent_input }}", self.parents[0].output if self.parents else "")
|
||||
response = self.structure.llm.run(prompt)
|
||||
self.output = response
|
||||
return response
|
||||
|
||||
|
||||
|
||||
class Workflow:
|
||||
"""
|
||||
Workflows are ideal for prescriptive processes that need to be executed sequentially.
|
||||
They string together multiple tasks of varying types, and can use Short-Term Memory
|
||||
or pass specific arguments downstream.
|
||||
|
||||
|
||||
self.futures_executor = futures_executor
|
||||
self._execution_args = ()
|
||||
self._logger = None
|
||||
|
||||
[task.preprocess(self) for task in self.tasks]
|
||||
```
|
||||
llm = LLM()
|
||||
workflow = Workflow(llm)
|
||||
|
||||
self.model.structure = self
|
||||
workflow.add("What's the weather in miami")
|
||||
workflow.add("Provide detauls for {{ parent_output }}")
|
||||
workflow.add("Summarize the above information: {{ parent_output}})
|
||||
|
||||
@property
|
||||
def execution_args(self) -> tuple:
|
||||
return self._execution_args
|
||||
workflow.run()
|
||||
|
||||
@property
|
||||
def logger(self) -> Logger:
|
||||
if self.custom_logger:
|
||||
return self.custom_logger
|
||||
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
llm
|
||||
):
|
||||
self.llm = llm
|
||||
self.tasks: List[BaseTask] = []
|
||||
|
||||
def add(
|
||||
self,
|
||||
task: BaseTask
|
||||
) -> BaseTask:
|
||||
task = StringTask(task)
|
||||
|
||||
if self.last_task():
|
||||
self.last_task().add_child(task)
|
||||
else:
|
||||
if self._logger is None:
|
||||
self._logger = logging.getLogger(self.LOGGER_NAME)
|
||||
|
||||
self._logger.propagate = False
|
||||
self._logger.level = self.logger_level
|
||||
|
||||
self._logger.handlers = [
|
||||
RichHandler(
|
||||
show_time=True,
|
||||
show_path=False
|
||||
)
|
||||
]
|
||||
return self._logger
|
||||
|
||||
def is_finished(self) -> bool:
|
||||
return all(s.is_finished() for s in self.tasks)
|
||||
|
||||
def is_executing(self) -> bool:
|
||||
return any(s for s in self.tasks if s.is_executing())
|
||||
|
||||
def find_task(self, task_id: str) -> Optional[BaseTask]:
|
||||
return next((task for task in self.tasks if task.id == task_id), None)
|
||||
|
||||
def add_tasks(self, *tasks: BaseTask) -> list[BaseTask]:
|
||||
return [self.add_task(s) for s in tasks]
|
||||
|
||||
def context(self, task: BaseTask) -> dict[str, any]:
|
||||
return {
|
||||
"args": self.execution_args,
|
||||
"structure": self,
|
||||
}
|
||||
|
||||
@abstractmethod
|
||||
def add(self, task: BaseTask) -> BaseTask:
|
||||
task.preprocess(self)
|
||||
self.tasks.append(task)
|
||||
task.structure = self
|
||||
self.tasks.append(task)
|
||||
return task
|
||||
|
||||
@abstractmethod
|
||||
def run(self, *args) -> Union[BaseTask, list[BaseTask]]:
|
||||
self._execution_args = args
|
||||
ordered_tasks = self.order_tasks()
|
||||
exit_loop = False
|
||||
def first_task(self) -> Optional[BaseTask]:
|
||||
return self.tasks[0] if self.tasks else None
|
||||
|
||||
def last_task(self) -> Optional[BaseTask]:
|
||||
return self.tasks[-1] if self.tasks else None
|
||||
|
||||
while not self.is_finished() and not exit_loop:
|
||||
futures_list = {}
|
||||
def run(self, *args) -> BaseTask:
|
||||
self._execution_args = args
|
||||
|
||||
for task in ordered_tasks:
|
||||
if task.can_execute():
|
||||
future = self.futures_executor.submit(task.execute)
|
||||
futures_list[future] = task
|
||||
[task.reset() for task in self.tasks]
|
||||
|
||||
# Wait for all tasks to complete
|
||||
for future in futures.as_completed(futures_list):
|
||||
if isinstance(future.result(), ErrorArtifact):
|
||||
exit_loop = True
|
||||
break
|
||||
self.__run_from_task(self.first_task())
|
||||
|
||||
self._execution_args = ()
|
||||
|
||||
return self.output_tasks()
|
||||
return self.last_task()
|
||||
|
||||
def context(self, task: BaseTask) -> dict[str, any]:
|
||||
|
||||
def context(self, task: BaseTask) -> Dict[str, Any]:
|
||||
context = super().context(task)
|
||||
|
||||
context.update(
|
||||
{
|
||||
"parent_outputs": {parent.id: parent.output.to_text() if parent.output else "" for parent in task.parents},
|
||||
"parents": {parent.id: parent for parent in task.parents},
|
||||
"children": {child.id: child for child in task.children}
|
||||
"parent_output": task.parents[0].output.to_text() \
|
||||
if task.parents and task.parents[0].output else None,
|
||||
"parent": task.parents[0] if task.parents else None,
|
||||
"child": task.children[0] if task.children else None
|
||||
}
|
||||
)
|
||||
|
||||
return context
|
||||
|
||||
def output_tasks(self) -> list[BaseTask]:
|
||||
return [task for task in self.tasks if not task.children]
|
||||
|
||||
def to_graph(self) -> dict[str, set[str]]:
|
||||
graph: dict[str, set[str]] = {}
|
||||
|
||||
for key_task in self.tasks:
|
||||
graph[key_task.id] = set()
|
||||
|
||||
for value_task in self.tasks:
|
||||
if key_task.id in value_task.child_ids:
|
||||
graph[key_task.id].add(value_task.id)
|
||||
|
||||
def __run_from_task(self, task: Optional[BaseTask]) -> None:
|
||||
if task is None:
|
||||
return
|
||||
else:
|
||||
if isinstance(task.execute(), ErrorArtifact):
|
||||
return
|
||||
else:
|
||||
self.__run_from_task(next(iter(task.children), None))
|
||||
|
||||
return graph
|
||||
|
||||
def order_tasks(self) -> list[BaseTask]:
|
||||
return [self.find_task(task_id) for task_id in TopologicalSorter(self.to_graph()).static_order()]
|
Loading…
Reference in new issue