parent
df2be1d22e
commit
d6ce848d72
@ -1,99 +0,0 @@
|
|||||||
import logging
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from swarms.structs.base_structure import BaseStructure
|
|
||||||
from swarms.structs.task import Task
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class RecursiveWorkflow(BaseStructure):
|
|
||||||
"""
|
|
||||||
RecursiveWorkflow class for running a task recursively until a stopping condition is met.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task (Task): The task to execute.
|
|
||||||
stop_token (Any): The token that indicates when to stop the workflow.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
task (Task): The task to execute.
|
|
||||||
stop_token (Any): The token that indicates when to stop the workflow.
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> from swarm_models import OpenAIChat
|
|
||||||
>>> from swarms.structs import RecursiveWorkflow, Task
|
|
||||||
>>> llm = OpenAIChat(openai_api_key="")
|
|
||||||
>>> task = Task(llm, "What's the weather in miami")
|
|
||||||
>>> workflow = RecursiveWorkflow()
|
|
||||||
>>> workflow.add(task)
|
|
||||||
>>> workflow.run()
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
stop_token: str = "<DONE>",
|
|
||||||
stopping_conditions: callable = None,
|
|
||||||
max_loops: int = 1,
|
|
||||||
*args,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
self.stop_token = stop_token
|
|
||||||
self.stopping_conditions = stopping_conditions
|
|
||||||
self.task_pool = []
|
|
||||||
|
|
||||||
assert (
|
|
||||||
self.stop_token is not None
|
|
||||||
), "stop_token cannot be None"
|
|
||||||
|
|
||||||
def add(self, task: Task = None, tasks: List[Task] = None):
|
|
||||||
"""Adds a task to the workflow.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task (Task): _description_
|
|
||||||
tasks (List[Task]): _description_
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
if tasks:
|
|
||||||
for task in tasks:
|
|
||||||
if isinstance(task, Task):
|
|
||||||
self.task_pool.append(task)
|
|
||||||
logger.info(
|
|
||||||
"[INFO][RecursiveWorkflow] Added task"
|
|
||||||
f" {task} to workflow"
|
|
||||||
)
|
|
||||||
elif isinstance(task, Task):
|
|
||||||
self.task_pool.append(task)
|
|
||||||
logger.info(
|
|
||||||
f"[INFO][RecursiveWorkflow] Added task {task} to"
|
|
||||||
" workflow"
|
|
||||||
)
|
|
||||||
except Exception as error:
|
|
||||||
logger.warning(f"[ERROR][RecursiveWorkflow] {error}")
|
|
||||||
raise error
|
|
||||||
|
|
||||||
def run(self):
|
|
||||||
"""
|
|
||||||
Executes the tasks in the workflow until the stop token is encountered.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
None
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
loop = 0
|
|
||||||
while loop < self.max_loops:
|
|
||||||
for task in self.task_pool:
|
|
||||||
while True:
|
|
||||||
result = task.run()
|
|
||||||
if (
|
|
||||||
result is not None
|
|
||||||
and self.stop_token in result
|
|
||||||
):
|
|
||||||
break
|
|
||||||
print(f"{result}")
|
|
||||||
loop += 1
|
|
||||||
|
|
||||||
return result
|
|
||||||
except Exception as error:
|
|
||||||
logger.warning(f"[ERROR][RecursiveWorkflow] {error}")
|
|
||||||
raise error
|
|
Loading…
Reference in new issue