concurrent docs updates + concurrent workflow updates

pull/801/merge
Kye Gomez 1 day ago
parent 688c9eab05
commit e9a7c7994c

@ -0,0 +1,33 @@
from swarms import Agent, ConcurrentWorkflow
from swarms.prompts.finance_agent_sys_prompt import (
FINANCIAL_AGENT_SYS_PROMPT,
)
if __name__ == "__main__":
# Assuming you've already initialized some agents outside of this class
agents = [
Agent(
agent_name=f"Financial-Analysis-Agent-{i}",
system_prompt=FINANCIAL_AGENT_SYS_PROMPT,
model_name="gpt-4o",
max_loops=1,
)
for i in range(3) # Adjust number of agents as needed
]
# Initialize the workflow with the list of agents
workflow = ConcurrentWorkflow(
agents=agents,
metadata_output_path="agent_metadata_4.json",
output_type="list",
show_progress=False,
max_loops=3,
interactive=True,
)
# Define the task for all agents
task = "How can I establish a ROTH IRA to buy stocks and get a tax break? What are the criteria?"
# Run the workflow and save metadata
metadata = workflow.run(task)
print(metadata)

@ -6,16 +6,17 @@ The `ConcurrentWorkflow` class is designed to facilitate the concurrent executio
### Key Features
- **Concurrent Execution**: Runs multiple agents simultaneously using Python's `asyncio` and `ThreadPoolExecutor`.
- **Metadata Collection**: Gathers detailed metadata about each agent's execution, including start and end times, duration, and output.
- **Customizable Output**: Allows the user to save metadata to a file or return it as a string or dictionary.
- **Error Handling**: Implements retry logic for improved reliability.
- **Batch Processing**: Supports running tasks in batches and parallel execution.
- **Asynchronous Execution**: Provides asynchronous run options for improved performance.
## Class Definitions
The `ConcurrentWorkflow` class is the core class that manages the concurrent execution of agents. It inherits from `BaseSwarm` and includes several key attributes and methods to facilitate this process.
- **Concurrent Execution**: Runs multiple agents simultaneously using Python's `ThreadPoolExecutor`
- **Interactive Mode**: Supports interactive task modification and execution
- **Caching System**: Implements LRU caching for repeated prompts
- **Progress Tracking**: Optional progress bar for task execution
- **Enhanced Error Handling**: Implements retry mechanism with exponential backoff
- **Input Validation**: Validates task inputs before execution
- **Batch Processing**: Supports running tasks in batches
- **Metadata Collection**: Gathers detailed metadata about each agent's execution
- **Customizable Output**: Allows saving metadata to file or returning as string/dictionary
## Class Definition
### Attributes
@ -26,15 +27,18 @@ The `ConcurrentWorkflow` class is the core class that manages the concurrent exe
| `agents` | `List[Agent]` | A list of agents to be executed concurrently. |
| `metadata_output_path` | `str` | Path to save the metadata output. Defaults to `"agent_metadata.json"`. |
| `auto_save` | `bool` | Flag indicating whether to automatically save the metadata. |
| `output_schema` | `BaseModel` | The output schema for the metadata, defaults to `MetadataSchema`. |
| `max_loops` | `int` | Maximum number of loops for the workflow, defaults to `1`. |
| `output_type` | `str` | The type of output format. Defaults to `"dict"`. |
| `max_loops` | `int` | Maximum number of loops for each agent. Defaults to `1`. |
| `return_str_on` | `bool` | Flag to return output as string. Defaults to `False`. |
| `agent_responses` | `List[str]` | List of agent responses as strings. |
| `auto_generate_prompts`| `bool` | Flag indicating whether to auto-generate prompts for agents. |
| `output_type` | `OutputType` | Type of output format to return. Defaults to `"dict"`. |
| `return_entire_history`| `bool` | Flag to return entire conversation history. Defaults to `False`. |
| `conversation` | `Conversation` | Conversation object to track agent interactions. |
| `max_workers` | `int` | Maximum number of worker threads. Defaults to CPU count. |
| `interactive` | `bool` | Flag indicating whether to enable interactive mode. Defaults to `False`. |
| `cache_size` | `int` | The size of the cache. Defaults to `100`. |
| `max_retries` | `int` | The maximum number of retry attempts. Defaults to `3`. |
| `retry_delay` | `float` | The delay between retry attempts in seconds. Defaults to `1.0`. |
| `show_progress` | `bool` | Flag indicating whether to show progress. Defaults to `False`. |
| `_cache` | `dict` | The cache for storing agent outputs. |
| `_progress_bar` | `tqdm` | The progress bar for tracking execution. |
## Methods
@ -51,53 +55,72 @@ Initializes the `ConcurrentWorkflow` class with the provided parameters.
| `agents` | `List[Agent]` | `[]` | A list of agents to be executed concurrently. |
| `metadata_output_path`| `str` | `"agent_metadata.json"` | Path to save the metadata output. |
| `auto_save` | `bool` | `True` | Flag indicating whether to automatically save the metadata. |
| `output_schema` | `BaseModel` | `MetadataSchema` | The output schema for the metadata. |
| `max_loops` | `int` | `1` | Maximum number of loops for the workflow. |
| `output_type` | `str` | `"dict"` | The type of output format. |
| `max_loops` | `int` | `1` | Maximum number of loops for each agent. |
| `return_str_on` | `bool` | `False` | Flag to return output as string. |
| `agent_responses` | `List[str]` | `[]` | List of agent responses as strings. |
| `auto_generate_prompts`| `bool` | `False` | Flag indicating whether to auto-generate prompts for agents. |
| `output_type` | `OutputType` | `"dict"` | Type of output format to return. |
| `return_entire_history`| `bool` | `False` | Flag to return entire conversation history. |
| `interactive` | `bool` | `False` | Flag indicating whether to enable interactive mode. |
| `cache_size` | `int` | `100` | The size of the cache. |
| `max_retries` | `int` | `3` | The maximum number of retry attempts. |
| `retry_delay` | `float` | `1.0` | The delay between retry attempts in seconds. |
| `show_progress` | `bool` | `False` | Flag indicating whether to show progress. |
#### Raises
- `ValueError`: If the list of agents is empty or if the description is empty.
### ConcurrentWorkflow.disable_agent_prints
Disables print statements for all agents in the workflow.
```python
workflow.disable_agent_prints()
```
### ConcurrentWorkflow.activate_auto_prompt_engineering
Activates the auto-generate prompts feature for all agents in the workflow.
#### Example
```python
workflow = ConcurrentWorkflow(agents=[Agent()])
workflow.activate_auto_prompt_engineering()
# All agents in the workflow will now auto-generate prompts.
```
### ConcurrentWorkflow.transform_metadata_schema_to_str
### ConcurrentWorkflow.enable_progress_bar
Transforms the metadata schema into a string format.
Enables the progress bar display for task execution.
#### Parameters
```python
workflow.enable_progress_bar()
```
| Parameter | Type | Description |
|-------------|---------------------|-----------------------------------------------------------|
| `schema` | `MetadataSchema` | The metadata schema to transform. |
### ConcurrentWorkflow.disable_progress_bar
#### Returns
Disables the progress bar display.
```python
workflow.disable_progress_bar()
```
### ConcurrentWorkflow.clear_cache
- `str`: The metadata schema as a formatted string.
Clears the task cache.
### ConcurrentWorkflow.save_metadata
```python
workflow.clear_cache()
```
Saves the metadata to a JSON file based on the `auto_save` flag.
### ConcurrentWorkflow.get_cache_stats
#### Example
Gets cache statistics.
#### Returns
- `Dict[str, int]`: A dictionary containing cache statistics.
```python
workflow.save_metadata()
# Metadata will be saved to the specified path if auto_save is True.
stats = workflow.get_cache_stats()
print(stats) # {'cache_size': 5, 'max_cache_size': 100}
```
### ConcurrentWorkflow.run
@ -134,136 +157,71 @@ Runs the workflow for a batch of tasks.
#### Returns
- `List[Union[Dict[str, Any], str]]`: A list of final metadata for each task.
#### Example
```python
tasks = ["Task 1", "Task 2"]
results = workflow.run_batched(tasks)
print(results)
```
- `List[Any]`: A list of results for each task.
## Usage Examples
### Example 1: Basic Usage
### Example 1: Basic Usage with Interactive Mode
```python
import os
from swarms import Agent, ConcurrentWorkflow
from swarms import Agent, ConcurrentWorkflow, OpenAIChat
# Define custom system prompts for each social media platform
TWITTER_AGENT_SYS_PROMPT = """
You are a Twitter marketing expert specializing in real estate. Your task is to create engaging, concise tweets to promote properties, analyze trends to maximize engagement, and use appropriate hashtags and timing to reach potential buyers.
"""
INSTAGRAM_AGENT_SYS_PROMPT = """
You are an Instagram marketing expert focusing on real estate. Your task is to create visually appealing posts with engaging captions and hashtags to showcase properties, targeting specific demographics interested in real estate.
"""
FACEBOOK_AGENT_SYS_PROMPT = """
You are a Facebook marketing expert for real estate. Your task is to craft posts optimized for engagement and reach on Facebook, including using images, links, and targeted messaging to attract potential property buyers.
"""
LINKEDIN_AGENT_SYS_PROMPT = """
You are a LinkedIn marketing expert for the real estate industry. Your task is to create professional and informative posts, highlighting property features, market trends, and investment opportunities, tailored to professionals and investors.
"""
EMAIL_AGENT_SYS_PROMPT = """
You are an Email marketing expert specializing in real estate. Your task is to write compelling email campaigns to promote properties, focusing on personalization, subject lines, and effective call-to-action strategies to drive conversions.
"""
# Initialize your agents for different social media platforms
# Initialize agents
agents = [
Agent(
agent_name="Twitter-RealEstate-Agent",
system_prompt=TWITTER_AGENT_SYS_PROMPT,
model_name="gpt-4o",
max_loops=1,
dynamic_temperature_enabled=True,
saved_state_path="twitter_realestate_agent.json",
user_name="swarm_corp",
retry_attempts=1,
),
Agent(
agent_name="Instagram-RealEstate-Agent",
system_prompt=INSTAGRAM_AGENT_SYS_PROMPT,
model_name="gpt-4o",
max_loops=1,
dynamic_temperature_enabled=True,
saved_state_path="instagram_realestate_agent.json",
user_name="swarm_corp",
retry_attempts=1,
),
Agent(
agent_name="Facebook-RealEstate-Agent",
system_prompt=FACEBOOK_AGENT_SYS_PROMPT,
model_name="gpt-4o",
agent_name=f"Agent-{i}",
system_prompt="You are a helpful assistant.",
model_name="gpt-4",
max_loops=1,
dynamic_temperature_enabled=True,
saved_state_path="facebook_realestate_agent.json",
user_name="swarm_corp",
retry_attempts=1,
),
Agent(
agent_name="LinkedIn-RealEstate-Agent",
system_prompt=LINKEDIN_AGENT_SYS_PROMPT,
model_name="gpt-4o",
max_loops=1,
dynamic_temperature_enabled=True,
saved_state_path="linkedin_realestate_agent.json",
user_name="swarm_corp",
retry_attempts=1,
),
Agent(
agent_name="Email-RealEstate-Agent",
system_prompt=EMAIL_AGENT_SYS_PROMPT,
model_name="gpt-4o",
max_loops=1,
dynamic_temperature_enabled=True,
saved_state_path="email_realestate_agent.json",
user_name="swarm_corp",
retry_attempts=1,
),
)
for i in range(3)
]
# Initialize workflow
# Initialize workflow with interactive mode
workflow = ConcurrentWorkflow(
name="Real Estate Marketing Swarm",
name="Interactive Workflow",
agents=agents,
metadata_output_path="metadata.json",
description="Concurrent swarm of content generators for real estate!",
auto_save=True,
interactive=True,
show_progress=True,
cache_size=100,
max_retries=3,
retry_delay=1.0
)
# Run workflow
task = "Create a marketing campaign for a luxury beachfront property in Miami, focusing on its stunning ocean views, private beach access, and state-of-the-art amenities."
metadata = workflow.run(task)
print(metadata)
task = "What are the benefits of using Python for data analysis?"
result = workflow.run(task)
print(result)
```
### Example 2: Custom Output Handling
### Example 2: Batch Processing with Progress Bar
```python
# Initialize workflow with string output
# Initialize workflow
workflow = ConcurrentWorkflow(
name="Real Estate Marketing Swarm",
name="Batch Processing Workflow",
agents=agents,
metadata_output_path="metadata.json",
description="Concurrent swarm of content generators for real estate!",
auto_save=True,
return_str_on=True
show_progress=True,
auto_save=True
)
# Run workflow
task = "Develop a marketing strategy for a newly renovated historic townhouse in Boston, emphasizing its blend of classic architecture and modern amenities."
metadata_str = workflow.run(task)
print(metadata_str)
# Define tasks
tasks = [
"Analyze the impact of climate change on agriculture",
"Evaluate renewable energy solutions",
"Assess water conservation strategies"
]
# Run batch processing
results = workflow.run_batched(tasks)
# Process results
for task, result in zip(tasks, results):
print(f"Task: {task}")
print(f"Result: {result}\n")
```
### Example 3: Error Handling and Debugging
### Example 3: Error Handling and Retries
```python
import logging
@ -271,71 +229,38 @@ import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
# Initialize workflow
# Initialize workflow with retry settings
workflow = ConcurrentWorkflow(
name="Real Estate Marketing Swarm",
name="Reliable Workflow",
agents=agents,
metadata_output_path="metadata.json",
description="Concurrent swarm of content generators for real estate!",
auto_save=True
max_retries=3,
retry_delay=1.0,
show_progress=True
)
# Run workflow with error handling
try:
task = "Create a marketing campaign for a eco-friendly tiny house community in Portland, Oregon."
metadata = workflow.run(task)
print(metadata)
task = "Generate a comprehensive market analysis report"
result = workflow.run(task)
print(result)
except Exception as e:
logging.error(f"An error occurred during workflow execution: {str(e)}")
# Additional error handling or debugging steps can be added here
```
### Example 4: Batch Processing
```python
# Initialize workflow
workflow = ConcurrentWorkflow(
name="Real Estate Marketing Swarm",
agents=agents,
metadata_output_path="metadata_batch.json",
description="Concurrent swarm of content generators for real estate!",
auto_save=True
)
# Define a list of tasks
tasks = [
"Market a family-friendly suburban home with a large backyard and excellent schools nearby.",
"Promote a high-rise luxury apartment in New York City with panoramic skyline views.",
"Advertise a ski-in/ski-out chalet in Aspen, Colorado, perfect for winter sports enthusiasts."
]
# Run workflow in batch mode
results = workflow.run_batched(tasks)
# Process and print results
for task, result in zip(tasks, results):
print(f"Task: {task}")
print(f"Result: {result}\n")
logging.error(f"An error occurred: {str(e)}")
```
## Tips and Best Practices
- **Agent Initialization**: Ensure that all agents are correctly initialized with their required configurations before passing them to `ConcurrentWorkflow`.
- **Metadata Management**: Use the `auto_save` flag to automatically save metadata if you plan to run multiple workflows in succession.
- **Concurrency Limits**: Adjust the number of agents based on your system's capabilities to avoid overloading resources.
- **Error Handling**: Implement try-except blocks when running workflows to catch and handle exceptions gracefully.
- **Batch Processing**: For large numbers of tasks, consider using `run_batched` or `run_parallel` methods to improve overall throughput.
- **Asynchronous Operations**: Utilize asynchronous methods (`run_async`, `run_batched_async`, `run_parallel_async`) when dealing with I/O-bound tasks or when you need to maintain responsiveness in your application.
- **Logging**: Implement detailed logging to track the progress of your workflows and troubleshoot any issues that may arise.
- **Resource Management**: Be mindful of API rate limits and resource consumption, especially when running large batches or parallel executions.
- **Testing**: Thoroughly test your workflows with various inputs and edge cases to ensure robust performance in production environments.
- **Agent Initialization**: Ensure all agents are correctly initialized with required configurations.
- **Interactive Mode**: Use interactive mode for tasks requiring user input or modification.
- **Caching**: Utilize the caching system for repeated tasks to improve performance.
- **Progress Tracking**: Enable progress bar for long-running tasks to monitor execution.
- **Error Handling**: Implement proper error handling and use retry mechanism for reliability.
- **Resource Management**: Monitor cache size and clear when necessary.
- **Batch Processing**: Use batch processing for multiple related tasks.
- **Logging**: Implement detailed logging for debugging and monitoring.
## References and Resources
- [Python's `asyncio` Documentation](https://docs.python.org/3/library/asyncio.html)
- [Pydantic Documentation](https://pydantic-docs.helpmanual.io/)
- [ThreadPoolExecutor in Python](https://docs.python.org/3/library/concurrent.futures.html#concurrent.futures.ThreadPoolExecutor)
- [Python's ThreadPoolExecutor Documentation](https://docs.python.org/3/library/concurrent.futures.html#concurrent.futures.ThreadPoolExecutor)
- [tqdm Progress Bar Documentation](https://tqdm.github.io/)
- [Python's functools.lru_cache Documentation](https://docs.python.org/3/library/functools.html#functools.lru_cache)
- [Loguru for Logging in Python](https://loguru.readthedocs.io/en/stable/)
- [Tenacity: Retry library for Python](https://tenacity.readthedocs.io/en/latest/)

@ -46,6 +46,7 @@ class Message:
class DateTimeEncoder(json.JSONEncoder):
"""Custom JSON encoder for handling datetime objects."""
def default(self, obj):
if isinstance(obj, datetime.datetime):
return obj.isoformat()
@ -540,7 +541,10 @@ class DuckDBConversation:
except json.JSONDecodeError:
pass
message = {"role": row[0], "content": content} # role column
message = {
"role": row[0],
"content": content,
} # role column
if row[2]: # timestamp column
message["timestamp"] = row[2]
@ -562,7 +566,9 @@ class DuckDBConversation:
Returns:
str: JSON string representation of the conversation
"""
return json.dumps(self.to_dict(), indent=2, cls=DateTimeEncoder)
return json.dumps(
self.to_dict(), indent=2, cls=DateTimeEncoder
)
def to_yaml(self) -> str:
"""
@ -585,7 +591,9 @@ class DuckDBConversation:
"""
try:
with open(filename, "w") as f:
json.dump(self.to_dict(), f, indent=2, cls=DateTimeEncoder)
json.dump(
self.to_dict(), f, indent=2, cls=DateTimeEncoder
)
return True
except Exception as e:
if self.enable_logging:
@ -614,12 +622,13 @@ class DuckDBConversation:
# Add all messages
for message in messages:
# Convert timestamp string back to datetime if it exists
timestamp = None
if "timestamp" in message:
try:
timestamp = datetime.datetime.fromisoformat(message["timestamp"])
datetime.datetime.fromisoformat(
message["timestamp"]
)
except (ValueError, TypeError):
timestamp = message["timestamp"]
message["timestamp"]
self.add(
role=message["role"],

@ -1111,31 +1111,15 @@ class Agent:
# Convert to a str if the response is not a str
response = self.parse_llm_output(response)
# self.short_memory.add(
# role=self.agent_name, content=response
# )
# # Print
# self.pretty_print(response, loop_count)
# # Output Cleaner
# self.output_cleaner_op(response)
# 9. Batch memory updates and prints
update_tasks = [
lambda: self.short_memory.add(
role=self.agent_name, content=response
),
lambda: self.pretty_print(
response, loop_count
),
lambda: self.output_cleaner_op(response),
]
with ThreadPoolExecutor(
max_workers=len(update_tasks)
) as executor:
executor.map(lambda f: f(), update_tasks)
self.short_memory.add(
role=self.agent_name, content=response
)
# Print
self.pretty_print(response, loop_count)
# Output Cleaner
self.output_cleaner_op(response)
####### MCP TOOL HANDLING #######
if (
@ -1156,21 +1140,23 @@ class Agent:
role="Tool Executor", content=out
)
agent_print(
f"{self.agent_name} - Tool Executor",
out,
loop_count,
self.streaming_on,
)
if self.no_print is False:
agent_print(
f"{self.agent_name} - Tool Executor",
out,
loop_count,
self.streaming_on,
)
out = self.llm.run(out)
agent_print(
f"{self.agent_name} - Agent Analysis",
out,
loop_count,
self.streaming_on,
)
if self.no_print is False:
agent_print(
f"{self.agent_name} - Agent Analysis",
out,
loop_count,
self.streaming_on,
)
self.short_memory.add(
role=self.agent_name, content=out
@ -2738,6 +2724,8 @@ class Agent:
f"{self.agent_name}: {response}",
title=f"Agent Name: {self.agent_name} [Max Loops: {loop_count}]",
)
elif self.no_print is True:
pass
else:
# logger.info(f"Response: {response}")
formatter.print_panel(

@ -1,70 +1,32 @@
import os
import uuid
import time
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from functools import lru_cache
from typing import Any, Callable, Dict, List, Optional, Union
from pydantic import BaseModel, Field
from tqdm import tqdm
from swarms.structs.agent import Agent
from swarms.structs.base_swarm import BaseSwarm
from swarms.utils.file_processing import create_file_in_folder
from swarms.utils.loguru_logger import initialize_logger
from swarms.structs.conversation import Conversation
from swarms.structs.swarm_id_generator import generate_swarm_id
from swarms.structs.output_types import OutputType
from swarms.utils.formatter import formatter
from swarms.utils.history_output_formatter import (
history_output_formatter,
)
from swarms.utils.loguru_logger import initialize_logger
logger = initialize_logger(log_folder="concurrent_workflow")
class AgentOutputSchema(BaseModel):
run_id: Optional[str] = Field(
..., description="Unique ID for the run"
)
agent_name: Optional[str] = Field(
..., description="Name of the agent"
)
task: Optional[str] = Field(
..., description="Task or query given to the agent"
)
output: Optional[str] = Field(
..., description="Output generated by the agent"
)
start_time: Optional[datetime] = Field(
..., description="Start time of the task"
)
end_time: Optional[datetime] = Field(
..., description="End time of the task"
)
duration: Optional[float] = Field(
...,
description="Duration taken to complete the task (in seconds)",
)
class MetadataSchema(BaseModel):
swarm_id: Optional[str] = Field(
generate_swarm_id(), description="Unique ID for the run"
)
task: Optional[str] = Field(
..., description="Task or query given to all agents"
)
description: Optional[str] = Field(
"Concurrent execution of multiple agents",
description="Description of the workflow",
)
agents: Optional[List[AgentOutputSchema]] = Field(
..., description="List of agent outputs and metadata"
)
timestamp: Optional[datetime] = Field(
default_factory=datetime.now,
description="Timestamp of the workflow execution",
)
class ConcurrentWorkflow(BaseSwarm):
"""
Represents a concurrent workflow that executes multiple agents concurrently in a production-grade manner.
Features include:
- Interactive model support
- Caching for repeated prompts
- Optional progress tracking
- Enhanced error handling and retries
- Input validation
Args:
name (str): The name of the workflow. Defaults to "ConcurrentWorkflow".
@ -72,11 +34,16 @@ class ConcurrentWorkflow(BaseSwarm):
agents (List[Agent]): The list of agents to be executed concurrently. Defaults to an empty list.
metadata_output_path (str): The path to save the metadata output. Defaults to "agent_metadata.json".
auto_save (bool): Flag indicating whether to automatically save the metadata. Defaults to False.
output_schema (BaseModel): The output schema for the metadata. Defaults to MetadataSchema.
output_type (str): The type of output format. Defaults to "dict".
max_loops (int): The maximum number of loops for each agent. Defaults to 1.
return_str_on (bool): Flag indicating whether to return the output as a string. Defaults to False.
agent_responses (list): The list of agent responses. Defaults to an empty list.
auto_generate_prompts (bool): Flag indicating whether to auto-generate prompts for agents. Defaults to False.
return_entire_history (bool): Flag indicating whether to return the entire conversation history. Defaults to False.
interactive (bool): Flag indicating whether to enable interactive mode. Defaults to False.
cache_size (int): The size of the cache. Defaults to 100.
max_retries (int): The maximum number of retry attempts. Defaults to 3.
retry_delay (float): The delay between retry attempts in seconds. Defaults to 1.0.
show_progress (bool): Flag indicating whether to show progress. Defaults to False.
Raises:
ValueError: If the list of agents is empty or if the description is empty.
@ -87,13 +54,18 @@ class ConcurrentWorkflow(BaseSwarm):
agents (List[Agent]): The list of agents to be executed concurrently.
metadata_output_path (str): The path to save the metadata output.
auto_save (bool): Flag indicating whether to automatically save the metadata.
output_schema (BaseModel): The output schema for the metadata.
output_type (str): The type of output format.
max_loops (int): The maximum number of loops for each agent.
return_str_on (bool): Flag indicating whether to return the output as a string.
agent_responses (list): The list of agent responses.
auto_generate_prompts (bool): Flag indicating whether to auto-generate prompts for agents.
retry_attempts (int): The number of retry attempts for failed agent executions.
retry_wait_time (int): The initial wait time for retries in seconds.
return_entire_history (bool): Flag indicating whether to return the entire conversation history.
interactive (bool): Flag indicating whether to enable interactive mode.
cache_size (int): The size of the cache.
max_retries (int): The maximum number of retry attempts.
retry_delay (float): The delay between retry attempts in seconds.
show_progress (bool): Flag indicating whether to show progress.
_cache (dict): The cache for storing agent outputs.
_progress_bar (tqdm): The progress bar for tracking execution.
"""
def __init__(
@ -103,13 +75,16 @@ class ConcurrentWorkflow(BaseSwarm):
agents: List[Union[Agent, Callable]] = [],
metadata_output_path: str = "agent_metadata.json",
auto_save: bool = True,
output_schema: BaseModel = MetadataSchema,
output_type: str = "dict-all-except-first",
max_loops: int = 1,
return_str_on: bool = False,
agent_responses: list = [],
auto_generate_prompts: bool = False,
output_type: OutputType = "dict",
return_entire_history: bool = False,
interactive: bool = False,
cache_size: int = 100,
max_retries: int = 3,
retry_delay: float = 1.0,
show_progress: bool = False,
*args,
**kwargs,
):
@ -125,18 +100,22 @@ class ConcurrentWorkflow(BaseSwarm):
self.agents = agents
self.metadata_output_path = metadata_output_path
self.auto_save = auto_save
self.output_schema = output_schema
self.max_loops = max_loops
self.return_str_on = return_str_on
self.agent_responses = agent_responses
self.auto_generate_prompts = auto_generate_prompts
self.max_workers = os.cpu_count()
self.output_type = output_type
self.return_entire_history = return_entire_history
self.tasks = [] # Initialize tasks list
self.interactive = interactive
self.cache_size = cache_size
self.max_retries = max_retries
self.retry_delay = retry_delay
self.show_progress = show_progress
self._cache = {}
self._progress_bar = None
self.reliability_check()
self.conversation = Conversation()
def disable_agent_prints(self):
@ -145,29 +124,47 @@ class ConcurrentWorkflow(BaseSwarm):
def reliability_check(self):
try:
logger.info("Starting reliability checks")
formatter.print_panel(
content=f"\n 🏷️ Name: {self.name}\n 📝 Description: {self.description}\n 🤖 Agents: {len(self.agents)}\n 🔄 Max Loops: {self.max_loops}\n ",
title="⚙️ Concurrent Workflow Settings",
style="bold blue",
)
formatter.print_panel(
content="🔍 Starting reliability checks",
title="🔒 Reliability Checks",
style="bold blue",
)
if self.name is None:
logger.error("A name is required for the swarm")
raise ValueError("A name is required for the swarm")
logger.error("❌ A name is required for the swarm")
raise ValueError(
"❌ A name is required for the swarm"
)
if not self.agents:
logger.error("The list of agents must not be empty.")
if not self.agents or len(self.agents) <= 1:
logger.error(
"❌ The list of agents must not be empty."
)
raise ValueError(
"The list of agents must not be empty."
"The list of agents must not be empty."
)
if not self.description:
logger.error("A description is required.")
raise ValueError("A description is required.")
logger.error("❌ A description is required.")
raise ValueError("❌ A description is required.")
formatter.print_panel(
content="✅ Reliability checks completed successfully",
title="🎉 Reliability Checks",
style="bold green",
)
logger.info("Reliability checks completed successfully")
except ValueError as e:
logger.error(f"Reliability check failed: {e}")
logger.error(f"Reliability check failed: {e}")
raise
except Exception as e:
logger.error(
f"An unexpected error occurred during reliability checks: {e}"
f"💥 An unexpected error occurred during reliability checks: {e}"
)
raise
@ -184,147 +181,179 @@ class ConcurrentWorkflow(BaseSwarm):
for agent in self.agents:
agent.auto_generate_prompt = True
# @retry(wait=wait_exponential(min=2), stop=stop_after_attempt(3))
def transform_metadata_schema_to_str(
self, schema: MetadataSchema
):
"""
Converts the metadata swarm schema into a string format with the agent name, response, and time.
Args:
schema (MetadataSchema): The metadata schema to convert.
Returns:
str: The string representation of the metadata schema.
Example:
>>> metadata_schema = MetadataSchema()
>>> metadata_str = workflow.transform_metadata_schema_to_str(metadata_schema)
>>> print(metadata_str)
"""
self.agent_responses = [
f"Agent Name: {agent.agent_name}\nResponse: {agent.output}\n\n"
for agent in schema.agents
]
# Return the agent responses as a string
return "\n".join(self.agent_responses)
def save_metadata(self):
"""
Saves the metadata to a JSON file based on the auto_save flag.
Example:
>>> workflow.save_metadata()
>>> # Metadata will be saved to the specified path if auto_save is True.
"""
# Save metadata to a JSON file
if self.auto_save:
logger.info(
f"Saving metadata to {self.metadata_output_path}"
@lru_cache(maxsize=100)
def _cached_run(self, task: str, agent_id: int) -> Any:
"""Cached version of agent execution to avoid redundant computations"""
return self.agents[agent_id].run(task=task)
def enable_progress_bar(self):
"""Enable progress bar display"""
self.show_progress = True
def disable_progress_bar(self):
"""Disable progress bar display"""
if self._progress_bar:
self._progress_bar.close()
self._progress_bar = None
self.show_progress = False
def _create_progress_bar(self, total: int):
"""Create a progress bar for tracking execution"""
if self.show_progress:
try:
self._progress_bar = tqdm(
total=total,
desc="Processing tasks",
unit="task",
disable=not self.show_progress,
ncols=100,
bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]",
)
except Exception as e:
logger.warning(f"Failed to create progress bar: {e}")
self.show_progress = False
self._progress_bar = None
return self._progress_bar
def _update_progress(self, increment: int = 1):
"""Update the progress bar"""
if self._progress_bar and self.show_progress:
try:
self._progress_bar.update(increment)
except Exception as e:
logger.warning(f"Failed to update progress bar: {e}")
self.disable_progress_bar()
def _validate_input(self, task: str) -> bool:
"""Validate input task"""
if not isinstance(task, str):
raise ValueError("Task must be a string")
if not task.strip():
raise ValueError("Task cannot be empty")
return True
def _handle_interactive(self, task: str) -> str:
"""Handle interactive mode for task input"""
if self.interactive:
from swarms.utils.formatter import formatter
# Display current task in a panel
formatter.print_panel(
content=f"Current task: {task}",
title="Task Status",
style="bold blue",
)
create_file_in_folder(
os.getenv("WORKSPACE_DIR"),
self.metadata_output_path,
self.output_schema.model_dump_json(indent=4),
# Get user input with formatted prompt
formatter.print_panel(
content="Do you want to modify this task? (y/n/q to quit): ",
title="User Input",
style="bold green",
)
response = input().lower()
if response == "q":
return None
elif response == "y":
formatter.print_panel(
content="Enter new task: ",
title="New Task Input",
style="bold yellow",
)
new_task = input()
return new_task
return task
def _run_with_retry(
self, agent: Agent, task: str, img: str = None
) -> Any:
"""Run agent with retry mechanism"""
for attempt in range(self.max_retries):
try:
output = agent.run(task=task, img=img)
self.conversation.add(agent.agent_name, output)
return output
except Exception as e:
if attempt == self.max_retries - 1:
logger.error(
f"Error running agent {agent.agent_name} after {self.max_retries} attempts: {e}"
)
raise
logger.warning(
f"Attempt {attempt + 1} failed for agent {agent.agent_name}: {e}"
)
time.sleep(
self.retry_delay * (attempt + 1)
) # Exponential backoff
def _run(
self, task: str, img: str = None, *args, **kwargs
) -> Union[Dict[str, Any], str]:
"""
Runs the workflow for the given task, executes agents concurrently using ThreadPoolExecutor, and saves metadata.
Enhanced run method with caching, progress tracking, and better error handling
"""
Args:
task (str): The task or query to give to all agents.
img (str): The image to be processed by the agents.
# Validate and potentially modify task
self._validate_input(task)
task = self._handle_interactive(task)
Returns:
Dict[str, Any]: The final metadata as a dictionary.
str: The final metadata as a string if return_str_on is True.
# Add task to conversation
self.conversation.add("User", task)
Example:
>>> metadata = workflow.run(task="Example task", img="example.jpg")
>>> print(metadata)
"""
logger.info(
f"Running concurrent workflow with {len(self.agents)} agents."
)
self.conversation.add(
"User",
task,
)
# Create progress bar if enabled
if self.show_progress:
self._create_progress_bar(len(self.agents))
def run_agent(
agent: Agent, task: str, img: str = None
) -> AgentOutputSchema:
start_time = datetime.now()
) -> Any:
try:
output = agent.run(task=task)
self.conversation.add(
agent.agent_name,
output,
)
# Check cache first
cache_key = f"{task}_{agent.agent_name}"
if cache_key in self._cache:
output = self._cache[cache_key]
else:
output = self._run_with_retry(agent, task, img)
# Update cache
if len(self._cache) >= self.cache_size:
self._cache.pop(next(iter(self._cache)))
self._cache[cache_key] = output
self._update_progress()
return output
except Exception as e:
logger.error(
f"Error running agent {agent.agent_name}: {e}"
)
self._update_progress()
raise
end_time = datetime.now()
duration = (end_time - start_time).total_seconds()
agent_output = AgentOutputSchema(
run_id=uuid.uuid4().hex,
agent_name=agent.agent_name,
task=task,
output=output,
start_time=start_time,
end_time=end_time,
duration=duration,
)
logger.info(
f"Agent {agent.agent_name} completed task: {task} in {duration:.2f} seconds."
)
return agent_output
with ThreadPoolExecutor(
max_workers=os.cpu_count()
) as executor:
agent_outputs = list(
executor.map(
lambda agent: run_agent(agent, task), self.agents
try:
with ThreadPoolExecutor(
max_workers=self.max_workers
) as executor:
list(
executor.map(
lambda agent: run_agent(agent, task),
self.agents,
)
)
)
self.output_schema = MetadataSchema(
swarm_id=uuid.uuid4().hex,
task=task,
description=self.description,
agents=agent_outputs,
finally:
if self._progress_bar and self.show_progress:
try:
self._progress_bar.close()
except Exception as e:
logger.warning(
f"Failed to close progress bar: {e}"
)
finally:
self._progress_bar = None
return history_output_formatter(
self.conversation,
type=self.output_type,
)
self.save_metadata()
if self.return_str_on:
return self.transform_metadata_schema_to_str(
self.output_schema
)
elif self.return_entire_history:
return self.conversation.return_history_as_string()
elif self.output_type == "list":
return self.conversation.return_messages_as_list()
elif self.output_type == "dict":
return self.conversation.return_messages_as_dictionary()
else:
return self.output_schema.model_dump_json(indent=4)
def run(
self,
task: Optional[str] = None,
@ -333,9 +362,11 @@ class ConcurrentWorkflow(BaseSwarm):
**kwargs,
) -> Any:
"""
Executes the agent's run method on a specified device.
Executes the agent's run method on a specified device with optional interactive mode.
This method attempts to execute the agent's run method on a specified device, either CPU or GPU. It logs the device selection and the number of cores or GPU ID used. If the device is set to CPU, it can use all available cores or a specific core specified by `device_id`. If the device is set to GPU, it uses the GPU specified by `device_id`.
This method attempts to execute the agent's run method on a specified device, either CPU or GPU.
It supports both standard execution and interactive mode where users can modify tasks and continue
the workflow interactively.
Args:
task (Optional[str], optional): The task to be executed. Defaults to None.
@ -359,8 +390,73 @@ class ConcurrentWorkflow(BaseSwarm):
self.tasks.append(task)
try:
outputs = self._run(task, img, *args, **kwargs)
return outputs
# Handle interactive mode
if self.interactive:
current_task = task
loop_count = 0
while loop_count < self.max_loops:
if (
self.max_loops is not None
and loop_count >= self.max_loops
):
formatter.print_panel(
content=f"Maximum number of loops ({self.max_loops}) reached.",
title="Session Complete",
style="bold red",
)
break
if current_task is None:
formatter.print_panel(
content="Enter your task (or 'q' to quit): ",
title="Task Input",
style="bold blue",
)
current_task = input()
if current_task.lower() == "q":
break
# Run the workflow with the current task
try:
outputs = self._run(
current_task, img, *args, **kwargs
)
formatter.print_panel(
content=str(outputs),
title="Workflow Result",
style="bold green",
)
except Exception as e:
formatter.print_panel(
content=f"Error: {str(e)}",
title="Error",
style="bold red",
)
# Ask if user wants to continue
formatter.print_panel(
content="Do you want to continue with a new task? (y/n): ",
title="Continue Session",
style="bold yellow",
)
if input().lower() != "y":
break
current_task = None
loop_count += 1
formatter.print_panel(
content="Interactive session ended.",
title="Session Complete",
style="bold blue",
)
return outputs
else:
# Standard non-interactive execution
outputs = self._run(task, img, *args, **kwargs)
return outputs
except ValueError as e:
logger.error(f"Invalid device specified: {e}")
raise e
@ -368,29 +464,48 @@ class ConcurrentWorkflow(BaseSwarm):
logger.error(f"An error occurred during execution: {e}")
raise e
def run_batched(
self, tasks: List[str]
) -> List[Union[Dict[str, Any], str]]:
def run_batched(self, tasks: List[str]) -> Any:
"""
Runs the workflow for a batch of tasks, executes agents concurrently for each task, and saves metadata in a production-grade manner.
Enhanced batched execution with progress tracking
"""
if not tasks:
raise ValueError("Tasks list cannot be empty")
Args:
tasks (List[str]): A list of tasks or queries to give to all agents.
results = []
Returns:
List[Union[Dict[str, Any], str]]: A list of final metadata for each task, either as a dictionary or a string.
# Create progress bar if enabled
if self.show_progress:
self._create_progress_bar(len(tasks))
try:
for task in tasks:
result = self.run(task)
results.append(result)
self._update_progress()
finally:
if self._progress_bar and self.show_progress:
try:
self._progress_bar.close()
except Exception as e:
logger.warning(
f"Failed to close progress bar: {e}"
)
finally:
self._progress_bar = None
Example:
>>> tasks = ["Task 1", "Task 2"]
>>> results = workflow.run_batched(tasks)
>>> print(results)
"""
results = []
for task in tasks:
result = self.run(task)
results.append(result)
return results
def clear_cache(self):
"""Clear the task cache"""
self._cache.clear()
def get_cache_stats(self) -> Dict[str, int]:
"""Get cache statistics"""
return {
"cache_size": len(self._cache),
"max_cache_size": self.cache_size,
}
# if __name__ == "__main__":
# # Assuming you've already initialized some agents outside of this class

@ -1168,7 +1168,6 @@ class ModelGrid:
try:
# This would need to be implemented based on the specific model types
# and tasks supported. Here's a simple placeholder:
model = model_metadata.model
if model_metadata.model_type == ModelType.PYTORCH:
# Run PyTorch model

@ -1,6 +1,4 @@
import os
import json
from datetime import datetime
from pathlib import Path
import tempfile
import threading
@ -10,6 +8,7 @@ from swarms.communication.duckdb_wrap import (
MessageType,
)
def setup_test():
"""Set up test environment."""
temp_dir = tempfile.TemporaryDirectory()
@ -21,25 +20,34 @@ def setup_test():
)
return temp_dir, db_path, conversation
def cleanup_test(temp_dir, db_path):
"""Clean up test environment."""
if os.path.exists(db_path):
os.remove(db_path)
temp_dir.cleanup()
def test_initialization():
"""Test conversation initialization."""
temp_dir, db_path, _ = setup_test()
try:
conv = DuckDBConversation(db_path=str(db_path))
assert conv.db_path == db_path, "Database path mismatch"
assert conv.table_name == "conversations", "Table name mismatch"
assert conv.enable_timestamps is True, "Timestamps should be enabled"
assert conv.current_conversation_id is not None, "Conversation ID should not be None"
assert (
conv.table_name == "conversations"
), "Table name mismatch"
assert (
conv.enable_timestamps is True
), "Timestamps should be enabled"
assert (
conv.current_conversation_id is not None
), "Conversation ID should not be None"
print("✓ Initialization test passed")
finally:
cleanup_test(temp_dir, db_path)
def test_add_message():
"""Test adding a single message."""
temp_dir, db_path, conversation = setup_test()
@ -50,11 +58,14 @@ def test_add_message():
message_type=MessageType.USER,
)
assert msg_id is not None, "Message ID should not be None"
assert isinstance(msg_id, int), "Message ID should be an integer"
assert isinstance(
msg_id, int
), "Message ID should be an integer"
print("✓ Add message test passed")
finally:
cleanup_test(temp_dir, db_path)
def test_add_complex_message():
"""Test adding a message with complex content."""
temp_dir, db_path, conversation = setup_test()
@ -62,20 +73,21 @@ def test_add_complex_message():
complex_content = {
"text": "Hello",
"data": [1, 2, 3],
"nested": {"key": "value"}
"nested": {"key": "value"},
}
msg_id = conversation.add(
role="assistant",
content=complex_content,
message_type=MessageType.ASSISTANT,
metadata={"source": "test"},
token_count=10
token_count=10,
)
assert msg_id is not None, "Message ID should not be None"
print("✓ Add complex message test passed")
finally:
cleanup_test(temp_dir, db_path)
def test_batch_add():
"""Test batch adding messages."""
temp_dir, db_path, conversation = setup_test()
@ -84,21 +96,24 @@ def test_batch_add():
Message(
role="user",
content="First message",
message_type=MessageType.USER
message_type=MessageType.USER,
),
Message(
role="assistant",
content="Second message",
message_type=MessageType.ASSISTANT
)
message_type=MessageType.ASSISTANT,
),
]
msg_ids = conversation.batch_add(messages)
assert len(msg_ids) == 2, "Should have 2 message IDs"
assert all(isinstance(id, int) for id in msg_ids), "All IDs should be integers"
assert all(
isinstance(id, int) for id in msg_ids
), "All IDs should be integers"
print("✓ Batch add test passed")
finally:
cleanup_test(temp_dir, db_path)
def test_get_str():
"""Test getting conversation as string."""
temp_dir, db_path, conversation = setup_test()
@ -107,11 +122,14 @@ def test_get_str():
conversation.add("assistant", "Hi there!")
conv_str = conversation.get_str()
assert "user: Hello" in conv_str, "User message not found"
assert "assistant: Hi there!" in conv_str, "Assistant message not found"
assert (
"assistant: Hi there!" in conv_str
), "Assistant message not found"
print("✓ Get string test passed")
finally:
cleanup_test(temp_dir, db_path)
def test_get_messages():
"""Test getting messages with pagination."""
temp_dir, db_path, conversation = setup_test()
@ -123,14 +141,19 @@ def test_get_messages():
assert len(all_messages) == 5, "Should have 5 messages"
limited_messages = conversation.get_messages(limit=2)
assert len(limited_messages) == 2, "Should have 2 limited messages"
assert (
len(limited_messages) == 2
), "Should have 2 limited messages"
offset_messages = conversation.get_messages(offset=2)
assert len(offset_messages) == 3, "Should have 3 offset messages"
assert (
len(offset_messages) == 3
), "Should have 3 offset messages"
print("✓ Get messages test passed")
finally:
cleanup_test(temp_dir, db_path)
def test_search_messages():
"""Test searching messages."""
temp_dir, db_path, conversation = setup_test()
@ -140,12 +163,17 @@ def test_search_messages():
conversation.add("user", "Goodbye world")
results = conversation.search_messages("world")
assert len(results) == 2, "Should find 2 messages with 'world'"
assert all("world" in msg["content"] for msg in results), "All results should contain 'world'"
assert (
len(results) == 2
), "Should find 2 messages with 'world'"
assert all(
"world" in msg["content"] for msg in results
), "All results should contain 'world'"
print("✓ Search messages test passed")
finally:
cleanup_test(temp_dir, db_path)
def test_get_statistics():
"""Test getting conversation statistics."""
temp_dir, db_path, conversation = setup_test()
@ -154,13 +182,20 @@ def test_get_statistics():
conversation.add("assistant", "Hi", token_count=1)
stats = conversation.get_statistics()
assert stats["total_messages"] == 2, "Should have 2 total messages"
assert stats["unique_roles"] == 2, "Should have 2 unique roles"
assert stats["total_tokens"] == 3, "Should have 3 total tokens"
assert (
stats["total_messages"] == 2
), "Should have 2 total messages"
assert (
stats["unique_roles"] == 2
), "Should have 2 unique roles"
assert (
stats["total_tokens"] == 3
), "Should have 3 total tokens"
print("✓ Get statistics test passed")
finally:
cleanup_test(temp_dir, db_path)
def test_json_operations():
"""Test JSON save and load operations."""
temp_dir, db_path, conversation = setup_test()
@ -175,12 +210,17 @@ def test_json_operations():
new_conversation = DuckDBConversation(
db_path=str(Path(temp_dir.name) / "new.duckdb")
)
assert new_conversation.load_from_json(str(json_path)), "Should load from JSON"
assert len(new_conversation.get_messages()) == 2, "Should have 2 messages after load"
assert new_conversation.load_from_json(
str(json_path)
), "Should load from JSON"
assert (
len(new_conversation.get_messages()) == 2
), "Should have 2 messages after load"
print("✓ JSON operations test passed")
finally:
cleanup_test(temp_dir, db_path)
def test_yaml_operations():
"""Test YAML save and load operations."""
temp_dir, db_path, conversation = setup_test()
@ -195,29 +235,53 @@ def test_yaml_operations():
new_conversation = DuckDBConversation(
db_path=str(Path(temp_dir.name) / "new.duckdb")
)
assert new_conversation.load_from_yaml(str(yaml_path)), "Should load from YAML"
assert len(new_conversation.get_messages()) == 2, "Should have 2 messages after load"
assert new_conversation.load_from_yaml(
str(yaml_path)
), "Should load from YAML"
assert (
len(new_conversation.get_messages()) == 2
), "Should have 2 messages after load"
print("✓ YAML operations test passed")
finally:
cleanup_test(temp_dir, db_path)
def test_message_types():
"""Test different message types."""
temp_dir, db_path, conversation = setup_test()
try:
conversation.add("system", "System message", message_type=MessageType.SYSTEM)
conversation.add("user", "User message", message_type=MessageType.USER)
conversation.add("assistant", "Assistant message", message_type=MessageType.ASSISTANT)
conversation.add("function", "Function message", message_type=MessageType.FUNCTION)
conversation.add("tool", "Tool message", message_type=MessageType.TOOL)
conversation.add(
"system",
"System message",
message_type=MessageType.SYSTEM,
)
conversation.add(
"user", "User message", message_type=MessageType.USER
)
conversation.add(
"assistant",
"Assistant message",
message_type=MessageType.ASSISTANT,
)
conversation.add(
"function",
"Function message",
message_type=MessageType.FUNCTION,
)
conversation.add(
"tool", "Tool message", message_type=MessageType.TOOL
)
messages = conversation.get_messages()
assert len(messages) == 5, "Should have 5 messages"
assert all("message_type" in msg for msg in messages), "All messages should have type"
assert all(
"message_type" in msg for msg in messages
), "All messages should have type"
print("✓ Message types test passed")
finally:
cleanup_test(temp_dir, db_path)
def test_delete_operations():
"""Test deletion operations."""
temp_dir, db_path, conversation = setup_test()
@ -225,58 +289,79 @@ def test_delete_operations():
conversation.add("user", "Hello")
conversation.add("assistant", "Hi")
assert conversation.delete_current_conversation(), "Should delete conversation"
assert len(conversation.get_messages()) == 0, "Should have no messages after delete"
assert (
conversation.delete_current_conversation()
), "Should delete conversation"
assert (
len(conversation.get_messages()) == 0
), "Should have no messages after delete"
conversation.add("user", "New message")
assert conversation.clear_all(), "Should clear all messages"
assert len(conversation.get_messages()) == 0, "Should have no messages after clear"
assert (
len(conversation.get_messages()) == 0
), "Should have no messages after clear"
print("✓ Delete operations test passed")
finally:
cleanup_test(temp_dir, db_path)
def test_concurrent_operations():
"""Test concurrent operations."""
temp_dir, db_path, conversation = setup_test()
try:
def add_messages():
for i in range(10):
conversation.add("user", f"Message {i}")
threads = [threading.Thread(target=add_messages) for _ in range(5)]
threads = [
threading.Thread(target=add_messages) for _ in range(5)
]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
messages = conversation.get_messages()
assert len(messages) == 50, "Should have 50 messages (10 * 5 threads)"
assert (
len(messages) == 50
), "Should have 50 messages (10 * 5 threads)"
print("✓ Concurrent operations test passed")
finally:
cleanup_test(temp_dir, db_path)
def test_error_handling():
"""Test error handling."""
temp_dir, db_path, conversation = setup_test()
try:
# Test invalid message type
try:
conversation.add("user", "Message", message_type="invalid")
assert False, "Should raise exception for invalid message type"
conversation.add(
"user", "Message", message_type="invalid"
)
assert (
False
), "Should raise exception for invalid message type"
except Exception:
pass
# Test invalid JSON content
try:
conversation.add("user", {"invalid": object()})
assert False, "Should raise exception for invalid JSON content"
assert (
False
), "Should raise exception for invalid JSON content"
except Exception:
pass
# Test invalid file operations
try:
conversation.load_from_json("/nonexistent/path.json")
assert False, "Should raise exception for invalid file path"
assert (
False
), "Should raise exception for invalid file path"
except Exception:
pass
@ -284,6 +369,7 @@ def test_error_handling():
finally:
cleanup_test(temp_dir, db_path)
def run_all_tests():
"""Run all tests."""
print("Running DuckDB Conversation tests...")
@ -301,7 +387,7 @@ def run_all_tests():
test_message_types,
test_delete_operations,
test_concurrent_operations,
test_error_handling
test_error_handling,
]
for test in tests:
@ -313,5 +399,6 @@ def run_all_tests():
print("\nAll tests completed successfully!")
if __name__ == '__main__':
if __name__ == "__main__":
run_all_tests()

@ -98,8 +98,8 @@ def test_basic_conversation() -> bool:
# Test adding messages
console.print("\n[bold]Adding messages...[/bold]")
msg_id1 = conversation.add("user", "Hello")
msg_id2 = conversation.add("assistant", "Hi there!")
conversation.add("user", "Hello")
conversation.add("assistant", "Hi there!")
# Test getting messages
console.print("\n[bold]Retrieved messages:[/bold]")

@ -156,7 +156,7 @@ def test_stopping_token(mocked_sleep, basic_flow):
# Test interactive mode
def test_interactive_mode(basic_flow):
def test_interactive(basic_flow):
basic_flow.interactive = True
assert basic_flow.interactive
@ -309,7 +309,7 @@ def test_flow_run(flow_instance):
assert len(response) > 0
def test_flow_interactive_mode(flow_instance):
def test_flow_interactive(flow_instance):
# Test the interactive mode of the Agent class
flow_instance.interactive = True
response = flow_instance.run("Test task")

Loading…
Cancel
Save