You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							405 lines
						
					
					
						
							12 KiB
						
					
					
				
			
		
		
	
	
							405 lines
						
					
					
						
							12 KiB
						
					
					
				| import os
 | |
| from pathlib import Path
 | |
| import tempfile
 | |
| import threading
 | |
| from swarms.communication.duckdb_wrap import (
 | |
|     DuckDBConversation,
 | |
|     Message,
 | |
|     MessageType,
 | |
| )
 | |
| 
 | |
| 
 | |
| def setup_test():
 | |
|     """Set up test environment."""
 | |
|     temp_dir = tempfile.TemporaryDirectory()
 | |
|     db_path = Path(temp_dir.name) / "test_conversations.duckdb"
 | |
|     conversation = DuckDBConversation(
 | |
|         db_path=str(db_path),
 | |
|         enable_timestamps=True,
 | |
|         enable_logging=True,
 | |
|     )
 | |
|     return temp_dir, db_path, conversation
 | |
| 
 | |
| 
 | |
| def cleanup_test(temp_dir, db_path):
 | |
|     """Clean up test environment."""
 | |
|     if os.path.exists(db_path):
 | |
|         os.remove(db_path)
 | |
|     temp_dir.cleanup()
 | |
| 
 | |
| 
 | |
| def test_initialization():
 | |
|     """Test conversation initialization."""
 | |
|     temp_dir, db_path, _ = setup_test()
 | |
|     try:
 | |
|         conv = DuckDBConversation(db_path=str(db_path))
 | |
|         assert conv.db_path == db_path, "Database path mismatch"
 | |
|         assert (
 | |
|             conv.table_name == "conversations"
 | |
|         ), "Table name mismatch"
 | |
|         assert (
 | |
|             conv.enable_timestamps is True
 | |
|         ), "Timestamps should be enabled"
 | |
|         assert (
 | |
|             conv.current_conversation_id is not None
 | |
|         ), "Conversation ID should not be None"
 | |
|         print("✓ Initialization test passed")
 | |
|     finally:
 | |
|         cleanup_test(temp_dir, db_path)
 | |
| 
 | |
| 
 | |
| def test_add_message():
 | |
|     """Test adding a single message."""
 | |
|     temp_dir, db_path, conversation = setup_test()
 | |
|     try:
 | |
|         msg_id = conversation.add(
 | |
|             role="user",
 | |
|             content="Hello, world!",
 | |
|             message_type=MessageType.USER,
 | |
|         )
 | |
|         assert msg_id is not None, "Message ID should not be None"
 | |
|         assert isinstance(
 | |
|             msg_id, int
 | |
|         ), "Message ID should be an integer"
 | |
|         print("✓ Add message test passed")
 | |
|     finally:
 | |
|         cleanup_test(temp_dir, db_path)
 | |
| 
 | |
| 
 | |
| def test_add_complex_message():
 | |
|     """Test adding a message with complex content."""
 | |
|     temp_dir, db_path, conversation = setup_test()
 | |
|     try:
 | |
|         complex_content = {
 | |
|             "text": "Hello",
 | |
|             "data": [1, 2, 3],
 | |
|             "nested": {"key": "value"},
 | |
|         }
 | |
|         msg_id = conversation.add(
 | |
|             role="assistant",
 | |
|             content=complex_content,
 | |
|             message_type=MessageType.ASSISTANT,
 | |
|             metadata={"source": "test"},
 | |
|             token_count=10,
 | |
|         )
 | |
|         assert msg_id is not None, "Message ID should not be None"
 | |
|         print("✓ Add complex message test passed")
 | |
|     finally:
 | |
|         cleanup_test(temp_dir, db_path)
 | |
| 
 | |
| 
 | |
| def test_batch_add():
 | |
|     """Test batch adding messages."""
 | |
|     temp_dir, db_path, conversation = setup_test()
 | |
|     try:
 | |
