You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/tests/communication/test_sqlite_wrapper.py

387 lines
11 KiB

import json
import datetime
import os
from typing import Dict, List, Any, Tuple
from loguru import logger
from swarms.communication.sqlite_wrap import (
SQLiteConversation,
Message,
MessageType,
)
from rich.console import Console
from rich.table import Table
from rich.panel import Panel
console = Console()
def print_test_header(test_name: str) -> None:
"""Print a formatted test header."""
console.print(
Panel(
f"[bold blue]Running Test: {test_name}[/bold blue]",
expand=False,
)
)
def print_test_result(
test_name: str, success: bool, message: str, execution_time: float
) -> None:
"""Print a formatted test result."""
status = (
"[bold green]PASSED[/bold green]"
if success
else "[bold red]FAILED[/bold red]"
)
console.print(f"\n{status} - {test_name}")
console.print(f"Message: {message}")
console.print(f"Execution time: {execution_time:.3f} seconds\n")
def print_messages(
messages: List[Dict], title: str = "Messages"
) -> None:
"""Print messages in a formatted table."""
table = Table(title=title)
table.add_column("Role", style="cyan")
table.add_column("Content", style="green")
table.add_column("Type", style="yellow")
table.add_column("Timestamp", style="magenta")
for msg in messages:
content = str(msg.get("content", ""))
if isinstance(content, (dict, list)):
content = json.dumps(content)
table.add_row(
msg.get("role", ""),
content,
str(msg.get("message_type", "")),
str(msg.get("timestamp", "")),
)
console.print(table)
def run_test(
test_func: callable, *args, **kwargs
) -> Tuple[bool, str, float]:
"""
Run a test function and return its results.
Args:
test_func: The test function to run
*args: Arguments for the test function
**kwargs: Keyword arguments for the test function
Returns:
Tuple[bool, str, float]: (success, message, execution_time)
"""
start_time = datetime.datetime.now()
try:
result = test_func(*args, **kwargs)
end_time = datetime.datetime.now()
execution_time = (end_time - start_time).total_seconds()
return True, str(result), execution_time
except Exception as e:
end_time = datetime.datetime.now()
execution_time = (end_time - start_time).total_seconds()
return False, str(e), execution_time
def test_basic_conversation() -> bool:
"""Test basic conversation operations."""
print_test_header("Basic Conversation Test")
db_path = "test_conversations.db"
conversation = SQLiteConversation(db_path=db_path)
# Test adding messages
console.print("\n[bold]Adding messages...[/bold]")
conversation.add("user", "Hello")
conversation.add("assistant", "Hi there!")
# Test getting messages
console.print("\n[bold]Retrieved messages:[/bold]")
messages = conversation.get_messages()
print_messages(messages)
assert len(messages) == 2
assert messages[0]["role"] == "user"
assert messages[1]["role"] == "assistant"
# Cleanup
os.remove(db_path)
return True
def test_message_types() -> bool:
"""Test different message types and content formats."""
print_test_header("Message Types Test")
db_path = "test_conversations.db"
conversation = SQLiteConversation(db_path=db_path)
# Test different content types
console.print("\n[bold]Adding different message types...[/bold]")
conversation.add("user", "Simple text")
conversation.add(
"assistant", {"type": "json", "content": "Complex data"}
)
conversation.add("system", ["list", "of", "items"])
conversation.add(
"function",
"Function result",
message_type=MessageType.FUNCTION,
)
console.print("\n[bold]Retrieved messages:[/bold]")
messages = conversation.get_messages()
print_messages(messages)
assert len(messages) == 4
# Cleanup
os.remove(db_path)
return True
def test_conversation_operations() -> bool:
"""Test various conversation operations."""
print_test_header("Conversation Operations Test")
db_path = "test_conversations.db"
conversation = SQLiteConversation(db_path=db_path)
# Test batch operations
console.print("\n[bold]Adding batch messages...[/bold]")
messages = [
Message(role="user", content="Message 1"),
Message(role="assistant", content="Message 2"),
Message(role="user", content="Message 3"),
]
conversation.batch_add(messages)
console.print("\n[bold]Retrieved messages:[/bold]")
all_messages = conversation.get_messages()
print_messages(all_messages)
# Test statistics
console.print("\n[bold]Conversation Statistics:[/bold]")
stats = conversation.get_statistics()
console.print(json.dumps(stats, indent=2))
# Test role counting
console.print("\n[bold]Role Counts:[/bold]")
role_counts = conversation.count_messages_by_role()
console.print(json.dumps(role_counts, indent=2))
assert stats["total_messages"] == 3
assert role_counts["user"] == 2
assert role_counts["assistant"] == 1
# Cleanup
os.remove(db_path)
return True
def test_file_operations() -> bool:
"""Test file operations (JSON/YAML)."""
print_test_header("File Operations Test")
db_path = "test_conversations.db"
json_path = "test_conversation.json"
yaml_path = "test_conversation.yaml"
conversation = SQLiteConversation(db_path=db_path)
conversation.add("user", "Test message")
# Test JSON operations
console.print("\n[bold]Testing JSON operations...[/bold]")
assert conversation.save_as_json(json_path)
console.print(f"Saved to JSON: {json_path}")
conversation.start_new_conversation()
assert conversation.load_from_json(json_path)
console.print("Loaded from JSON")
# Test YAML operations
console.print("\n[bold]Testing YAML operations...[/bold]")
assert conversation.save_as_yaml(yaml_path)
console.print(f"Saved to YAML: {yaml_path}")
conversation.start_new_conversation()
assert conversation.load_from_yaml(yaml_path)
console.print("Loaded from YAML")
# Cleanup
os.remove(db_path)
os.remove(json_path)
os.remove(yaml_path)
return True
def test_search_and_filter() -> bool:
"""Test search and filter operations."""
print_test_header("Search and Filter Test")
db_path = "test_conversations.db"
conversation = SQLiteConversation(db_path=db_path)
# Add test messages
console.print("\n[bold]Adding test messages...[/bold]")
conversation.add("user", "Hello world")
conversation.add("assistant", "Hello there")
conversation.add("user", "Goodbye world")
# Test search
console.print("\n[bold]Searching for 'world'...[/bold]")
results = conversation.search_messages("world")
print_messages(results, "Search Results")
# Test role filtering
console.print("\n[bold]Filtering user messages...[/bold]")
user_messages = conversation.get_messages_by_role("user")
print_messages(user_messages, "User Messages")
assert len(results) == 2
assert len(user_messages) == 2
# Cleanup
os.remove(db_path)
return True
def test_conversation_management() -> bool:
"""Test conversation management features."""
print_test_header("Conversation Management Test")
db_path = "test_conversations.db"
conversation = SQLiteConversation(db_path=db_path)
# Test conversation ID generation
console.print("\n[bold]Testing conversation IDs...[/bold]")
conv_id1 = conversation.get_conversation_id()
console.print(f"First conversation ID: {conv_id1}")
conversation.start_new_conversation()
conv_id2 = conversation.get_conversation_id()
console.print(f"Second conversation ID: {conv_id2}")
assert conv_id1 != conv_id2
# Test conversation deletion
console.print("\n[bold]Testing conversation deletion...[/bold]")
conversation.add("user", "Test message")
assert conversation.delete_current_conversation()
console.print("Conversation deleted successfully")
# Cleanup
os.remove(db_path)
return True
def generate_test_report(
test_results: List[Dict[str, Any]]
) -> Dict[str, Any]:
"""
Generate a test report in JSON format.
Args:
test_results: List of test results
Returns:
Dict containing the test report
"""
total_tests = len(test_results)
passed_tests = sum(
1 for result in test_results if result["success"]
)
failed_tests = total_tests - passed_tests
total_time = sum(
result["execution_time"] for result in test_results
)
report = {
"timestamp": datetime.datetime.now().isoformat(),
"summary": {
"total_tests": total_tests,
"passed_tests": passed_tests,
"failed_tests": failed_tests,
"total_execution_time": total_time,
"average_execution_time": (
total_time / total_tests if total_tests > 0 else 0
),
},
"test_results": test_results,
}
return report
def run_all_tests() -> None:
"""Run all tests and generate a report."""
console.print(
Panel(
"[bold blue]Starting Test Suite[/bold blue]", expand=False
)
)
tests = [
("Basic Conversation", test_basic_conversation),
("Message Types", test_message_types),
("Conversation Operations", test_conversation_operations),
("File Operations", test_file_operations),
("Search and Filter", test_search_and_filter),
("Conversation Management", test_conversation_management),
]
test_results = []
for test_name, test_func in tests:
logger.info(f"Running test: {test_name}")
success, message, execution_time = run_test(test_func)
print_test_result(test_name, success, message, execution_time)
result = {
"test_name": test_name,
"success": success,
"message": message,
"execution_time": execution_time,
"timestamp": datetime.datetime.now().isoformat(),
}
if success:
logger.success(f"Test passed: {test_name}")
else:
logger.error(f"Test failed: {test_name} - {message}")
test_results.append(result)
# Generate and save report
report = generate_test_report(test_results)
report_path = "test_report.json"
with open(report_path, "w") as f:
json.dump(report, f, indent=2)
# Print final summary
console.print("\n[bold blue]Test Suite Summary[/bold blue]")
console.print(
Panel(
f"Total tests: {report['summary']['total_tests']}\n"
f"Passed tests: {report['summary']['passed_tests']}\n"
f"Failed tests: {report['summary']['failed_tests']}\n"
f"Total execution time: {report['summary']['total_execution_time']:.2f} seconds",
title="Summary",
expand=False,
)
)
logger.info(f"Test report saved to {report_path}")
if __name__ == "__main__":
run_all_tests()