You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							314 lines
						
					
					
						
							8.8 KiB
						
					
					
				
			
		
		
	
	
							314 lines
						
					
					
						
							8.8 KiB
						
					
					
				| import time
 | |
| 
 | |
| from loguru import logger
 | |
| from swarms import Agent
 | |
| 
 | |
| from experimental.airflow_swarm import (
 | |
|     AirflowDAGSwarm,
 | |
|     NodeType,
 | |
|     Conversation,
 | |
| )
 | |
| 
 | |
| # Configure logger
 | |
| logger.remove()
 | |
| logger.add(lambda msg: print(msg, end=""), level="DEBUG")
 | |
| 
 | |
| 
 | |
| def test_swarm_initialization():
 | |
|     """Test basic swarm initialization and configuration."""
 | |
|     try:
 | |
|         swarm = AirflowDAGSwarm(
 | |
|             dag_id="test_dag",
 | |
|             name="Test DAG",
 | |
|             initial_message="Test message",
 | |
|         )
 | |
|         assert swarm.dag_id == "test_dag", "DAG ID not set correctly"
 | |
|         assert swarm.name == "Test DAG", "Name not set correctly"
 | |
|         assert (
 | |
|             len(swarm.nodes) == 0
 | |
|         ), "Nodes should be empty on initialization"
 | |
|         assert (
 | |
|             len(swarm.edges) == 0
 | |
|         ), "Edges should be empty on initialization"
 | |
| 
 | |
|         # Test initial message
 | |
|         conv_json = swarm.get_conversation_history()
 | |
|         assert (
 | |
|             "Test message" in conv_json
 | |
|         ), "Initial message not set correctly"
 | |
|         print("✅ Swarm initialization test passed")
 | |
|         return True
 | |
|     except AssertionError as e:
 | |
|         print(f"❌ Swarm initialization test failed: {str(e)}")
 | |
|         return False
 | |
| 
 | |
| 
 | |
| def test_node_addition():
 | |
|     """Test adding different types of nodes to the swarm."""
 | |
|     try:
 | |
|         swarm = AirflowDAGSwarm(dag_id="test_dag")
 | |
| 
 | |
|         # Test adding an agent node
 | |
|         agent = Agent(
 | |
|             agent_name="Test-Agent",
 | |
|             system_prompt="Test prompt",
 | |
|             model_name="gpt-4o-mini",
 | |
|             max_loops=1,
 | |
|         )
 | |
|         agent_id = swarm.add_node(
 | |
|             "test_agent",
 | |
|             agent,
 | |
|             NodeType.AGENT,
 | |
|             query="Test query",
 | |
|             concurrent=True,
 | |
|         )
 | |
|         assert (
 | |
|             agent_id == "test_agent"
 | |
|         ), "Agent node ID not returned correctly"
 | |
|         assert (
 | |
|             "test_agent" in swarm.nodes
 | |
|         ), "Agent node not added to nodes dict"
 | |
| 
 | |
|         # Test adding a callable node
 | |
|         def test_callable(x: int, conversation: Conversation) -> str:
 | |
|             return f"Test output {x}"
 | |
| 
 | |
|         callable_id = swarm.add_node(
 | |
|             "test_callable",
 | |
|             test_callable,
 | |
|             NodeType.CALLABLE,
 | |
|             args=[42],
 | |
|             concurrent=False,
 | |
|         )
 | |
|         assert (
 | |
|             callable_id == "test_callable"
 | |
|         ), "Callable node ID not returned correctly"
 | |
|         assert (
 | |
|             "test_callable" in swarm.nodes
 | |
|         ), "Callable node not added to nodes dict"
 | |
| 
 | |
|         print("✅ Node addition test passed")
 | |
|         return True
 | |
|     except AssertionError as e:
 | |
|         print(f"❌ Node addition test failed: {str(e)}")
 | |
|         return False
 | |
|     except Exception as e:
 | |
|         print(
 | |
|             f"❌ Node addition test failed with unexpected error: {str(e)}"
 | |
|         )
 | |
|         return False
 | |
| 
 | |
| 
 | |
| def test_edge_addition():
 | |