|         messages = [
 | |
|             Message(
 | |
|                 role="user",
 | |
|                 content="First message",
 | |
|                 message_type=MessageType.USER,
 | |
|             ),
 | |
|             Message(
 | |
|                 role="assistant",
 | |
|                 content="Second message",
 | |
|                 message_type=MessageType.ASSISTANT,
 | |
|             ),
 | |
|         ]
 | |
|         msg_ids = conversation.batch_add(messages)
 | |
|         assert len(msg_ids) == 2, "Should have 2 message IDs"
 | |
|         assert all(
 | |
|             isinstance(id, int) for id in msg_ids
 | |
|         ), "All IDs should be integers"
 | |
|         print("✓ Batch add test passed")
 | |
|     finally:
 | |
|         cleanup_test(temp_dir, db_path)
 | |
| 
 | |
| 
 | |
| def test_get_str():
 | |
|     """Test getting conversation as string."""
 | |
|     temp_dir, db_path, conversation = setup_test()
 | |
|     try:
 | |
|         conversation.add("user", "Hello")
 | |
|         conversation.add("assistant", "Hi there!")
 | |
|         conv_str = conversation.get_str()
 | |
|         assert "user: Hello" in conv_str, "User message not found"
 | |
|         assert (
 | |
|             "assistant: Hi there!" in conv_str
 | |
|         ), "Assistant message not found"
 | |
|         print("✓ Get string test passed")
 | |
|     finally:
 | |
|         cleanup_test(temp_dir, db_path)
 | |
| 
 | |
| 
 | |
| def test_get_messages():
 | |
|     """Test getting messages with pagination."""
 | |
|     temp_dir, db_path, conversation = setup_test()
 | |
|     try:
 | |
|         for i in range(5):
 | |
|             conversation.add("user", f"Message {i}")
 | |
| 
 | |
|         all_messages = conversation.get_messages()
 | |
|         assert len(all_messages) == 5, "Should have 5 messages"
 | |
| 
 | |
|         limited_messages = conversation.get_messages(limit=2)
 | |
|         assert (
 | |
|             len(limited_messages) == 2
 | |
|         ), "Should have 2 limited messages"
 | |
| 
 | |
|         offset_messages = conversation.get_messages(offset=2)
 | |
|         assert (
 | |
|             len(offset_messages) == 3
 | |
|         ), "Should have 3 offset messages"
 | |
|         print("✓ Get messages test passed")
 | |
|     finally:
 | |
|         cleanup_test(temp_dir, db_path)
 | |
| 
 | |
| 
 | |
| def test_search_messages():
 | |
|     """Test searching messages."""
 | |
|     temp_dir, db_path, conversation = setup_test()
 | |
|     try:
 | |
|         conversation.add("user", "Hello world")
 | |
|         conversation.add("assistant", "Hello there")
 | |
|         conversation.add("user", "Goodbye world")
 | |
| 
 | |
|         results = conversation.search_messages("world")
 | |
|         assert (
 | |
|             len(results) == 2
 | |
|         ), "Should find 2 messages with 'world'"
 | |
|         assert all(
 | |
|             "world" in msg["content"] for msg in results
 | |
|         ), "All results should contain 'world'"
 | |
|         print("✓ Search messages test passed")
 | |
|     finally:
 | |
|         cleanup_test(temp_dir, db_path)
 | |
| 
 | |
| 
 | |
| def test_get_statistics():
 | |
|     """Test getting conversation statistics."""
 | |
|     temp_dir, db_path, conversation = setup_test()
 | |
|     try:
 | |
|         conversation.add("user", "Hello", token_count=2)
 | |
|         conversation.add("assistant", "Hi", token_count=1)
 | |
| 
 | |
|         stats = conversation.get_statistics()
 | |
|         assert (
 | |
|             stats["total_messages"] == 2
 | |
|         ), "Should have 2 total messages"
 | |
|         assert (
 | |
|             stats["unique_roles"] == 2
 | |
|         ), "Should have 2 unique roles"
 | |
|         assert (
 | |
|             stats["total_tokens"] == 3
 | |
|         ), "Should have 3 total tokens"
 | |
|         print("✓ Get statistics test passed")
 | |
