[FEAT][Agent][save][load] [FIX][openai assistants]

pull/692/head
Kye Gomez 3 weeks ago
parent 9b1106ac91
commit 9abe300548

@ -0,0 +1,348 @@
import json
import os
import platform
import sys
import traceback
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple
import psutil
import requests
from loguru import logger
from swarm_models import OpenAIChat
from swarms.structs.agent import Agent
@dataclass
class SwarmSystemInfo:
"""System information for Swarms issue reports."""
os_name: str
os_version: str
python_version: str
cpu_usage: float
memory_usage: float
disk_usage: float
swarms_version: str # Added Swarms version tracking
cuda_available: bool # Added CUDA availability check
gpu_info: Optional[str] # Added GPU information
class SwarmsIssueReporter:
"""
Production-grade GitHub issue reporter specifically designed for the Swarms library.
Automatically creates detailed issues for the https://github.com/kyegomez/swarms repository.
Features:
- Swarms-specific error categorization
- Automatic version and dependency tracking
- CUDA and GPU information collection
- Integration with Swarms logging system
- Detailed environment information
"""
REPO_OWNER = "kyegomez"
REPO_NAME = "swarms"
ISSUE_CATEGORIES = {
"agent": ["agent", "automation"],
"memory": ["memory", "storage"],
"tool": ["tools", "integration"],
"llm": ["llm", "model"],
"performance": ["performance", "optimization"],
"compatibility": ["compatibility", "environment"],
}
def __init__(
self,
github_token: str,
rate_limit: int = 10,
rate_period: int = 3600,
log_file: str = "swarms_issues.log",
enable_duplicate_check: bool = True,
):
"""
Initialize the Swarms Issue Reporter.
Args:
github_token (str): GitHub personal access token
rate_limit (int): Maximum number of issues to create per rate_period
rate_period (int): Time period for rate limiting in seconds
log_file (str): Path to log file
enable_duplicate_check (bool): Whether to check for duplicate issues
"""
self.github_token = github_token
self.rate_limit = rate_limit
self.rate_period = rate_period
self.enable_duplicate_check = enable_duplicate_check
self.github_token = os.getenv("GITHUB_API_KEY")
# Initialize logging
log_path = os.path.join(os.getcwd(), "logs", log_file)
os.makedirs(os.path.dirname(log_path), exist_ok=True)
logger.add(
log_path,
rotation="1 day",
retention="1 month",
compression="zip",
)
# Issue tracking
self.issues_created = []
self.last_issue_time = datetime.now()
# Validate GitHub token
# self._validate_github_credentials()
def _get_swarms_version(self) -> str:
"""Get the installed version of Swarms."""
try:
import swarms
return swarms.__version__
except:
return "Unknown"
def _get_gpu_info(self) -> Tuple[bool, Optional[str]]:
"""Get GPU information and CUDA availability."""
try:
import torch
cuda_available = torch.cuda.is_available()
if cuda_available:
gpu_info = torch.cuda.get_device_name(0)
return cuda_available, gpu_info
return False, None
except:
return False, None
def _get_system_info(self) -> SwarmSystemInfo:
"""Collect system and Swarms-specific information."""
cuda_available, gpu_info = self._get_gpu_info()
return SwarmSystemInfo(
os_name=platform.system(),
os_version=platform.version(),
python_version=sys.version,
cpu_usage=psutil.cpu_percent(),
memory_usage=psutil.virtual_memory().percent,
disk_usage=psutil.disk_usage("/").percent,
swarms_version=self._get_swarms_version(),
cuda_available=cuda_available,
gpu_info=gpu_info,
)
def _categorize_error(
self, error: Exception, context: Dict
) -> List[str]:
"""Categorize the error and return appropriate labels."""
error_str = str(error).lower()
type(error).__name__
labels = ["bug", "automated"]
# Check error message and context for category keywords
for (
category,
category_labels,
) in self.ISSUE_CATEGORIES.items():
if any(
keyword in error_str for keyword in category_labels
):
labels.extend(category_labels)
break
# Add severity label based on error type
if issubclass(type(error), (SystemError, MemoryError)):
labels.append("severity:critical")
elif issubclass(type(error), (ValueError, TypeError)):
labels.append("severity:medium")
else:
labels.append("severity:low")
return list(set(labels)) # Remove duplicates
def _format_swarms_issue_body(
self,
error: Exception,
system_info: SwarmSystemInfo,
context: Dict,
) -> str:
"""Format the issue body with Swarms-specific information."""
return f"""
## Swarms Error Report
- **Error Type**: {type(error).__name__}
- **Error Message**: {str(error)}
- **Swarms Version**: {system_info.swarms_version}
## Environment Information
- **OS**: {system_info.os_name} {system_info.os_version}
- **Python Version**: {system_info.python_version}
- **CUDA Available**: {system_info.cuda_available}
- **GPU**: {system_info.gpu_info or "N/A"}
- **CPU Usage**: {system_info.cpu_usage}%
- **Memory Usage**: {system_info.memory_usage}%
- **Disk Usage**: {system_info.disk_usage}%
## Stack Trace
{traceback.format_exc()}
## Context
{json.dumps(context, indent=2)}
## Dependencies
{self._get_dependencies_info()}
## Time of Occurrence
{datetime.now().isoformat()}
---
*This issue was automatically generated by SwarmsIssueReporter*
"""
def _get_dependencies_info(self) -> str:
"""Get information about installed dependencies."""
try:
import pkg_resources
deps = []
for dist in pkg_resources.working_set:
deps.append(f"- {dist.key} {dist.version}")
return "\n".join(deps)
except:
return "Unable to fetch dependency information"
# First, add this method to your SwarmsIssueReporter class
def _check_rate_limit(self) -> bool:
"""Check if we're within rate limits."""
now = datetime.now()
time_diff = (now - self.last_issue_time).total_seconds()
if (
len(self.issues_created) >= self.rate_limit
and time_diff < self.rate_period
):
logger.warning("Rate limit exceeded for issue creation")
return False
# Clean up old issues from tracking
self.issues_created = [
time
for time in self.issues_created
if (now - time).total_seconds() < self.rate_period
]
return True
def report_swarms_issue(
self,
error: Exception,
agent: Optional[Agent] = None,
context: Dict[str, Any] = None,
priority: str = "normal",
) -> Optional[int]:
"""
Report a Swarms-specific issue to GitHub.
Args:
error (Exception): The exception to report
agent (Optional[Agent]): The Swarms agent instance that encountered the error
context (Dict[str, Any]): Additional context about the error
priority (str): Issue priority ("low", "normal", "high", "critical")
Returns:
Optional[int]: Issue number if created successfully
"""
try:
if not self._check_rate_limit():
logger.warning(
"Skipping issue creation due to rate limit"
)
return None
# Collect system information
system_info = self._get_system_info()
# Prepare context with agent information if available
full_context = context or {}
if agent:
full_context.update(
{
"agent_name": agent.agent_name,
"agent_description": agent.agent_description,
"max_loops": agent.max_loops,
"context_length": agent.context_length,
}
)
# Create issue title
title = f"[{type(error).__name__}] {str(error)[:100]}"
if agent:
title = f"[Agent: {agent.agent_name}] {title}"
# Get appropriate labels
labels = self._categorize_error(error, full_context)
labels.append(f"priority:{priority}")
# Create the issue
url = f"https://api.github.com/repos/{self.REPO_OWNER}/{self.REPO_NAME}/issues"
data = {
"title": title,
"body": self._format_swarms_issue_body(
error, system_info, full_context
),
"labels": labels,
}
response = requests.post(
url,
headers={
"Authorization": f"token {self.github_token}"
},
json=data,
)
response.raise_for_status()
issue_number = response.json()["number"]
logger.info(
f"Successfully created Swarms issue #{issue_number}"
)
return issue_number
except Exception as e:
logger.error(f"Error creating Swarms issue: {str(e)}")
return None
# from swarms import Agent
# from swarm_models import OpenAIChat
# from swarms.utils.issue_reporter import SwarmsIssueReporter
# import os
# Setup the reporter with your GitHub token
reporter = SwarmsIssueReporter(
github_token=os.getenv("GITHUB_API_KEY")
)
# Force an error to test the reporter
try:
# This will raise an error since the input isn't valid
# Create an agent that might have issues
model = OpenAIChat(model_name="gpt-4o")
agent = Agent(agent_name="Test-Agent", max_loops=1)
result = agent.run(None)
raise ValueError("test")
except Exception as e:
# Report the issue
issue_number = reporter.report_swarms_issue(
error=e,
agent=agent,
context={"task": "test_run"},
priority="high",
)
print(f"Created issue number: {issue_number}")

