You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/tests/structs/test_airflow_swarm.py

314 lines
8.8 KiB

import time
from loguru import logger
from swarms import Agent
from experimental.airflow_swarm import (
AirflowDAGSwarm,
NodeType,
Conversation,
)
# Configure logger
logger.remove()
logger.add(lambda msg: print(msg, end=""), level="DEBUG")
def test_swarm_initialization():
"""Test basic swarm initialization and configuration."""
try:
swarm = AirflowDAGSwarm(
dag_id="test_dag",
name="Test DAG",
initial_message="Test message",
)
assert swarm.dag_id == "test_dag", "DAG ID not set correctly"
assert swarm.name == "Test DAG", "Name not set correctly"
assert (
len(swarm.nodes) == 0
), "Nodes should be empty on initialization"
assert (
len(swarm.edges) == 0
), "Edges should be empty on initialization"
# Test initial message
conv_json = swarm.get_conversation_history()
assert (
"Test message" in conv_json
), "Initial message not set correctly"
print("✅ Swarm initialization test passed")
return True
except AssertionError as e:
print(f"❌ Swarm initialization test failed: {str(e)}")
return False
def test_node_addition():
"""Test adding different types of nodes to the swarm."""
try:
swarm = AirflowDAGSwarm(dag_id="test_dag")
# Test adding an agent node
agent = Agent(
agent_name="Test-Agent",
system_prompt="Test prompt",
model_name="gpt-4o-mini",
max_loops=1,
)
agent_id = swarm.add_node(
"test_agent",
agent,
NodeType.AGENT,
query="Test query",
concurrent=True,
)
assert (
agent_id == "test_agent"
), "Agent node ID not returned correctly"
assert (
"test_agent" in swarm.nodes
), "Agent node not added to nodes dict"
# Test adding a callable node
def test_callable(x: int, conversation: Conversation) -> str:
return f"Test output {x}"
callable_id = swarm.add_node(
"test_callable",
test_callable,
NodeType.CALLABLE,
args=[42],
concurrent=False,
)
assert (
callable_id == "test_callable"
), "Callable node ID not returned correctly"
assert (
"test_callable" in swarm.nodes
), "Callable node not added to nodes dict"
print("✅ Node addition test passed")
return True
except AssertionError as e:
print(f"❌ Node addition test failed: {str(e)}")
return False
except Exception as e:
print(
f"❌ Node addition test failed with unexpected error: {str(e)}"
)
return False
def test_edge_addition():
"""Test adding edges between nodes."""
try:
swarm = AirflowDAGSwarm(dag_id="test_dag")
# Add two nodes
def node1_fn(conversation: Conversation) -> str:
return "Node 1 output"
def node2_fn(conversation: Conversation) -> str:
return "Node 2 output"
swarm.add_node("node1", node1_fn, NodeType.CALLABLE)
swarm.add_node("node2", node2_fn, NodeType.CALLABLE)
# Add edge between them
swarm.add_edge("node1", "node2")
assert (
"node2" in swarm.edges["node1"]
), "Edge not added correctly"
assert (
len(swarm.edges["node1"]) == 1
), "Incorrect number of edges"
# Test adding edge with non-existent node
try:
swarm.add_edge("node1", "non_existent")
assert (
False
), "Should raise ValueError for non-existent node"
except ValueError:
pass
print("✅ Edge addition test passed")
return True
except AssertionError as e:
print(f"❌ Edge addition test failed: {str(e)}")
return False
def test_execution_order():
"""Test that nodes are executed in the correct order based on dependencies."""
try:
swarm = AirflowDAGSwarm(dag_id="test_dag")
execution_order = []
def node1(conversation: Conversation) -> str:
execution_order.append("node1")
return "Node 1 output"
def node2(conversation: Conversation) -> str:
execution_order.append("node2")
return "Node 2 output"
def node3(conversation: Conversation) -> str:
execution_order.append("node3")
return "Node 3 output"
# Add nodes
swarm.add_node(
"node1", node1, NodeType.CALLABLE, concurrent=False
)
swarm.add_node(
"node2", node2, NodeType.CALLABLE, concurrent=False
)
swarm.add_node(
"node3", node3, NodeType.CALLABLE, concurrent=False
)
# Add edges to create a chain: node1 -> node2 -> node3
swarm.add_edge("node1", "node2")
swarm.add_edge("node2", "node3")
# Execute
swarm.run()
# Check execution order
assert execution_order == [
"node1",
"node2",
"node3",
], "Incorrect execution order"
print("✅ Execution order test passed")
return True
except AssertionError as e:
print(f"❌ Execution order test failed: {str(e)}")
return False
def test_concurrent_execution():
"""Test concurrent execution of nodes."""
try:
swarm = AirflowDAGSwarm(dag_id="test_dag")
def slow_node1(conversation: Conversation) -> str:
time.sleep(0.5)
return "Slow node 1 output"
def slow_node2(conversation: Conversation) -> str:
time.sleep(0.5)
return "Slow node 2 output"
# Add nodes with concurrent=True
swarm.add_node(
"slow1", slow_node1, NodeType.CALLABLE, concurrent=True
)
swarm.add_node(
"slow2", slow_node2, NodeType.CALLABLE, concurrent=True
)
# Measure execution time
start_time = time.time()
swarm.run()
execution_time = time.time() - start_time
# Should take ~0.5s for concurrent execution, not ~1s
assert (
execution_time < 0.8
), "Concurrent execution took too long"
print("✅ Concurrent execution test passed")
return True
except AssertionError as e:
print(f"❌ Concurrent execution test failed: {str(e)}")
return False
def test_conversation_handling():
"""Test conversation management within the swarm."""
try:
swarm = AirflowDAGSwarm(
dag_id="test_dag", initial_message="Initial test message"
)
# Test adding user messages
swarm.add_user_message("Test message 1")
swarm.add_user_message("Test message 2")
history = swarm.get_conversation_history()
assert (
"Initial test message" in history
), "Initial message not in history"
assert (
"Test message 1" in history
), "First message not in history"
assert (
"Test message 2" in history
), "Second message not in history"
print("✅ Conversation handling test passed")
return True
except AssertionError as e:
print(f"❌ Conversation handling test failed: {str(e)}")
return False
def test_error_handling():
"""Test error handling in node execution."""
try:
swarm = AirflowDAGSwarm(dag_id="test_dag")
def failing_node(conversation: Conversation) -> str:
raise ValueError("Test error")
swarm.add_node("failing", failing_node, NodeType.CALLABLE)
# Execute should not raise an exception
result = swarm.run()
assert (
"Error" in result
), "Error not captured in execution result"
assert (
"Test error" in result
), "Specific error message not captured"
print("✅ Error handling test passed")
return True
except Exception as e:
print(f"❌ Error handling test failed: {str(e)}")
return False
def run_all_tests():
"""Run all test functions and report results."""
tests = [
test_swarm_initialization,
test_node_addition,
test_edge_addition,
test_execution_order,
test_concurrent_execution,
test_conversation_handling,
test_error_handling,
]
results = []
for test in tests:
print(f"\nRunning {test.__name__}...")
result = test()
results.append(result)
total = len(results)
passed = sum(results)
print("\n=== Test Results ===")
print(f"Total tests: {total}")
print(f"Passed: {passed}")
print(f"Failed: {total - passed}")
print("==================")
if __name__ == "__main__":
run_all_tests()