You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/tests/structs/flow.py

224 lines
7.7 KiB

import json
import os
from unittest.mock import MagicMock, patch
import pytest
from dotenv import load_dotenv
from swarms.models import OpenAIChat
from swarms.structs.flow import Flow, stop_when_repeats
load_dotenv()
openai_api_key = os.getenv("OPENAI_API_KEY")
# Mocks and Fixtures
@pytest.fixture
def mocked_llm():
return OpenAIChat(
openai_api_key=openai_api_key,
)
@pytest.fixture
def basic_flow(mocked_llm):
return Flow(llm=mocked_llm, max_loops=5)
@pytest.fixture
def flow_with_condition(mocked_llm):
return Flow(llm=mocked_llm, max_loops=5, stopping_condition=stop_when_repeats)
# Basic Tests
def test_stop_when_repeats():
assert stop_when_repeats("Please Stop now")
assert not stop_when_repeats("Continue the process")
def test_flow_initialization(basic_flow):
assert basic_flow.max_loops == 5
assert basic_flow.stopping_condition is None
assert basic_flow.loop_interval == 1
assert basic_flow.retry_attempts == 3
assert basic_flow.retry_interval == 1
assert basic_flow.feedback == []
assert basic_flow.memory == []
assert basic_flow.task is None
assert basic_flow.stopping_token == "<DONE>"
assert not basic_flow.interactive
def test_provide_feedback(basic_flow):
feedback = "Test feedback"
basic_flow.provide_feedback(feedback)
assert feedback in basic_flow.feedback
@patch('time.sleep', return_value=None) # to speed up tests
def test_run_without_stopping_condition(mocked_sleep, basic_flow):
response = basic_flow.run("Test task")
assert response == "Test task" # since our mocked llm doesn't modify the response
@patch('time.sleep', return_value=None) # to speed up tests
def test_run_with_stopping_condition(mocked_sleep, flow_with_condition):
response = flow_with_condition.run("Stop")
assert response == "Stop"
@patch('time.sleep', return_value=None) # to speed up tests
def test_run_with_exception(mocked_sleep, basic_flow):
basic_flow.llm.side_effect = Exception("Test Exception")
with pytest.raises(Exception, match="Test Exception"):
basic_flow.run("Test task")
def test_bulk_run(basic_flow):
inputs = [{"task": "Test1"}, {"task": "Test2"}]
responses = basic_flow.bulk_run(inputs)
assert responses == ["Test1", "Test2"]
# Tests involving file IO
def test_save_and_load(basic_flow, tmp_path):
file_path = tmp_path / "memory.json"
basic_flow.memory.append(["Test1", "Test2"])
basic_flow.save(file_path)
new_flow = Flow(llm=mocked_llm, max_loops=5)
new_flow.load(file_path)
assert new_flow.memory == [["Test1", "Test2"]]
# Environment variable mock test
def test_env_variable_handling(monkeypatch):
monkeypatch.setenv("API_KEY", "test_key")
assert os.getenv("API_KEY") == "test_key"
# TODO: Add more tests, especially edge cases and exception cases. Implement parametrized tests for varied inputs.
# Test initializing the flow with different stopping conditions
def test_flow_with_custom_stopping_condition(mocked_llm):
def stopping_condition(x):
return "terminate" in x.lower()
flow = Flow(llm=mocked_llm, max_loops=5, stopping_condition=stopping_condition)
assert flow.stopping_condition("Please terminate now")
assert not flow.stopping_condition("Continue the process")
# Test calling the flow directly
def test_flow_call(basic_flow):
response = basic_flow("Test call")
assert response == "Test call"
# Test formatting the prompt
def test_format_prompt(basic_flow):
formatted_prompt = basic_flow.format_prompt("Hello {name}", name="John")
assert formatted_prompt == "Hello John"
# Test with max loops
@patch('time.sleep', return_value=None)
def test_max_loops(mocked_sleep, basic_flow):
basic_flow.max_loops = 3
response = basic_flow.run("Looping")
assert response == "Looping"
# Test stopping token
@patch('time.sleep', return_value=None)
def test_stopping_token(mocked_sleep, basic_flow):
basic_flow.stopping_token = "Terminate"
response = basic_flow.run("Loop until Terminate")
assert response == "Loop until Terminate"
# Test interactive mode
def test_interactive_mode(basic_flow):
basic_flow.interactive = True
assert basic_flow.interactive
# Test bulk run with varied inputs
def test_bulk_run_varied_inputs(basic_flow):
inputs = [{"task": "Test1"}, {"task": "Test2"}, {"task": "Stop now"}]
responses = basic_flow.bulk_run(inputs)
assert responses == ["Test1", "Test2", "Stop now"]
# Test loading non-existent file
def test_load_non_existent_file(basic_flow, tmp_path):
file_path = tmp_path / "non_existent.json"
with pytest.raises(FileNotFoundError):
basic_flow.load(file_path)
# Test saving with different memory data
def test_save_different_memory(basic_flow, tmp_path):
file_path = tmp_path / "memory.json"
basic_flow.memory.append(["Task1", "Task2", "Task3"])
basic_flow.save(file_path)
with open(file_path, 'r') as f:
data = json.load(f)
assert data == [["Task1", "Task2", "Task3"]]
# Test the stopping condition check
def test_check_stopping_condition(flow_with_condition):
assert flow_with_condition._check_stopping_condition("Stop this process")
assert not flow_with_condition._check_stopping_condition("Continue the task")
# Test without providing max loops (default value should be 5)
def test_default_max_loops(mocked_llm):
flow = Flow(llm=mocked_llm)
assert flow.max_loops == 5
# Test creating flow from llm and template
def test_from_llm_and_template(mocked_llm):
flow = Flow.from_llm_and_template(mocked_llm, "Test template")
assert isinstance(flow, Flow)
# Mocking the OpenAIChat for testing
@patch('swarms.models.OpenAIChat', autospec=True)
def test_mocked_openai_chat(MockedOpenAIChat):
llm = MockedOpenAIChat(openai_api_key=openai_api_key)
llm.return_value = MagicMock()
flow = Flow(llm=llm, max_loops=5)
flow.run("Mocked run")
assert MockedOpenAIChat.called
# Test retry attempts
@patch('time.sleep', return_value=None)
def test_retry_attempts(mocked_sleep, basic_flow):
basic_flow.retry_attempts = 2
basic_flow.llm.side_effect = [Exception("Test Exception"), "Valid response"]
response = basic_flow.run("Test retry")
assert response == "Valid response"
# Test different loop intervals
@patch('time.sleep', return_value=None)
def test_different_loop_intervals(mocked_sleep, basic_flow):
basic_flow.loop_interval = 2
response = basic_flow.run("Test loop interval")
assert response == "Test loop interval"
# Test different retry intervals
@patch('time.sleep', return_value=None)
def test_different_retry_intervals(mocked_sleep, basic_flow):
basic_flow.retry_interval = 2
response = basic_flow.run("Test retry interval")
assert response == "Test retry interval"
# Test invoking the flow with additional kwargs
@patch('time.sleep', return_value=None)
def test_flow_call_with_kwargs(mocked_sleep, basic_flow):
response = basic_flow("Test call", param1="value1", param2="value2")
assert response == "Test call"
# Test initializing the flow with all parameters
def test_flow_initialization_all_params(mocked_llm):
flow = Flow(
llm=mocked_llm,
max_loops=10,
stopping_condition=stop_when_repeats,
loop_interval=2,
retry_attempts=4,
retry_interval=2,
interactive=True,
param1="value1",
param2="value2"
)
assert flow.max_loops == 10
assert flow.loop_interval == 2
assert flow.retry_attempts == 4
assert flow.retry_interval == 2
assert flow.interactive
# Test the stopping token is in the response
@patch('time.sleep', return_value=None)
def test_stopping_token_in_response(mocked_sleep, basic_flow):
response = basic_flow.run("Test stopping token")
assert basic_flow.stopping_token in response