@ -192,15 +192,6 @@ nav:
- Conversation: "swarms/structs/conversation.md"
# - Task: "swarms/structs/task.md"
- Full API Reference: "swarms/framework/reference.md"
- Contributing:
- Contributing: "swarms/contributing.md"
- Tests: "swarms/framework/test.md"
- Code Cleanliness: "swarms/framework/code_cleanliness.md"
- Philosophy: "swarms/concept/philosophy.md"
- Changelog:
- Swarms 5.6.8: "swarms/changelog/5_6_8.md"
- Swarms 5.8.1: "swarms/changelog/5_8_1.md"
- Swarms 5.9.2: "swarms/changelog/changelog_new.md"
- Swarm Models:
- Overview: "swarms/models/index.md"
# - Models Available: "swarms/models/index.md"
@ -254,16 +245,25 @@ nav:
# - Tools API:
# - Overview: "swarms_platform/tools_api.md"
# - Add Tools: "swarms_platform/fetch_tools.md"
- Guides:
- Unlocking Efficiency and Cost Savings in Healthcare; How Swarms of LLM Agents Can Revolutionize Medical Operations and Save Millions: "guides/healthcare_blog.md"
- Understanding Agent Evaluation Mechanisms: "guides/agent_evals.md"
- Agent Glossary: "swarms/glossary.md"
- The Ultimate Technical Guide to the Swarms CLI; A Step-by-Step Developers Guide: "swarms/cli/cli_guide.md"
- Prompting Guide:
- The Essence of Enterprise-Grade Prompting: "swarms/prompts/essence.md"
- An Analysis on Prompting Strategies: "swarms/prompts/overview.md"
- Managing Prompts in Production: "swarms/prompts/main.md"
# - Guides:
# - Unlocking Efficiency and Cost Savings in Healthcare; How Swarms of LLM Agents Can Revolutionize Medical Operations and Save Millions: "guides/healthcare_blog.md"
# - Understanding Agent Evaluation Mechanisms: "guides/agent_evals.md"
# - Agent Glossary: "swarms/glossary.md"
# - The Ultimate Technical Guide to the Swarms CLI; A Step-by-Step Developers Guide: "swarms/cli/cli_guide.md"
# - Prompting Guide:
# - The Essence of Enterprise-Grade Prompting: "swarms/prompts/essence.md"
# - An Analysis on Prompting Strategies: "swarms/prompts/overview.md"
# - Managing Prompts in Production: "swarms/prompts/main.md"
- Community:
- Contributing:
- Contributing: "swarms/contributing.md"
- Tests: "swarms/framework/test.md"
- Code Cleanliness: "swarms/framework/code_cleanliness.md"
- Philosophy: "swarms/concept/philosophy.md"
- Changelog:
- Swarms 5.6.8: "swarms/changelog/5_6_8.md"
- Swarms 5.8.1: "swarms/changelog/5_8_1.md"
- Swarms 5.9.2: "swarms/changelog/changelog_new.md"
- Bounty Program: "corporate/bounty_program.md"
- Corporate:
- Culture: "corporate/culture.md"

