You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
314 lines
8.8 KiB
314 lines
8.8 KiB
import time
|
|
|
|
from loguru import logger
|
|
from swarms import Agent
|
|
|
|
from experimental.airflow_swarm import (
|
|
AirflowDAGSwarm,
|
|
NodeType,
|
|
Conversation,
|
|
)
|
|
|
|
# Configure logger
|
|
logger.remove()
|
|
logger.add(lambda msg: print(msg, end=""), level="DEBUG")
|
|
|
|
|
|
def test_swarm_initialization():
|
|
"""Test basic swarm initialization and configuration."""
|
|
try:
|
|
swarm = AirflowDAGSwarm(
|
|
dag_id="test_dag",
|
|
name="Test DAG",
|
|
initial_message="Test message",
|
|
)
|
|
assert swarm.dag_id == "test_dag", "DAG ID not set correctly"
|
|
assert swarm.name == "Test DAG", "Name not set correctly"
|
|
assert (
|
|
len(swarm.nodes) == 0
|
|
), "Nodes should be empty on initialization"
|
|
assert (
|
|
len(swarm.edges) == 0
|
|
), "Edges should be empty on initialization"
|
|
|
|
# Test initial message
|
|
conv_json = swarm.get_conversation_history()
|
|
assert (
|
|
"Test message" in conv_json
|
|
), "Initial message not set correctly"
|
|
print("✅ Swarm initialization test passed")
|
|
return True
|
|
except AssertionError as e:
|
|
print(f"❌ Swarm initialization test failed: {str(e)}")
|
|
return False
|
|
|
|
|
|
def test_node_addition():
|
|
"""Test adding different types of nodes to the swarm."""
|
|
try:
|
|
swarm = AirflowDAGSwarm(dag_id="test_dag")
|
|
|
|
# Test adding an agent node
|
|
agent = Agent(
|
|
agent_name="Test-Agent",
|
|
system_prompt="Test prompt",
|
|
model_name="gpt-4o-mini",
|
|
max_loops=1,
|
|
)
|
|
agent_id = swarm.add_node(
|
|
"test_agent",
|
|
agent,
|
|
NodeType.AGENT,
|
|
query="Test query",
|
|
concurrent=True,
|
|
)
|
|
assert (
|
|
agent_id == "test_agent"
|
|
), "Agent node ID not returned correctly"
|
|
assert (
|
|
"test_agent" in swarm.nodes
|
|
), "Agent node not added to nodes dict"
|
|
|
|
# Test adding a callable node
|
|
def test_callable(x: int, conversation: Conversation) -> str:
|
|
return f"Test output {x}"
|
|
|
|
callable_id = swarm.add_node(
|
|
"test_callable",
|
|
test_callable,
|
|
NodeType.CALLABLE,
|
|
args=[42],
|
|
concurrent=False,
|
|
)
|
|
assert (
|
|
callable_id == "test_callable"
|
|
), "Callable node ID not returned correctly"
|
|
assert (
|
|
"test_callable" in swarm.nodes
|
|
), "Callable node not added to nodes dict"
|
|
|
|
print("✅ Node addition test passed")
|
|
return True
|
|
except AssertionError as e:
|
|
print(f"❌ Node addition test failed: {str(e)}")
|
|
return False
|
|
except Exception as e:
|
|
print(
|
|
f"❌ Node addition test failed with unexpected error: {str(e)}"
|
|
)
|
|
return False
|
|
|
|
|
|
def test_edge_addition():
|
|
"""Test adding edges between nodes."""
|
|
try:
|
|
swarm = AirflowDAGSwarm(dag_id="test_dag")
|
|
|
|
# Add two nodes
|
|
def node1_fn(conversation: Conversation) -> str:
|
|
return "Node 1 output"
|
|
|
|
def node2_fn(conversation: Conversation) -> str:
|
|
return "Node 2 output"
|
|
|
|
swarm.add_node("node1", node1_fn, NodeType.CALLABLE)
|
|
swarm.add_node("node2", node2_fn, NodeType.CALLABLE)
|
|
|
|
# Add edge between them
|
|
swarm.add_edge("node1", "node2")
|
|
|
|
assert (
|
|
"node2" in swarm.edges["node1"]
|
|
), "Edge not added correctly"
|
|
assert (
|
|
len(swarm.edges["node1"]) == 1
|
|
), "Incorrect number of edges"
|
|
|
|
# Test adding edge with non-existent node
|
|
try:
|
|
swarm.add_edge("node1", "non_existent")
|
|
assert (
|
|
False
|
|
), "Should raise ValueError for non-existent node"
|
|
except ValueError:
|
|
pass
|
|
|
|
print("✅ Edge addition test passed")
|
|
return True
|
|
except AssertionError as e:
|
|
print(f"❌ Edge addition test failed: {str(e)}")
|
|
return False
|
|
|
|
|
|
def test_execution_order():
|
|
"""Test that nodes are executed in the correct order based on dependencies."""
|
|
try:
|
|
swarm = AirflowDAGSwarm(dag_id="test_dag")
|
|
execution_order = []
|
|
|
|
def node1(conversation: Conversation) -> str:
|
|
execution_order.append("node1")
|
|
return "Node 1 output"
|
|
|
|
def node2(conversation: Conversation) -> str:
|
|
execution_order.append("node2")
|
|
return "Node 2 output"
|
|
|
|
def node3(conversation: Conversation) -> str:
|
|
execution_order.append("node3")
|
|
return "Node 3 output"
|
|
|
|
# Add nodes
|
|
swarm.add_node(
|
|
"node1", node1, NodeType.CALLABLE, concurrent=False
|
|
)
|
|
swarm.add_node(
|
|
"node2", node2, NodeType.CALLABLE, concurrent=False
|
|
)
|
|
swarm.add_node(
|
|
"node3", node3, NodeType.CALLABLE, concurrent=False
|
|
)
|
|
|
|
# Add edges to create a chain: node1 -> node2 -> node3
|
|
swarm.add_edge("node1", "node2")
|
|
swarm.add_edge("node2", "node3")
|
|
|
|
# Execute
|
|
swarm.run()
|
|
|
|
# Check execution order
|
|
assert execution_order == [
|
|
"node1",
|
|
"node2",
|
|
"node3",
|
|
], "Incorrect execution order"
|
|
print("✅ Execution order test passed")
|
|
return True
|
|
except AssertionError as e:
|
|
print(f"❌ Execution order test failed: {str(e)}")
|
|
return False
|
|
|
|
|
|
def test_concurrent_execution():
|
|
"""Test concurrent execution of nodes."""
|
|
try:
|
|
swarm = AirflowDAGSwarm(dag_id="test_dag")
|
|
|
|
def slow_node1(conversation: Conversation) -> str:
|
|
time.sleep(0.5)
|
|
return "Slow node 1 output"
|
|
|
|
def slow_node2(conversation: Conversation) -> str:
|
|
time.sleep(0.5)
|
|
return "Slow node 2 output"
|
|
|
|
# Add nodes with concurrent=True
|
|
swarm.add_node(
|
|
"slow1", slow_node1, NodeType.CALLABLE, concurrent=True
|
|
)
|
|
swarm.add_node(
|
|
"slow2", slow_node2, NodeType.CALLABLE, concurrent=True
|
|
)
|
|
|
|
# Measure execution time
|
|
start_time = time.time()
|
|
swarm.run()
|
|
execution_time = time.time() - start_time
|
|
|
|
# Should take ~0.5s for concurrent execution, not ~1s
|
|
assert (
|
|
execution_time < 0.8
|
|
), "Concurrent execution took too long"
|
|
print("✅ Concurrent execution test passed")
|
|
return True
|
|
except AssertionError as e:
|
|
print(f"❌ Concurrent execution test failed: {str(e)}")
|
|
return False
|
|
|
|
|
|
def test_conversation_handling():
|
|
"""Test conversation management within the swarm."""
|
|
try:
|
|
swarm = AirflowDAGSwarm(
|
|
dag_id="test_dag", initial_message="Initial test message"
|
|
)
|
|
|
|
# Test adding user messages
|
|
swarm.add_user_message("Test message 1")
|
|
swarm.add_user_message("Test message 2")
|
|
|
|
history = swarm.get_conversation_history()
|
|
assert (
|
|
"Initial test message" in history
|
|
), "Initial message not in history"
|
|
assert (
|
|
"Test message 1" in history
|
|
), "First message not in history"
|
|
assert (
|
|
"Test message 2" in history
|
|
), "Second message not in history"
|
|
|
|
print("✅ Conversation handling test passed")
|
|
return True
|
|
except AssertionError as e:
|
|
print(f"❌ Conversation handling test failed: {str(e)}")
|
|
return False
|
|
|
|
|
|
def test_error_handling():
|
|
"""Test error handling in node execution."""
|
|
try:
|
|
swarm = AirflowDAGSwarm(dag_id="test_dag")
|
|
|
|
def failing_node(conversation: Conversation) -> str:
|
|
raise ValueError("Test error")
|
|
|
|
swarm.add_node("failing", failing_node, NodeType.CALLABLE)
|
|
|
|
# Execute should not raise an exception
|
|
result = swarm.run()
|
|
|
|
assert (
|
|
"Error" in result
|
|
), "Error not captured in execution result"
|
|
assert (
|
|
"Test error" in result
|
|
), "Specific error message not captured"
|
|
|
|
print("✅ Error handling test passed")
|
|
return True
|
|
except Exception as e:
|
|
print(f"❌ Error handling test failed: {str(e)}")
|
|
return False
|
|
|
|
|
|
def run_all_tests():
|
|
"""Run all test functions and report results."""
|
|
tests = [
|
|
test_swarm_initialization,
|
|
test_node_addition,
|
|
test_edge_addition,
|
|
test_execution_order,
|
|
test_concurrent_execution,
|
|
test_conversation_handling,
|
|
test_error_handling,
|
|
]
|
|
|
|
results = []
|
|
for test in tests:
|
|
print(f"\nRunning {test.__name__}...")
|
|
result = test()
|
|
results.append(result)
|
|
|
|
total = len(results)
|
|
passed = sum(results)
|
|
print("\n=== Test Results ===")
|
|
print(f"Total tests: {total}")
|
|
print(f"Passed: {passed}")
|
|
print(f"Failed: {total - passed}")
|
|
print("==================")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_all_tests()
|