You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
405 lines
12 KiB
405 lines
12 KiB
import os
|
|
from pathlib import Path
|
|
import tempfile
|
|
import threading
|
|
from swarms.communication.duckdb_wrap import (
|
|
DuckDBConversation,
|
|
Message,
|
|
MessageType,
|
|
)
|
|
|
|
|
|
def setup_test():
|
|
"""Set up test environment."""
|
|
temp_dir = tempfile.TemporaryDirectory()
|
|
db_path = Path(temp_dir.name) / "test_conversations.duckdb"
|
|
conversation = DuckDBConversation(
|
|
db_path=str(db_path),
|
|
enable_timestamps=True,
|
|
enable_logging=True,
|
|
)
|
|
return temp_dir, db_path, conversation
|
|
|
|
|
|
def cleanup_test(temp_dir, db_path):
|
|
"""Clean up test environment."""
|
|
if os.path.exists(db_path):
|
|
os.remove(db_path)
|
|
temp_dir.cleanup()
|
|
|
|
|
|
def test_initialization():
|
|
"""Test conversation initialization."""
|
|
temp_dir, db_path, _ = setup_test()
|
|
try:
|
|
conv = DuckDBConversation(db_path=str(db_path))
|
|
assert conv.db_path == db_path, "Database path mismatch"
|
|
assert (
|
|
conv.table_name == "conversations"
|
|
), "Table name mismatch"
|
|
assert (
|
|
conv.enable_timestamps is True
|
|
), "Timestamps should be enabled"
|
|
assert (
|
|
conv.current_conversation_id is not None
|
|
), "Conversation ID should not be None"
|
|
print("✓ Initialization test passed")
|
|
finally:
|
|
cleanup_test(temp_dir, db_path)
|
|
|
|
|
|
def test_add_message():
|
|
"""Test adding a single message."""
|
|
temp_dir, db_path, conversation = setup_test()
|
|
try:
|
|
msg_id = conversation.add(
|
|
role="user",
|
|
content="Hello, world!",
|
|
message_type=MessageType.USER,
|
|
)
|
|
assert msg_id is not None, "Message ID should not be None"
|
|
assert isinstance(
|
|
msg_id, int
|
|
), "Message ID should be an integer"
|
|
print("✓ Add message test passed")
|
|
finally:
|
|
cleanup_test(temp_dir, db_path)
|
|
|
|
|
|
def test_add_complex_message():
|
|
"""Test adding a message with complex content."""
|
|
temp_dir, db_path, conversation = setup_test()
|
|
try:
|
|
complex_content = {
|
|
"text": "Hello",
|
|
"data": [1, 2, 3],
|
|
"nested": {"key": "value"},
|
|
}
|
|
msg_id = conversation.add(
|
|
role="assistant",
|
|
content=complex_content,
|
|
message_type=MessageType.ASSISTANT,
|
|
metadata={"source": "test"},
|
|
token_count=10,
|
|
)
|
|
assert msg_id is not None, "Message ID should not be None"
|
|
print("✓ Add complex message test passed")
|
|
finally:
|
|
cleanup_test(temp_dir, db_path)
|
|
|
|
|
|
def test_batch_add():
|
|
"""Test batch adding messages."""
|
|
temp_dir, db_path, conversation = setup_test()
|
|
try:
|
|
messages = [
|
|
Message(
|
|
role="user",
|
|
content="First message",
|
|
message_type=MessageType.USER,
|
|
),
|
|
Message(
|
|
role="assistant",
|
|
content="Second message",
|
|
message_type=MessageType.ASSISTANT,
|
|
),
|
|
]
|
|
msg_ids = conversation.batch_add(messages)
|
|
assert len(msg_ids) == 2, "Should have 2 message IDs"
|
|
assert all(
|
|
isinstance(id, int) for id in msg_ids
|
|
), "All IDs should be integers"
|
|
print("✓ Batch add test passed")
|
|
finally:
|
|
cleanup_test(temp_dir, db_path)
|
|
|
|
|
|
def test_get_str():
|
|
"""Test getting conversation as string."""
|
|
temp_dir, db_path, conversation = setup_test()
|
|
try:
|
|
conversation.add("user", "Hello")
|
|
conversation.add("assistant", "Hi there!")
|
|
conv_str = conversation.get_str()
|
|
assert "user: Hello" in conv_str, "User message not found"
|
|
assert (
|
|
"assistant: Hi there!" in conv_str
|
|
), "Assistant message not found"
|
|
print("✓ Get string test passed")
|
|
finally:
|
|
cleanup_test(temp_dir, db_path)
|
|
|
|
|
|
def test_get_messages():
|
|
"""Test getting messages with pagination."""
|
|
temp_dir, db_path, conversation = setup_test()
|
|
try:
|
|
for i in range(5):
|
|
conversation.add("user", f"Message {i}")
|
|
|
|
all_messages = conversation.get_messages()
|
|
assert len(all_messages) == 5, "Should have 5 messages"
|
|
|
|
limited_messages = conversation.get_messages(limit=2)
|
|
assert (
|
|
len(limited_messages) == 2
|
|
), "Should have 2 limited messages"
|
|
|
|
offset_messages = conversation.get_messages(offset=2)
|
|
assert (
|
|
len(offset_messages) == 3
|
|
), "Should have 3 offset messages"
|
|
print("✓ Get messages test passed")
|
|
finally:
|
|
cleanup_test(temp_dir, db_path)
|
|
|
|
|
|
def test_search_messages():
|
|
"""Test searching messages."""
|
|
temp_dir, db_path, conversation = setup_test()
|
|
try:
|
|
conversation.add("user", "Hello world")
|
|
conversation.add("assistant", "Hello there")
|
|
conversation.add("user", "Goodbye world")
|
|
|
|
results = conversation.search_messages("world")
|
|
assert (
|
|
len(results) == 2
|
|
), "Should find 2 messages with 'world'"
|
|
assert all(
|
|
"world" in msg["content"] for msg in results
|
|
), "All results should contain 'world'"
|
|
print("✓ Search messages test passed")
|
|
finally:
|
|
cleanup_test(temp_dir, db_path)
|
|
|
|
|
|
def test_get_statistics():
|
|
"""Test getting conversation statistics."""
|
|
temp_dir, db_path, conversation = setup_test()
|
|
try:
|
|
conversation.add("user", "Hello", token_count=2)
|
|
conversation.add("assistant", "Hi", token_count=1)
|
|
|
|
stats = conversation.get_statistics()
|
|
assert (
|
|
stats["total_messages"] == 2
|
|
), "Should have 2 total messages"
|
|
assert (
|
|
stats["unique_roles"] == 2
|
|
), "Should have 2 unique roles"
|
|
assert (
|
|
stats["total_tokens"] == 3
|
|
), "Should have 3 total tokens"
|
|
print("✓ Get statistics test passed")
|
|
finally:
|
|
cleanup_test(temp_dir, db_path)
|
|
|
|
|
|
def test_json_operations():
|
|
"""Test JSON save and load operations."""
|
|
temp_dir, db_path, conversation = setup_test()
|
|
try:
|
|
conversation.add("user", "Hello")
|
|
conversation.add("assistant", "Hi")
|
|
|
|
json_path = Path(temp_dir.name) / "test_conversation.json"
|
|
conversation.save_as_json(str(json_path))
|
|
assert json_path.exists(), "JSON file should exist"
|
|
|
|
new_conversation = DuckDBConversation(
|
|
db_path=str(Path(temp_dir.name) / "new.duckdb")
|
|
)
|
|
assert new_conversation.load_from_json(
|
|
str(json_path)
|
|
), "Should load from JSON"
|
|
assert (
|
|
len(new_conversation.get_messages()) == 2
|
|
), "Should have 2 messages after load"
|
|
print("✓ JSON operations test passed")
|
|
finally:
|
|
cleanup_test(temp_dir, db_path)
|
|
|
|
|
|
def test_yaml_operations():
|
|
"""Test YAML save and load operations."""
|
|
temp_dir, db_path, conversation = setup_test()
|
|
try:
|
|
conversation.add("user", "Hello")
|
|
conversation.add("assistant", "Hi")
|
|
|
|
yaml_path = Path(temp_dir.name) / "test_conversation.yaml"
|
|
conversation.save_as_yaml(str(yaml_path))
|
|
assert yaml_path.exists(), "YAML file should exist"
|
|
|
|
new_conversation = DuckDBConversation(
|
|
db_path=str(Path(temp_dir.name) / "new.duckdb")
|
|
)
|
|
assert new_conversation.load_from_yaml(
|
|
str(yaml_path)
|
|
), "Should load from YAML"
|
|
assert (
|
|
len(new_conversation.get_messages()) == 2
|
|
), "Should have 2 messages after load"
|
|
print("✓ YAML operations test passed")
|
|
finally:
|
|
cleanup_test(temp_dir, db_path)
|
|
|
|
|
|
def test_message_types():
|
|
"""Test different message types."""
|
|
temp_dir, db_path, conversation = setup_test()
|
|
try:
|
|
conversation.add(
|
|
"system",
|
|
"System message",
|
|
message_type=MessageType.SYSTEM,
|
|
)
|
|
conversation.add(
|
|
"user", "User message", message_type=MessageType.USER
|
|
)
|
|
conversation.add(
|
|
"assistant",
|
|
"Assistant message",
|
|
message_type=MessageType.ASSISTANT,
|
|
)
|
|
conversation.add(
|
|
"function",
|
|
"Function message",
|
|
message_type=MessageType.FUNCTION,
|
|
)
|
|
conversation.add(
|
|
"tool", "Tool message", message_type=MessageType.TOOL
|
|
)
|
|
|
|
messages = conversation.get_messages()
|
|
assert len(messages) == 5, "Should have 5 messages"
|
|
assert all(
|
|
"message_type" in msg for msg in messages
|
|
), "All messages should have type"
|
|
print("✓ Message types test passed")
|
|
finally:
|
|
cleanup_test(temp_dir, db_path)
|
|
|
|
|
|
def test_delete_operations():
|
|
"""Test deletion operations."""
|
|
temp_dir, db_path, conversation = setup_test()
|
|
try:
|
|
conversation.add("user", "Hello")
|
|
conversation.add("assistant", "Hi")
|
|
|
|
assert (
|
|
conversation.delete_current_conversation()
|
|
), "Should delete conversation"
|
|
assert (
|
|
len(conversation.get_messages()) == 0
|
|
), "Should have no messages after delete"
|
|
|
|
conversation.add("user", "New message")
|
|
assert conversation.clear_all(), "Should clear all messages"
|
|
assert (
|
|
len(conversation.get_messages()) == 0
|
|
), "Should have no messages after clear"
|
|
print("✓ Delete operations test passed")
|
|
finally:
|
|
cleanup_test(temp_dir, db_path)
|
|
|
|
|
|
def test_concurrent_operations():
|
|
"""Test concurrent operations."""
|
|
temp_dir, db_path, conversation = setup_test()
|
|
try:
|
|
|
|
def add_messages():
|
|
for i in range(10):
|
|
conversation.add("user", f"Message {i}")
|
|
|
|
threads = [
|
|
threading.Thread(target=add_messages) for _ in range(5)
|
|
]
|
|
for thread in threads:
|
|
thread.start()
|
|
for thread in threads:
|
|
thread.join()
|
|
|
|
messages = conversation.get_messages()
|
|
assert (
|
|
len(messages) == 50
|
|
), "Should have 50 messages (10 * 5 threads)"
|
|
print("✓ Concurrent operations test passed")
|
|
finally:
|
|
cleanup_test(temp_dir, db_path)
|
|
|
|
|
|
def test_error_handling():
|
|
"""Test error handling."""
|
|
temp_dir, db_path, conversation = setup_test()
|
|
try:
|
|
# Test invalid message type
|
|
try:
|
|
conversation.add(
|
|
"user", "Message", message_type="invalid"
|
|
)
|
|
assert (
|
|
False
|
|
), "Should raise exception for invalid message type"
|
|
except Exception:
|
|
pass
|
|
|
|
# Test invalid JSON content
|
|
try:
|
|
conversation.add("user", {"invalid": object()})
|
|
assert (
|
|
False
|
|
), "Should raise exception for invalid JSON content"
|
|
except Exception:
|
|
pass
|
|
|
|
# Test invalid file operations
|
|
try:
|
|
conversation.load_from_json("/nonexistent/path.json")
|
|
assert (
|
|
False
|
|
), "Should raise exception for invalid file path"
|
|
except Exception:
|
|
pass
|
|
|
|
print("✓ Error handling test passed")
|
|
finally:
|
|
cleanup_test(temp_dir, db_path)
|
|
|
|
|
|
def run_all_tests():
|
|
"""Run all tests."""
|
|
print("Running DuckDB Conversation tests...")
|
|
tests = [
|
|
test_initialization,
|
|
test_add_message,
|
|
test_add_complex_message,
|
|
test_batch_add,
|
|
test_get_str,
|
|
test_get_messages,
|
|
test_search_messages,
|
|
test_get_statistics,
|
|
test_json_operations,
|
|
test_yaml_operations,
|
|
test_message_types,
|
|
test_delete_operations,
|
|
test_concurrent_operations,
|
|
test_error_handling,
|
|
]
|
|
|
|
for test in tests:
|
|
try:
|
|
test()
|
|
except Exception as e:
|
|
print(f"✗ {test.__name__} failed: {str(e)}")
|
|
raise
|
|
|
|
print("\nAll tests completed successfully!")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_all_tests()
|