@ -2,6 +2,10 @@ from swarms import Agent
from swarms.prompts.finance_agent_sys_prompt import (
FINANCIAL_AGENT_SYS_PROMPT,
)
from swarm_models import OpenAIChat
model = OpenAIChat(model_name="gpt-4o")
# Initialize the agent
agent = Agent(
@ -9,20 +13,21 @@ agent = Agent(
agent_description="Personal finance advisor agent",
system_prompt=FINANCIAL_AGENT_SYS_PROMPT
+ "Output the <DONE> token when you're done creating a portfolio of etfs, index, funds, and more for AI",
model_name="gpt-4o", # Use any model from litellm
max_loops="auto",
max_loops=1,
llm=model,
dynamic_temperature_enabled=True,
user_name="Kye",
retry_attempts=3,
streaming_on=True,
context_length=16000,
# streaming_on=True,
context_length=8192,
return_step_meta=False,
output_type="str", # "json", "dict", "csv" OR "string" "yaml" and
auto_generate_prompt=False, # Auto generate prompt for the agent based on name, description, and system prompt, task
max_tokens=16000, # max output tokens
interactive=True,
max_tokens=4000, # max output tokens
# interactive=True,
stopping_token="<DONE>",
execute_tool=True,
saved_state_path="agent_00.json",
interactive=False,
)
agent.run(

@ -0,0 +1,189 @@
import requests
import datetime
from typing import List, Dict, Tuple
from loguru import logger
from swarms import Agent
from swarm_models import OpenAIChat
# GitHub API Configurations
GITHUB_REPO = "kyegomez/swarms" # Swarms GitHub repository
GITHUB_API_URL = f"https://api.github.com/repos/{GITHUB_REPO}/commits"
# Initialize Loguru
logger.add(
"commit_summary.log",
rotation="1 MB",
level="INFO",
backtrace=True,
diagnose=True,
)
# Step 1: Fetch the latest commits from GitHub
def fetch_latest_commits(
repo_url: str, limit: int = 5
) -> List[Dict[str, str]]:
"""
Fetch the latest commits from a public GitHub repository.
"""
logger.info(
f"Fetching the latest {limit} commits from {repo_url}"
)
try:
params = {"per_page": limit}
response = requests.get(repo_url, params=params)
response.raise_for_status()
commits = response.json()
commit_data = []
for commit in commits:
commit_data.append(
{
"sha": commit["sha"][:7], # Short commit hash
"author": commit["commit"]["author"]["name"],
"message": commit["commit"]["message"],
"date": commit["commit"]["author"]["date"],
}
)
logger.success("Successfully fetched commit data")
return commit_data
except Exception as e:
logger.error(f"Error fetching commits: {e}")
raise
# Step 2: Format commits and fetch current time
def format_commits_with_time(
commits: List[Dict[str, str]]
) -> Tuple[str, str]:
"""
Format commit data into a readable string and return current time.
"""
current_time = datetime.datetime.now().strftime(
"%Y-%m-%d %H:%M:%S"
)
logger.info(f"Formatting commits at {current_time}")
commit_summary = "\n".join(
[
f"- `{commit['sha']}` by {commit['author']} on {commit['date']}: {commit['message']}"
for commit in commits
]
)
logger.success("Commits formatted successfully")
return current_time, commit_summary
# Step 3: Build a dynamic system prompt
def build_custom_system_prompt(
current_time: str, commit_summary: str
) -> str:
"""
Build a dynamic system prompt with the current time and commit summary.
"""
logger.info("Building the custom system prompt for the agent")
prompt = f"""
You are a software analyst tasked with summarizing the latest commits from the Swarms GitHub repository.
The current time is **{current_time}**.
Here are the latest commits:
{commit_summary}
**Your task**:
1. Summarize the changes into a clear and concise table in **markdown format**.
2. Highlight the key improvements and fixes.
3. End your output with the token `<DONE>`.
Make sure the table includes the following columns: Commit SHA, Author, Date, and Commit Message.
"""
logger.success("System prompt created successfully")
return prompt
# Step 4: Initialize the Agent
def initialize_agent() -> Agent:
"""
Initialize the Swarms agent with OpenAI model.
"""
logger.info("Initializing the agent with GPT-4o")
model = OpenAIChat(model_name="gpt-4o")
agent = Agent(
agent_name="Commit-Summarization-Agent",
agent_description="Fetch and summarize GitHub commits for Swarms repository.",
system_prompt="", # Will set dynamically
max_loops=1,
llm=model,
dynamic_temperature_enabled=True,
user_name="Kye",
retry_attempts=3,
context_length=8192,
return_step_meta=False,
output_type="str",
auto_generate_prompt=False,
max_tokens=4000,
stopping_token="<DONE>",
interactive=False,
)
logger.success("Agent initialized successfully")
return agent
# Step 5: Run the Agent with Data
def summarize_commits_with_agent(agent: Agent, prompt: str) -> str:
"""
Pass the system prompt to the agent and fetch the result.
"""
logger.info("Sending data to the agent for summarization")
try:
result = agent.run(
f"{prompt}",
all_cores=True,
)
logger.success("Agent completed the summarization task")
return result
except Exception as e:
logger.error(f"Agent encountered an error: {e}")
raise
# Main Execution
if __name__ == "__main__":
try:
logger.info("Starting commit summarization process")
# Fetch latest commits
latest_commits = fetch_latest_commits(GITHUB_API_URL, limit=5)
# Format commits and get current time
current_time, commit_summary = format_commits_with_time(
latest_commits
)
# Build the custom system prompt
custom_system_prompt = build_custom_system_prompt(
current_time, commit_summary
)
# Initialize agent
agent = initialize_agent()
# Set the dynamic system prompt
agent.system_prompt = custom_system_prompt
# Run the agent and summarize commits
result = summarize_commits_with_agent(
agent, custom_system_prompt
)
# Print the result
print("### Commit Summary in Markdown:")
print(result)
except Exception as e:
logger.critical(f"Process failed: {e}")

@ -0,0 +1,44 @@
from swarms import Agent
from swarms.prompts.finance_agent_sys_prompt import (
FINANCIAL_AGENT_SYS_PROMPT,
)
from swarm_models import OpenAIChat
model = OpenAIChat(model_name="gpt-4o")
# Initialize the agent
agent = Agent(
agent_name="Financial-Analysis-Agent",
agent_description="Personal finance advisor agent",
system_prompt=FINANCIAL_AGENT_SYS_PROMPT
+ "Output the <DONE> token when you're done creating a portfolio of etfs, index, funds, and more for AI",
max_loops=1,
llm=model,
dynamic_temperature_enabled=True,
user_name="Kye",
retry_attempts=3,
# streaming_on=True,
context_length=8192,
return_step_meta=False,
output_type="str", # "json", "dict", "csv" OR "string" "yaml" and
auto_generate_prompt=False, # Auto generate prompt for the agent based on name, description, and system prompt, task
max_tokens=4000, # max output tokens
# interactive=True,
stopping_token="<DONE>",
saved_state_path="agent_00.json",
interactive=False,
)
async def run_agent():
await agent.arun(
"Create a table of super high growth opportunities for AI. I have $40k to invest in ETFs, index funds, and more. Please create a table in markdown.",
all_cores=True,
)
if __name__ == "__main__":
import asyncio
asyncio.run(run_agent())

@ -0,0 +1,14 @@
from swarms.prompts.finance_agent_sys_prompt import (
FINANCIAL_AGENT_SYS_PROMPT,
)
from swarms.agents.openai_assistant import OpenAIAssistant
agent = OpenAIAssistant(
name="test", instructions=FINANCIAL_AGENT_SYS_PROMPT
)
print(
agent.run(
"Create a table of super high growth opportunities for AI. I have $40k to invest in ETFs, index funds, and more. Please create a table in markdown.",
)
)

@ -5,7 +5,7 @@ build-backend = "poetry.core.masonry.api"
[tool.poetry]
name = "swarms"
version = "6.6.9"
version = "6.7.0"
description = "Swarms - TGSC"
license = "MIT"
authors = ["Kye Gomez <kye@apac.ai>"]

@ -1,14 +1,47 @@
from typing import Optional, List, Dict, Any, Callable
import json
import os
import subprocess
import sys
import time
from openai import OpenAI
from typing import Any, Callable, Dict, List, Optional
from loguru import logger
from swarms.structs.agent import Agent
import json
def check_openai_package():
"""Check if the OpenAI package is installed, and install it if not."""
try:
import openai
return openai
except ImportError:
logger.info(
"OpenAI package not found. Attempting to install..."
)
# Attempt to install the OpenAI package
try:
subprocess.check_call(
[sys.executable, "-m", "pip", "install", "openai"]
)
logger.info("OpenAI package installed successfully.")
import openai # Re-import the package after installation
return openai
except subprocess.CalledProcessError as e:
logger.error(f"Failed to install OpenAI package: {e}")
raise RuntimeError(
"OpenAI package installation failed."
) from e
class OpenAIAssistant(Agent):
"""
OpenAI Assistant wrapper for the swarms framework.
Integrates OpenAI's Assistants API with the swarms architecture.
Example:
>>> assistant = OpenAIAssistant(
... name="Math Tutor",
@ -22,6 +55,7 @@ class OpenAIAssistant(Agent):
def __init__(
self,
name: str,
description: str = "Standard openai assistant wrapper",
instructions: Optional[str] = None,
model: str = "gpt-4o",
tools: Optional[List[Dict[str, Any]]] = None,
@ -29,47 +63,63 @@ class OpenAIAssistant(Agent):
metadata: Optional[Dict[str, Any]] = None,
functions: Optional[List[Dict[str, Any]]] = None,
*args,
**kwargs
**kwargs,
):
"""Initialize the OpenAI Assistant.
Args:
name: Name of the assistant
instructions: System instructions for the assistant
model: Model to use (default: gpt-4-turbo-preview)
model: Model to use (default: gpt-4o)
tools: List of tools to enable (code_interpreter, retrieval)
file_ids: List of file IDs to attach
metadata: Additional metadata
functions: List of custom functions to make available
"""
self.name = name
self.description = description
self.instructions = instructions
self.model = model
self.tools = tools
self.file_ids = file_ids
self.metadata = metadata
self.functions = functions
super().__init__(*args, **kwargs)
# Initialize tools list with any provided functions
self.tools = tools or []
if functions:
for func in functions:
self.tools.append({
"type": "function",
"function": func
})
self.tools.append(
{"type": "function", "function": func}
)
# Create the OpenAI Assistant
self.client = OpenAI()
openai = check_openai_package()
self.client = openai.OpenAI(
api_key=os.getenv("OPENAI_API_KEY")
)
self.assistant = self.client.beta.assistants.create(
name=name,
instructions=instructions,
model=model,
tools=self.tools,
file_ids=file_ids or [],
metadata=metadata or {}
# file_ids=file_ids or [],
metadata=metadata or {},
)
# Store available functions
self.available_functions: Dict[str, Callable] = {}
def add_function(self, func: Callable, description: str, parameters: Dict[str, Any]) -> None:
def add_function(
self,
func: Callable,
description: str,
parameters: Dict[str, Any],
) -> None:
"""Add a function that the assistant can call.
Args:
func: The function to make available to the assistant
description: Description of what the function does
@ -78,27 +128,23 @@ class OpenAIAssistant(Agent):
func_dict = {
"name": func.__name__,
"description": description,
"parameters": parameters
"parameters": parameters,
}
# Add to tools list
self.tools.append({
"type": "function",
"function": func_dict
})
self.tools.append({"type": "function", "function": func_dict})
# Store function reference
self.available_functions[func.__name__] = func
# Update assistant with new tools
self.assistant = self.client.beta.assistants.update(
assistant_id=self.assistant.id,
tools=self.tools
assistant_id=self.assistant.id, tools=self.tools
)
def _handle_tool_calls(self, run, thread_id: str) -> None:
"""Handle any required tool calls during a run.
This method processes any tool calls required by the assistant during execution.
It extracts function calls, executes them with provided arguments, and submits
the results back to the assistant.
@ -114,38 +160,46 @@ class OpenAIAssistant(Agent):
Exception: If there are errors executing the tool calls
"""
while run.status == "requires_action":
tool_calls = run.required_action.submit_tool_outputs.tool_calls
tool_calls = (
run.required_action.submit_tool_outputs.tool_calls
)
tool_outputs = []
for tool_call in tool_calls:
if tool_call.type == "function":
# Get function details
function_name = tool_call.function.name
function_args = json.loads(tool_call.function.arguments)
function_args = json.loads(
tool_call.function.arguments
)
# Call function if available
if function_name in self.available_functions:
function_response = self.available_functions[function_name](**function_args)
tool_outputs.append({
"tool_call_id": tool_call.id,
"output": str(function_response)
})
function_response = self.available_functions[
function_name
](**function_args)
tool_outputs.append(
{
"tool_call_id": tool_call.id,
"output": str(function_response),
}
)
# Submit outputs back to the run
run = self.client.beta.threads.runs.submit_tool_outputs(
thread_id=thread_id,
run_id=run.id,
tool_outputs=tool_outputs
tool_outputs=tool_outputs,
)
# Wait for processing
run = self._wait_for_run(run)
return run
def _wait_for_run(self, run) -> Any:
"""Wait for a run to complete and handle any required actions.
This method polls the OpenAI API to check the status of a run until it completes
or fails. It handles intermediate states like required actions and implements
exponential backoff.
@ -161,10 +215,9 @@ class OpenAIAssistant(Agent):
"""
while True:
run = self.client.beta.threads.runs.retrieve(
thread_id=run.thread_id,
run_id=run.id
thread_id=run.thread_id, run_id=run.id
)
if run.status == "completed":
break
elif run.status == "requires_action":
@ -172,15 +225,17 @@ class OpenAIAssistant(Agent):
if run.status == "completed":
break
elif run.status in ["failed", "expired"]:
raise Exception(f"Run failed with status: {run.status}")
raise Exception(
f"Run failed with status: {run.status}"
)
time.sleep(3) # Wait 3 seconds before checking again
return run
def _ensure_thread(self):
"""Ensure a thread exists for the conversation.
This method checks if there is an active thread for the current conversation.
If no thread exists, it creates a new one. This maintains conversation context
across multiple interactions.
@ -188,10 +243,11 @@ class OpenAIAssistant(Agent):
Side Effects:
Sets self.thread if it doesn't exist
"""
if not self.thread:
self.thread = self.client.beta.threads.create()
self.thread = self.client.beta.threads.create()
def add_message(self, content: str, file_ids: Optional[List[str]] = None) -> None:
def add_message(
self, content: str, file_ids: Optional[List[str]] = None
) -> None:
"""Add a message to the thread.
This method adds a new user message to the conversation thread. It ensures
@ -211,20 +267,18 @@ class OpenAIAssistant(Agent):
thread_id=self.thread.id,
role="user",
content=content,
file_ids=file_ids or []
# file_ids=file_ids or [],
)
def _get_response(self) -> str:
"""Get the latest assistant response from the thread."""
messages = self.client.beta.threads.messages.list(
thread_id=self.thread.id,
order="desc",
limit=1
thread_id=self.thread.id, order="desc", limit=1
)
if not messages.data:
return ""
message = messages.data[0]
if message.role == "assistant":
return message.content[0].text.value
@ -235,25 +289,25 @@ class OpenAIAssistant(Agent):
Args:
task: The task or prompt to send to the assistant
Returns:
The assistant's response as a string
"""
self._ensure_thread()
# Add the user message
self.add_message(task)
# Create and run the assistant
run = self.client.beta.threads.runs.create(
thread_id=self.thread.id,
assistant_id=self.assistant.id,
instructions=self.instructions
instructions=self.instructions,
)
# Wait for completion
run = self._wait_for_run(run)
# Only get and return the response if run completed successfully
if run.status == "completed":
return self._get_response()
@ -261,4 +315,4 @@ class OpenAIAssistant(Agent):
def call(self, task: str, *args, **kwargs) -> str:
"""Alias for run() to maintain compatibility with different agent interfaces."""
return self.run(task, *args, **kwargs)
return self.run(task, *args, **kwargs)

@ -73,22 +73,6 @@ from swarms.structs.utils import (
find_token_in_text,
parse_tasks,
)
from swarms.structs.swarm_router import (
SwarmRouter,
SwarmType,
swarm_router,
)
from swarms.structs.swarm_arange import SwarmRearrange
from swarms.structs.multi_agent_exec import (
run_agents_concurrently,
run_agents_concurrently_async,
run_single_agent,
run_agents_concurrently_multiprocess,
run_agents_sequentially,
run_agents_with_different_tasks,
run_agent_with_timeout,
run_agents_with_resource_monitoring,
)
from swarms.structs.async_workflow import AsyncWorkflow
__all__ = [

File diff suppressed because it is too large Load Diff

@ -0,0 +1,419 @@
import json
import logging
import time
import uuid
from datetime import datetime
from typing import Any, Dict, List, Optional
import yaml
from pydantic import BaseModel
from swarm_models.tiktoken_wrapper import TikTokenizer
logger = logging.getLogger(__name__)
class MemoryMetadata(BaseModel):
"""Metadata for memory entries"""
timestamp: Optional[float] = time.time()
role: Optional[str] = None
agent_name: Optional[str] = None
session_id: Optional[str] = None
memory_type: Optional[str] = None # 'short_term' or 'long_term'
token_count: Optional[int] = None
message_id: Optional[str] = str(uuid.uuid4())
class MemoryEntry(BaseModel):
"""Single memory entry with content and metadata"""
content: Optional[str] = None
metadata: Optional[MemoryMetadata] = None
class MemoryConfig(BaseModel):
"""Configuration for memory manager"""
max_short_term_tokens: Optional[int] = 4096
max_entries: Optional[int] = None
system_messages_token_buffer: Optional[int] = 1000
enable_long_term_memory: Optional[bool] = False
auto_archive: Optional[bool] = True
archive_threshold: Optional[float] = 0.8 # Archive when 80% full
class MemoryManager:
"""
Manages both short-term and long-term memory for an agent, handling token limits,
archival, and context retrieval.
Args:
config (MemoryConfig): Configuration for memory management
tokenizer (Optional[Any]): Tokenizer to use for token counting
long_term_memory (Optional[Any]): Vector store or database for long-term storage
"""
def __init__(
self,
config: MemoryConfig,
tokenizer: Optional[Any] = None,
long_term_memory: Optional[Any] = None,
):
self.config = config
self.tokenizer = tokenizer or TikTokenizer()
self.long_term_memory = long_term_memory
# Initialize memories
self.short_term_memory: List[MemoryEntry] = []
self.system_messages: List[MemoryEntry] = []
# Memory statistics
self.total_tokens_processed: int = 0
self.archived_entries_count: int = 0
def create_memory_entry(
self,
content: str,
role: str,
agent_name: str,
session_id: str,
memory_type: str = "short_term",
) -> MemoryEntry:
"""Create a new memory entry with metadata"""
metadata = MemoryMetadata(
timestamp=time.time(),
role=role,
agent_name=agent_name,
session_id=session_id,
memory_type=memory_type,
token_count=self.tokenizer.count_tokens(content),
)
return MemoryEntry(content=content, metadata=metadata)
def add_memory(
self,
content: str,
role: str,
agent_name: str,
session_id: str,
is_system: bool = False,
) -> None:
"""Add a new memory entry to appropriate storage"""
entry = self.create_memory_entry(
content=content,
role=role,
agent_name=agent_name,
session_id=session_id,
memory_type="system" if is_system else "short_term",
)
if is_system:
self.system_messages.append(entry)
else:
self.short_term_memory.append(entry)
# Check if archiving is needed
if self.should_archive():
self.archive_old_memories()
self.total_tokens_processed += entry.metadata.token_count
def get_current_token_count(self) -> int:
"""Get total tokens in short-term memory"""
return sum(
entry.metadata.token_count
for entry in self.short_term_memory
)
def get_system_messages_token_count(self) -> int:
"""Get total tokens in system messages"""
return sum(
entry.metadata.token_count
for entry in self.system_messages
)
def should_archive(self) -> bool:
"""Check if archiving is needed based on configuration"""
if not self.config.auto_archive:
return False
current_usage = (
self.get_current_token_count()
/ self.config.max_short_term_tokens
)
return current_usage >= self.config.archive_threshold
def archive_old_memories(self) -> None:
"""Move older memories to long-term storage"""
if not self.long_term_memory:
logger.warning(
"No long-term memory storage configured for archiving"
)
return
while self.should_archive():
# Get oldest non-system message
if not self.short_term_memory:
break
oldest_entry = self.short_term_memory.pop(0)
# Store in long-term memory
self.store_in_long_term_memory(oldest_entry)
self.archived_entries_count += 1
def store_in_long_term_memory(self, entry: MemoryEntry) -> None:
"""Store a memory entry in long-term memory"""
if self.long_term_memory is None:
logger.warning(
"Attempted to store in non-existent long-term memory"
)
return
try:
self.long_term_memory.add(str(entry.model_dump()))
except Exception as e:
logger.error(f"Error storing in long-term memory: {e}")
# Re-add to short-term if storage fails
self.short_term_memory.insert(0, entry)
def get_relevant_context(
self, query: str, max_tokens: Optional[int] = None
) -> str:
"""
Get relevant context from both memory types
Args:
query (str): Query to match against memories
max_tokens (Optional[int]): Maximum tokens to return
Returns:
str: Combined relevant context
"""
contexts = []
# Add system messages first
for entry in self.system_messages:
contexts.append(entry.content)
# Add short-term memory
for entry in reversed(self.short_term_memory):
contexts.append(entry.content)
# Query long-term memory if available
if self.long_term_memory is not None:
long_term_context = self.long_term_memory.query(query)
if long_term_context:
contexts.append(str(long_term_context))
# Combine and truncate if needed
combined = "\n".join(contexts)
if max_tokens:
combined = self.truncate_to_token_limit(
combined, max_tokens
)
return combined
def truncate_to_token_limit(
self, text: str, max_tokens: int
) -> str:
"""Truncate text to fit within token limit"""
current_tokens = self.tokenizer.count_tokens(text)
if current_tokens <= max_tokens:
return text
# Truncate by splitting into sentences and rebuilding
sentences = text.split(". ")
result = []
current_count = 0
for sentence in sentences:
sentence_tokens = self.tokenizer.count_tokens(sentence)
if current_count + sentence_tokens <= max_tokens:
result.append(sentence)
current_count += sentence_tokens
else:
break
return ". ".join(result)
def clear_short_term_memory(
self, preserve_system: bool = True
) -> None:
"""Clear short-term memory with option to preserve system messages"""
if not preserve_system:
self.system_messages.clear()
self.short_term_memory.clear()
logger.info(
"Cleared short-term memory"
+ " (preserved system messages)"
if preserve_system
else ""
)
def get_memory_stats(self) -> Dict[str, Any]:
"""Get detailed memory statistics"""
return {
"short_term_messages": len(self.short_term_memory),
"system_messages": len(self.system_messages),
"current_tokens": self.get_current_token_count(),
"system_tokens": self.get_system_messages_token_count(),
"max_tokens": self.config.max_short_term_tokens,
"token_usage_percent": round(
(
self.get_current_token_count()
/ self.config.max_short_term_tokens
)
* 100,
2,
),
"has_long_term_memory": self.long_term_memory is not None,
"archived_entries": self.archived_entries_count,
"total_tokens_processed": self.total_tokens_processed,
}
def save_memory_snapshot(self, file_path: str) -> None:
"""Save current memory state to file"""
try:
data = {
"timestamp": datetime.now().isoformat(),
"config": self.config.model_dump(),
"system_messages": [
entry.model_dump()
for entry in self.system_messages
],
"short_term_memory": [
entry.model_dump()
for entry in self.short_term_memory
],
"stats": self.get_memory_stats(),
}
with open(file_path, "w") as f:
if file_path.endswith(".yaml"):
yaml.dump(data, f)
else:
json.dump(data, f, indent=2)
logger.info(f"Saved memory snapshot to {file_path}")
except Exception as e:
logger.error(f"Error saving memory snapshot: {e}")
raise
def load_memory_snapshot(self, file_path: str) -> None:
"""Load memory state from file"""
try:
with open(file_path, "r") as f:
if file_path.endswith(".yaml"):
data = yaml.safe_load(f)
else:
data = json.load(f)
self.config = MemoryConfig(**data["config"])
self.system_messages = [
MemoryEntry(**entry)
for entry in data["system_messages"]
]
self.short_term_memory = [
MemoryEntry(**entry)
for entry in data["short_term_memory"]
]
logger.info(f"Loaded memory snapshot from {file_path}")
except Exception as e:
logger.error(f"Error loading memory snapshot: {e}")
raise
def search_memories(
self, query: str, memory_type: str = "all"
) -> List[MemoryEntry]:
"""
Search through memories of specified type
Args:
query (str): Search query
memory_type (str): Type of memories to search ("short_term", "system", "long_term", or "all")
Returns:
List[MemoryEntry]: Matching memory entries
"""
results = []
if memory_type in ["short_term", "all"]:
results.extend(
[
entry
for entry in self.short_term_memory
if query.lower() in entry.content.lower()
]
)
if memory_type in ["system", "all"]:
results.extend(
[
entry
for entry in self.system_messages
if query.lower() in entry.content.lower()
]
)
if (
memory_type in ["long_term", "all"]
and self.long_term_memory is not None
):
long_term_results = self.long_term_memory.query(query)
if long_term_results:
# Convert long-term results to MemoryEntry format
for result in long_term_results:
content = str(result)
metadata = MemoryMetadata(
timestamp=time.time(),
role="long_term",
agent_name="system",
session_id="long_term",
memory_type="long_term",
token_count=self.tokenizer.count_tokens(
content
),
)
results.append(
MemoryEntry(
content=content, metadata=metadata
)
)
return results
def get_memory_by_timeframe(
self, start_time: float, end_time: float
) -> List[MemoryEntry]:
"""Get memories within a specific timeframe"""
return [
entry
for entry in self.short_term_memory
if start_time <= entry.metadata.timestamp <= end_time
]
def export_memories(
self, file_path: str, format: str = "json"
) -> None:
"""Export memories to file in specified format"""
data = {
"system_messages": [
entry.model_dump() for entry in self.system_messages
],
"short_term_memory": [
entry.model_dump() for entry in self.short_term_memory
],
"stats": self.get_memory_stats(),
}
with open(file_path, "w") as f:
if format == "yaml":
yaml.dump(data, f)
else:
json.dump(data, f, indent=2)

@ -1,9 +1,10 @@
import asyncio
from typing import Any, Callable, List, Optional
from typing import Any, List
from swarms.structs.base_workflow import BaseWorkflow
from swarms.structs.agent import Agent
from swarms.utils.loguru_logger import logger
class AsyncWorkflow(BaseWorkflow):
def __init__(
self,
@ -13,7 +14,7 @@ class AsyncWorkflow(BaseWorkflow):
dashboard: bool = False,
autosave: bool = False,
verbose: bool = False,
**kwargs
**kwargs,
):
super().__init__(agents=agents, **kwargs)
self.name = name
@ -26,17 +27,25 @@ class AsyncWorkflow(BaseWorkflow):
self.results = []
self.loop = None
async def _execute_agent_task(self, agent: Agent, task: str) -> Any:
async def _execute_agent_task(
self, agent: Agent, task: str
) -> Any:
"""Execute a single agent task asynchronously"""
try:
if self.verbose:
logger.info(f"Agent {agent.agent_name} processing task: {task}")
logger.info(
f"Agent {agent.agent_name} processing task: {task}"
)
result = await agent.arun(task)
if self.verbose:
logger.info(f"Agent {agent.agent_name} completed task")
logger.info(
f"Agent {agent.agent_name} completed task"
)
return result
except Exception as e:
logger.error(f"Error in agent {agent.agent_name}: {str(e)}")
logger.error(
f"Error in agent {agent.agent_name}: {str(e)}"
)
return str(e)
async def run(self, task: str) -> List[Any]:
@ -46,17 +55,22 @@ class AsyncWorkflow(BaseWorkflow):
try:
# Create tasks for all agents
tasks = [self._execute_agent_task(agent, task) for agent in self.agents]
tasks = [
self._execute_agent_task(agent, task)
for agent in self.agents
]
# Execute all tasks concurrently
self.results = await asyncio.gather(*tasks, return_exceptions=True)
self.results = await asyncio.gather(
*tasks, return_exceptions=True
)
if self.autosave:
# TODO: Implement autosave logic here
pass
return self.results
except Exception as e:
logger.error(f"Error in workflow execution: {str(e)}")
raise
raise

@ -0,0 +1,258 @@
import inspect
import json
import logging
import os
from datetime import datetime
from typing import Any, Dict, Set
from uuid import UUID
logger = logging.getLogger(__name__)
class SafeLoaderUtils:
"""
Utility class for safely loading and saving object states while automatically
detecting and preserving class instances and complex objects.
"""
@staticmethod
def is_class_instance(obj: Any) -> bool:
"""
Detect if an object is a class instance (excluding built-in types).
Args:
obj: Object to check
Returns:
bool: True if object is a class instance
"""
if obj is None:
return False
# Get the type of the object
obj_type = type(obj)
# Check if it's a class instance but not a built-in type
return (
hasattr(obj, "__dict__")
and not isinstance(obj, type)
and obj_type.__module__ != "builtins"
)
@staticmethod
def is_safe_type(value: Any) -> bool:
"""
Check if a value is of a safe, serializable type.
Args:
value: Value to check
Returns:
bool: True if the value is safe to serialize
"""
# Basic safe types
safe_types = (
type(None),
bool,
int,
float,
str,
datetime,
UUID,
)
if isinstance(value, safe_types):
return True
# Check containers
if isinstance(value, (list, tuple)):
return all(
SafeLoaderUtils.is_safe_type(item) for item in value
)
if isinstance(value, dict):
return all(
isinstance(k, str) and SafeLoaderUtils.is_safe_type(v)
for k, v in value.items()
)
# Check for common serializable types
try:
json.dumps(value)
return True
except (TypeError, OverflowError, ValueError):
return False
@staticmethod
def get_class_attributes(obj: Any) -> Set[str]:
"""
Get all attributes of a class, including inherited ones.
Args:
obj: Object to inspect
Returns:
Set[str]: Set of attribute names
"""
attributes = set()
# Get all attributes from class and parent classes
for cls in inspect.getmro(type(obj)):
attributes.update(cls.__dict__.keys())
# Add instance attributes
attributes.update(obj.__dict__.keys())
return attributes
@staticmethod
def create_state_dict(obj: Any) -> Dict[str, Any]:
"""
Create a dictionary of safe values from an object's state.
Args:
obj: Object to create state dict from
Returns:
Dict[str, Any]: Dictionary of safe values
"""
state_dict = {}
for attr_name in SafeLoaderUtils.get_class_attributes(obj):
# Skip private attributes
if attr_name.startswith("_"):
continue
try:
value = getattr(obj, attr_name, None)
if SafeLoaderUtils.is_safe_type(value):
state_dict[attr_name] = value
except Exception as e:
logger.debug(f"Skipped attribute {attr_name}: {e}")
return state_dict
@staticmethod
def preserve_instances(obj: Any) -> Dict[str, Any]:
"""
Automatically detect and preserve all class instances in an object.
Args:
obj: Object to preserve instances from
Returns:
Dict[str, Any]: Dictionary of preserved instances
"""
preserved = {}
for attr_name in SafeLoaderUtils.get_class_attributes(obj):
if attr_name.startswith("_"):
continue
try:
value = getattr(obj, attr_name, None)
if SafeLoaderUtils.is_class_instance(value):
preserved[attr_name] = value
except Exception as e:
logger.debug(f"Could not preserve {attr_name}: {e}")
return preserved
class SafeStateManager:
"""
Manages saving and loading object states while automatically handling
class instances and complex objects.
"""
@staticmethod
def save_state(obj: Any, file_path: str) -> None:
"""
Save an object's state to a file, automatically handling complex objects.
Args:
obj: Object to save state from
file_path: Path to save state to
"""
try:
# Create state dict with only safe values
state_dict = SafeLoaderUtils.create_state_dict(obj)
# Ensure directory exists
os.makedirs(os.path.dirname(file_path), exist_ok=True)
# Save to file
with open(file_path, "w") as f:
json.dump(state_dict, f, indent=4, default=str)
logger.info(f"Successfully saved state to: {file_path}")
except Exception as e:
logger.error(f"Error saving state: {e}")
raise
@staticmethod
def load_state(obj: Any, file_path: str) -> None:
"""
Load state into an object while preserving class instances.
Args:
obj: Object to load state into
file_path: Path to load state from
"""
try:
# Verify file exists
if not os.path.exists(file_path):
raise FileNotFoundError(
f"State file not found: {file_path}"
)
# Preserve existing instances
preserved = SafeLoaderUtils.preserve_instances(obj)
# Load state
with open(file_path, "r") as f:
state_dict = json.load(f)
# Set safe values
for key, value in state_dict.items():
if (
not key.startswith("_")
and key not in preserved
and SafeLoaderUtils.is_safe_type(value)
):
setattr(obj, key, value)
# Restore preserved instances
for key, value in preserved.items():
setattr(obj, key, value)
logger.info(
f"Successfully loaded state from: {file_path}"
)
except Exception as e:
logger.error(f"Error loading state: {e}")
raise
# # Example decorator for easy integration
# def safe_state_methods(cls: Type) -> Type:
# """
# Class decorator to add safe state loading/saving methods to a class.
# Args:
# cls: Class to decorate
# Returns:
# Type: Decorated class
# """
# def save(self, file_path: str) -> None:
# SafeStateManager.save_state(self, file_path)
# def load(self, file_path: str) -> None:
# SafeStateManager.load_state(self, file_path)
# cls.save = save
# cls.load = load
# return cls

@ -400,10 +400,8 @@ class BaseTool(BaseModel):
logger.info(
f"Converting tool: {name} into a OpenAI certified function calling schema. Add documentation and type hints."
)
tool_schema = (
get_openai_function_schema_from_func(
tool, name=name, description=description
)
tool_schema = get_openai_function_schema_from_func(
tool, name=name, description=description
)
logger.info(
@ -420,10 +418,12 @@ class BaseTool(BaseModel):
if tool_schemas:
combined_schema = {
"type": "function",
"functions": [schema["function"] for schema in tool_schemas]
"functions": [
schema["function"] for schema in tool_schemas
],
}
return json.dumps(combined_schema, indent=4)
return None
def check_func_if_have_docs(self, func: callable):

@ -77,7 +77,7 @@ class LiteLLM:
str: The content of the response from the model.
"""
try:
messages = self._prepare_messages(task)
response = completion(
@ -85,14 +85,15 @@ class LiteLLM:
messages=messages,
stream=self.stream,
temperature=self.temperature,
# max_completion_tokens=self.max_tokens,
max_tokens=self.max_tokens,
*args,
**kwargs,
)
content = response.choices[
0
].message.content # Accessing the content
return content
except Exception as error:
print(error)

@ -34,4 +34,4 @@ def initialize_logger(log_folder: str = "logs"):
retention="10 days",
# compression="zip",
)
return logger
return logger

@ -0,0 +1,598 @@
import asyncio
from swarms import Agent
from swarm_models import OpenAIChat
import os
import time
import json
import yaml
import tempfile
def test_basic_agent_functionality():
"""Test basic agent initialization and simple task execution"""
print("\nTesting basic agent functionality...")
model = OpenAIChat(model_name="gpt-4o")
agent = Agent(agent_name="Test-Agent", llm=model, max_loops=1)
response = agent.run("What is 2+2?")
assert response is not None, "Agent response should not be None"
# Test agent properties
assert (
agent.agent_name == "Test-Agent"
), "Agent name not set correctly"
assert agent.max_loops == 1, "Max loops not set correctly"
assert agent.llm is not None, "LLM not initialized"
print("✓ Basic agent functionality test passed")
def test_memory_management():
"""Test agent memory management functionality"""
print("\nTesting memory management...")
model = OpenAIChat(model_name="gpt-4o")
agent = Agent(
agent_name="Memory-Test-Agent",
llm=model,
max_loops=1,
context_length=8192,
)
# Test adding to memory
agent.add_memory("Test memory entry")
assert (
"Test memory entry"
in agent.short_memory.return_history_as_string()
)
# Test memory query
agent.memory_query("Test query")
# Test token counting
tokens = agent.check_available_tokens()
assert isinstance(tokens, int), "Token count should be an integer"
print("✓ Memory management test passed")
def test_agent_output_formats():
"""Test all available output formats"""
print("\nTesting all output formats...")
model = OpenAIChat(model_name="gpt-4o")
test_task = "Say hello!"
output_types = {
"str": str,
"string": str,
"list": str, # JSON string containing list
"json": str, # JSON string
"dict": dict,
"yaml": str,
}
for output_type, expected_type in output_types.items():
agent = Agent(
agent_name=f"{output_type.capitalize()}-Output-Agent",
llm=model,
max_loops=1,
output_type=output_type,
)
response = agent.run(test_task)
assert (
response is not None
), f"{output_type} output should not be None"
if output_type == "yaml":
# Verify YAML can be parsed
try:
yaml.safe_load(response)
print(f"{output_type} output valid")
except yaml.YAMLError:
assert False, f"Invalid YAML output for {output_type}"
elif output_type in ["json", "list"]:
# Verify JSON can be parsed
try:
json.loads(response)
print(f"{output_type} output valid")
except json.JSONDecodeError:
assert False, f"Invalid JSON output for {output_type}"
print("✓ Output formats test passed")
def test_agent_state_management():
"""Test comprehensive state management functionality"""
print("\nTesting state management...")
model = OpenAIChat(model_name="gpt-4o")
# Create temporary directory for test files
with tempfile.TemporaryDirectory() as temp_dir:
state_path = os.path.join(temp_dir, "agent_state.json")
# Create agent with initial state
agent1 = Agent(
agent_name="State-Test-Agent",
llm=model,
max_loops=1,
saved_state_path=state_path,
)
# Add some data to the agent
agent1.run("Remember this: Test message 1")
agent1.add_memory("Test message 2")
# Save state
agent1.save()
assert os.path.exists(state_path), "State file not created"
# Create new agent and load state
agent2 = Agent(
agent_name="State-Test-Agent", llm=model, max_loops=1
)
agent2.load(state_path)
# Verify state loaded correctly
history2 = agent2.short_memory.return_history_as_string()
assert (
"Test message 1" in history2
), "State not loaded correctly"
assert (
"Test message 2" in history2
), "Memory not loaded correctly"
# Test autosave functionality
agent3 = Agent(
agent_name="Autosave-Test-Agent",
llm=model,
max_loops=1,
saved_state_path=os.path.join(
temp_dir, "autosave_state.json"
),
autosave=True,
)
agent3.run("Test autosave")
time.sleep(2) # Wait for autosave
assert os.path.exists(
os.path.join(temp_dir, "autosave_state.json")
), "Autosave file not created"
print("✓ State management test passed")
def test_agent_tools_and_execution():
"""Test agent tool handling and execution"""
print("\nTesting tools and execution...")
def sample_tool(x: int, y: int) -> int:
"""Sample tool that adds two numbers"""
return x + y
model = OpenAIChat(model_name="gpt-4o")
agent = Agent(
agent_name="Tools-Test-Agent",
llm=model,
max_loops=1,
tools=[sample_tool],
)
# Test adding tools
agent.add_tool(lambda x: x * 2)
assert len(agent.tools) == 2, "Tool not added correctly"
# Test removing tools
agent.remove_tool(sample_tool)
assert len(agent.tools) == 1, "Tool not removed correctly"
# Test tool execution
response = agent.run("Calculate 2 + 2 using the sample tool")
assert response is not None, "Tool execution failed"
print("✓ Tools and execution test passed")
def test_agent_concurrent_execution():
"""Test agent concurrent execution capabilities"""
print("\nTesting concurrent execution...")
model = OpenAIChat(model_name="gpt-4o")
agent = Agent(
agent_name="Concurrent-Test-Agent", llm=model, max_loops=1
)
# Test bulk run
tasks = [
{"task": "Count to 3"},
{"task": "Say hello"},
{"task": "Tell a short joke"},
]
responses = agent.bulk_run(tasks)
assert len(responses) == len(tasks), "Not all tasks completed"
assert all(
response is not None for response in responses
), "Some tasks failed"
# Test concurrent tasks
concurrent_responses = agent.run_concurrent_tasks(
["Task 1", "Task 2", "Task 3"]
)
assert (
len(concurrent_responses) == 3
), "Not all concurrent tasks completed"
print("✓ Concurrent execution test passed")
def test_agent_error_handling():
"""Test agent error handling and recovery"""
print("\nTesting error handling...")
model = OpenAIChat(model_name="gpt-4o")
agent = Agent(
agent_name="Error-Test-Agent",
llm=model,
max_loops=1,
retry_attempts=3,
retry_interval=1,
)
# Test invalid tool execution
try:
agent.parse_and_execute_tools("invalid_json")
print("✓ Invalid tool execution handled")
except Exception:
assert True, "Expected error caught"
# Test recovery after error
response = agent.run("Continue after error")
assert response is not None, "Agent failed to recover after error"
print("✓ Error handling test passed")
def test_agent_configuration():
"""Test agent configuration and parameters"""
print("\nTesting agent configuration...")
model = OpenAIChat(model_name="gpt-4o")
agent = Agent(
agent_name="Config-Test-Agent",
llm=model,
max_loops=1,
temperature=0.7,
max_tokens=4000,
context_length=8192,
)
# Test configuration methods
agent.update_system_prompt("New system prompt")
agent.update_max_loops(2)
agent.update_loop_interval(2)
# Verify updates
assert agent.max_loops == 2, "Max loops not updated"
assert agent.loop_interval == 2, "Loop interval not updated"
# Test configuration export
config_dict = agent.to_dict()
assert isinstance(
config_dict, dict
), "Configuration export failed"
# Test YAML export
yaml_config = agent.to_yaml()
assert isinstance(yaml_config, str), "YAML export failed"
print("✓ Configuration test passed")
def test_agent_with_stopping_condition():
"""Test agent with custom stopping condition"""
print("\nTesting agent with stopping condition...")
def custom_stopping_condition(response: str) -> bool:
return "STOP" in response.upper()
model = OpenAIChat(model_name="gpt-4o")
agent = Agent(
agent_name="Stopping-Condition-Agent",
llm=model,
max_loops=5,
stopping_condition=custom_stopping_condition,
)
response = agent.run("Count up until you see the word STOP")
assert response is not None, "Stopping condition test failed"
print("✓ Stopping condition test passed")
def test_agent_with_retry_mechanism():
"""Test agent retry mechanism"""
print("\nTesting agent retry mechanism...")
model = OpenAIChat(model_name="gpt-4o")
agent = Agent(
agent_name="Retry-Test-Agent",
llm=model,
max_loops=1,
retry_attempts=3,
retry_interval=1,
)
response = agent.run("Tell me a joke.")
assert response is not None, "Retry mechanism test failed"
print("✓ Retry mechanism test passed")
def test_bulk_and_filtered_operations():
"""Test bulk operations and response filtering"""
print("\nTesting bulk and filtered operations...")
model = OpenAIChat(model_name="gpt-4o")
agent = Agent(
agent_name="Bulk-Filter-Test-Agent", llm=model, max_loops=1
)
# Test bulk run
bulk_tasks = [
{"task": "What is 2+2?"},
{"task": "Name a color"},
{"task": "Count to 3"},
]
bulk_responses = agent.bulk_run(bulk_tasks)
assert len(bulk_responses) == len(
bulk_tasks
), "Bulk run should return same number of responses as tasks"
# Test response filtering
agent.add_response_filter("color")
filtered_response = agent.filtered_run(
"What is your favorite color?"
)
assert (
"[FILTERED]" in filtered_response
), "Response filter not applied"
print("✓ Bulk and filtered operations test passed")
async def test_async_operations():
"""Test asynchronous operations"""
print("\nTesting async operations...")
model = OpenAIChat(model_name="gpt-4o")
agent = Agent(
agent_name="Async-Test-Agent", llm=model, max_loops=1
)
# Test single async run
response = await agent.arun("What is 1+1?")
assert response is not None, "Async run failed"
# Test concurrent async runs
tasks = ["Task 1", "Task 2", "Task 3"]
responses = await asyncio.gather(
*[agent.arun(task) for task in tasks]
)
assert len(responses) == len(
tasks
), "Not all async tasks completed"
print("✓ Async operations test passed")
def test_memory_and_state_persistence():
"""Test memory management and state persistence"""
print("\nTesting memory and state persistence...")
with tempfile.TemporaryDirectory() as temp_dir:
state_path = os.path.join(temp_dir, "test_state.json")
# Create agent with memory configuration
model = OpenAIChat(model_name="gpt-4o")
agent1 = Agent(
agent_name="Memory-State-Test-Agent",
llm=model,
max_loops=1,
saved_state_path=state_path,
context_length=8192,
autosave=True,
)
# Test memory operations
agent1.add_memory("Important fact: The sky is blue")
agent1.memory_query("What color is the sky?")
# Save state
agent1.save()
# Create new agent and load state
agent2 = Agent(
agent_name="Memory-State-Test-Agent",
llm=model,
max_loops=1,
)
agent2.load(state_path)
# Verify memory persistence
memory_content = (
agent2.short_memory.return_history_as_string()
)
assert (
"sky is blue" in memory_content
), "Memory not properly persisted"
print("✓ Memory and state persistence test passed")
def test_sentiment_and_evaluation():
"""Test sentiment analysis and response evaluation"""
print("\nTesting sentiment analysis and evaluation...")
def mock_sentiment_analyzer(text):
"""Mock sentiment analyzer that returns a score between 0 and 1"""
return 0.7 if "positive" in text.lower() else 0.3
def mock_evaluator(response):
"""Mock evaluator that checks response quality"""
return "GOOD" if len(response) > 10 else "BAD"
model = OpenAIChat(model_name="gpt-4o")
agent = Agent(
agent_name="Sentiment-Eval-Test-Agent",
llm=model,
max_loops=1,
sentiment_analyzer=mock_sentiment_analyzer,
sentiment_threshold=0.5,
evaluator=mock_evaluator,
)
# Test sentiment analysis
agent.run("Generate a positive message")
# Test evaluation
agent.run("Generate a detailed response")
print("✓ Sentiment and evaluation test passed")
def test_tool_management():
"""Test tool management functionality"""
print("\nTesting tool management...")
def tool1(x: int) -> int:
"""Sample tool 1"""
return x * 2
def tool2(x: int) -> int:
"""Sample tool 2"""
return x + 2
model = OpenAIChat(model_name="gpt-4o")
agent = Agent(
agent_name="Tool-Test-Agent",
llm=model,
max_loops=1,
tools=[tool1],
)
# Test adding tools
agent.add_tool(tool2)
assert len(agent.tools) == 2, "Tool not added correctly"
# Test removing tools
agent.remove_tool(tool1)
assert len(agent.tools) == 1, "Tool not removed correctly"
# Test adding multiple tools
agent.add_tools([tool1, tool2])
assert len(agent.tools) == 3, "Multiple tools not added correctly"
print("✓ Tool management test passed")
def test_system_prompt_and_configuration():
"""Test system prompt and configuration updates"""
print("\nTesting system prompt and configuration...")
model = OpenAIChat(model_name="gpt-4o")
agent = Agent(
agent_name="Config-Test-Agent", llm=model, max_loops=1
)
# Test updating system prompt
new_prompt = "You are a helpful assistant."
agent.update_system_prompt(new_prompt)
assert (
agent.system_prompt == new_prompt
), "System prompt not updated"
# Test configuration updates
agent.update_max_loops(5)
assert agent.max_loops == 5, "Max loops not updated"
agent.update_loop_interval(2)
assert agent.loop_interval == 2, "Loop interval not updated"
# Test configuration export
config_dict = agent.to_dict()
assert isinstance(
config_dict, dict
), "Configuration export failed"
print("✓ System prompt and configuration test passed")
def test_agent_with_dynamic_temperature():
"""Test agent with dynamic temperature"""
print("\nTesting agent with dynamic temperature...")
model = OpenAIChat(model_name="gpt-4o")
agent = Agent(
agent_name="Dynamic-Temp-Agent",
llm=model,
max_loops=2,
dynamic_temperature_enabled=True,
)
response = agent.run("Generate a creative story.")
assert response is not None, "Dynamic temperature test failed"
print("✓ Dynamic temperature test passed")
def run_all_tests():
"""Run all test functions"""
print("Starting Extended Agent functional tests...\n")
test_functions = [
test_basic_agent_functionality,
test_memory_management,
test_agent_output_formats,
test_agent_state_management,
test_agent_tools_and_execution,
test_agent_concurrent_execution,
test_agent_error_handling,
test_agent_configuration,
test_agent_with_stopping_condition,
test_agent_with_retry_mechanism,
test_agent_with_dynamic_temperature,
test_bulk_and_filtered_operations,
test_memory_and_state_persistence,
test_sentiment_and_evaluation,
test_tool_management,
test_system_prompt_and_configuration,
]
# Run synchronous tests
total_tests = len(test_functions) + 1 # +1 for async test
passed_tests = 0
for test in test_functions:
try:
test()
passed_tests += 1
except Exception as e:
print(f"✗ Test {test.__name__} failed: {str(e)}")
# Run async test
try:
asyncio.run(test_async_operations())
passed_tests += 1
except Exception as e:
print(f"✗ Async operations test failed: {str(e)}")
print("\nExtended Test Summary:")
print(f"Total Tests: {total_tests}")
print(f"Passed: {passed_tests}")
print(f"Failed: {total_tests - passed_tests}")
print(f"Success Rate: {(passed_tests/total_tests)*100:.2f}%")
if __name__ == "__main__":
run_all_tests()
Loading…
Cancel
Save