From 2f48cfc071d0baa04d86ff883b531725bdaea5ea Mon Sep 17 00:00:00 2001 From: Kye Date: Sun, 31 Dec 2023 18:29:52 -0500 Subject: [PATCH] [FEATS] [BaseWorkflow] [try_except_wrapper] [+++] --- concurrent_workflow.py | 4 +- recursive_example.py | 4 +- swarms/__init__.py | 3 +- swarms/structs/base_workflow.py | 314 +++++++++++++++++++++++++ swarms/structs/nonlinear_workflow.py | 3 +- swarms/structs/recursive_workflow.py | 2 - swarms/utils/try_except_wrapper.py | 28 +++ tests/structs/test_base_workflow.py | 66 ++++++ tests/utils/test_try_except_wrapper.py | 45 ++++ 9 files changed, 461 insertions(+), 8 deletions(-) create mode 100644 swarms/utils/try_except_wrapper.py create mode 100644 tests/structs/test_base_workflow.py create mode 100644 tests/utils/test_try_except_wrapper.py diff --git a/concurrent_workflow.py b/concurrent_workflow.py index f152e4bb..a228d247 100644 --- a/concurrent_workflow.py +++ b/concurrent_workflow.py @@ -1,5 +1,5 @@ -import os -from dotenv import load_dotenv +import os +from dotenv import load_dotenv from swarms import OpenAIChat, Task, ConcurrentWorkflow, Agent # Load environment variables from .env file diff --git a/recursive_example.py b/recursive_example.py index 9ec182e0..9760b606 100644 --- a/recursive_example.py +++ b/recursive_example.py @@ -1,5 +1,5 @@ -import os -from dotenv import load_dotenv +import os +from dotenv import load_dotenv from swarms import OpenAIChat, Task, RecursiveWorkflow, Agent # Load environment variables from .env file diff --git a/swarms/__init__.py b/swarms/__init__.py index 3b53b810..bd7435f8 100644 --- a/swarms/__init__.py +++ b/swarms/__init__.py @@ -8,4 +8,5 @@ from swarms.models import * # noqa: E402, F403 from swarms.telemetry import * # noqa: E402, F403 from swarms.utils import * # noqa: E402, F403 from swarms.prompts import * # noqa: E402, F403 -# from swarms.cli import * # noqa: E402, F403 \ No newline at end of file + +# from swarms.cli import * # noqa: E402, F403 diff --git a/swarms/structs/base_workflow.py b/swarms/structs/base_workflow.py index e69de29b..d1457c99 100644 --- a/swarms/structs/base_workflow.py +++ b/swarms/structs/base_workflow.py @@ -0,0 +1,314 @@ +import json +from typing import Any, Dict, List, Optional + +from termcolor import colored + +from swarms.structs.base import BaseStructure +from swarms.structs.task import Task + + +class BaseWorkflow(BaseStructure): + """ + Base class for workflows. + + Attributes: + task_pool (list): A list to store tasks. + + Methods: + add(task: Task = None, tasks: List[Task] = None, *args, **kwargs): + Adds a task or a list of tasks to the task pool. + run(): + Abstract method to run the workflow. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.task_pool = [] + + def add(self, task: Task = None, tasks: List[Task] = None, *args, **kwargs): + """ + Adds a task or a list of tasks to the task pool. + + Args: + task (Task, optional): A single task to add. Defaults to None. + tasks (List[Task], optional): A list of tasks to add. Defaults to None. + + Raises: + ValueError: If neither task nor tasks are provided. + """ + if task: + self.task_pool.append(task) + elif tasks: + self.task_pool.extend(tasks) + else: + raise ValueError("You must provide a task or a list of tasks") + + def run(self): + """ + Abstract method to run the workflow. + """ + raise NotImplementedError("You must implement this method") + + def __sequential_loop(self): + """ + Abstract method for the sequential loop. + """ + # raise NotImplementedError("You must implement this method") + pass + + def __log(self, message: str): + """ + Logs a message if verbose mode is enabled. + + Args: + message (str): The message to log. + """ + if self.verbose: + print(message) + + def __str__(self): + return f"Workflow with {len(self.task_pool)} tasks" + + def __repr__(self): + return f"Workflow with {len(self.task_pool)} tasks" + + def reset(self) -> None: + """Resets the workflow by clearing the results of each task.""" + try: + for task in self.tasks: + task.result = None + except Exception as error: + print( + colored(f"Error resetting workflow: {error}", "red"), + ) + + def get_task_results(self) -> Dict[str, Any]: + """ + Returns the results of each task in the workflow. + + Returns: + Dict[str, Any]: The results of each task in the workflow + """ + try: + return { + task.description: task.result for task in self.tasks + } + except Exception as error: + print( + colored( + f"Error getting task results: {error}", "red" + ), + ) + + def remove_task(self, task: str) -> None: + """Remove tasks from sequential workflow""" + try: + self.tasks = [ + task + for task in self.tasks + if task.description != task + ] + except Exception as error: + print( + colored( + f"Error removing task from workflow: {error}", + "red", + ), + ) + + def update_task(self, task: str, **updates) -> None: + """ + Updates the arguments of a task in the workflow. + + Args: + task (str): The description of the task to update. + **updates: The updates to apply to the task. + + Raises: + ValueError: If the task is not found in the workflow. + + Examples: + >>> from swarms.models import OpenAIChat + >>> from swarms.structs import SequentialWorkflow + >>> llm = OpenAIChat(openai_api_key="") + >>> workflow = SequentialWorkflow(max_loops=1) + >>> workflow.add("What's the weather in miami", llm) + >>> workflow.add("Create a report on these metrics", llm) + >>> workflow.update_task("What's the weather in miami", max_tokens=1000) + >>> workflow.tasks[0].kwargs + {'max_tokens': 1000} + + """ + try: + for task in self.tasks: + if task.description == task: + task.kwargs.update(updates) + break + else: + raise ValueError( + f"Task {task} not found in workflow." + ) + except Exception as error: + print( + colored( + f"Error updating task in workflow: {error}", "red" + ), + ) + + def delete_task(self, task: str) -> None: + """ + Delete a task from the workflow. + + Args: + task (str): The description of the task to delete. + + Raises: + ValueError: If the task is not found in the workflow. + + Examples: + >>> from swarms.models import OpenAIChat + >>> from swarms.structs import SequentialWorkflow + >>> llm = OpenAIChat(openai_api_key="") + >>> workflow = SequentialWorkflow(max_loops=1) + >>> workflow.add("What's the weather in miami", llm) + >>> workflow.add("Create a report on these metrics", llm) + >>> workflow.delete_task("What's the weather in miami") + >>> workflow.tasks + [Task(description='Create a report on these metrics', agent=Agent(llm=OpenAIChat(openai_api_key=''), max_loops=1, dashboard=False), args=[], kwargs={}, result=None, history=[])] + """ + try: + for task in self.tasks: + if task.description == task: + self.tasks.remove(task) + break + else: + raise ValueError( + f"Task {task} not found in workflow." + ) + except Exception as error: + print( + colored( + f"Error deleting task from workflow: {error}", + "red", + ), + ) + + + def save_workflow_state( + self, + filepath: Optional[str] = "sequential_workflow_state.json", + **kwargs, + ) -> None: + """ + Saves the workflow state to a json file. + + Args: + filepath (str): The path to save the workflow state to. + + Examples: + >>> from swarms.models import OpenAIChat + >>> from swarms.structs import SequentialWorkflow + >>> llm = OpenAIChat(openai_api_key="") + >>> workflow = SequentialWorkflow(max_loops=1) + >>> workflow.add("What's the weather in miami", llm) + >>> workflow.add("Create a report on these metrics", llm) + >>> workflow.save_workflow_state("sequential_workflow_state.json") + """ + try: + filepath = filepath or self.saved_state_filepath + + with open(filepath, "w") as f: + # Saving the state as a json for simplicuty + state = { + "tasks": [ + { + "description": task.description, + "args": task.args, + "kwargs": task.kwargs, + "result": task.result, + "history": task.history, + } + for task in self.tasks + ], + "max_loops": self.max_loops, + } + json.dump(state, f, indent=4) + except Exception as error: + print( + colored( + f"Error saving workflow state: {error}", + "red", + ) + ) + + def add_objective_to_workflow(self, task: str, **kwargs) -> None: + """Adds an objective to the workflow.""" + try: + print( + colored( + """ + Adding Objective to Workflow...""", + "green", + attrs=["bold", "underline"], + ) + ) + + task = Task( + description=task, + agent=kwargs["agent"], + args=list(kwargs["args"]), + kwargs=kwargs["kwargs"], + ) + self.tasks.append(task) + except Exception as error: + print( + colored( + f"Error adding objective to workflow: {error}", + "red", + ) + ) + + def load_workflow_state( + self, filepath: str = None, **kwargs + ) -> None: + """ + Loads the workflow state from a json file and restores the workflow state. + + Args: + filepath (str): The path to load the workflow state from. + + Examples: + >>> from swarms.models import OpenAIChat + >>> from swarms.structs import SequentialWorkflow + >>> llm = OpenAIChat(openai_api_key="") + >>> workflow = SequentialWorkflow(max_loops=1) + >>> workflow.add("What's the weather in miami", llm) + >>> workflow.add("Create a report on these metrics", llm) + >>> workflow.save_workflow_state("sequential_workflow_state.json") + >>> workflow.load_workflow_state("sequential_workflow_state.json") + + """ + try: + filepath = filepath or self.restore_state_filepath + + with open(filepath, "r") as f: + state = json.load(f) + self.max_loops = state["max_loops"] + self.tasks = [] + for task_state in state["tasks"]: + task = Task( + description=task_state["description"], + agent=task_state["agent"], + args=task_state["args"], + kwargs=task_state["kwargs"], + result=task_state["result"], + history=task_state["history"], + ) + self.tasks.append(task) + except Exception as error: + print( + colored( + f"Error loading workflow state: {error}", + "red", + ) + ) diff --git a/swarms/structs/nonlinear_workflow.py b/swarms/structs/nonlinear_workflow.py index 8fb42cb0..e1724d09 100644 --- a/swarms/structs/nonlinear_workflow.py +++ b/swarms/structs/nonlinear_workflow.py @@ -24,6 +24,7 @@ class NonlinearWorkflow(BaseStruct): >>> workflow.run() """ + def __init__(self, stopping_token: str = ""): self.tasks = {} self.edges = {} @@ -81,7 +82,7 @@ class NonlinearWorkflow(BaseStruct): for deps in edges.values(): for task in ready_tasks: if task in deps: - deps.remove(task) + deps.remove(task) except Exception as error: print(f"[ERROR][NonlinearWorkflow] {error}") raise error diff --git a/swarms/structs/recursive_workflow.py b/swarms/structs/recursive_workflow.py index e4e0785f..38487ec0 100644 --- a/swarms/structs/recursive_workflow.py +++ b/swarms/structs/recursive_workflow.py @@ -34,7 +34,6 @@ class RecursiveWorkflow(BaseStruct): self.stop_token is not None ), "stop_token cannot be None" - def add(self, task: Task, tasks: List[Task] = None): """Adds a task to the workflow. @@ -52,7 +51,6 @@ class RecursiveWorkflow(BaseStruct): print(f"[ERROR][ConcurrentWorkflow] {error}") raise error - def run(self): """ Executes the tasks in the workflow until the stop token is encountered. diff --git a/swarms/utils/try_except_wrapper.py b/swarms/utils/try_except_wrapper.py new file mode 100644 index 00000000..a12b4393 --- /dev/null +++ b/swarms/utils/try_except_wrapper.py @@ -0,0 +1,28 @@ +def try_except_wrapper(func): + """ + A decorator that wraps a function with a try-except block. + It catches any exception that occurs during the execution of the function, + prints an error message, and returns None. + It also prints a message indicating the exit of the function. + + Args: + func (function): The function to be wrapped. + + Returns: + function: The wrapped function. + """ + + def wrapper(*args, **kwargs): + try: + result = func(*args, **kwargs) + return result + except Exception as error: + print( + f"An error occurred in function {func.__name__}:" + f" {error}" + ) + return None + finally: + print(f"Exiting function: {func.__name__}") + + return wrapper diff --git a/tests/structs/test_base_workflow.py b/tests/structs/test_base_workflow.py new file mode 100644 index 00000000..17be5ea8 --- /dev/null +++ b/tests/structs/test_base_workflow.py @@ -0,0 +1,66 @@ +import os +import pytest +import json +from swarms.models import OpenAIChat +from swarms.structs import BaseWorkflow + +from dotenv import load_dotenv + +load_dotenv() + +api_key = os.environ.get("OPENAI_API_KEY") + + +def setup_workflow(): + llm = OpenAIChat(openai_api_key=api_key) + workflow = BaseWorkflow(max_loops=1) + workflow.add("What's the weather in miami", llm) + workflow.add("Create a report on these metrics", llm) + workflow.save_workflow_state("workflow_state.json") + return workflow + + +def teardown_workflow(): + os.remove("workflow_state.json") + + +def test_load_workflow_state(): + workflow = setup_workflow() + workflow.load_workflow_state("workflow_state.json") + assert workflow.max_loops == 1 + assert len(workflow.tasks) == 2 + assert ( + workflow.tasks[0].description == "What's the weather in miami" + ) + assert ( + workflow.tasks[1].description + == "Create a report on these metrics" + ) + teardown_workflow() + + +def test_load_workflow_state_with_missing_file(): + workflow = setup_workflow() + with pytest.raises(FileNotFoundError): + workflow.load_workflow_state("non_existent_file.json") + teardown_workflow() + + +def test_load_workflow_state_with_invalid_file(): + workflow = setup_workflow() + with open("invalid_file.json", "w") as f: + f.write("This is not valid JSON") + with pytest.raises(json.JSONDecodeError): + workflow.load_workflow_state("invalid_file.json") + os.remove("invalid_file.json") + teardown_workflow() + + +def test_load_workflow_state_with_missing_keys(): + workflow = setup_workflow() + with open("missing_keys.json", "w") as f: + json.dump({"max_loops": 1}, f) + with pytest.raises(KeyError): + workflow.load_workflow_state("missing_keys.json") + os.remove("missing_keys.json") + teardown_workflow() diff --git a/tests/utils/test_try_except_wrapper.py b/tests/utils/test_try_except_wrapper.py new file mode 100644 index 00000000..26b509fb --- /dev/null +++ b/tests/utils/test_try_except_wrapper.py @@ -0,0 +1,45 @@ +from swarms.utils.try_except_wrapper import try_except_wrapper + + +def test_try_except_wrapper_with_no_exception(): + @try_except_wrapper + def add(x, y): + return x + y + + result = add(1, 2) + assert ( + result == 3 + ), "The function should return the sum of the arguments" + + +def test_try_except_wrapper_with_exception(): + @try_except_wrapper + def divide(x, y): + return x / y + + result = divide(1, 0) + assert ( + result is None + ), "The function should return None when an exception is raised" + + +def test_try_except_wrapper_with_multiple_arguments(): + @try_except_wrapper + def concatenate(*args): + return "".join(args) + + result = concatenate("Hello", " ", "world") + assert ( + result == "Hello world" + ), "The function should concatenate the arguments" + + +def test_try_except_wrapper_with_keyword_arguments(): + @try_except_wrapper + def greet(name="world"): + return f"Hello, {name}" + + result = greet(name="Alice") + assert ( + result == "Hello, Alice" + ), "The function should use the keyword arguments"