parent
00e4df95c1
commit
e4ee2cfab9
@ -1,11 +1,11 @@
|
||||
# disable_logging()
|
||||
from swarms.utils.disable_logging import disable_logging
|
||||
|
||||
disable_logging()
|
||||
|
||||
from swarms.agents import * # noqa: E402, F403
|
||||
from swarms.swarms import * # noqa: E402, F403
|
||||
from swarms.structs import * # noqa: E402, F403
|
||||
from swarms.models import * # noqa: E402, F403
|
||||
from swarms.telemetry import * # noqa: E402, F403
|
||||
from swarms.utils import * # noqa: E402, F403
|
||||
from swarms.prompts import * # noqa: E402, F403
|
||||
|
||||
# from swarms.cli import * # noqa: E402, F403
|
||||
# from swarms.cli import * # noqa: E402, F403
|
@ -0,0 +1,87 @@
|
||||
from swarms.structs.task import Task
|
||||
from swarms.structs.base import BaseStruct
|
||||
|
||||
|
||||
class NonlinearWorkflow(BaseStruct):
|
||||
"""
|
||||
Represents a Directed Acyclic Graph (DAG) workflow.
|
||||
|
||||
Attributes:
|
||||
tasks (dict): A dictionary mapping task names to Task objects.
|
||||
edges (dict): A dictionary mapping task names to a list of dependencies.
|
||||
|
||||
Methods:
|
||||
add(task: Task, *dependencies: str): Adds a task to the workflow with its dependencies.
|
||||
run(): Executes the workflow by running tasks in topological order.
|
||||
|
||||
Examples:
|
||||
>>> from swarms.models import OpenAIChat
|
||||
>>> from swarms.structs import NonlinearWorkflow, Task
|
||||
>>> llm = OpenAIChat(openai_api_key="")
|
||||
>>> task = Task(llm, "What's the weather in miami")
|
||||
>>> workflow = NonlinearWorkflow()
|
||||
>>> workflow.add(task)
|
||||
>>> workflow.run()
|
||||
|
||||
"""
|
||||
def __init__(self, stopping_token: str = "<DONE>"):
|
||||
self.tasks = {}
|
||||
self.edges = {}
|
||||
self.stopping_token = stopping_token
|
||||
|
||||
def add(self, task: Task, *dependencies: str):
|
||||
"""
|
||||
Adds a task to the workflow with its dependencies.
|
||||
|
||||
Args:
|
||||
task (Task): The task to be added.
|
||||
dependencies (str): Variable number of dependency task names.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the task is None.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
assert task is not None, "Task cannot be None"
|
||||
self.tasks[task.name] = task
|
||||
self.edges[task.name] = list(dependencies)
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
Executes the workflow by running tasks in topological order.
|
||||
|
||||
Raises:
|
||||
Exception: If a circular dependency is detected.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
try:
|
||||
# Create a copy of the edges
|
||||
edges = self.edges.copy()
|
||||
|
||||
while edges:
|
||||
# Get all tasks with no dependencies
|
||||
ready_tasks = [
|
||||
task for task, deps in edges.items() if not deps
|
||||
]
|
||||
|
||||
if not ready_tasks:
|
||||
raise Exception("Circular dependency detected")
|
||||
|
||||
# Run all ready tasks
|
||||
for task in ready_tasks:
|
||||
result = self.tasks[task].execute()
|
||||
if result == self.stopping_token:
|
||||
return
|
||||
del edges[task]
|
||||
|
||||
# Remove dependencies on the ready tasks
|
||||
for deps in edges.values():
|
||||
for task in ready_tasks:
|
||||
if task in deps:
|
||||
deps.remove(task)
|
||||
except Exception as error:
|
||||
print(f"[ERROR][NonlinearWorkflow] {error}")
|
||||
raise error
|
@ -0,0 +1,56 @@
|
||||
from typing import List
|
||||
|
||||
from swarms.structs.base import BaseStruct
|
||||
from swarms.structs.task import Task
|
||||
|
||||
|
||||
class RecursiveWorkflow(BaseStruct):
|
||||
"""
|
||||
RecursiveWorkflow class for running a task recursively until a stopping condition is met.
|
||||
|
||||
Args:
|
||||
task (Task): The task to execute.
|
||||
stop_token (Any): The token that indicates when to stop the workflow.
|
||||
|
||||
Attributes:
|
||||
task (Task): The task to execute.
|
||||
stop_token (Any): The token that indicates when to stop the workflow.
|
||||
|
||||
Examples:
|
||||
>>> from swarms.models import OpenAIChat
|
||||
>>> from swarms.structs import RecursiveWorkflow, Task
|
||||
>>> llm = OpenAIChat(openai_api_key="")
|
||||
>>> task = Task(llm, "What's the weather in miami")
|
||||
>>> workflow = RecursiveWorkflow()
|
||||
>>> workflow.add(task)
|
||||
>>> workflow.run()
|
||||
"""
|
||||
|
||||
def __init__(self, stop_token: str = "<DONE>"):
|
||||
self.stop_token = stop_token
|
||||
self.tasks = List[Task]
|
||||
|
||||
assert (
|
||||
self.stop_token is not None
|
||||
), "stop_token cannot be None"
|
||||
|
||||
def add(self, task: Task):
|
||||
assert task is not None, "task cannot be None"
|
||||
return self.tasks.appennd(task)
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
Executes the tasks in the workflow until the stop token is encountered.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
try:
|
||||
for task in self.tasks:
|
||||
while True:
|
||||
result = task.execute()
|
||||
if self.stop_token in result:
|
||||
break
|
||||
except Exception as error:
|
||||
print(f"[ERROR][RecursiveWorkflow] {error}")
|
||||
raise error
|
@ -0,0 +1,49 @@
|
||||
import pytest
|
||||
from swarms.structs import NonlinearWorkflow, Task
|
||||
from swarms.models import OpenAIChat
|
||||
|
||||
|
||||
class TestNonlinearWorkflow:
|
||||
def test_add_task(self):
|
||||
llm = OpenAIChat(openai_api_key="")
|
||||
task = Task(llm, "What's the weather in miami")
|
||||
workflow = NonlinearWorkflow()
|
||||
workflow.add(task)
|
||||
assert task.name in workflow.tasks
|
||||
assert task.name in workflow.edges
|
||||
|
||||
def test_run_without_tasks(self):
|
||||
workflow = NonlinearWorkflow()
|
||||
# No exception should be raised
|
||||
workflow.run()
|
||||
|
||||
def test_run_with_single_task(self):
|
||||
llm = OpenAIChat(openai_api_key="")
|
||||
task = Task(llm, "What's the weather in miami")
|
||||
workflow = NonlinearWorkflow()
|
||||
workflow.add(task)
|
||||
# No exception should be raised
|
||||
workflow.run()
|
||||
|
||||
def test_run_with_circular_dependency(self):
|
||||
llm = OpenAIChat(openai_api_key="")
|
||||
task1 = Task(llm, "What's the weather in miami")
|
||||
task2 = Task(llm, "What's the weather in new york")
|
||||
workflow = NonlinearWorkflow()
|
||||
workflow.add(task1, task2.name)
|
||||
workflow.add(task2, task1.name)
|
||||
with pytest.raises(
|
||||
Exception, match="Circular dependency detected"
|
||||
):
|
||||
workflow.run()
|
||||
|
||||
def test_run_with_stopping_token(self):
|
||||
llm = OpenAIChat(openai_api_key="")
|
||||
task1 = Task(llm, "What's the weather in miami")
|
||||
task2 = Task(llm, "What's the weather in new york")
|
||||
workflow = NonlinearWorkflow(stopping_token="stop")
|
||||
workflow.add(task1)
|
||||
workflow.add(task2)
|
||||
# Assuming that task1's execute method returns "stop"
|
||||
# No exception should be raised
|
||||
workflow.run()
|
@ -0,0 +1,72 @@
|
||||
import pytest
|
||||
from unittest.mock import Mock, create_autospec
|
||||
from swarms.models import OpenAIChat
|
||||
from swarms.structs import RecursiveWorkflow, Task
|
||||
|
||||
|
||||
def test_add():
|
||||
workflow = RecursiveWorkflow(stop_token="<DONE>")
|
||||
task = Mock(spec=Task)
|
||||
workflow.add(task)
|
||||
assert task in workflow.tasks
|
||||
|
||||
|
||||
def test_run():
|
||||
workflow = RecursiveWorkflow(stop_token="<DONE>")
|
||||
agent1 = create_autospec(OpenAIChat)
|
||||
agent2 = create_autospec(OpenAIChat)
|
||||
task1 = Task("What's the weather in miami", agent1)
|
||||
task2 = Task("What's the weather in miami", agent2)
|
||||
workflow.add(task1)
|
||||
workflow.add(task2)
|
||||
|
||||
agent1.execute.return_value = "Not done"
|
||||
agent2.execute.return_value = "<DONE>"
|
||||
|
||||
workflow.run()
|
||||
|
||||
assert agent1.execute.call_count >= 1
|
||||
assert agent2.execute.call_count == 1
|
||||
|
||||
|
||||
def test_run_no_tasks():
|
||||
workflow = RecursiveWorkflow(stop_token="<DONE>")
|
||||
# No tasks are added to the workflow
|
||||
# This should not raise any errors
|
||||
workflow.run()
|
||||
|
||||
|
||||
def test_run_stop_token_not_in_result():
|
||||
workflow = RecursiveWorkflow(stop_token="<DONE>")
|
||||
agent = create_autospec(OpenAIChat)
|
||||
task = Task("What's the weather in miami", agent)
|
||||
workflow.add(task)
|
||||
|
||||
agent.execute.return_value = "Not done"
|
||||
|
||||
# If the stop token is never found in the result, the workflow could run forever.
|
||||
# To prevent this, we'll set a maximum number of iterations.
|
||||
max_iterations = 1000
|
||||
for _ in range(max_iterations):
|
||||
try:
|
||||
workflow.run()
|
||||
except RecursionError:
|
||||
pytest.fail(
|
||||
"RecursiveWorkflow.run caused a RecursionError"
|
||||
)
|
||||
|
||||
assert agent.execute.call_count == max_iterations
|
||||
|
||||
|
||||
def test_run_stop_token_in_result():
|
||||
workflow = RecursiveWorkflow(stop_token="<DONE>")
|
||||
agent = create_autospec(OpenAIChat)
|
||||
task = Task("What's the weather in miami", agent)
|
||||
workflow.add(task)
|
||||
|
||||
agent.execute.return_value = "<DONE>"
|
||||
|
||||
workflow.run()
|
||||
|
||||
# If the stop token is found in the result, the workflow should stop running the task.
|
||||
assert agent.execute.call_count == 1
|
Loading…
Reference in new issue