You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
285 lines
8.9 KiB
285 lines
8.9 KiB
import asyncio
|
|
import concurrent.futures
|
|
import json
|
|
import os
|
|
import psutil
|
|
import datetime
|
|
from pathlib import Path
|
|
from typing import List, Dict, Any, Optional
|
|
from swarms.structs.agent import Agent
|
|
from loguru import logger
|
|
|
|
|
|
class AgentBenchmark:
|
|
def __init__(
|
|
self,
|
|
num_iterations: int = 5,
|
|
output_dir: str = "benchmark_results",
|
|
):
|
|
self.num_iterations = num_iterations
|
|
self.output_dir = Path(output_dir)
|
|
self.output_dir.mkdir(exist_ok=True)
|
|
|
|
# Use process pool for CPU-bound tasks
|
|
self.process_pool = concurrent.futures.ProcessPoolExecutor(
|
|
max_workers=min(os.cpu_count(), 4)
|
|
)
|
|
|
|
# Use thread pool for I/O-bound tasks
|
|
self.thread_pool = concurrent.futures.ThreadPoolExecutor(
|
|
max_workers=min(os.cpu_count() * 2, 8)
|
|
)
|
|
|
|
self.default_queries = [
|
|
"Conduct an analysis of the best real undervalued ETFs",
|
|
"What are the top performing tech stocks this quarter?",
|
|
"Analyze current market trends in renewable energy sector",
|
|
"Compare Bitcoin and Ethereum investment potential",
|
|
"Evaluate the risk factors in emerging markets",
|
|
]
|
|
|
|
self.agent = self._initialize_agent()
|
|
self.process = psutil.Process()
|
|
|
|
# Cache for storing repeated query results
|
|
self._query_cache = {}
|
|
|
|
def _initialize_agent(self) -> Agent:
|
|
return Agent(
|
|
agent_name="Financial-Analysis-Agent",
|
|
agent_description="Personal finance advisor agent",
|
|
# system_prompt=FINANCIAL_AGENT_SYS_PROMPT,
|
|
max_loops=1,
|
|
model_name="gpt-4o-mini",
|
|
dynamic_temperature_enabled=True,
|
|
interactive=False,
|
|
)
|
|
|
|
def _get_system_metrics(self) -> Dict[str, float]:
|
|
# Optimized system metrics collection
|
|
return {
|
|
"cpu_percent": self.process.cpu_percent(),
|
|
"memory_mb": self.process.memory_info().rss / 1024 / 1024,
|
|
}
|
|
|
|
def _calculate_statistics(
|
|
self, values: List[float]
|
|
) -> Dict[str, float]:
|
|
if not values:
|
|
return {}
|
|
|
|
sorted_values = sorted(values)
|
|
n = len(sorted_values)
|
|
mean_val = sum(values) / n
|
|
|
|
stats = {
|
|
"mean": mean_val,
|
|
"median": sorted_values[n // 2],
|
|
"min": sorted_values[0],
|
|
"max": sorted_values[-1],
|
|
}
|
|
|
|
# Only calculate stdev if we have enough values
|
|
if n > 1:
|
|
stats["std_dev"] = (
|
|
sum((x - mean_val) ** 2 for x in values) / n
|
|
) ** 0.5
|
|
|
|
return {k: round(v, 3) for k, v in stats.items()}
|
|
|
|
async def process_iteration(
|
|
self, query: str, iteration: int
|
|
) -> Dict[str, Any]:
|
|
"""Process a single iteration of a query"""
|
|
try:
|
|
# Check cache for repeated queries
|
|
cache_key = f"{query}_{iteration}"
|
|
if cache_key in self._query_cache:
|
|
return self._query_cache[cache_key]
|
|
|
|
iteration_start = datetime.datetime.now()
|
|
pre_metrics = self._get_system_metrics()
|
|
|
|
# Run the agent
|
|
try:
|
|
self.agent.run(query)
|
|
success = True
|
|
except Exception as e:
|
|
str(e)
|
|
success = False
|
|
|
|
execution_time = (
|
|
datetime.datetime.now() - iteration_start
|
|
).total_seconds()
|
|
post_metrics = self._get_system_metrics()
|
|
|
|
result = {
|
|
"execution_time": execution_time,
|
|
"success": success,
|
|
"pre_metrics": pre_metrics,
|
|
"post_metrics": post_metrics,
|
|
"iteration_data": {
|
|
"iteration": iteration + 1,
|
|
"execution_time": round(execution_time, 3),
|
|
"success": success,
|
|
"system_metrics": {
|
|
"pre": pre_metrics,
|
|
"post": post_metrics,
|
|
},
|
|
},
|
|
}
|
|
|
|
# Cache the result
|
|
self._query_cache[cache_key] = result
|
|
return result
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in iteration {iteration}: {e}")
|
|
raise
|
|
|
|
async def run_benchmark(
|
|
self, queries: Optional[List[str]] = None
|
|
) -> Dict[str, Any]:
|
|
"""Run the benchmark asynchronously"""
|
|
queries = queries or self.default_queries
|
|
benchmark_data = {
|
|
"metadata": {
|
|
"timestamp": datetime.datetime.now().isoformat(),
|
|
"num_iterations": self.num_iterations,
|
|
"agent_config": {
|
|
"model_name": self.agent.model_name,
|
|
"max_loops": self.agent.max_loops,
|
|
},
|
|
},
|
|
"results": {},
|
|
}
|
|
|
|
async def process_query(query: str):
|
|
query_results = {
|
|
"execution_times": [],
|
|
"system_metrics": [],
|
|
"iterations": [],
|
|
}
|
|
|
|
# Process iterations concurrently
|
|
tasks = [
|
|
self.process_iteration(query, i)
|
|
for i in range(self.num_iterations)
|
|
]
|
|
iteration_results = await asyncio.gather(*tasks)
|
|
|
|
for result in iteration_results:
|
|
query_results["execution_times"].append(
|
|
result["execution_time"]
|
|
)
|
|
query_results["system_metrics"].append(
|
|
result["post_metrics"]
|
|
)
|
|
query_results["iterations"].append(
|
|
result["iteration_data"]
|
|
)
|
|
|
|
# Calculate statistics
|
|
query_results["statistics"] = {
|
|
"execution_time": self._calculate_statistics(
|
|
query_results["execution_times"]
|
|
),
|
|
"memory_usage": self._calculate_statistics(
|
|
[
|
|
m["memory_mb"]
|
|
for m in query_results["system_metrics"]
|
|
]
|
|
),
|
|
"cpu_usage": self._calculate_statistics(
|
|
[
|
|
m["cpu_percent"]
|
|
for m in query_results["system_metrics"]
|
|
]
|
|
),
|
|
}
|
|
|
|
return query, query_results
|
|
|
|
# Execute all queries concurrently
|
|
query_tasks = [process_query(query) for query in queries]
|
|
query_results = await asyncio.gather(*query_tasks)
|
|
|
|
for query, results in query_results:
|
|
benchmark_data["results"][query] = results
|
|
|
|
return benchmark_data
|
|
|
|
def save_results(self, benchmark_data: Dict[str, Any]) -> str:
|
|
"""Save benchmark results efficiently"""
|
|
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
filename = (
|
|
self.output_dir / f"benchmark_results_{timestamp}.json"
|
|
)
|
|
|
|
# Write results in a single operation
|
|
with open(filename, "w") as f:
|
|
json.dump(benchmark_data, f, indent=2)
|
|
|
|
logger.info(f"Benchmark results saved to: {filename}")
|
|
return str(filename)
|
|
|
|
def print_summary(self, results: Dict[str, Any]):
|
|
"""Print a summary of the benchmark results"""
|
|
print("\n=== Benchmark Summary ===")
|
|
for query, data in results["results"].items():
|
|
print(f"\nQuery: {query[:50]}...")
|
|
stats = data["statistics"]["execution_time"]
|
|
print(f"Average time: {stats['mean']:.2f}s")
|
|
print(
|
|
f"Memory usage (avg): {data['statistics']['memory_usage']['mean']:.1f}MB"
|
|
)
|
|
print(
|
|
f"CPU usage (avg): {data['statistics']['cpu_usage']['mean']:.1f}%"
|
|
)
|
|
|
|
async def run_with_timeout(
|
|
self, timeout: int = 300
|
|
) -> Dict[str, Any]:
|
|
"""Run benchmark with timeout"""
|
|
try:
|
|
return await asyncio.wait_for(
|
|
self.run_benchmark(), timeout
|
|
)
|
|
except asyncio.TimeoutError:
|
|
logger.error(
|
|
f"Benchmark timed out after {timeout} seconds"
|
|
)
|
|
raise
|
|
|
|
def cleanup(self):
|
|
"""Cleanup resources"""
|
|
self.process_pool.shutdown()
|
|
self.thread_pool.shutdown()
|
|
self._query_cache.clear()
|
|
|
|
|
|
async def main():
|
|
try:
|
|
# Create and run benchmark
|
|
benchmark = AgentBenchmark(num_iterations=1)
|
|
|
|
# Run benchmark with timeout
|
|
results = await benchmark.run_with_timeout(timeout=300)
|
|
|
|
# Save results
|
|
benchmark.save_results(results)
|
|
|
|
# Print summary
|
|
benchmark.print_summary(results)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Benchmark failed: {e}")
|
|
finally:
|
|
# Cleanup resources
|
|
benchmark.cleanup()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Run the async main function
|
|
asyncio.run(main())
|