import os from pathlib import Path import tempfile import threading from swarms.communication.duckdb_wrap import ( DuckDBConversation, Message, MessageType, ) def setup_test(): """Set up test environment.""" temp_dir = tempfile.TemporaryDirectory() db_path = Path(temp_dir.name) / "test_conversations.duckdb" conversation = DuckDBConversation( db_path=str(db_path), enable_timestamps=True, enable_logging=True, ) return temp_dir, db_path, conversation def cleanup_test(temp_dir, db_path): """Clean up test environment.""" if os.path.exists(db_path): os.remove(db_path) temp_dir.cleanup() def test_initialization(): """Test conversation initialization.""" temp_dir, db_path, _ = setup_test() try: conv = DuckDBConversation(db_path=str(db_path)) assert conv.db_path == db_path, "Database path mismatch" assert ( conv.table_name == "conversations" ), "Table name mismatch" assert ( conv.enable_timestamps is True ), "Timestamps should be enabled" assert ( conv.current_conversation_id is not None ), "Conversation ID should not be None" print("✓ Initialization test passed") finally: cleanup_test(temp_dir, db_path) def test_add_message(): """Test adding a single message.""" temp_dir, db_path, conversation = setup_test() try: msg_id = conversation.add( role="user", content="Hello, world!", message_type=MessageType.USER, ) assert msg_id is not None, "Message ID should not be None" assert isinstance( msg_id, int ), "Message ID should be an integer" print("✓ Add message test passed") finally: cleanup_test(temp_dir, db_path) def test_add_complex_message(): """Test adding a message with complex content.""" temp_dir, db_path, conversation = setup_test() try: complex_content = { "text": "Hello", "data": [1, 2, 3], "nested": {"key": "value"}, } msg_id = conversation.add( role="assistant", content=complex_content, message_type=MessageType.ASSISTANT, metadata={"source": "test"}, token_count=10, ) assert msg_id is not None, "Message ID should not be None" print("✓ Add complex message test passed") finally: cleanup_test(temp_dir, db_path) def test_batch_add(): """Test batch adding messages.""" temp_dir, db_path, conversation = setup_test() try: messages = [ Message( role="user", content="First message", message_type=MessageType.USER, ), Message( role="assistant", content="Second message", message_type=MessageType.ASSISTANT, ), ] msg_ids = conversation.batch_add(messages) assert len(msg_ids) == 2, "Should have 2 message IDs" assert all( isinstance(id, int) for id in msg_ids ), "All IDs should be integers" print("✓ Batch add test passed") finally: cleanup_test(temp_dir, db_path) def test_get_str(): """Test getting conversation as string.""" temp_dir, db_path, conversation = setup_test() try: conversation.add("user", "Hello") conversation.add("assistant", "Hi there!") conv_str = conversation.get_str() assert "user: Hello" in conv_str, "User message not found" assert ( "assistant: Hi there!" in conv_str ), "Assistant message not found" print("✓ Get string test passed") finally: cleanup_test(temp_dir, db_path) def test_get_messages(): """Test getting messages with pagination.""" temp_dir, db_path, conversation = setup_test() try: for i in range(5): conversation.add("user", f"Message {i}") all_messages = conversation.get_messages() assert len(all_messages) == 5, "Should have 5 messages" limited_messages = conversation.get_messages(limit=2) assert ( len(limited_messages) == 2 ), "Should have 2 limited messages" offset_messages = conversation.get_messages(offset=2) assert ( len(offset_messages) == 3 ), "Should have 3 offset messages" print("✓ Get messages test passed") finally: cleanup_test(temp_dir, db_path) def test_search_messages(): """Test searching messages.""" temp_dir, db_path, conversation = setup_test() try: conversation.add("user", "Hello world") conversation.add("assistant", "Hello there") conversation.add("user", "Goodbye world") results = conversation.search_messages("world") assert ( len(results) == 2 ), "Should find 2 messages with 'world'" assert all( "world" in msg["content"] for msg in results ), "All results should contain 'world'" print("✓ Search messages test passed") finally: cleanup_test(temp_dir, db_path) def test_get_statistics(): """Test getting conversation statistics.""" temp_dir, db_path, conversation = setup_test() try: conversation.add("user", "Hello", token_count=2) conversation.add("assistant", "Hi", token_count=1) stats = conversation.get_statistics() assert ( stats["total_messages"] == 2 ), "Should have 2 total messages" assert ( stats["unique_roles"] == 2 ), "Should have 2 unique roles" assert ( stats["total_tokens"] == 3 ), "Should have 3 total tokens" print("✓ Get statistics test passed") finally: cleanup_test(temp_dir, db_path) def test_json_operations(): """Test JSON save and load operations.""" temp_dir, db_path, conversation = setup_test() try: conversation.add("user", "Hello") conversation.add("assistant", "Hi") json_path = Path(temp_dir.name) / "test_conversation.json" conversation.save_as_json(str(json_path)) assert json_path.exists(), "JSON file should exist" new_conversation = DuckDBConversation( db_path=str(Path(temp_dir.name) / "new.duckdb") ) assert new_conversation.load_from_json( str(json_path) ), "Should load from JSON" assert ( len(new_conversation.get_messages()) == 2 ), "Should have 2 messages after load" print("✓ JSON operations test passed") finally: cleanup_test(temp_dir, db_path) def test_yaml_operations(): """Test YAML save and load operations.""" temp_dir, db_path, conversation = setup_test() try: conversation.add("user", "Hello") conversation.add("assistant", "Hi") yaml_path = Path(temp_dir.name) / "test_conversation.yaml" conversation.save_as_yaml(str(yaml_path)) assert yaml_path.exists(), "YAML file should exist" new_conversation = DuckDBConversation( db_path=str(Path(temp_dir.name) / "new.duckdb") ) assert new_conversation.load_from_yaml( str(yaml_path) ), "Should load from YAML" assert ( len(new_conversation.get_messages()) == 2 ), "Should have 2 messages after load" print("✓ YAML operations test passed") finally: cleanup_test(temp_dir, db_path) def test_message_types(): """Test different message types.""" temp_dir, db_path, conversation = setup_test() try: conversation.add( "system", "System message", message_type=MessageType.SYSTEM, ) conversation.add( "user", "User message", message_type=MessageType.USER ) conversation.add( "assistant", "Assistant message", message_type=MessageType.ASSISTANT, ) conversation.add( "function", "Function message", message_type=MessageType.FUNCTION, ) conversation.add( "tool", "Tool message", message_type=MessageType.TOOL ) messages = conversation.get_messages() assert len(messages) == 5, "Should have 5 messages" assert all( "message_type" in msg for msg in messages ), "All messages should have type" print("✓ Message types test passed") finally: cleanup_test(temp_dir, db_path) def test_delete_operations(): """Test deletion operations.""" temp_dir, db_path, conversation = setup_test() try: conversation.add("user", "Hello") conversation.add("assistant", "Hi") assert ( conversation.delete_current_conversation() ), "Should delete conversation" assert ( len(conversation.get_messages()) == 0 ), "Should have no messages after delete" conversation.add("user", "New message") assert conversation.clear_all(), "Should clear all messages" assert ( len(conversation.get_messages()) == 0 ), "Should have no messages after clear" print("✓ Delete operations test passed") finally: cleanup_test(temp_dir, db_path) def test_concurrent_operations(): """Test concurrent operations.""" temp_dir, db_path, conversation = setup_test() try: def add_messages(): for i in range(10): conversation.add("user", f"Message {i}") threads = [ threading.Thread(target=add_messages) for _ in range(5) ] for thread in threads: thread.start() for thread in threads: thread.join() messages = conversation.get_messages() assert ( len(messages) == 50 ), "Should have 50 messages (10 * 5 threads)" print("✓ Concurrent operations test passed") finally: cleanup_test(temp_dir, db_path) def test_error_handling(): """Test error handling.""" temp_dir, db_path, conversation = setup_test() try: # Test invalid message type try: conversation.add( "user", "Message", message_type="invalid" ) assert ( False ), "Should raise exception for invalid message type" except Exception: pass # Test invalid JSON content try: conversation.add("user", {"invalid": object()}) assert ( False ), "Should raise exception for invalid JSON content" except Exception: pass # Test invalid file operations try: conversation.load_from_json("/nonexistent/path.json") assert ( False ), "Should raise exception for invalid file path" except Exception: pass print("✓ Error handling test passed") finally: cleanup_test(temp_dir, db_path) def run_all_tests(): """Run all tests.""" print("Running DuckDB Conversation tests...") tests = [ test_initialization, test_add_message, test_add_complex_message, test_batch_add, test_get_str, test_get_messages, test_search_messages, test_get_statistics, test_json_operations, test_yaml_operations, test_message_types, test_delete_operations, test_concurrent_operations, test_error_handling, ] for test in tests: try: test() except Exception as e: print(f"✗ {test.__name__} failed: {str(e)}") raise print("\nAll tests completed successfully!") if __name__ == "__main__": run_all_tests()