You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							405 lines
						
					
					
						
							12 KiB
						
					
					
				
			
		
		
	
	
							405 lines
						
					
					
						
							12 KiB
						
					
					
				import os
 | 
						|
from pathlib import Path
 | 
						|
import tempfile
 | 
						|
import threading
 | 
						|
from swarms.communication.duckdb_wrap import (
 | 
						|
    DuckDBConversation,
 | 
						|
    Message,
 | 
						|
    MessageType,
 | 
						|
)
 | 
						|
 | 
						|
 | 
						|
def setup_test():
 | 
						|
    """Set up test environment."""
 | 
						|
    temp_dir = tempfile.TemporaryDirectory()
 | 
						|
    db_path = Path(temp_dir.name) / "test_conversations.duckdb"
 | 
						|
    conversation = DuckDBConversation(
 | 
						|
        db_path=str(db_path),
 | 
						|
        enable_timestamps=True,
 | 
						|
        enable_logging=True,
 | 
						|
    )
 | 
						|
    return temp_dir, db_path, conversation
 | 
						|
 | 
						|
 | 
						|
def cleanup_test(temp_dir, db_path):
 | 
						|
    """Clean up test environment."""
 | 
						|
    if os.path.exists(db_path):
 | 
						|
        os.remove(db_path)
 | 
						|
    temp_dir.cleanup()
 | 
						|
 | 
						|
 | 
						|
def test_initialization():
 | 
						|
    """Test conversation initialization."""
 | 
						|
    temp_dir, db_path, _ = setup_test()
 | 
						|
    try:
 | 
						|
        conv = DuckDBConversation(db_path=str(db_path))
 | 
						|
        assert conv.db_path == db_path, "Database path mismatch"
 | 
						|
        assert (
 | 
						|
            conv.table_name == "conversations"
 | 
						|
        ), "Table name mismatch"
 | 
						|
        assert (
 | 
						|
            conv.enable_timestamps is True
 | 
						|
        ), "Timestamps should be enabled"
 | 
						|
        assert (
 | 
						|
            conv.current_conversation_id is not None
 | 
						|
        ), "Conversation ID should not be None"
 | 
						|
        print("✓ Initialization test passed")
 | 
						|
    finally:
 | 
						|
        cleanup_test(temp_dir, db_path)
 | 
						|
 | 
						|
 | 
						|
def test_add_message():
 | 
						|
    """Test adding a single message."""
 | 
						|
    temp_dir, db_path, conversation = setup_test()
 | 
						|
    try:
 | 
						|
        msg_id = conversation.add(
 | 
						|
            role="user",
 | 
						|
            content="Hello, world!",
 | 
						|
            message_type=MessageType.USER,
 | 
						|
        )
 | 
						|
        assert msg_id is not None, "Message ID should not be None"
 | 
						|
        assert isinstance(
 | 
						|
            msg_id, int
 | 
						|
        ), "Message ID should be an integer"
 | 
						|
        print("✓ Add message test passed")
 | 
						|
    finally:
 | 
						|
        cleanup_test(temp_dir, db_path)
 | 
						|
 | 
						|
 | 
						|
def test_add_complex_message():
 | 
						|
    """Test adding a message with complex content."""
 | 
						|
    temp_dir, db_path, conversation = setup_test()
 | 
						|
    try:
 | 
						|
        complex_content = {
 | 
						|
            "text": "Hello",
 | 
						|
            "data": [1, 2, 3],
 | 
						|
            "nested": {"key": "value"},
 | 
						|
        }
 | 
						|
        msg_id = conversation.add(
 | 
						|
            role="assistant",
 | 
						|
            content=complex_content,
 | 
						|
            message_type=MessageType.ASSISTANT,
 | 
						|
            metadata={"source": "test"},
 | 
						|
            token_count=10,
 | 
						|
        )
 | 
						|
        assert msg_id is not None, "Message ID should not be None"
 | 
						|
        print("✓ Add complex message test passed")
 | 
						|
    finally:
 | 
						|
        cleanup_test(temp_dir, db_path)
 | 
						|
 | 
						|
 | 
						|
