You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/swarms/structs/multi_process_workflow.py

183 lines
5.5 KiB

from functools import wraps
from multiprocessing import Manager, Pool, cpu_count
from time import sleep
from typing import Sequence
from swarms.structs.agent import Agent
from swarms.structs.base_workflow import BaseWorkflow
from swarms.utils.loguru_logger import logger
# Retry on failure
def retry_on_failure(max_retries: int = 3, delay: int = 5):
"""
Decorator that retries a function a specified number of times on failure.
Args:
max_retries (int): The maximum number of retries (default: 3).
delay (int): The delay in seconds between retries (default: 5).
Returns:
The result of the function if it succeeds within the maximum number of retries,
otherwise None.
"""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
for _ in range(max_retries):
try:
return func(*args, **kwargs)
except Exception as error:
logger.error(
f"Error: {str(error)}, retrying in"
f" {delay} seconds..."
)
sleep(delay)
return None
return wrapper
return decorator
class MultiProcessWorkflow(BaseWorkflow):
"""
Initialize a MultiProcessWorkflow object.
Args:
max_workers (int): The maximum number of workers to use for parallel processing.
autosave (bool): Flag indicating whether to automatically save the workflow.
tasks (List[Task]): A list of Task objects representing the workflow tasks.
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Example:
>>> from swarms.structs.multi_process_workflow import MultiProcessingWorkflow
>>> from swarms.structs.task import Task
>>> from datetime import datetime
>>> from time import sleep
>>>
>>> # Define a simple task
>>> def simple_task():
>>> sleep(1)
>>> return datetime.now()
>>>
>>> # Create a task object
>>> task = Task(
>>> name="Simple Task",
>>> execute=simple_task,
>>> priority=1,
>>> )
>>>
>>> # Create a workflow with the task
>>> workflow = MultiProcessingWorkflow(tasks=[task])
>>>
>>> # Run the workflow
>>> results = workflow.run(task)
>>>
>>> # Print the results
>>> print(results)
"""
def __init__(
self,
max_workers: int = 5,
autosave: bool = True,
agents: Sequence[Agent] = None,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.max_workers = max_workers
self.autosave = autosave
self.agents = agents
self.max_workers or cpu_count()
# Log
logger.info(
(
"Initialized MultiProcessWorkflow with"
f" {self.max_workers} max workers and autosave set to"
f" {self.autosave}"
),
)
# Log the agents
if self.agents is not None:
for agent in self.agents:
logger.info(f"Agent: {agent.agent_name}")
def execute_task(self, task: str, *args, **kwargs):
"""Execute a task and handle exceptions.
Args:
task (Task): The task to execute.
*args: Additional positional arguments for the task execution.
**kwargs: Additional keyword arguments for the task execution.
Returns:
Any: The result of the task execution.
"""
try:
if self.agents is not None:
# Execute the task
for agent in self.agents:
result = agent.run(task, *args, **kwargs)
return result
except Exception as e:
logger.error(
(
"An error occurred during execution of task"
f" {task}: {str(e)}"
),
)
return None
def run(self, task: str, *args, **kwargs):
"""Run the workflow.
Args:
task (Task): The task to run.
*args: Additional positional arguments for the task execution.
**kwargs: Additional keyword arguments for the task execution.
Returns:
List[Any]: The results of all executed tasks.
"""
try:
results = []
with Manager() as manager:
with Pool(
processes=self.max_workers, *args, **kwargs
) as pool:
# Using manager.list() to collect results in a process safe way
results_list = manager.list()
jobs = [
pool.apply_async(
self.execute_task, # Pass the function, not the function call
args=(task,)
+ args, # Pass the arguments as a tuple
kwds=kwargs, # Pass the keyword arguments as a dictionary
callback=results_list.append,
timeout=task.timeout,
)
for agent in self.agent
]
# Wait for all jobs to complete
for job in jobs:
job.get()
results = list(results_list)
return results
except Exception as error:
logger.error(f"Error in run: {error}")
return None