diff --git a/tests/structs/test_swarm_rearrange.py b/tests/structs/test_swarm_rearrange.py new file mode 100644 index 00000000..86ca345e --- /dev/null +++ b/tests/structs/test_swarm_rearrange.py @@ -0,0 +1,134 @@ +import pytest +from unittest.mock import Mock, patch +from swarms.structs.swarm_arange import SwarmRearrange +from swarms import Agent +from swarm_models import OpenAIChat + +@pytest.fixture +def mock_agent(): + """Create a mock agent for testing.""" + return Mock(spec=Agent) + +@pytest.fixture +def swarm_rearrange(mock_agent): + """Create a SwarmRearrange instance with mock agent.""" + return SwarmRearrange( + id="test_id", + name="TestSwarm", + description="Test swarm for testing", + swarms=[mock_agent], + flow="Agent1 -> Agent2", + max_loops=2, + verbose=True + ) + +def test_initialization(swarm_rearrange): + """Test SwarmRearrange initialization.""" + assert swarm_rearrange.id == "test_id" + assert swarm_rearrange.name == "TestSwarm" + assert swarm_rearrange.description == "Test swarm for testing" + assert len(swarm_rearrange.swarms) == 1 + assert swarm_rearrange.flow == "Agent1 -> Agent2" + assert swarm_rearrange.max_loops == 2 + assert swarm_rearrange.verbose is True + +def test_reliability_checks_empty_swarms(): + """Test reliability checks with empty swarms.""" + with pytest.raises(ValueError, match="No swarms found in the swarm."): + SwarmRearrange(swarms=[], flow="test") + +def test_reliability_checks_empty_flow(): + """Test reliability checks with empty flow.""" + with pytest.raises(ValueError, match="No flow found in the swarm."): + SwarmRearrange(swarms=[Mock()], flow="") + +def test_reliability_checks_invalid_max_loops(): + """Test reliability checks with invalid max_loops.""" + with pytest.raises(ValueError, match="Max loops must be a positive integer."): + SwarmRearrange(swarms=[Mock()], flow="test", max_loops=0) + +def test_add_swarm(swarm_rearrange, mock_agent): + """Test adding a new swarm.""" + new_agent = Mock(spec=Agent) + swarm_rearrange.add_swarm(new_agent) + assert len(swarm_rearrange.swarms) == 2 + assert new_agent in swarm_rearrange.swarms.values() + +def test_remove_swarm(swarm_rearrange, mock_agent): + """Test removing a swarm.""" + swarm_rearrange.remove_swarm(mock_agent.name) + assert len(swarm_rearrange.swarms) == 0 + assert mock_agent.name not in swarm_rearrange.swarms + +def test_add_swarms(swarm_rearrange): + """Test adding multiple swarms.""" + new_agents = [Mock(spec=Agent) for _ in range(3)] + swarm_rearrange.add_swarms(new_agents) + assert len(swarm_rearrange.swarms) == 4 + for agent in new_agents: + assert agent in swarm_rearrange.swarms.values() + +def test_track_history(swarm_rearrange, mock_agent): + """Test tracking swarm history.""" + result = "Test result" + swarm_rearrange.track_history(mock_agent.name, result) + assert result in swarm_rearrange.swarm_history[mock_agent.name] + +def test_set_custom_flow(swarm_rearrange): + """Test setting custom flow.""" + new_flow = "Agent1, Agent2 -> Agent3" + swarm_rearrange.set_custom_flow(new_flow) + assert swarm_rearrange.flow == new_flow + +def test_context_manager(swarm_rearrange): + """Test context manager functionality.""" + with swarm_rearrange as db: + assert db == swarm_rearrange + # Verify cleanup was performed + assert not swarm_rearrange.session.is_open() + +def test_error_handling(swarm_rearrange): + """Test error handling in various operations.""" + # Test invalid flow pattern + with pytest.raises(ValueError): + swarm_rearrange.set_custom_flow("Invalid -> Flow -> Pattern") + + # Test removing non-existent swarm + with pytest.raises(KeyError): + swarm_rearrange.remove_swarm("NonExistentSwarm") + +def test_thread_safety(swarm_rearrange): + """Test thread safety of operations.""" + import threading + import time + + def add_swarm_thread(): + for i in range(10): + new_agent = Mock(spec=Agent) + new_agent.name = f"Agent{i}" + swarm_rearrange.add_swarm(new_agent) + time.sleep(0.1) + + def remove_swarm_thread(): + for i in range(10): + try: + swarm_rearrange.remove_swarm(f"Agent{i}") + except KeyError: + pass + time.sleep(0.1) + + # Create and start threads + threads = [ + threading.Thread(target=add_swarm_thread), + threading.Thread(target=remove_swarm_thread) + ] + + for thread in threads: + thread.start() + + for thread in threads: + thread.join() + + # Verify no data corruption occurred + assert isinstance(swarm_rearrange.swarms, dict) + assert isinstance(swarm_rearrange.swarm_history, dict) \ No newline at end of file