[FEAT][NonlinearWorkflow]

pull/343/head
Kye 1 year ago
parent 00e4df95c1
commit e4ee2cfab9

@ -1,11 +1,11 @@
# disable_logging()
from swarms.utils.disable_logging import disable_logging
disable_logging()
from swarms.agents import * # noqa: E402, F403
from swarms.swarms import * # noqa: E402, F403
from swarms.structs import * # noqa: E402, F403
from swarms.models import * # noqa: E402, F403
from swarms.telemetry import * # noqa: E402, F403
from swarms.utils import * # noqa: E402, F403
from swarms.prompts import * # noqa: E402, F403
# from swarms.cli import * # noqa: E402, F403
# from swarms.cli import * # noqa: E402, F403

@ -1,10 +1,13 @@
from swarms.structs.agent import Agent
from swarms.structs.autoscaler import AutoScaler
from swarms.structs.base_swarm import AbstractSwarm
from swarms.structs.concurrent_workflow import ConcurrentWorkflow
from swarms.structs.conversation import Conversation
from swarms.structs.groupchat import GroupChat, GroupChatManager
from swarms.structs.model_parallizer import ModelParallelizer
from swarms.structs.multi_agent_collab import MultiAgentCollaboration
from swarms.structs.nonlinear_workflow import NonlinearWorkflow
from swarms.structs.recursive_workflow import RecursiveWorkflow
from swarms.structs.schemas import (
Artifact,
ArtifactUpload,
@ -21,7 +24,6 @@ from swarms.structs.utils import (
find_token_in_text,
parse_tasks,
)
from swarms.structs.concurrent_workflow import ConcurrentWorkflow
__all__ = [
"Agent",
@ -45,4 +47,6 @@ __all__ = [
"extract_key_from_json",
"extract_tokens_from_text",
"ConcurrentWorkflow",
"RecursiveWorkflow",
"NonlinearWorkflow",
]

@ -0,0 +1,87 @@
from swarms.structs.task import Task
from swarms.structs.base import BaseStruct
class NonlinearWorkflow(BaseStruct):
"""
Represents a Directed Acyclic Graph (DAG) workflow.
Attributes:
tasks (dict): A dictionary mapping task names to Task objects.
edges (dict): A dictionary mapping task names to a list of dependencies.
Methods:
add(task: Task, *dependencies: str): Adds a task to the workflow with its dependencies.
run(): Executes the workflow by running tasks in topological order.
Examples:
>>> from swarms.models import OpenAIChat
>>> from swarms.structs import NonlinearWorkflow, Task
>>> llm = OpenAIChat(openai_api_key="")
>>> task = Task(llm, "What's the weather in miami")
>>> workflow = NonlinearWorkflow()
>>> workflow.add(task)
>>> workflow.run()
"""
def __init__(self, stopping_token: str = "<DONE>"):
self.tasks = {}
self.edges = {}
self.stopping_token = stopping_token
def add(self, task: Task, *dependencies: str):
"""
Adds a task to the workflow with its dependencies.
Args:
task (Task): The task to be added.
dependencies (str): Variable number of dependency task names.
Raises:
AssertionError: If the task is None.
Returns:
None
"""
assert task is not None, "Task cannot be None"
self.tasks[task.name] = task
self.edges[task.name] = list(dependencies)
def run(self):
"""
Executes the workflow by running tasks in topological order.
Raises:
Exception: If a circular dependency is detected.
Returns:
None
"""
try:
# Create a copy of the edges
edges = self.edges.copy()
while edges:
# Get all tasks with no dependencies
ready_tasks = [
task for task, deps in edges.items() if not deps
]
if not ready_tasks:
raise Exception("Circular dependency detected")
# Run all ready tasks
for task in ready_tasks:
result = self.tasks[task].execute()
if result == self.stopping_token:
return
del edges[task]
# Remove dependencies on the ready tasks
for deps in edges.values():
for task in ready_tasks:
if task in deps:
deps.remove(task)
except Exception as error:
print(f"[ERROR][NonlinearWorkflow] {error}")
raise error

@ -0,0 +1,56 @@
from typing import List
from swarms.structs.base import BaseStruct
from swarms.structs.task import Task
class RecursiveWorkflow(BaseStruct):
"""
RecursiveWorkflow class for running a task recursively until a stopping condition is met.
Args:
task (Task): The task to execute.
stop_token (Any): The token that indicates when to stop the workflow.
Attributes:
task (Task): The task to execute.
stop_token (Any): The token that indicates when to stop the workflow.
Examples:
>>> from swarms.models import OpenAIChat
>>> from swarms.structs import RecursiveWorkflow, Task
>>> llm = OpenAIChat(openai_api_key="")
>>> task = Task(llm, "What's the weather in miami")
>>> workflow = RecursiveWorkflow()
>>> workflow.add(task)
>>> workflow.run()
"""
def __init__(self, stop_token: str = "<DONE>"):
self.stop_token = stop_token
self.tasks = List[Task]
assert (
self.stop_token is not None
), "stop_token cannot be None"
def add(self, task: Task):
assert task is not None, "task cannot be None"
return self.tasks.appennd(task)
def run(self):
"""
Executes the tasks in the workflow until the stop token is encountered.
Returns:
None
"""
try:
for task in self.tasks:
while True:
result = task.execute()
if self.stop_token in result:
break
except Exception as error:
print(f"[ERROR][RecursiveWorkflow] {error}")
raise error

@ -0,0 +1,49 @@
import pytest
from swarms.structs import NonlinearWorkflow, Task
from swarms.models import OpenAIChat
class TestNonlinearWorkflow:
def test_add_task(self):
llm = OpenAIChat(openai_api_key="")
task = Task(llm, "What's the weather in miami")
workflow = NonlinearWorkflow()
workflow.add(task)
assert task.name in workflow.tasks
assert task.name in workflow.edges
def test_run_without_tasks(self):
workflow = NonlinearWorkflow()
# No exception should be raised
workflow.run()
def test_run_with_single_task(self):
llm = OpenAIChat(openai_api_key="")
task = Task(llm, "What's the weather in miami")
workflow = NonlinearWorkflow()
workflow.add(task)
# No exception should be raised
workflow.run()
def test_run_with_circular_dependency(self):
llm = OpenAIChat(openai_api_key="")
task1 = Task(llm, "What's the weather in miami")
task2 = Task(llm, "What's the weather in new york")
workflow = NonlinearWorkflow()
workflow.add(task1, task2.name)
workflow.add(task2, task1.name)
with pytest.raises(
Exception, match="Circular dependency detected"
):
workflow.run()
def test_run_with_stopping_token(self):
llm = OpenAIChat(openai_api_key="")
task1 = Task(llm, "What's the weather in miami")
task2 = Task(llm, "What's the weather in new york")
workflow = NonlinearWorkflow(stopping_token="stop")
workflow.add(task1)
workflow.add(task2)
# Assuming that task1's execute method returns "stop"
# No exception should be raised
workflow.run()

@ -0,0 +1,72 @@
import pytest
from unittest.mock import Mock, create_autospec
from swarms.models import OpenAIChat
from swarms.structs import RecursiveWorkflow, Task
def test_add():
workflow = RecursiveWorkflow(stop_token="<DONE>")
task = Mock(spec=Task)
workflow.add(task)
assert task in workflow.tasks
def test_run():
workflow = RecursiveWorkflow(stop_token="<DONE>")
agent1 = create_autospec(OpenAIChat)
agent2 = create_autospec(OpenAIChat)
task1 = Task("What's the weather in miami", agent1)
task2 = Task("What's the weather in miami", agent2)
workflow.add(task1)
workflow.add(task2)
agent1.execute.return_value = "Not done"
agent2.execute.return_value = "<DONE>"
workflow.run()
assert agent1.execute.call_count >= 1
assert agent2.execute.call_count == 1
def test_run_no_tasks():
workflow = RecursiveWorkflow(stop_token="<DONE>")
# No tasks are added to the workflow
# This should not raise any errors
workflow.run()
def test_run_stop_token_not_in_result():
workflow = RecursiveWorkflow(stop_token="<DONE>")
agent = create_autospec(OpenAIChat)
task = Task("What's the weather in miami", agent)
workflow.add(task)
agent.execute.return_value = "Not done"
# If the stop token is never found in the result, the workflow could run forever.
# To prevent this, we'll set a maximum number of iterations.
max_iterations = 1000
for _ in range(max_iterations):
try:
workflow.run()
except RecursionError:
pytest.fail(
"RecursiveWorkflow.run caused a RecursionError"
)
assert agent.execute.call_count == max_iterations
def test_run_stop_token_in_result():
workflow = RecursiveWorkflow(stop_token="<DONE>")
agent = create_autospec(OpenAIChat)
task = Task("What's the weather in miami", agent)
workflow.add(task)
agent.execute.return_value = "<DONE>"
workflow.run()
# If the stop token is found in the result, the workflow should stop running the task.
assert agent.execute.call_count == 1
Loading…
Cancel
Save