|     finally:
 | |
|         cleanup_test(temp_dir, db_path)
 | |
| 
 | |
| 
 | |
| def test_json_operations():
 | |
|     """Test JSON save and load operations."""
 | |
|     temp_dir, db_path, conversation = setup_test()
 | |
|     try:
 | |
|         conversation.add("user", "Hello")
 | |
|         conversation.add("assistant", "Hi")
 | |
| 
 | |
|         json_path = Path(temp_dir.name) / "test_conversation.json"
 | |
|         conversation.save_as_json(str(json_path))
 | |
|         assert json_path.exists(), "JSON file should exist"
 | |
| 
 | |
|         new_conversation = DuckDBConversation(
 | |
|             db_path=str(Path(temp_dir.name) / "new.duckdb")
 | |
|         )
 | |
|         assert new_conversation.load_from_json(
 | |
|             str(json_path)
 | |
|         ), "Should load from JSON"
 | |
|         assert (
 | |
|             len(new_conversation.get_messages()) == 2
 | |
|         ), "Should have 2 messages after load"
 | |
|         print("✓ JSON operations test passed")
 | |
|     finally:
 | |
|         cleanup_test(temp_dir, db_path)
 | |
| 
 | |
| 
 | |
| def test_yaml_operations():
 | |
|     """Test YAML save and load operations."""
 | |
|     temp_dir, db_path, conversation = setup_test()
 | |
|     try:
 | |
|         conversation.add("user", "Hello")
 | |
|         conversation.add("assistant", "Hi")
 | |
| 
 | |
|         yaml_path = Path(temp_dir.name) / "test_conversation.yaml"
 | |
|         conversation.save_as_yaml(str(yaml_path))
 | |
|         assert yaml_path.exists(), "YAML file should exist"
 | |
| 
 | |
|         new_conversation = DuckDBConversation(
 | |
|             db_path=str(Path(temp_dir.name) / "new.duckdb")
 | |
|         )
 | |
|         assert new_conversation.load_from_yaml(
 | |
|             str(yaml_path)
 | |
|         ), "Should load from YAML"
 | |
|         assert (
 | |
|             len(new_conversation.get_messages()) == 2
 | |
|         ), "Should have 2 messages after load"
 | |
|         print("✓ YAML operations test passed")
 | |
|     finally:
 | |
|         cleanup_test(temp_dir, db_path)
 | |
| 
 | |
| 
 | |
| def test_message_types():
 | |
|     """Test different message types."""
 | |
|     temp_dir, db_path, conversation = setup_test()
 | |
|     try:
 | |
|         conversation.add(
 | |
|             "system",
 | |
|             "System message",
 | |
|             message_type=MessageType.SYSTEM,
 | |
|         )
 | |
|         conversation.add(
 | |
|             "user", "User message", message_type=MessageType.USER
 | |
|         )
 | |
|         conversation.add(
 | |
|             "assistant",
 | |
|             "Assistant message",
 | |
|             message_type=MessageType.ASSISTANT,
 | |
|         )
 | |
|         conversation.add(
 | |
|             "function",
 | |
|             "Function message",
 | |
|             message_type=MessageType.FUNCTION,
 | |
|         )
 | |
|         conversation.add(
 | |
|             "tool", "Tool message", message_type=MessageType.TOOL
 | |
|         )
 | |
| 
 | |
|         messages = conversation.get_messages()
 | |
|         assert len(messages) == 5, "Should have 5 messages"
 | |
|         assert all(
 | |
|             "message_type" in msg for msg in messages
 | |
|         ), "All messages should have type"
 | |
|         print("✓ Message types test passed")
 | |
|     finally:
 | |
|         cleanup_test(temp_dir, db_path)
 | |
| 
 | |
| 
 | |
| def test_delete_operations():
 | |
|     """Test deletion operations."""
 | |
|     temp_dir, db_path, conversation = setup_test()
 | |
|     try:
 | |
|         conversation.add("user", "Hello")
 | |
|         conversation.add("assistant", "Hi")
 | |
