You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/tests/structs/test_recursive_workflow.py

73 lines
2.0 KiB

5 months ago
from unittest.mock import Mock, create_autospec
import pytest
from swarms.models import OpenAIChat
from swarms.structs import RecursiveWorkflow, Task
def test_add():
workflow = RecursiveWorkflow(stop_token="<DONE>")
task = Mock(spec=Task)
workflow.add(task)
assert task in workflow.tasks
def test_run():
workflow = RecursiveWorkflow(stop_token="<DONE>")
agent1 = create_autospec(OpenAIChat)
agent2 = create_autospec(OpenAIChat)
task1 = Task("What's the weather in miami", agent1)
task2 = Task("What's the weather in miami", agent2)
workflow.add(task1)
workflow.add(task2)
agent1.execute.return_value = "Not done"
agent2.execute.return_value = "<DONE>"
workflow.run()
assert agent1.execute.call_count >= 1
assert agent2.execute.call_count == 1
def test_run_no_tasks():
workflow = RecursiveWorkflow(stop_token="<DONE>")
# No tasks are added to the workflow
# This should not raise any errors
workflow.run()
def test_run_stop_token_not_in_result():
workflow = RecursiveWorkflow(stop_token="<DONE>")
agent = create_autospec(OpenAIChat)
task = Task("What's the weather in miami", agent)
workflow.add(task)
agent.execute.return_value = "Not done"
# If the stop token is never found in the result, the workflow could run forever.
# To prevent this, we'll set a maximum number of iterations.
max_iterations = 1000
for _ in range(max_iterations):
try:
workflow.run()
except RecursionError:
pytest.fail("RecursiveWorkflow.run caused a RecursionError")
assert agent.execute.call_count == max_iterations
def test_run_stop_token_in_result():
workflow = RecursiveWorkflow(stop_token="<DONE>")
agent = create_autospec(OpenAIChat)
task = Task("What's the weather in miami", agent)
workflow.add(task)
agent.execute.return_value = "<DONE>"
workflow.run()
# If the stop token is found in the result, the workflow should stop running the task.
assert agent.execute.call_count == 1