|     """Test adding edges between nodes."""
 | |
|     try:
 | |
|         swarm = AirflowDAGSwarm(dag_id="test_dag")
 | |
| 
 | |
|         # Add two nodes
 | |
|         def node1_fn(conversation: Conversation) -> str:
 | |
|             return "Node 1 output"
 | |
| 
 | |
|         def node2_fn(conversation: Conversation) -> str:
 | |
|             return "Node 2 output"
 | |
| 
 | |
|         swarm.add_node("node1", node1_fn, NodeType.CALLABLE)
 | |
|         swarm.add_node("node2", node2_fn, NodeType.CALLABLE)
 | |
| 
 | |
|         # Add edge between them
 | |
|         swarm.add_edge("node1", "node2")
 | |
| 
 | |
|         assert (
 | |
|             "node2" in swarm.edges["node1"]
 | |
|         ), "Edge not added correctly"
 | |
|         assert (
 | |
|             len(swarm.edges["node1"]) == 1
 | |
|         ), "Incorrect number of edges"
 | |
| 
 | |
|         # Test adding edge with non-existent node
 | |
|         try:
 | |
|             swarm.add_edge("node1", "non_existent")
 | |
|             assert (
 | |
|                 False
 | |
|             ), "Should raise ValueError for non-existent node"
 | |
|         except ValueError:
 | |
|             pass
 | |
| 
 | |
|         print("✅ Edge addition test passed")
 | |
|         return True
 | |
|     except AssertionError as e:
 | |
|         print(f"❌ Edge addition test failed: {str(e)}")
 | |
|         return False
 | |
| 
 | |
| 
 | |
| def test_execution_order():
 | |
|     """Test that nodes are executed in the correct order based on dependencies."""
 | |
|     try:
 | |
|         swarm = AirflowDAGSwarm(dag_id="test_dag")
 | |
|         execution_order = []
 | |
| 
 | |
|         def node1(conversation: Conversation) -> str:
 | |
|             execution_order.append("node1")
 | |
|             return "Node 1 output"
 | |
| 
 | |
|         def node2(conversation: Conversation) -> str:
 | |
|             execution_order.append("node2")
 | |
|             return "Node 2 output"
 | |
| 
 | |
|         def node3(conversation: Conversation) -> str:
 | |
|             execution_order.append("node3")
 | |
|             return "Node 3 output"
 | |
| 
 | |
|         # Add nodes
 | |
|         swarm.add_node(
 | |
|             "node1", node1, NodeType.CALLABLE, concurrent=False
 | |
|         )
 | |
|         swarm.add_node(
 | |
|             "node2", node2, NodeType.CALLABLE, concurrent=False
 | |
|         )
 | |
|         swarm.add_node(
 | |
|             "node3", node3, NodeType.CALLABLE, concurrent=False
 | |
|         )
 | |
| 
 | |
|         # Add edges to create a chain: node1 -> node2 -> node3
 | |
|         swarm.add_edge("node1", "node2")
 | |
|         swarm.add_edge("node2", "node3")
 | |
| 
 | |
|         # Execute
 | |
|         swarm.run()
 | |
| 
 | |
|         # Check execution order
 | |
|         assert execution_order == [
 | |
|             "node1",
 | |
|             "node2",
 | |
|             "node3",
 | |
|         ], "Incorrect execution order"
 | |
|         print("✅ Execution order test passed")
 | |
|         return True
 | |
|     except AssertionError as e:
 | |
|         print(f"❌ Execution order test failed: {str(e)}")
 | |
|         return False
 | |
| 
 | |
| 
 | |
| def test_concurrent_execution():
 | |
|     """Test concurrent execution of nodes."""
 | |
|     try:
 | |
|         swarm = AirflowDAGSwarm(dag_id="test_dag")
 | |
| 
 | |
|         def slow_node1(conversation: Conversation) -> str:
 | |
|             time.sleep(0.5)
 | |
|             return "Slow node 1 output"
 | |
| 
 | |
|         def slow_node2(conversation: Conversation) -> str:
 | |
|             time.sleep(0.5)
 | |
