You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
75 lines
2.1 KiB
75 lines
2.1 KiB
from unittest.mock import Mock, create_autospec
|
|
|
|
import pytest
|
|
|
|
from swarm_models import OpenAIChat
|
|
from swarms.structs import RecursiveWorkflow, Task
|
|
|
|
|
|
def test_add():
|
|
workflow = RecursiveWorkflow(stop_token="<DONE>")
|
|
task = Mock(spec=Task)
|
|
workflow.add(task)
|
|
assert task in workflow.tasks
|
|
|
|
|
|
def test_run():
|
|
workflow = RecursiveWorkflow(stop_token="<DONE>")
|
|
agent1 = create_autospec(OpenAIChat)
|
|
agent2 = create_autospec(OpenAIChat)
|
|
task1 = Task("What's the weather in miami", agent1)
|
|
task2 = Task("What's the weather in miami", agent2)
|
|
workflow.add(task1)
|
|
workflow.add(task2)
|
|
|
|
agent1.execute.return_value = "Not done"
|
|
agent2.execute.return_value = "<DONE>"
|
|
|
|
workflow.run()
|
|
|
|
assert agent1.execute.call_count >= 1
|
|
assert agent2.execute.call_count == 1
|
|
|
|
|
|
def test_run_no_tasks():
|
|
workflow = RecursiveWorkflow(stop_token="<DONE>")
|
|
# No tasks are added to the workflow
|
|
# This should not raise any errors
|
|
workflow.run()
|
|
|
|
|
|
def test_run_stop_token_not_in_result():
|
|
workflow = RecursiveWorkflow(stop_token="<DONE>")
|
|
agent = create_autospec(OpenAIChat)
|
|
task = Task("What's the weather in miami", agent)
|
|
workflow.add(task)
|
|
|
|
agent.execute.return_value = "Not done"
|
|
|
|
# If the stop token is never found in the result, the workflow could run forever.
|
|
# To prevent this, we'll set a maximum number of iterations.
|
|
max_iterations = 1000
|
|
for _ in range(max_iterations):
|
|
try:
|
|
workflow.run()
|
|
except RecursionError:
|
|
pytest.fail(
|
|
"RecursiveWorkflow.run caused a RecursionError"
|
|
)
|
|
|
|
assert agent.execute.call_count == max_iterations
|
|
|
|
|
|
def test_run_stop_token_in_result():
|
|
workflow = RecursiveWorkflow(stop_token="<DONE>")
|
|
agent = create_autospec(OpenAIChat)
|
|
task = Task("What's the weather in miami", agent)
|
|
workflow.add(task)
|
|
|
|
agent.execute.return_value = "<DONE>"
|
|
|
|
workflow.run()
|
|
|
|
# If the stop token is found in the result, the workflow should stop running the task.
|
|
assert agent.execute.call_count == 1
|