parent
9b1106ac91
commit
9abe300548
@ -0,0 +1,348 @@
|
||||
import json
|
||||
import os
|
||||
import platform
|
||||
import sys
|
||||
import traceback
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import psutil
|
||||
import requests
|
||||
from loguru import logger
|
||||
from swarm_models import OpenAIChat
|
||||
|
||||
from swarms.structs.agent import Agent
|
||||
|
||||
|
||||
@dataclass
|
||||
class SwarmSystemInfo:
|
||||
"""System information for Swarms issue reports."""
|
||||
|
||||
os_name: str
|
||||
os_version: str
|
||||
python_version: str
|
||||
cpu_usage: float
|
||||
memory_usage: float
|
||||
disk_usage: float
|
||||
swarms_version: str # Added Swarms version tracking
|
||||
cuda_available: bool # Added CUDA availability check
|
||||
gpu_info: Optional[str] # Added GPU information
|
||||
|
||||
|
||||
class SwarmsIssueReporter:
|
||||
"""
|
||||
Production-grade GitHub issue reporter specifically designed for the Swarms library.
|
||||
Automatically creates detailed issues for the https://github.com/kyegomez/swarms repository.
|
||||
|
||||
Features:
|
||||
- Swarms-specific error categorization
|
||||
- Automatic version and dependency tracking
|
||||
- CUDA and GPU information collection
|
||||
- Integration with Swarms logging system
|
||||
- Detailed environment information
|
||||
"""
|
||||
|
||||
REPO_OWNER = "kyegomez"
|
||||
REPO_NAME = "swarms"
|
||||
ISSUE_CATEGORIES = {
|
||||
"agent": ["agent", "automation"],
|
||||
"memory": ["memory", "storage"],
|
||||
"tool": ["tools", "integration"],
|
||||
"llm": ["llm", "model"],
|
||||
"performance": ["performance", "optimization"],
|
||||
"compatibility": ["compatibility", "environment"],
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
github_token: str,
|
||||
rate_limit: int = 10,
|
||||
rate_period: int = 3600,
|
||||
log_file: str = "swarms_issues.log",
|
||||
enable_duplicate_check: bool = True,
|
||||
):
|
||||
"""
|
||||
Initialize the Swarms Issue Reporter.
|
||||
|
||||
Args:
|
||||
github_token (str): GitHub personal access token
|
||||
rate_limit (int): Maximum number of issues to create per rate_period
|
||||
rate_period (int): Time period for rate limiting in seconds
|
||||
log_file (str): Path to log file
|
||||
enable_duplicate_check (bool): Whether to check for duplicate issues
|
||||
"""
|
||||
self.github_token = github_token
|
||||
self.rate_limit = rate_limit
|
||||
self.rate_period = rate_period
|
||||
self.enable_duplicate_check = enable_duplicate_check
|
||||
self.github_token = os.getenv("GITHUB_API_KEY")
|
||||
|
||||
# Initialize logging
|
||||
log_path = os.path.join(os.getcwd(), "logs", log_file)
|
||||
os.makedirs(os.path.dirname(log_path), exist_ok=True)
|
||||
logger.add(
|
||||
log_path,
|
||||
rotation="1 day",
|
||||
retention="1 month",
|
||||
compression="zip",
|
||||
)
|
||||
|
||||
# Issue tracking
|
||||
self.issues_created = []
|
||||
self.last_issue_time = datetime.now()
|
||||
|
||||
# Validate GitHub token
|
||||
# self._validate_github_credentials()
|
||||
|
||||
def _get_swarms_version(self) -> str:
|
||||
"""Get the installed version of Swarms."""
|
||||
try:
|
||||
import swarms
|
||||
|
||||
return swarms.__version__
|
||||
except:
|
||||
return "Unknown"
|
||||
|
||||
def _get_gpu_info(self) -> Tuple[bool, Optional[str]]:
|
||||
"""Get GPU information and CUDA availability."""
|
||||
try:
|
||||
import torch
|
||||
|
||||
cuda_available = torch.cuda.is_available()
|
||||
if cuda_available:
|
||||
gpu_info = torch.cuda.get_device_name(0)
|
||||
return cuda_available, gpu_info
|
||||
return False, None
|
||||
except:
|
||||
return False, None
|
||||
|
||||
def _get_system_info(self) -> SwarmSystemInfo:
|
||||
"""Collect system and Swarms-specific information."""
|
||||
cuda_available, gpu_info = self._get_gpu_info()
|
||||
|
||||
return SwarmSystemInfo(
|
||||
os_name=platform.system(),
|
||||
os_version=platform.version(),
|
||||
python_version=sys.version,
|
||||
cpu_usage=psutil.cpu_percent(),
|
||||
memory_usage=psutil.virtual_memory().percent,
|
||||
disk_usage=psutil.disk_usage("/").percent,
|
||||
swarms_version=self._get_swarms_version(),
|
||||
cuda_available=cuda_available,
|
||||
gpu_info=gpu_info,
|
||||
)
|
||||
|
||||
def _categorize_error(
|
||||
self, error: Exception, context: Dict
|
||||
) -> List[str]:
|
||||
"""Categorize the error and return appropriate labels."""
|
||||
error_str = str(error).lower()
|
||||
type(error).__name__
|
||||
|
||||
labels = ["bug", "automated"]
|
||||
|
||||
# Check error message and context for category keywords
|
||||
for (
|
||||
category,
|
||||
category_labels,
|
||||
) in self.ISSUE_CATEGORIES.items():
|
||||
if any(
|
||||
keyword in error_str for keyword in category_labels
|
||||
):
|
||||
labels.extend(category_labels)
|
||||
break
|
||||
|
||||
# Add severity label based on error type
|
||||
if issubclass(type(error), (SystemError, MemoryError)):
|
||||
labels.append("severity:critical")
|
||||
elif issubclass(type(error), (ValueError, TypeError)):
|
||||
labels.append("severity:medium")
|
||||
else:
|
||||
labels.append("severity:low")
|
||||
|
||||
return list(set(labels)) # Remove duplicates
|
||||
|
||||
def _format_swarms_issue_body(
|
||||
self,
|
||||
error: Exception,
|
||||
system_info: SwarmSystemInfo,
|
||||
context: Dict,
|
||||
) -> str:
|
||||
"""Format the issue body with Swarms-specific information."""
|
||||
return f"""
|
||||
## Swarms Error Report
|
||||
- **Error Type**: {type(error).__name__}
|
||||
- **Error Message**: {str(error)}
|
||||
- **Swarms Version**: {system_info.swarms_version}
|
||||
|
||||
## Environment Information
|
||||
- **OS**: {system_info.os_name} {system_info.os_version}
|
||||
- **Python Version**: {system_info.python_version}
|
||||
- **CUDA Available**: {system_info.cuda_available}
|
||||
- **GPU**: {system_info.gpu_info or "N/A"}
|
||||
- **CPU Usage**: {system_info.cpu_usage}%
|
||||
- **Memory Usage**: {system_info.memory_usage}%
|
||||
- **Disk Usage**: {system_info.disk_usage}%
|
||||
|
||||
## Stack Trace
|
||||
{traceback.format_exc()}
|
||||
|
||||
## Context
|
||||
{json.dumps(context, indent=2)}
|
||||
|
||||
## Dependencies
|
||||
{self._get_dependencies_info()}
|
||||
|
||||
## Time of Occurrence
|
||||
{datetime.now().isoformat()}
|
||||
|
||||
---
|
||||
*This issue was automatically generated by SwarmsIssueReporter*
|
||||
"""
|
||||
|
||||
def _get_dependencies_info(self) -> str:
|
||||
"""Get information about installed dependencies."""
|
||||
try:
|
||||
import pkg_resources
|
||||
|
||||
deps = []
|
||||
for dist in pkg_resources.working_set:
|
||||
deps.append(f"- {dist.key} {dist.version}")
|
||||
return "\n".join(deps)
|
||||
except:
|
||||
return "Unable to fetch dependency information"
|
||||
|
||||
# First, add this method to your SwarmsIssueReporter class
|
||||
def _check_rate_limit(self) -> bool:
|
||||
"""Check if we're within rate limits."""
|
||||
now = datetime.now()
|
||||
time_diff = (now - self.last_issue_time).total_seconds()
|
||||
|
||||
if (
|
||||
len(self.issues_created) >= self.rate_limit
|
||||
and time_diff < self.rate_period
|
||||
):
|
||||
logger.warning("Rate limit exceeded for issue creation")
|
||||
return False
|
||||
|
||||
# Clean up old issues from tracking
|
||||
self.issues_created = [
|
||||
time
|
||||
for time in self.issues_created
|
||||
if (now - time).total_seconds() < self.rate_period
|
||||
]
|
||||
|
||||
return True
|
||||
|
||||
def report_swarms_issue(
|
||||
self,
|
||||
error: Exception,
|
||||
agent: Optional[Agent] = None,
|
||||
context: Dict[str, Any] = None,
|
||||
priority: str = "normal",
|
||||
) -> Optional[int]:
|
||||
"""
|
||||
Report a Swarms-specific issue to GitHub.
|
||||
|
||||
Args:
|
||||
error (Exception): The exception to report
|
||||
agent (Optional[Agent]): The Swarms agent instance that encountered the error
|
||||
context (Dict[str, Any]): Additional context about the error
|
||||
priority (str): Issue priority ("low", "normal", "high", "critical")
|
||||
|
||||
Returns:
|
||||
Optional[int]: Issue number if created successfully
|
||||
"""
|
||||
try:
|
||||
if not self._check_rate_limit():
|
||||
logger.warning(
|
||||
"Skipping issue creation due to rate limit"
|
||||
)
|
||||
return None
|
||||
|
||||
# Collect system information
|
||||
system_info = self._get_system_info()
|
||||
|
||||
# Prepare context with agent information if available
|
||||
full_context = context or {}
|
||||
if agent:
|
||||
full_context.update(
|
||||
{
|
||||
"agent_name": agent.agent_name,
|
||||
"agent_description": agent.agent_description,
|
||||
"max_loops": agent.max_loops,
|
||||
"context_length": agent.context_length,
|
||||
}
|
||||
)
|
||||
|
||||
# Create issue title
|
||||
title = f"[{type(error).__name__}] {str(error)[:100]}"
|
||||
if agent:
|
||||
title = f"[Agent: {agent.agent_name}] {title}"
|
||||
|
||||
# Get appropriate labels
|
||||
labels = self._categorize_error(error, full_context)
|
||||
labels.append(f"priority:{priority}")
|
||||
|
||||
# Create the issue
|
||||
url = f"https://api.github.com/repos/{self.REPO_OWNER}/{self.REPO_NAME}/issues"
|
||||
data = {
|
||||
"title": title,
|
||||
"body": self._format_swarms_issue_body(
|
||||
error, system_info, full_context
|
||||
),
|
||||
"labels": labels,
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
url,
|
||||
headers={
|
||||
"Authorization": f"token {self.github_token}"
|
||||
},
|
||||
json=data,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
issue_number = response.json()["number"]
|
||||
logger.info(
|
||||
f"Successfully created Swarms issue #{issue_number}"
|
||||
)
|
||||
|
||||
return issue_number
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating Swarms issue: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
# from swarms import Agent
|
||||
# from swarm_models import OpenAIChat
|
||||
# from swarms.utils.issue_reporter import SwarmsIssueReporter
|
||||
# import os
|
||||
|
||||
# Setup the reporter with your GitHub token
|
||||
reporter = SwarmsIssueReporter(
|
||||
github_token=os.getenv("GITHUB_API_KEY")
|
||||
)
|
||||
|
||||
|
||||
# Force an error to test the reporter
|
||||
try:
|
||||
# This will raise an error since the input isn't valid
|
||||
# Create an agent that might have issues
|
||||
model = OpenAIChat(model_name="gpt-4o")
|
||||
agent = Agent(agent_name="Test-Agent", max_loops=1)
|
||||
|
||||
result = agent.run(None)
|
||||
|
||||
raise ValueError("test")
|
||||
except Exception as e:
|
||||
# Report the issue
|
||||
issue_number = reporter.report_swarms_issue(
|
||||
error=e,
|
||||
agent=agent,
|
||||
context={"task": "test_run"},
|
||||
priority="high",
|
||||
)
|
||||
print(f"Created issue number: {issue_number}")
|
@ -0,0 +1,189 @@
|
||||
import requests
|
||||
import datetime
|
||||
from typing import List, Dict, Tuple
|
||||
from loguru import logger
|
||||
from swarms import Agent
|
||||
from swarm_models import OpenAIChat
|
||||
|
||||
# GitHub API Configurations
|
||||
GITHUB_REPO = "kyegomez/swarms" # Swarms GitHub repository
|
||||
GITHUB_API_URL = f"https://api.github.com/repos/{GITHUB_REPO}/commits"
|
||||
|
||||
# Initialize Loguru
|
||||
logger.add(
|
||||
"commit_summary.log",
|
||||
rotation="1 MB",
|
||||
level="INFO",
|
||||
backtrace=True,
|
||||
diagnose=True,
|
||||
)
|
||||
|
||||
|
||||
# Step 1: Fetch the latest commits from GitHub
|
||||
def fetch_latest_commits(
|
||||
repo_url: str, limit: int = 5
|
||||
) -> List[Dict[str, str]]:
|
||||
"""
|
||||
Fetch the latest commits from a public GitHub repository.
|
||||
"""
|
||||
logger.info(
|
||||
f"Fetching the latest {limit} commits from {repo_url}"
|
||||
)
|
||||
try:
|
||||
params = {"per_page": limit}
|
||||
response = requests.get(repo_url, params=params)
|
||||
response.raise_for_status()
|
||||
|
||||
commits = response.json()
|
||||
commit_data = []
|
||||
|
||||
for commit in commits:
|
||||
commit_data.append(
|
||||
{
|
||||
"sha": commit["sha"][:7], # Short commit hash
|
||||
"author": commit["commit"]["author"]["name"],
|
||||
"message": commit["commit"]["message"],
|
||||
"date": commit["commit"]["author"]["date"],
|
||||
}
|
||||
)
|
||||
|
||||
logger.success("Successfully fetched commit data")
|
||||
return commit_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching commits: {e}")
|
||||
raise
|
||||
|
||||
|
||||
# Step 2: Format commits and fetch current time
|
||||
def format_commits_with_time(
|
||||
commits: List[Dict[str, str]]
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
Format commit data into a readable string and return current time.
|
||||
"""
|
||||
current_time = datetime.datetime.now().strftime(
|
||||
"%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
logger.info(f"Formatting commits at {current_time}")
|
||||
|
||||
commit_summary = "\n".join(
|
||||
[
|
||||
f"- `{commit['sha']}` by {commit['author']} on {commit['date']}: {commit['message']}"
|
||||
for commit in commits
|
||||
]
|
||||
)
|
||||
|
||||
logger.success("Commits formatted successfully")
|
||||
return current_time, commit_summary
|
||||
|
||||
|
||||
# Step 3: Build a dynamic system prompt
|
||||
def build_custom_system_prompt(
|
||||
current_time: str, commit_summary: str
|
||||
) -> str:
|
||||
"""
|
||||
Build a dynamic system prompt with the current time and commit summary.
|
||||
"""
|
||||
logger.info("Building the custom system prompt for the agent")
|
||||
prompt = f"""
|
||||
You are a software analyst tasked with summarizing the latest commits from the Swarms GitHub repository.
|
||||
|
||||
The current time is **{current_time}**.
|
||||
|
||||
Here are the latest commits:
|
||||
{commit_summary}
|
||||
|
||||
**Your task**:
|
||||
1. Summarize the changes into a clear and concise table in **markdown format**.
|
||||
2. Highlight the key improvements and fixes.
|
||||
3. End your output with the token `<DONE>`.
|
||||
|
||||
Make sure the table includes the following columns: Commit SHA, Author, Date, and Commit Message.
|
||||
"""
|
||||
logger.success("System prompt created successfully")
|
||||
return prompt
|
||||
|
||||
|
||||
# Step 4: Initialize the Agent
|
||||
def initialize_agent() -> Agent:
|
||||
"""
|
||||
Initialize the Swarms agent with OpenAI model.
|
||||
"""
|
||||
logger.info("Initializing the agent with GPT-4o")
|
||||
model = OpenAIChat(model_name="gpt-4o")
|
||||
|
||||
agent = Agent(
|
||||
agent_name="Commit-Summarization-Agent",
|
||||
agent_description="Fetch and summarize GitHub commits for Swarms repository.",
|
||||
system_prompt="", # Will set dynamically
|
||||
max_loops=1,
|
||||
llm=model,
|
||||
dynamic_temperature_enabled=True,
|
||||
user_name="Kye",
|
||||
retry_attempts=3,
|
||||
context_length=8192,
|
||||
return_step_meta=False,
|
||||
output_type="str",
|
||||
auto_generate_prompt=False,
|
||||
max_tokens=4000,
|
||||
stopping_token="<DONE>",
|
||||
interactive=False,
|
||||
)
|
||||
logger.success("Agent initialized successfully")
|
||||
return agent
|
||||
|
||||
|
||||
# Step 5: Run the Agent with Data
|
||||
def summarize_commits_with_agent(agent: Agent, prompt: str) -> str:
|
||||
"""
|
||||
Pass the system prompt to the agent and fetch the result.
|
||||
"""
|
||||
logger.info("Sending data to the agent for summarization")
|
||||
try:
|
||||
result = agent.run(
|
||||
f"{prompt}",
|
||||
all_cores=True,
|
||||
)
|
||||
logger.success("Agent completed the summarization task")
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Agent encountered an error: {e}")
|
||||
raise
|
||||
|
||||
|
||||
# Main Execution
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
logger.info("Starting commit summarization process")
|
||||
|
||||
# Fetch latest commits
|
||||
latest_commits = fetch_latest_commits(GITHUB_API_URL, limit=5)
|
||||
|
||||
# Format commits and get current time
|
||||
current_time, commit_summary = format_commits_with_time(
|
||||
latest_commits
|
||||
)
|
||||
|
||||
# Build the custom system prompt
|
||||
custom_system_prompt = build_custom_system_prompt(
|
||||
current_time, commit_summary
|
||||
)
|
||||
|
||||
# Initialize agent
|
||||
agent = initialize_agent()
|
||||
|
||||
# Set the dynamic system prompt
|
||||
agent.system_prompt = custom_system_prompt
|
||||
|
||||
# Run the agent and summarize commits
|
||||
result = summarize_commits_with_agent(
|
||||
agent, custom_system_prompt
|
||||
)
|
||||
|
||||
# Print the result
|
||||
print("### Commit Summary in Markdown:")
|
||||
print(result)
|
||||
|
||||
except Exception as e:
|
||||
logger.critical(f"Process failed: {e}")
|
@ -0,0 +1,44 @@
|
||||
from swarms import Agent
|
||||
from swarms.prompts.finance_agent_sys_prompt import (
|
||||
FINANCIAL_AGENT_SYS_PROMPT,
|
||||
)
|
||||
from swarm_models import OpenAIChat
|
||||
|
||||
model = OpenAIChat(model_name="gpt-4o")
|
||||
|
||||
|
||||
# Initialize the agent
|
||||
agent = Agent(
|
||||
agent_name="Financial-Analysis-Agent",
|
||||
agent_description="Personal finance advisor agent",
|
||||
system_prompt=FINANCIAL_AGENT_SYS_PROMPT
|
||||
+ "Output the <DONE> token when you're done creating a portfolio of etfs, index, funds, and more for AI",
|
||||
max_loops=1,
|
||||
llm=model,
|
||||
dynamic_temperature_enabled=True,
|
||||
user_name="Kye",
|
||||
retry_attempts=3,
|
||||
# streaming_on=True,
|
||||
context_length=8192,
|
||||
return_step_meta=False,
|
||||
output_type="str", # "json", "dict", "csv" OR "string" "yaml" and
|
||||
auto_generate_prompt=False, # Auto generate prompt for the agent based on name, description, and system prompt, task
|
||||
max_tokens=4000, # max output tokens
|
||||
# interactive=True,
|
||||
stopping_token="<DONE>",
|
||||
saved_state_path="agent_00.json",
|
||||
interactive=False,
|
||||
)
|
||||
|
||||
|
||||
async def run_agent():
|
||||
await agent.arun(
|
||||
"Create a table of super high growth opportunities for AI. I have $40k to invest in ETFs, index funds, and more. Please create a table in markdown.",
|
||||
all_cores=True,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
asyncio.run(run_agent())
|
@ -0,0 +1,14 @@
|
||||
from swarms.prompts.finance_agent_sys_prompt import (
|
||||
FINANCIAL_AGENT_SYS_PROMPT,
|
||||
)
|
||||
from swarms.agents.openai_assistant import OpenAIAssistant
|
||||
|
||||
agent = OpenAIAssistant(
|
||||
name="test", instructions=FINANCIAL_AGENT_SYS_PROMPT
|
||||
)
|
||||
|
||||
print(
|
||||
agent.run(
|
||||
"Create a table of super high growth opportunities for AI. I have $40k to invest in ETFs, index funds, and more. Please create a table in markdown.",
|
||||
)
|
||||
)
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,419 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel
|
||||
from swarm_models.tiktoken_wrapper import TikTokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MemoryMetadata(BaseModel):
|
||||
"""Metadata for memory entries"""
|
||||
|
||||
timestamp: Optional[float] = time.time()
|
||||
role: Optional[str] = None
|
||||
agent_name: Optional[str] = None
|
||||
session_id: Optional[str] = None
|
||||
memory_type: Optional[str] = None # 'short_term' or 'long_term'
|
||||
token_count: Optional[int] = None
|
||||
message_id: Optional[str] = str(uuid.uuid4())
|
||||
|
||||
|
||||
class MemoryEntry(BaseModel):
|
||||
"""Single memory entry with content and metadata"""
|
||||
|
||||
content: Optional[str] = None
|
||||
metadata: Optional[MemoryMetadata] = None
|
||||
|
||||
|
||||
class MemoryConfig(BaseModel):
|
||||
"""Configuration for memory manager"""
|
||||
|
||||
max_short_term_tokens: Optional[int] = 4096
|
||||
max_entries: Optional[int] = None
|
||||
system_messages_token_buffer: Optional[int] = 1000
|
||||
enable_long_term_memory: Optional[bool] = False
|
||||
auto_archive: Optional[bool] = True
|
||||
archive_threshold: Optional[float] = 0.8 # Archive when 80% full
|
||||
|
||||
|
||||
class MemoryManager:
|
||||
"""
|
||||
Manages both short-term and long-term memory for an agent, handling token limits,
|
||||
archival, and context retrieval.
|
||||
|
||||
Args:
|
||||
config (MemoryConfig): Configuration for memory management
|
||||
tokenizer (Optional[Any]): Tokenizer to use for token counting
|
||||
long_term_memory (Optional[Any]): Vector store or database for long-term storage
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: MemoryConfig,
|
||||
tokenizer: Optional[Any] = None,
|
||||
long_term_memory: Optional[Any] = None,
|
||||
):
|
||||
self.config = config
|
||||
self.tokenizer = tokenizer or TikTokenizer()
|
||||
self.long_term_memory = long_term_memory
|
||||
|
||||
# Initialize memories
|
||||
self.short_term_memory: List[MemoryEntry] = []
|
||||
self.system_messages: List[MemoryEntry] = []
|
||||
|
||||
# Memory statistics
|
||||
self.total_tokens_processed: int = 0
|
||||
self.archived_entries_count: int = 0
|
||||
|
||||
def create_memory_entry(
|
||||
self,
|
||||
content: str,
|
||||
role: str,
|
||||
agent_name: str,
|
||||
session_id: str,
|
||||
memory_type: str = "short_term",
|
||||
) -> MemoryEntry:
|
||||
"""Create a new memory entry with metadata"""
|
||||
metadata = MemoryMetadata(
|
||||
timestamp=time.time(),
|
||||
role=role,
|
||||
agent_name=agent_name,
|
||||
session_id=session_id,
|
||||
memory_type=memory_type,
|
||||
token_count=self.tokenizer.count_tokens(content),
|
||||
)
|
||||
return MemoryEntry(content=content, metadata=metadata)
|
||||
|
||||
def add_memory(
|
||||
self,
|
||||
content: str,
|
||||
role: str,
|
||||
agent_name: str,
|
||||
session_id: str,
|
||||
is_system: bool = False,
|
||||
) -> None:
|
||||
"""Add a new memory entry to appropriate storage"""
|
||||
entry = self.create_memory_entry(
|
||||
content=content,
|
||||
role=role,
|
||||
agent_name=agent_name,
|
||||
session_id=session_id,
|
||||
memory_type="system" if is_system else "short_term",
|
||||
)
|
||||
|
||||
if is_system:
|
||||
self.system_messages.append(entry)
|
||||
else:
|
||||
self.short_term_memory.append(entry)
|
||||
|
||||
# Check if archiving is needed
|
||||
if self.should_archive():
|
||||
self.archive_old_memories()
|
||||
|
||||
self.total_tokens_processed += entry.metadata.token_count
|
||||
|
||||
def get_current_token_count(self) -> int:
|
||||
"""Get total tokens in short-term memory"""
|
||||
return sum(
|
||||
entry.metadata.token_count
|
||||
for entry in self.short_term_memory
|
||||
)
|
||||
|
||||
def get_system_messages_token_count(self) -> int:
|
||||
"""Get total tokens in system messages"""
|
||||
return sum(
|
||||
entry.metadata.token_count
|
||||
for entry in self.system_messages
|
||||
)
|
||||
|
||||
def should_archive(self) -> bool:
|
||||
"""Check if archiving is needed based on configuration"""
|
||||
if not self.config.auto_archive:
|
||||
return False
|
||||
|
||||
current_usage = (
|
||||
self.get_current_token_count()
|
||||
/ self.config.max_short_term_tokens
|
||||
)
|
||||
return current_usage >= self.config.archive_threshold
|
||||
|
||||
def archive_old_memories(self) -> None:
|
||||
"""Move older memories to long-term storage"""
|
||||
if not self.long_term_memory:
|
||||
logger.warning(
|
||||
"No long-term memory storage configured for archiving"
|
||||
)
|
||||
return
|
||||
|
||||
while self.should_archive():
|
||||
# Get oldest non-system message
|
||||
if not self.short_term_memory:
|
||||
break
|
||||
|
||||
oldest_entry = self.short_term_memory.pop(0)
|
||||
|
||||
# Store in long-term memory
|
||||
self.store_in_long_term_memory(oldest_entry)
|
||||
self.archived_entries_count += 1
|
||||
|
||||
def store_in_long_term_memory(self, entry: MemoryEntry) -> None:
|
||||
"""Store a memory entry in long-term memory"""
|
||||
if self.long_term_memory is None:
|
||||
logger.warning(
|
||||
"Attempted to store in non-existent long-term memory"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
self.long_term_memory.add(str(entry.model_dump()))
|
||||
except Exception as e:
|
||||
logger.error(f"Error storing in long-term memory: {e}")
|
||||
# Re-add to short-term if storage fails
|
||||
self.short_term_memory.insert(0, entry)
|
||||
|
||||
def get_relevant_context(
|
||||
self, query: str, max_tokens: Optional[int] = None
|
||||
) -> str:
|
||||
"""
|
||||
Get relevant context from both memory types
|
||||
|
||||
Args:
|
||||
query (str): Query to match against memories
|
||||
max_tokens (Optional[int]): Maximum tokens to return
|
||||
|
||||
Returns:
|
||||
str: Combined relevant context
|
||||
"""
|
||||
contexts = []
|
||||
|
||||
# Add system messages first
|
||||
for entry in self.system_messages:
|
||||
contexts.append(entry.content)
|
||||
|
||||
# Add short-term memory
|
||||
for entry in reversed(self.short_term_memory):
|
||||
contexts.append(entry.content)
|
||||
|
||||
# Query long-term memory if available
|
||||
if self.long_term_memory is not None:
|
||||
long_term_context = self.long_term_memory.query(query)
|
||||
if long_term_context:
|
||||
contexts.append(str(long_term_context))
|
||||
|
||||
# Combine and truncate if needed
|
||||
combined = "\n".join(contexts)
|
||||
if max_tokens:
|
||||
combined = self.truncate_to_token_limit(
|
||||
combined, max_tokens
|
||||
)
|
||||
|
||||
return combined
|
||||
|
||||
def truncate_to_token_limit(
|
||||
self, text: str, max_tokens: int
|
||||
) -> str:
|
||||
"""Truncate text to fit within token limit"""
|
||||
current_tokens = self.tokenizer.count_tokens(text)
|
||||
|
||||
if current_tokens <= max_tokens:
|
||||
return text
|
||||
|
||||
# Truncate by splitting into sentences and rebuilding
|
||||
sentences = text.split(". ")
|
||||
result = []
|
||||
current_count = 0
|
||||
|
||||
for sentence in sentences:
|
||||
sentence_tokens = self.tokenizer.count_tokens(sentence)
|
||||
if current_count + sentence_tokens <= max_tokens:
|
||||
result.append(sentence)
|
||||
current_count += sentence_tokens
|
||||
else:
|
||||
break
|
||||
|
||||
return ". ".join(result)
|
||||
|
||||
def clear_short_term_memory(
|
||||
self, preserve_system: bool = True
|
||||
) -> None:
|
||||
"""Clear short-term memory with option to preserve system messages"""
|
||||
if not preserve_system:
|
||||
self.system_messages.clear()
|
||||
self.short_term_memory.clear()
|
||||
logger.info(
|
||||
"Cleared short-term memory"
|
||||
+ " (preserved system messages)"
|
||||
if preserve_system
|
||||
else ""
|
||||
)
|
||||
|
||||
def get_memory_stats(self) -> Dict[str, Any]:
|
||||
"""Get detailed memory statistics"""
|
||||
return {
|
||||
"short_term_messages": len(self.short_term_memory),
|
||||
"system_messages": len(self.system_messages),
|
||||
"current_tokens": self.get_current_token_count(),
|
||||
"system_tokens": self.get_system_messages_token_count(),
|
||||
"max_tokens": self.config.max_short_term_tokens,
|
||||
"token_usage_percent": round(
|
||||
(
|
||||
self.get_current_token_count()
|
||||
/ self.config.max_short_term_tokens
|
||||
)
|
||||
* 100,
|
||||
2,
|
||||
),
|
||||
"has_long_term_memory": self.long_term_memory is not None,
|
||||
"archived_entries": self.archived_entries_count,
|
||||
"total_tokens_processed": self.total_tokens_processed,
|
||||
}
|
||||
|
||||
def save_memory_snapshot(self, file_path: str) -> None:
|
||||
"""Save current memory state to file"""
|
||||
try:
|
||||
data = {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"config": self.config.model_dump(),
|
||||
"system_messages": [
|
||||
entry.model_dump()
|
||||
for entry in self.system_messages
|
||||
],
|
||||
"short_term_memory": [
|
||||
entry.model_dump()
|
||||
for entry in self.short_term_memory
|
||||
],
|
||||
"stats": self.get_memory_stats(),
|
||||
}
|
||||
|
||||
with open(file_path, "w") as f:
|
||||
if file_path.endswith(".yaml"):
|
||||
yaml.dump(data, f)
|
||||
else:
|
||||
json.dump(data, f, indent=2)
|
||||
|
||||
logger.info(f"Saved memory snapshot to {file_path}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving memory snapshot: {e}")
|
||||
raise
|
||||
|
||||
def load_memory_snapshot(self, file_path: str) -> None:
|
||||
"""Load memory state from file"""
|
||||
try:
|
||||
with open(file_path, "r") as f:
|
||||
if file_path.endswith(".yaml"):
|
||||
data = yaml.safe_load(f)
|
||||
else:
|
||||
data = json.load(f)
|
||||
|
||||
self.config = MemoryConfig(**data["config"])
|
||||
self.system_messages = [
|
||||
MemoryEntry(**entry)
|
||||
for entry in data["system_messages"]
|
||||
]
|
||||
self.short_term_memory = [
|
||||
MemoryEntry(**entry)
|
||||
for entry in data["short_term_memory"]
|
||||
]
|
||||
|
||||
logger.info(f"Loaded memory snapshot from {file_path}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading memory snapshot: {e}")
|
||||
raise
|
||||
|
||||
def search_memories(
|
||||
self, query: str, memory_type: str = "all"
|
||||
) -> List[MemoryEntry]:
|
||||
"""
|
||||
Search through memories of specified type
|
||||
|
||||
Args:
|
||||
query (str): Search query
|
||||
memory_type (str): Type of memories to search ("short_term", "system", "long_term", or "all")
|
||||
|
||||
Returns:
|
||||
List[MemoryEntry]: Matching memory entries
|
||||
"""
|
||||
results = []
|
||||
|
||||
if memory_type in ["short_term", "all"]:
|
||||
results.extend(
|
||||
[
|
||||
entry
|
||||
for entry in self.short_term_memory
|
||||
if query.lower() in entry.content.lower()
|
||||
]
|
||||
)
|
||||
|
||||
if memory_type in ["system", "all"]:
|
||||
results.extend(
|
||||
[
|
||||
entry
|
||||
for entry in self.system_messages
|
||||
if query.lower() in entry.content.lower()
|
||||
]
|
||||
)
|
||||
|
||||
if (
|
||||
memory_type in ["long_term", "all"]
|
||||
and self.long_term_memory is not None
|
||||
):
|
||||
long_term_results = self.long_term_memory.query(query)
|
||||
if long_term_results:
|
||||
# Convert long-term results to MemoryEntry format
|
||||
for result in long_term_results:
|
||||
content = str(result)
|
||||
metadata = MemoryMetadata(
|
||||
timestamp=time.time(),
|
||||
role="long_term",
|
||||
agent_name="system",
|
||||
session_id="long_term",
|
||||
memory_type="long_term",
|
||||
token_count=self.tokenizer.count_tokens(
|
||||
content
|
||||
),
|
||||
)
|
||||
results.append(
|
||||
MemoryEntry(
|
||||
content=content, metadata=metadata
|
||||
)
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
def get_memory_by_timeframe(
|
||||
self, start_time: float, end_time: float
|
||||
) -> List[MemoryEntry]:
|
||||
"""Get memories within a specific timeframe"""
|
||||
return [
|
||||
entry
|
||||
for entry in self.short_term_memory
|
||||
if start_time <= entry.metadata.timestamp <= end_time
|
||||
]
|
||||
|
||||
def export_memories(
|
||||
self, file_path: str, format: str = "json"
|
||||
) -> None:
|
||||
"""Export memories to file in specified format"""
|
||||
data = {
|
||||
"system_messages": [
|
||||
entry.model_dump() for entry in self.system_messages
|
||||
],
|
||||
"short_term_memory": [
|
||||
entry.model_dump() for entry in self.short_term_memory
|
||||
],
|
||||
"stats": self.get_memory_stats(),
|
||||
}
|
||||
|
||||
with open(file_path, "w") as f:
|
||||
if format == "yaml":
|
||||
yaml.dump(data, f)
|
||||
else:
|
||||
json.dump(data, f, indent=2)
|
@ -0,0 +1,258 @@
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Set
|
||||
from uuid import UUID
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SafeLoaderUtils:
|
||||
"""
|
||||
Utility class for safely loading and saving object states while automatically
|
||||
detecting and preserving class instances and complex objects.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def is_class_instance(obj: Any) -> bool:
|
||||
"""
|
||||
Detect if an object is a class instance (excluding built-in types).
|
||||
|
||||
Args:
|
||||
obj: Object to check
|
||||
|
||||
Returns:
|
||||
bool: True if object is a class instance
|
||||
"""
|
||||
if obj is None:
|
||||
return False
|
||||
|
||||
# Get the type of the object
|
||||
obj_type = type(obj)
|
||||
|
||||
# Check if it's a class instance but not a built-in type
|
||||
return (
|
||||
hasattr(obj, "__dict__")
|
||||
and not isinstance(obj, type)
|
||||
and obj_type.__module__ != "builtins"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def is_safe_type(value: Any) -> bool:
|
||||
"""
|
||||
Check if a value is of a safe, serializable type.
|
||||
|
||||
Args:
|
||||
value: Value to check
|
||||
|
||||
Returns:
|
||||
bool: True if the value is safe to serialize
|
||||
"""
|
||||
# Basic safe types
|
||||
safe_types = (
|
||||
type(None),
|
||||
bool,
|
||||
int,
|
||||
float,
|
||||
str,
|
||||
datetime,
|
||||
UUID,
|
||||
)
|
||||
|
||||
if isinstance(value, safe_types):
|
||||
return True
|
||||
|
||||
# Check containers
|
||||
if isinstance(value, (list, tuple)):
|
||||
return all(
|
||||
SafeLoaderUtils.is_safe_type(item) for item in value
|
||||
)
|
||||
|
||||
if isinstance(value, dict):
|
||||
return all(
|
||||
isinstance(k, str) and SafeLoaderUtils.is_safe_type(v)
|
||||
for k, v in value.items()
|
||||
)
|
||||
|
||||
# Check for common serializable types
|
||||
try:
|
||||
json.dumps(value)
|
||||
return True
|
||||
except (TypeError, OverflowError, ValueError):
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def get_class_attributes(obj: Any) -> Set[str]:
|
||||
"""
|
||||
Get all attributes of a class, including inherited ones.
|
||||
|
||||
Args:
|
||||
obj: Object to inspect
|
||||
|
||||
Returns:
|
||||
Set[str]: Set of attribute names
|
||||
"""
|
||||
attributes = set()
|
||||
|
||||
# Get all attributes from class and parent classes
|
||||
for cls in inspect.getmro(type(obj)):
|
||||
attributes.update(cls.__dict__.keys())
|
||||
|
||||
# Add instance attributes
|
||||
attributes.update(obj.__dict__.keys())
|
||||
|
||||
return attributes
|
||||
|
||||
@staticmethod
|
||||
def create_state_dict(obj: Any) -> Dict[str, Any]:
|
||||
"""
|
||||
Create a dictionary of safe values from an object's state.
|
||||
|
||||
Args:
|
||||
obj: Object to create state dict from
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Dictionary of safe values
|
||||
"""
|
||||
state_dict = {}
|
||||
|
||||
for attr_name in SafeLoaderUtils.get_class_attributes(obj):
|
||||
# Skip private attributes
|
||||
if attr_name.startswith("_"):
|
||||
continue
|
||||
|
||||
try:
|
||||
value = getattr(obj, attr_name, None)
|
||||
if SafeLoaderUtils.is_safe_type(value):
|
||||
state_dict[attr_name] = value
|
||||
except Exception as e:
|
||||
logger.debug(f"Skipped attribute {attr_name}: {e}")
|
||||
|
||||
return state_dict
|
||||
|
||||
@staticmethod
|
||||
def preserve_instances(obj: Any) -> Dict[str, Any]:
|
||||
"""
|
||||
Automatically detect and preserve all class instances in an object.
|
||||
|
||||
Args:
|
||||
obj: Object to preserve instances from
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Dictionary of preserved instances
|
||||
"""
|
||||
preserved = {}
|
||||
|
||||
for attr_name in SafeLoaderUtils.get_class_attributes(obj):
|
||||
if attr_name.startswith("_"):
|
||||
continue
|
||||
|
||||
try:
|
||||
value = getattr(obj, attr_name, None)
|
||||
if SafeLoaderUtils.is_class_instance(value):
|
||||
preserved[attr_name] = value
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not preserve {attr_name}: {e}")
|
||||
|
||||
return preserved
|
||||
|
||||
|
||||
class SafeStateManager:
|
||||
"""
|
||||
Manages saving and loading object states while automatically handling
|
||||
class instances and complex objects.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def save_state(obj: Any, file_path: str) -> None:
|
||||
"""
|
||||
Save an object's state to a file, automatically handling complex objects.
|
||||
|
||||
Args:
|
||||
obj: Object to save state from
|
||||
file_path: Path to save state to
|
||||
"""
|
||||
try:
|
||||
# Create state dict with only safe values
|
||||
state_dict = SafeLoaderUtils.create_state_dict(obj)
|
||||
|
||||
# Ensure directory exists
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||
|
||||
# Save to file
|
||||
with open(file_path, "w") as f:
|
||||
json.dump(state_dict, f, indent=4, default=str)
|
||||
|
||||
logger.info(f"Successfully saved state to: {file_path}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving state: {e}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def load_state(obj: Any, file_path: str) -> None:
|
||||
"""
|
||||
Load state into an object while preserving class instances.
|
||||
|
||||
Args:
|
||||
obj: Object to load state into
|
||||
file_path: Path to load state from
|
||||
"""
|
||||
try:
|
||||
# Verify file exists
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(
|
||||
f"State file not found: {file_path}"
|
||||
)
|
||||
|
||||
# Preserve existing instances
|
||||
preserved = SafeLoaderUtils.preserve_instances(obj)
|
||||
|
||||
# Load state
|
||||
with open(file_path, "r") as f:
|
||||
state_dict = json.load(f)
|
||||
|
||||
# Set safe values
|
||||
for key, value in state_dict.items():
|
||||
if (
|
||||
not key.startswith("_")
|
||||
and key not in preserved
|
||||
and SafeLoaderUtils.is_safe_type(value)
|
||||
):
|
||||
setattr(obj, key, value)
|
||||
|
||||
# Restore preserved instances
|
||||
for key, value in preserved.items():
|
||||
setattr(obj, key, value)
|
||||
|
||||
logger.info(
|
||||
f"Successfully loaded state from: {file_path}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading state: {e}")
|
||||
raise
|
||||
|
||||
|
||||
# # Example decorator for easy integration
|
||||
# def safe_state_methods(cls: Type) -> Type:
|
||||
# """
|
||||
# Class decorator to add safe state loading/saving methods to a class.
|
||||
|
||||
# Args:
|
||||
# cls: Class to decorate
|
||||
|
||||
# Returns:
|
||||
# Type: Decorated class
|
||||
# """
|
||||
# def save(self, file_path: str) -> None:
|
||||
# SafeStateManager.save_state(self, file_path)
|
||||
|
||||
# def load(self, file_path: str) -> None:
|
||||
# SafeStateManager.load_state(self, file_path)
|
||||
|
||||
# cls.save = save
|
||||
# cls.load = load
|
||||
# return cls
|
@ -0,0 +1,598 @@
|
||||
import asyncio
|
||||
from swarms import Agent
|
||||
from swarm_models import OpenAIChat
|
||||
import os
|
||||
import time
|
||||
import json
|
||||
import yaml
|
||||
import tempfile
|
||||
|
||||
|
||||
def test_basic_agent_functionality():
|
||||
"""Test basic agent initialization and simple task execution"""
|
||||
print("\nTesting basic agent functionality...")
|
||||
|
||||
model = OpenAIChat(model_name="gpt-4o")
|
||||
agent = Agent(agent_name="Test-Agent", llm=model, max_loops=1)
|
||||
|
||||
response = agent.run("What is 2+2?")
|
||||
assert response is not None, "Agent response should not be None"
|
||||
|
||||
# Test agent properties
|
||||
assert (
|
||||
agent.agent_name == "Test-Agent"
|
||||
), "Agent name not set correctly"
|
||||
assert agent.max_loops == 1, "Max loops not set correctly"
|
||||
assert agent.llm is not None, "LLM not initialized"
|
||||
|
||||
print("✓ Basic agent functionality test passed")
|
||||
|
||||
|
||||
def test_memory_management():
|
||||
"""Test agent memory management functionality"""
|
||||
print("\nTesting memory management...")
|
||||
|
||||
model = OpenAIChat(model_name="gpt-4o")
|
||||
agent = Agent(
|
||||
agent_name="Memory-Test-Agent",
|
||||
llm=model,
|
||||
max_loops=1,
|
||||
context_length=8192,
|
||||
)
|
||||
|
||||
# Test adding to memory
|
||||
agent.add_memory("Test memory entry")
|
||||
assert (
|
||||
"Test memory entry"
|
||||
in agent.short_memory.return_history_as_string()
|
||||
)
|
||||
|
||||
# Test memory query
|
||||
agent.memory_query("Test query")
|
||||
|
||||
# Test token counting
|
||||
tokens = agent.check_available_tokens()
|
||||
assert isinstance(tokens, int), "Token count should be an integer"
|
||||
|
||||
print("✓ Memory management test passed")
|
||||
|
||||
|
||||
def test_agent_output_formats():
|
||||
"""Test all available output formats"""
|
||||
print("\nTesting all output formats...")
|
||||
|
||||
model = OpenAIChat(model_name="gpt-4o")
|
||||
test_task = "Say hello!"
|
||||
|
||||
output_types = {
|
||||
"str": str,
|
||||
"string": str,
|
||||
"list": str, # JSON string containing list
|
||||
"json": str, # JSON string
|
||||
"dict": dict,
|
||||
"yaml": str,
|
||||
}
|
||||
|
||||
for output_type, expected_type in output_types.items():
|
||||
agent = Agent(
|
||||
agent_name=f"{output_type.capitalize()}-Output-Agent",
|
||||
llm=model,
|
||||
max_loops=1,
|
||||
output_type=output_type,
|
||||
)
|
||||
|
||||
response = agent.run(test_task)
|
||||
assert (
|
||||
response is not None
|
||||
), f"{output_type} output should not be None"
|
||||
|
||||
if output_type == "yaml":
|
||||
# Verify YAML can be parsed
|
||||
try:
|
||||
yaml.safe_load(response)
|
||||
print(f"✓ {output_type} output valid")
|
||||
except yaml.YAMLError:
|
||||
assert False, f"Invalid YAML output for {output_type}"
|
||||
elif output_type in ["json", "list"]:
|
||||
# Verify JSON can be parsed
|
||||
try:
|
||||
json.loads(response)
|
||||
print(f"✓ {output_type} output valid")
|
||||
except json.JSONDecodeError:
|
||||
assert False, f"Invalid JSON output for {output_type}"
|
||||
|
||||
print("✓ Output formats test passed")
|
||||
|
||||
|
||||
def test_agent_state_management():
|
||||
"""Test comprehensive state management functionality"""
|
||||
print("\nTesting state management...")
|
||||
|
||||
model = OpenAIChat(model_name="gpt-4o")
|
||||
|
||||
# Create temporary directory for test files
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
state_path = os.path.join(temp_dir, "agent_state.json")
|
||||
|
||||
# Create agent with initial state
|
||||
agent1 = Agent(
|
||||
agent_name="State-Test-Agent",
|
||||
llm=model,
|
||||
max_loops=1,
|
||||
saved_state_path=state_path,
|
||||
)
|
||||
|
||||
# Add some data to the agent
|
||||
agent1.run("Remember this: Test message 1")
|
||||
agent1.add_memory("Test message 2")
|
||||
|
||||
# Save state
|
||||
agent1.save()
|
||||
assert os.path.exists(state_path), "State file not created"
|
||||
|
||||
# Create new agent and load state
|
||||
agent2 = Agent(
|
||||
agent_name="State-Test-Agent", llm=model, max_loops=1
|
||||
)
|
||||
agent2.load(state_path)
|
||||
|
||||
# Verify state loaded correctly
|
||||
history2 = agent2.short_memory.return_history_as_string()
|
||||
assert (
|
||||
"Test message 1" in history2
|
||||
), "State not loaded correctly"
|
||||
assert (
|
||||
"Test message 2" in history2
|
||||
), "Memory not loaded correctly"
|
||||
|
||||
# Test autosave functionality
|
||||
agent3 = Agent(
|
||||
agent_name="Autosave-Test-Agent",
|
||||
llm=model,
|
||||
max_loops=1,
|
||||
saved_state_path=os.path.join(
|
||||
temp_dir, "autosave_state.json"
|
||||
),
|
||||
autosave=True,
|
||||
)
|
||||
|
||||
agent3.run("Test autosave")
|
||||
time.sleep(2) # Wait for autosave
|
||||
assert os.path.exists(
|
||||
os.path.join(temp_dir, "autosave_state.json")
|
||||
), "Autosave file not created"
|
||||
|
||||
print("✓ State management test passed")
|
||||
|
||||
|
||||
def test_agent_tools_and_execution():
|
||||
"""Test agent tool handling and execution"""
|
||||
print("\nTesting tools and execution...")
|
||||
|
||||
def sample_tool(x: int, y: int) -> int:
|
||||
"""Sample tool that adds two numbers"""
|
||||
return x + y
|
||||
|
||||
model = OpenAIChat(model_name="gpt-4o")
|
||||
agent = Agent(
|
||||
agent_name="Tools-Test-Agent",
|
||||
llm=model,
|
||||
max_loops=1,
|
||||
tools=[sample_tool],
|
||||
)
|
||||
|
||||
# Test adding tools
|
||||
agent.add_tool(lambda x: x * 2)
|
||||
assert len(agent.tools) == 2, "Tool not added correctly"
|
||||
|
||||
# Test removing tools
|
||||
agent.remove_tool(sample_tool)
|
||||
assert len(agent.tools) == 1, "Tool not removed correctly"
|
||||
|
||||
# Test tool execution
|
||||
response = agent.run("Calculate 2 + 2 using the sample tool")
|
||||
assert response is not None, "Tool execution failed"
|
||||
|
||||
print("✓ Tools and execution test passed")
|
||||
|
||||
|
||||
def test_agent_concurrent_execution():
|
||||
"""Test agent concurrent execution capabilities"""
|
||||
print("\nTesting concurrent execution...")
|
||||
|
||||
model = OpenAIChat(model_name="gpt-4o")
|
||||
agent = Agent(
|
||||
agent_name="Concurrent-Test-Agent", llm=model, max_loops=1
|
||||
)
|
||||
|
||||
# Test bulk run
|
||||
tasks = [
|
||||
{"task": "Count to 3"},
|
||||
{"task": "Say hello"},
|
||||
{"task": "Tell a short joke"},
|
||||
]
|
||||
|
||||
responses = agent.bulk_run(tasks)
|
||||
assert len(responses) == len(tasks), "Not all tasks completed"
|
||||
assert all(
|
||||
response is not None for response in responses
|
||||
), "Some tasks failed"
|
||||
|
||||
# Test concurrent tasks
|
||||
concurrent_responses = agent.run_concurrent_tasks(
|
||||
["Task 1", "Task 2", "Task 3"]
|
||||
)
|
||||
assert (
|
||||
len(concurrent_responses) == 3
|
||||
), "Not all concurrent tasks completed"
|
||||
|
||||
print("✓ Concurrent execution test passed")
|
||||
|
||||
|
||||
def test_agent_error_handling():
|
||||
"""Test agent error handling and recovery"""
|
||||
print("\nTesting error handling...")
|
||||
|
||||
model = OpenAIChat(model_name="gpt-4o")
|
||||
agent = Agent(
|
||||
agent_name="Error-Test-Agent",
|
||||
llm=model,
|
||||
max_loops=1,
|
||||
retry_attempts=3,
|
||||
retry_interval=1,
|
||||
)
|
||||
|
||||
# Test invalid tool execution
|
||||
try:
|
||||
agent.parse_and_execute_tools("invalid_json")
|
||||
print("✓ Invalid tool execution handled")
|
||||
except Exception:
|
||||
assert True, "Expected error caught"
|
||||
|
||||
# Test recovery after error
|
||||
response = agent.run("Continue after error")
|
||||
assert response is not None, "Agent failed to recover after error"
|
||||
|
||||
print("✓ Error handling test passed")
|
||||
|
||||
|
||||
def test_agent_configuration():
|
||||
"""Test agent configuration and parameters"""
|
||||
print("\nTesting agent configuration...")
|
||||
|
||||
model = OpenAIChat(model_name="gpt-4o")
|
||||
agent = Agent(
|
||||
agent_name="Config-Test-Agent",
|
||||
llm=model,
|
||||
max_loops=1,
|
||||
temperature=0.7,
|
||||
max_tokens=4000,
|
||||
context_length=8192,
|
||||
)
|
||||
|
||||
# Test configuration methods
|
||||
agent.update_system_prompt("New system prompt")
|
||||
agent.update_max_loops(2)
|
||||
agent.update_loop_interval(2)
|
||||
|
||||
# Verify updates
|
||||
assert agent.max_loops == 2, "Max loops not updated"
|
||||
assert agent.loop_interval == 2, "Loop interval not updated"
|
||||
|
||||
# Test configuration export
|
||||
config_dict = agent.to_dict()
|
||||
assert isinstance(
|
||||
config_dict, dict
|
||||
), "Configuration export failed"
|
||||
|
||||
# Test YAML export
|
||||
yaml_config = agent.to_yaml()
|
||||
assert isinstance(yaml_config, str), "YAML export failed"
|
||||
|
||||
print("✓ Configuration test passed")
|
||||
|
||||
|
||||
def test_agent_with_stopping_condition():
|
||||
"""Test agent with custom stopping condition"""
|
||||
print("\nTesting agent with stopping condition...")
|
||||
|
||||
def custom_stopping_condition(response: str) -> bool:
|
||||
return "STOP" in response.upper()
|
||||
|
||||
model = OpenAIChat(model_name="gpt-4o")
|
||||
agent = Agent(
|
||||
agent_name="Stopping-Condition-Agent",
|
||||
llm=model,
|
||||
max_loops=5,
|
||||
stopping_condition=custom_stopping_condition,
|
||||
)
|
||||
|
||||
response = agent.run("Count up until you see the word STOP")
|
||||
assert response is not None, "Stopping condition test failed"
|
||||
print("✓ Stopping condition test passed")
|
||||
|
||||
|
||||
def test_agent_with_retry_mechanism():
|
||||
"""Test agent retry mechanism"""
|
||||
print("\nTesting agent retry mechanism...")
|
||||
|
||||
model = OpenAIChat(model_name="gpt-4o")
|
||||
agent = Agent(
|
||||
agent_name="Retry-Test-Agent",
|
||||
llm=model,
|
||||
max_loops=1,
|
||||
retry_attempts=3,
|
||||
retry_interval=1,
|
||||
)
|
||||
|
||||
response = agent.run("Tell me a joke.")
|
||||
assert response is not None, "Retry mechanism test failed"
|
||||
print("✓ Retry mechanism test passed")
|
||||
|
||||
|
||||
def test_bulk_and_filtered_operations():
|
||||
"""Test bulk operations and response filtering"""
|
||||
print("\nTesting bulk and filtered operations...")
|
||||
|
||||
model = OpenAIChat(model_name="gpt-4o")
|
||||
agent = Agent(
|
||||
agent_name="Bulk-Filter-Test-Agent", llm=model, max_loops=1
|
||||
)
|
||||
|
||||
# Test bulk run
|
||||
bulk_tasks = [
|
||||
{"task": "What is 2+2?"},
|
||||
{"task": "Name a color"},
|
||||
{"task": "Count to 3"},
|
||||
]
|
||||
bulk_responses = agent.bulk_run(bulk_tasks)
|
||||
assert len(bulk_responses) == len(
|
||||
bulk_tasks
|
||||
), "Bulk run should return same number of responses as tasks"
|
||||
|
||||
# Test response filtering
|
||||
agent.add_response_filter("color")
|
||||
filtered_response = agent.filtered_run(
|
||||
"What is your favorite color?"
|
||||
)
|
||||
assert (
|
||||
"[FILTERED]" in filtered_response
|
||||
), "Response filter not applied"
|
||||
|
||||
print("✓ Bulk and filtered operations test passed")
|
||||
|
||||
|
||||
async def test_async_operations():
|
||||
"""Test asynchronous operations"""
|
||||
print("\nTesting async operations...")
|
||||
|
||||
model = OpenAIChat(model_name="gpt-4o")
|
||||
agent = Agent(
|
||||
agent_name="Async-Test-Agent", llm=model, max_loops=1
|
||||
)
|
||||
|
||||
# Test single async run
|
||||
response = await agent.arun("What is 1+1?")
|
||||
assert response is not None, "Async run failed"
|
||||
|
||||
# Test concurrent async runs
|
||||
tasks = ["Task 1", "Task 2", "Task 3"]
|
||||
responses = await asyncio.gather(
|
||||
*[agent.arun(task) for task in tasks]
|
||||
)
|
||||
assert len(responses) == len(
|
||||
tasks
|
||||
), "Not all async tasks completed"
|
||||
|
||||
print("✓ Async operations test passed")
|
||||
|
||||
|
||||
def test_memory_and_state_persistence():
|
||||
"""Test memory management and state persistence"""
|
||||
print("\nTesting memory and state persistence...")
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
state_path = os.path.join(temp_dir, "test_state.json")
|
||||
|
||||
# Create agent with memory configuration
|
||||
model = OpenAIChat(model_name="gpt-4o")
|
||||
agent1 = Agent(
|
||||
agent_name="Memory-State-Test-Agent",
|
||||
llm=model,
|
||||
max_loops=1,
|
||||
saved_state_path=state_path,
|
||||
context_length=8192,
|
||||
autosave=True,
|
||||
)
|
||||
|
||||
# Test memory operations
|
||||
agent1.add_memory("Important fact: The sky is blue")
|
||||
agent1.memory_query("What color is the sky?")
|
||||
|
||||
# Save state
|
||||
agent1.save()
|
||||
|
||||
# Create new agent and load state
|
||||
agent2 = Agent(
|
||||
agent_name="Memory-State-Test-Agent",
|
||||
llm=model,
|
||||
max_loops=1,
|
||||
)
|
||||
agent2.load(state_path)
|
||||
|
||||
# Verify memory persistence
|
||||
memory_content = (
|
||||
agent2.short_memory.return_history_as_string()
|
||||
)
|
||||
assert (
|
||||
"sky is blue" in memory_content
|
||||
), "Memory not properly persisted"
|
||||
|
||||
print("✓ Memory and state persistence test passed")
|
||||
|
||||
|
||||
def test_sentiment_and_evaluation():
|
||||
"""Test sentiment analysis and response evaluation"""
|
||||
print("\nTesting sentiment analysis and evaluation...")
|
||||
|
||||
def mock_sentiment_analyzer(text):
|
||||
"""Mock sentiment analyzer that returns a score between 0 and 1"""
|
||||
return 0.7 if "positive" in text.lower() else 0.3
|
||||
|
||||
def mock_evaluator(response):
|
||||
"""Mock evaluator that checks response quality"""
|
||||
return "GOOD" if len(response) > 10 else "BAD"
|
||||
|
||||
model = OpenAIChat(model_name="gpt-4o")
|
||||
agent = Agent(
|
||||
agent_name="Sentiment-Eval-Test-Agent",
|
||||
llm=model,
|
||||
max_loops=1,
|
||||
sentiment_analyzer=mock_sentiment_analyzer,
|
||||
sentiment_threshold=0.5,
|
||||
evaluator=mock_evaluator,
|
||||
)
|
||||
|
||||
# Test sentiment analysis
|
||||
agent.run("Generate a positive message")
|
||||
|
||||
# Test evaluation
|
||||
agent.run("Generate a detailed response")
|
||||
|
||||
print("✓ Sentiment and evaluation test passed")
|
||||
|
||||
|
||||
def test_tool_management():
|
||||
"""Test tool management functionality"""
|
||||
print("\nTesting tool management...")
|
||||
|
||||
def tool1(x: int) -> int:
|
||||
"""Sample tool 1"""
|
||||
return x * 2
|
||||
|
||||
def tool2(x: int) -> int:
|
||||
"""Sample tool 2"""
|
||||
return x + 2
|
||||
|
||||
model = OpenAIChat(model_name="gpt-4o")
|
||||
agent = Agent(
|
||||
agent_name="Tool-Test-Agent",
|
||||
llm=model,
|
||||
max_loops=1,
|
||||
tools=[tool1],
|
||||
)
|
||||
|
||||
# Test adding tools
|
||||
agent.add_tool(tool2)
|
||||
assert len(agent.tools) == 2, "Tool not added correctly"
|
||||
|
||||
# Test removing tools
|
||||
agent.remove_tool(tool1)
|
||||
assert len(agent.tools) == 1, "Tool not removed correctly"
|
||||
|
||||
# Test adding multiple tools
|
||||
agent.add_tools([tool1, tool2])
|
||||
assert len(agent.tools) == 3, "Multiple tools not added correctly"
|
||||
|
||||
print("✓ Tool management test passed")
|
||||
|
||||
|
||||
def test_system_prompt_and_configuration():
|
||||
"""Test system prompt and configuration updates"""
|
||||
print("\nTesting system prompt and configuration...")
|
||||
|
||||
model = OpenAIChat(model_name="gpt-4o")
|
||||
agent = Agent(
|
||||
agent_name="Config-Test-Agent", llm=model, max_loops=1
|
||||
)
|
||||
|
||||
# Test updating system prompt
|
||||
new_prompt = "You are a helpful assistant."
|
||||
agent.update_system_prompt(new_prompt)
|
||||
assert (
|
||||
agent.system_prompt == new_prompt
|
||||
), "System prompt not updated"
|
||||
|
||||
# Test configuration updates
|
||||
agent.update_max_loops(5)
|
||||
assert agent.max_loops == 5, "Max loops not updated"
|
||||
|
||||
agent.update_loop_interval(2)
|
||||
assert agent.loop_interval == 2, "Loop interval not updated"
|
||||
|
||||
# Test configuration export
|
||||
config_dict = agent.to_dict()
|
||||
assert isinstance(
|
||||
config_dict, dict
|
||||
), "Configuration export failed"
|
||||
|
||||
print("✓ System prompt and configuration test passed")
|
||||
|
||||
|
||||
def test_agent_with_dynamic_temperature():
|
||||
"""Test agent with dynamic temperature"""
|
||||
print("\nTesting agent with dynamic temperature...")
|
||||
|
||||
model = OpenAIChat(model_name="gpt-4o")
|
||||
agent = Agent(
|
||||
agent_name="Dynamic-Temp-Agent",
|
||||
llm=model,
|
||||
max_loops=2,
|
||||
dynamic_temperature_enabled=True,
|
||||
)
|
||||
|
||||
response = agent.run("Generate a creative story.")
|
||||
assert response is not None, "Dynamic temperature test failed"
|
||||
print("✓ Dynamic temperature test passed")
|
||||
|
||||
|
||||
def run_all_tests():
|
||||
"""Run all test functions"""
|
||||
print("Starting Extended Agent functional tests...\n")
|
||||
|
||||
test_functions = [
|
||||
test_basic_agent_functionality,
|
||||
test_memory_management,
|
||||
test_agent_output_formats,
|
||||
test_agent_state_management,
|
||||
test_agent_tools_and_execution,
|
||||
test_agent_concurrent_execution,
|
||||
test_agent_error_handling,
|
||||
test_agent_configuration,
|
||||
test_agent_with_stopping_condition,
|
||||
test_agent_with_retry_mechanism,
|
||||
test_agent_with_dynamic_temperature,
|
||||
test_bulk_and_filtered_operations,
|
||||
test_memory_and_state_persistence,
|
||||
test_sentiment_and_evaluation,
|
||||
test_tool_management,
|
||||
test_system_prompt_and_configuration,
|
||||
]
|
||||
|
||||
# Run synchronous tests
|
||||
total_tests = len(test_functions) + 1 # +1 for async test
|
||||
passed_tests = 0
|
||||
|
||||
for test in test_functions:
|
||||
try:
|
||||
test()
|
||||
passed_tests += 1
|
||||
except Exception as e:
|
||||
print(f"✗ Test {test.__name__} failed: {str(e)}")
|
||||
|
||||
# Run async test
|
||||
try:
|
||||
asyncio.run(test_async_operations())
|
||||
passed_tests += 1
|
||||
except Exception as e:
|
||||
print(f"✗ Async operations test failed: {str(e)}")
|
||||
|
||||
print("\nExtended Test Summary:")
|
||||
print(f"Total Tests: {total_tests}")
|
||||
print(f"Passed: {passed_tests}")
|
||||
print(f"Failed: {total_tests - passed_tests}")
|
||||
print(f"Success Rate: {(passed_tests/total_tests)*100:.2f}%")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_all_tests()
|
Loading…
Reference in new issue