You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/tests/communication/test_duckdb_conversation.py

317 lines
11 KiB

import os
import json
from datetime import datetime
from pathlib import Path
import tempfile
import threading
from swarms.communication.duckdb_wrap import (
DuckDBConversation,
Message,
MessageType,
)
def setup_test():
"""Set up test environment."""
temp_dir = tempfile.TemporaryDirectory()
db_path = Path(temp_dir.name) / "test_conversations.duckdb"
conversation = DuckDBConversation(
db_path=str(db_path),
enable_timestamps=True,
enable_logging=True,
)
return temp_dir, db_path, conversation
def cleanup_test(temp_dir, db_path):
"""Clean up test environment."""
if os.path.exists(db_path):
os.remove(db_path)
temp_dir.cleanup()
def test_initialization():
"""Test conversation initialization."""
temp_dir, db_path, _ = setup_test()
try:
conv = DuckDBConversation(db_path=str(db_path))
assert conv.db_path == db_path, "Database path mismatch"
assert conv.table_name == "conversations", "Table name mismatch"
assert conv.enable_timestamps is True, "Timestamps should be enabled"
assert conv.current_conversation_id is not None, "Conversation ID should not be None"
print("✓ Initialization test passed")
finally:
cleanup_test(temp_dir, db_path)
def test_add_message():
"""Test adding a single message."""
temp_dir, db_path, conversation = setup_test()
try:
msg_id = conversation.add(
role="user",
content="Hello, world!",
message_type=MessageType.USER,
)
assert msg_id is not None, "Message ID should not be None"
assert isinstance(msg_id, int), "Message ID should be an integer"
print("✓ Add message test passed")
finally:
cleanup_test(temp_dir, db_path)
def test_add_complex_message():
"""Test adding a message with complex content."""
temp_dir, db_path, conversation = setup_test()
try:
complex_content = {
"text": "Hello",
"data": [1, 2, 3],
"nested": {"key": "value"}
}
msg_id = conversation.add(
role="assistant",
content=complex_content,
message_type=MessageType.ASSISTANT,
metadata={"source": "test"},
token_count=10
)
assert msg_id is not None, "Message ID should not be None"
print("✓ Add complex message test passed")
finally:
cleanup_test(temp_dir, db_path)
def test_batch_add():
"""Test batch adding messages."""
temp_dir, db_path, conversation = setup_test()
try:
messages = [
Message(
role="user",
content="First message",
message_type=MessageType.USER
),
Message(
role="assistant",
content="Second message",
message_type=MessageType.ASSISTANT
)
]
msg_ids = conversation.batch_add(messages)
assert len(msg_ids) == 2, "Should have 2 message IDs"
assert all(isinstance(id, int) for id in msg_ids), "All IDs should be integers"
print("✓ Batch add test passed")
finally:
cleanup_test(temp_dir, db_path)
def test_get_str():
"""Test getting conversation as string."""
temp_dir, db_path, conversation = setup_test()
try:
conversation.add("user", "Hello")
conversation.add("assistant", "Hi there!")
conv_str = conversation.get_str()
assert "user: Hello" in conv_str, "User message not found"
assert "assistant: Hi there!" in conv_str, "Assistant message not found"
print("✓ Get string test passed")
finally:
cleanup_test(temp_dir, db_path)
def test_get_messages():
"""Test getting messages with pagination."""
temp_dir, db_path, conversation = setup_test()
try:
for i in range(5):
conversation.add("user", f"Message {i}")
all_messages = conversation.get_messages()
assert len(all_messages) == 5, "Should have 5 messages"
limited_messages = conversation.get_messages(limit=2)
assert len(limited_messages) == 2, "Should have 2 limited messages"
offset_messages = conversation.get_messages(offset=2)
assert len(offset_messages) == 3, "Should have 3 offset messages"
print("✓ Get messages test passed")
finally:
cleanup_test(temp_dir, db_path)
def test_search_messages():
"""Test searching messages."""
temp_dir, db_path, conversation = setup_test()
try:
conversation.add("user", "Hello world")
conversation.add("assistant", "Hello there")
conversation.add("user", "Goodbye world")
results = conversation.search_messages("world")
assert len(results) == 2, "Should find 2 messages with 'world'"
assert all("world" in msg["content"] for msg in results), "All results should contain 'world'"
print("✓ Search messages test passed")
finally:
cleanup_test(temp_dir, db_path)
def test_get_statistics():
"""Test getting conversation statistics."""
temp_dir, db_path, conversation = setup_test()
try:
conversation.add("user", "Hello", token_count=2)
conversation.add("assistant", "Hi", token_count=1)
stats = conversation.get_statistics()
assert stats["total_messages"] == 2, "Should have 2 total messages"
assert stats["unique_roles"] == 2, "Should have 2 unique roles"
assert stats["total_tokens"] == 3, "Should have 3 total tokens"
print("✓ Get statistics test passed")
finally:
cleanup_test(temp_dir, db_path)
def test_json_operations():
"""Test JSON save and load operations."""
temp_dir, db_path, conversation = setup_test()
try:
conversation.add("user", "Hello")
conversation.add("assistant", "Hi")
json_path = Path(temp_dir.name) / "test_conversation.json"
conversation.save_as_json(str(json_path))
assert json_path.exists(), "JSON file should exist"
new_conversation = DuckDBConversation(
db_path=str(Path(temp_dir.name) / "new.duckdb")
)
assert new_conversation.load_from_json(str(json_path)), "Should load from JSON"
assert len(new_conversation.get_messages()) == 2, "Should have 2 messages after load"
print("✓ JSON operations test passed")
finally:
cleanup_test(temp_dir, db_path)
def test_yaml_operations():
"""Test YAML save and load operations."""
temp_dir, db_path, conversation = setup_test()
try:
conversation.add("user", "Hello")
conversation.add("assistant", "Hi")
yaml_path = Path(temp_dir.name) / "test_conversation.yaml"
conversation.save_as_yaml(str(yaml_path))
assert yaml_path.exists(), "YAML file should exist"
new_conversation = DuckDBConversation(
db_path=str(Path(temp_dir.name) / "new.duckdb")
)
assert new_conversation.load_from_yaml(str(yaml_path)), "Should load from YAML"
assert len(new_conversation.get_messages()) == 2, "Should have 2 messages after load"
print("✓ YAML operations test passed")
finally:
cleanup_test(temp_dir, db_path)
def test_message_types():
"""Test different message types."""
temp_dir, db_path, conversation = setup_test()
try:
conversation.add("system", "System message", message_type=MessageType.SYSTEM)
conversation.add("user", "User message", message_type=MessageType.USER)
conversation.add("assistant", "Assistant message", message_type=MessageType.ASSISTANT)
conversation.add("function", "Function message", message_type=MessageType.FUNCTION)
conversation.add("tool", "Tool message", message_type=MessageType.TOOL)
messages = conversation.get_messages()
assert len(messages) == 5, "Should have 5 messages"
assert all("message_type" in msg for msg in messages), "All messages should have type"
print("✓ Message types test passed")
finally:
cleanup_test(temp_dir, db_path)
def test_delete_operations():
"""Test deletion operations."""
temp_dir, db_path, conversation = setup_test()
try:
conversation.add("user", "Hello")
conversation.add("assistant", "Hi")
assert conversation.delete_current_conversation(), "Should delete conversation"
assert len(conversation.get_messages()) == 0, "Should have no messages after delete"
conversation.add("user", "New message")
assert conversation.clear_all(), "Should clear all messages"
assert len(conversation.get_messages()) == 0, "Should have no messages after clear"
print("✓ Delete operations test passed")
finally:
cleanup_test(temp_dir, db_path)
def test_concurrent_operations():
"""Test concurrent operations."""
temp_dir, db_path, conversation = setup_test()
try:
def add_messages():
for i in range(10):
conversation.add("user", f"Message {i}")
threads = [threading.Thread(target=add_messages) for _ in range(5)]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
messages = conversation.get_messages()
assert len(messages) == 50, "Should have 50 messages (10 * 5 threads)"
print("✓ Concurrent operations test passed")
finally:
cleanup_test(temp_dir, db_path)
def test_error_handling():
"""Test error handling."""
temp_dir, db_path, conversation = setup_test()
try:
# Test invalid message type
try:
conversation.add("user", "Message", message_type="invalid")
assert False, "Should raise exception for invalid message type"
except Exception:
pass
# Test invalid JSON content
try:
conversation.add("user", {"invalid": object()})
assert False, "Should raise exception for invalid JSON content"
except Exception:
pass
# Test invalid file operations
try:
conversation.load_from_json("/nonexistent/path.json")
assert False, "Should raise exception for invalid file path"
except Exception:
pass
print("✓ Error handling test passed")
finally:
cleanup_test(temp_dir, db_path)
def run_all_tests():
"""Run all tests."""
print("Running DuckDB Conversation tests...")
tests = [
test_initialization,
test_add_message,
test_add_complex_message,
test_batch_add,
test_get_str,
test_get_messages,
test_search_messages,
test_get_statistics,
test_json_operations,
test_yaml_operations,
test_message_types,
test_delete_operations,
test_concurrent_operations,
test_error_handling
]
for test in tests:
try:
test()
except Exception as e:
print(f"{test.__name__} failed: {str(e)}")
raise
print("\nAll tests completed successfully!")
if __name__ == '__main__':
run_all_tests()