parent
9b1106ac91
commit
9abe300548
@ -0,0 +1,348 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
import platform
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import psutil
|
||||||
|
import requests
|
||||||
|
from loguru import logger
|
||||||
|
from swarm_models import OpenAIChat
|
||||||
|
|
||||||
|
from swarms.structs.agent import Agent
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SwarmSystemInfo:
|
||||||
|
"""System information for Swarms issue reports."""
|
||||||
|
|
||||||
|
os_name: str
|
||||||
|
os_version: str
|
||||||
|
python_version: str
|
||||||
|
cpu_usage: float
|
||||||
|
memory_usage: float
|
||||||
|
disk_usage: float
|
||||||
|
swarms_version: str # Added Swarms version tracking
|
||||||
|
cuda_available: bool # Added CUDA availability check
|
||||||
|
gpu_info: Optional[str] # Added GPU information
|
||||||
|
|
||||||
|
|
||||||
|
class SwarmsIssueReporter:
|
||||||
|
"""
|
||||||
|
Production-grade GitHub issue reporter specifically designed for the Swarms library.
|
||||||
|
Automatically creates detailed issues for the https://github.com/kyegomez/swarms repository.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Swarms-specific error categorization
|
||||||
|
- Automatic version and dependency tracking
|
||||||
|
- CUDA and GPU information collection
|
||||||
|
- Integration with Swarms logging system
|
||||||
|
- Detailed environment information
|
||||||
|
"""
|
||||||
|
|
||||||
|
REPO_OWNER = "kyegomez"
|
||||||
|
REPO_NAME = "swarms"
|
||||||
|
ISSUE_CATEGORIES = {
|
||||||
|
"agent": ["agent", "automation"],
|
||||||
|
"memory": ["memory", "storage"],
|
||||||
|
"tool": ["tools", "integration"],
|
||||||
|
"llm": ["llm", "model"],
|
||||||
|
"performance": ["performance", "optimization"],
|
||||||
|
"compatibility": ["compatibility", "environment"],
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
github_token: str,
|
||||||
|
rate_limit: int = 10,
|
||||||
|
rate_period: int = 3600,
|
||||||
|
log_file: str = "swarms_issues.log",
|
||||||
|
enable_duplicate_check: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the Swarms Issue Reporter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
github_token (str): GitHub personal access token
|
||||||
|
rate_limit (int): Maximum number of issues to create per rate_period
|
||||||
|
rate_period (int): Time period for rate limiting in seconds
|
||||||
|
log_file (str): Path to log file
|
||||||
|
enable_duplicate_check (bool): Whether to check for duplicate issues
|
||||||
|
"""
|
||||||
|
self.github_token = github_token
|
||||||
|
self.rate_limit = rate_limit
|
||||||
|
self.rate_period = rate_period
|
||||||
|
self.enable_duplicate_check = enable_duplicate_check
|
||||||
|
self.github_token = os.getenv("GITHUB_API_KEY")
|
||||||
|
|
||||||
|
# Initialize logging
|
||||||
|
log_path = os.path.join(os.getcwd(), "logs", log_file)
|
||||||
|
os.makedirs(os.path.dirname(log_path), exist_ok=True)
|
||||||
|
logger.add(
|
||||||
|
log_path,
|
||||||
|
rotation="1 day",
|
||||||
|
retention="1 month",
|
||||||
|
compression="zip",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Issue tracking
|
||||||
|
self.issues_created = []
|
||||||
|
self.last_issue_time = datetime.now()
|
||||||
|
|
||||||
|
# Validate GitHub token
|
||||||
|
# self._validate_github_credentials()
|
||||||
|
|
||||||
|
def _get_swarms_version(self) -> str:
|
||||||
|
"""Get the installed version of Swarms."""
|
||||||
|
try:
|
||||||
|
import swarms
|
||||||
|
|
||||||
|
return swarms.__version__
|
||||||
|
except:
|
||||||
|
return "Unknown"
|
||||||
|
|
||||||
|
def _get_gpu_info(self) -> Tuple[bool, Optional[str]]:
|
||||||
|
"""Get GPU information and CUDA availability."""
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
|
||||||
|
cuda_available = torch.cuda.is_available()
|
||||||
|
if cuda_available:
|
||||||
|
gpu_info = torch.cuda.get_device_name(0)
|
||||||
|
return cuda_available, gpu_info
|
||||||
|
return False, None
|
||||||
|
except:
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
def _get_system_info(self) -> SwarmSystemInfo:
|
||||||
|
"""Collect system and Swarms-specific information."""
|
||||||
|
cuda_available, gpu_info = self._get_gpu_info()
|
||||||
|
|
||||||
|
return SwarmSystemInfo(
|
||||||
|
os_name=platform.system(),
|
||||||
|
os_version=platform.version(),
|
||||||
|
python_version=sys.version,
|
||||||
|
cpu_usage=psutil.cpu_percent(),
|
||||||
|
memory_usage=psutil.virtual_memory().percent,
|
||||||
|
disk_usage=psutil.disk_usage("/").percent,
|
||||||
|
swarms_version=self._get_swarms_version(),
|
||||||
|
cuda_available=cuda_available,
|
||||||
|
gpu_info=gpu_info,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _categorize_error(
|
||||||
|
self, error: Exception, context: Dict
|
||||||
|
) -> List[str]:
|
||||||
|
"""Categorize the error and return appropriate labels."""
|
||||||
|
error_str = str(error).lower()
|
||||||
|
type(error).__name__
|
||||||
|
|
||||||
|
labels = ["bug", "automated"]
|
||||||
|
|
||||||
|
# Check error message and context for category keywords
|
||||||
|
for (
|
||||||
|
category,
|
||||||
|
category_labels,
|
||||||
|
) in self.ISSUE_CATEGORIES.items():
|
||||||
|
if any(
|
||||||
|
keyword in error_str for keyword in category_labels
|
||||||
|
):
|
||||||
|
labels.extend(category_labels)
|
||||||
|
break
|
||||||
|
|
||||||
|
# Add severity label based on error type
|
||||||
|
if issubclass(type(error), (SystemError, MemoryError)):
|
||||||
|
labels.append("severity:critical")
|
||||||
|
elif issubclass(type(error), (ValueError, TypeError)):
|
||||||
|
labels.append("severity:medium")
|
||||||
|
else:
|
||||||
|
labels.append("severity:low")
|
||||||
|
|
||||||
|
return list(set(labels)) # Remove duplicates
|
||||||
|
|
||||||
|
def _format_swarms_issue_body(
|
||||||
|
self,
|
||||||
|
error: Exception,
|
||||||
|
system_info: SwarmSystemInfo,
|
||||||
|
context: Dict,
|
||||||
|
) -> str:
|
||||||
|
"""Format the issue body with Swarms-specific information."""
|
||||||
|
return f"""
|
||||||
|
## Swarms Error Report
|
||||||
|
- **Error Type**: {type(error).__name__}
|
||||||
|
- **Error Message**: {str(error)}
|
||||||
|
- **Swarms Version**: {system_info.swarms_version}
|
||||||
|
|
||||||
|
## Environment Information
|
||||||
|
- **OS**: {system_info.os_name} {system_info.os_version}
|
||||||
|
- **Python Version**: {system_info.python_version}
|
||||||
|
- **CUDA Available**: {system_info.cuda_available}
|
||||||
|
- **GPU**: {system_info.gpu_info or "N/A"}
|
||||||
|
- **CPU Usage**: {system_info.cpu_usage}%
|
||||||
|
- **Memory Usage**: {system_info.memory_usage}%
|
||||||
|
- **Disk Usage**: {system_info.disk_usage}%
|
||||||
|
|
||||||
|
## Stack Trace
|
||||||
|
{traceback.format_exc()}
|
||||||
|
|
||||||
|
## Context
|
||||||
|
{json.dumps(context, indent=2)}
|
||||||
|
|
||||||
|
## Dependencies
|
||||||
|
{self._get_dependencies_info()}
|
||||||
|
|
||||||
|
## Time of Occurrence
|
||||||
|
{datetime.now().isoformat()}
|
||||||
|
|
||||||
|
---
|
||||||
|
*This issue was automatically generated by SwarmsIssueReporter*
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _get_dependencies_info(self) -> str:
|
||||||
|
"""Get information about installed dependencies."""
|
||||||
|
try:
|
||||||
|
import pkg_resources
|
||||||
|
|
||||||
|
deps = []
|
||||||
|
for dist in pkg_resources.working_set:
|
||||||
|
deps.append(f"- {dist.key} {dist.version}")
|
||||||
|
return "\n".join(deps)
|
||||||
|
except:
|
||||||
|
return "Unable to fetch dependency information"
|
||||||
|
|
||||||
|
# First, add this method to your SwarmsIssueReporter class
|
||||||
|
def _check_rate_limit(self) -> bool:
|
||||||
|
"""Check if we're within rate limits."""
|
||||||
|
now = datetime.now()
|
||||||
|
time_diff = (now - self.last_issue_time).total_seconds()
|
||||||
|
|
||||||
|
if (
|
||||||
|
len(self.issues_created) >= self.rate_limit
|
||||||
|
and time_diff < self.rate_period
|
||||||
|
):
|
||||||
|
logger.warning("Rate limit exceeded for issue creation")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Clean up old issues from tracking
|
||||||
|
self.issues_created = [
|
||||||
|
time
|
||||||
|
for time in self.issues_created
|
||||||
|
if (now - time).total_seconds() < self.rate_period
|
||||||
|
]
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def report_swarms_issue(
|
||||||
|
self,
|
||||||
|
error: Exception,
|
||||||
|
agent: Optional[Agent] = None,
|
||||||
|
context: Dict[str, Any] = None,
|
||||||
|
priority: str = "normal",
|
||||||
|
) -> Optional[int]:
|
||||||
|
"""
|
||||||
|
Report a Swarms-specific issue to GitHub.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
error (Exception): The exception to report
|
||||||
|
agent (Optional[Agent]): The Swarms agent instance that encountered the error
|
||||||
|
context (Dict[str, Any]): Additional context about the error
|
||||||
|
priority (str): Issue priority ("low", "normal", "high", "critical")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[int]: Issue number if created successfully
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if not self._check_rate_limit():
|
||||||
|
logger.warning(
|
||||||
|
"Skipping issue creation due to rate limit"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Collect system information
|
||||||
|
system_info = self._get_system_info()
|
||||||
|
|
||||||
|
# Prepare context with agent information if available
|
||||||
|
full_context = context or {}
|
||||||
|
if agent:
|
||||||
|
full_context.update(
|
||||||
|
{
|
||||||
|
"agent_name": agent.agent_name,
|
||||||
|
"agent_description": agent.agent_description,
|
||||||
|
"max_loops": agent.max_loops,
|
||||||
|
"context_length": agent.context_length,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create issue title
|
||||||
|
title = f"[{type(error).__name__}] {str(error)[:100]}"
|
||||||
|
if agent:
|
||||||
|
title = f"[Agent: {agent.agent_name}] {title}"
|
||||||
|
|
||||||
|
# Get appropriate labels
|
||||||
|
labels = self._categorize_error(error, full_context)
|
||||||
|
labels.append(f"priority:{priority}")
|
||||||
|
|
||||||
|
# Create the issue
|
||||||
|
url = f"https://api.github.com/repos/{self.REPO_OWNER}/{self.REPO_NAME}/issues"
|
||||||
|
data = {
|
||||||
|
"title": title,
|
||||||
|
"body": self._format_swarms_issue_body(
|
||||||
|
error, system_info, full_context
|
||||||
|
),
|
||||||
|
"labels": labels,
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
url,
|
||||||
|
headers={
|
||||||
|
"Authorization": f"token {self.github_token}"
|
||||||
|
},
|
||||||
|
json=data,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
issue_number = response.json()["number"]
|
||||||
|
logger.info(
|
||||||
|
f"Successfully created Swarms issue #{issue_number}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return issue_number
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error creating Swarms issue: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# from swarms import Agent
|
||||||
|
# from swarm_models import OpenAIChat
|
||||||
|
# from swarms.utils.issue_reporter import SwarmsIssueReporter
|
||||||
|
# import os
|
||||||
|
|
||||||
|
# Setup the reporter with your GitHub token
|
||||||
|
reporter = SwarmsIssueReporter(
|
||||||
|
github_token=os.getenv("GITHUB_API_KEY")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Force an error to test the reporter
|
||||||
|
try:
|
||||||
|
# This will raise an error since the input isn't valid
|
||||||
|
# Create an agent that might have issues
|
||||||
|
model = OpenAIChat(model_name="gpt-4o")
|
||||||
|
agent = Agent(agent_name="Test-Agent", max_loops=1)
|
||||||
|
|
||||||
|
result = agent.run(None)
|
||||||
|
|
||||||
|
raise ValueError("test")
|
||||||
|
except Exception as e:
|
||||||
|
# Report the issue
|
||||||
|
issue_number = reporter.report_swarms_issue(
|
||||||
|
error=e,
|
||||||
|
agent=agent,
|
||||||
|
context={"task": "test_run"},
|
||||||
|
priority="high",
|
||||||
|
)
|
||||||
|
print(f"Created issue number: {issue_number}")
|
@ -0,0 +1,189 @@
|
|||||||
|
import requests
|
||||||
|
import datetime
|
||||||
|
from typing import List, Dict, Tuple
|
||||||
|
from loguru import logger
|
||||||
|
from swarms import Agent
|
||||||
|
from swarm_models import OpenAIChat
|
||||||
|
|
||||||
|
# GitHub API Configurations
|
||||||
|
GITHUB_REPO = "kyegomez/swarms" # Swarms GitHub repository
|
||||||
|
GITHUB_API_URL = f"https://api.github.com/repos/{GITHUB_REPO}/commits"
|
||||||
|
|
||||||
|
# Initialize Loguru
|
||||||
|
logger.add(
|
||||||
|
"commit_summary.log",
|
||||||
|
rotation="1 MB",
|
||||||
|
level="INFO",
|
||||||
|
backtrace=True,
|
||||||
|
diagnose=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Step 1: Fetch the latest commits from GitHub
|
||||||
|
def fetch_latest_commits(
|
||||||
|
repo_url: str, limit: int = 5
|
||||||
|
) -> List[Dict[str, str]]:
|
||||||
|
"""
|
||||||
|
Fetch the latest commits from a public GitHub repository.
|
||||||
|
"""
|
||||||
|
logger.info(
|
||||||
|
f"Fetching the latest {limit} commits from {repo_url}"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
params = {"per_page": limit}
|
||||||
|
response = requests.get(repo_url, params=params)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
commits = response.json()
|
||||||
|
commit_data = []
|
||||||
|
|
||||||
|
for commit in commits:
|
||||||
|
commit_data.append(
|
||||||
|
{
|
||||||
|
"sha": commit["sha"][:7], # Short commit hash
|
||||||
|
"author": commit["commit"]["author"]["name"],
|
||||||
|
"message": commit["commit"]["message"],
|
||||||
|
"date": commit["commit"]["author"]["date"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.success("Successfully fetched commit data")
|
||||||
|
return commit_data
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error fetching commits: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
# Step 2: Format commits and fetch current time
|
||||||
|
def format_commits_with_time(
|
||||||
|
commits: List[Dict[str, str]]
|
||||||
|
) -> Tuple[str, str]:
|
||||||
|
"""
|
||||||
|
Format commit data into a readable string and return current time.
|
||||||
|
"""
|
||||||
|
current_time = datetime.datetime.now().strftime(
|
||||||
|
"%Y-%m-%d %H:%M:%S"
|
||||||
|
)
|
||||||
|
logger.info(f"Formatting commits at {current_time}")
|
||||||
|
|
||||||
|
commit_summary = "\n".join(
|
||||||
|
[
|
||||||
|
f"- `{commit['sha']}` by {commit['author']} on {commit['date']}: {commit['message']}"
|
||||||
|
for commit in commits
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.success("Commits formatted successfully")
|
||||||
|
return current_time, commit_summary
|
||||||
|
|
||||||
|
|
||||||
|
# Step 3: Build a dynamic system prompt
|
||||||
|
def build_custom_system_prompt(
|
||||||
|
current_time: str, commit_summary: str
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Build a dynamic system prompt with the current time and commit summary.
|
||||||
|
"""
|
||||||
|
logger.info("Building the custom system prompt for the agent")
|
||||||
|
prompt = f"""
|
||||||
|
You are a software analyst tasked with summarizing the latest commits from the Swarms GitHub repository.
|
||||||
|
|
||||||
|
The current time is **{current_time}**.
|
||||||
|
|
||||||
|
Here are the latest commits:
|
||||||
|
{commit_summary}
|
||||||
|
|
||||||
|
**Your task**:
|
||||||
|
1. Summarize the changes into a clear and concise table in **markdown format**.
|
||||||
|
2. Highlight the key improvements and fixes.
|
||||||
|
3. End your output with the token `<DONE>`.
|
||||||
|
|
||||||
|
Make sure the table includes the following columns: Commit SHA, Author, Date, and Commit Message.
|
||||||
|
"""
|
||||||
|
logger.success("System prompt created successfully")
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
# Step 4: Initialize the Agent
|
||||||
|
def initialize_agent() -> Agent:
|
||||||
|
"""
|
||||||
|
Initialize the Swarms agent with OpenAI model.
|
||||||
|
"""
|
||||||
|
logger.info("Initializing the agent with GPT-4o")
|
||||||
|
model = OpenAIChat(model_name="gpt-4o")
|
||||||
|
|
||||||
|
agent = Agent(
|
||||||
|
agent_name="Commit-Summarization-Agent",
|
||||||
|
agent_description="Fetch and summarize GitHub commits for Swarms repository.",
|
||||||
|
system_prompt="", # Will set dynamically
|
||||||
|
max_loops=1,
|
||||||
|
llm=model,
|
||||||
|
dynamic_temperature_enabled=True,
|
||||||
|
user_name="Kye",
|
||||||
|
retry_attempts=3,
|
||||||
|
context_length=8192,
|
||||||
|
return_step_meta=False,
|
||||||
|
output_type="str",
|
||||||
|
auto_generate_prompt=False,
|
||||||
|
max_tokens=4000,
|
||||||
|
stopping_token="<DONE>",
|
||||||
|
interactive=False,
|
||||||
|
)
|
||||||
|
logger.success("Agent initialized successfully")
|
||||||
|
return agent
|
||||||
|
|
||||||
|
|
||||||
|
# Step 5: Run the Agent with Data
|
||||||
|
def summarize_commits_with_agent(agent: Agent, prompt: str) -> str:
|
||||||
|
"""
|
||||||
|
Pass the system prompt to the agent and fetch the result.
|
||||||
|
"""
|
||||||
|
logger.info("Sending data to the agent for summarization")
|
||||||
|
try:
|
||||||
|
result = agent.run(
|
||||||
|
f"{prompt}",
|
||||||
|
all_cores=True,
|
||||||
|
)
|
||||||
|
logger.success("Agent completed the summarization task")
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Agent encountered an error: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
# Main Execution
|
||||||
|
if __name__ == "__main__":
|
||||||
|
try:
|
||||||
|
logger.info("Starting commit summarization process")
|
||||||
|
|
||||||
|
# Fetch latest commits
|
||||||
|
latest_commits = fetch_latest_commits(GITHUB_API_URL, limit=5)
|
||||||
|
|
||||||
|
# Format commits and get current time
|
||||||
|
current_time, commit_summary = format_commits_with_time(
|
||||||
|
latest_commits
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build the custom system prompt
|
||||||
|
custom_system_prompt = build_custom_system_prompt(
|
||||||
|
current_time, commit_summary
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize agent
|
||||||
|
agent = initialize_agent()
|
||||||
|
|
||||||
|
# Set the dynamic system prompt
|
||||||
|
agent.system_prompt = custom_system_prompt
|
||||||
|
|
||||||
|
# Run the agent and summarize commits
|
||||||
|
result = summarize_commits_with_agent(
|
||||||
|
agent, custom_system_prompt
|
||||||
|
)
|
||||||
|
|
||||||
|
# Print the result
|
||||||
|
print("### Commit Summary in Markdown:")
|
||||||
|
print(result)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.critical(f"Process failed: {e}")
|
@ -0,0 +1,44 @@
|
|||||||
|
from swarms import Agent
|
||||||
|
from swarms.prompts.finance_agent_sys_prompt import (
|
||||||
|
FINANCIAL_AGENT_SYS_PROMPT,
|
||||||
|
)
|
||||||
|
from swarm_models import OpenAIChat
|
||||||
|
|
||||||
|
model = OpenAIChat(model_name="gpt-4o")
|
||||||
|
|
||||||
|
|
||||||
|
# Initialize the agent
|
||||||
|
agent = Agent(
|
||||||
|
agent_name="Financial-Analysis-Agent",
|
||||||
|
agent_description="Personal finance advisor agent",
|
||||||
|
system_prompt=FINANCIAL_AGENT_SYS_PROMPT
|
||||||
|
+ "Output the <DONE> token when you're done creating a portfolio of etfs, index, funds, and more for AI",
|
||||||
|
max_loops=1,
|
||||||
|
llm=model,
|
||||||
|
dynamic_temperature_enabled=True,
|
||||||
|
user_name="Kye",
|
||||||
|
retry_attempts=3,
|
||||||
|
# streaming_on=True,
|
||||||
|
context_length=8192,
|
||||||
|
return_step_meta=False,
|
||||||
|
output_type="str", # "json", "dict", "csv" OR "string" "yaml" and
|
||||||
|
auto_generate_prompt=False, # Auto generate prompt for the agent based on name, description, and system prompt, task
|
||||||
|
max_tokens=4000, # max output tokens
|
||||||
|
# interactive=True,
|
||||||
|
stopping_token="<DONE>",
|
||||||
|
saved_state_path="agent_00.json",
|
||||||
|
interactive=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def run_agent():
|
||||||
|
await agent.arun(
|
||||||
|
"Create a table of super high growth opportunities for AI. I have $40k to invest in ETFs, index funds, and more. Please create a table in markdown.",
|
||||||
|
all_cores=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
asyncio.run(run_agent())
|
@ -0,0 +1,14 @@
|
|||||||
|
from swarms.prompts.finance_agent_sys_prompt import (
|
||||||
|
FINANCIAL_AGENT_SYS_PROMPT,
|
||||||
|
)
|
||||||
|
from swarms.agents.openai_assistant import OpenAIAssistant
|
||||||
|
|
||||||
|
agent = OpenAIAssistant(
|
||||||
|
name="test", instructions=FINANCIAL_AGENT_SYS_PROMPT
|
||||||
|
)
|
||||||
|
|
||||||
|
print(
|
||||||
|
agent.run(
|
||||||
|
"Create a table of super high growth opportunities for AI. I have $40k to invest in ETFs, index funds, and more. Please create a table in markdown.",
|
||||||
|
)
|
||||||
|
)
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,419 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from swarm_models.tiktoken_wrapper import TikTokenizer
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryMetadata(BaseModel):
|
||||||
|
"""Metadata for memory entries"""
|
||||||
|
|
||||||
|
timestamp: Optional[float] = time.time()
|
||||||
|
role: Optional[str] = None
|
||||||
|
agent_name: Optional[str] = None
|
||||||
|
session_id: Optional[str] = None
|
||||||
|
memory_type: Optional[str] = None # 'short_term' or 'long_term'
|
||||||
|
token_count: Optional[int] = None
|
||||||
|
message_id: Optional[str] = str(uuid.uuid4())
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryEntry(BaseModel):
|
||||||
|
"""Single memory entry with content and metadata"""
|
||||||
|
|
||||||
|
content: Optional[str] = None
|
||||||
|
metadata: Optional[MemoryMetadata] = None
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryConfig(BaseModel):
|
||||||
|
"""Configuration for memory manager"""
|
||||||
|
|
||||||
|
max_short_term_tokens: Optional[int] = 4096
|
||||||
|
max_entries: Optional[int] = None
|
||||||
|
system_messages_token_buffer: Optional[int] = 1000
|
||||||
|
enable_long_term_memory: Optional[bool] = False
|
||||||
|
auto_archive: Optional[bool] = True
|
||||||
|
archive_threshold: Optional[float] = 0.8 # Archive when 80% full
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryManager:
|
||||||
|
"""
|
||||||
|
Manages both short-term and long-term memory for an agent, handling token limits,
|
||||||
|
archival, and context retrieval.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (MemoryConfig): Configuration for memory management
|
||||||
|
tokenizer (Optional[Any]): Tokenizer to use for token counting
|
||||||
|
long_term_memory (Optional[Any]): Vector store or database for long-term storage
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: MemoryConfig,
|
||||||
|
tokenizer: Optional[Any] = None,
|
||||||
|
long_term_memory: Optional[Any] = None,
|
||||||
|
):
|
||||||
|
self.config = config
|
||||||
|
self.tokenizer = tokenizer or TikTokenizer()
|
||||||
|
self.long_term_memory = long_term_memory
|
||||||
|
|
||||||
|
# Initialize memories
|
||||||
|
self.short_term_memory: List[MemoryEntry] = []
|
||||||
|
self.system_messages: List[MemoryEntry] = []
|
||||||
|
|
||||||
|
# Memory statistics
|
||||||
|
self.total_tokens_processed: int = 0
|
||||||
|
self.archived_entries_count: int = 0
|
||||||
|
|
||||||
|
def create_memory_entry(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
role: str,
|
||||||
|
agent_name: str,
|
||||||
|
session_id: str,
|
||||||
|
memory_type: str = "short_term",
|
||||||
|
) -> MemoryEntry:
|
||||||
|
"""Create a new memory entry with metadata"""
|
||||||
|
metadata = MemoryMetadata(
|
||||||
|
timestamp=time.time(),
|
||||||
|
role=role,
|
||||||
|
agent_name=agent_name,
|
||||||
|
session_id=session_id,
|
||||||
|
memory_type=memory_type,
|
||||||
|
token_count=self.tokenizer.count_tokens(content),
|
||||||
|
)
|
||||||
|
return MemoryEntry(content=content, metadata=metadata)
|
||||||
|
|
||||||
|
def add_memory(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
role: str,
|
||||||
|
agent_name: str,
|
||||||
|
session_id: str,
|
||||||
|
is_system: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""Add a new memory entry to appropriate storage"""
|
||||||
|
entry = self.create_memory_entry(
|
||||||
|
content=content,
|
||||||
|
role=role,
|
||||||
|
agent_name=agent_name,
|
||||||
|
session_id=session_id,
|
||||||
|
memory_type="system" if is_system else "short_term",
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_system:
|
||||||
|
self.system_messages.append(entry)
|
||||||
|
else:
|
||||||
|
self.short_term_memory.append(entry)
|
||||||
|
|
||||||
|
# Check if archiving is needed
|
||||||
|
if self.should_archive():
|
||||||
|
self.archive_old_memories()
|
||||||
|
|
||||||
|
self.total_tokens_processed += entry.metadata.token_count
|
||||||
|
|
||||||
|
def get_current_token_count(self) -> int:
|
||||||
|
"""Get total tokens in short-term memory"""
|
||||||
|
return sum(
|
||||||
|
entry.metadata.token_count
|
||||||
|
for entry in self.short_term_memory
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_system_messages_token_count(self) -> int:
|
||||||
|
"""Get total tokens in system messages"""
|
||||||
|
return sum(
|
||||||
|
entry.metadata.token_count
|
||||||
|
for entry in self.system_messages
|
||||||
|
)
|
||||||
|
|
||||||
|
def should_archive(self) -> bool:
|
||||||
|
"""Check if archiving is needed based on configuration"""
|
||||||
|
if not self.config.auto_archive:
|
||||||
|
return False
|
||||||
|
|
||||||
|
current_usage = (
|
||||||
|
self.get_current_token_count()
|
||||||
|
/ self.config.max_short_term_tokens
|
||||||
|
)
|
||||||
|
return current_usage >= self.config.archive_threshold
|
||||||
|
|
||||||
|
def archive_old_memories(self) -> None:
|
||||||
|
"""Move older memories to long-term storage"""
|
||||||
|
if not self.long_term_memory:
|
||||||
|
logger.warning(
|
||||||
|
"No long-term memory storage configured for archiving"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
while self.should_archive():
|
||||||
|
# Get oldest non-system message
|
||||||
|
if not self.short_term_memory:
|
||||||
|
break
|
||||||
|
|
||||||
|
oldest_entry = self.short_term_memory.pop(0)
|
||||||
|
|
||||||
|
# Store in long-term memory
|
||||||
|
self.store_in_long_term_memory(oldest_entry)
|
||||||
|
self.archived_entries_count += 1
|
||||||
|
|
||||||
|
def store_in_long_term_memory(self, entry: MemoryEntry) -> None:
|
||||||
|
"""Store a memory entry in long-term memory"""
|
||||||
|
if self.long_term_memory is None:
|
||||||
|
logger.warning(
|
||||||
|
"Attempted to store in non-existent long-term memory"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.long_term_memory.add(str(entry.model_dump()))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error storing in long-term memory: {e}")
|
||||||
|
# Re-add to short-term if storage fails
|
||||||
|
self.short_term_memory.insert(0, entry)
|
||||||
|
|
||||||
|
def get_relevant_context(
|
||||||
|
self, query: str, max_tokens: Optional[int] = None
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Get relevant context from both memory types
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query (str): Query to match against memories
|
||||||
|
max_tokens (Optional[int]): Maximum tokens to return
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Combined relevant context
|
||||||
|
"""
|
||||||
|
contexts = []
|
||||||
|
|
||||||
|
# Add system messages first
|
||||||
|
for entry in self.system_messages:
|
||||||
|
contexts.append(entry.content)
|
||||||
|
|
||||||
|
# Add short-term memory
|
||||||
|
for entry in reversed(self.short_term_memory):
|
||||||
|
contexts.append(entry.content)
|
||||||
|
|
||||||
|
# Query long-term memory if available
|
||||||
|
if self.long_term_memory is not None:
|
||||||
|
long_term_context = self.long_term_memory.query(query)
|
||||||
|
if long_term_context:
|
||||||
|
contexts.append(str(long_term_context))
|
||||||
|
|
||||||
|
# Combine and truncate if needed
|
||||||
|
combined = "\n".join(contexts)
|
||||||
|
if max_tokens:
|
||||||
|
combined = self.truncate_to_token_limit(
|
||||||
|
combined, max_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
return combined
|
||||||
|
|
||||||
|
def truncate_to_token_limit(
|
||||||
|
self, text: str, max_tokens: int
|
||||||
|
) -> str:
|
||||||
|
"""Truncate text to fit within token limit"""
|
||||||
|
current_tokens = self.tokenizer.count_tokens(text)
|
||||||
|
|
||||||
|
if current_tokens <= max_tokens:
|
||||||
|
return text
|
||||||
|
|
||||||
|
# Truncate by splitting into sentences and rebuilding
|
||||||
|
sentences = text.split(". ")
|
||||||
|
result = []
|
||||||
|
current_count = 0
|
||||||
|
|
||||||
|
for sentence in sentences:
|
||||||
|
sentence_tokens = self.tokenizer.count_tokens(sentence)
|
||||||
|
if current_count + sentence_tokens <= max_tokens:
|
||||||
|
result.append(sentence)
|
||||||
|
current_count += sentence_tokens
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
return ". ".join(result)
|
||||||
|
|
||||||
|
def clear_short_term_memory(
|
||||||
|
self, preserve_system: bool = True
|
||||||
|
) -> None:
|
||||||
|
"""Clear short-term memory with option to preserve system messages"""
|
||||||
|
if not preserve_system:
|
||||||
|
self.system_messages.clear()
|
||||||
|
self.short_term_memory.clear()
|
||||||
|
logger.info(
|
||||||
|
"Cleared short-term memory"
|
||||||
|
+ " (preserved system messages)"
|
||||||
|
if preserve_system
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_memory_stats(self) -> Dict[str, Any]:
|
||||||
|
"""Get detailed memory statistics"""
|
||||||
|
return {
|
||||||
|
"short_term_messages": len(self.short_term_memory),
|
||||||
|
"system_messages": len(self.system_messages),
|
||||||
|
"current_tokens": self.get_current_token_count(),
|
||||||
|
"system_tokens": self.get_system_messages_token_count(),
|
||||||
|
"max_tokens": self.config.max_short_term_tokens,
|
||||||
|
"token_usage_percent": round(
|
||||||
|
(
|
||||||
|
self.get_current_token_count()
|
||||||
|
/ self.config.max_short_term_tokens
|
||||||
|
)
|
||||||
|
* 100,
|
||||||
|
2,
|
||||||
|
),
|
||||||
|
"has_long_term_memory": self.long_term_memory is not None,
|
||||||
|
"archived_entries": self.archived_entries_count,
|
||||||
|
"total_tokens_processed": self.total_tokens_processed,
|
||||||
|
}
|
||||||
|
|
||||||
|
def save_memory_snapshot(self, file_path: str) -> None:
|
||||||
|
"""Save current memory state to file"""
|
||||||
|
try:
|
||||||
|
data = {
|
||||||
|
"timestamp": datetime.now().isoformat(),
|
||||||
|
"config": self.config.model_dump(),
|
||||||
|
"system_messages": [
|
||||||
|
entry.model_dump()
|
||||||
|
for entry in self.system_messages
|
||||||
|
],
|
||||||
|
"short_term_memory": [
|
||||||
|
entry.model_dump()
|
||||||
|
for entry in self.short_term_memory
|
||||||
|
],
|
||||||
|
"stats": self.get_memory_stats(),
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(file_path, "w") as f:
|
||||||
|
if file_path.endswith(".yaml"):
|
||||||
|
yaml.dump(data, f)
|
||||||
|
else:
|
||||||
|
json.dump(data, f, indent=2)
|
||||||
|
|
||||||
|
logger.info(f"Saved memory snapshot to {file_path}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error saving memory snapshot: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def load_memory_snapshot(self, file_path: str) -> None:
|
||||||
|
"""Load memory state from file"""
|
||||||
|
try:
|
||||||
|
with open(file_path, "r") as f:
|
||||||
|
if file_path.endswith(".yaml"):
|
||||||
|
data = yaml.safe_load(f)
|
||||||
|
else:
|
||||||
|
data = json.load(f)
|
||||||
|
|
||||||
|
self.config = MemoryConfig(**data["config"])
|
||||||
|
self.system_messages = [
|
||||||
|
MemoryEntry(**entry)
|
||||||
|
for entry in data["system_messages"]
|
||||||
|
]
|
||||||
|
self.short_term_memory = [
|
||||||
|
MemoryEntry(**entry)
|
||||||
|
for entry in data["short_term_memory"]
|
||||||
|
]
|
||||||
|
|
||||||
|
logger.info(f"Loaded memory snapshot from {file_path}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error loading memory snapshot: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def search_memories(
|
||||||
|
self, query: str, memory_type: str = "all"
|
||||||
|
) -> List[MemoryEntry]:
|
||||||
|
"""
|
||||||
|
Search through memories of specified type
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query (str): Search query
|
||||||
|
memory_type (str): Type of memories to search ("short_term", "system", "long_term", or "all")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[MemoryEntry]: Matching memory entries
|
||||||
|
"""
|
||||||
|
results = []
|
||||||
|
|
||||||
|
if memory_type in ["short_term", "all"]:
|
||||||
|
results.extend(
|
||||||
|
[
|
||||||
|
entry
|
||||||
|
for entry in self.short_term_memory
|
||||||
|
if query.lower() in entry.content.lower()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
if memory_type in ["system", "all"]:
|
||||||
|
results.extend(
|
||||||
|
[
|
||||||
|
entry
|
||||||
|
for entry in self.system_messages
|
||||||
|
if query.lower() in entry.content.lower()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
memory_type in ["long_term", "all"]
|
||||||
|
and self.long_term_memory is not None
|
||||||
|
):
|
||||||
|
long_term_results = self.long_term_memory.query(query)
|
||||||
|
if long_term_results:
|
||||||
|
# Convert long-term results to MemoryEntry format
|
||||||
|
for result in long_term_results:
|
||||||
|
content = str(result)
|
||||||
|
metadata = MemoryMetadata(
|
||||||
|
timestamp=time.time(),
|
||||||
|
role="long_term",
|
||||||
|
agent_name="system",
|
||||||
|
session_id="long_term",
|
||||||
|
memory_type="long_term",
|
||||||
|
token_count=self.tokenizer.count_tokens(
|
||||||
|
content
|
||||||
|
),
|
||||||
|
)
|
||||||
|
results.append(
|
||||||
|
MemoryEntry(
|
||||||
|
content=content, metadata=metadata
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def get_memory_by_timeframe(
|
||||||
|
self, start_time: float, end_time: float
|
||||||
|
) -> List[MemoryEntry]:
|
||||||
|
"""Get memories within a specific timeframe"""
|
||||||
|
return [
|
||||||
|
entry
|
||||||
|
for entry in self.short_term_memory
|
||||||
|
if start_time <= entry.metadata.timestamp <= end_time
|
||||||
|
]
|
||||||
|
|
||||||
|
def export_memories(
|
||||||
|
self, file_path: str, format: str = "json"
|
||||||
|
) -> None:
|
||||||
|
"""Export memories to file in specified format"""
|
||||||
|
data = {
|
||||||
|
"system_messages": [
|
||||||
|
entry.model_dump() for entry in self.system_messages
|
||||||
|
],
|
||||||
|
"short_term_memory": [
|
||||||
|
entry.model_dump() for entry in self.short_term_memory
|
||||||
|
],
|
||||||
|
"stats": self.get_memory_stats(),
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(file_path, "w") as f:
|
||||||
|
if format == "yaml":
|
||||||
|
yaml.dump(data, f)
|
||||||
|
else:
|
||||||
|
json.dump(data, f, indent=2)
|
@ -0,0 +1,258 @@
|
|||||||
|
import inspect
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Dict, Set
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SafeLoaderUtils:
|
||||||
|
"""
|
||||||
|
Utility class for safely loading and saving object states while automatically
|
||||||
|
detecting and preserving class instances and complex objects.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_class_instance(obj: Any) -> bool:
|
||||||
|
"""
|
||||||
|
Detect if an object is a class instance (excluding built-in types).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obj: Object to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if object is a class instance
|
||||||
|
"""
|
||||||
|
if obj is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Get the type of the object
|
||||||
|
obj_type = type(obj)
|
||||||
|
|
||||||
|
# Check if it's a class instance but not a built-in type
|
||||||
|
return (
|
||||||
|
hasattr(obj, "__dict__")
|
||||||
|
and not isinstance(obj, type)
|
||||||
|
and obj_type.__module__ != "builtins"
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_safe_type(value: Any) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a value is of a safe, serializable type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: Value to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the value is safe to serialize
|
||||||
|
"""
|
||||||
|
# Basic safe types
|
||||||
|
safe_types = (
|
||||||
|
type(None),
|
||||||
|
bool,
|
||||||
|
int,
|
||||||
|
float,
|
||||||
|
str,
|
||||||
|
datetime,
|
||||||
|
UUID,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(value, safe_types):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check containers
|
||||||
|
if isinstance(value, (list, tuple)):
|
||||||
|
return all(
|
||||||
|
SafeLoaderUtils.is_safe_type(item) for item in value
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(value, dict):
|
||||||
|
return all(
|
||||||
|
isinstance(k, str) and SafeLoaderUtils.is_safe_type(v)
|
||||||
|
for k, v in value.items()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for common serializable types
|
||||||
|
try:
|
||||||
|
json.dumps(value)
|
||||||
|
return True
|
||||||
|
except (TypeError, OverflowError, ValueError):
|
||||||
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_class_attributes(obj: Any) -> Set[str]:
|
||||||
|
"""
|
||||||
|
Get all attributes of a class, including inherited ones.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obj: Object to inspect
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Set[str]: Set of attribute names
|
||||||
|
"""
|
||||||
|
attributes = set()
|
||||||
|
|
||||||
|
# Get all attributes from class and parent classes
|
||||||
|
for cls in inspect.getmro(type(obj)):
|
||||||
|
attributes.update(cls.__dict__.keys())
|
||||||
|
|
||||||
|
# Add instance attributes
|
||||||
|
attributes.update(obj.__dict__.keys())
|
||||||
|
|
||||||
|
return attributes
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_state_dict(obj: Any) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Create a dictionary of safe values from an object's state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obj: Object to create state dict from
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Any]: Dictionary of safe values
|
||||||
|
"""
|
||||||
|
state_dict = {}
|
||||||
|
|
||||||
|
for attr_name in SafeLoaderUtils.get_class_attributes(obj):
|
||||||
|
# Skip private attributes
|
||||||
|
if attr_name.startswith("_"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
value = getattr(obj, attr_name, None)
|
||||||
|
if SafeLoaderUtils.is_safe_type(value):
|
||||||
|
state_dict[attr_name] = value
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Skipped attribute {attr_name}: {e}")
|
||||||
|
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def preserve_instances(obj: Any) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Automatically detect and preserve all class instances in an object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obj: Object to preserve instances from
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Any]: Dictionary of preserved instances
|
||||||
|
"""
|
||||||
|
preserved = {}
|
||||||
|
|
||||||
|
for attr_name in SafeLoaderUtils.get_class_attributes(obj):
|
||||||
|
if attr_name.startswith("_"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
value = getattr(obj, attr_name, None)
|
||||||
|
if SafeLoaderUtils.is_class_instance(value):
|
||||||
|
preserved[attr_name] = value
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Could not preserve {attr_name}: {e}")
|
||||||
|
|
||||||
|
return preserved
|
||||||
|
|
||||||
|
|
||||||
|
class SafeStateManager:
|
||||||
|
"""
|
||||||
|
Manages saving and loading object states while automatically handling
|
||||||
|
class instances and complex objects.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def save_state(obj: Any, file_path: str) -> None:
|
||||||
|
"""
|
||||||
|
Save an object's state to a file, automatically handling complex objects.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obj: Object to save state from
|
||||||
|
file_path: Path to save state to
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Create state dict with only safe values
|
||||||
|
state_dict = SafeLoaderUtils.create_state_dict(obj)
|
||||||
|
|
||||||
|
# Ensure directory exists
|
||||||
|
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||||
|
|
||||||
|
# Save to file
|
||||||
|
with open(file_path, "w") as f:
|
||||||
|
json.dump(state_dict, f, indent=4, default=str)
|
||||||
|
|
||||||
|
logger.info(f"Successfully saved state to: {file_path}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error saving state: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load_state(obj: Any, file_path: str) -> None:
|
||||||
|
"""
|
||||||
|
Load state into an object while preserving class instances.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obj: Object to load state into
|
||||||
|
file_path: Path to load state from
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Verify file exists
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"State file not found: {file_path}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Preserve existing instances
|
||||||
|
preserved = SafeLoaderUtils.preserve_instances(obj)
|
||||||
|
|
||||||
|
# Load state
|
||||||
|
with open(file_path, "r") as f:
|
||||||
|
state_dict = json.load(f)
|
||||||
|
|
||||||
|
# Set safe values
|
||||||
|
for key, value in state_dict.items():
|
||||||
|
if (
|
||||||
|
not key.startswith("_")
|
||||||
|
and key not in preserved
|
||||||
|
and SafeLoaderUtils.is_safe_type(value)
|
||||||
|
):
|
||||||
|
setattr(obj, key, value)
|
||||||
|
|
||||||
|
# Restore preserved instances
|
||||||
|
for key, value in preserved.items():
|
||||||
|
setattr(obj, key, value)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Successfully loaded state from: {file_path}"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error loading state: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
# # Example decorator for easy integration
|
||||||
|
# def safe_state_methods(cls: Type) -> Type:
|
||||||
|
# """
|
||||||
|
# Class decorator to add safe state loading/saving methods to a class.
|
||||||
|
|
||||||
|
# Args:
|
||||||
|
# cls: Class to decorate
|
||||||
|
|
||||||
|
# Returns:
|
||||||
|
# Type: Decorated class
|
||||||
|
# """
|
||||||
|
# def save(self, file_path: str) -> None:
|
||||||
|
# SafeStateManager.save_state(self, file_path)
|
||||||
|
|
||||||
|
# def load(self, file_path: str) -> None:
|
||||||
|
# SafeStateManager.load_state(self, file_path)
|
||||||
|
|
||||||
|
# cls.save = save
|
||||||
|
# cls.load = load
|
||||||
|
# return cls
|
@ -0,0 +1,598 @@
|
|||||||
|
import asyncio
|
||||||
|
from swarms import Agent
|
||||||
|
from swarm_models import OpenAIChat
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import json
|
||||||
|
import yaml
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
|
||||||
|
def test_basic_agent_functionality():
|
||||||
|
"""Test basic agent initialization and simple task execution"""
|
||||||
|
print("\nTesting basic agent functionality...")
|
||||||
|
|
||||||
|
model = OpenAIChat(model_name="gpt-4o")
|
||||||
|
agent = Agent(agent_name="Test-Agent", llm=model, max_loops=1)
|
||||||
|
|
||||||
|
response = agent.run("What is 2+2?")
|
||||||
|
assert response is not None, "Agent response should not be None"
|
||||||
|
|
||||||
|
# Test agent properties
|
||||||
|
assert (
|
||||||
|
agent.agent_name == "Test-Agent"
|
||||||
|
), "Agent name not set correctly"
|
||||||
|
assert agent.max_loops == 1, "Max loops not set correctly"
|
||||||
|
assert agent.llm is not None, "LLM not initialized"
|
||||||
|
|
||||||
|
print("✓ Basic agent functionality test passed")
|
||||||
|
|
||||||
|
|
||||||
|
def test_memory_management():
|
||||||
|
"""Test agent memory management functionality"""
|
||||||
|
print("\nTesting memory management...")
|
||||||
|
|
||||||
|
model = OpenAIChat(model_name="gpt-4o")
|
||||||
|
agent = Agent(
|
||||||
|
agent_name="Memory-Test-Agent",
|
||||||
|
llm=model,
|
||||||
|
max_loops=1,
|
||||||
|
context_length=8192,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test adding to memory
|
||||||
|
agent.add_memory("Test memory entry")
|
||||||
|
assert (
|
||||||
|
"Test memory entry"
|
||||||
|
in agent.short_memory.return_history_as_string()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test memory query
|
||||||
|
agent.memory_query("Test query")
|
||||||
|
|
||||||
|
# Test token counting
|
||||||
|
tokens = agent.check_available_tokens()
|
||||||
|
assert isinstance(tokens, int), "Token count should be an integer"
|
||||||
|
|
||||||
|
print("✓ Memory management test passed")
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_output_formats():
|
||||||
|
"""Test all available output formats"""
|
||||||
|
print("\nTesting all output formats...")
|
||||||
|
|
||||||
|
model = OpenAIChat(model_name="gpt-4o")
|
||||||
|
test_task = "Say hello!"
|
||||||
|
|
||||||
|
output_types = {
|
||||||
|
"str": str,
|
||||||
|
"string": str,
|
||||||
|
"list": str, # JSON string containing list
|
||||||
|
"json": str, # JSON string
|
||||||
|
"dict": dict,
|
||||||
|
"yaml": str,
|
||||||
|
}
|
||||||
|
|
||||||
|
for output_type, expected_type in output_types.items():
|
||||||
|
agent = Agent(
|
||||||
|
agent_name=f"{output_type.capitalize()}-Output-Agent",
|
||||||
|
llm=model,
|
||||||
|
max_loops=1,
|
||||||
|
output_type=output_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = agent.run(test_task)
|
||||||
|
assert (
|
||||||
|
response is not None
|
||||||
|
), f"{output_type} output should not be None"
|
||||||
|
|
||||||
|
if output_type == "yaml":
|
||||||
|
# Verify YAML can be parsed
|
||||||
|
try:
|
||||||
|
yaml.safe_load(response)
|
||||||
|
print(f"✓ {output_type} output valid")
|
||||||
|
except yaml.YAMLError:
|
||||||
|
assert False, f"Invalid YAML output for {output_type}"
|
||||||
|
elif output_type in ["json", "list"]:
|
||||||
|
# Verify JSON can be parsed
|
||||||
|
try:
|
||||||
|
json.loads(response)
|
||||||
|
print(f"✓ {output_type} output valid")
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
assert False, f"Invalid JSON output for {output_type}"
|
||||||
|
|
||||||
|
print("✓ Output formats test passed")
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_state_management():
|
||||||
|
"""Test comprehensive state management functionality"""
|
||||||
|
print("\nTesting state management...")
|
||||||
|
|
||||||
|
model = OpenAIChat(model_name="gpt-4o")
|
||||||
|
|
||||||
|
# Create temporary directory for test files
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
state_path = os.path.join(temp_dir, "agent_state.json")
|
||||||
|
|
||||||
|
# Create agent with initial state
|
||||||
|
agent1 = Agent(
|
||||||
|
agent_name="State-Test-Agent",
|
||||||
|
llm=model,
|
||||||
|
max_loops=1,
|
||||||
|
saved_state_path=state_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add some data to the agent
|
||||||
|
agent1.run("Remember this: Test message 1")
|
||||||
|
agent1.add_memory("Test message 2")
|
||||||
|
|
||||||
|
# Save state
|
||||||
|
agent1.save()
|
||||||
|
assert os.path.exists(state_path), "State file not created"
|
||||||
|
|
||||||
|
# Create new agent and load state
|
||||||
|
agent2 = Agent(
|
||||||
|
agent_name="State-Test-Agent", llm=model, max_loops=1
|
||||||
|
)
|
||||||
|
agent2.load(state_path)
|
||||||
|
|
||||||
|
# Verify state loaded correctly
|
||||||
|
history2 = agent2.short_memory.return_history_as_string()
|
||||||
|
assert (
|
||||||
|
"Test message 1" in history2
|
||||||
|
), "State not loaded correctly"
|
||||||
|
assert (
|
||||||
|
"Test message 2" in history2
|
||||||
|
), "Memory not loaded correctly"
|
||||||
|
|
||||||
|
# Test autosave functionality
|
||||||
|
agent3 = Agent(
|
||||||
|
agent_name="Autosave-Test-Agent",
|
||||||
|
llm=model,
|
||||||
|
max_loops=1,
|
||||||
|
saved_state_path=os.path.join(
|
||||||
|
temp_dir, "autosave_state.json"
|
||||||
|
),
|
||||||
|
autosave=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
agent3.run("Test autosave")
|
||||||
|
time.sleep(2) # Wait for autosave
|
||||||
|
assert os.path.exists(
|
||||||
|
os.path.join(temp_dir, "autosave_state.json")
|
||||||
|
), "Autosave file not created"
|
||||||
|
|
||||||
|
print("✓ State management test passed")
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_tools_and_execution():
|
||||||
|
"""Test agent tool handling and execution"""
|
||||||
|
print("\nTesting tools and execution...")
|
||||||
|
|
||||||
|
def sample_tool(x: int, y: int) -> int:
|
||||||
|
"""Sample tool that adds two numbers"""
|
||||||
|
return x + y
|
||||||
|
|
||||||
|
model = OpenAIChat(model_name="gpt-4o")
|
||||||
|
agent = Agent(
|
||||||
|
agent_name="Tools-Test-Agent",
|
||||||
|
llm=model,
|
||||||
|
max_loops=1,
|
||||||
|
tools=[sample_tool],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test adding tools
|
||||||
|
agent.add_tool(lambda x: x * 2)
|
||||||
|
assert len(agent.tools) == 2, "Tool not added correctly"
|
||||||
|
|
||||||
|
# Test removing tools
|
||||||
|
agent.remove_tool(sample_tool)
|
||||||
|
assert len(agent.tools) == 1, "Tool not removed correctly"
|
||||||
|
|
||||||
|
# Test tool execution
|
||||||
|
response = agent.run("Calculate 2 + 2 using the sample tool")
|
||||||
|
assert response is not None, "Tool execution failed"
|
||||||
|
|
||||||
|
print("✓ Tools and execution test passed")
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_concurrent_execution():
|
||||||
|
"""Test agent concurrent execution capabilities"""
|
||||||
|
print("\nTesting concurrent execution...")
|
||||||
|
|
||||||
|
model = OpenAIChat(model_name="gpt-4o")
|
||||||
|
agent = Agent(
|
||||||
|
agent_name="Concurrent-Test-Agent", llm=model, max_loops=1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test bulk run
|
||||||
|
tasks = [
|
||||||
|
{"task": "Count to 3"},
|
||||||
|
{"task": "Say hello"},
|
||||||
|
{"task": "Tell a short joke"},
|
||||||
|
]
|
||||||
|
|
||||||
|
responses = agent.bulk_run(tasks)
|
||||||
|
assert len(responses) == len(tasks), "Not all tasks completed"
|
||||||
|
assert all(
|
||||||
|
response is not None for response in responses
|
||||||
|
), "Some tasks failed"
|
||||||
|
|
||||||
|
# Test concurrent tasks
|
||||||
|
concurrent_responses = agent.run_concurrent_tasks(
|
||||||
|
["Task 1", "Task 2", "Task 3"]
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
len(concurrent_responses) == 3
|
||||||
|
), "Not all concurrent tasks completed"
|
||||||
|
|
||||||
|
print("✓ Concurrent execution test passed")
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_error_handling():
|
||||||
|
"""Test agent error handling and recovery"""
|
||||||
|
print("\nTesting error handling...")
|
||||||
|
|
||||||
|
model = OpenAIChat(model_name="gpt-4o")
|
||||||
|
agent = Agent(
|
||||||
|
agent_name="Error-Test-Agent",
|
||||||
|
llm=model,
|
||||||
|
max_loops=1,
|
||||||
|
retry_attempts=3,
|
||||||
|
retry_interval=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test invalid tool execution
|
||||||
|
try:
|
||||||
|
agent.parse_and_execute_tools("invalid_json")
|
||||||
|
print("✓ Invalid tool execution handled")
|
||||||
|
except Exception:
|
||||||
|
assert True, "Expected error caught"
|
||||||
|
|
||||||
|
# Test recovery after error
|
||||||
|
response = agent.run("Continue after error")
|
||||||
|
assert response is not None, "Agent failed to recover after error"
|
||||||
|
|
||||||
|
print("✓ Error handling test passed")
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_configuration():
|
||||||
|
"""Test agent configuration and parameters"""
|
||||||
|
print("\nTesting agent configuration...")
|
||||||
|
|
||||||
|
model = OpenAIChat(model_name="gpt-4o")
|
||||||
|
agent = Agent(
|
||||||
|
agent_name="Config-Test-Agent",
|
||||||
|
llm=model,
|
||||||
|
max_loops=1,
|
||||||
|
temperature=0.7,
|
||||||
|
max_tokens=4000,
|
||||||
|
context_length=8192,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test configuration methods
|
||||||
|
agent.update_system_prompt("New system prompt")
|
||||||
|
agent.update_max_loops(2)
|
||||||
|
agent.update_loop_interval(2)
|
||||||
|
|
||||||
|
# Verify updates
|
||||||
|
assert agent.max_loops == 2, "Max loops not updated"
|
||||||
|
assert agent.loop_interval == 2, "Loop interval not updated"
|
||||||
|
|
||||||
|
# Test configuration export
|
||||||
|
config_dict = agent.to_dict()
|
||||||
|
assert isinstance(
|
||||||
|
config_dict, dict
|
||||||
|
), "Configuration export failed"
|
||||||
|
|
||||||
|
# Test YAML export
|
||||||
|
yaml_config = agent.to_yaml()
|
||||||
|
assert isinstance(yaml_config, str), "YAML export failed"
|
||||||
|
|
||||||
|
print("✓ Configuration test passed")
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_with_stopping_condition():
|
||||||
|
"""Test agent with custom stopping condition"""
|
||||||
|
print("\nTesting agent with stopping condition...")
|
||||||
|
|
||||||
|
def custom_stopping_condition(response: str) -> bool:
|
||||||
|
return "STOP" in response.upper()
|
||||||
|
|
||||||
|
model = OpenAIChat(model_name="gpt-4o")
|
||||||
|
agent = Agent(
|
||||||
|
agent_name="Stopping-Condition-Agent",
|
||||||
|
llm=model,
|
||||||
|
max_loops=5,
|
||||||
|
stopping_condition=custom_stopping_condition,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = agent.run("Count up until you see the word STOP")
|
||||||
|
assert response is not None, "Stopping condition test failed"
|
||||||
|
print("✓ Stopping condition test passed")
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_with_retry_mechanism():
|
||||||
|
"""Test agent retry mechanism"""
|
||||||
|
print("\nTesting agent retry mechanism...")
|
||||||
|
|
||||||
|
model = OpenAIChat(model_name="gpt-4o")
|
||||||
|
agent = Agent(
|
||||||
|
agent_name="Retry-Test-Agent",
|
||||||
|
llm=model,
|
||||||
|
max_loops=1,
|
||||||
|
retry_attempts=3,
|
||||||
|
retry_interval=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = agent.run("Tell me a joke.")
|
||||||
|
assert response is not None, "Retry mechanism test failed"
|
||||||
|
print("✓ Retry mechanism test passed")
|
||||||
|
|
||||||
|
|
||||||
|
def test_bulk_and_filtered_operations():
|
||||||
|
"""Test bulk operations and response filtering"""
|
||||||
|
print("\nTesting bulk and filtered operations...")
|
||||||
|
|
||||||
|
model = OpenAIChat(model_name="gpt-4o")
|
||||||
|
agent = Agent(
|
||||||
|
agent_name="Bulk-Filter-Test-Agent", llm=model, max_loops=1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test bulk run
|
||||||
|
bulk_tasks = [
|
||||||
|
{"task": "What is 2+2?"},
|
||||||
|
{"task": "Name a color"},
|
||||||
|
{"task": "Count to 3"},
|
||||||
|
]
|
||||||
|
bulk_responses = agent.bulk_run(bulk_tasks)
|
||||||
|
assert len(bulk_responses) == len(
|
||||||
|
bulk_tasks
|
||||||
|
), "Bulk run should return same number of responses as tasks"
|
||||||
|
|
||||||
|
# Test response filtering
|
||||||
|
agent.add_response_filter("color")
|
||||||
|
filtered_response = agent.filtered_run(
|
||||||
|
"What is your favorite color?"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
"[FILTERED]" in filtered_response
|
||||||
|
), "Response filter not applied"
|
||||||
|
|
||||||
|
print("✓ Bulk and filtered operations test passed")
|
||||||
|
|
||||||
|
|
||||||
|
async def test_async_operations():
|
||||||
|
"""Test asynchronous operations"""
|
||||||
|
print("\nTesting async operations...")
|
||||||
|
|
||||||
|
model = OpenAIChat(model_name="gpt-4o")
|
||||||
|
agent = Agent(
|
||||||
|
agent_name="Async-Test-Agent", llm=model, max_loops=1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test single async run
|
||||||
|
response = await agent.arun("What is 1+1?")
|
||||||
|
assert response is not None, "Async run failed"
|
||||||
|
|
||||||
|
# Test concurrent async runs
|
||||||
|
tasks = ["Task 1", "Task 2", "Task 3"]
|
||||||
|
responses = await asyncio.gather(
|
||||||
|
*[agent.arun(task) for task in tasks]
|
||||||
|
)
|
||||||
|
assert len(responses) == len(
|
||||||
|
tasks
|
||||||
|
), "Not all async tasks completed"
|
||||||
|
|
||||||
|
print("✓ Async operations test passed")
|
||||||
|
|
||||||
|
|
||||||
|
def test_memory_and_state_persistence():
|
||||||
|
"""Test memory management and state persistence"""
|
||||||
|
print("\nTesting memory and state persistence...")
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
state_path = os.path.join(temp_dir, "test_state.json")
|
||||||
|
|
||||||
|
# Create agent with memory configuration
|
||||||
|
model = OpenAIChat(model_name="gpt-4o")
|
||||||
|
agent1 = Agent(
|
||||||
|
agent_name="Memory-State-Test-Agent",
|
||||||
|
llm=model,
|
||||||
|
max_loops=1,
|
||||||
|
saved_state_path=state_path,
|
||||||
|
context_length=8192,
|
||||||
|
autosave=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test memory operations
|
||||||
|
agent1.add_memory("Important fact: The sky is blue")
|
||||||
|
agent1.memory_query("What color is the sky?")
|
||||||
|
|
||||||
|
# Save state
|
||||||
|
agent1.save()
|
||||||
|
|
||||||
|
# Create new agent and load state
|
||||||
|
agent2 = Agent(
|
||||||
|
agent_name="Memory-State-Test-Agent",
|
||||||
|
llm=model,
|
||||||
|
max_loops=1,
|
||||||
|
)
|
||||||
|
agent2.load(state_path)
|
||||||
|
|
||||||
|
# Verify memory persistence
|
||||||
|
memory_content = (
|
||||||
|
agent2.short_memory.return_history_as_string()
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
"sky is blue" in memory_content
|
||||||
|
), "Memory not properly persisted"
|
||||||
|
|
||||||
|
print("✓ Memory and state persistence test passed")
|
||||||
|
|
||||||
|
|
||||||
|
def test_sentiment_and_evaluation():
|
||||||
|
"""Test sentiment analysis and response evaluation"""
|
||||||
|
print("\nTesting sentiment analysis and evaluation...")
|
||||||
|
|
||||||
|
def mock_sentiment_analyzer(text):
|
||||||
|
"""Mock sentiment analyzer that returns a score between 0 and 1"""
|
||||||
|
return 0.7 if "positive" in text.lower() else 0.3
|
||||||
|
|
||||||
|
def mock_evaluator(response):
|
||||||
|
"""Mock evaluator that checks response quality"""
|
||||||
|
return "GOOD" if len(response) > 10 else "BAD"
|
||||||
|
|
||||||
|
model = OpenAIChat(model_name="gpt-4o")
|
||||||
|
agent = Agent(
|
||||||
|
agent_name="Sentiment-Eval-Test-Agent",
|
||||||
|
llm=model,
|
||||||
|
max_loops=1,
|
||||||
|
sentiment_analyzer=mock_sentiment_analyzer,
|
||||||
|
sentiment_threshold=0.5,
|
||||||
|
evaluator=mock_evaluator,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test sentiment analysis
|
||||||
|
agent.run("Generate a positive message")
|
||||||
|
|
||||||
|
# Test evaluation
|
||||||
|
agent.run("Generate a detailed response")
|
||||||
|
|
||||||
|
print("✓ Sentiment and evaluation test passed")
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_management():
|
||||||
|
"""Test tool management functionality"""
|
||||||
|
print("\nTesting tool management...")
|
||||||
|
|
||||||
|
def tool1(x: int) -> int:
|
||||||
|
"""Sample tool 1"""
|
||||||
|
return x * 2
|
||||||
|
|
||||||
|
def tool2(x: int) -> int:
|
||||||
|
"""Sample tool 2"""
|
||||||
|
return x + 2
|
||||||
|
|
||||||
|
model = OpenAIChat(model_name="gpt-4o")
|
||||||
|
agent = Agent(
|
||||||
|
agent_name="Tool-Test-Agent",
|
||||||
|
llm=model,
|
||||||
|
max_loops=1,
|
||||||
|
tools=[tool1],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test adding tools
|
||||||
|
agent.add_tool(tool2)
|
||||||
|
assert len(agent.tools) == 2, "Tool not added correctly"
|
||||||
|
|
||||||
|
# Test removing tools
|
||||||
|
agent.remove_tool(tool1)
|
||||||
|
assert len(agent.tools) == 1, "Tool not removed correctly"
|
||||||
|
|
||||||
|
# Test adding multiple tools
|
||||||
|
agent.add_tools([tool1, tool2])
|
||||||
|
assert len(agent.tools) == 3, "Multiple tools not added correctly"
|
||||||
|
|
||||||
|
print("✓ Tool management test passed")
|
||||||
|
|
||||||
|
|
||||||
|
def test_system_prompt_and_configuration():
|
||||||
|
"""Test system prompt and configuration updates"""
|
||||||
|
print("\nTesting system prompt and configuration...")
|
||||||
|
|
||||||
|
model = OpenAIChat(model_name="gpt-4o")
|
||||||
|
agent = Agent(
|
||||||
|
agent_name="Config-Test-Agent", llm=model, max_loops=1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test updating system prompt
|
||||||
|
new_prompt = "You are a helpful assistant."
|
||||||
|
agent.update_system_prompt(new_prompt)
|
||||||
|
assert (
|
||||||
|
agent.system_prompt == new_prompt
|
||||||
|
), "System prompt not updated"
|
||||||
|
|
||||||
|
# Test configuration updates
|
||||||
|
agent.update_max_loops(5)
|
||||||
|
assert agent.max_loops == 5, "Max loops not updated"
|
||||||
|
|
||||||
|
agent.update_loop_interval(2)
|
||||||
|
assert agent.loop_interval == 2, "Loop interval not updated"
|
||||||
|
|
||||||
|
# Test configuration export
|
||||||
|
config_dict = agent.to_dict()
|
||||||
|
assert isinstance(
|
||||||
|
config_dict, dict
|
||||||
|
), "Configuration export failed"
|
||||||
|
|
||||||
|
print("✓ System prompt and configuration test passed")
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_with_dynamic_temperature():
|
||||||
|
"""Test agent with dynamic temperature"""
|
||||||
|
print("\nTesting agent with dynamic temperature...")
|
||||||
|
|
||||||
|
model = OpenAIChat(model_name="gpt-4o")
|
||||||
|
agent = Agent(
|
||||||
|
agent_name="Dynamic-Temp-Agent",
|
||||||
|
llm=model,
|
||||||
|
max_loops=2,
|
||||||
|
dynamic_temperature_enabled=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = agent.run("Generate a creative story.")
|
||||||
|
assert response is not None, "Dynamic temperature test failed"
|
||||||
|
print("✓ Dynamic temperature test passed")
|
||||||
|
|
||||||
|
|
||||||
|
def run_all_tests():
|
||||||
|
"""Run all test functions"""
|
||||||
|
print("Starting Extended Agent functional tests...\n")
|
||||||
|
|
||||||
|
test_functions = [
|
||||||
|
test_basic_agent_functionality,
|
||||||
|
test_memory_management,
|
||||||
|
test_agent_output_formats,
|
||||||
|
test_agent_state_management,
|
||||||
|
test_agent_tools_and_execution,
|
||||||
|
test_agent_concurrent_execution,
|
||||||
|
test_agent_error_handling,
|
||||||
|
test_agent_configuration,
|
||||||
|
test_agent_with_stopping_condition,
|
||||||
|
test_agent_with_retry_mechanism,
|
||||||
|
test_agent_with_dynamic_temperature,
|
||||||
|
test_bulk_and_filtered_operations,
|
||||||
|
test_memory_and_state_persistence,
|
||||||
|
test_sentiment_and_evaluation,
|
||||||
|
test_tool_management,
|
||||||
|
test_system_prompt_and_configuration,
|
||||||
|
]
|
||||||
|
|
||||||
|
# Run synchronous tests
|
||||||
|
total_tests = len(test_functions) + 1 # +1 for async test
|
||||||
|
passed_tests = 0
|
||||||
|
|
||||||
|
for test in test_functions:
|
||||||
|
try:
|
||||||
|
test()
|
||||||
|
passed_tests += 1
|
||||||
|
except Exception as e:
|
||||||
|
print(f"✗ Test {test.__name__} failed: {str(e)}")
|
||||||
|
|
||||||
|
# Run async test
|
||||||
|
try:
|
||||||
|
asyncio.run(test_async_operations())
|
||||||
|
passed_tests += 1
|
||||||
|
except Exception as e:
|
||||||
|
print(f"✗ Async operations test failed: {str(e)}")
|
||||||
|
|
||||||
|
print("\nExtended Test Summary:")
|
||||||
|
print(f"Total Tests: {total_tests}")
|
||||||
|
print(f"Passed: {passed_tests}")
|
||||||
|
print(f"Failed: {total_tests - passed_tests}")
|
||||||
|
print(f"Success Rate: {(passed_tests/total_tests)*100:.2f}%")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run_all_tests()
|
Loading…
Reference in new issue