[Improvement][AgentRouter][ Now leverages litellm embedding instead chromadb] [CLEANUP][Removed un-used util files and cleaned up artifacts]
parent
111bdb0157
commit
6c6b9911e0
@ -0,0 +1,119 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Example demonstrating the AOP network error handling feature.
|
||||
|
||||
This example shows how the AOP server handles network connection issues
|
||||
with custom error messages and automatic retry logic.
|
||||
"""
|
||||
|
||||
from swarms import Agent
|
||||
from swarms.structs.aop import AOP
|
||||
|
||||
|
||||
def main():
|
||||
"""Demonstrate AOP network error handling functionality."""
|
||||
|
||||
# Create a simple agent
|
||||
agent = Agent(
|
||||
agent_name="network_test_agent",
|
||||
agent_description="An agent for testing network error handling",
|
||||
system_prompt="You are a helpful assistant for network testing.",
|
||||
)
|
||||
|
||||
# Create AOP with network monitoring enabled
|
||||
aop = AOP(
|
||||
server_name="Network Resilient AOP Server",
|
||||
description="An AOP server with network error handling and retry logic",
|
||||
agents=[agent],
|
||||
port=8003,
|
||||
host="localhost",
|
||||
persistence=True, # Enable persistence for automatic restart
|
||||
max_restart_attempts=3,
|
||||
restart_delay=2.0,
|
||||
network_monitoring=True, # Enable network monitoring
|
||||
max_network_retries=5, # Allow up to 5 network retries
|
||||
network_retry_delay=3.0, # Wait 3 seconds between network retries
|
||||
network_timeout=10.0, # 10 second network timeout
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
print("AOP Network Error Handling Demo")
|
||||
print("=" * 40)
|
||||
print()
|
||||
|
||||
# Show initial network status
|
||||
print("Initial network status:")
|
||||
network_status = aop.get_network_status()
|
||||
for key, value in network_status.items():
|
||||
print(f" {key}: {value}")
|
||||
print()
|
||||
|
||||
# Show persistence status
|
||||
print("Persistence status:")
|
||||
persistence_status = aop.get_persistence_status()
|
||||
for key, value in persistence_status.items():
|
||||
print(f" {key}: {value}")
|
||||
print()
|
||||
|
||||
print("Network error handling features:")
|
||||
print("✅ Custom error messages with emojis")
|
||||
print("✅ Automatic network connectivity testing")
|
||||
print("✅ Configurable retry attempts and delays")
|
||||
print("✅ Network error detection and classification")
|
||||
print("✅ Graceful degradation and recovery")
|
||||
print()
|
||||
|
||||
print("To test network error handling:")
|
||||
print("1. Start the server (it will run on localhost:8003)")
|
||||
print("2. Simulate network issues by:")
|
||||
print(" - Disconnecting your network")
|
||||
print(" - Blocking the port with firewall")
|
||||
print(" - Stopping the network service")
|
||||
print("3. Watch the custom error messages and retry attempts")
|
||||
print("4. Reconnect and see automatic recovery")
|
||||
print()
|
||||
|
||||
try:
|
||||
print("Starting server with network monitoring...")
|
||||
print("Press Ctrl+C to stop the demo")
|
||||
print()
|
||||
|
||||
# This will run with network monitoring enabled
|
||||
aop.run()
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\nDemo interrupted by user")
|
||||
print("Network status at shutdown:")
|
||||
network_status = aop.get_network_status()
|
||||
for key, value in network_status.items():
|
||||
print(f" {key}: {value}")
|
||||
except Exception as e:
|
||||
print(f"\nUnexpected error: {e}")
|
||||
print("This demonstrates how non-network errors are handled")
|
||||
|
||||
|
||||
def simulate_network_issues():
|
||||
"""
|
||||
Simulate various network issues for testing.
|
||||
|
||||
This function can be used to test the network error handling
|
||||
in a controlled environment.
|
||||
"""
|
||||
print("Network Issue Simulation:")
|
||||
print("1. Connection Refused - Server not running")
|
||||
print("2. Connection Reset - Server closed connection")
|
||||
print("3. Timeout - Server not responding")
|
||||
print("4. Host Resolution Failed - Invalid hostname")
|
||||
print("5. Network Unreachable - No route to host")
|
||||
print()
|
||||
print("The AOP server will detect these errors and:")
|
||||
print("- Display custom error messages with emojis")
|
||||
print("- Attempt automatic reconnection")
|
||||
print("- Test network connectivity before retry")
|
||||
print("- Give up after max retry attempts")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
print("\n" + "=" * 40)
|
||||
simulate_network_issues()
|
@ -0,0 +1,223 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Example demonstrating AOP network management and monitoring.
|
||||
|
||||
This example shows how to monitor and manage network connectivity
|
||||
in an AOP server with real-time status updates.
|
||||
"""
|
||||
|
||||
import time
|
||||
import threading
|
||||
from swarms import Agent
|
||||
from swarms.structs.aop import AOP
|
||||
|
||||
|
||||
def monitor_network_status(aop_instance):
|
||||
"""Monitor network status in a separate thread."""
|
||||
while True:
|
||||
try:
|
||||
network_status = aop_instance.get_network_status()
|
||||
persistence_status = aop_instance.get_persistence_status()
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(
|
||||
f"📊 REAL-TIME STATUS MONITOR - {time.strftime('%H:%M:%S')}"
|
||||
)
|
||||
print(f"{'='*60}")
|
||||
|
||||
# Network Status
|
||||
print("🌐 NETWORK STATUS:")
|
||||
print(
|
||||
f" Monitoring: {'✅ Enabled' if network_status['network_monitoring_enabled'] else '❌ Disabled'}"
|
||||
)
|
||||
print(
|
||||
f" Connected: {'✅ Yes' if network_status['network_connected'] else '❌ No'}"
|
||||
)
|
||||
print(
|
||||
f" Retry Count: {network_status['network_retry_count']}/{network_status['max_network_retries']}"
|
||||
)
|
||||
print(
|
||||
f" Remaining Retries: {network_status['remaining_network_retries']}"
|
||||
)
|
||||
print(
|
||||
f" Host: {network_status['host']}:{network_status['port']}"
|
||||
)
|
||||
print(f" Timeout: {network_status['network_timeout']}s")
|
||||
print(
|
||||
f" Retry Delay: {network_status['network_retry_delay']}s"
|
||||
)
|
||||
|
||||
if network_status["last_network_error"]:
|
||||
print(
|
||||
f" Last Error: {network_status['last_network_error']}"
|
||||
)
|
||||
|
||||
# Persistence Status
|
||||
print("\n🔄 PERSISTENCE STATUS:")
|
||||
print(
|
||||
f" Enabled: {'✅ Yes' if persistence_status['persistence_enabled'] else '❌ No'}"
|
||||
)
|
||||
print(
|
||||
f" Shutdown Requested: {'❌ Yes' if persistence_status['shutdown_requested'] else '✅ No'}"
|
||||
)
|
||||
print(
|
||||
f" Restart Count: {persistence_status['restart_count']}/{persistence_status['max_restart_attempts']}"
|
||||
)
|
||||
print(
|
||||
f" Remaining Restarts: {persistence_status['remaining_restarts']}"
|
||||
)
|
||||
print(
|
||||
f" Restart Delay: {persistence_status['restart_delay']}s"
|
||||
)
|
||||
|
||||
# Connection Health
|
||||
if network_status["network_connected"]:
|
||||
print("\n💚 CONNECTION HEALTH: Excellent")
|
||||
elif network_status["network_retry_count"] == 0:
|
||||
print("\n🟡 CONNECTION HEALTH: Unknown")
|
||||
elif network_status["remaining_network_retries"] > 0:
|
||||
print(
|
||||
f"\n🟠 CONNECTION HEALTH: Recovering ({network_status['remaining_network_retries']} retries left)"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
"\n🔴 CONNECTION HEALTH: Critical (No retries left)"
|
||||
)
|
||||
|
||||
print(f"{'='*60}")
|
||||
|
||||
# Check if we should stop monitoring
|
||||
if (
|
||||
persistence_status["shutdown_requested"]
|
||||
and not persistence_status["persistence_enabled"]
|
||||
):
|
||||
print("🛑 Shutdown requested, stopping monitor...")
|
||||
break
|
||||
|
||||
time.sleep(5) # Update every 5 seconds
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Monitor error: {e}")
|
||||
time.sleep(5)
|
||||
|
||||
|
||||
def main():
|
||||
"""Demonstrate AOP network management."""
|
||||
|
||||
# Create a simple agent
|
||||
agent = Agent(
|
||||
agent_name="network_monitor_agent",
|
||||
agent_description="An agent for network monitoring demo",
|
||||
system_prompt="You are a helpful assistant for network monitoring.",
|
||||
)
|
||||
|
||||
# Create AOP with comprehensive network monitoring
|
||||
aop = AOP(
|
||||
server_name="Network Managed AOP Server",
|
||||
description="An AOP server with comprehensive network management",
|
||||
agents=[agent],
|
||||
port=8004,
|
||||
host="localhost",
|
||||
persistence=True,
|
||||
max_restart_attempts=5,
|
||||
restart_delay=3.0,
|
||||
network_monitoring=True,
|
||||
max_network_retries=10,
|
||||
network_retry_delay=2.0,
|
||||
network_timeout=5.0,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
print("AOP Network Management Demo")
|
||||
print("=" * 50)
|
||||
print()
|
||||
|
||||
# Show initial configuration
|
||||
print("Initial Configuration:")
|
||||
print(f" Server: {aop.server_name}")
|
||||
print(f" Host: {aop.host}:{aop.port}")
|
||||
print(f" Persistence: {aop.persistence}")
|
||||
print(f" Network Monitoring: {aop.network_monitoring}")
|
||||
print(f" Max Network Retries: {aop.max_network_retries}")
|
||||
print(f" Network Timeout: {aop.network_timeout}s")
|
||||
print()
|
||||
|
||||
# Start monitoring in background
|
||||
print("Starting network status monitor...")
|
||||
monitor_thread = threading.Thread(
|
||||
target=monitor_network_status, args=(aop,), daemon=True
|
||||
)
|
||||
monitor_thread.start()
|
||||
|
||||
print("Available commands:")
|
||||
print(" 'start' - Start the server")
|
||||
print(" 'status' - Show current status")
|
||||
print(" 'reset_network' - Reset network retry counter")
|
||||
print(" 'disable_network' - Disable network monitoring")
|
||||
print(" 'enable_network' - Enable network monitoring")
|
||||
print(" 'shutdown' - Request graceful shutdown")
|
||||
print(" 'quit' - Exit the program")
|
||||
print()
|
||||
|
||||
try:
|
||||
while True:
|
||||
command = input("Enter command: ").strip().lower()
|
||||
|
||||
if command == "start":
|
||||
print(
|
||||
"Starting server... (Press Ctrl+C to test network error handling)"
|
||||
)
|
||||
try:
|
||||
aop.run()
|
||||
except KeyboardInterrupt:
|
||||
print("Server interrupted!")
|
||||
|
||||
elif command == "status":
|
||||
print("\nCurrent Status:")
|
||||
network_status = aop.get_network_status()
|
||||
persistence_status = aop.get_persistence_status()
|
||||
|
||||
print("Network:")
|
||||
for key, value in network_status.items():
|
||||
print(f" {key}: {value}")
|
||||
|
||||
print("\nPersistence:")
|
||||
for key, value in persistence_status.items():
|
||||
print(f" {key}: {value}")
|
||||
|
||||
elif command == "reset_network":
|
||||
aop.reset_network_retry_count()
|
||||
print("Network retry counter reset!")
|
||||
|
||||
elif command == "disable_network":
|
||||
aop.network_monitoring = False
|
||||
print("Network monitoring disabled!")
|
||||
|
||||
elif command == "enable_network":
|
||||
aop.network_monitoring = True
|
||||
print("Network monitoring enabled!")
|
||||
|
||||
elif command == "shutdown":
|
||||
aop.request_shutdown()
|
||||
print("Shutdown requested!")
|
||||
|
||||
elif command == "quit":
|
||||
print("Exiting...")
|
||||
break
|
||||
|
||||
else:
|
||||
print(
|
||||
"Unknown command. Try: start, status, reset_network, disable_network, enable_network, shutdown, quit"
|
||||
)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\nExiting...")
|
||||
finally:
|
||||
# Clean shutdown
|
||||
aop.disable_persistence()
|
||||
aop.request_shutdown()
|
||||
print("Cleanup completed")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,62 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Example demonstrating the AOP persistence feature.
|
||||
|
||||
This example shows how to use the persistence mode to create a server
|
||||
that automatically restarts when stopped, with failsafe protection.
|
||||
"""
|
||||
|
||||
from swarms import Agent
|
||||
from swarms.structs.aop import AOP
|
||||
|
||||
|
||||
def main():
|
||||
"""Demonstrate AOP persistence functionality."""
|
||||
|
||||
# Create a simple agent
|
||||
agent = Agent(
|
||||
agent_name="example_agent",
|
||||
agent_description="An example agent for persistence demo",
|
||||
system_prompt="You are a helpful assistant.",
|
||||
)
|
||||
|
||||
# Create AOP with persistence enabled
|
||||
aop = AOP(
|
||||
server_name="Persistent AOP Server",
|
||||
description="A persistent AOP server that auto-restarts",
|
||||
agents=[agent],
|
||||
port=8001,
|
||||
persistence=True, # Enable persistence
|
||||
max_restart_attempts=5, # Allow up to 5 restarts
|
||||
restart_delay=3.0, # Wait 3 seconds between restarts
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
print("Starting persistent AOP server...")
|
||||
print("Press Ctrl+C to test the restart functionality")
|
||||
print("The server will restart automatically up to 5 times")
|
||||
print("After 5 failed restarts, it will shut down permanently")
|
||||
print()
|
||||
|
||||
# Show persistence status
|
||||
status = aop.get_persistence_status()
|
||||
print(f"Persistence Status: {status}")
|
||||
print()
|
||||
|
||||
try:
|
||||
# This will run with persistence enabled
|
||||
aop.run()
|
||||
except KeyboardInterrupt:
|
||||
print("\nReceived interrupt signal")
|
||||
print(
|
||||
"In persistence mode, the server would normally restart"
|
||||
)
|
||||
print(
|
||||
"To disable persistence and shut down gracefully, call:"
|
||||
)
|
||||
print(" aop.disable_persistence()")
|
||||
print(" aop.request_shutdown()")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,141 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Example demonstrating AOP persistence management methods.
|
||||
|
||||
This example shows how to control persistence mode at runtime,
|
||||
including enabling/disabling persistence and monitoring status.
|
||||
"""
|
||||
|
||||
import time
|
||||
import threading
|
||||
from swarms import Agent
|
||||
from swarms.structs.aop import AOP
|
||||
|
||||
|
||||
def monitor_persistence(aop_instance):
|
||||
"""Monitor persistence status in a separate thread."""
|
||||
while True:
|
||||
status = aop_instance.get_persistence_status()
|
||||
print("\n[Monitor] Persistence Status:")
|
||||
print(f" - Enabled: {status['persistence_enabled']}")
|
||||
print(
|
||||
f" - Shutdown Requested: {status['shutdown_requested']}"
|
||||
)
|
||||
print(f" - Restart Count: {status['restart_count']}")
|
||||
print(
|
||||
f" - Remaining Restarts: {status['remaining_restarts']}"
|
||||
)
|
||||
print(
|
||||
f" - Max Restart Attempts: {status['max_restart_attempts']}"
|
||||
)
|
||||
print(f" - Restart Delay: {status['restart_delay']}s")
|
||||
|
||||
if status["shutdown_requested"]:
|
||||
break
|
||||
|
||||
time.sleep(10) # Check every 10 seconds
|
||||
|
||||
|
||||
def main():
|
||||
"""Demonstrate AOP persistence management."""
|
||||
|
||||
# Create a simple agent
|
||||
agent = Agent(
|
||||
agent_name="management_agent",
|
||||
agent_description="An agent for persistence management demo",
|
||||
system_prompt="You are a helpful assistant for testing persistence.",
|
||||
)
|
||||
|
||||
# Create AOP with persistence initially disabled
|
||||
aop = AOP(
|
||||
server_name="Managed AOP Server",
|
||||
description="An AOP server with runtime persistence management",
|
||||
agents=[agent],
|
||||
port=8002,
|
||||
persistence=False, # Start with persistence disabled
|
||||
max_restart_attempts=3,
|
||||
restart_delay=2.0,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
print("AOP Persistence Management Demo")
|
||||
print("=" * 40)
|
||||
print()
|
||||
|
||||
# Show initial status
|
||||
print("Initial persistence status:")
|
||||
status = aop.get_persistence_status()
|
||||
for key, value in status.items():
|
||||
print(f" {key}: {value}")
|
||||
print()
|
||||
|
||||
# Start monitoring in background
|
||||
monitor_thread = threading.Thread(
|
||||
target=monitor_persistence, args=(aop,), daemon=True
|
||||
)
|
||||
monitor_thread.start()
|
||||
|
||||
print("Available commands:")
|
||||
print(" 'enable' - Enable persistence mode")
|
||||
print(" 'disable' - Disable persistence mode")
|
||||
print(" 'shutdown' - Request graceful shutdown")
|
||||
print(" 'reset' - Reset restart counter")
|
||||
print(" 'status' - Show current status")
|
||||
print(" 'start' - Start the server")
|
||||
print(" 'quit' - Exit the program")
|
||||
print()
|
||||
|
||||
try:
|
||||
while True:
|
||||
command = input("Enter command: ").strip().lower()
|
||||
|
||||
if command == "enable":
|
||||
aop.enable_persistence()
|
||||
print("Persistence enabled!")
|
||||
|
||||
elif command == "disable":
|
||||
aop.disable_persistence()
|
||||
print("Persistence disabled!")
|
||||
|
||||
elif command == "shutdown":
|
||||
aop.request_shutdown()
|
||||
print("Shutdown requested!")
|
||||
|
||||
elif command == "reset":
|
||||
aop.reset_restart_count()
|
||||
print("Restart counter reset!")
|
||||
|
||||
elif command == "status":
|
||||
status = aop.get_persistence_status()
|
||||
print("Current status:")
|
||||
for key, value in status.items():
|
||||
print(f" {key}: {value}")
|
||||
|
||||
elif command == "start":
|
||||
print(
|
||||
"Starting server... (Press Ctrl+C to test restart)"
|
||||
)
|
||||
try:
|
||||
aop.run()
|
||||
except KeyboardInterrupt:
|
||||
print("Server interrupted!")
|
||||
|
||||
elif command == "quit":
|
||||
print("Exiting...")
|
||||
break
|
||||
|
||||
else:
|
||||
print(
|
||||
"Unknown command. Try: enable, disable, shutdown, reset, status, start, quit"
|
||||
)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\nExiting...")
|
||||
finally:
|
||||
# Clean shutdown
|
||||
aop.disable_persistence()
|
||||
aop.request_shutdown()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,343 +0,0 @@
|
||||
import base64
|
||||
from typing import Union, Dict, Any, Tuple
|
||||
import requests
|
||||
from pathlib import Path
|
||||
import wave
|
||||
import numpy as np
|
||||
|
||||
|
||||
def encode_audio_to_base64(audio_path: Union[str, Path]) -> str:
|
||||
"""
|
||||
Encode a WAV file to base64 string.
|
||||
|
||||
Args:
|
||||
audio_path (Union[str, Path]): Path to the WAV file
|
||||
|
||||
Returns:
|
||||
str: Base64 encoded string of the audio file
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the audio file doesn't exist
|
||||
ValueError: If the file is not a valid WAV file
|
||||
"""
|
||||
try:
|
||||
audio_path = Path(audio_path)
|
||||
if not audio_path.exists():
|
||||
raise FileNotFoundError(
|
||||
f"Audio file not found: {audio_path}"
|
||||
)
|
||||
|
||||
if not audio_path.suffix.lower() == ".wav":
|
||||
raise ValueError("File must be a WAV file")
|
||||
|
||||
with open(audio_path, "rb") as audio_file:
|
||||
audio_data = audio_file.read()
|
||||
return base64.b64encode(audio_data).decode("utf-8")
|
||||
except Exception as e:
|
||||
raise Exception(f"Error encoding audio file: {str(e)}")
|
||||
|
||||
|
||||
def decode_base64_to_audio(
|
||||
base64_string: str, output_path: Union[str, Path]
|
||||
) -> None:
|
||||
"""
|
||||
Decode a base64 string to a WAV file.
|
||||
|
||||
Args:
|
||||
base64_string (str): Base64 encoded audio data
|
||||
output_path (Union[str, Path]): Path where the WAV file should be saved
|
||||
|
||||
Raises:
|
||||
ValueError: If the base64 string is invalid
|
||||
IOError: If there's an error writing the file
|
||||
"""
|
||||
try:
|
||||
output_path = Path(output_path)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
audio_data = base64.b64decode(base64_string)
|
||||
with open(output_path, "wb") as audio_file:
|
||||
audio_file.write(audio_data)
|
||||
except Exception as e:
|
||||
raise Exception(f"Error decoding audio data: {str(e)}")
|
||||
|
||||
|
||||
def download_audio_from_url(
|
||||
url: str, output_path: Union[str, Path]
|
||||
) -> None:
|
||||
"""
|
||||
Download an audio file from a URL and save it locally.
|
||||
|
||||
Args:
|
||||
url (str): URL of the audio file
|
||||
output_path (Union[str, Path]): Path where the audio file should be saved
|
||||
|
||||
Raises:
|
||||
requests.RequestException: If there's an error downloading the file
|
||||
IOError: If there's an error saving the file
|
||||
"""
|
||||
try:
|
||||
output_path = Path(output_path)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
response = requests.get(url)
|
||||
response.raise_for_status()
|
||||
|
||||
with open(output_path, "wb") as audio_file:
|
||||
audio_file.write(response.content)
|
||||
except Exception as e:
|
||||
raise Exception(f"Error downloading audio file: {str(e)}")
|
||||
|
||||
|
||||
def process_audio_with_model(
|
||||
audio_path: Union[str, Path],
|
||||
model: str,
|
||||
prompt: str,
|
||||
voice: str = "alloy",
|
||||
format: str = "wav",
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Process an audio file with a model that supports audio input/output.
|
||||
|
||||
Args:
|
||||
audio_path (Union[str, Path]): Path to the input WAV file
|
||||
model (str): Model name to use for processing
|
||||
prompt (str): Text prompt to accompany the audio
|
||||
voice (str, optional): Voice to use for audio output. Defaults to "alloy"
|
||||
format (str, optional): Audio format. Defaults to "wav"
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Model response containing both text and audio if applicable
|
||||
|
||||
Raises:
|
||||
ImportError: If litellm is not installed
|
||||
ValueError: If the model doesn't support audio processing
|
||||
"""
|
||||
try:
|
||||
from litellm import (
|
||||
completion,
|
||||
supports_audio_input,
|
||||
supports_audio_output,
|
||||
)
|
||||
|
||||
if not supports_audio_input(model):
|
||||
raise ValueError(
|
||||
f"Model {model} does not support audio input"
|
||||
)
|
||||
|
||||
# Encode the audio file
|
||||
encoded_audio = encode_audio_to_base64(audio_path)
|
||||
|
||||
# Prepare the messages
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": prompt},
|
||||
{
|
||||
"type": "input_audio",
|
||||
"input_audio": {
|
||||
"data": encoded_audio,
|
||||
"format": format,
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
# Make the API call
|
||||
response = completion(
|
||||
model=model,
|
||||
modalities=["text", "audio"],
|
||||
audio={"voice": voice, "format": format},
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
return response
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install litellm: pip install litellm"
|
||||
)
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"Error processing audio with model: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
def read_wav_file(
|
||||
file_path: Union[str, Path],
|
||||
) -> Tuple[np.ndarray, int]:
|
||||
"""
|
||||
Read a WAV file and return its audio data and sample rate.
|
||||
|
||||
Args:
|
||||
file_path (Union[str, Path]): Path to the WAV file
|
||||
|
||||
Returns:
|
||||
Tuple[np.ndarray, int]: Audio data as numpy array and sample rate
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the file doesn't exist
|
||||
ValueError: If the file is not a valid WAV file
|
||||
"""
|
||||
try:
|
||||
file_path = Path(file_path)
|
||||
if not file_path.exists():
|
||||
raise FileNotFoundError(
|
||||
f"Audio file not found: {file_path}"
|
||||
)
|
||||
|
||||
with wave.open(str(file_path), "rb") as wav_file:
|
||||
# Get audio parameters
|
||||
n_channels = wav_file.getnchannels()
|
||||
sample_width = wav_file.getsampwidth()
|
||||
frame_rate = wav_file.getframerate()
|
||||
n_frames = wav_file.getnframes()
|
||||
|
||||
# Read audio data
|
||||
frames = wav_file.readframes(n_frames)
|
||||
|
||||
# Convert to numpy array
|
||||
dtype = np.int16 if sample_width == 2 else np.int8
|
||||
audio_data = np.frombuffer(frames, dtype=dtype)
|
||||
|
||||
# Reshape if stereo
|
||||
if n_channels == 2:
|
||||
audio_data = audio_data.reshape(-1, 2)
|
||||
|
||||
return audio_data, frame_rate
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"Error reading WAV file: {str(e)}")
|
||||
|
||||
|
||||
def write_wav_file(
|
||||
audio_data: np.ndarray,
|
||||
file_path: Union[str, Path],
|
||||
sample_rate: int,
|
||||
sample_width: int = 2,
|
||||
) -> None:
|
||||
"""
|
||||
Write audio data to a WAV file.
|
||||
|
||||
Args:
|
||||
audio_data (np.ndarray): Audio data as numpy array
|
||||
file_path (Union[str, Path]): Path where to save the WAV file
|
||||
sample_rate (int): Sample rate of the audio
|
||||
sample_width (int, optional): Sample width in bytes. Defaults to 2 (16-bit)
|
||||
|
||||
Raises:
|
||||
ValueError: If the audio data is invalid
|
||||
IOError: If there's an error writing the file
|
||||
"""
|
||||
try:
|
||||
file_path = Path(file_path)
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Ensure audio data is in the correct format
|
||||
if audio_data.dtype != np.int16 and sample_width == 2:
|
||||
audio_data = (audio_data * 32767).astype(np.int16)
|
||||
elif audio_data.dtype != np.int8 and sample_width == 1:
|
||||
audio_data = (audio_data * 127).astype(np.int8)
|
||||
|
||||
# Determine number of channels
|
||||
n_channels = (
|
||||
2
|
||||
if len(audio_data.shape) > 1 and audio_data.shape[1] == 2
|
||||
else 1
|
||||
)
|
||||
|
||||
with wave.open(str(file_path), "wb") as wav_file:
|
||||
wav_file.setnchannels(n_channels)
|
||||
wav_file.setsampwidth(sample_width)
|
||||
wav_file.setframerate(sample_rate)
|
||||
wav_file.writeframes(audio_data.tobytes())
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"Error writing WAV file: {str(e)}")
|
||||
|
||||
|
||||
def normalize_audio(audio_data: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Normalize audio data to have maximum amplitude of 1.0.
|
||||
|
||||
Args:
|
||||
audio_data (np.ndarray): Input audio data
|
||||
|
||||
Returns:
|
||||
np.ndarray: Normalized audio data
|
||||
"""
|
||||
return audio_data / np.max(np.abs(audio_data))
|
||||
|
||||
|
||||
def convert_to_mono(audio_data: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Convert stereo audio to mono by averaging channels.
|
||||
|
||||
Args:
|
||||
audio_data (np.ndarray): Input audio data (stereo)
|
||||
|
||||
Returns:
|
||||
np.ndarray: Mono audio data
|
||||
"""
|
||||
if len(audio_data.shape) == 1:
|
||||
return audio_data
|
||||
return np.mean(audio_data, axis=1)
|
||||
|
||||
|
||||
def encode_wav_to_base64(
|
||||
audio_data: np.ndarray, sample_rate: int
|
||||
) -> str:
|
||||
"""
|
||||
Convert audio data to base64 encoded WAV string.
|
||||
|
||||
Args:
|
||||
audio_data (np.ndarray): Audio data
|
||||
sample_rate (int): Sample rate of the audio
|
||||
|
||||
Returns:
|
||||
str: Base64 encoded WAV data
|
||||
"""
|
||||
# Create a temporary WAV file in memory
|
||||
with wave.open("temp.wav", "wb") as wav_file:
|
||||
wav_file.setnchannels(1 if len(audio_data.shape) == 1 else 2)
|
||||
wav_file.setsampwidth(2) # 16-bit
|
||||
wav_file.setframerate(sample_rate)
|
||||
wav_file.writeframes(audio_data.tobytes())
|
||||
|
||||
# Read the file and encode to base64
|
||||
with open("temp.wav", "rb") as f:
|
||||
wav_bytes = f.read()
|
||||
|
||||
# Clean up temporary file
|
||||
Path("temp.wav").unlink()
|
||||
|
||||
return base64.b64encode(wav_bytes).decode("utf-8")
|
||||
|
||||
|
||||
def decode_base64_to_wav(
|
||||
base64_string: str,
|
||||
) -> Tuple[np.ndarray, int]:
|
||||
"""
|
||||
Convert base64 encoded WAV string to audio data and sample rate.
|
||||
|
||||
Args:
|
||||
base64_string (str): Base64 encoded WAV data
|
||||
|
||||
Returns:
|
||||
Tuple[np.ndarray, int]: Audio data and sample rate
|
||||
"""
|
||||
# Decode base64 string
|
||||
wav_bytes = base64.b64decode(base64_string)
|
||||
|
||||
# Write to temporary file
|
||||
with open("temp.wav", "wb") as f:
|
||||
f.write(wav_bytes)
|
||||
|
||||
# Read the WAV file
|
||||
audio_data, sample_rate = read_wav_file("temp.wav")
|
||||
|
||||
# Clean up temporary file
|
||||
Path("temp.wav").unlink()
|
||||
|
||||
return audio_data, sample_rate
|
@ -1,151 +0,0 @@
|
||||
"""
|
||||
Package installation utility that checks for package existence and installs if needed.
|
||||
Supports both pip and conda package managers.
|
||||
"""
|
||||
|
||||
import importlib.util
|
||||
import subprocess
|
||||
import sys
|
||||
from typing import Literal, Optional, Union
|
||||
from swarms.utils.loguru_logger import initialize_logger
|
||||
|
||||
|
||||
from importlib.metadata import distribution, PackageNotFoundError
|
||||
|
||||
logger = initialize_logger("autocheckpackages")
|
||||
|
||||
|
||||
def check_and_install_package(
|
||||
package_name: str,
|
||||
package_manager: Literal["pip", "conda"] = "pip",
|
||||
version: Optional[str] = None,
|
||||
upgrade: bool = False,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a package is installed and install it if not found.
|
||||
|
||||
Args:
|
||||
package_name: Name of the package to check/install
|
||||
package_manager: Package manager to use ('pip' or 'conda')
|
||||
version: Specific version to install (optional)
|
||||
upgrade: Whether to upgrade the package if it exists
|
||||
|
||||
Returns:
|
||||
bool: True if package is available after check/install, False if installation failed
|
||||
|
||||
Raises:
|
||||
ValueError: If invalid package manager is specified
|
||||
"""
|
||||
try:
|
||||
# Check if package exists
|
||||
if package_manager == "pip":
|
||||
try:
|
||||
distribution(package_name)
|
||||
if not upgrade:
|
||||
logger.info(
|
||||
f"Package {package_name} is already installed"
|
||||
)
|
||||
return True
|
||||
except PackageNotFoundError:
|
||||
pass
|
||||
|
||||
# Construct installation command
|
||||
cmd = [sys.executable, "-m", "pip", "install"]
|
||||
if upgrade:
|
||||
cmd.append("--upgrade")
|
||||
|
||||
if version:
|
||||
cmd.append(f"{package_name}=={version}")
|
||||
else:
|
||||
cmd.append(package_name)
|
||||
|
||||
elif package_manager == "conda":
|
||||
# Check if conda is available
|
||||
try:
|
||||
subprocess.run(
|
||||
["conda", "--version"],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
)
|
||||
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||
logger.error(
|
||||
"Conda is not available. Please install conda first."
|
||||
)
|
||||
return False
|
||||
|
||||
# Construct conda command
|
||||
cmd = ["conda", "install", "-y"]
|
||||
if version:
|
||||
cmd.append(f"{package_name}={version}")
|
||||
else:
|
||||
cmd.append(package_name)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid package manager: {package_manager}"
|
||||
)
|
||||
|
||||
# Run installation
|
||||
logger.info(f"Installing {package_name}...")
|
||||
subprocess.run(
|
||||
cmd, check=True, capture_output=True, text=True
|
||||
)
|
||||
|
||||
# Verify installation
|
||||
try:
|
||||
importlib.import_module(package_name)
|
||||
logger.info(f"Successfully installed {package_name}")
|
||||
return True
|
||||
except ImportError:
|
||||
logger.error(
|
||||
f"Package {package_name} was installed but cannot be imported"
|
||||
)
|
||||
return False
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.error(f"Failed to install {package_name}: {e.stderr}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Unexpected error while installing {package_name}: {str(e)}"
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
def auto_check_and_download_package(
|
||||
packages: Union[str, list[str]],
|
||||
package_manager: Literal["pip", "conda"] = "pip",
|
||||
upgrade: bool = False,
|
||||
) -> bool:
|
||||
"""
|
||||
Ensure multiple packages are installed.
|
||||
|
||||
Args:
|
||||
packages: Single package name or list of package names
|
||||
package_manager: Package manager to use ('pip' or 'conda')
|
||||
upgrade: Whether to upgrade existing packages
|
||||
|
||||
Returns:
|
||||
bool: True if all packages are available, False if any installation failed
|
||||
"""
|
||||
if isinstance(packages, str):
|
||||
packages = [packages]
|
||||
|
||||
success = True
|
||||
for package in packages:
|
||||
if ":" in package:
|
||||
name, version = package.split(":")
|
||||
if not check_and_install_package(
|
||||
name, package_manager, version, upgrade
|
||||
):
|
||||
success = False
|
||||
else:
|
||||
if not check_and_install_package(
|
||||
package, package_manager, upgrade=upgrade
|
||||
):
|
||||
success = False
|
||||
|
||||
return success
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# print(auto_check_and_download_package("torch"))
|
@ -1,54 +0,0 @@
|
||||
from typing import Any
|
||||
from litellm import image_generation
|
||||
|
||||
|
||||
class ImageGenerator:
|
||||
def __init__(
|
||||
self,
|
||||
model: str | None = None,
|
||||
n: int | None = 2,
|
||||
quality: Any = None,
|
||||
response_format: str | None = None,
|
||||
size: str | None = 10,
|
||||
style: str | None = None,
|
||||
user: str | None = None,
|
||||
input_fidelity: str | None = None,
|
||||
timeout: int = 600,
|
||||
output_path_folder: str | None = "images",
|
||||
api_key: str | None = None,
|
||||
api_base: str | None = None,
|
||||
):
|
||||
self.model = model
|
||||
self.n = n
|
||||
self.quality = quality
|
||||
self.response_format = response_format
|
||||
self.size = size
|
||||
self.style = style
|
||||
self.user = user
|
||||
self.input_fidelity = input_fidelity
|
||||
self.timeout = timeout
|
||||
self.output_path_folder = output_path_folder
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base
|
||||
|
||||
def run(self, task: str = None):
|
||||
|
||||
return image_generation(
|
||||
prompt=task,
|
||||
model=self.model,
|
||||
n=self.n,
|
||||
quality=self.quality,
|
||||
response_format=self.response_format,
|
||||
size=self.size,
|
||||
style=self.style,
|
||||
user=self.user,
|
||||
input_fidelity=self.input_fidelity,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# image_generator = ImageGenerator()
|
||||
# print(image_generator.run(task="A beautiful sunset over a calm ocean"))
|
||||
|
||||
# print(model_list)
|
@ -0,0 +1,387 @@
|
||||
"""
|
||||
Simplified test suite for AgentRouter class using pytest.
|
||||
|
||||
This module contains focused tests for the core functionality of the AgentRouter class.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from swarms.structs.agent_router import AgentRouter
|
||||
from swarms.structs.agent import Agent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_agent():
|
||||
"""Create a real agent for testing."""
|
||||
with patch("swarms.structs.agent.LiteLLM") as mock_llm:
|
||||
mock_llm.return_value.run.return_value = "Test response"
|
||||
return Agent(
|
||||
agent_name="test_agent",
|
||||
agent_description="A test agent",
|
||||
system_prompt="You are a test agent",
|
||||
model_name="gpt-4o-mini",
|
||||
max_loops=1,
|
||||
verbose=False,
|
||||
print_on=False,
|
||||
)
|
||||
|
||||
|
||||
def test_agent_router_initialization_default():
|
||||
"""Test AgentRouter initialization with default parameters."""
|
||||
with patch("swarms.structs.agent_router.embedding"):
|
||||
router = AgentRouter()
|
||||
|
||||
assert router.embedding_model == "text-embedding-ada-002"
|
||||
assert router.n_agents == 1
|
||||
assert router.api_key is None
|
||||
assert router.api_base is None
|
||||
assert router.agents == []
|
||||
assert router.agent_embeddings == []
|
||||
assert router.agent_metadata == []
|
||||
|
||||
|
||||
def test_agent_router_initialization_custom():
|
||||
"""Test AgentRouter initialization with custom parameters."""
|
||||
with patch("swarms.structs.agent_router.embedding"), patch(
|
||||
"swarms.structs.agent.LiteLLM"
|
||||
) as mock_llm:
|
||||
mock_llm.return_value.run.return_value = "Test response"
|
||||
agents = [
|
||||
Agent(
|
||||
agent_name="test1",
|
||||
model_name="gpt-4o-mini",
|
||||
max_loops=1,
|
||||
verbose=False,
|
||||
print_on=False,
|
||||
),
|
||||
Agent(
|
||||
agent_name="test2",
|
||||
model_name="gpt-4o-mini",
|
||||
max_loops=1,
|
||||
verbose=False,
|
||||
print_on=False,
|
||||
),
|
||||
]
|
||||
router = AgentRouter(
|
||||
embedding_model="custom-model",
|
||||
n_agents=3,
|
||||
api_key="custom_key",
|
||||
api_base="custom_base",
|
||||
agents=agents,
|
||||
)
|
||||
|
||||
assert router.embedding_model == "custom-model"
|
||||
assert router.n_agents == 3
|
||||
assert router.api_key == "custom_key"
|
||||
assert router.api_base == "custom_base"
|
||||
assert len(router.agents) == 2
|
||||
|
||||
|
||||
def test_cosine_similarity_identical_vectors():
|
||||
"""Test cosine similarity with identical vectors."""
|
||||
router = AgentRouter()
|
||||
vec1 = [1.0, 0.0, 0.0]
|
||||
vec2 = [1.0, 0.0, 0.0]
|
||||
|
||||
result = router._cosine_similarity(vec1, vec2)
|
||||
assert result == 1.0
|
||||
|
||||
|
||||
def test_cosine_similarity_orthogonal_vectors():
|
||||
"""Test cosine similarity with orthogonal vectors."""
|
||||
router = AgentRouter()
|
||||
vec1 = [1.0, 0.0, 0.0]
|
||||
vec2 = [0.0, 1.0, 0.0]
|
||||
|
||||
result = router._cosine_similarity(vec1, vec2)
|
||||
assert result == 0.0
|
||||
|
||||
|
||||
def test_cosine_similarity_opposite_vectors():
|
||||
"""Test cosine similarity with opposite vectors."""
|
||||
router = AgentRouter()
|
||||
vec1 = [1.0, 0.0, 0.0]
|
||||
vec2 = [-1.0, 0.0, 0.0]
|
||||
|
||||
result = router._cosine_similarity(vec1, vec2)
|
||||
assert result == -1.0
|
||||
|
||||
|
||||
def test_cosine_similarity_different_lengths():
|
||||
"""Test cosine similarity with vectors of different lengths."""
|
||||
router = AgentRouter()
|
||||
vec1 = [1.0, 0.0]
|
||||
vec2 = [1.0, 0.0, 0.0]
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="Vectors must have the same length"
|
||||
):
|
||||
router._cosine_similarity(vec1, vec2)
|
||||
|
||||
|
||||
@patch("swarms.structs.agent_router.embedding")
|
||||
def test_generate_embedding_success(mock_embedding):
|
||||
"""Test successful embedding generation."""
|
||||
mock_embedding.return_value.data = [
|
||||
Mock(embedding=[0.1, 0.2, 0.3, 0.4])
|
||||
]
|
||||
|
||||
router = AgentRouter()
|
||||
result = router._generate_embedding("test text")
|
||||
|
||||
assert result == [0.1, 0.2, 0.3, 0.4]
|
||||
mock_embedding.assert_called_once()
|
||||
|
||||
|
||||
@patch("swarms.structs.agent_router.embedding")
|
||||
def test_generate_embedding_error(mock_embedding):
|
||||
"""Test embedding generation error handling."""
|
||||
mock_embedding.side_effect = Exception("API Error")
|
||||
|
||||
router = AgentRouter()
|
||||
|
||||
with pytest.raises(Exception, match="API Error"):
|
||||
router._generate_embedding("test text")
|
||||
|
||||
|
||||
@patch("swarms.structs.agent_router.embedding")
|
||||
def test_add_agent_success(mock_embedding, test_agent):
|
||||
"""Test successful agent addition."""
|
||||
mock_embedding.return_value.data = [
|
||||
Mock(embedding=[0.1, 0.2, 0.3])
|
||||
]
|
||||
|
||||
router = AgentRouter()
|
||||
router.add_agent(test_agent)
|
||||
|
||||
assert len(router.agents) == 1
|
||||
assert len(router.agent_embeddings) == 1
|
||||
assert len(router.agent_metadata) == 1
|
||||
assert router.agents[0] == test_agent
|
||||
assert router.agent_embeddings[0] == [0.1, 0.2, 0.3]
|
||||
assert router.agent_metadata[0]["name"] == "test_agent"
|
||||
|
||||
|
||||
@patch("swarms.structs.agent_router.embedding")
|
||||
def test_add_agent_retry_error(mock_embedding, test_agent):
|
||||
"""Test agent addition with retry mechanism failure."""
|
||||
mock_embedding.side_effect = Exception("Embedding error")
|
||||
|
||||
router = AgentRouter()
|
||||
|
||||
# Should raise RetryError after retries are exhausted
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
router.add_agent(test_agent)
|
||||
|
||||
# Check that it's a retry error or contains the original error
|
||||
assert "Embedding error" in str(
|
||||
exc_info.value
|
||||
) or "RetryError" in str(exc_info.value)
|
||||
|
||||
|
||||
@patch("swarms.structs.agent_router.embedding")
|
||||
def test_add_agents_multiple(mock_embedding):
|
||||
"""Test adding multiple agents."""
|
||||
mock_embedding.return_value.data = [
|
||||
Mock(embedding=[0.1, 0.2, 0.3])
|
||||
]
|
||||
|
||||
with patch("swarms.structs.agent.LiteLLM") as mock_llm:
|
||||
mock_llm.return_value.run.return_value = "Test response"
|
||||
router = AgentRouter()
|
||||
agents = [
|
||||
Agent(
|
||||
agent_name="agent1",
|
||||
model_name="gpt-4o-mini",
|
||||
max_loops=1,
|
||||
verbose=False,
|
||||
print_on=False,
|
||||
),
|
||||
Agent(
|
||||
agent_name="agent2",
|
||||
model_name="gpt-4o-mini",
|
||||
max_loops=1,
|
||||
verbose=False,
|
||||
print_on=False,
|
||||
),
|
||||
Agent(
|
||||
agent_name="agent3",
|
||||
model_name="gpt-4o-mini",
|
||||
max_loops=1,
|
||||
verbose=False,
|
||||
print_on=False,
|
||||
),
|
||||
]
|
||||
|
||||
router.add_agents(agents)
|
||||
|
||||
assert len(router.agents) == 3
|
||||
assert len(router.agent_embeddings) == 3
|
||||
assert len(router.agent_metadata) == 3
|
||||
|
||||
|
||||
@patch("swarms.structs.agent_router.embedding")
|
||||
def test_find_best_agent_success(mock_embedding):
|
||||
"""Test successful best agent finding."""
|
||||
# Mock embeddings for agents and task
|
||||
mock_embedding.side_effect = [
|
||||
Mock(data=[Mock(embedding=[0.1, 0.2, 0.3])]), # agent1
|
||||
Mock(data=[Mock(embedding=[0.4, 0.5, 0.6])]), # agent2
|
||||
Mock(data=[Mock(embedding=[0.7, 0.8, 0.9])]), # task
|
||||
]
|
||||
|
||||
with patch("swarms.structs.agent.LiteLLM") as mock_llm:
|
||||
mock_llm.return_value.run.return_value = "Test response"
|
||||
router = AgentRouter()
|
||||
agent1 = Agent(
|
||||
agent_name="agent1",
|
||||
agent_description="First agent",
|
||||
system_prompt="Prompt 1",
|
||||
model_name="gpt-4o-mini",
|
||||
max_loops=1,
|
||||
verbose=False,
|
||||
print_on=False,
|
||||
)
|
||||
agent2 = Agent(
|
||||
agent_name="agent2",
|
||||
agent_description="Second agent",
|
||||
system_prompt="Prompt 2",
|
||||
model_name="gpt-4o-mini",
|
||||
max_loops=1,
|
||||
verbose=False,
|
||||
print_on=False,
|
||||
)
|
||||
|
||||
router.add_agent(agent1)
|
||||
router.add_agent(agent2)
|
||||
|
||||
# Mock the similarity calculation to return predictable results
|
||||
with patch.object(
|
||||
router, "_cosine_similarity"
|
||||
) as mock_similarity:
|
||||
mock_similarity.side_effect = [
|
||||
0.8,
|
||||
0.6,
|
||||
] # agent1 more similar
|
||||
|
||||
result = router.find_best_agent("test task")
|
||||
|
||||
assert result == agent1
|
||||
|
||||
|
||||
def test_find_best_agent_no_agents():
|
||||
"""Test finding best agent when no agents are available."""
|
||||
with patch("swarms.structs.agent_router.embedding"):
|
||||
router = AgentRouter()
|
||||
|
||||
result = router.find_best_agent("test task")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
@patch("swarms.structs.agent_router.embedding")
|
||||
def test_find_best_agent_retry_error(mock_embedding):
|
||||
"""Test error handling in find_best_agent with retry mechanism."""
|
||||
mock_embedding.side_effect = Exception("API Error")
|
||||
|
||||
with patch("swarms.structs.agent.LiteLLM") as mock_llm:
|
||||
mock_llm.return_value.run.return_value = "Test response"
|
||||
router = AgentRouter()
|
||||
router.agents = [
|
||||
Agent(
|
||||
agent_name="agent1",
|
||||
model_name="gpt-4o-mini",
|
||||
max_loops=1,
|
||||
verbose=False,
|
||||
print_on=False,
|
||||
)
|
||||
]
|
||||
router.agent_embeddings = [[0.1, 0.2, 0.3]]
|
||||
|
||||
# Should raise RetryError after retries are exhausted
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
router.find_best_agent("test task")
|
||||
|
||||
# Check that it's a retry error or contains the original error
|
||||
assert "API Error" in str(
|
||||
exc_info.value
|
||||
) or "RetryError" in str(exc_info.value)
|
||||
|
||||
|
||||
@patch("swarms.structs.agent_router.embedding")
|
||||
def test_update_agent_history_success(mock_embedding, test_agent):
|
||||
"""Test successful agent history update."""
|
||||
mock_embedding.return_value.data = [
|
||||
Mock(embedding=[0.1, 0.2, 0.3])
|
||||
]
|
||||
|
||||
router = AgentRouter()
|
||||
router.add_agent(test_agent)
|
||||
|
||||
# Update agent history
|
||||
router.update_agent_history("test_agent")
|
||||
|
||||
# Verify the embedding was regenerated
|
||||
assert (
|
||||
mock_embedding.call_count == 2
|
||||
) # Once for add, once for update
|
||||
|
||||
|
||||
def test_update_agent_history_agent_not_found():
|
||||
"""Test updating history for non-existent agent."""
|
||||
with patch(
|
||||
"swarms.structs.agent_router.embedding"
|
||||
) as mock_embedding:
|
||||
mock_embedding.return_value.data = [
|
||||
Mock(embedding=[0.1, 0.2, 0.3])
|
||||
]
|
||||
router = AgentRouter()
|
||||
|
||||
# Should not raise an exception, just log a warning
|
||||
router.update_agent_history("non_existent_agent")
|
||||
|
||||
|
||||
@patch("swarms.structs.agent_router.embedding")
|
||||
def test_agent_metadata_structure(mock_embedding, test_agent):
|
||||
"""Test the structure of agent metadata."""
|
||||
mock_embedding.return_value.data = [
|
||||
Mock(embedding=[0.1, 0.2, 0.3])
|
||||
]
|
||||
|
||||
router = AgentRouter()
|
||||
router.add_agent(test_agent)
|
||||
|
||||
metadata = router.agent_metadata[0]
|
||||
assert "name" in metadata
|
||||
assert "text" in metadata
|
||||
assert metadata["name"] == "test_agent"
|
||||
assert (
|
||||
"test_agent A test agent You are a test agent"
|
||||
in metadata["text"]
|
||||
)
|
||||
|
||||
|
||||
def test_agent_router_edge_cases():
|
||||
"""Test various edge cases."""
|
||||
with patch(
|
||||
"swarms.structs.agent_router.embedding"
|
||||
) as mock_embedding:
|
||||
mock_embedding.return_value.data = [
|
||||
Mock(embedding=[0.1, 0.2, 0.3])
|
||||
]
|
||||
|
||||
router = AgentRouter()
|
||||
|
||||
# Test with empty string task
|
||||
result = router.find_best_agent("")
|
||||
assert result is None
|
||||
|
||||
# Test with very long task description
|
||||
long_task = "test " * 1000
|
||||
result = router.find_best_agent(long_task)
|
||||
assert result is None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
@ -1,104 +0,0 @@
|
||||
from swarms.utils.auto_download_check_packages import (
|
||||
auto_check_and_download_package,
|
||||
check_and_install_package,
|
||||
)
|
||||
|
||||
|
||||
def test_check_and_install_package_pip():
|
||||
result = check_and_install_package("numpy", package_manager="pip")
|
||||
print(f"Test result for 'numpy' installation using pip: {result}")
|
||||
assert result, "Failed to install or verify 'numpy' using pip"
|
||||
|
||||
|
||||
def test_check_and_install_package_conda():
|
||||
result = check_and_install_package(
|
||||
"numpy", package_manager="conda"
|
||||
)
|
||||
print(
|
||||
f"Test result for 'numpy' installation using conda: {result}"
|
||||
)
|
||||
assert result, "Failed to install or verify 'numpy' using conda"
|
||||
|
||||
|
||||
def test_check_and_install_specific_version():
|
||||
result = check_and_install_package(
|
||||
"numpy", package_manager="pip", version="1.21.0"
|
||||
)
|
||||
print(
|
||||
f"Test result for specific version of 'numpy' installation using pip: {result}"
|
||||
)
|
||||
assert (
|
||||
result
|
||||
), "Failed to install or verify specific version of 'numpy' using pip"
|
||||
|
||||
|
||||
def test_check_and_install_with_upgrade():
|
||||
result = check_and_install_package(
|
||||
"numpy", package_manager="pip", upgrade=True
|
||||
)
|
||||
print(f"Test result for 'numpy' upgrade using pip: {result}")
|
||||
assert result, "Failed to upgrade 'numpy' using pip"
|
||||
|
||||
|
||||
def test_auto_check_and_download_single_package():
|
||||
result = auto_check_and_download_package(
|
||||
"scipy", package_manager="pip"
|
||||
)
|
||||
print(f"Test result for 'scipy' installation using pip: {result}")
|
||||
assert result, "Failed to install or verify 'scipy' using pip"
|
||||
|
||||
|
||||
def test_auto_check_and_download_multiple_packages():
|
||||
packages = ["scipy", "pandas"]
|
||||
result = auto_check_and_download_package(
|
||||
packages, package_manager="pip"
|
||||
)
|
||||
print(
|
||||
f"Test result for multiple packages installation using pip: {result}"
|
||||
)
|
||||
assert (
|
||||
result
|
||||
), f"Failed to install or verify one or more packages in {packages} using pip"
|
||||
|
||||
|
||||
def test_auto_check_and_download_multiple_packages_with_versions():
|
||||
packages = ["numpy:1.21.0", "pandas:1.3.0"]
|
||||
result = auto_check_and_download_package(
|
||||
packages, package_manager="pip"
|
||||
)
|
||||
print(
|
||||
f"Test result for multiple packages with versions installation using pip: {result}"
|
||||
)
|
||||
assert (
|
||||
result
|
||||
), f"Failed to install or verify one or more packages in {packages} with specific versions using pip"
|
||||
|
||||
|
||||
# Example of running tests
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
test_check_and_install_package_pip()
|
||||
print("test_check_and_install_package_pip passed")
|
||||
|
||||
test_check_and_install_package_conda()
|
||||
print("test_check_and_install_package_conda passed")
|
||||
|
||||
test_check_and_install_specific_version()
|
||||
print("test_check_and_install_specific_version passed")
|
||||
|
||||
test_check_and_install_with_upgrade()
|
||||
print("test_check_and_install_with_upgrade passed")
|
||||
|
||||
test_auto_check_and_download_single_package()
|
||||
print("test_auto_check_and_download_single_package passed")
|
||||
|
||||
test_auto_check_and_download_multiple_packages()
|
||||
print("test_auto_check_and_download_multiple_packages passed")
|
||||
|
||||
test_auto_check_and_download_multiple_packages_with_versions()
|
||||
print(
|
||||
"test_auto_check_and_download_multiple_packages_with_versions passed"
|
||||
)
|
||||
|
||||
except AssertionError as e:
|
||||
print(f"Test failed: {str(e)}")
|
Loading…
Reference in new issue