import json import datetime import os from typing import Dict, List, Any, Tuple from loguru import logger from swarms.communication.sqlite_wrap import ( SQLiteConversation, Message, MessageType, ) from rich.console import Console from rich.table import Table from rich.panel import Panel console = Console() def print_test_header(test_name: str) -> None: """Print a formatted test header.""" console.print( Panel( f"[bold blue]Running Test: {test_name}[/bold blue]", expand=False, ) ) def print_test_result( test_name: str, success: bool, message: str, execution_time: float ) -> None: """Print a formatted test result.""" status = ( "[bold green]PASSED[/bold green]" if success else "[bold red]FAILED[/bold red]" ) console.print(f"\n{status} - {test_name}") console.print(f"Message: {message}") console.print(f"Execution time: {execution_time:.3f} seconds\n") def print_messages( messages: List[Dict], title: str = "Messages" ) -> None: """Print messages in a formatted table.""" table = Table(title=title) table.add_column("Role", style="cyan") table.add_column("Content", style="green") table.add_column("Type", style="yellow") table.add_column("Timestamp", style="magenta") for msg in messages: content = str(msg.get("content", "")) if isinstance(content, (dict, list)): content = json.dumps(content) table.add_row( msg.get("role", ""), content, str(msg.get("message_type", "")), str(msg.get("timestamp", "")), ) console.print(table) def run_test( test_func: callable, *args, **kwargs ) -> Tuple[bool, str, float]: """ Run a test function and return its results. Args: test_func: The test function to run *args: Arguments for the test function **kwargs: Keyword arguments for the test function Returns: Tuple[bool, str, float]: (success, message, execution_time) """ start_time = datetime.datetime.now() try: result = test_func(*args, **kwargs) end_time = datetime.datetime.now() execution_time = (end_time - start_time).total_seconds() return True, str(result), execution_time except Exception as e: end_time = datetime.datetime.now() execution_time = (end_time - start_time).total_seconds() return False, str(e), execution_time def test_basic_conversation() -> bool: """Test basic conversation operations.""" print_test_header("Basic Conversation Test") db_path = "test_conversations.db" conversation = SQLiteConversation(db_path=db_path) # Test adding messages console.print("\n[bold]Adding messages...[/bold]") conversation.add("user", "Hello") conversation.add("assistant", "Hi there!") # Test getting messages console.print("\n[bold]Retrieved messages:[/bold]") messages = conversation.get_messages() print_messages(messages) assert len(messages) == 2 assert messages[0]["role"] == "user" assert messages[1]["role"] == "assistant" # Cleanup os.remove(db_path) return True def test_message_types() -> bool: """Test different message types and content formats.""" print_test_header("Message Types Test") db_path = "test_conversations.db" conversation = SQLiteConversation(db_path=db_path) # Test different content types console.print("\n[bold]Adding different message types...[/bold]") conversation.add("user", "Simple text") conversation.add( "assistant", {"type": "json", "content": "Complex data"} ) conversation.add("system", ["list", "of", "items"]) conversation.add( "function", "Function result", message_type=MessageType.FUNCTION, ) console.print("\n[bold]Retrieved messages:[/bold]") messages = conversation.get_messages() print_messages(messages) assert len(messages) == 4 # Cleanup os.remove(db_path) return True def test_conversation_operations() -> bool: """Test various conversation operations.""" print_test_header("Conversation Operations Test") db_path = "test_conversations.db" conversation = SQLiteConversation(db_path=db_path) # Test batch operations console.print("\n[bold]Adding batch messages...[/bold]") messages = [ Message(role="user", content="Message 1"), Message(role="assistant", content="Message 2"), Message(role="user", content="Message 3"), ] conversation.batch_add(messages) console.print("\n[bold]Retrieved messages:[/bold]") all_messages = conversation.get_messages() print_messages(all_messages) # Test statistics console.print("\n[bold]Conversation Statistics:[/bold]") stats = conversation.get_statistics() console.print(json.dumps(stats, indent=2)) # Test role counting console.print("\n[bold]Role Counts:[/bold]") role_counts = conversation.count_messages_by_role() console.print(json.dumps(role_counts, indent=2)) assert stats["total_messages"] == 3 assert role_counts["user"] == 2 assert role_counts["assistant"] == 1 # Cleanup os.remove(db_path) return True def test_file_operations() -> bool: """Test file operations (JSON/YAML).""" print_test_header("File Operations Test") db_path = "test_conversations.db" json_path = "test_conversation.json" yaml_path = "test_conversation.yaml" conversation = SQLiteConversation(db_path=db_path) conversation.add("user", "Test message") # Test JSON operations console.print("\n[bold]Testing JSON operations...[/bold]") assert conversation.save_as_json(json_path) console.print(f"Saved to JSON: {json_path}") conversation.start_new_conversation() assert conversation.load_from_json(json_path) console.print("Loaded from JSON") # Test YAML operations console.print("\n[bold]Testing YAML operations...[/bold]") assert conversation.save_as_yaml(yaml_path) console.print(f"Saved to YAML: {yaml_path}") conversation.start_new_conversation() assert conversation.load_from_yaml(yaml_path) console.print("Loaded from YAML") # Cleanup os.remove(db_path) os.remove(json_path) os.remove(yaml_path) return True def test_search_and_filter() -> bool: """Test search and filter operations.""" print_test_header("Search and Filter Test") db_path = "test_conversations.db" conversation = SQLiteConversation(db_path=db_path) # Add test messages console.print("\n[bold]Adding test messages...[/bold]") conversation.add("user", "Hello world") conversation.add("assistant", "Hello there") conversation.add("user", "Goodbye world") # Test search console.print("\n[bold]Searching for 'world'...[/bold]") results = conversation.search_messages("world") print_messages(results, "Search Results") # Test role filtering console.print("\n[bold]Filtering user messages...[/bold]") user_messages = conversation.get_messages_by_role("user") print_messages(user_messages, "User Messages") assert len(results) == 2 assert len(user_messages) == 2 # Cleanup os.remove(db_path) return True def test_conversation_management() -> bool: """Test conversation management features.""" print_test_header("Conversation Management Test") db_path = "test_conversations.db" conversation = SQLiteConversation(db_path=db_path) # Test conversation ID generation console.print("\n[bold]Testing conversation IDs...[/bold]") conv_id1 = conversation.get_conversation_id() console.print(f"First conversation ID: {conv_id1}") conversation.start_new_conversation() conv_id2 = conversation.get_conversation_id() console.print(f"Second conversation ID: {conv_id2}") assert conv_id1 != conv_id2 # Test conversation deletion console.print("\n[bold]Testing conversation deletion...[/bold]") conversation.add("user", "Test message") assert conversation.delete_current_conversation() console.print("Conversation deleted successfully") # Cleanup os.remove(db_path) return True def generate_test_report( test_results: List[Dict[str, Any]] ) -> Dict[str, Any]: """ Generate a test report in JSON format. Args: test_results: List of test results Returns: Dict containing the test report """ total_tests = len(test_results) passed_tests = sum( 1 for result in test_results if result["success"] ) failed_tests = total_tests - passed_tests total_time = sum( result["execution_time"] for result in test_results ) report = { "timestamp": datetime.datetime.now().isoformat(), "summary": { "total_tests": total_tests, "passed_tests": passed_tests, "failed_tests": failed_tests, "total_execution_time": total_time, "average_execution_time": ( total_time / total_tests if total_tests > 0 else 0 ), }, "test_results": test_results, } return report def run_all_tests() -> None: """Run all tests and generate a report.""" console.print( Panel( "[bold blue]Starting Test Suite[/bold blue]", expand=False ) ) tests = [ ("Basic Conversation", test_basic_conversation), ("Message Types", test_message_types), ("Conversation Operations", test_conversation_operations), ("File Operations", test_file_operations), ("Search and Filter", test_search_and_filter), ("Conversation Management", test_conversation_management), ] test_results = [] for test_name, test_func in tests: logger.info(f"Running test: {test_name}") success, message, execution_time = run_test(test_func) print_test_result(test_name, success, message, execution_time) result = { "test_name": test_name, "success": success, "message": message, "execution_time": execution_time, "timestamp": datetime.datetime.now().isoformat(), } if success: logger.success(f"Test passed: {test_name}") else: logger.error(f"Test failed: {test_name} - {message}") test_results.append(result) # Generate and save report report = generate_test_report(test_results) report_path = "test_report.json" with open(report_path, "w") as f: json.dump(report, f, indent=2) # Print final summary console.print("\n[bold blue]Test Suite Summary[/bold blue]") console.print( Panel( f"Total tests: {report['summary']['total_tests']}\n" f"Passed tests: {report['summary']['passed_tests']}\n" f"Failed tests: {report['summary']['failed_tests']}\n" f"Total execution time: {report['summary']['total_execution_time']:.2f} seconds", title="Summary", expand=False, ) ) logger.info(f"Test report saved to {report_path}") if __name__ == "__main__": run_all_tests()