You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/tests/structs/test_multiprocess.py

178 lines
5.6 KiB

import asyncio
import time
from swarms.structs.agent import Agent
from swarms.structs.multi_process_workflow import MultiProcessWorkflow
def create_test_agent(name: str) -> Agent:
"""Create a test agent that simply returns its input with a timestamp"""
return Agent(
agent_name=name,
system_prompt=f"Test prompt for {name}",
model_name="gpt-4o-mini",
max_loops=1,
)
def test_initialization():
"""Test basic workflow initialization"""
print("\n=== Testing Workflow Initialization ===")
try:
agents = [create_test_agent(f"agent{i}") for i in range(3)]
workflow = MultiProcessWorkflow(max_workers=2, agents=agents)
print("✓ Created workflow with configuration:")
print(f" - Max workers: {workflow.max_workers}")
print(f" - Number of agents: {len(workflow.agents)}")
print(f" - Autosave: {workflow.autosave}")
print("✓ Initialization test passed")
except Exception as e:
print(f"✗ Initialization test failed: {str(e)}")
raise
def test_execute_task():
"""Test execution of a single task"""
print("\n=== Testing Task Execution ===")
try:
agents = [create_test_agent("test_agent")]
workflow = MultiProcessWorkflow(agents=agents)
test_task = "Return this message with timestamp"
result = workflow.execute_task(test_task)
print("✓ Task executed successfully")
print(f" - Input task: {test_task}")
print(f" - Result: {result}")
print("✓ Task execution test passed")
except Exception as e:
print(f"✗ Task execution test failed: {str(e)}")
raise
def test_parallel_run():
"""Test parallel execution of tasks"""
print("\n=== Testing Parallel Run ===")
try:
agents = [create_test_agent(f"agent{i}") for i in range(3)]
workflow = MultiProcessWorkflow(max_workers=2, agents=agents)
test_task = "Process this in parallel"
results = workflow.run(test_task)
print("✓ Parallel execution completed")
# print(f" - Number of results: {len(results)}")
print(f" - Results: {results}")
print("✓ Parallel run test passed")
except Exception as e:
print(f"✗ Parallel run test failed: {str(e)}")
raise
async def test_async_run():
"""Test asynchronous execution of tasks"""
print("\n=== Testing Async Run ===")
try:
agents = [create_test_agent(f"agent{i}") for i in range(3)]
workflow = MultiProcessWorkflow(max_workers=2, agents=agents)
test_task = "Process this asynchronously"
results = await workflow.async_run(test_task)
print("✓ Async execution completed")
print(f" - Number of results: {len(results)}")
print(f" - Results: {results}")
print("✓ Async run test passed")
except Exception as e:
print(f"✗ Async run test failed: {str(e)}")
raise
def test_batched_run():
"""Test batch execution of tasks"""
print("\n=== Testing Batched Run ===")
try:
agents = [create_test_agent(f"agent{i}") for i in range(2)]
workflow = MultiProcessWorkflow(max_workers=2, agents=agents)
tasks = [f"Batch task {i}" for i in range(5)]
results = workflow.batched_run(tasks, batch_size=2)
print("✓ Batch execution completed")
print(f" - Number of tasks: {len(tasks)}")
print(" - Batch size: 2")
print(f" - Results: {results}")
print("✓ Batched run test passed")
except Exception as e:
print(f"✗ Batched run test failed: {str(e)}")
raise
def test_concurrent_run():
"""Test concurrent execution of tasks"""
print("\n=== Testing Concurrent Run ===")
try:
agents = [create_test_agent(f"agent{i}") for i in range(2)]
workflow = MultiProcessWorkflow(max_workers=2, agents=agents)
tasks = [f"Concurrent task {i}" for i in range(4)]
results = workflow.concurrent_run(tasks)
print("✓ Concurrent execution completed")
print(f" - Number of tasks: {len(tasks)}")
print(f" - Results: {results}")
print("✓ Concurrent run test passed")
except Exception as e:
print(f"✗ Concurrent run test failed: {str(e)}")
raise
def test_error_handling():
"""Test error handling in workflow"""
print("\n=== Testing Error Handling ===")
try:
# Create workflow with no agents to trigger error
workflow = MultiProcessWorkflow(max_workers=2, agents=None)
result = workflow.execute_task(
"This should handle the error gracefully"
)
print("✓ Error handled gracefully")
print(f" - Result when no agents: {result}")
print("✓ Error handling test passed")
except Exception as e:
print(f"✗ Error handling test failed: {str(e)}")
raise
async def run_all_tests():
"""Run all tests"""
print("\n=== Starting MultiProcessWorkflow Test Suite ===")
start_time = time.time()
try:
# Run synchronous tests
test_initialization()
test_execute_task()
test_parallel_run()
test_batched_run()
test_concurrent_run()
test_error_handling()
# Run async test
await test_async_run()
end_time = time.time()
duration = round(end_time - start_time, 2)
print("\n=== Test Suite Completed Successfully ===")
print(f"Time taken: {duration} seconds")
except Exception as e:
print("\n=== Test Suite Failed ===")
print(f"Error: {str(e)}")
raise
if __name__ == "__main__":
asyncio.run(run_all_tests())