from swarms.structs.matrix_swarm import AgentMatrix, AgentOutput from swarms import Agent def create_test_matrix(rows: int, cols: int) -> AgentMatrix: """Helper function to create a test agent matrix""" agents = [ [ Agent( agent_name=f"TestAgent-{i}-{j}", system_prompt="Test prompt", ) for j in range(cols) ] for i in range(rows) ] return AgentMatrix(agents) def test_init(): """Test AgentMatrix initialization""" # Test valid initialization matrix = create_test_matrix(2, 2) assert isinstance(matrix, AgentMatrix) assert len(matrix.agents) == 2 assert len(matrix.agents[0]) == 2 # Test invalid initialization try: AgentMatrix([[1, 2], [3, 4]]) # Non-agent elements assert False, "Should raise ValueError" except ValueError: pass try: AgentMatrix([]) # Empty matrix assert False, "Should raise ValueError" except ValueError: pass def test_transpose(): """Test matrix transpose operation""" matrix = create_test_matrix(2, 3) transposed = matrix.transpose() assert len(transposed.agents) == 3 # Original cols become rows assert len(transposed.agents[0]) == 2 # Original rows become cols # Verify agent positions for i in range(2): for j in range(3): assert ( matrix.agents[i][j].agent_name == transposed.agents[j][i].agent_name ) def test_add(): """Test matrix addition""" matrix1 = create_test_matrix(2, 2) matrix2 = create_test_matrix(2, 2) result = matrix1.add(matrix2) assert len(result.agents) == 2 assert len(result.agents[0]) == 2 # Test incompatible dimensions matrix3 = create_test_matrix(2, 3) try: matrix1.add(matrix3) assert False, "Should raise ValueError" except ValueError: pass def test_scalar_multiply(): """Test scalar multiplication""" matrix = create_test_matrix(2, 2) scalar = 3 result = matrix.scalar_multiply(scalar) assert len(result.agents) == 2 assert len(result.agents[0]) == 2 * scalar # Verify agent duplication for i in range(len(result.agents)): for j in range(0, len(result.agents[0]), scalar): original_agent = matrix.agents[i][j // scalar] for k in range(scalar): assert ( result.agents[i][j + k].agent_name == original_agent.agent_name ) def test_multiply(): """Test matrix multiplication""" matrix1 = create_test_matrix(2, 3) matrix2 = create_test_matrix(3, 2) inputs = ["test query 1", "test query 2"] result = matrix1.multiply(matrix2, inputs) assert len(result) == 2 # Number of rows in first matrix assert len(result[0]) == 2 # Number of columns in second matrix # Verify output structure for row in result: for output in row: assert isinstance(output, AgentOutput) assert isinstance(output.input_query, str) assert isinstance(output.metadata, dict) def test_subtract(): """Test matrix subtraction""" matrix1 = create_test_matrix(2, 2) matrix2 = create_test_matrix(2, 2) result = matrix1.subtract(matrix2) assert len(result.agents) == 2 assert len(result.agents[0]) == 2 def test_identity(): """Test identity matrix creation""" matrix = create_test_matrix(3, 3) identity = matrix.identity(3) assert len(identity.agents) == 3 assert len(identity.agents[0]) == 3 # Verify diagonal elements are from original matrix for i in range(3): assert ( identity.agents[i][i].agent_name == matrix.agents[i][i].agent_name ) # Verify non-diagonal elements are zero agents for j in range(3): if i != j: assert identity.agents[i][j].agent_name.startswith( "Zero-Agent" ) def test_determinant(): """Test determinant calculation""" # Test 1x1 matrix matrix1 = create_test_matrix(1, 1) det1 = matrix1.determinant() assert det1 is not None # Test 2x2 matrix matrix2 = create_test_matrix(2, 2) det2 = matrix2.determinant() assert det2 is not None # Test non-square matrix matrix3 = create_test_matrix(2, 3) try: matrix3.determinant() assert False, "Should raise ValueError" except ValueError: pass def test_save_to_file(tmp_path): """Test saving matrix to file""" import os matrix = create_test_matrix(2, 2) file_path = os.path.join(tmp_path, "test_matrix.json") matrix.save_to_file(file_path) assert os.path.exists(file_path) # Verify file contents import json with open(file_path, "r") as f: data = json.load(f) assert "agents" in data assert "outputs" in data assert len(data["agents"]) == 2 assert len(data["agents"][0]) == 2 def run_all_tests(): """Run all test functions""" test_functions = [ test_init, test_transpose, test_add, test_scalar_multiply, test_multiply, test_subtract, test_identity, test_determinant, ] for test_func in test_functions: try: test_func() print(f"✅ {test_func.__name__} passed") except AssertionError as e: print(f"❌ {test_func.__name__} failed: {str(e)}") except Exception as e: print( f"❌ {test_func.__name__} failed with exception: {str(e)}" ) if __name__ == "__main__": run_all_tests()