|
|
|
@ -3,6 +3,11 @@ from typing import List
|
|
|
|
|
from swarms.structs.base import BaseStructure
|
|
|
|
|
from swarms.structs.task import Task
|
|
|
|
|
|
|
|
|
|
import logging
|
|
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RecursiveWorkflow(BaseStructure):
|
|
|
|
|
"""
|
|
|
|
@ -28,7 +33,7 @@ class RecursiveWorkflow(BaseStructure):
|
|
|
|
|
|
|
|
|
|
def __init__(self, stop_token: str = "<DONE>"):
|
|
|
|
|
self.stop_token = stop_token
|
|
|
|
|
self.tasks = List[Task]
|
|
|
|
|
self.task_pool = List[Task]
|
|
|
|
|
|
|
|
|
|
assert (
|
|
|
|
|
self.stop_token is not None
|
|
|
|
@ -44,11 +49,19 @@ class RecursiveWorkflow(BaseStructure):
|
|
|
|
|
try:
|
|
|
|
|
if tasks:
|
|
|
|
|
for task in tasks:
|
|
|
|
|
self.tasks.append(task)
|
|
|
|
|
self.task_pool.append(task)
|
|
|
|
|
logger.info(
|
|
|
|
|
"[INFO][RecursiveWorkflow] Added task"
|
|
|
|
|
f" {task} to workflow"
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
self.tasks.append(task)
|
|
|
|
|
self.task_pool.append(task)
|
|
|
|
|
logger.info(
|
|
|
|
|
f"[INFO][RecursiveWorkflow] Added task {task} to"
|
|
|
|
|
" workflow"
|
|
|
|
|
)
|
|
|
|
|
except Exception as error:
|
|
|
|
|
print(f"[ERROR][ConcurrentWorkflow] {error}")
|
|
|
|
|
logger.warning(f"[ERROR][RecursiveWorkflow] {error}")
|
|
|
|
|
raise error
|
|
|
|
|
|
|
|
|
|
def run(self):
|
|
|
|
@ -59,11 +72,12 @@ class RecursiveWorkflow(BaseStructure):
|
|
|
|
|
None
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
for task in self.tasks:
|
|
|
|
|
for task in self.task_pool:
|
|
|
|
|
while True:
|
|
|
|
|
result = task.execute()
|
|
|
|
|
if self.stop_token in result:
|
|
|
|
|
break
|
|
|
|
|
logger.info(f"{result}")
|
|
|
|
|
except Exception as error:
|
|
|
|
|
print(f"[ERROR][RecursiveWorkflow] {error}")
|
|
|
|
|
logger.warning(f"[ERROR][RecursiveWorkflow] {error}")
|
|
|
|
|
raise error
|
|
|
|
|