[FEATS] [BaseWorkflow] [try_except_wrapper] [+++]

pull/343/head
Kye 1 year ago
parent 311d47ef4d
commit 2f48cfc071

@ -8,4 +8,5 @@ from swarms.models import * # noqa: E402, F403
from swarms.telemetry import * # noqa: E402, F403 from swarms.telemetry import * # noqa: E402, F403
from swarms.utils import * # noqa: E402, F403 from swarms.utils import * # noqa: E402, F403
from swarms.prompts import * # noqa: E402, F403 from swarms.prompts import * # noqa: E402, F403
# from swarms.cli import * # noqa: E402, F403 # from swarms.cli import * # noqa: E402, F403

@ -0,0 +1,314 @@
import json
from typing import Any, Dict, List, Optional
from termcolor import colored
from swarms.structs.base import BaseStructure
from swarms.structs.task import Task
class BaseWorkflow(BaseStructure):
"""
Base class for workflows.
Attributes:
task_pool (list): A list to store tasks.
Methods:
add(task: Task = None, tasks: List[Task] = None, *args, **kwargs):
Adds a task or a list of tasks to the task pool.
run():
Abstract method to run the workflow.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.task_pool = []
def add(self, task: Task = None, tasks: List[Task] = None, *args, **kwargs):
"""
Adds a task or a list of tasks to the task pool.
Args:
task (Task, optional): A single task to add. Defaults to None.
tasks (List[Task], optional): A list of tasks to add. Defaults to None.
Raises:
ValueError: If neither task nor tasks are provided.
"""
if task:
self.task_pool.append(task)
elif tasks:
self.task_pool.extend(tasks)
else:
raise ValueError("You must provide a task or a list of tasks")
def run(self):
"""
Abstract method to run the workflow.
"""
raise NotImplementedError("You must implement this method")
def __sequential_loop(self):
"""
Abstract method for the sequential loop.
"""
# raise NotImplementedError("You must implement this method")
pass
def __log(self, message: str):
"""
Logs a message if verbose mode is enabled.
Args:
message (str): The message to log.
"""
if self.verbose:
print(message)
def __str__(self):
return f"Workflow with {len(self.task_pool)} tasks"
def __repr__(self):
return f"Workflow with {len(self.task_pool)} tasks"
def reset(self) -> None:
"""Resets the workflow by clearing the results of each task."""
try:
for task in self.tasks:
task.result = None
except Exception as error:
print(
colored(f"Error resetting workflow: {error}", "red"),
)
def get_task_results(self) -> Dict[str, Any]:
"""
Returns the results of each task in the workflow.
Returns:
Dict[str, Any]: The results of each task in the workflow
"""
try:
return {
task.description: task.result for task in self.tasks
}
except Exception as error:
print(
colored(
f"Error getting task results: {error}", "red"
),
)
def remove_task(self, task: str) -> None:
"""Remove tasks from sequential workflow"""
try:
self.tasks = [
task
for task in self.tasks
if task.description != task
]
except Exception as error:
print(
colored(
f"Error removing task from workflow: {error}",
"red",
),
)
def update_task(self, task: str, **updates) -> None:
"""
Updates the arguments of a task in the workflow.
Args:
task (str): The description of the task to update.
**updates: The updates to apply to the task.
Raises:
ValueError: If the task is not found in the workflow.
Examples:
>>> from swarms.models import OpenAIChat
>>> from swarms.structs import SequentialWorkflow
>>> llm = OpenAIChat(openai_api_key="")
>>> workflow = SequentialWorkflow(max_loops=1)
>>> workflow.add("What's the weather in miami", llm)
>>> workflow.add("Create a report on these metrics", llm)
>>> workflow.update_task("What's the weather in miami", max_tokens=1000)
>>> workflow.tasks[0].kwargs
{'max_tokens': 1000}
"""
try:
for task in self.tasks:
if task.description == task:
task.kwargs.update(updates)
break
else:
raise ValueError(
f"Task {task} not found in workflow."
)
except Exception as error:
print(
colored(
f"Error updating task in workflow: {error}", "red"
),
)
def delete_task(self, task: str) -> None:
"""
Delete a task from the workflow.
Args:
task (str): The description of the task to delete.
Raises:
ValueError: If the task is not found in the workflow.
Examples:
>>> from swarms.models import OpenAIChat
>>> from swarms.structs import SequentialWorkflow
>>> llm = OpenAIChat(openai_api_key="")
>>> workflow = SequentialWorkflow(max_loops=1)
>>> workflow.add("What's the weather in miami", llm)
>>> workflow.add("Create a report on these metrics", llm)
>>> workflow.delete_task("What's the weather in miami")
>>> workflow.tasks
[Task(description='Create a report on these metrics', agent=Agent(llm=OpenAIChat(openai_api_key=''), max_loops=1, dashboard=False), args=[], kwargs={}, result=None, history=[])]
"""
try:
for task in self.tasks:
if task.description == task:
self.tasks.remove(task)
break
else:
raise ValueError(
f"Task {task} not found in workflow."
)
except Exception as error:
print(
colored(
f"Error deleting task from workflow: {error}",
"red",
),
)
def save_workflow_state(
self,
filepath: Optional[str] = "sequential_workflow_state.json",
**kwargs,
) -> None:
"""
Saves the workflow state to a json file.
Args:
filepath (str): The path to save the workflow state to.
Examples:
>>> from swarms.models import OpenAIChat
>>> from swarms.structs import SequentialWorkflow
>>> llm = OpenAIChat(openai_api_key="")
>>> workflow = SequentialWorkflow(max_loops=1)
>>> workflow.add("What's the weather in miami", llm)
>>> workflow.add("Create a report on these metrics", llm)
>>> workflow.save_workflow_state("sequential_workflow_state.json")
"""
try:
filepath = filepath or self.saved_state_filepath
with open(filepath, "w") as f:
# Saving the state as a json for simplicuty
state = {
"tasks": [
{
"description": task.description,
"args": task.args,
"kwargs": task.kwargs,
"result": task.result,
"history": task.history,
}
for task in self.tasks
],
"max_loops": self.max_loops,
}
json.dump(state, f, indent=4)
except Exception as error:
print(
colored(
f"Error saving workflow state: {error}",
"red",
)
)
def add_objective_to_workflow(self, task: str, **kwargs) -> None:
"""Adds an objective to the workflow."""
try:
print(
colored(
"""
Adding Objective to Workflow...""",
"green",
attrs=["bold", "underline"],
)
)
task = Task(
description=task,
agent=kwargs["agent"],
args=list(kwargs["args"]),
kwargs=kwargs["kwargs"],
)
self.tasks.append(task)
except Exception as error:
print(
colored(
f"Error adding objective to workflow: {error}",
"red",
)
)
def load_workflow_state(
self, filepath: str = None, **kwargs
) -> None:
"""
Loads the workflow state from a json file and restores the workflow state.
Args:
filepath (str): The path to load the workflow state from.
Examples:
>>> from swarms.models import OpenAIChat
>>> from swarms.structs import SequentialWorkflow
>>> llm = OpenAIChat(openai_api_key="")
>>> workflow = SequentialWorkflow(max_loops=1)
>>> workflow.add("What's the weather in miami", llm)
>>> workflow.add("Create a report on these metrics", llm)
>>> workflow.save_workflow_state("sequential_workflow_state.json")
>>> workflow.load_workflow_state("sequential_workflow_state.json")
"""
try:
filepath = filepath or self.restore_state_filepath
with open(filepath, "r") as f:
state = json.load(f)
self.max_loops = state["max_loops"]
self.tasks = []
for task_state in state["tasks"]:
task = Task(
description=task_state["description"],
agent=task_state["agent"],
args=task_state["args"],
kwargs=task_state["kwargs"],
result=task_state["result"],
history=task_state["history"],
)
self.tasks.append(task)
except Exception as error:
print(
colored(
f"Error loading workflow state: {error}",
"red",
)
)

@ -24,6 +24,7 @@ class NonlinearWorkflow(BaseStruct):
>>> workflow.run() >>> workflow.run()
""" """
def __init__(self, stopping_token: str = "<DONE>"): def __init__(self, stopping_token: str = "<DONE>"):
self.tasks = {} self.tasks = {}
self.edges = {} self.edges = {}

@ -34,7 +34,6 @@ class RecursiveWorkflow(BaseStruct):
self.stop_token is not None self.stop_token is not None
), "stop_token cannot be None" ), "stop_token cannot be None"
def add(self, task: Task, tasks: List[Task] = None): def add(self, task: Task, tasks: List[Task] = None):
"""Adds a task to the workflow. """Adds a task to the workflow.
@ -52,7 +51,6 @@ class RecursiveWorkflow(BaseStruct):
print(f"[ERROR][ConcurrentWorkflow] {error}") print(f"[ERROR][ConcurrentWorkflow] {error}")
raise error raise error
def run(self): def run(self):
""" """
Executes the tasks in the workflow until the stop token is encountered. Executes the tasks in the workflow until the stop token is encountered.

@ -0,0 +1,28 @@
def try_except_wrapper(func):
"""
A decorator that wraps a function with a try-except block.
It catches any exception that occurs during the execution of the function,
prints an error message, and returns None.
It also prints a message indicating the exit of the function.
Args:
func (function): The function to be wrapped.
Returns:
function: The wrapped function.
"""
def wrapper(*args, **kwargs):
try:
result = func(*args, **kwargs)
return result
except Exception as error:
print(
f"An error occurred in function {func.__name__}:"
f" {error}"
)
return None
finally:
print(f"Exiting function: {func.__name__}")
return wrapper

@ -0,0 +1,66 @@
import os
import pytest
import json
from swarms.models import OpenAIChat
from swarms.structs import BaseWorkflow
from dotenv import load_dotenv
load_dotenv()
api_key = os.environ.get("OPENAI_API_KEY")
def setup_workflow():
llm = OpenAIChat(openai_api_key=api_key)
workflow = BaseWorkflow(max_loops=1)
workflow.add("What's the weather in miami", llm)
workflow.add("Create a report on these metrics", llm)
workflow.save_workflow_state("workflow_state.json")
return workflow
def teardown_workflow():
os.remove("workflow_state.json")
def test_load_workflow_state():
workflow = setup_workflow()
workflow.load_workflow_state("workflow_state.json")
assert workflow.max_loops == 1
assert len(workflow.tasks) == 2
assert (
workflow.tasks[0].description == "What's the weather in miami"
)
assert (
workflow.tasks[1].description
== "Create a report on these metrics"
)
teardown_workflow()
def test_load_workflow_state_with_missing_file():
workflow = setup_workflow()
with pytest.raises(FileNotFoundError):
workflow.load_workflow_state("non_existent_file.json")
teardown_workflow()
def test_load_workflow_state_with_invalid_file():
workflow = setup_workflow()
with open("invalid_file.json", "w") as f:
f.write("This is not valid JSON")
with pytest.raises(json.JSONDecodeError):
workflow.load_workflow_state("invalid_file.json")
os.remove("invalid_file.json")
teardown_workflow()
def test_load_workflow_state_with_missing_keys():
workflow = setup_workflow()
with open("missing_keys.json", "w") as f:
json.dump({"max_loops": 1}, f)
with pytest.raises(KeyError):
workflow.load_workflow_state("missing_keys.json")
os.remove("missing_keys.json")
teardown_workflow()

@ -0,0 +1,45 @@
from swarms.utils.try_except_wrapper import try_except_wrapper
def test_try_except_wrapper_with_no_exception():
@try_except_wrapper
def add(x, y):
return x + y
result = add(1, 2)
assert (
result == 3
), "The function should return the sum of the arguments"
def test_try_except_wrapper_with_exception():
@try_except_wrapper
def divide(x, y):
return x / y
result = divide(1, 0)
assert (
result is None
), "The function should return None when an exception is raised"
def test_try_except_wrapper_with_multiple_arguments():
@try_except_wrapper
def concatenate(*args):
return "".join(args)
result = concatenate("Hello", " ", "world")
assert (
result == "Hello world"
), "The function should concatenate the arguments"
def test_try_except_wrapper_with_keyword_arguments():
@try_except_wrapper
def greet(name="world"):
return f"Hello, {name}"
result = greet(name="Alice")
assert (
result == "Hello, Alice"
), "The function should use the keyword arguments"
Loading…
Cancel
Save