|
|
|
@ -5,6 +5,8 @@ from typing import Dict, List, Optional
|
|
|
|
|
from swarms.structs.base import BaseStructure
|
|
|
|
|
from swarms.structs.task import Task
|
|
|
|
|
|
|
|
|
|
from swarms.utils.logger import logger
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class ConcurrentWorkflow(BaseStructure):
|
|
|
|
@ -51,11 +53,17 @@ class ConcurrentWorkflow(BaseStructure):
|
|
|
|
|
if tasks:
|
|
|
|
|
for task in tasks:
|
|
|
|
|
self.task_pool.append(task)
|
|
|
|
|
logger.info(
|
|
|
|
|
f"Added task {task} to ConcurrentWorkflow."
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
if task:
|
|
|
|
|
self.task_pool.append(task)
|
|
|
|
|
logger.info(
|
|
|
|
|
f"Added task {task} to ConcurrentWorkflow."
|
|
|
|
|
)
|
|
|
|
|
except Exception as error:
|
|
|
|
|
print(f"[ERROR][ConcurrentWorkflow] {error}")
|
|
|
|
|
logger.warning(f"[ERROR][ConcurrentWorkflow] {error}")
|
|
|
|
|
raise error
|
|
|
|
|
|
|
|
|
|
def run(self):
|
|
|
|
@ -90,3 +98,12 @@ class ConcurrentWorkflow(BaseStructure):
|
|
|
|
|
print(f"Task {task} generated an exception: {e}")
|
|
|
|
|
|
|
|
|
|
return results if self.return_results else None
|
|
|
|
|
|
|
|
|
|
def list_tasks(self):
|
|
|
|
|
"""Prints a list of the tasks in the workflow."""
|
|
|
|
|
for task in self.task_pool:
|
|
|
|
|
logger.info(task)
|
|
|
|
|
|
|
|
|
|
def save(self):
|
|
|
|
|
"""Saves the state of the workflow to a file."""
|
|
|
|
|
self.save_state(self.saved_state_filepath)
|
|
|
|
|