def test_batch_add():
 | 
						|
    """Test batch adding messages."""
 | 
						|
    temp_dir, db_path, conversation = setup_test()
 | 
						|
    try:
 | 
						|
        messages = [
 | 
						|
            Message(
 | 
						|
                role="user",
 | 
						|
                content="First message",
 | 
						|
                message_type=MessageType.USER,
 | 
						|
            ),
 | 
						|
            Message(
 | 
						|
                role="assistant",
 | 
						|
                content="Second message",
 | 
						|
                message_type=MessageType.ASSISTANT,
 | 
						|
            ),
 | 
						|
        ]
 | 
						|
        msg_ids = conversation.batch_add(messages)
 | 
						|
        assert len(msg_ids) == 2, "Should have 2 message IDs"
 | 
						|
        assert all(
 | 
						|
            isinstance(id, int) for id in msg_ids
 | 
						|
        ), "All IDs should be integers"
 | 
						|
        print("✓ Batch add test passed")
 | 
						|
    finally:
 | 
						|
        cleanup_test(temp_dir, db_path)
 | 
						|
 | 
						|
 | 
						|
def test_get_str():
 | 
						|
    """Test getting conversation as string."""
 | 
						|
    temp_dir, db_path, conversation = setup_test()
 | 
						|
    try:
 | 
						|
        conversation.add("user", "Hello")
 | 
						|
        conversation.add("assistant", "Hi there!")
 | 
						|
        conv_str = conversation.get_str()
 | 
						|
        assert "user: Hello" in conv_str, "User message not found"
 | 
						|
        assert (
 | 
						|
            "assistant: Hi there!" in conv_str
 | 
						|
        ), "Assistant message not found"
 | 
						|
        print("✓ Get string test passed")
 | 
						|
    finally:
 | 
						|
        cleanup_test(temp_dir, db_path)
 | 
						|
 | 
						|
 | 
						|
def test_get_messages():
 | 
						|
    """Test getting messages with pagination."""
 | 
						|
    temp_dir, db_path, conversation = setup_test()
 | 
						|
    try:
 | 
						|
        for i in range(5):
 | 
						|
            conversation.add("user", f"Message {i}")
 | 
						|
 | 
						|
        all_messages = conversation.get_messages()
 | 
						|
        assert len(all_messages) == 5, "Should have 5 messages"
 | 
						|
 | 
						|
        limited_messages = conversation.get_messages(limit=2)
 | 
						|
        assert (
 | 
						|
            len(limited_messages) == 2
 | 
						|
        ), "Should have 2 limited messages"
 | 
						|
 | 
						|
        offset_messages = conversation.get_messages(offset=2)
 | 
						|
        assert (
 | 
						|
            len(offset_messages) == 3
 | 
						|
        ), "Should have 3 offset messages"
 | 
						|
        print("✓ Get messages test passed")
 | 
						|
    finally:
 | 
						|
        cleanup_test(temp_dir, db_path)
 | 
						|
 | 
						|
 | 
						|
def test_search_messages():
 | 
						|
    """Test searching messages."""
 | 
						|
    temp_dir, db_path, conversation = setup_test()
 | 
						|
    try:
 | 
						|
        conversation.add("user", "Hello world")
 | 
						|
        conversation.add("assistant", "Hello there")
 | 
						|
        conversation.add("user", "Goodbye world")
 | 
						|
 | 
						|
        results = conversation.search_messages("world")
 | 
						|
        assert (
 | 
						|
            len(results) == 2
 | 
						|
        ), "Should find 2 messages with 'world'"
 | 
						|
        assert all(
 | 
						|
            "world" in msg["content"] for msg in results
 | 
						|
        ), "All results should contain 'world'"
 | 
						|
        print("✓ Search messages test passed")
 | 
						|
    finally:
 | 
						|
        cleanup_test(temp_dir, db_path)
 | 
						|
 | 
						|
 | 
						|
def test_get_statistics():
 | 
						|
    """Test getting conversation statistics."""
 | 
						|
    temp_dir, db_path, conversation = setup_test()
 | 
						|
    try:
 | 
						|
        conversation.add("user", "Hello", token_count=2)
 | 
						|
        conversation.add("assistant", "Hi", token_count=1)
 | 
						|
 | 
						|
        stats = conversation.get_statistics()
 | 
						|
        assert (
 | 
						|
            stats["total_messages"] == 2
 | 
						|
        ), "Should have 2 total messages"
 | 
						|
        assert (
 | 
						|
            stats["unique_roles"] == 2
 | 
						|
        ), "Should have 2 unique roles"
 | 
						|
        assert (
 | 
						|
            stats["total_tokens"] == 3
 | 
						|
        ), "Should have 3 total tokens"
 | 
						|
        print("✓ Get statistics test passed")
 | 
						|
    finally:
 | 
						|
        cleanup_test(temp_dir, db_path)
 | 
						|
 | 
						|
 | 
						|