| 
 | |
|         assert (
 | |
|             conversation.delete_current_conversation()
 | |
|         ), "Should delete conversation"
 | |
|         assert (
 | |
|             len(conversation.get_messages()) == 0
 | |
|         ), "Should have no messages after delete"
 | |
| 
 | |
|         conversation.add("user", "New message")
 | |
|         assert conversation.clear_all(), "Should clear all messages"
 | |
|         assert (
 | |
|             len(conversation.get_messages()) == 0
 | |
|         ), "Should have no messages after clear"
 | |
|         print("✓ Delete operations test passed")
 | |
|     finally:
 | |
|         cleanup_test(temp_dir, db_path)
 | |
| 
 | |
| 
 | |
| def test_concurrent_operations():
 | |
|     """Test concurrent operations."""
 | |
|     temp_dir, db_path, conversation = setup_test()
 | |
|     try:
 | |
| 
 | |
|         def add_messages():
 | |
|             for i in range(10):
 | |
|                 conversation.add("user", f"Message {i}")
 | |
| 
 | |
|         threads = [
 | |
|             threading.Thread(target=add_messages) for _ in range(5)
 | |
|         ]
 | |
|         for thread in threads:
 | |
|             thread.start()
 | |
|         for thread in threads:
 | |
|             thread.join()
 | |
| 
 | |
|         messages = conversation.get_messages()
 | |
|         assert (
 | |
|             len(messages) == 50
 | |
|         ), "Should have 50 messages (10 * 5 threads)"
 | |
|         print("✓ Concurrent operations test passed")
 | |
|     finally:
 | |
|         cleanup_test(temp_dir, db_path)
 | |
| 
 | |
| 
 | |
| def test_error_handling():
 | |
|     """Test error handling."""
 | |
|     temp_dir, db_path, conversation = setup_test()
 | |
|     try:
 | |
|         # Test invalid message type
 | |
|         try:
 | |
|             conversation.add(
 | |
|                 "user", "Message", message_type="invalid"
 | |
|             )
 | |
|             assert (
 | |
|                 False
 | |
|             ), "Should raise exception for invalid message type"
 | |
|         except Exception:
 | |
|             pass
 | |
| 
 | |
|         # Test invalid JSON content
 | |
|         try:
 | |
|             conversation.add("user", {"invalid": object()})
 | |
|             assert (
 | |
|                 False
 | |
|             ), "Should raise exception for invalid JSON content"
 | |
|         except Exception:
 | |
|             pass
 | |
| 
 | |
|         # Test invalid file operations
 | |
|         try:
 | |
|             conversation.load_from_json("/nonexistent/path.json")
 | |
|             assert (
 | |
|                 False
 | |
|             ), "Should raise exception for invalid file path"
 | |
|         except Exception:
 | |
|             pass
 | |
| 
 | |
|         print("✓ Error handling test passed")
 | |
|     finally:
 | |
|         cleanup_test(temp_dir, db_path)
 | |
| 
 | |
| 
 | |
| def run_all_tests():
 | |
|     """Run all tests."""
 | |
|     print("Running DuckDB Conversation tests...")
 | |
|     tests = [
 | |
|         test_initialization,
 | |
|         test_add_message,
 | |
|         test_add_complex_message,
 | |
|         test_batch_add,
 | |
|         test_get_str,
 | |
|         test_get_messages,
 | |
|         test_search_messages,
 | |
|         test_get_statistics,
 | |
|         test_json_operations,
 | |
|         test_yaml_operations,
 | |
|         test_message_types,
 | |
|         test_delete_operations,
 | |
|         test_concurrent_operations,
 | |
|         test_error_handling,
 | |
|     ]
 | |
| 
 | |
|     for test in tests:
 | |
|         try:
 | |
|             test()
 | |
|         except Exception as e:
 | |
|             print(f"✗ {test.__name__} failed: {str(e)}")
 | |
|             raise
 | |
| 
 | |
|     print("\nAll tests completed successfully!")
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     run_all_tests()
 |