You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
178 lines
5.6 KiB
178 lines
5.6 KiB
4 weeks ago
|
import asyncio
|
||
|
import time
|
||
|
from swarms.structs.agent import Agent
|
||
|
from swarms.structs.multi_process_workflow import MultiProcessWorkflow
|
||
|
|
||
|
|
||
|
def create_test_agent(name: str) -> Agent:
|
||
|
"""Create a test agent that simply returns its input with a timestamp"""
|
||
|
return Agent(
|
||
|
agent_name=name,
|
||
|
system_prompt=f"Test prompt for {name}",
|
||
|
model_name="gpt-4o-mini",
|
||
|
max_loops=1,
|
||
|
)
|
||
|
|
||
|
|
||
|
def test_initialization():
|
||
|
"""Test basic workflow initialization"""
|
||
|
print("\n=== Testing Workflow Initialization ===")
|
||
|
try:
|
||
|
agents = [create_test_agent(f"agent{i}") for i in range(3)]
|
||
|
workflow = MultiProcessWorkflow(max_workers=2, agents=agents)
|
||
|
|
||
|
print("✓ Created workflow with configuration:")
|
||
|
print(f" - Max workers: {workflow.max_workers}")
|
||
|
print(f" - Number of agents: {len(workflow.agents)}")
|
||
|
print(f" - Autosave: {workflow.autosave}")
|
||
|
print("✓ Initialization test passed")
|
||
|
except Exception as e:
|
||
|
print(f"✗ Initialization test failed: {str(e)}")
|
||
|
raise
|
||
|
|
||
|
|
||
|
def test_execute_task():
|
||
|
"""Test execution of a single task"""
|
||
|
print("\n=== Testing Task Execution ===")
|
||
|
try:
|
||
|
agents = [create_test_agent("test_agent")]
|
||
|
workflow = MultiProcessWorkflow(agents=agents)
|
||
|
|
||
|
test_task = "Return this message with timestamp"
|
||
|
result = workflow.execute_task(test_task)
|
||
|
|
||
|
print("✓ Task executed successfully")
|
||
|
print(f" - Input task: {test_task}")
|
||
|
print(f" - Result: {result}")
|
||
|
print("✓ Task execution test passed")
|
||
|
except Exception as e:
|
||
|
print(f"✗ Task execution test failed: {str(e)}")
|
||
|
raise
|
||
|
|
||
|
|
||
|
def test_parallel_run():
|
||
|
"""Test parallel execution of tasks"""
|
||
|
print("\n=== Testing Parallel Run ===")
|
||
|
try:
|
||
|
agents = [create_test_agent(f"agent{i}") for i in range(3)]
|
||
|
workflow = MultiProcessWorkflow(max_workers=2, agents=agents)
|
||
|
|
||
|
test_task = "Process this in parallel"
|
||
|
results = workflow.run(test_task)
|
||
|
|
||
|
print("✓ Parallel execution completed")
|
||
|
# print(f" - Number of results: {len(results)}")
|
||
|
print(f" - Results: {results}")
|
||
|
print("✓ Parallel run test passed")
|
||
|
except Exception as e:
|
||
|
print(f"✗ Parallel run test failed: {str(e)}")
|
||
|
raise
|
||
|
|
||
|
|
||
|
async def test_async_run():
|
||
|
"""Test asynchronous execution of tasks"""
|
||
|
print("\n=== Testing Async Run ===")
|
||
|
try:
|
||
|
agents = [create_test_agent(f"agent{i}") for i in range(3)]
|
||
|
workflow = MultiProcessWorkflow(max_workers=2, agents=agents)
|
||
|
|
||
|
test_task = "Process this asynchronously"
|
||
|
results = await workflow.async_run(test_task)
|
||
|
|
||
|
print("✓ Async execution completed")
|
||
|
print(f" - Number of results: {len(results)}")
|
||
|
print(f" - Results: {results}")
|
||
|
print("✓ Async run test passed")
|
||
|
except Exception as e:
|
||
|
print(f"✗ Async run test failed: {str(e)}")
|
||
|
raise
|
||
|
|
||
|
|
||
|
def test_batched_run():
|
||
|
"""Test batch execution of tasks"""
|
||
|
print("\n=== Testing Batched Run ===")
|
||
|
try:
|
||
|
agents = [create_test_agent(f"agent{i}") for i in range(2)]
|
||
|
workflow = MultiProcessWorkflow(max_workers=2, agents=agents)
|
||
|
|
||
|
tasks = [f"Batch task {i}" for i in range(5)]
|
||
|
results = workflow.batched_run(tasks, batch_size=2)
|
||
|
|
||
|
print("✓ Batch execution completed")
|
||
|
print(f" - Number of tasks: {len(tasks)}")
|
||
|
print(" - Batch size: 2")
|
||
|
print(f" - Results: {results}")
|
||
|
print("✓ Batched run test passed")
|
||
|
except Exception as e:
|
||
|
print(f"✗ Batched run test failed: {str(e)}")
|
||
|
raise
|
||
|
|
||
|
|
||
|
def test_concurrent_run():
|
||
|
"""Test concurrent execution of tasks"""
|
||
|
print("\n=== Testing Concurrent Run ===")
|
||
|
try:
|
||
|
agents = [create_test_agent(f"agent{i}") for i in range(2)]
|
||
|
workflow = MultiProcessWorkflow(max_workers=2, agents=agents)
|
||
|
|
||
|
tasks = [f"Concurrent task {i}" for i in range(4)]
|
||
|
results = workflow.concurrent_run(tasks)
|
||
|
|
||
|
print("✓ Concurrent execution completed")
|
||
|
print(f" - Number of tasks: {len(tasks)}")
|
||
|
print(f" - Results: {results}")
|
||
|
print("✓ Concurrent run test passed")
|
||
|
except Exception as e:
|
||
|
print(f"✗ Concurrent run test failed: {str(e)}")
|
||
|
raise
|
||
|
|
||
|
|
||
|
def test_error_handling():
|
||
|
"""Test error handling in workflow"""
|
||
|
print("\n=== Testing Error Handling ===")
|
||
|
try:
|
||
|
# Create workflow with no agents to trigger error
|
||
|
workflow = MultiProcessWorkflow(max_workers=2, agents=None)
|
||
|
result = workflow.execute_task(
|
||
|
"This should handle the error gracefully"
|
||
|
)
|
||
|
|
||
|
print("✓ Error handled gracefully")
|
||
|
print(f" - Result when no agents: {result}")
|
||
|
print("✓ Error handling test passed")
|
||
|
except Exception as e:
|
||
|
print(f"✗ Error handling test failed: {str(e)}")
|
||
|
raise
|
||
|
|
||
|
|
||
|
async def run_all_tests():
|
||
|
"""Run all tests"""
|
||
|
print("\n=== Starting MultiProcessWorkflow Test Suite ===")
|
||
|
start_time = time.time()
|
||
|
|
||
|
try:
|
||
|
# Run synchronous tests
|
||
|
test_initialization()
|
||
|
test_execute_task()
|
||
|
test_parallel_run()
|
||
|
test_batched_run()
|
||
|
test_concurrent_run()
|
||
|
test_error_handling()
|
||
|
|
||
|
# Run async test
|
||
|
await test_async_run()
|
||
|
|
||
|
end_time = time.time()
|
||
|
duration = round(end_time - start_time, 2)
|
||
|
print("\n=== Test Suite Completed Successfully ===")
|
||
|
print(f"Time taken: {duration} seconds")
|
||
|
|
||
|
except Exception as e:
|
||
|
print("\n=== Test Suite Failed ===")
|
||
|
print(f"Error: {str(e)}")
|
||
|
raise
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
asyncio.run(run_all_tests())
|