|             return "Slow node 2 output"
 | |
| 
 | |
|         # Add nodes with concurrent=True
 | |
|         swarm.add_node(
 | |
|             "slow1", slow_node1, NodeType.CALLABLE, concurrent=True
 | |
|         )
 | |
|         swarm.add_node(
 | |
|             "slow2", slow_node2, NodeType.CALLABLE, concurrent=True
 | |
|         )
 | |
| 
 | |
|         # Measure execution time
 | |
|         start_time = time.time()
 | |
|         swarm.run()
 | |
|         execution_time = time.time() - start_time
 | |
| 
 | |
|         # Should take ~0.5s for concurrent execution, not ~1s
 | |
|         assert (
 | |
|             execution_time < 0.8
 | |
|         ), "Concurrent execution took too long"
 | |
|         print("✅ Concurrent execution test passed")
 | |
|         return True
 | |
|     except AssertionError as e:
 | |
|         print(f"❌ Concurrent execution test failed: {str(e)}")
 | |
|         return False
 | |
| 
 | |
| 
 | |
| def test_conversation_handling():
 | |
|     """Test conversation management within the swarm."""
 | |
|     try:
 | |
|         swarm = AirflowDAGSwarm(
 | |
|             dag_id="test_dag", initial_message="Initial test message"
 | |
|         )
 | |
| 
 | |
|         # Test adding user messages
 | |
|         swarm.add_user_message("Test message 1")
 | |
|         swarm.add_user_message("Test message 2")
 | |
| 
 | |
|         history = swarm.get_conversation_history()
 | |
|         assert (
 | |
|             "Initial test message" in history
 | |
|         ), "Initial message not in history"
 | |
|         assert (
 | |
|             "Test message 1" in history
 | |
|         ), "First message not in history"
 | |
|         assert (
 | |
|             "Test message 2" in history
 | |
|         ), "Second message not in history"
 | |
| 
 | |
|         print("✅ Conversation handling test passed")
 | |
|         return True
 | |
|     except AssertionError as e:
 | |
|         print(f"❌ Conversation handling test failed: {str(e)}")
 | |
|         return False
 | |
| 
 | |
| 
 | |
| def test_error_handling():
 | |
|     """Test error handling in node execution."""
 | |
|     try:
 | |
|         swarm = AirflowDAGSwarm(dag_id="test_dag")
 | |
| 
 | |
|         def failing_node(conversation: Conversation) -> str:
 | |
|             raise ValueError("Test error")
 | |
| 
 | |
|         swarm.add_node("failing", failing_node, NodeType.CALLABLE)
 | |
| 
 | |
|         # Execute should not raise an exception
 | |
|         result = swarm.run()
 | |
| 
 | |
|         assert (
 | |
|             "Error" in result
 | |
|         ), "Error not captured in execution result"
 | |
|         assert (
 | |
|             "Test error" in result
 | |
|         ), "Specific error message not captured"
 | |
| 
 | |
|         print("✅ Error handling test passed")
 | |
|         return True
 | |
|     except Exception as e:
 | |
|         print(f"❌ Error handling test failed: {str(e)}")
 | |
|         return False
 | |
| 
 | |
| 
 | |
| def run_all_tests():
 | |
|     """Run all test functions and report results."""
 | |
|     tests = [
 | |
|         test_swarm_initialization,
 | |
|         test_node_addition,
 | |
|         test_edge_addition,
 | |
|         test_execution_order,
 | |
|         test_concurrent_execution,
 | |
|         test_conversation_handling,
 | |
|         test_error_handling,
 | |
|     ]
 | |
| 
 | |
|     results = []
 | |
|     for test in tests:
 | |
|         print(f"\nRunning {test.__name__}...")
 | |
|         result = test()
 | |
|         results.append(result)
 | |
| 
 | |
|     total = len(results)
 | |
|     passed = sum(results)
 | |
|     print("\n=== Test Results ===")
 | |
|     print(f"Total tests: {total}")
 | |
|     print(f"Passed: {passed}")
 | |
|     print(f"Failed: {total - passed}")
 | |
|     print("==================")
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     run_all_tests()
 |