You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
394 lines
13 KiB
394 lines
13 KiB
4 months ago
|
from typing import List, Callable, Union, Optional
|
||
|
from loguru import logger
|
||
|
from swarms.structs.base_swarm import BaseSwarm
|
||
|
from queue import PriorityQueue
|
||
|
from concurrent.futures import (
|
||
|
ThreadPoolExecutor,
|
||
|
as_completed,
|
||
|
)
|
||
|
import time
|
||
|
from pydantic import BaseModel, Field
|
||
|
|
||
|
|
||
|
class SwarmRunData(BaseModel):
|
||
|
"""
|
||
|
Pydantic model to capture metadata about each swarm's execution.
|
||
|
"""
|
||
|
|
||
|
swarm_name: str
|
||
|
task: str
|
||
|
priority: int
|
||
|
start_time: Optional[float] = None
|
||
|
end_time: Optional[float] = None
|
||
|
duration: Optional[float] = None
|
||
|
status: str = "Pending"
|
||
|
retries: int = 0
|
||
|
result: Optional[str] = None
|
||
|
exception: Optional[str] = None
|
||
|
|
||
|
|
||
|
class FederatedSwarmModel(BaseModel):
|
||
|
"""
|
||
|
Pydantic base model to capture and log data for the FederatedSwarm system.
|
||
|
"""
|
||
|
|
||
|
task: str
|
||
|
swarms_data: List[SwarmRunData] = Field(default_factory=list)
|
||
|
|
||
|
def add_swarm(self, swarm_name: str, task: str, priority: int):
|
||
|
swarm_data = SwarmRunData(
|
||
|
swarm_name=swarm_name, task=task, priority=priority
|
||
|
)
|
||
|
self.swarms_data.append(swarm_data)
|
||
|
|
||
|
def update_swarm_status(
|
||
|
self,
|
||
|
swarm_name: str,
|
||
|
status: str,
|
||
|
start_time: float = None,
|
||
|
end_time: float = None,
|
||
|
retries: int = 0,
|
||
|
result: str = None,
|
||
|
exception: str = None,
|
||
|
):
|
||
|
for swarm in self.swarms_data:
|
||
|
if swarm.name == swarm_name:
|
||
|
swarm.status = status
|
||
|
if start_time:
|
||
|
swarm.start_time = start_time
|
||
|
if end_time:
|
||
|
swarm.end_time = end_time
|
||
|
swarm.duration = end_time - swarm.start_time
|
||
|
swarm.retries = retries
|
||
|
swarm.result = result
|
||
|
swarm.exception = exception
|
||
|
break
|
||
|
|
||
|
|
||
|
class FederatedSwarm:
|
||
|
def __init__(
|
||
|
self,
|
||
|
swarms: List[Union[BaseSwarm, Callable]],
|
||
|
max_workers: int = 4,
|
||
|
):
|
||
|
"""
|
||
|
Initializes the FederatedSwarm with a list of swarms or callable objects and
|
||
|
sets up a priority queue and thread pool for concurrency.
|
||
|
|
||
|
Args:
|
||
|
swarms (List[Union[BaseSwarm, Callable]]): A list of swarms (BaseSwarm) or callable objects.
|
||
|
max_workers (int): The maximum number of concurrent workers (threads) to run swarms in parallel.
|
||
|
"""
|
||
|
self.swarms = PriorityQueue()
|
||
|
self.max_workers = max_workers
|
||
|
self.thread_pool = ThreadPoolExecutor(
|
||
|
max_workers=self.max_workers
|
||
|
)
|
||
|
self.task_queue = []
|
||
|
self.future_to_swarm = {}
|
||
|
self.results = {}
|
||
|
self.validate_swarms(swarms)
|
||
|
|
||
|
def init_metadata(self, task: str):
|
||
|
"""
|
||
|
Initializes the Pydantic base model to capture metadata about the current task and swarms.
|
||
|
"""
|
||
|
self.metadata = FederatedSwarmModel(task=task)
|
||
|
for priority, swarm in list(self.swarms.queue):
|
||
|
swarm_name = (
|
||
|
swarm.__class__.__name__
|
||
|
if hasattr(swarm, "__class__")
|
||
|
else str(swarm)
|
||
|
)
|
||
|
self.metadata.add_swarm(
|
||
|
swarm_name=swarm_name, task=task, priority=priority
|
||
|
)
|
||
|
logger.info(f"Metadata initialized for task '{task}'.")
|
||
|
|
||
|
def validate_swarms(
|
||
|
self, swarms: List[Union[BaseSwarm, Callable]]
|
||
|
):
|
||
|
"""
|
||
|
Validates and adds swarms to the priority queue, ensuring each swarm has a `run(task)` method.
|
||
|
|
||
|
Args:
|
||
|
swarms (List[Union[BaseSwarm, Callable]]): List of swarms with an optional priority value.
|
||
|
"""
|
||
|
for swarm, priority in swarms:
|
||
|
if not callable(swarm):
|
||
|
raise TypeError(f"{swarm} is not callable.")
|
||
|
|
||
|
if hasattr(swarm, "run"):
|
||
|
logger.info(f"{swarm} has a 'run' method.")
|
||
|
else:
|
||
|
raise AttributeError(
|
||
|
f"{swarm} does not have a 'run(task)' method."
|
||
|
)
|
||
|
|
||
|
self.swarms.put((priority, swarm))
|
||
|
logger.info(
|
||
|
f"Swarm {swarm} added with priority {priority}."
|
||
|
)
|
||
|
|
||
|
def run_parallel(
|
||
|
self,
|
||
|
task: str,
|
||
|
timeout: Optional[float] = None,
|
||
|
retries: int = 0,
|
||
|
):
|
||
|
"""
|
||
|
Runs all swarms in parallel with prioritization and optional timeout.
|
||
|
|
||
|
Args:
|
||
|
task (str): The task to be passed to the `run` method of each swarm.
|
||
|
timeout (Optional[float]): Maximum time allowed for each swarm to run.
|
||
|
retries (int): Number of retries allowed for failed swarms.
|
||
|
"""
|
||
|
logger.info(
|
||
|
f"Running task '{task}' in parallel with timeout: {timeout}, retries: {retries}"
|
||
|
)
|
||
|
self.init_metadata(task)
|
||
|
|
||
|
while not self.swarms.empty():
|
||
|
priority, swarm = self.swarms.get()
|
||
|
swarm_name = (
|
||
|
swarm.__class__.__name__
|
||
|
if hasattr(swarm, "__class__")
|
||
|
else str(swarm)
|
||
|
)
|
||
|
future = self.thread_pool.submit(
|
||
|
self._run_with_retry,
|
||
|
swarm,
|
||
|
task,
|
||
|
retries,
|
||
|
timeout,
|
||
|
swarm_name,
|
||
|
)
|
||
|
self.future_to_swarm[future] = swarm
|
||
|
|
||
|
for future in as_completed(self.future_to_swarm):
|
||
|
swarm = self.future_to_swarm[future]
|
||
|
try:
|
||
|
result = future.result()
|
||
|
swarm_name = (
|
||
|
swarm.__class__.__name__
|
||
|
if hasattr(swarm, "__class__")
|
||
|
else str(swarm)
|
||
|
)
|
||
|
self.metadata.update_swarm_status(
|
||
|
swarm_name=swarm_name,
|
||
|
status="Completed",
|
||
|
result=result,
|
||
|
)
|
||
|
logger.info(
|
||
|
f"Swarm {swarm_name} completed successfully."
|
||
|
)
|
||
|
except Exception as e:
|
||
|
swarm_name = (
|
||
|
swarm.__class__.__name__
|
||
|
if hasattr(swarm, "__class__")
|
||
|
else str(swarm)
|
||
|
)
|
||
|
self.metadata.update_swarm_status(
|
||
|
swarm_name=swarm_name,
|
||
|
status="Failed",
|
||
|
exception=str(e),
|
||
|
)
|
||
|
logger.error(f"Swarm {swarm_name} failed: {e}")
|
||
|
self.results[swarm] = "Failed"
|
||
|
|
||
|
def run_sequentially(
|
||
|
self,
|
||
|
task: str,
|
||
|
retries: int = 0,
|
||
|
timeout: Optional[float] = None,
|
||
|
):
|
||
|
"""
|
||
|
Runs all swarms sequentially in order of priority.
|
||
|
|
||
|
Args:
|
||
|
task (str): The task to pass to the `run` method of each swarm.
|
||
|
retries (int): Number of retries for failed swarms.
|
||
|
timeout (Optional[float]): Optional time limit for each swarm.
|
||
|
"""
|
||
|
logger.info(f"Running task '{task}' sequentially.")
|
||
|
|
||
|
while not self.swarms.empty():
|
||
|
priority, swarm = self.swarms.get()
|
||
|
try:
|
||
|
logger.info(
|
||
|
f"Running swarm {swarm} with priority {priority}."
|
||
|
)
|
||
|
self._run_with_retry(swarm, task, retries, timeout)
|
||
|
logger.info(f"Swarm {swarm} completed successfully.")
|
||
|
except Exception as e:
|
||
|
logger.error(f"Swarm {swarm} failed with error: {e}")
|
||
|
|
||
|
def _run_with_retry(
|
||
|
self,
|
||
|
swarm: Union[BaseSwarm, Callable],
|
||
|
task: str,
|
||
|
retries: int,
|
||
|
timeout: Optional[float],
|
||
|
swarm_name: str,
|
||
|
):
|
||
|
"""
|
||
|
Helper function to run a swarm with a retry mechanism and optional timeout.
|
||
|
|
||
|
Args:
|
||
|
swarm (Union[BaseSwarm, Callable]): The swarm to run.
|
||
|
task (str): The task to pass to the swarm.
|
||
|
retries (int): The number of retries allowed for the swarm in case of failure.
|
||
|
timeout (Optional[float]): Maximum time allowed for the swarm to run.
|
||
|
swarm_name (str): Name of the swarm (used for metadata).
|
||
|
"""
|
||
|
attempts = 0
|
||
|
start_time = time.time()
|
||
|
while attempts <= retries:
|
||
|
try:
|
||
|
logger.info(
|
||
|
f"Running swarm {swarm}. Attempt: {attempts + 1}"
|
||
|
)
|
||
|
self.metadata.update_swarm_status(
|
||
|
swarm_name=swarm_name,
|
||
|
status="Running",
|
||
|
start_time=start_time,
|
||
|
)
|
||
|
if hasattr(swarm, "run"):
|
||
|
if timeout:
|
||
|
start_time = time.time()
|
||
|
swarm.run(task)
|
||
|
duration = time.time() - start_time
|
||
|
if duration > timeout:
|
||
|
raise TimeoutError(
|
||
|
f"Swarm {swarm} timed out after {duration:.2f}s."
|
||
|
)
|
||
|
else:
|
||
|
swarm.run(task)
|
||
|
else:
|
||
|
swarm(task)
|
||
|
end_time = time.time()
|
||
|
self.metadata.update_swarm_status(
|
||
|
swarm_name=swarm_name,
|
||
|
status="Completed",
|
||
|
end_time=end_time,
|
||
|
retries=attempts,
|
||
|
)
|
||
|
return "Success"
|
||
|
except Exception as e:
|
||
|
logger.error(f"Swarm {swarm} failed: {e}")
|
||
|
attempts += 1
|
||
|
if attempts > retries:
|
||
|
end_time = time.time()
|
||
|
self.metadata.update_swarm_status(
|
||
|
swarm_name=swarm_name,
|
||
|
status="Failed",
|
||
|
end_time=end_time,
|
||
|
retries=attempts,
|
||
|
exception=str(e),
|
||
|
)
|
||
|
logger.error(f"Swarm {swarm} exhausted retries.")
|
||
|
raise
|
||
|
|
||
|
def add_swarm(
|
||
|
self, swarm: Union[BaseSwarm, Callable], priority: int
|
||
|
):
|
||
|
"""
|
||
|
Adds a new swarm to the FederatedSwarm at runtime.
|
||
|
|
||
|
Args:
|
||
|
swarm (Union[BaseSwarm, Callable]): The swarm to add.
|
||
|
priority (int): The priority level for the swarm.
|
||
|
"""
|
||
|
self.swarms.put((priority, swarm))
|
||
|
logger.info(
|
||
|
f"Swarm {swarm} added dynamically with priority {priority}."
|
||
|
)
|
||
|
|
||
|
def queue_task(self, task: str):
|
||
|
"""
|
||
|
Adds a task to the internal task queue for batch processing.
|
||
|
|
||
|
Args:
|
||
|
task (str): The task to queue.
|
||
|
"""
|
||
|
self.task_queue.append(task)
|
||
|
logger.info(f"Task '{task}' added to the queue.")
|
||
|
|
||
|
def process_task_queue(self):
|
||
|
"""
|
||
|
Processes all tasks in the task queue.
|
||
|
"""
|
||
|
for task in self.task_queue:
|
||
|
logger.info(f"Processing task: {task}")
|
||
|
self.run_parallel(task)
|
||
|
self.task_queue = []
|
||
|
|
||
|
def log_swarm_results(self):
|
||
|
"""
|
||
|
Logs the results of all swarms after execution.
|
||
|
"""
|
||
|
logger.info("Logging swarm results...")
|
||
|
for swarm, result in self.results.items():
|
||
|
logger.info(f"Swarm {swarm}: {result}")
|
||
|
|
||
|
def get_swarm_status(self) -> dict:
|
||
|
"""
|
||
|
Retrieves the status of each swarm (completed, running, failed).
|
||
|
|
||
|
Returns:
|
||
|
dict: Dictionary containing swarm statuses.
|
||
|
"""
|
||
|
status = {}
|
||
|
for future, swarm in self.future_to_swarm.items():
|
||
|
if future.done():
|
||
|
status[swarm] = "Completed"
|
||
|
elif future.running():
|
||
|
status[swarm] = "Running"
|
||
|
else:
|
||
|
status[swarm] = "Failed"
|
||
|
return status
|
||
|
|
||
|
def cancel_running_swarms(self):
|
||
|
"""
|
||
|
Cancels all currently running swarms by shutting down the thread pool.
|
||
|
"""
|
||
|
logger.warning("Cancelling all running swarms...")
|
||
|
self.thread_pool.shutdown(wait=False)
|
||
|
logger.info("All running swarms cancelled.")
|
||
|
|
||
|
|
||
|
# Example Usage:
|
||
|
|
||
|
|
||
|
# class ExampleSwarm(BaseSwarm):
|
||
|
# def run(self, task: str):
|
||
|
# logger.info(f"ExampleSwarm is processing task: {task}")
|
||
|
|
||
|
|
||
|
# def example_callable(task: str):
|
||
|
# logger.info(f"Callable is processing task: {task}")
|
||
|
|
||
|
|
||
|
# if __name__ == "__main__":
|
||
|
# swarms = [(ExampleSwarm(), 1), (example_callable, 2)]
|
||
|
# federated_swarm = FederatedSwarm(swarms)
|
||
|
|
||
|
# # Run in parallel
|
||
|
# federated_swarm.run_parallel(
|
||
|
# "Process data", timeout=10, retries=3
|
||
|
# )
|
||
|
|
||
|
# # Run sequentially
|
||
|
# federated_swarm.run_sequentially("Process data sequentially")
|
||
|
|
||
|
# # Log results
|
||
|
# federated_swarm.log_swarm_results()
|
||
|
|
||
|
# # Get status of swarms
|
||
|
# status = federated_swarm.get_swarm_status()
|
||
|
# logger.info(f"Swarm statuses: {status}")
|
||
|
|
||
|
# # Cancel running swarms (if needed)
|
||
|
# # federated_swarm.cancel_running_swarms()
|