parent
00e4df95c1
commit
e4ee2cfab9
@ -1,11 +1,11 @@
|
|||||||
# disable_logging()
|
from swarms.utils.disable_logging import disable_logging
|
||||||
|
|
||||||
|
disable_logging()
|
||||||
|
|
||||||
from swarms.agents import * # noqa: E402, F403
|
from swarms.agents import * # noqa: E402, F403
|
||||||
from swarms.swarms import * # noqa: E402, F403
|
|
||||||
from swarms.structs import * # noqa: E402, F403
|
from swarms.structs import * # noqa: E402, F403
|
||||||
from swarms.models import * # noqa: E402, F403
|
from swarms.models import * # noqa: E402, F403
|
||||||
from swarms.telemetry import * # noqa: E402, F403
|
from swarms.telemetry import * # noqa: E402, F403
|
||||||
from swarms.utils import * # noqa: E402, F403
|
from swarms.utils import * # noqa: E402, F403
|
||||||
from swarms.prompts import * # noqa: E402, F403
|
from swarms.prompts import * # noqa: E402, F403
|
||||||
|
# from swarms.cli import * # noqa: E402, F403
|
||||||
# from swarms.cli import * # noqa: E402, F403
|
|
@ -0,0 +1,87 @@
|
|||||||
|
from swarms.structs.task import Task
|
||||||
|
from swarms.structs.base import BaseStruct
|
||||||
|
|
||||||
|
|
||||||
|
class NonlinearWorkflow(BaseStruct):
|
||||||
|
"""
|
||||||
|
Represents a Directed Acyclic Graph (DAG) workflow.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
tasks (dict): A dictionary mapping task names to Task objects.
|
||||||
|
edges (dict): A dictionary mapping task names to a list of dependencies.
|
||||||
|
|
||||||
|
Methods:
|
||||||
|
add(task: Task, *dependencies: str): Adds a task to the workflow with its dependencies.
|
||||||
|
run(): Executes the workflow by running tasks in topological order.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> from swarms.models import OpenAIChat
|
||||||
|
>>> from swarms.structs import NonlinearWorkflow, Task
|
||||||
|
>>> llm = OpenAIChat(openai_api_key="")
|
||||||
|
>>> task = Task(llm, "What's the weather in miami")
|
||||||
|
>>> workflow = NonlinearWorkflow()
|
||||||
|
>>> workflow.add(task)
|
||||||
|
>>> workflow.run()
|
||||||
|
|
||||||
|
"""
|
||||||
|
def __init__(self, stopping_token: str = "<DONE>"):
|
||||||
|
self.tasks = {}
|
||||||
|
self.edges = {}
|
||||||
|
self.stopping_token = stopping_token
|
||||||
|
|
||||||
|
def add(self, task: Task, *dependencies: str):
|
||||||
|
"""
|
||||||
|
Adds a task to the workflow with its dependencies.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task (Task): The task to be added.
|
||||||
|
dependencies (str): Variable number of dependency task names.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: If the task is None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
assert task is not None, "Task cannot be None"
|
||||||
|
self.tasks[task.name] = task
|
||||||
|
self.edges[task.name] = list(dependencies)
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
"""
|
||||||
|
Executes the workflow by running tasks in topological order.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If a circular dependency is detected.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Create a copy of the edges
|
||||||
|
edges = self.edges.copy()
|
||||||
|
|
||||||
|
while edges:
|
||||||
|
# Get all tasks with no dependencies
|
||||||
|
ready_tasks = [
|
||||||
|
task for task, deps in edges.items() if not deps
|
||||||
|
]
|
||||||
|
|
||||||
|
if not ready_tasks:
|
||||||
|
raise Exception("Circular dependency detected")
|
||||||
|
|
||||||
|
# Run all ready tasks
|
||||||
|
for task in ready_tasks:
|
||||||
|
result = self.tasks[task].execute()
|
||||||
|
if result == self.stopping_token:
|
||||||
|
return
|
||||||
|
del edges[task]
|
||||||
|
|
||||||
|
# Remove dependencies on the ready tasks
|
||||||
|
for deps in edges.values():
|
||||||
|
for task in ready_tasks:
|
||||||
|
if task in deps:
|
||||||
|
deps.remove(task)
|
||||||
|
except Exception as error:
|
||||||
|
print(f"[ERROR][NonlinearWorkflow] {error}")
|
||||||
|
raise error
|
@ -0,0 +1,56 @@
|
|||||||
|
from typing import List
|
||||||
|
|
||||||
|
from swarms.structs.base import BaseStruct
|
||||||
|
from swarms.structs.task import Task
|
||||||
|
|
||||||
|
|
||||||
|
class RecursiveWorkflow(BaseStruct):
|
||||||
|
"""
|
||||||
|
RecursiveWorkflow class for running a task recursively until a stopping condition is met.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task (Task): The task to execute.
|
||||||
|
stop_token (Any): The token that indicates when to stop the workflow.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
task (Task): The task to execute.
|
||||||
|
stop_token (Any): The token that indicates when to stop the workflow.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> from swarms.models import OpenAIChat
|
||||||
|
>>> from swarms.structs import RecursiveWorkflow, Task
|
||||||
|
>>> llm = OpenAIChat(openai_api_key="")
|
||||||
|
>>> task = Task(llm, "What's the weather in miami")
|
||||||
|
>>> workflow = RecursiveWorkflow()
|
||||||
|
>>> workflow.add(task)
|
||||||
|
>>> workflow.run()
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, stop_token: str = "<DONE>"):
|
||||||
|
self.stop_token = stop_token
|
||||||
|
self.tasks = List[Task]
|
||||||
|
|
||||||
|
assert (
|
||||||
|
self.stop_token is not None
|
||||||
|
), "stop_token cannot be None"
|
||||||
|
|
||||||
|
def add(self, task: Task):
|
||||||
|
assert task is not None, "task cannot be None"
|
||||||
|
return self.tasks.appennd(task)
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
"""
|
||||||
|
Executes the tasks in the workflow until the stop token is encountered.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
for task in self.tasks:
|
||||||
|
while True:
|
||||||
|
result = task.execute()
|
||||||
|
if self.stop_token in result:
|
||||||
|
break
|
||||||
|
except Exception as error:
|
||||||
|
print(f"[ERROR][RecursiveWorkflow] {error}")
|
||||||
|
raise error
|
@ -0,0 +1,49 @@
|
|||||||
|
import pytest
|
||||||
|
from swarms.structs import NonlinearWorkflow, Task
|
||||||
|
from swarms.models import OpenAIChat
|
||||||
|
|
||||||
|
|
||||||
|
class TestNonlinearWorkflow:
|
||||||
|
def test_add_task(self):
|
||||||
|
llm = OpenAIChat(openai_api_key="")
|
||||||
|
task = Task(llm, "What's the weather in miami")
|
||||||
|
workflow = NonlinearWorkflow()
|
||||||
|
workflow.add(task)
|
||||||
|
assert task.name in workflow.tasks
|
||||||
|
assert task.name in workflow.edges
|
||||||
|
|
||||||
|
def test_run_without_tasks(self):
|
||||||
|
workflow = NonlinearWorkflow()
|
||||||
|
# No exception should be raised
|
||||||
|
workflow.run()
|
||||||
|
|
||||||
|
def test_run_with_single_task(self):
|
||||||
|
llm = OpenAIChat(openai_api_key="")
|
||||||
|
task = Task(llm, "What's the weather in miami")
|
||||||
|
workflow = NonlinearWorkflow()
|
||||||
|
workflow.add(task)
|
||||||
|
# No exception should be raised
|
||||||
|
workflow.run()
|
||||||
|
|
||||||
|
def test_run_with_circular_dependency(self):
|
||||||
|
llm = OpenAIChat(openai_api_key="")
|
||||||
|
task1 = Task(llm, "What's the weather in miami")
|
||||||
|
task2 = Task(llm, "What's the weather in new york")
|
||||||
|
workflow = NonlinearWorkflow()
|
||||||
|
workflow.add(task1, task2.name)
|
||||||
|
workflow.add(task2, task1.name)
|
||||||
|
with pytest.raises(
|
||||||
|
Exception, match="Circular dependency detected"
|
||||||
|
):
|
||||||
|
workflow.run()
|
||||||
|
|
||||||
|
def test_run_with_stopping_token(self):
|
||||||
|
llm = OpenAIChat(openai_api_key="")
|
||||||
|
task1 = Task(llm, "What's the weather in miami")
|
||||||
|
task2 = Task(llm, "What's the weather in new york")
|
||||||
|
workflow = NonlinearWorkflow(stopping_token="stop")
|
||||||
|
workflow.add(task1)
|
||||||
|
workflow.add(task2)
|
||||||
|
# Assuming that task1's execute method returns "stop"
|
||||||
|
# No exception should be raised
|
||||||
|
workflow.run()
|
@ -0,0 +1,72 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import Mock, create_autospec
|
||||||
|
from swarms.models import OpenAIChat
|
||||||
|
from swarms.structs import RecursiveWorkflow, Task
|
||||||
|
|
||||||
|
|
||||||
|
def test_add():
|
||||||
|
workflow = RecursiveWorkflow(stop_token="<DONE>")
|
||||||
|
task = Mock(spec=Task)
|
||||||
|
workflow.add(task)
|
||||||
|
assert task in workflow.tasks
|
||||||
|
|
||||||
|
|
||||||
|
def test_run():
|
||||||
|
workflow = RecursiveWorkflow(stop_token="<DONE>")
|
||||||
|
agent1 = create_autospec(OpenAIChat)
|
||||||
|
agent2 = create_autospec(OpenAIChat)
|
||||||
|
task1 = Task("What's the weather in miami", agent1)
|
||||||
|
task2 = Task("What's the weather in miami", agent2)
|
||||||
|
workflow.add(task1)
|
||||||
|
workflow.add(task2)
|
||||||
|
|
||||||
|
agent1.execute.return_value = "Not done"
|
||||||
|
agent2.execute.return_value = "<DONE>"
|
||||||
|
|
||||||
|
workflow.run()
|
||||||
|
|
||||||
|
assert agent1.execute.call_count >= 1
|
||||||
|
assert agent2.execute.call_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_no_tasks():
|
||||||
|
workflow = RecursiveWorkflow(stop_token="<DONE>")
|
||||||
|
# No tasks are added to the workflow
|
||||||
|
# This should not raise any errors
|
||||||
|
workflow.run()
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_stop_token_not_in_result():
|
||||||
|
workflow = RecursiveWorkflow(stop_token="<DONE>")
|
||||||
|
agent = create_autospec(OpenAIChat)
|
||||||
|
task = Task("What's the weather in miami", agent)
|
||||||
|
workflow.add(task)
|
||||||
|
|
||||||
|
agent.execute.return_value = "Not done"
|
||||||
|
|
||||||
|
# If the stop token is never found in the result, the workflow could run forever.
|
||||||
|
# To prevent this, we'll set a maximum number of iterations.
|
||||||
|
max_iterations = 1000
|
||||||
|
for _ in range(max_iterations):
|
||||||
|
try:
|
||||||
|
workflow.run()
|
||||||
|
except RecursionError:
|
||||||
|
pytest.fail(
|
||||||
|
"RecursiveWorkflow.run caused a RecursionError"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert agent.execute.call_count == max_iterations
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_stop_token_in_result():
|
||||||
|
workflow = RecursiveWorkflow(stop_token="<DONE>")
|
||||||
|
agent = create_autospec(OpenAIChat)
|
||||||
|
task = Task("What's the weather in miami", agent)
|
||||||
|
workflow.add(task)
|
||||||
|
|
||||||
|
agent.execute.return_value = "<DONE>"
|
||||||
|
|
||||||
|
workflow.run()
|
||||||
|
|
||||||
|
# If the stop token is found in the result, the workflow should stop running the task.
|
||||||
|
assert agent.execute.call_count == 1
|
Loading…
Reference in new issue