def test_json_operations():
 | 
						|
    """Test JSON save and load operations."""
 | 
						|
    temp_dir, db_path, conversation = setup_test()
 | 
						|
    try:
 | 
						|
        conversation.add("user", "Hello")
 | 
						|
        conversation.add("assistant", "Hi")
 | 
						|
 | 
						|
        json_path = Path(temp_dir.name) / "test_conversation.json"
 | 
						|
        conversation.save_as_json(str(json_path))
 | 
						|
        assert json_path.exists(), "JSON file should exist"
 | 
						|
 | 
						|
        new_conversation = DuckDBConversation(
 | 
						|
            db_path=str(Path(temp_dir.name) / "new.duckdb")
 | 
						|
        )
 | 
						|
        assert new_conversation.load_from_json(
 | 
						|
            str(json_path)
 | 
						|
        ), "Should load from JSON"
 | 
						|
        assert (
 | 
						|
            len(new_conversation.get_messages()) == 2
 | 
						|
        ), "Should have 2 messages after load"
 | 
						|
        print("✓ JSON operations test passed")
 | 
						|
    finally:
 | 
						|
        cleanup_test(temp_dir, db_path)
 | 
						|
 | 
						|
 | 
						|
def test_yaml_operations():
 | 
						|
    """Test YAML save and load operations."""
 | 
						|
    temp_dir, db_path, conversation = setup_test()
 | 
						|
    try:
 | 
						|
        conversation.add("user", "Hello")
 | 
						|
        conversation.add("assistant", "Hi")
 | 
						|
 | 
						|
        yaml_path = Path(temp_dir.name) / "test_conversation.yaml"
 | 
						|
        conversation.save_as_yaml(str(yaml_path))
 | 
						|
        assert yaml_path.exists(), "YAML file should exist"
 | 
						|
 | 
						|
        new_conversation = DuckDBConversation(
 | 
						|
            db_path=str(Path(temp_dir.name) / "new.duckdb")
 | 
						|
        )
 | 
						|
        assert new_conversation.load_from_yaml(
 | 
						|
            str(yaml_path)
 | 
						|
        ), "Should load from YAML"
 | 
						|
        assert (
 | 
						|
            len(new_conversation.get_messages()) == 2
 | 
						|
        ), "Should have 2 messages after load"
 | 
						|
        print("✓ YAML operations test passed")
 | 
						|
    finally:
 | 
						|
        cleanup_test(temp_dir, db_path)
 | 
						|
 | 
						|
 | 
						|
def test_message_types():
 | 
						|
    """Test different message types."""
 | 
						|
    temp_dir, db_path, conversation = setup_test()
 | 
						|
    try:
 | 
						|
        conversation.add(
 | 
						|
            "system",
 | 
						|
            "System message",
 | 
						|
            message_type=MessageType.SYSTEM,
 | 
						|
        )
 | 
						|
        conversation.add(
 | 
						|
            "user", "User message", message_type=MessageType.USER
 | 
						|
        )
 | 
						|
        conversation.add(
 | 
						|
            "assistant",
 | 
						|
            "Assistant message",
 | 
						|
            message_type=MessageType.ASSISTANT,
 | 
						|
        )
 | 
						|
        conversation.add(
 | 
						|
            "function",
 | 
						|
            "Function message",
 | 
						|
            message_type=MessageType.FUNCTION,
 | 
						|
        )
 | 
						|
        conversation.add(
 | 
						|
            "tool", "Tool message", message_type=MessageType.TOOL
 | 
						|
        )
 | 
						|
 | 
						|
        messages = conversation.get_messages()
 | 
						|
        assert len(messages) == 5, "Should have 5 messages"
 | 
						|
        assert all(
 | 
						|
            "message_type" in msg for msg in messages
 | 
						|
        ), "All messages should have type"
 | 
						|
        print("✓ Message types test passed")
 | 
						|
    finally:
 | 
						|
        cleanup_test(temp_dir, db_path)
 | 
						|
 | 
						|
 | 
						|
