parent
e397b8b19e
commit
d4176729f3
@ -0,0 +1,134 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
from swarms.structs.swarm_arange import SwarmRearrange
|
||||||
|
from swarms import Agent
|
||||||
|
from swarm_models import OpenAIChat
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_agent():
|
||||||
|
"""Create a mock agent for testing."""
|
||||||
|
return Mock(spec=Agent)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def swarm_rearrange(mock_agent):
|
||||||
|
"""Create a SwarmRearrange instance with mock agent."""
|
||||||
|
return SwarmRearrange(
|
||||||
|
id="test_id",
|
||||||
|
name="TestSwarm",
|
||||||
|
description="Test swarm for testing",
|
||||||
|
swarms=[mock_agent],
|
||||||
|
flow="Agent1 -> Agent2",
|
||||||
|
max_loops=2,
|
||||||
|
verbose=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_initialization(swarm_rearrange):
|
||||||
|
"""Test SwarmRearrange initialization."""
|
||||||
|
assert swarm_rearrange.id == "test_id"
|
||||||
|
assert swarm_rearrange.name == "TestSwarm"
|
||||||
|
assert swarm_rearrange.description == "Test swarm for testing"
|
||||||
|
assert len(swarm_rearrange.swarms) == 1
|
||||||
|
assert swarm_rearrange.flow == "Agent1 -> Agent2"
|
||||||
|
assert swarm_rearrange.max_loops == 2
|
||||||
|
assert swarm_rearrange.verbose is True
|
||||||
|
|
||||||
|
def test_reliability_checks_empty_swarms():
|
||||||
|
"""Test reliability checks with empty swarms."""
|
||||||
|
with pytest.raises(ValueError, match="No swarms found in the swarm."):
|
||||||
|
SwarmRearrange(swarms=[], flow="test")
|
||||||
|
|
||||||
|
def test_reliability_checks_empty_flow():
|
||||||
|
"""Test reliability checks with empty flow."""
|
||||||
|
with pytest.raises(ValueError, match="No flow found in the swarm."):
|
||||||
|
SwarmRearrange(swarms=[Mock()], flow="")
|
||||||
|
|
||||||
|
def test_reliability_checks_invalid_max_loops():
|
||||||
|
"""Test reliability checks with invalid max_loops."""
|
||||||
|
with pytest.raises(ValueError, match="Max loops must be a positive integer."):
|
||||||
|
SwarmRearrange(swarms=[Mock()], flow="test", max_loops=0)
|
||||||
|
|
||||||
|
def test_add_swarm(swarm_rearrange, mock_agent):
|
||||||
|
"""Test adding a new swarm."""
|
||||||
|
new_agent = Mock(spec=Agent)
|
||||||
|
swarm_rearrange.add_swarm(new_agent)
|
||||||
|
assert len(swarm_rearrange.swarms) == 2
|
||||||
|
assert new_agent in swarm_rearrange.swarms.values()
|
||||||
|
|
||||||
|
def test_remove_swarm(swarm_rearrange, mock_agent):
|
||||||
|
"""Test removing a swarm."""
|
||||||
|
swarm_rearrange.remove_swarm(mock_agent.name)
|
||||||
|
assert len(swarm_rearrange.swarms) == 0
|
||||||
|
assert mock_agent.name not in swarm_rearrange.swarms
|
||||||
|
|
||||||
|
def test_add_swarms(swarm_rearrange):
|
||||||
|
"""Test adding multiple swarms."""
|
||||||
|
new_agents = [Mock(spec=Agent) for _ in range(3)]
|
||||||
|
swarm_rearrange.add_swarms(new_agents)
|
||||||
|
assert len(swarm_rearrange.swarms) == 4
|
||||||
|
for agent in new_agents:
|
||||||
|
assert agent in swarm_rearrange.swarms.values()
|
||||||
|
|
||||||
|
def test_track_history(swarm_rearrange, mock_agent):
|
||||||
|
"""Test tracking swarm history."""
|
||||||
|
result = "Test result"
|
||||||
|
swarm_rearrange.track_history(mock_agent.name, result)
|
||||||
|
assert result in swarm_rearrange.swarm_history[mock_agent.name]
|
||||||
|
|
||||||
|
def test_set_custom_flow(swarm_rearrange):
|
||||||
|
"""Test setting custom flow."""
|
||||||
|
new_flow = "Agent1, Agent2 -> Agent3"
|
||||||
|
swarm_rearrange.set_custom_flow(new_flow)
|
||||||
|
assert swarm_rearrange.flow == new_flow
|
||||||
|
|
||||||
|
def test_context_manager(swarm_rearrange):
|
||||||
|
"""Test context manager functionality."""
|
||||||
|
with swarm_rearrange as db:
|
||||||
|
assert db == swarm_rearrange
|
||||||
|
# Verify cleanup was performed
|
||||||
|
assert not swarm_rearrange.session.is_open()
|
||||||
|
|
||||||
|
def test_error_handling(swarm_rearrange):
|
||||||
|
"""Test error handling in various operations."""
|
||||||
|
# Test invalid flow pattern
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
swarm_rearrange.set_custom_flow("Invalid -> Flow -> Pattern")
|
||||||
|
|
||||||
|
# Test removing non-existent swarm
|
||||||
|
with pytest.raises(KeyError):
|
||||||
|
swarm_rearrange.remove_swarm("NonExistentSwarm")
|
||||||
|
|
||||||
|
def test_thread_safety(swarm_rearrange):
|
||||||
|
"""Test thread safety of operations."""
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
|
||||||
|
def add_swarm_thread():
|
||||||
|
for i in range(10):
|
||||||
|
new_agent = Mock(spec=Agent)
|
||||||
|
new_agent.name = f"Agent{i}"
|
||||||
|
swarm_rearrange.add_swarm(new_agent)
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
def remove_swarm_thread():
|
||||||
|
for i in range(10):
|
||||||
|
try:
|
||||||
|
swarm_rearrange.remove_swarm(f"Agent{i}")
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
# Create and start threads
|
||||||
|
threads = [
|
||||||
|
threading.Thread(target=add_swarm_thread),
|
||||||
|
threading.Thread(target=remove_swarm_thread)
|
||||||
|
]
|
||||||
|
|
||||||
|
for thread in threads:
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
for thread in threads:
|
||||||
|
thread.join()
|
||||||
|
|
||||||
|
# Verify no data corruption occurred
|
||||||
|
assert isinstance(swarm_rearrange.swarms, dict)
|
||||||
|
assert isinstance(swarm_rearrange.swarm_history, dict)
|
Loading…
Reference in new issue