diff --git a/concurrent_swarm_example.py b/concurrent_swarm_example.py new file mode 100644 index 00000000..724346d4 --- /dev/null +++ b/concurrent_swarm_example.py @@ -0,0 +1,33 @@ +from swarms import Agent, ConcurrentWorkflow +from swarms.prompts.finance_agent_sys_prompt import ( + FINANCIAL_AGENT_SYS_PROMPT, +) + +if __name__ == "__main__": + # Assuming you've already initialized some agents outside of this class + agents = [ + Agent( + agent_name=f"Financial-Analysis-Agent-{i}", + system_prompt=FINANCIAL_AGENT_SYS_PROMPT, + model_name="gpt-4o", + max_loops=1, + ) + for i in range(3) # Adjust number of agents as needed + ] + + # Initialize the workflow with the list of agents + workflow = ConcurrentWorkflow( + agents=agents, + metadata_output_path="agent_metadata_4.json", + output_type="list", + show_progress=False, + max_loops=3, + interactive=True, + ) + + # Define the task for all agents + task = "How can I establish a ROTH IRA to buy stocks and get a tax break? What are the criteria?" + + # Run the workflow and save metadata + metadata = workflow.run(task) + print(metadata) diff --git a/docs/swarms/structs/concurrentworkflow.md b/docs/swarms/structs/concurrentworkflow.md index a517177b..48297ed3 100644 --- a/docs/swarms/structs/concurrentworkflow.md +++ b/docs/swarms/structs/concurrentworkflow.md @@ -6,16 +6,17 @@ The `ConcurrentWorkflow` class is designed to facilitate the concurrent executio ### Key Features -- **Concurrent Execution**: Runs multiple agents simultaneously using Python's `asyncio` and `ThreadPoolExecutor`. -- **Metadata Collection**: Gathers detailed metadata about each agent's execution, including start and end times, duration, and output. -- **Customizable Output**: Allows the user to save metadata to a file or return it as a string or dictionary. -- **Error Handling**: Implements retry logic for improved reliability. -- **Batch Processing**: Supports running tasks in batches and parallel execution. -- **Asynchronous Execution**: Provides asynchronous run options for improved performance. - -## Class Definitions - -The `ConcurrentWorkflow` class is the core class that manages the concurrent execution of agents. It inherits from `BaseSwarm` and includes several key attributes and methods to facilitate this process. +- **Concurrent Execution**: Runs multiple agents simultaneously using Python's `ThreadPoolExecutor` +- **Interactive Mode**: Supports interactive task modification and execution +- **Caching System**: Implements LRU caching for repeated prompts +- **Progress Tracking**: Optional progress bar for task execution +- **Enhanced Error Handling**: Implements retry mechanism with exponential backoff +- **Input Validation**: Validates task inputs before execution +- **Batch Processing**: Supports running tasks in batches +- **Metadata Collection**: Gathers detailed metadata about each agent's execution +- **Customizable Output**: Allows saving metadata to file or returning as string/dictionary + +## Class Definition ### Attributes @@ -26,15 +27,18 @@ The `ConcurrentWorkflow` class is the core class that manages the concurrent exe | `agents` | `List[Agent]` | A list of agents to be executed concurrently. | | `metadata_output_path` | `str` | Path to save the metadata output. Defaults to `"agent_metadata.json"`. | | `auto_save` | `bool` | Flag indicating whether to automatically save the metadata. | -| `output_schema` | `BaseModel` | The output schema for the metadata, defaults to `MetadataSchema`. | -| `max_loops` | `int` | Maximum number of loops for the workflow, defaults to `1`. | +| `output_type` | `str` | The type of output format. Defaults to `"dict"`. | +| `max_loops` | `int` | Maximum number of loops for each agent. Defaults to `1`. | | `return_str_on` | `bool` | Flag to return output as string. Defaults to `False`. | -| `agent_responses` | `List[str]` | List of agent responses as strings. | | `auto_generate_prompts`| `bool` | Flag indicating whether to auto-generate prompts for agents. | -| `output_type` | `OutputType` | Type of output format to return. Defaults to `"dict"`. | | `return_entire_history`| `bool` | Flag to return entire conversation history. Defaults to `False`. | -| `conversation` | `Conversation` | Conversation object to track agent interactions. | -| `max_workers` | `int` | Maximum number of worker threads. Defaults to CPU count. | +| `interactive` | `bool` | Flag indicating whether to enable interactive mode. Defaults to `False`. | +| `cache_size` | `int` | The size of the cache. Defaults to `100`. | +| `max_retries` | `int` | The maximum number of retry attempts. Defaults to `3`. | +| `retry_delay` | `float` | The delay between retry attempts in seconds. Defaults to `1.0`. | +| `show_progress` | `bool` | Flag indicating whether to show progress. Defaults to `False`. | +| `_cache` | `dict` | The cache for storing agent outputs. | +| `_progress_bar` | `tqdm` | The progress bar for tracking execution. | ## Methods @@ -51,53 +55,72 @@ Initializes the `ConcurrentWorkflow` class with the provided parameters. | `agents` | `List[Agent]` | `[]` | A list of agents to be executed concurrently. | | `metadata_output_path`| `str` | `"agent_metadata.json"` | Path to save the metadata output. | | `auto_save` | `bool` | `True` | Flag indicating whether to automatically save the metadata. | -| `output_schema` | `BaseModel` | `MetadataSchema` | The output schema for the metadata. | -| `max_loops` | `int` | `1` | Maximum number of loops for the workflow. | +| `output_type` | `str` | `"dict"` | The type of output format. | +| `max_loops` | `int` | `1` | Maximum number of loops for each agent. | | `return_str_on` | `bool` | `False` | Flag to return output as string. | -| `agent_responses` | `List[str]` | `[]` | List of agent responses as strings. | | `auto_generate_prompts`| `bool` | `False` | Flag indicating whether to auto-generate prompts for agents. | -| `output_type` | `OutputType` | `"dict"` | Type of output format to return. | | `return_entire_history`| `bool` | `False` | Flag to return entire conversation history. | +| `interactive` | `bool` | `False` | Flag indicating whether to enable interactive mode. | +| `cache_size` | `int` | `100` | The size of the cache. | +| `max_retries` | `int` | `3` | The maximum number of retry attempts. | +| `retry_delay` | `float` | `1.0` | The delay between retry attempts in seconds. | +| `show_progress` | `bool` | `False` | Flag indicating whether to show progress. | #### Raises - `ValueError`: If the list of agents is empty or if the description is empty. +### ConcurrentWorkflow.disable_agent_prints + +Disables print statements for all agents in the workflow. + +```python +workflow.disable_agent_prints() +``` + ### ConcurrentWorkflow.activate_auto_prompt_engineering Activates the auto-generate prompts feature for all agents in the workflow. -#### Example - ```python -workflow = ConcurrentWorkflow(agents=[Agent()]) workflow.activate_auto_prompt_engineering() -# All agents in the workflow will now auto-generate prompts. ``` -### ConcurrentWorkflow.transform_metadata_schema_to_str +### ConcurrentWorkflow.enable_progress_bar -Transforms the metadata schema into a string format. +Enables the progress bar display for task execution. -#### Parameters +```python +workflow.enable_progress_bar() +``` -| Parameter | Type | Description | -|-------------|---------------------|-----------------------------------------------------------| -| `schema` | `MetadataSchema` | The metadata schema to transform. | +### ConcurrentWorkflow.disable_progress_bar -#### Returns +Disables the progress bar display. + +```python +workflow.disable_progress_bar() +``` + +### ConcurrentWorkflow.clear_cache -- `str`: The metadata schema as a formatted string. +Clears the task cache. -### ConcurrentWorkflow.save_metadata +```python +workflow.clear_cache() +``` -Saves the metadata to a JSON file based on the `auto_save` flag. +### ConcurrentWorkflow.get_cache_stats -#### Example +Gets cache statistics. + +#### Returns + +- `Dict[str, int]`: A dictionary containing cache statistics. ```python -workflow.save_metadata() -# Metadata will be saved to the specified path if auto_save is True. +stats = workflow.get_cache_stats() +print(stats) # {'cache_size': 5, 'max_cache_size': 100} ``` ### ConcurrentWorkflow.run @@ -134,136 +157,71 @@ Runs the workflow for a batch of tasks. #### Returns -- `List[Union[Dict[str, Any], str]]`: A list of final metadata for each task. - -#### Example - -```python -tasks = ["Task 1", "Task 2"] -results = workflow.run_batched(tasks) -print(results) -``` - - +- `List[Any]`: A list of results for each task. ## Usage Examples -### Example 1: Basic Usage +### Example 1: Basic Usage with Interactive Mode ```python -import os +from swarms import Agent, ConcurrentWorkflow -from swarms import Agent, ConcurrentWorkflow, OpenAIChat -# Define custom system prompts for each social media platform -TWITTER_AGENT_SYS_PROMPT = """ -You are a Twitter marketing expert specializing in real estate. Your task is to create engaging, concise tweets to promote properties, analyze trends to maximize engagement, and use appropriate hashtags and timing to reach potential buyers. -""" - -INSTAGRAM_AGENT_SYS_PROMPT = """ -You are an Instagram marketing expert focusing on real estate. Your task is to create visually appealing posts with engaging captions and hashtags to showcase properties, targeting specific demographics interested in real estate. -""" - -FACEBOOK_AGENT_SYS_PROMPT = """ -You are a Facebook marketing expert for real estate. Your task is to craft posts optimized for engagement and reach on Facebook, including using images, links, and targeted messaging to attract potential property buyers. -""" - -LINKEDIN_AGENT_SYS_PROMPT = """ -You are a LinkedIn marketing expert for the real estate industry. Your task is to create professional and informative posts, highlighting property features, market trends, and investment opportunities, tailored to professionals and investors. -""" - -EMAIL_AGENT_SYS_PROMPT = """ -You are an Email marketing expert specializing in real estate. Your task is to write compelling email campaigns to promote properties, focusing on personalization, subject lines, and effective call-to-action strategies to drive conversions. -""" - -# Initialize your agents for different social media platforms +# Initialize agents agents = [ Agent( - agent_name="Twitter-RealEstate-Agent", - system_prompt=TWITTER_AGENT_SYS_PROMPT, - model_name="gpt-4o", - max_loops=1, - dynamic_temperature_enabled=True, - saved_state_path="twitter_realestate_agent.json", - user_name="swarm_corp", - retry_attempts=1, - ), - Agent( - agent_name="Instagram-RealEstate-Agent", - system_prompt=INSTAGRAM_AGENT_SYS_PROMPT, - model_name="gpt-4o", - max_loops=1, - dynamic_temperature_enabled=True, - saved_state_path="instagram_realestate_agent.json", - user_name="swarm_corp", - retry_attempts=1, - ), - Agent( - agent_name="Facebook-RealEstate-Agent", - system_prompt=FACEBOOK_AGENT_SYS_PROMPT, - model_name="gpt-4o", + agent_name=f"Agent-{i}", + system_prompt="You are a helpful assistant.", + model_name="gpt-4", max_loops=1, - dynamic_temperature_enabled=True, - saved_state_path="facebook_realestate_agent.json", - user_name="swarm_corp", - retry_attempts=1, - ), - Agent( - agent_name="LinkedIn-RealEstate-Agent", - system_prompt=LINKEDIN_AGENT_SYS_PROMPT, - model_name="gpt-4o", - max_loops=1, - dynamic_temperature_enabled=True, - saved_state_path="linkedin_realestate_agent.json", - user_name="swarm_corp", - retry_attempts=1, - ), - Agent( - agent_name="Email-RealEstate-Agent", - system_prompt=EMAIL_AGENT_SYS_PROMPT, - model_name="gpt-4o", - max_loops=1, - dynamic_temperature_enabled=True, - saved_state_path="email_realestate_agent.json", - user_name="swarm_corp", - retry_attempts=1, - ), + ) + for i in range(3) ] -# Initialize workflow +# Initialize workflow with interactive mode workflow = ConcurrentWorkflow( - name="Real Estate Marketing Swarm", + name="Interactive Workflow", agents=agents, - metadata_output_path="metadata.json", - description="Concurrent swarm of content generators for real estate!", - auto_save=True, + interactive=True, + show_progress=True, + cache_size=100, + max_retries=3, + retry_delay=1.0 ) # Run workflow -task = "Create a marketing campaign for a luxury beachfront property in Miami, focusing on its stunning ocean views, private beach access, and state-of-the-art amenities." -metadata = workflow.run(task) -print(metadata) +task = "What are the benefits of using Python for data analysis?" +result = workflow.run(task) +print(result) ``` -### Example 2: Custom Output Handling +### Example 2: Batch Processing with Progress Bar ```python -# Initialize workflow with string output +# Initialize workflow workflow = ConcurrentWorkflow( - name="Real Estate Marketing Swarm", + name="Batch Processing Workflow", agents=agents, - metadata_output_path="metadata.json", - description="Concurrent swarm of content generators for real estate!", - auto_save=True, - return_str_on=True + show_progress=True, + auto_save=True ) -# Run workflow -task = "Develop a marketing strategy for a newly renovated historic townhouse in Boston, emphasizing its blend of classic architecture and modern amenities." -metadata_str = workflow.run(task) -print(metadata_str) +# Define tasks +tasks = [ + "Analyze the impact of climate change on agriculture", + "Evaluate renewable energy solutions", + "Assess water conservation strategies" +] + +# Run batch processing +results = workflow.run_batched(tasks) + +# Process results +for task, result in zip(tasks, results): + print(f"Task: {task}") + print(f"Result: {result}\n") ``` -### Example 3: Error Handling and Debugging +### Example 3: Error Handling and Retries ```python import logging @@ -271,71 +229,38 @@ import logging # Set up logging logging.basicConfig(level=logging.INFO) -# Initialize workflow +# Initialize workflow with retry settings workflow = ConcurrentWorkflow( - name="Real Estate Marketing Swarm", + name="Reliable Workflow", agents=agents, - metadata_output_path="metadata.json", - description="Concurrent swarm of content generators for real estate!", - auto_save=True + max_retries=3, + retry_delay=1.0, + show_progress=True ) # Run workflow with error handling try: - task = "Create a marketing campaign for a eco-friendly tiny house community in Portland, Oregon." - metadata = workflow.run(task) - print(metadata) + task = "Generate a comprehensive market analysis report" + result = workflow.run(task) + print(result) except Exception as e: - logging.error(f"An error occurred during workflow execution: {str(e)}") - # Additional error handling or debugging steps can be added here -``` - -### Example 4: Batch Processing - -```python -# Initialize workflow -workflow = ConcurrentWorkflow( - name="Real Estate Marketing Swarm", - agents=agents, - metadata_output_path="metadata_batch.json", - description="Concurrent swarm of content generators for real estate!", - auto_save=True -) - -# Define a list of tasks -tasks = [ - "Market a family-friendly suburban home with a large backyard and excellent schools nearby.", - "Promote a high-rise luxury apartment in New York City with panoramic skyline views.", - "Advertise a ski-in/ski-out chalet in Aspen, Colorado, perfect for winter sports enthusiasts." -] - -# Run workflow in batch mode -results = workflow.run_batched(tasks) - -# Process and print results -for task, result in zip(tasks, results): - print(f"Task: {task}") - print(f"Result: {result}\n") + logging.error(f"An error occurred: {str(e)}") ``` - - ## Tips and Best Practices -- **Agent Initialization**: Ensure that all agents are correctly initialized with their required configurations before passing them to `ConcurrentWorkflow`. -- **Metadata Management**: Use the `auto_save` flag to automatically save metadata if you plan to run multiple workflows in succession. -- **Concurrency Limits**: Adjust the number of agents based on your system's capabilities to avoid overloading resources. -- **Error Handling**: Implement try-except blocks when running workflows to catch and handle exceptions gracefully. -- **Batch Processing**: For large numbers of tasks, consider using `run_batched` or `run_parallel` methods to improve overall throughput. -- **Asynchronous Operations**: Utilize asynchronous methods (`run_async`, `run_batched_async`, `run_parallel_async`) when dealing with I/O-bound tasks or when you need to maintain responsiveness in your application. -- **Logging**: Implement detailed logging to track the progress of your workflows and troubleshoot any issues that may arise. -- **Resource Management**: Be mindful of API rate limits and resource consumption, especially when running large batches or parallel executions. -- **Testing**: Thoroughly test your workflows with various inputs and edge cases to ensure robust performance in production environments. +- **Agent Initialization**: Ensure all agents are correctly initialized with required configurations. +- **Interactive Mode**: Use interactive mode for tasks requiring user input or modification. +- **Caching**: Utilize the caching system for repeated tasks to improve performance. +- **Progress Tracking**: Enable progress bar for long-running tasks to monitor execution. +- **Error Handling**: Implement proper error handling and use retry mechanism for reliability. +- **Resource Management**: Monitor cache size and clear when necessary. +- **Batch Processing**: Use batch processing for multiple related tasks. +- **Logging**: Implement detailed logging for debugging and monitoring. ## References and Resources -- [Python's `asyncio` Documentation](https://docs.python.org/3/library/asyncio.html) -- [Pydantic Documentation](https://pydantic-docs.helpmanual.io/) -- [ThreadPoolExecutor in Python](https://docs.python.org/3/library/concurrent.futures.html#concurrent.futures.ThreadPoolExecutor) -- [Loguru for Logging in Python](https://loguru.readthedocs.io/en/stable/) -- [Tenacity: Retry library for Python](https://tenacity.readthedocs.io/en/latest/) \ No newline at end of file +- [Python's ThreadPoolExecutor Documentation](https://docs.python.org/3/library/concurrent.futures.html#concurrent.futures.ThreadPoolExecutor) +- [tqdm Progress Bar Documentation](https://tqdm.github.io/) +- [Python's functools.lru_cache Documentation](https://docs.python.org/3/library/functools.html#functools.lru_cache) +- [Loguru for Logging in Python](https://loguru.readthedocs.io/en/stable/) \ No newline at end of file diff --git a/swarms/communication/duckdb_wrap.py b/swarms/communication/duckdb_wrap.py index 7ea3223d..2ef95779 100644 --- a/swarms/communication/duckdb_wrap.py +++ b/swarms/communication/duckdb_wrap.py @@ -46,6 +46,7 @@ class Message: class DateTimeEncoder(json.JSONEncoder): """Custom JSON encoder for handling datetime objects.""" + def default(self, obj): if isinstance(obj, datetime.datetime): return obj.isoformat() @@ -540,7 +541,10 @@ class DuckDBConversation: except json.JSONDecodeError: pass - message = {"role": row[0], "content": content} # role column + message = { + "role": row[0], + "content": content, + } # role column if row[2]: # timestamp column message["timestamp"] = row[2] @@ -562,7 +566,9 @@ class DuckDBConversation: Returns: str: JSON string representation of the conversation """ - return json.dumps(self.to_dict(), indent=2, cls=DateTimeEncoder) + return json.dumps( + self.to_dict(), indent=2, cls=DateTimeEncoder + ) def to_yaml(self) -> str: """ @@ -585,7 +591,9 @@ class DuckDBConversation: """ try: with open(filename, "w") as f: - json.dump(self.to_dict(), f, indent=2, cls=DateTimeEncoder) + json.dump( + self.to_dict(), f, indent=2, cls=DateTimeEncoder + ) return True except Exception as e: if self.enable_logging: @@ -614,12 +622,13 @@ class DuckDBConversation: # Add all messages for message in messages: # Convert timestamp string back to datetime if it exists - timestamp = None if "timestamp" in message: try: - timestamp = datetime.datetime.fromisoformat(message["timestamp"]) + datetime.datetime.fromisoformat( + message["timestamp"] + ) except (ValueError, TypeError): - timestamp = message["timestamp"] + message["timestamp"] self.add( role=message["role"], diff --git a/swarms/structs/agent.py b/swarms/structs/agent.py index f7bda171..f37f5d61 100644 --- a/swarms/structs/agent.py +++ b/swarms/structs/agent.py @@ -1111,31 +1111,15 @@ class Agent: # Convert to a str if the response is not a str response = self.parse_llm_output(response) - # self.short_memory.add( - # role=self.agent_name, content=response - # ) - - # # Print - # self.pretty_print(response, loop_count) - - # # Output Cleaner - # self.output_cleaner_op(response) - - # 9. Batch memory updates and prints - update_tasks = [ - lambda: self.short_memory.add( - role=self.agent_name, content=response - ), - lambda: self.pretty_print( - response, loop_count - ), - lambda: self.output_cleaner_op(response), - ] - - with ThreadPoolExecutor( - max_workers=len(update_tasks) - ) as executor: - executor.map(lambda f: f(), update_tasks) + self.short_memory.add( + role=self.agent_name, content=response + ) + + # Print + self.pretty_print(response, loop_count) + + # Output Cleaner + self.output_cleaner_op(response) ####### MCP TOOL HANDLING ####### if ( @@ -1156,21 +1140,23 @@ class Agent: role="Tool Executor", content=out ) - agent_print( - f"{self.agent_name} - Tool Executor", - out, - loop_count, - self.streaming_on, - ) + if self.no_print is False: + agent_print( + f"{self.agent_name} - Tool Executor", + out, + loop_count, + self.streaming_on, + ) out = self.llm.run(out) - agent_print( - f"{self.agent_name} - Agent Analysis", - out, - loop_count, - self.streaming_on, - ) + if self.no_print is False: + agent_print( + f"{self.agent_name} - Agent Analysis", + out, + loop_count, + self.streaming_on, + ) self.short_memory.add( role=self.agent_name, content=out @@ -2738,6 +2724,8 @@ class Agent: f"{self.agent_name}: {response}", title=f"Agent Name: {self.agent_name} [Max Loops: {loop_count}]", ) + elif self.no_print is True: + pass else: # logger.info(f"Response: {response}") formatter.print_panel( diff --git a/swarms/structs/concurrent_workflow.py b/swarms/structs/concurrent_workflow.py index 58cda962..d6dfd619 100644 --- a/swarms/structs/concurrent_workflow.py +++ b/swarms/structs/concurrent_workflow.py @@ -1,70 +1,32 @@ import os -import uuid +import time from concurrent.futures import ThreadPoolExecutor -from datetime import datetime +from functools import lru_cache from typing import Any, Callable, Dict, List, Optional, Union -from pydantic import BaseModel, Field +from tqdm import tqdm from swarms.structs.agent import Agent from swarms.structs.base_swarm import BaseSwarm -from swarms.utils.file_processing import create_file_in_folder -from swarms.utils.loguru_logger import initialize_logger from swarms.structs.conversation import Conversation -from swarms.structs.swarm_id_generator import generate_swarm_id -from swarms.structs.output_types import OutputType +from swarms.utils.formatter import formatter +from swarms.utils.history_output_formatter import ( + history_output_formatter, +) +from swarms.utils.loguru_logger import initialize_logger logger = initialize_logger(log_folder="concurrent_workflow") -class AgentOutputSchema(BaseModel): - run_id: Optional[str] = Field( - ..., description="Unique ID for the run" - ) - agent_name: Optional[str] = Field( - ..., description="Name of the agent" - ) - task: Optional[str] = Field( - ..., description="Task or query given to the agent" - ) - output: Optional[str] = Field( - ..., description="Output generated by the agent" - ) - start_time: Optional[datetime] = Field( - ..., description="Start time of the task" - ) - end_time: Optional[datetime] = Field( - ..., description="End time of the task" - ) - duration: Optional[float] = Field( - ..., - description="Duration taken to complete the task (in seconds)", - ) - - -class MetadataSchema(BaseModel): - swarm_id: Optional[str] = Field( - generate_swarm_id(), description="Unique ID for the run" - ) - task: Optional[str] = Field( - ..., description="Task or query given to all agents" - ) - description: Optional[str] = Field( - "Concurrent execution of multiple agents", - description="Description of the workflow", - ) - agents: Optional[List[AgentOutputSchema]] = Field( - ..., description="List of agent outputs and metadata" - ) - timestamp: Optional[datetime] = Field( - default_factory=datetime.now, - description="Timestamp of the workflow execution", - ) - - class ConcurrentWorkflow(BaseSwarm): """ Represents a concurrent workflow that executes multiple agents concurrently in a production-grade manner. + Features include: + - Interactive model support + - Caching for repeated prompts + - Optional progress tracking + - Enhanced error handling and retries + - Input validation Args: name (str): The name of the workflow. Defaults to "ConcurrentWorkflow". @@ -72,11 +34,16 @@ class ConcurrentWorkflow(BaseSwarm): agents (List[Agent]): The list of agents to be executed concurrently. Defaults to an empty list. metadata_output_path (str): The path to save the metadata output. Defaults to "agent_metadata.json". auto_save (bool): Flag indicating whether to automatically save the metadata. Defaults to False. - output_schema (BaseModel): The output schema for the metadata. Defaults to MetadataSchema. + output_type (str): The type of output format. Defaults to "dict". max_loops (int): The maximum number of loops for each agent. Defaults to 1. return_str_on (bool): Flag indicating whether to return the output as a string. Defaults to False. - agent_responses (list): The list of agent responses. Defaults to an empty list. auto_generate_prompts (bool): Flag indicating whether to auto-generate prompts for agents. Defaults to False. + return_entire_history (bool): Flag indicating whether to return the entire conversation history. Defaults to False. + interactive (bool): Flag indicating whether to enable interactive mode. Defaults to False. + cache_size (int): The size of the cache. Defaults to 100. + max_retries (int): The maximum number of retry attempts. Defaults to 3. + retry_delay (float): The delay between retry attempts in seconds. Defaults to 1.0. + show_progress (bool): Flag indicating whether to show progress. Defaults to False. Raises: ValueError: If the list of agents is empty or if the description is empty. @@ -87,13 +54,18 @@ class ConcurrentWorkflow(BaseSwarm): agents (List[Agent]): The list of agents to be executed concurrently. metadata_output_path (str): The path to save the metadata output. auto_save (bool): Flag indicating whether to automatically save the metadata. - output_schema (BaseModel): The output schema for the metadata. + output_type (str): The type of output format. max_loops (int): The maximum number of loops for each agent. return_str_on (bool): Flag indicating whether to return the output as a string. - agent_responses (list): The list of agent responses. auto_generate_prompts (bool): Flag indicating whether to auto-generate prompts for agents. - retry_attempts (int): The number of retry attempts for failed agent executions. - retry_wait_time (int): The initial wait time for retries in seconds. + return_entire_history (bool): Flag indicating whether to return the entire conversation history. + interactive (bool): Flag indicating whether to enable interactive mode. + cache_size (int): The size of the cache. + max_retries (int): The maximum number of retry attempts. + retry_delay (float): The delay between retry attempts in seconds. + show_progress (bool): Flag indicating whether to show progress. + _cache (dict): The cache for storing agent outputs. + _progress_bar (tqdm): The progress bar for tracking execution. """ def __init__( @@ -103,13 +75,16 @@ class ConcurrentWorkflow(BaseSwarm): agents: List[Union[Agent, Callable]] = [], metadata_output_path: str = "agent_metadata.json", auto_save: bool = True, - output_schema: BaseModel = MetadataSchema, + output_type: str = "dict-all-except-first", max_loops: int = 1, return_str_on: bool = False, - agent_responses: list = [], auto_generate_prompts: bool = False, - output_type: OutputType = "dict", return_entire_history: bool = False, + interactive: bool = False, + cache_size: int = 100, + max_retries: int = 3, + retry_delay: float = 1.0, + show_progress: bool = False, *args, **kwargs, ): @@ -125,18 +100,22 @@ class ConcurrentWorkflow(BaseSwarm): self.agents = agents self.metadata_output_path = metadata_output_path self.auto_save = auto_save - self.output_schema = output_schema self.max_loops = max_loops self.return_str_on = return_str_on - self.agent_responses = agent_responses self.auto_generate_prompts = auto_generate_prompts self.max_workers = os.cpu_count() self.output_type = output_type self.return_entire_history = return_entire_history self.tasks = [] # Initialize tasks list + self.interactive = interactive + self.cache_size = cache_size + self.max_retries = max_retries + self.retry_delay = retry_delay + self.show_progress = show_progress + self._cache = {} + self._progress_bar = None self.reliability_check() - self.conversation = Conversation() def disable_agent_prints(self): @@ -145,29 +124,47 @@ class ConcurrentWorkflow(BaseSwarm): def reliability_check(self): try: - logger.info("Starting reliability checks") + formatter.print_panel( + content=f"\n 🏷️ Name: {self.name}\n 📝 Description: {self.description}\n 🤖 Agents: {len(self.agents)}\n 🔄 Max Loops: {self.max_loops}\n ", + title="⚙️ Concurrent Workflow Settings", + style="bold blue", + ) + formatter.print_panel( + content="🔍 Starting reliability checks", + title="🔒 Reliability Checks", + style="bold blue", + ) if self.name is None: - logger.error("A name is required for the swarm") - raise ValueError("A name is required for the swarm") + logger.error("❌ A name is required for the swarm") + raise ValueError( + "❌ A name is required for the swarm" + ) - if not self.agents: - logger.error("The list of agents must not be empty.") + if not self.agents or len(self.agents) <= 1: + logger.error( + "❌ The list of agents must not be empty." + ) raise ValueError( - "The list of agents must not be empty." + "❌ The list of agents must not be empty." ) if not self.description: - logger.error("A description is required.") - raise ValueError("A description is required.") + logger.error("❌ A description is required.") + raise ValueError("❌ A description is required.") + + formatter.print_panel( + content="✅ Reliability checks completed successfully", + title="🎉 Reliability Checks", + style="bold green", + ) - logger.info("Reliability checks completed successfully") except ValueError as e: - logger.error(f"Reliability check failed: {e}") + logger.error(f"❌ Reliability check failed: {e}") raise except Exception as e: logger.error( - f"An unexpected error occurred during reliability checks: {e}" + f"💥 An unexpected error occurred during reliability checks: {e}" ) raise @@ -184,147 +181,179 @@ class ConcurrentWorkflow(BaseSwarm): for agent in self.agents: agent.auto_generate_prompt = True - # @retry(wait=wait_exponential(min=2), stop=stop_after_attempt(3)) - - def transform_metadata_schema_to_str( - self, schema: MetadataSchema - ): - """ - Converts the metadata swarm schema into a string format with the agent name, response, and time. - - Args: - schema (MetadataSchema): The metadata schema to convert. - - Returns: - str: The string representation of the metadata schema. - - Example: - >>> metadata_schema = MetadataSchema() - >>> metadata_str = workflow.transform_metadata_schema_to_str(metadata_schema) - >>> print(metadata_str) - """ - self.agent_responses = [ - f"Agent Name: {agent.agent_name}\nResponse: {agent.output}\n\n" - for agent in schema.agents - ] - - # Return the agent responses as a string - return "\n".join(self.agent_responses) - - def save_metadata(self): - """ - Saves the metadata to a JSON file based on the auto_save flag. - - Example: - >>> workflow.save_metadata() - >>> # Metadata will be saved to the specified path if auto_save is True. - """ - # Save metadata to a JSON file - if self.auto_save: - logger.info( - f"Saving metadata to {self.metadata_output_path}" + @lru_cache(maxsize=100) + def _cached_run(self, task: str, agent_id: int) -> Any: + """Cached version of agent execution to avoid redundant computations""" + return self.agents[agent_id].run(task=task) + + def enable_progress_bar(self): + """Enable progress bar display""" + self.show_progress = True + + def disable_progress_bar(self): + """Disable progress bar display""" + if self._progress_bar: + self._progress_bar.close() + self._progress_bar = None + self.show_progress = False + + def _create_progress_bar(self, total: int): + """Create a progress bar for tracking execution""" + if self.show_progress: + try: + self._progress_bar = tqdm( + total=total, + desc="Processing tasks", + unit="task", + disable=not self.show_progress, + ncols=100, + bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]", + ) + except Exception as e: + logger.warning(f"Failed to create progress bar: {e}") + self.show_progress = False + self._progress_bar = None + return self._progress_bar + + def _update_progress(self, increment: int = 1): + """Update the progress bar""" + if self._progress_bar and self.show_progress: + try: + self._progress_bar.update(increment) + except Exception as e: + logger.warning(f"Failed to update progress bar: {e}") + self.disable_progress_bar() + + def _validate_input(self, task: str) -> bool: + """Validate input task""" + if not isinstance(task, str): + raise ValueError("Task must be a string") + if not task.strip(): + raise ValueError("Task cannot be empty") + return True + + def _handle_interactive(self, task: str) -> str: + """Handle interactive mode for task input""" + if self.interactive: + from swarms.utils.formatter import formatter + + # Display current task in a panel + formatter.print_panel( + content=f"Current task: {task}", + title="Task Status", + style="bold blue", ) - create_file_in_folder( - os.getenv("WORKSPACE_DIR"), - self.metadata_output_path, - self.output_schema.model_dump_json(indent=4), + + # Get user input with formatted prompt + formatter.print_panel( + content="Do you want to modify this task? (y/n/q to quit): ", + title="User Input", + style="bold green", ) + response = input().lower() + + if response == "q": + return None + elif response == "y": + formatter.print_panel( + content="Enter new task: ", + title="New Task Input", + style="bold yellow", + ) + new_task = input() + return new_task + return task + + def _run_with_retry( + self, agent: Agent, task: str, img: str = None + ) -> Any: + """Run agent with retry mechanism""" + for attempt in range(self.max_retries): + try: + output = agent.run(task=task, img=img) + self.conversation.add(agent.agent_name, output) + return output + except Exception as e: + if attempt == self.max_retries - 1: + logger.error( + f"Error running agent {agent.agent_name} after {self.max_retries} attempts: {e}" + ) + raise + logger.warning( + f"Attempt {attempt + 1} failed for agent {agent.agent_name}: {e}" + ) + time.sleep( + self.retry_delay * (attempt + 1) + ) # Exponential backoff def _run( self, task: str, img: str = None, *args, **kwargs ) -> Union[Dict[str, Any], str]: """ - Runs the workflow for the given task, executes agents concurrently using ThreadPoolExecutor, and saves metadata. + Enhanced run method with caching, progress tracking, and better error handling + """ - Args: - task (str): The task or query to give to all agents. - img (str): The image to be processed by the agents. + # Validate and potentially modify task + self._validate_input(task) + task = self._handle_interactive(task) - Returns: - Dict[str, Any]: The final metadata as a dictionary. - str: The final metadata as a string if return_str_on is True. + # Add task to conversation + self.conversation.add("User", task) - Example: - >>> metadata = workflow.run(task="Example task", img="example.jpg") - >>> print(metadata) - """ - logger.info( - f"Running concurrent workflow with {len(self.agents)} agents." - ) - - self.conversation.add( - "User", - task, - ) + # Create progress bar if enabled + if self.show_progress: + self._create_progress_bar(len(self.agents)) def run_agent( agent: Agent, task: str, img: str = None - ) -> AgentOutputSchema: - start_time = datetime.now() + ) -> Any: try: - output = agent.run(task=task) - - self.conversation.add( - agent.agent_name, - output, - ) + # Check cache first + cache_key = f"{task}_{agent.agent_name}" + if cache_key in self._cache: + output = self._cache[cache_key] + else: + output = self._run_with_retry(agent, task, img) + # Update cache + if len(self._cache) >= self.cache_size: + self._cache.pop(next(iter(self._cache))) + self._cache[cache_key] = output + + self._update_progress() + return output except Exception as e: logger.error( f"Error running agent {agent.agent_name}: {e}" ) + self._update_progress() raise - end_time = datetime.now() - duration = (end_time - start_time).total_seconds() - - agent_output = AgentOutputSchema( - run_id=uuid.uuid4().hex, - agent_name=agent.agent_name, - task=task, - output=output, - start_time=start_time, - end_time=end_time, - duration=duration, - ) - - logger.info( - f"Agent {agent.agent_name} completed task: {task} in {duration:.2f} seconds." - ) - - return agent_output - - with ThreadPoolExecutor( - max_workers=os.cpu_count() - ) as executor: - agent_outputs = list( - executor.map( - lambda agent: run_agent(agent, task), self.agents + try: + with ThreadPoolExecutor( + max_workers=self.max_workers + ) as executor: + list( + executor.map( + lambda agent: run_agent(agent, task), + self.agents, + ) ) - ) - - self.output_schema = MetadataSchema( - swarm_id=uuid.uuid4().hex, - task=task, - description=self.description, - agents=agent_outputs, + finally: + if self._progress_bar and self.show_progress: + try: + self._progress_bar.close() + except Exception as e: + logger.warning( + f"Failed to close progress bar: {e}" + ) + finally: + self._progress_bar = None + + return history_output_formatter( + self.conversation, + type=self.output_type, ) - self.save_metadata() - - if self.return_str_on: - return self.transform_metadata_schema_to_str( - self.output_schema - ) - - elif self.return_entire_history: - return self.conversation.return_history_as_string() - elif self.output_type == "list": - return self.conversation.return_messages_as_list() - elif self.output_type == "dict": - return self.conversation.return_messages_as_dictionary() - else: - return self.output_schema.model_dump_json(indent=4) - def run( self, task: Optional[str] = None, @@ -333,9 +362,11 @@ class ConcurrentWorkflow(BaseSwarm): **kwargs, ) -> Any: """ - Executes the agent's run method on a specified device. + Executes the agent's run method on a specified device with optional interactive mode. - This method attempts to execute the agent's run method on a specified device, either CPU or GPU. It logs the device selection and the number of cores or GPU ID used. If the device is set to CPU, it can use all available cores or a specific core specified by `device_id`. If the device is set to GPU, it uses the GPU specified by `device_id`. + This method attempts to execute the agent's run method on a specified device, either CPU or GPU. + It supports both standard execution and interactive mode where users can modify tasks and continue + the workflow interactively. Args: task (Optional[str], optional): The task to be executed. Defaults to None. @@ -359,8 +390,73 @@ class ConcurrentWorkflow(BaseSwarm): self.tasks.append(task) try: - outputs = self._run(task, img, *args, **kwargs) - return outputs + # Handle interactive mode + if self.interactive: + current_task = task + loop_count = 0 + + while loop_count < self.max_loops: + if ( + self.max_loops is not None + and loop_count >= self.max_loops + ): + formatter.print_panel( + content=f"Maximum number of loops ({self.max_loops}) reached.", + title="Session Complete", + style="bold red", + ) + break + + if current_task is None: + formatter.print_panel( + content="Enter your task (or 'q' to quit): ", + title="Task Input", + style="bold blue", + ) + current_task = input() + if current_task.lower() == "q": + break + + # Run the workflow with the current task + try: + outputs = self._run( + current_task, img, *args, **kwargs + ) + formatter.print_panel( + content=str(outputs), + title="Workflow Result", + style="bold green", + ) + except Exception as e: + formatter.print_panel( + content=f"Error: {str(e)}", + title="Error", + style="bold red", + ) + + # Ask if user wants to continue + formatter.print_panel( + content="Do you want to continue with a new task? (y/n): ", + title="Continue Session", + style="bold yellow", + ) + if input().lower() != "y": + break + + current_task = None + loop_count += 1 + + formatter.print_panel( + content="Interactive session ended.", + title="Session Complete", + style="bold blue", + ) + return outputs + else: + # Standard non-interactive execution + outputs = self._run(task, img, *args, **kwargs) + return outputs + except ValueError as e: logger.error(f"Invalid device specified: {e}") raise e @@ -368,29 +464,48 @@ class ConcurrentWorkflow(BaseSwarm): logger.error(f"An error occurred during execution: {e}") raise e - def run_batched( - self, tasks: List[str] - ) -> List[Union[Dict[str, Any], str]]: + def run_batched(self, tasks: List[str]) -> Any: """ - Runs the workflow for a batch of tasks, executes agents concurrently for each task, and saves metadata in a production-grade manner. + Enhanced batched execution with progress tracking + """ + if not tasks: + raise ValueError("Tasks list cannot be empty") - Args: - tasks (List[str]): A list of tasks or queries to give to all agents. + results = [] - Returns: - List[Union[Dict[str, Any], str]]: A list of final metadata for each task, either as a dictionary or a string. + # Create progress bar if enabled + if self.show_progress: + self._create_progress_bar(len(tasks)) + + try: + for task in tasks: + result = self.run(task) + results.append(result) + self._update_progress() + finally: + if self._progress_bar and self.show_progress: + try: + self._progress_bar.close() + except Exception as e: + logger.warning( + f"Failed to close progress bar: {e}" + ) + finally: + self._progress_bar = None - Example: - >>> tasks = ["Task 1", "Task 2"] - >>> results = workflow.run_batched(tasks) - >>> print(results) - """ - results = [] - for task in tasks: - result = self.run(task) - results.append(result) return results + def clear_cache(self): + """Clear the task cache""" + self._cache.clear() + + def get_cache_stats(self) -> Dict[str, int]: + """Get cache statistics""" + return { + "cache_size": len(self._cache), + "max_cache_size": self.cache_size, + } + # if __name__ == "__main__": # # Assuming you've already initialized some agents outside of this class diff --git a/swarms/structs/multi_model_gpu_manager.py b/swarms/structs/multi_model_gpu_manager.py index d1d04bae..221bdb6d 100644 --- a/swarms/structs/multi_model_gpu_manager.py +++ b/swarms/structs/multi_model_gpu_manager.py @@ -1168,7 +1168,6 @@ class ModelGrid: try: # This would need to be implemented based on the specific model types # and tasks supported. Here's a simple placeholder: - model = model_metadata.model if model_metadata.model_type == ModelType.PYTORCH: # Run PyTorch model diff --git a/tests/communication/test_duckdb_conversation.py b/tests/communication/test_duckdb_conversation.py index 9494de43..be837ad5 100644 --- a/tests/communication/test_duckdb_conversation.py +++ b/tests/communication/test_duckdb_conversation.py @@ -1,6 +1,4 @@ import os -import json -from datetime import datetime from pathlib import Path import tempfile import threading @@ -10,6 +8,7 @@ from swarms.communication.duckdb_wrap import ( MessageType, ) + def setup_test(): """Set up test environment.""" temp_dir = tempfile.TemporaryDirectory() @@ -21,25 +20,34 @@ def setup_test(): ) return temp_dir, db_path, conversation + def cleanup_test(temp_dir, db_path): """Clean up test environment.""" if os.path.exists(db_path): os.remove(db_path) temp_dir.cleanup() + def test_initialization(): """Test conversation initialization.""" temp_dir, db_path, _ = setup_test() try: conv = DuckDBConversation(db_path=str(db_path)) assert conv.db_path == db_path, "Database path mismatch" - assert conv.table_name == "conversations", "Table name mismatch" - assert conv.enable_timestamps is True, "Timestamps should be enabled" - assert conv.current_conversation_id is not None, "Conversation ID should not be None" + assert ( + conv.table_name == "conversations" + ), "Table name mismatch" + assert ( + conv.enable_timestamps is True + ), "Timestamps should be enabled" + assert ( + conv.current_conversation_id is not None + ), "Conversation ID should not be None" print("✓ Initialization test passed") finally: cleanup_test(temp_dir, db_path) + def test_add_message(): """Test adding a single message.""" temp_dir, db_path, conversation = setup_test() @@ -50,11 +58,14 @@ def test_add_message(): message_type=MessageType.USER, ) assert msg_id is not None, "Message ID should not be None" - assert isinstance(msg_id, int), "Message ID should be an integer" + assert isinstance( + msg_id, int + ), "Message ID should be an integer" print("✓ Add message test passed") finally: cleanup_test(temp_dir, db_path) + def test_add_complex_message(): """Test adding a message with complex content.""" temp_dir, db_path, conversation = setup_test() @@ -62,20 +73,21 @@ def test_add_complex_message(): complex_content = { "text": "Hello", "data": [1, 2, 3], - "nested": {"key": "value"} + "nested": {"key": "value"}, } msg_id = conversation.add( role="assistant", content=complex_content, message_type=MessageType.ASSISTANT, metadata={"source": "test"}, - token_count=10 + token_count=10, ) assert msg_id is not None, "Message ID should not be None" print("✓ Add complex message test passed") finally: cleanup_test(temp_dir, db_path) + def test_batch_add(): """Test batch adding messages.""" temp_dir, db_path, conversation = setup_test() @@ -84,21 +96,24 @@ def test_batch_add(): Message( role="user", content="First message", - message_type=MessageType.USER + message_type=MessageType.USER, ), Message( role="assistant", content="Second message", - message_type=MessageType.ASSISTANT - ) + message_type=MessageType.ASSISTANT, + ), ] msg_ids = conversation.batch_add(messages) assert len(msg_ids) == 2, "Should have 2 message IDs" - assert all(isinstance(id, int) for id in msg_ids), "All IDs should be integers" + assert all( + isinstance(id, int) for id in msg_ids + ), "All IDs should be integers" print("✓ Batch add test passed") finally: cleanup_test(temp_dir, db_path) + def test_get_str(): """Test getting conversation as string.""" temp_dir, db_path, conversation = setup_test() @@ -107,30 +122,38 @@ def test_get_str(): conversation.add("assistant", "Hi there!") conv_str = conversation.get_str() assert "user: Hello" in conv_str, "User message not found" - assert "assistant: Hi there!" in conv_str, "Assistant message not found" + assert ( + "assistant: Hi there!" in conv_str + ), "Assistant message not found" print("✓ Get string test passed") finally: cleanup_test(temp_dir, db_path) + def test_get_messages(): """Test getting messages with pagination.""" temp_dir, db_path, conversation = setup_test() try: for i in range(5): conversation.add("user", f"Message {i}") - + all_messages = conversation.get_messages() assert len(all_messages) == 5, "Should have 5 messages" - + limited_messages = conversation.get_messages(limit=2) - assert len(limited_messages) == 2, "Should have 2 limited messages" - + assert ( + len(limited_messages) == 2 + ), "Should have 2 limited messages" + offset_messages = conversation.get_messages(offset=2) - assert len(offset_messages) == 3, "Should have 3 offset messages" + assert ( + len(offset_messages) == 3 + ), "Should have 3 offset messages" print("✓ Get messages test passed") finally: cleanup_test(temp_dir, db_path) + def test_search_messages(): """Test searching messages.""" temp_dir, db_path, conversation = setup_test() @@ -138,152 +161,215 @@ def test_search_messages(): conversation.add("user", "Hello world") conversation.add("assistant", "Hello there") conversation.add("user", "Goodbye world") - + results = conversation.search_messages("world") - assert len(results) == 2, "Should find 2 messages with 'world'" - assert all("world" in msg["content"] for msg in results), "All results should contain 'world'" + assert ( + len(results) == 2 + ), "Should find 2 messages with 'world'" + assert all( + "world" in msg["content"] for msg in results + ), "All results should contain 'world'" print("✓ Search messages test passed") finally: cleanup_test(temp_dir, db_path) + def test_get_statistics(): """Test getting conversation statistics.""" temp_dir, db_path, conversation = setup_test() try: conversation.add("user", "Hello", token_count=2) conversation.add("assistant", "Hi", token_count=1) - + stats = conversation.get_statistics() - assert stats["total_messages"] == 2, "Should have 2 total messages" - assert stats["unique_roles"] == 2, "Should have 2 unique roles" - assert stats["total_tokens"] == 3, "Should have 3 total tokens" + assert ( + stats["total_messages"] == 2 + ), "Should have 2 total messages" + assert ( + stats["unique_roles"] == 2 + ), "Should have 2 unique roles" + assert ( + stats["total_tokens"] == 3 + ), "Should have 3 total tokens" print("✓ Get statistics test passed") finally: cleanup_test(temp_dir, db_path) + def test_json_operations(): """Test JSON save and load operations.""" temp_dir, db_path, conversation = setup_test() try: conversation.add("user", "Hello") conversation.add("assistant", "Hi") - + json_path = Path(temp_dir.name) / "test_conversation.json" conversation.save_as_json(str(json_path)) assert json_path.exists(), "JSON file should exist" - + new_conversation = DuckDBConversation( db_path=str(Path(temp_dir.name) / "new.duckdb") ) - assert new_conversation.load_from_json(str(json_path)), "Should load from JSON" - assert len(new_conversation.get_messages()) == 2, "Should have 2 messages after load" + assert new_conversation.load_from_json( + str(json_path) + ), "Should load from JSON" + assert ( + len(new_conversation.get_messages()) == 2 + ), "Should have 2 messages after load" print("✓ JSON operations test passed") finally: cleanup_test(temp_dir, db_path) + def test_yaml_operations(): """Test YAML save and load operations.""" temp_dir, db_path, conversation = setup_test() try: conversation.add("user", "Hello") conversation.add("assistant", "Hi") - + yaml_path = Path(temp_dir.name) / "test_conversation.yaml" conversation.save_as_yaml(str(yaml_path)) assert yaml_path.exists(), "YAML file should exist" - + new_conversation = DuckDBConversation( db_path=str(Path(temp_dir.name) / "new.duckdb") ) - assert new_conversation.load_from_yaml(str(yaml_path)), "Should load from YAML" - assert len(new_conversation.get_messages()) == 2, "Should have 2 messages after load" + assert new_conversation.load_from_yaml( + str(yaml_path) + ), "Should load from YAML" + assert ( + len(new_conversation.get_messages()) == 2 + ), "Should have 2 messages after load" print("✓ YAML operations test passed") finally: cleanup_test(temp_dir, db_path) + def test_message_types(): """Test different message types.""" temp_dir, db_path, conversation = setup_test() try: - conversation.add("system", "System message", message_type=MessageType.SYSTEM) - conversation.add("user", "User message", message_type=MessageType.USER) - conversation.add("assistant", "Assistant message", message_type=MessageType.ASSISTANT) - conversation.add("function", "Function message", message_type=MessageType.FUNCTION) - conversation.add("tool", "Tool message", message_type=MessageType.TOOL) - + conversation.add( + "system", + "System message", + message_type=MessageType.SYSTEM, + ) + conversation.add( + "user", "User message", message_type=MessageType.USER + ) + conversation.add( + "assistant", + "Assistant message", + message_type=MessageType.ASSISTANT, + ) + conversation.add( + "function", + "Function message", + message_type=MessageType.FUNCTION, + ) + conversation.add( + "tool", "Tool message", message_type=MessageType.TOOL + ) + messages = conversation.get_messages() assert len(messages) == 5, "Should have 5 messages" - assert all("message_type" in msg for msg in messages), "All messages should have type" + assert all( + "message_type" in msg for msg in messages + ), "All messages should have type" print("✓ Message types test passed") finally: cleanup_test(temp_dir, db_path) + def test_delete_operations(): """Test deletion operations.""" temp_dir, db_path, conversation = setup_test() try: conversation.add("user", "Hello") conversation.add("assistant", "Hi") - - assert conversation.delete_current_conversation(), "Should delete conversation" - assert len(conversation.get_messages()) == 0, "Should have no messages after delete" - + + assert ( + conversation.delete_current_conversation() + ), "Should delete conversation" + assert ( + len(conversation.get_messages()) == 0 + ), "Should have no messages after delete" + conversation.add("user", "New message") assert conversation.clear_all(), "Should clear all messages" - assert len(conversation.get_messages()) == 0, "Should have no messages after clear" + assert ( + len(conversation.get_messages()) == 0 + ), "Should have no messages after clear" print("✓ Delete operations test passed") finally: cleanup_test(temp_dir, db_path) + def test_concurrent_operations(): """Test concurrent operations.""" temp_dir, db_path, conversation = setup_test() try: + def add_messages(): for i in range(10): conversation.add("user", f"Message {i}") - - threads = [threading.Thread(target=add_messages) for _ in range(5)] + + threads = [ + threading.Thread(target=add_messages) for _ in range(5) + ] for thread in threads: thread.start() for thread in threads: thread.join() - + messages = conversation.get_messages() - assert len(messages) == 50, "Should have 50 messages (10 * 5 threads)" + assert ( + len(messages) == 50 + ), "Should have 50 messages (10 * 5 threads)" print("✓ Concurrent operations test passed") finally: cleanup_test(temp_dir, db_path) + def test_error_handling(): """Test error handling.""" temp_dir, db_path, conversation = setup_test() try: # Test invalid message type try: - conversation.add("user", "Message", message_type="invalid") - assert False, "Should raise exception for invalid message type" + conversation.add( + "user", "Message", message_type="invalid" + ) + assert ( + False + ), "Should raise exception for invalid message type" except Exception: pass - + # Test invalid JSON content try: conversation.add("user", {"invalid": object()}) - assert False, "Should raise exception for invalid JSON content" + assert ( + False + ), "Should raise exception for invalid JSON content" except Exception: pass - + # Test invalid file operations try: conversation.load_from_json("/nonexistent/path.json") - assert False, "Should raise exception for invalid file path" + assert ( + False + ), "Should raise exception for invalid file path" except Exception: pass - + print("✓ Error handling test passed") finally: cleanup_test(temp_dir, db_path) + def run_all_tests(): """Run all tests.""" print("Running DuckDB Conversation tests...") @@ -301,17 +387,18 @@ def run_all_tests(): test_message_types, test_delete_operations, test_concurrent_operations, - test_error_handling + test_error_handling, ] - + for test in tests: try: test() except Exception as e: print(f"✗ {test.__name__} failed: {str(e)}") raise - + print("\nAll tests completed successfully!") -if __name__ == '__main__': - run_all_tests() \ No newline at end of file + +if __name__ == "__main__": + run_all_tests() diff --git a/tests/communication/test_sqlite_wrapper.py b/tests/communication/test_sqlite_wrapper.py index 2a42ce76..d188ec10 100644 --- a/tests/communication/test_sqlite_wrapper.py +++ b/tests/communication/test_sqlite_wrapper.py @@ -98,8 +98,8 @@ def test_basic_conversation() -> bool: # Test adding messages console.print("\n[bold]Adding messages...[/bold]") - msg_id1 = conversation.add("user", "Hello") - msg_id2 = conversation.add("assistant", "Hi there!") + conversation.add("user", "Hello") + conversation.add("assistant", "Hi there!") # Test getting messages console.print("\n[bold]Retrieved messages:[/bold]") diff --git a/tests/structs/test_agent.py b/tests/structs/test_agent.py index 1661e354..1c4c7971 100644 --- a/tests/structs/test_agent.py +++ b/tests/structs/test_agent.py @@ -156,7 +156,7 @@ def test_stopping_token(mocked_sleep, basic_flow): # Test interactive mode -def test_interactive_mode(basic_flow): +def test_interactive(basic_flow): basic_flow.interactive = True assert basic_flow.interactive @@ -309,7 +309,7 @@ def test_flow_run(flow_instance): assert len(response) > 0 -def test_flow_interactive_mode(flow_instance): +def test_flow_interactive(flow_instance): # Test the interactive mode of the Agent class flow_instance.interactive = True response = flow_instance.run("Test task")