def test_delete_operations():
 | 
						|
    """Test deletion operations."""
 | 
						|
    temp_dir, db_path, conversation = setup_test()
 | 
						|
    try:
 | 
						|
        conversation.add("user", "Hello")
 | 
						|
        conversation.add("assistant", "Hi")
 | 
						|
 | 
						|
        assert (
 | 
						|
            conversation.delete_current_conversation()
 | 
						|
        ), "Should delete conversation"
 | 
						|
        assert (
 | 
						|
            len(conversation.get_messages()) == 0
 | 
						|
        ), "Should have no messages after delete"
 | 
						|
 | 
						|
        conversation.add("user", "New message")
 | 
						|
        assert conversation.clear_all(), "Should clear all messages"
 | 
						|
        assert (
 | 
						|
            len(conversation.get_messages()) == 0
 | 
						|
        ), "Should have no messages after clear"
 | 
						|
        print("✓ Delete operations test passed")
 | 
						|
    finally:
 | 
						|
        cleanup_test(temp_dir, db_path)
 | 
						|
 | 
						|
 | 
						|
def test_concurrent_operations():
 | 
						|
    """Test concurrent operations."""
 | 
						|
    temp_dir, db_path, conversation = setup_test()
 | 
						|
    try:
 | 
						|
 | 
						|
        def add_messages():
 | 
						|
            for i in range(10):
 | 
						|
                conversation.add("user", f"Message {i}")
 | 
						|
 | 
						|
        threads = [
 | 
						|
            threading.Thread(target=add_messages) for _ in range(5)
 | 
						|
        ]
 | 
						|
        for thread in threads:
 | 
						|
            thread.start()
 | 
						|
        for thread in threads:
 | 
						|
            thread.join()
 | 
						|
 | 
						|
        messages = conversation.get_messages()
 | 
						|
        assert (
 | 
						|
            len(messages) == 50
 | 
						|
        ), "Should have 50 messages (10 * 5 threads)"
 | 
						|
        print("✓ Concurrent operations test passed")
 | 
						|
    finally:
 | 
						|
        cleanup_test(temp_dir, db_path)
 | 
						|
 | 
						|
 | 
						|
def test_error_handling():
 | 
						|
    """Test error handling."""
 | 
						|
    temp_dir, db_path, conversation = setup_test()
 | 
						|
    try:
 | 
						|
        # Test invalid message type
 | 
						|
        try:
 | 
						|
            conversation.add(
 | 
						|
                "user", "Message", message_type="invalid"
 | 
						|
            )
 | 
						|
            assert (
 | 
						|
                False
 | 
						|
            ), "Should raise exception for invalid message type"
 | 
						|
        except Exception:
 | 
						|
            pass
 | 
						|
 | 
						|
        # Test invalid JSON content
 | 
						|
        try:
 | 
						|
            conversation.add("user", {"invalid": object()})
 | 
						|
            assert (
 | 
						|
                False
 | 
						|
            ), "Should raise exception for invalid JSON content"
 | 
						|
        except Exception:
 | 
						|
            pass
 | 
						|
 | 
						|
        # Test invalid file operations
 | 
						|
        try:
 | 
						|
            conversation.load_from_json("/nonexistent/path.json")
 | 
						|
            assert (
 | 
						|
                False
 | 
						|
            ), "Should raise exception for invalid file path"
 | 
						|
        except Exception:
 | 
						|
            pass
 | 
						|
 | 
						|
        print("✓ Error handling test passed")
 | 
						|
    finally:
 | 
						|
        cleanup_test(temp_dir, db_path)
 | 
						|
 | 
						|
 | 
						|
def run_all_tests():
 | 
						|
    """Run all tests."""
 | 
						|
    print("Running DuckDB Conversation tests...")
 | 
						|
    tests = [
 | 
						|
        test_initialization,
 | 
						|
        test_add_message,
 | 
						|
        test_add_complex_message,
 | 
						|
        test_batch_add,
 | 
						|
        test_get_str,
 | 
						|
        test_get_messages,
 | 
						|
        test_search_messages,
 | 
						|
        test_get_statistics,
 | 
						|
        test_json_operations,
 | 
						|
        test_yaml_operations,
 | 
						|
        test_message_types,
 | 
						|
        test_delete_operations,
 | 
						|
        test_concurrent_operations,
 | 
						|
        test_error_handling,
 | 
						|
    ]
 | 
						|
 | 
						|
    for test in tests:
 | 
						|
        try:
 | 
						|
            test()
 | 
						|
        except Exception as e:
 | 
						|
            print(f"✗ {test.__name__} failed: {str(e)}")
 | 
						|
            raise
 | 
						|
 | 
						|
    print("\nAll tests completed successfully!")
 | 
						|
 | 
						|
 | 
						|
if __name__ == "__main__":
 | 
						|
    run_all_tests()
 |