parent
24b4288943
commit
3b1f514545
@ -1,404 +0,0 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
import threading
|
||||
from swarms.communication.duckdb_wrap import (
|
||||
DuckDBConversation,
|
||||
Message,
|
||||
MessageType,
|
||||
)
|
||||
|
||||
|
||||
def setup_test():
|
||||
"""Set up test environment."""
|
||||
temp_dir = tempfile.TemporaryDirectory()
|
||||
db_path = Path(temp_dir.name) / "test_conversations.duckdb"
|
||||
conversation = DuckDBConversation(
|
||||
db_path=str(db_path),
|
||||
enable_timestamps=True,
|
||||
enable_logging=True,
|
||||
)
|
||||
return temp_dir, db_path, conversation
|
||||
|
||||
|
||||
def cleanup_test(temp_dir, db_path):
|
||||
"""Clean up test environment."""
|
||||
if os.path.exists(db_path):
|
||||
os.remove(db_path)
|
||||
temp_dir.cleanup()
|
||||
|
||||
|
||||
def test_initialization():
|
||||
"""Test conversation initialization."""
|
||||
temp_dir, db_path, _ = setup_test()
|
||||
try:
|
||||
conv = DuckDBConversation(db_path=str(db_path))
|
||||
assert conv.db_path == db_path, "Database path mismatch"
|
||||
assert (
|
||||
conv.table_name == "conversations"
|
||||
), "Table name mismatch"
|
||||
assert (
|
||||
conv.enable_timestamps is True
|
||||
), "Timestamps should be enabled"
|
||||
assert (
|
||||
conv.current_conversation_id is not None
|
||||
), "Conversation ID should not be None"
|
||||
print("✓ Initialization test passed")
|
||||
finally:
|
||||
cleanup_test(temp_dir, db_path)
|
||||
|
||||
|
||||
def test_add_message():
|
||||
"""Test adding a single message."""
|
||||
temp_dir, db_path, conversation = setup_test()
|
||||
try:
|
||||
msg_id = conversation.add(
|
||||
role="user",
|
||||
content="Hello, world!",
|
||||
message_type=MessageType.USER,
|
||||
)
|
||||
assert msg_id is not None, "Message ID should not be None"
|
||||
assert isinstance(
|
||||
msg_id, int
|
||||
), "Message ID should be an integer"
|
||||
print("✓ Add message test passed")
|
||||
finally:
|
||||
cleanup_test(temp_dir, db_path)
|
||||
|
||||
|
||||
def test_add_complex_message():
|
||||
"""Test adding a message with complex content."""
|
||||
temp_dir, db_path, conversation = setup_test()
|
||||
try:
|
||||
complex_content = {
|
||||
"text": "Hello",
|
||||
"data": [1, 2, 3],
|
||||
"nested": {"key": "value"},
|
||||
}
|
||||
msg_id = conversation.add(
|
||||
role="assistant",
|
||||
content=complex_content,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
metadata={"source": "test"},
|
||||
token_count=10,
|
||||
)
|
||||
assert msg_id is not None, "Message ID should not be None"
|
||||
print("✓ Add complex message test passed")
|
||||
finally:
|
||||
cleanup_test(temp_dir, db_path)
|
||||
|
||||
|
||||
def test_batch_add():
|
||||
"""Test batch adding messages."""
|
||||
temp_dir, db_path, conversation = setup_test()
|
||||
try:
|
||||
messages = [
|
||||
Message(
|
||||
role="user",
|
||||
content="First message",
|
||||
message_type=MessageType.USER,
|
||||
),
|
||||
Message(
|
||||
role="assistant",
|
||||
content="Second message",
|
||||
message_type=MessageType.ASSISTANT,
|
||||
),
|
||||
]
|
||||
msg_ids = conversation.batch_add(messages)
|
||||
assert len(msg_ids) == 2, "Should have 2 message IDs"
|
||||
assert all(
|
||||
isinstance(id, int) for id in msg_ids
|
||||
), "All IDs should be integers"
|
||||
print("✓ Batch add test passed")
|
||||
finally:
|
||||
cleanup_test(temp_dir, db_path)
|
||||
|
||||
|
||||
def test_get_str():
|
||||
"""Test getting conversation as string."""
|
||||
temp_dir, db_path, conversation = setup_test()
|
||||
try:
|
||||
conversation.add("user", "Hello")
|
||||
conversation.add("assistant", "Hi there!")
|
||||
conv_str = conversation.get_str()
|
||||
assert "user: Hello" in conv_str, "User message not found"
|
||||
assert (
|
||||
"assistant: Hi there!" in conv_str
|
||||
), "Assistant message not found"
|
||||
print("✓ Get string test passed")
|
||||
finally:
|
||||
cleanup_test(temp_dir, db_path)
|
||||
|
||||
|
||||
def test_get_messages():
|
||||
"""Test getting messages with pagination."""
|
||||
temp_dir, db_path, conversation = setup_test()
|
||||
try:
|
||||
for i in range(5):
|
||||
conversation.add("user", f"Message {i}")
|
||||
|
||||
all_messages = conversation.get_messages()
|
||||
assert len(all_messages) == 5, "Should have 5 messages"
|
||||
|
||||
limited_messages = conversation.get_messages(limit=2)
|
||||
assert (
|
||||
len(limited_messages) == 2
|
||||
), "Should have 2 limited messages"
|
||||
|
||||
offset_messages = conversation.get_messages(offset=2)
|
||||
assert (
|
||||
len(offset_messages) == 3
|
||||
), "Should have 3 offset messages"
|
||||
print("✓ Get messages test passed")
|
||||
finally:
|
||||
cleanup_test(temp_dir, db_path)
|
||||
|
||||
|
||||
def test_search_messages():
|
||||
"""Test searching messages."""
|
||||
temp_dir, db_path, conversation = setup_test()
|
||||
try:
|
||||
conversation.add("user", "Hello world")
|
||||
conversation.add("assistant", "Hello there")
|
||||
conversation.add("user", "Goodbye world")
|
||||
|
||||
results = conversation.search_messages("world")
|
||||
assert (
|
||||
len(results) == 2
|
||||
), "Should find 2 messages with 'world'"
|
||||
assert all(
|
||||
"world" in msg["content"] for msg in results
|
||||
), "All results should contain 'world'"
|
||||
print("✓ Search messages test passed")
|
||||
finally:
|
||||
cleanup_test(temp_dir, db_path)
|
||||
|
||||
|
||||
def test_get_statistics():
|
||||
"""Test getting conversation statistics."""
|
||||
temp_dir, db_path, conversation = setup_test()
|
||||
try:
|
||||
conversation.add("user", "Hello", token_count=2)
|
||||
conversation.add("assistant", "Hi", token_count=1)
|
||||
|
||||
stats = conversation.get_statistics()
|
||||
assert (
|
||||
stats["total_messages"] == 2
|
||||
), "Should have 2 total messages"
|
||||
assert (
|
||||
stats["unique_roles"] == 2
|
||||
), "Should have 2 unique roles"
|
||||
assert (
|
||||
stats["total_tokens"] == 3
|
||||
), "Should have 3 total tokens"
|
||||
print("✓ Get statistics test passed")
|
||||
finally:
|
||||
cleanup_test(temp_dir, db_path)
|
||||
|
||||
|
||||
def test_json_operations():
|
||||
"""Test JSON save and load operations."""
|
||||
temp_dir, db_path, conversation = setup_test()
|
||||
try:
|
||||
conversation.add("user", "Hello")
|
||||
conversation.add("assistant", "Hi")
|
||||
|
||||
json_path = Path(temp_dir.name) / "test_conversation.json"
|
||||
conversation.save_as_json(str(json_path))
|
||||
assert json_path.exists(), "JSON file should exist"
|
||||
|
||||
new_conversation = DuckDBConversation(
|
||||
db_path=str(Path(temp_dir.name) / "new.duckdb")
|
||||
)
|
||||
assert new_conversation.load_from_json(
|
||||
str(json_path)
|
||||
), "Should load from JSON"
|
||||
assert (
|
||||
len(new_conversation.get_messages()) == 2
|
||||
), "Should have 2 messages after load"
|
||||
print("✓ JSON operations test passed")
|
||||
finally:
|
||||
cleanup_test(temp_dir, db_path)
|
||||
|
||||
|
||||
def test_yaml_operations():
|
||||
"""Test YAML save and load operations."""
|
||||
temp_dir, db_path, conversation = setup_test()
|
||||
try:
|
||||
conversation.add("user", "Hello")
|
||||
conversation.add("assistant", "Hi")
|
||||
|
||||
yaml_path = Path(temp_dir.name) / "test_conversation.yaml"
|
||||
conversation.save_as_yaml(str(yaml_path))
|
||||
assert yaml_path.exists(), "YAML file should exist"
|
||||
|
||||
new_conversation = DuckDBConversation(
|
||||
db_path=str(Path(temp_dir.name) / "new.duckdb")
|
||||
)
|
||||
assert new_conversation.load_from_yaml(
|
||||
str(yaml_path)
|
||||
), "Should load from YAML"
|
||||
assert (
|
||||
len(new_conversation.get_messages()) == 2
|
||||
), "Should have 2 messages after load"
|
||||
print("✓ YAML operations test passed")
|
||||
finally:
|
||||
cleanup_test(temp_dir, db_path)
|
||||
|
||||
|
||||
def test_message_types():
|
||||
"""Test different message types."""
|
||||
temp_dir, db_path, conversation = setup_test()
|
||||
try:
|
||||
conversation.add(
|
||||
"system",
|
||||
"System message",
|
||||
message_type=MessageType.SYSTEM,
|
||||
)
|
||||
conversation.add(
|
||||
"user", "User message", message_type=MessageType.USER
|
||||
)
|
||||
conversation.add(
|
||||
"assistant",
|
||||
"Assistant message",
|
||||
message_type=MessageType.ASSISTANT,
|
||||
)
|
||||
conversation.add(
|
||||
"function",
|
||||
"Function message",
|
||||
message_type=MessageType.FUNCTION,
|
||||
)
|
||||
conversation.add(
|
||||
"tool", "Tool message", message_type=MessageType.TOOL
|
||||
)
|
||||
|
||||
messages = conversation.get_messages()
|
||||
assert len(messages) == 5, "Should have 5 messages"
|
||||
assert all(
|
||||
"message_type" in msg for msg in messages
|
||||
), "All messages should have type"
|
||||
print("✓ Message types test passed")
|
||||
finally:
|
||||
cleanup_test(temp_dir, db_path)
|
||||
|
||||
|
||||
def test_delete_operations():
|
||||
"""Test deletion operations."""
|
||||
temp_dir, db_path, conversation = setup_test()
|
||||
try:
|
||||
conversation.add("user", "Hello")
|
||||
conversation.add("assistant", "Hi")
|
||||
|
||||
assert (
|
||||
conversation.delete_current_conversation()
|
||||
), "Should delete conversation"
|
||||
assert (
|
||||
len(conversation.get_messages()) == 0
|
||||
), "Should have no messages after delete"
|
||||
|
||||
conversation.add("user", "New message")
|
||||
assert conversation.clear_all(), "Should clear all messages"
|
||||
assert (
|
||||
len(conversation.get_messages()) == 0
|
||||
), "Should have no messages after clear"
|
||||
print("✓ Delete operations test passed")
|
||||
finally:
|
||||
cleanup_test(temp_dir, db_path)
|
||||
|
||||
|
||||
def test_concurrent_operations():
|
||||
"""Test concurrent operations."""
|
||||
temp_dir, db_path, conversation = setup_test()
|
||||
try:
|
||||
|
||||
def add_messages():
|
||||
for i in range(10):
|
||||
conversation.add("user", f"Message {i}")
|
||||
|
||||
threads = [
|
||||
threading.Thread(target=add_messages) for _ in range(5)
|
||||
]
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
messages = conversation.get_messages()
|
||||
assert (
|
||||
len(messages) == 50
|
||||
), "Should have 50 messages (10 * 5 threads)"
|
||||
print("✓ Concurrent operations test passed")
|
||||
finally:
|
||||
cleanup_test(temp_dir, db_path)
|
||||
|
||||
|
||||
def test_error_handling():
|
||||
"""Test error handling."""
|
||||
temp_dir, db_path, conversation = setup_test()
|
||||
try:
|
||||
# Test invalid message type
|
||||
try:
|
||||
conversation.add(
|
||||
"user", "Message", message_type="invalid"
|
||||
)
|
||||
assert (
|
||||
False
|
||||
), "Should raise exception for invalid message type"
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Test invalid JSON content
|
||||
try:
|
||||
conversation.add("user", {"invalid": object()})
|
||||
assert (
|
||||
False
|
||||
), "Should raise exception for invalid JSON content"
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Test invalid file operations
|
||||
try:
|
||||
conversation.load_from_json("/nonexistent/path.json")
|
||||
assert (
|
||||
False
|
||||
), "Should raise exception for invalid file path"
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
print("✓ Error handling test passed")
|
||||
finally:
|
||||
cleanup_test(temp_dir, db_path)
|
||||
|
||||
|
||||
def run_all_tests():
|
||||
"""Run all tests."""
|
||||
print("Running DuckDB Conversation tests...")
|
||||
tests = [
|
||||
test_initialization,
|
||||
test_add_message,
|
||||
test_add_complex_message,
|
||||
test_batch_add,
|
||||
test_get_str,
|
||||
test_get_messages,
|
||||
test_search_messages,
|
||||
test_get_statistics,
|
||||
test_json_operations,
|
||||
test_yaml_operations,
|
||||
test_message_types,
|
||||
test_delete_operations,
|
||||
test_concurrent_operations,
|
||||
test_error_handling,
|
||||
]
|
||||
|
||||
for test in tests:
|
||||
try:
|
||||
test()
|
||||
except Exception as e:
|
||||
print(f"✗ {test.__name__} failed: {str(e)}")
|
||||
raise
|
||||
|
||||
print("\nAll tests completed successfully!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_all_tests()
|
@ -1,445 +0,0 @@
|
||||
import json
|
||||
import time
|
||||
import os
|
||||
import sys
|
||||
import socket
|
||||
import subprocess
|
||||
from datetime import datetime
|
||||
from typing import Dict, Callable, Tuple
|
||||
from loguru import logger
|
||||
from swarms.communication.pulsar_struct import (
|
||||
PulsarConversation,
|
||||
Message,
|
||||
)
|
||||
|
||||
|
||||
def check_pulsar_client_installed() -> bool:
|
||||
"""Check if pulsar-client package is installed."""
|
||||
try:
|
||||
import pulsar
|
||||
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
def install_pulsar_client() -> bool:
|
||||
"""Install pulsar-client package using pip."""
|
||||
try:
|
||||
logger.info("Installing pulsar-client package...")
|
||||
result = subprocess.run(
|
||||
[sys.executable, "-m", "pip", "install", "pulsar-client"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
if result.returncode == 0:
|
||||
logger.info("Successfully installed pulsar-client")
|
||||
return True
|
||||
else:
|
||||
logger.error(
|
||||
f"Failed to install pulsar-client: {result.stderr}"
|
||||
)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error installing pulsar-client: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
def check_port_available(
|
||||
host: str = "localhost", port: int = 6650
|
||||
) -> bool:
|
||||
"""Check if a port is open on the given host."""
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
try:
|
||||
sock.settimeout(2) # 2 second timeout
|
||||
result = sock.connect_ex((host, port))
|
||||
return result == 0
|
||||
except Exception:
|
||||
return False
|
||||
finally:
|
||||
sock.close()
|
||||
|
||||
|
||||
def setup_test_broker() -> Tuple[bool, str]:
|
||||
"""
|
||||
Set up a test broker for running tests.
|
||||
Returns (success, message).
|
||||
"""
|
||||
try:
|
||||
from pulsar import Client
|
||||
|
||||
# Create a memory-based standalone broker for testing
|
||||
client = Client("pulsar://localhost:6650")
|
||||
producer = client.create_producer("test-topic")
|
||||
producer.close()
|
||||
client.close()
|
||||
return True, "Test broker setup successful"
|
||||
except Exception as e:
|
||||
return False, f"Failed to set up test broker: {str(e)}"
|
||||
|
||||
|
||||
class PulsarTestSuite:
|
||||
"""Custom test suite for PulsarConversation class."""
|
||||
|
||||
def __init__(self, pulsar_host: str = "pulsar://localhost:6650"):
|
||||
self.pulsar_host = pulsar_host
|
||||
self.host = pulsar_host.split("://")[1].split(":")[0]
|
||||
self.port = int(pulsar_host.split(":")[-1])
|
||||
self.test_results = {
|
||||
"test_suite": "PulsarConversation Tests",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"total_tests": 0,
|
||||
"passed_tests": 0,
|
||||
"failed_tests": 0,
|
||||
"skipped_tests": 0,
|
||||
"results": [],
|
||||
}
|
||||
|
||||
def check_pulsar_setup(self) -> bool:
|
||||
"""
|
||||
Check if Pulsar is properly set up and provide guidance if it's not.
|
||||
"""
|
||||
# First check if pulsar-client is installed
|
||||
if not check_pulsar_client_installed():
|
||||
logger.error(
|
||||
"\nPulsar client library is not installed. Installing now..."
|
||||
)
|
||||
if not install_pulsar_client():
|
||||
logger.error(
|
||||
"\nFailed to install pulsar-client. Please install it manually:\n"
|
||||
" $ pip install pulsar-client\n"
|
||||
)
|
||||
return False
|
||||
|
||||
# Import the newly installed package
|
||||
try:
|
||||
from swarms.communication.pulsar_struct import (
|
||||
PulsarConversation,
|
||||
Message,
|
||||
)
|
||||
except ImportError as e:
|
||||
logger.error(
|
||||
f"Failed to import PulsarConversation after installation: {str(e)}"
|
||||
)
|
||||
return False
|
||||
|
||||
# Try to set up test broker
|
||||
success, message = setup_test_broker()
|
||||
if not success:
|
||||
logger.error(
|
||||
f"\nFailed to set up test environment: {message}"
|
||||
)
|
||||
return False
|
||||
|
||||
logger.info("Pulsar setup check passed successfully")
|
||||
return True
|
||||
|
||||
def run_test(self, test_func: Callable) -> Dict:
|
||||
"""Run a single test and return its result."""
|
||||
start_time = time.time()
|
||||
test_name = test_func.__name__
|
||||
|
||||
try:
|
||||
logger.info(f"Running test: {test_name}")
|
||||
test_func()
|
||||
success = True
|
||||
error = None
|
||||
status = "PASSED"
|
||||
except Exception as e:
|
||||
success = False
|
||||
error = str(e)
|
||||
status = "FAILED"
|
||||
logger.error(f"Test {test_name} failed: {error}")
|
||||
|
||||
end_time = time.time()
|
||||
duration = round(end_time - start_time, 3)
|
||||
|
||||
result = {
|
||||
"test_name": test_name,
|
||||
"success": success,
|
||||
"duration": duration,
|
||||
"error": error,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"status": status,
|
||||
}
|
||||
|
||||
self.test_results["total_tests"] += 1
|
||||
if success:
|
||||
self.test_results["passed_tests"] += 1
|
||||
else:
|
||||
self.test_results["failed_tests"] += 1
|
||||
|
||||
self.test_results["results"].append(result)
|
||||
return result
|
||||
|
||||
def test_initialization(self):
|
||||
"""Test PulsarConversation initialization."""
|
||||
conversation = PulsarConversation(
|
||||
pulsar_host=self.pulsar_host,
|
||||
system_prompt="Test system prompt",
|
||||
)
|
||||
assert conversation.conversation_id is not None
|
||||
assert conversation.health_check()["client_connected"] is True
|
||||
conversation.__del__()
|
||||
|
||||
def test_add_message(self):
|
||||
"""Test adding a message."""
|
||||
conversation = PulsarConversation(
|
||||
pulsar_host=self.pulsar_host
|
||||
)
|
||||
msg_id = conversation.add("user", "Test message")
|
||||
assert msg_id is not None
|
||||
|
||||
# Verify message was added
|
||||
messages = conversation.get_messages()
|
||||
assert len(messages) > 0
|
||||
assert messages[0]["content"] == "Test message"
|
||||
conversation.__del__()
|
||||
|
||||
def test_batch_add_messages(self):
|
||||
"""Test adding multiple messages."""
|
||||
conversation = PulsarConversation(
|
||||
pulsar_host=self.pulsar_host
|
||||
)
|
||||
messages = [
|
||||
Message(role="user", content="Message 1"),
|
||||
Message(role="assistant", content="Message 2"),
|
||||
]
|
||||
msg_ids = conversation.batch_add(messages)
|
||||
assert len(msg_ids) == 2
|
||||
|
||||
# Verify messages were added
|
||||
stored_messages = conversation.get_messages()
|
||||
assert len(stored_messages) == 2
|
||||
assert stored_messages[0]["content"] == "Message 1"
|
||||
assert stored_messages[1]["content"] == "Message 2"
|
||||
conversation.__del__()
|
||||
|
||||
def test_get_messages(self):
|
||||
"""Test retrieving messages."""
|
||||
conversation = PulsarConversation(
|
||||
pulsar_host=self.pulsar_host
|
||||
)
|
||||
conversation.add("user", "Test message")
|
||||
messages = conversation.get_messages()
|
||||
assert len(messages) > 0
|
||||
conversation.__del__()
|
||||
|
||||
def test_search_messages(self):
|
||||
"""Test searching messages."""
|
||||
conversation = PulsarConversation(
|
||||
pulsar_host=self.pulsar_host
|
||||
)
|
||||
conversation.add("user", "Unique test message")
|
||||
results = conversation.search("unique")
|
||||
assert len(results) > 0
|
||||
conversation.__del__()
|
||||
|
||||
def test_conversation_clear(self):
|
||||
"""Test clearing conversation."""
|
||||
conversation = PulsarConversation(
|
||||
pulsar_host=self.pulsar_host
|
||||
)
|
||||
conversation.add("user", "Test message")
|
||||
conversation.clear()
|
||||
messages = conversation.get_messages()
|
||||
assert len(messages) == 0
|
||||
conversation.__del__()
|
||||
|
||||
def test_conversation_export_import(self):
|
||||
"""Test exporting and importing conversation."""
|
||||
conversation = PulsarConversation(
|
||||
pulsar_host=self.pulsar_host
|
||||
)
|
||||
conversation.add("user", "Test message")
|
||||
conversation.export_conversation("test_export.json")
|
||||
|
||||
new_conversation = PulsarConversation(
|
||||
pulsar_host=self.pulsar_host
|
||||
)
|
||||
new_conversation.import_conversation("test_export.json")
|
||||
messages = new_conversation.get_messages()
|
||||
assert len(messages) > 0
|
||||
conversation.__del__()
|
||||
new_conversation.__del__()
|
||||
|
||||
def test_message_count(self):
|
||||
"""Test message counting."""
|
||||
conversation = PulsarConversation(
|
||||
pulsar_host=self.pulsar_host
|
||||
)
|
||||
conversation.add("user", "Message 1")
|
||||
conversation.add("assistant", "Message 2")
|
||||
counts = conversation.count_messages_by_role()
|
||||
assert counts["user"] == 1
|
||||
assert counts["assistant"] == 1
|
||||
conversation.__del__()
|
||||
|
||||
def test_conversation_string(self):
|
||||
"""Test string representation."""
|
||||
conversation = PulsarConversation(
|
||||
pulsar_host=self.pulsar_host
|
||||
)
|
||||
conversation.add("user", "Test message")
|
||||
string_rep = conversation.get_str()
|
||||
assert "Test message" in string_rep
|
||||
conversation.__del__()
|
||||
|
||||
def test_conversation_json(self):
|
||||
"""Test JSON conversion."""
|
||||
conversation = PulsarConversation(
|
||||
pulsar_host=self.pulsar_host
|
||||
)
|
||||
conversation.add("user", "Test message")
|
||||
json_data = conversation.to_json()
|
||||
assert isinstance(json_data, str)
|
||||
assert "Test message" in json_data
|
||||
conversation.__del__()
|
||||
|
||||
def test_conversation_yaml(self):
|
||||
"""Test YAML conversion."""
|
||||
conversation = PulsarConversation(
|
||||
pulsar_host=self.pulsar_host
|
||||
)
|
||||
conversation.add("user", "Test message")
|
||||
yaml_data = conversation.to_yaml()
|
||||
assert isinstance(yaml_data, str)
|
||||
assert "Test message" in yaml_data
|
||||
conversation.__del__()
|
||||
|
||||
def test_last_message(self):
|
||||
"""Test getting last message."""
|
||||
conversation = PulsarConversation(
|
||||
pulsar_host=self.pulsar_host
|
||||
)
|
||||
conversation.add("user", "Test message")
|
||||
last_msg = conversation.get_last_message()
|
||||
assert last_msg["content"] == "Test message"
|
||||
conversation.__del__()
|
||||
|
||||
def test_messages_by_role(self):
|
||||
"""Test getting messages by role."""
|
||||
conversation = PulsarConversation(
|
||||
pulsar_host=self.pulsar_host
|
||||
)
|
||||
conversation.add("user", "User message")
|
||||
conversation.add("assistant", "Assistant message")
|
||||
user_messages = conversation.get_messages_by_role("user")
|
||||
assert len(user_messages) == 1
|
||||
conversation.__del__()
|
||||
|
||||
def test_conversation_summary(self):
|
||||
"""Test getting conversation summary."""
|
||||
conversation = PulsarConversation(
|
||||
pulsar_host=self.pulsar_host
|
||||
)
|
||||
conversation.add("user", "Test message")
|
||||
summary = conversation.get_conversation_summary()
|
||||
assert summary["message_count"] == 1
|
||||
conversation.__del__()
|
||||
|
||||
def test_conversation_statistics(self):
|
||||
"""Test getting conversation statistics."""
|
||||
conversation = PulsarConversation(
|
||||
pulsar_host=self.pulsar_host
|
||||
)
|
||||
conversation.add("user", "Test message")
|
||||
stats = conversation.get_statistics()
|
||||
assert stats["total_messages"] == 1
|
||||
conversation.__del__()
|
||||
|
||||
def test_health_check(self):
|
||||
"""Test health check functionality."""
|
||||
conversation = PulsarConversation(
|
||||
pulsar_host=self.pulsar_host
|
||||
)
|
||||
health = conversation.health_check()
|
||||
assert health["client_connected"] is True
|
||||
conversation.__del__()
|
||||
|
||||
def test_cache_stats(self):
|
||||
"""Test cache statistics."""
|
||||
conversation = PulsarConversation(
|
||||
pulsar_host=self.pulsar_host
|
||||
)
|
||||
stats = conversation.get_cache_stats()
|
||||
assert "hits" in stats
|
||||
assert "misses" in stats
|
||||
conversation.__del__()
|
||||
|
||||
def run_all_tests(self):
|
||||
"""Run all test cases."""
|
||||
if not self.check_pulsar_setup():
|
||||
logger.error(
|
||||
"Pulsar setup check failed. Please check the error messages above."
|
||||
)
|
||||
return
|
||||
|
||||
test_methods = [
|
||||
method
|
||||
for method in dir(self)
|
||||
if method.startswith("test_")
|
||||
and callable(getattr(self, method))
|
||||
]
|
||||
|
||||
logger.info(f"Running {len(test_methods)} tests...")
|
||||
|
||||
for method_name in test_methods:
|
||||
test_method = getattr(self, method_name)
|
||||
self.run_test(test_method)
|
||||
|
||||
self.save_results()
|
||||
|
||||
def save_results(self):
|
||||
"""Save test results to JSON file."""
|
||||
total_tests = (
|
||||
self.test_results["passed_tests"]
|
||||
+ self.test_results["failed_tests"]
|
||||
)
|
||||
|
||||
if total_tests > 0:
|
||||
self.test_results["success_rate"] = round(
|
||||
(self.test_results["passed_tests"] / total_tests)
|
||||
* 100,
|
||||
2,
|
||||
)
|
||||
else:
|
||||
self.test_results["success_rate"] = 0
|
||||
|
||||
# Add test environment info
|
||||
self.test_results["environment"] = {
|
||||
"pulsar_host": self.pulsar_host,
|
||||
"pulsar_port": self.port,
|
||||
"pulsar_client_installed": check_pulsar_client_installed(),
|
||||
"os": os.uname().sysname,
|
||||
"python_version": subprocess.check_output(
|
||||
["python", "--version"]
|
||||
)
|
||||
.decode()
|
||||
.strip(),
|
||||
}
|
||||
|
||||
with open("pulsar_test_results.json", "w") as f:
|
||||
json.dump(self.test_results, f, indent=2)
|
||||
|
||||
logger.info(
|
||||
f"\nTest Results Summary:\n"
|
||||
f"Total tests: {self.test_results['total_tests']}\n"
|
||||
f"Passed: {self.test_results['passed_tests']}\n"
|
||||
f"Failed: {self.test_results['failed_tests']}\n"
|
||||
f"Skipped: {self.test_results['skipped_tests']}\n"
|
||||
f"Success rate: {self.test_results['success_rate']}%\n"
|
||||
f"Results saved to: pulsar_test_results.json"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
test_suite = PulsarTestSuite()
|
||||
test_suite.run_all_tests()
|
||||
except KeyboardInterrupt:
|
||||
logger.warning("Tests interrupted by user")
|
||||
exit(1)
|
||||
except Exception as e:
|
||||
logger.error(f"Test suite failed: {str(e)}")
|
||||
exit(1)
|
@ -1,350 +0,0 @@
|
||||
import time
|
||||
import json
|
||||
from datetime import datetime
|
||||
from loguru import logger
|
||||
|
||||
from swarms.communication.redis_wrap import (
|
||||
RedisConversation,
|
||||
REDIS_AVAILABLE,
|
||||
)
|
||||
|
||||
|
||||
class TestResults:
|
||||
def __init__(self):
|
||||
self.results = []
|
||||
self.start_time = datetime.now()
|
||||
self.end_time = None
|
||||
self.total_tests = 0
|
||||
self.passed_tests = 0
|
||||
self.failed_tests = 0
|
||||
|
||||
def add_result(
|
||||
self, test_name: str, passed: bool, error: str = None
|
||||
):
|
||||
self.total_tests += 1
|
||||
if passed:
|
||||
self.passed_tests += 1
|
||||
status = "✅ PASSED"
|
||||
else:
|
||||
self.failed_tests += 1
|
||||
status = "❌ FAILED"
|
||||
|
||||
self.results.append(
|
||||
{
|
||||
"test_name": test_name,
|
||||
"status": status,
|
||||
"error": error if error else "None",
|
||||
}
|
||||
)
|
||||
|
||||
def generate_markdown(self) -> str:
|
||||
self.end_time = datetime.now()
|
||||
duration = (self.end_time - self.start_time).total_seconds()
|
||||
|
||||
md = [
|
||||
"# Redis Conversation Test Results",
|
||||
"",
|
||||
f"Test Run: {self.start_time.strftime('%Y-%m-%d %H:%M:%S')}",
|
||||
f"Duration: {duration:.2f} seconds",
|
||||
"",
|
||||
"## Summary",
|
||||
f"- Total Tests: {self.total_tests}",
|
||||
f"- Passed: {self.passed_tests}",
|
||||
f"- Failed: {self.failed_tests}",
|
||||
f"- Success Rate: {(self.passed_tests/self.total_tests*100):.1f}%",
|
||||
"",
|
||||
"## Detailed Results",
|
||||
"",
|
||||
"| Test Name | Status | Error |",
|
||||
"|-----------|--------|-------|",
|
||||
]
|
||||
|
||||
for result in self.results:
|
||||
md.append(
|
||||
f"| {result['test_name']} | {result['status']} | {result['error']} |"
|
||||
)
|
||||
|
||||
return "\n".join(md)
|
||||
|
||||
|
||||
class RedisConversationTester:
|
||||
def __init__(self):
|
||||
self.results = TestResults()
|
||||
self.conversation = None
|
||||
self.redis_server = None
|
||||
|
||||
def run_test(self, test_func: callable, test_name: str):
|
||||
"""Run a single test and record its result."""
|
||||
try:
|
||||
test_func()
|
||||
self.results.add_result(test_name, True)
|
||||
except Exception as e:
|
||||
self.results.add_result(test_name, False, str(e))
|
||||
logger.error(f"Test '{test_name}' failed: {str(e)}")
|
||||
|
||||
def setup(self):
|
||||
"""Initialize Redis server and conversation for testing."""
|
||||
try:
|
||||
# Try first with external Redis (if available)
|
||||
logger.info(
|
||||
"Trying to connect to external Redis server..."
|
||||
)
|
||||
self.conversation = RedisConversation(
|
||||
system_prompt="Test System Prompt",
|
||||
redis_host="localhost",
|
||||
redis_port=6379,
|
||||
redis_retry_attempts=1,
|
||||
use_embedded_redis=False, # Try external first
|
||||
)
|
||||
logger.info(
|
||||
"Successfully connected to external Redis server"
|
||||
)
|
||||
return True
|
||||
except Exception as external_error:
|
||||
logger.info(
|
||||
f"External Redis connection failed: {external_error}"
|
||||
)
|
||||
logger.info("Trying to start embedded Redis server...")
|
||||
|
||||
try:
|
||||
# Fallback to embedded Redis
|
||||
self.conversation = RedisConversation(
|
||||
system_prompt="Test System Prompt",
|
||||
redis_host="localhost",
|
||||
redis_port=6379,
|
||||
redis_retry_attempts=3,
|
||||
use_embedded_redis=True,
|
||||
)
|
||||
logger.info(
|
||||
"Successfully started embedded Redis server"
|
||||
)
|
||||
return True
|
||||
except Exception as embedded_error:
|
||||
logger.error(
|
||||
"Both external and embedded Redis failed:"
|
||||
)
|
||||
logger.error(f" External: {external_error}")
|
||||
logger.error(f" Embedded: {embedded_error}")
|
||||
return False
|
||||
|
||||
def cleanup(self):
|
||||
"""Cleanup resources after tests."""
|
||||
if self.conversation:
|
||||
try:
|
||||
# Check if we have an embedded server to stop
|
||||
if (
|
||||
hasattr(self.conversation, "embedded_server")
|
||||
and self.conversation.embedded_server is not None
|
||||
):
|
||||
self.conversation.embedded_server.stop()
|
||||
# Close Redis client if it exists
|
||||
if (
|
||||
hasattr(self.conversation, "redis_client")
|
||||
and self.conversation.redis_client
|
||||
):
|
||||
self.conversation.redis_client.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error during cleanup: {str(e)}")
|
||||
|
||||
def test_initialization(self):
|
||||
"""Test basic initialization."""
|
||||
assert (
|
||||
self.conversation is not None
|
||||
), "Failed to initialize RedisConversation"
|
||||
assert (
|
||||
self.conversation.system_prompt == "Test System Prompt"
|
||||
), "System prompt not set correctly"
|
||||
|
||||
def test_add_message(self):
|
||||
"""Test adding messages."""
|
||||
self.conversation.add("user", "Hello")
|
||||
self.conversation.add("assistant", "Hi there!")
|
||||
messages = self.conversation.return_messages_as_list()
|
||||
assert len(messages) >= 2, "Failed to add messages"
|
||||
|
||||
def test_json_message(self):
|
||||
"""Test adding JSON messages."""
|
||||
json_content = {"key": "value", "nested": {"data": 123}}
|
||||
self.conversation.add("system", json_content)
|
||||
last_message = self.conversation.get_final_message_content()
|
||||
|
||||
# Parse the JSON string back to dict for comparison
|
||||
if isinstance(last_message, str):
|
||||
try:
|
||||
parsed_content = json.loads(last_message)
|
||||
assert isinstance(
|
||||
parsed_content, dict
|
||||
), "Failed to handle JSON message"
|
||||
except json.JSONDecodeError:
|
||||
assert (
|
||||
False
|
||||
), "JSON message was not stored as valid JSON"
|
||||
else:
|
||||
assert isinstance(
|
||||
last_message, dict
|
||||
), "Failed to handle JSON message"
|
||||
|
||||
def test_search(self):
|
||||
"""Test search functionality."""
|
||||
self.conversation.add("user", "searchable message")
|
||||
results = self.conversation.search("searchable")
|
||||
assert len(results) > 0, "Search failed to find message"
|
||||
|
||||
def test_delete(self):
|
||||
"""Test message deletion."""
|
||||
initial_count = len(
|
||||
self.conversation.return_messages_as_list()
|
||||
)
|
||||
if initial_count > 0:
|
||||
self.conversation.delete(0)
|
||||
new_count = len(
|
||||
self.conversation.return_messages_as_list()
|
||||
)
|
||||
assert (
|
||||
new_count == initial_count - 1
|
||||
), "Failed to delete message"
|
||||
|
||||
def test_update(self):
|
||||
"""Test message update."""
|
||||
# Add initial message
|
||||
self.conversation.add("user", "original message")
|
||||
|
||||
all_messages = self.conversation.return_messages_as_list()
|
||||
if len(all_messages) > 0:
|
||||
self.conversation.update(0, "user", "updated message")
|
||||
self.conversation.query(0)
|
||||
assert True, "Update method executed successfully"
|
||||
|
||||
def test_clear(self):
|
||||
"""Test clearing conversation."""
|
||||
self.conversation.add("user", "test message")
|
||||
self.conversation.clear()
|
||||
messages = self.conversation.return_messages_as_list()
|
||||
assert len(messages) == 0, "Failed to clear conversation"
|
||||
|
||||
def test_export_import(self):
|
||||
"""Test export and import functionality."""
|
||||
self.conversation.add("user", "export test")
|
||||
self.conversation.export_conversation("test_export.txt")
|
||||
self.conversation.clear()
|
||||
self.conversation.import_conversation("test_export.txt")
|
||||
messages = self.conversation.return_messages_as_list()
|
||||
assert (
|
||||
len(messages) > 0
|
||||
), "Failed to export/import conversation"
|
||||
|
||||
def test_json_operations(self):
|
||||
"""Test JSON operations."""
|
||||
self.conversation.add("user", "json test")
|
||||
json_data = self.conversation.to_json()
|
||||
assert isinstance(
|
||||
json.loads(json_data), list
|
||||
), "Failed to convert to JSON"
|
||||
|
||||
def test_yaml_operations(self):
|
||||
"""Test YAML operations."""
|
||||
self.conversation.add("user", "yaml test")
|
||||
yaml_data = self.conversation.to_yaml()
|
||||
assert isinstance(yaml_data, str), "Failed to convert to YAML"
|
||||
|
||||
def test_token_counting(self):
|
||||
"""Test token counting functionality."""
|
||||
self.conversation.add("user", "token test message")
|
||||
time.sleep(1) # Wait for async token counting
|
||||
messages = self.conversation.to_dict()
|
||||
assert isinstance(
|
||||
messages, list
|
||||
), "Token counting test completed"
|
||||
|
||||
def test_cache_operations(self):
|
||||
"""Test cache operations."""
|
||||
self.conversation.add("user", "cache test")
|
||||
stats = self.conversation.get_cache_stats()
|
||||
assert isinstance(stats, dict), "Failed to get cache stats"
|
||||
|
||||
def test_conversation_stats(self):
|
||||
"""Test conversation statistics."""
|
||||
self.conversation.add("user", "stats test")
|
||||
counts = self.conversation.count_messages_by_role()
|
||||
assert isinstance(
|
||||
counts, dict
|
||||
), "Failed to get message counts"
|
||||
|
||||
def run_all_tests(self):
|
||||
"""Run all tests and generate report."""
|
||||
if not REDIS_AVAILABLE:
|
||||
logger.error(
|
||||
"Redis is not available. Please install redis package."
|
||||
)
|
||||
return "# Redis Tests Failed\n\nRedis package is not installed."
|
||||
|
||||
try:
|
||||
if not self.setup():
|
||||
logger.warning(
|
||||
"Failed to setup Redis connection. This is expected on systems without Redis server."
|
||||
)
|
||||
|
||||
# Generate a report indicating the limitation
|
||||
setup_failed_md = [
|
||||
"# Redis Conversation Test Results",
|
||||
"",
|
||||
f"Test Run: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}",
|
||||
"",
|
||||
"## Summary",
|
||||
"❌ **Redis Server Setup Failed**",
|
||||
"",
|
||||
"The Redis conversation class will work properly when a Redis server is available.",
|
||||
]
|
||||
|
||||
return "\n".join(setup_failed_md)
|
||||
|
||||
tests = [
|
||||
(self.test_initialization, "Initialization Test"),
|
||||
(self.test_add_message, "Add Message Test"),
|
||||
(self.test_json_message, "JSON Message Test"),
|
||||
(self.test_search, "Search Test"),
|
||||
(self.test_delete, "Delete Test"),
|
||||
(self.test_update, "Update Test"),
|
||||
(self.test_clear, "Clear Test"),
|
||||
(self.test_export_import, "Export/Import Test"),
|
||||
(self.test_json_operations, "JSON Operations Test"),
|
||||
(self.test_yaml_operations, "YAML Operations Test"),
|
||||
(self.test_token_counting, "Token Counting Test"),
|
||||
(self.test_cache_operations, "Cache Operations Test"),
|
||||
(
|
||||
self.test_conversation_stats,
|
||||
"Conversation Stats Test",
|
||||
),
|
||||
]
|
||||
|
||||
for test_func, test_name in tests:
|
||||
self.run_test(test_func, test_name)
|
||||
|
||||
return self.results.generate_markdown()
|
||||
finally:
|
||||
self.cleanup()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function to run tests and save results."""
|
||||
tester = RedisConversationTester()
|
||||
markdown_results = tester.run_all_tests()
|
||||
|
||||
# Save results to file
|
||||
try:
|
||||
with open(
|
||||
"redis_test_results.md", "w", encoding="utf-8"
|
||||
) as f:
|
||||
f.write(markdown_results)
|
||||
logger.info(
|
||||
"Test results have been saved to redis_test_results.md"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save test results: {e}")
|
||||
|
||||
# Also print results to console
|
||||
print(markdown_results)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,386 +0,0 @@
|
||||
import json
|
||||
import datetime
|
||||
import os
|
||||
from typing import Dict, List, Any, Tuple
|
||||
from loguru import logger
|
||||
from swarms.communication.sqlite_wrap import (
|
||||
SQLiteConversation,
|
||||
Message,
|
||||
MessageType,
|
||||
)
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
from rich.panel import Panel
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
def print_test_header(test_name: str) -> None:
|
||||
"""Print a formatted test header."""
|
||||
console.print(
|
||||
Panel(
|
||||
f"[bold blue]Running Test: {test_name}[/bold blue]",
|
||||
expand=False,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def print_test_result(
|
||||
test_name: str, success: bool, message: str, execution_time: float
|
||||
) -> None:
|
||||
"""Print a formatted test result."""
|
||||
status = (
|
||||
"[bold green]PASSED[/bold green]"
|
||||
if success
|
||||
else "[bold red]FAILED[/bold red]"
|
||||
)
|
||||
console.print(f"\n{status} - {test_name}")
|
||||
console.print(f"Message: {message}")
|
||||
console.print(f"Execution time: {execution_time:.3f} seconds\n")
|
||||
|
||||
|
||||
def print_messages(
|
||||
messages: List[Dict], title: str = "Messages"
|
||||
) -> None:
|
||||
"""Print messages in a formatted table."""
|
||||
table = Table(title=title)
|
||||
table.add_column("Role", style="cyan")
|
||||
table.add_column("Content", style="green")
|
||||
table.add_column("Type", style="yellow")
|
||||
table.add_column("Timestamp", style="magenta")
|
||||
|
||||
for msg in messages:
|
||||
content = str(msg.get("content", ""))
|
||||
if isinstance(content, (dict, list)):
|
||||
content = json.dumps(content)
|
||||
table.add_row(
|
||||
msg.get("role", ""),
|
||||
content,
|
||||
str(msg.get("message_type", "")),
|
||||
str(msg.get("timestamp", "")),
|
||||
)
|
||||
|
||||
console.print(table)
|
||||
|
||||
|
||||
def run_test(
|
||||
test_func: callable, *args, **kwargs
|
||||
) -> Tuple[bool, str, float]:
|
||||
"""
|
||||
Run a test function and return its results.
|
||||
|
||||
Args:
|
||||
test_func: The test function to run
|
||||
*args: Arguments for the test function
|
||||
**kwargs: Keyword arguments for the test function
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str, float]: (success, message, execution_time)
|
||||
"""
|
||||
start_time = datetime.datetime.now()
|
||||
try:
|
||||
result = test_func(*args, **kwargs)
|
||||
end_time = datetime.datetime.now()
|
||||
execution_time = (end_time - start_time).total_seconds()
|
||||
return True, str(result), execution_time
|
||||
except Exception as e:
|
||||
end_time = datetime.datetime.now()
|
||||
execution_time = (end_time - start_time).total_seconds()
|
||||
return False, str(e), execution_time
|
||||
|
||||
|
||||
def test_basic_conversation() -> bool:
|
||||
"""Test basic conversation operations."""
|
||||
print_test_header("Basic Conversation Test")
|
||||
|
||||
db_path = "test_conversations.db"
|
||||
conversation = SQLiteConversation(db_path=db_path)
|
||||
|
||||
# Test adding messages
|
||||
console.print("\n[bold]Adding messages...[/bold]")
|
||||
conversation.add("user", "Hello")
|
||||
conversation.add("assistant", "Hi there!")
|
||||
|
||||
# Test getting messages
|
||||
console.print("\n[bold]Retrieved messages:[/bold]")
|
||||
messages = conversation.get_messages()
|
||||
print_messages(messages)
|
||||
|
||||
assert len(messages) == 2
|
||||
assert messages[0]["role"] == "user"
|
||||
assert messages[1]["role"] == "assistant"
|
||||
|
||||
# Cleanup
|
||||
os.remove(db_path)
|
||||
return True
|
||||
|
||||
|
||||
def test_message_types() -> bool:
|
||||
"""Test different message types and content formats."""
|
||||
print_test_header("Message Types Test")
|
||||
|
||||
db_path = "test_conversations.db"
|
||||
conversation = SQLiteConversation(db_path=db_path)
|
||||
|
||||
# Test different content types
|
||||
console.print("\n[bold]Adding different message types...[/bold]")
|
||||
conversation.add("user", "Simple text")
|
||||
conversation.add(
|
||||
"assistant", {"type": "json", "content": "Complex data"}
|
||||
)
|
||||
conversation.add("system", ["list", "of", "items"])
|
||||
conversation.add(
|
||||
"function",
|
||||
"Function result",
|
||||
message_type=MessageType.FUNCTION,
|
||||
)
|
||||
|
||||
console.print("\n[bold]Retrieved messages:[/bold]")
|
||||
messages = conversation.get_messages()
|
||||
print_messages(messages)
|
||||
|
||||
assert len(messages) == 4
|
||||
|
||||
# Cleanup
|
||||
os.remove(db_path)
|
||||
return True
|
||||
|
||||
|
||||
def test_conversation_operations() -> bool:
|
||||
"""Test various conversation operations."""
|
||||
print_test_header("Conversation Operations Test")
|
||||
|
||||
db_path = "test_conversations.db"
|
||||
conversation = SQLiteConversation(db_path=db_path)
|
||||
|
||||
# Test batch operations
|
||||
console.print("\n[bold]Adding batch messages...[/bold]")
|
||||
messages = [
|
||||
Message(role="user", content="Message 1"),
|
||||
Message(role="assistant", content="Message 2"),
|
||||
Message(role="user", content="Message 3"),
|
||||
]
|
||||
conversation.batch_add(messages)
|
||||
|
||||
console.print("\n[bold]Retrieved messages:[/bold]")
|
||||
all_messages = conversation.get_messages()
|
||||
print_messages(all_messages)
|
||||
|
||||
# Test statistics
|
||||
console.print("\n[bold]Conversation Statistics:[/bold]")
|
||||
stats = conversation.get_statistics()
|
||||
console.print(json.dumps(stats, indent=2))
|
||||
|
||||
# Test role counting
|
||||
console.print("\n[bold]Role Counts:[/bold]")
|
||||
role_counts = conversation.count_messages_by_role()
|
||||
console.print(json.dumps(role_counts, indent=2))
|
||||
|
||||
assert stats["total_messages"] == 3
|
||||
assert role_counts["user"] == 2
|
||||
assert role_counts["assistant"] == 1
|
||||
|
||||
# Cleanup
|
||||
os.remove(db_path)
|
||||
return True
|
||||
|
||||
|
||||
def test_file_operations() -> bool:
|
||||
"""Test file operations (JSON/YAML)."""
|
||||
print_test_header("File Operations Test")
|
||||
|
||||
db_path = "test_conversations.db"
|
||||
json_path = "test_conversation.json"
|
||||
yaml_path = "test_conversation.yaml"
|
||||
|
||||
conversation = SQLiteConversation(db_path=db_path)
|
||||
conversation.add("user", "Test message")
|
||||
|
||||
# Test JSON operations
|
||||
console.print("\n[bold]Testing JSON operations...[/bold]")
|
||||
assert conversation.save_as_json(json_path)
|
||||
console.print(f"Saved to JSON: {json_path}")
|
||||
|
||||
conversation.start_new_conversation()
|
||||
assert conversation.load_from_json(json_path)
|
||||
console.print("Loaded from JSON")
|
||||
|
||||
# Test YAML operations
|
||||
console.print("\n[bold]Testing YAML operations...[/bold]")
|
||||
assert conversation.save_as_yaml(yaml_path)
|
||||
console.print(f"Saved to YAML: {yaml_path}")
|
||||
|
||||
conversation.start_new_conversation()
|
||||
assert conversation.load_from_yaml(yaml_path)
|
||||
console.print("Loaded from YAML")
|
||||
|
||||
# Cleanup
|
||||
os.remove(db_path)
|
||||
os.remove(json_path)
|
||||
os.remove(yaml_path)
|
||||
return True
|
||||
|
||||
|
||||
def test_search_and_filter() -> bool:
|
||||
"""Test search and filter operations."""
|
||||
print_test_header("Search and Filter Test")
|
||||
|
||||
db_path = "test_conversations.db"
|
||||
conversation = SQLiteConversation(db_path=db_path)
|
||||
|
||||
# Add test messages
|
||||
console.print("\n[bold]Adding test messages...[/bold]")
|
||||
conversation.add("user", "Hello world")
|
||||
conversation.add("assistant", "Hello there")
|
||||
conversation.add("user", "Goodbye world")
|
||||
|
||||
# Test search
|
||||
console.print("\n[bold]Searching for 'world'...[/bold]")
|
||||
results = conversation.search_messages("world")
|
||||
print_messages(results, "Search Results")
|
||||
|
||||
# Test role filtering
|
||||
console.print("\n[bold]Filtering user messages...[/bold]")
|
||||
user_messages = conversation.get_messages_by_role("user")
|
||||
print_messages(user_messages, "User Messages")
|
||||
|
||||
assert len(results) == 2
|
||||
assert len(user_messages) == 2
|
||||
|
||||
# Cleanup
|
||||
os.remove(db_path)
|
||||
return True
|
||||
|
||||
|
||||
def test_conversation_management() -> bool:
|
||||
"""Test conversation management features."""
|
||||
print_test_header("Conversation Management Test")
|
||||
|
||||
db_path = "test_conversations.db"
|
||||
conversation = SQLiteConversation(db_path=db_path)
|
||||
|
||||
# Test conversation ID generation
|
||||
console.print("\n[bold]Testing conversation IDs...[/bold]")
|
||||
conv_id1 = conversation.get_conversation_id()
|
||||
console.print(f"First conversation ID: {conv_id1}")
|
||||
|
||||
conversation.start_new_conversation()
|
||||
conv_id2 = conversation.get_conversation_id()
|
||||
console.print(f"Second conversation ID: {conv_id2}")
|
||||
|
||||
assert conv_id1 != conv_id2
|
||||
|
||||
# Test conversation deletion
|
||||
console.print("\n[bold]Testing conversation deletion...[/bold]")
|
||||
conversation.add("user", "Test message")
|
||||
assert conversation.delete_current_conversation()
|
||||
console.print("Conversation deleted successfully")
|
||||
|
||||
# Cleanup
|
||||
os.remove(db_path)
|
||||
return True
|
||||
|
||||
|
||||
def generate_test_report(
|
||||
test_results: List[Dict[str, Any]],
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate a test report in JSON format.
|
||||
|
||||
Args:
|
||||
test_results: List of test results
|
||||
|
||||
Returns:
|
||||
Dict containing the test report
|
||||
"""
|
||||
total_tests = len(test_results)
|
||||
passed_tests = sum(
|
||||
1 for result in test_results if result["success"]
|
||||
)
|
||||
failed_tests = total_tests - passed_tests
|
||||
total_time = sum(
|
||||
result["execution_time"] for result in test_results
|
||||
)
|
||||
|
||||
report = {
|
||||
"timestamp": datetime.datetime.now().isoformat(),
|
||||
"summary": {
|
||||
"total_tests": total_tests,
|
||||
"passed_tests": passed_tests,
|
||||
"failed_tests": failed_tests,
|
||||
"total_execution_time": total_time,
|
||||
"average_execution_time": (
|
||||
total_time / total_tests if total_tests > 0 else 0
|
||||
),
|
||||
},
|
||||
"test_results": test_results,
|
||||
}
|
||||
|
||||
return report
|
||||
|
||||
|
||||
def run_all_tests() -> None:
|
||||
"""Run all tests and generate a report."""
|
||||
console.print(
|
||||
Panel(
|
||||
"[bold blue]Starting Test Suite[/bold blue]", expand=False
|
||||
)
|
||||
)
|
||||
|
||||
tests = [
|
||||
("Basic Conversation", test_basic_conversation),
|
||||
("Message Types", test_message_types),
|
||||
("Conversation Operations", test_conversation_operations),
|
||||
("File Operations", test_file_operations),
|
||||
("Search and Filter", test_search_and_filter),
|
||||
("Conversation Management", test_conversation_management),
|
||||
]
|
||||
|
||||
test_results = []
|
||||
|
||||
for test_name, test_func in tests:
|
||||
logger.info(f"Running test: {test_name}")
|
||||
success, message, execution_time = run_test(test_func)
|
||||
|
||||
print_test_result(test_name, success, message, execution_time)
|
||||
|
||||
result = {
|
||||
"test_name": test_name,
|
||||
"success": success,
|
||||
"message": message,
|
||||
"execution_time": execution_time,
|
||||
"timestamp": datetime.datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
if success:
|
||||
logger.success(f"Test passed: {test_name}")
|
||||
else:
|
||||
logger.error(f"Test failed: {test_name} - {message}")
|
||||
|
||||
test_results.append(result)
|
||||
|
||||
# Generate and save report
|
||||
report = generate_test_report(test_results)
|
||||
report_path = "test_report.json"
|
||||
|
||||
with open(report_path, "w") as f:
|
||||
json.dump(report, f, indent=2)
|
||||
|
||||
# Print final summary
|
||||
console.print("\n[bold blue]Test Suite Summary[/bold blue]")
|
||||
console.print(
|
||||
Panel(
|
||||
f"Total tests: {report['summary']['total_tests']}\n"
|
||||
f"Passed tests: {report['summary']['passed_tests']}\n"
|
||||
f"Failed tests: {report['summary']['failed_tests']}\n"
|
||||
f"Total execution time: {report['summary']['total_execution_time']:.2f} seconds",
|
||||
title="Summary",
|
||||
expand=False,
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(f"Test report saved to {report_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_all_tests()
|
File diff suppressed because it is too large
Load Diff
@ -1,283 +0,0 @@
|
||||
import datetime
|
||||
from datetime import timedelta
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from swarm_models.gpt4_vision_api import GPT4VisionAPI
|
||||
from swarms.prompts.multi_modal_autonomous_instruction_prompt import (
|
||||
MULTI_MODAL_AUTO_AGENT_SYSTEM_PROMPT_1,
|
||||
)
|
||||
from swarms.structs.agent import Agent
|
||||
from swarms.structs.task import Task
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llm():
|
||||
return GPT4VisionAPI()
|
||||
|
||||
|
||||
def test_agent_run_task(llm):
|
||||
task = (
|
||||
"Analyze this image of an assembly line and identify any"
|
||||
" issues such as misaligned parts, defects, or deviations"
|
||||
" from the standard assembly process. IF there is anything"
|
||||
" unsafe in the image, explain why it is unsafe and how it"
|
||||
" could be improved."
|
||||
)
|
||||
img = "assembly_line.jpg"
|
||||
|
||||
agent = Agent(
|
||||
llm=llm,
|
||||
max_loops="auto",
|
||||
sop=MULTI_MODAL_AUTO_AGENT_SYSTEM_PROMPT_1,
|
||||
dashboard=True,
|
||||
)
|
||||
|
||||
result = agent.run(task=task, img=img)
|
||||
|
||||
# Add assertions here to verify the expected behavior of the agent's run method
|
||||
assert isinstance(result, dict)
|
||||
assert "response" in result
|
||||
assert "dashboard_data" in result
|
||||
# Add more assertions as needed
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def task():
|
||||
agents = [Agent(llm=llm, id=f"Agent_{i}") for i in range(5)]
|
||||
return Task(
|
||||
id="Task_1", task="Task_Name", agents=agents, dependencies=[]
|
||||
)
|
||||
|
||||
|
||||
# Basic tests
|
||||
|
||||
|
||||
def test_task_init(task):
|
||||
assert task.id == "Task_1"
|
||||
assert task.task == "Task_Name"
|
||||
assert isinstance(task.agents, list)
|
||||
assert len(task.agents) == 5
|
||||
assert isinstance(task.dependencies, list)
|
||||
|
||||
|
||||
def test_task_execute(task, mocker):
|
||||
mocker.patch.object(Agent, "run", side_effect=[1, 2, 3, 4, 5])
|
||||
parent_results = {}
|
||||
task.execute(parent_results)
|
||||
assert isinstance(task.results, list)
|
||||
assert len(task.results) == 5
|
||||
for result in task.results:
|
||||
assert isinstance(result, int)
|
||||
|
||||
|
||||
# Parameterized tests
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_agents", [1, 3, 5, 10])
|
||||
def test_task_num_agents(task, num_agents, mocker):
|
||||
task.agents = [Agent(id=f"Agent_{i}") for i in range(num_agents)]
|
||||
mocker.patch.object(Agent, "run", return_value=1)
|
||||
parent_results = {}
|
||||
task.execute(parent_results)
|
||||
assert len(task.results) == num_agents
|
||||
|
||||
|
||||
# Exception testing
|
||||
|
||||
|
||||
def test_task_execute_with_dependency_error(task, mocker):
|
||||
task.dependencies = ["NonExistentTask"]
|
||||
mocker.patch.object(Agent, "run", return_value=1)
|
||||
parent_results = {}
|
||||
with pytest.raises(KeyError):
|
||||
task.execute(parent_results)
|
||||
|
||||
|
||||
# Mocking and monkeypatching tests
|
||||
|
||||
|
||||
def test_task_execute_with_mocked_agents(task, mocker):
|
||||
mock_agents = [Mock(spec=Agent) for _ in range(5)]
|
||||
mocker.patch.object(task, "agents", mock_agents)
|
||||
for mock_agent in mock_agents:
|
||||
mock_agent.run.return_value = 1
|
||||
parent_results = {}
|
||||
task.execute(parent_results)
|
||||
assert len(task.results) == 5
|
||||
|
||||
|
||||
def test_task_creation():
|
||||
agent = Agent()
|
||||
task = Task(id="1", task="Task1", result=None, agents=[agent])
|
||||
assert task.id == "1"
|
||||
assert task.task == "Task1"
|
||||
assert task.result is None
|
||||
assert task.agents == [agent]
|
||||
|
||||
|
||||
def test_task_with_dependencies():
|
||||
agent = Agent()
|
||||
task = Task(
|
||||
id="2",
|
||||
task="Task2",
|
||||
result=None,
|
||||
agents=[agent],
|
||||
dependencies=["Task1"],
|
||||
)
|
||||
assert task.dependencies == ["Task1"]
|
||||
|
||||
|
||||
def test_task_with_args():
|
||||
agent = Agent()
|
||||
task = Task(
|
||||
id="3",
|
||||
task="Task3",
|
||||
result=None,
|
||||
agents=[agent],
|
||||
args=["arg1", "arg2"],
|
||||
)
|
||||
assert task.args == ["arg1", "arg2"]
|
||||
|
||||
|
||||
def test_task_with_kwargs():
|
||||
agent = Agent()
|
||||
task = Task(
|
||||
id="4",
|
||||
task="Task4",
|
||||
result=None,
|
||||
agents=[agent],
|
||||
kwargs={"kwarg1": "value1"},
|
||||
)
|
||||
assert task.kwargs == {"kwarg1": "value1"}
|
||||
|
||||
|
||||
# ... continue creating tests for different scenarios
|
||||
|
||||
|
||||
# Test execute method
|
||||
def test_execute():
|
||||
agent = Agent()
|
||||
task = Task(id="5", task="Task5", result=None, agents=[agent])
|
||||
# Assuming execute method returns True on successful execution
|
||||
assert task.run() is True
|
||||
|
||||
|
||||
def test_task_execute_with_agent(mocker):
|
||||
mock_agent = mocker.Mock(spec=Agent)
|
||||
mock_agent.run.return_value = "result"
|
||||
task = Task(description="Test task", agent=mock_agent)
|
||||
task.run()
|
||||
assert task.result == "result"
|
||||
assert task.history == ["result"]
|
||||
|
||||
|
||||
def test_task_execute_with_callable(mocker):
|
||||
mock_callable = mocker.Mock()
|
||||
mock_callable.run.return_value = "result"
|
||||
task = Task(description="Test task", agent=mock_callable)
|
||||
task.run()
|
||||
assert task.result == "result"
|
||||
assert task.history == ["result"]
|
||||
|
||||
|
||||
def test_task_execute_with_condition(mocker):
|
||||
mock_agent = mocker.Mock(spec=Agent)
|
||||
mock_agent.run.return_value = "result"
|
||||
condition = mocker.Mock(return_value=True)
|
||||
task = Task(
|
||||
description="Test task", agent=mock_agent, condition=condition
|
||||
)
|
||||
task.run()
|
||||
assert task.result == "result"
|
||||
assert task.history == ["result"]
|
||||
|
||||
|
||||
def test_task_execute_with_condition_false(mocker):
|
||||
mock_agent = mocker.Mock(spec=Agent)
|
||||
mock_agent.run.return_value = "result"
|
||||
condition = mocker.Mock(return_value=False)
|
||||
task = Task(
|
||||
description="Test task", agent=mock_agent, condition=condition
|
||||
)
|
||||
task.run()
|
||||
assert task.result is None
|
||||
assert task.history == []
|
||||
|
||||
|
||||
def test_task_execute_with_action(mocker):
|
||||
mock_agent = mocker.Mock(spec=Agent)
|
||||
mock_agent.run.return_value = "result"
|
||||
action = mocker.Mock()
|
||||
task = Task(
|
||||
description="Test task", agent=mock_agent, action=action
|
||||
)
|
||||
task.run()
|
||||
assert task.result == "result"
|
||||
assert task.history == ["result"]
|
||||
action.assert_called_once()
|
||||
|
||||
|
||||
def test_task_handle_scheduled_task_now(mocker):
|
||||
mock_agent = mocker.Mock(spec=Agent)
|
||||
mock_agent.run.return_value = "result"
|
||||
task = Task(
|
||||
description="Test task",
|
||||
agent=mock_agent,
|
||||
schedule_time=datetime.now(),
|
||||
)
|
||||
task.handle_scheduled_task()
|
||||
assert task.result == "result"
|
||||
assert task.history == ["result"]
|
||||
|
||||
|
||||
def test_task_handle_scheduled_task_future(mocker):
|
||||
mock_agent = mocker.Mock(spec=Agent)
|
||||
mock_agent.run.return_value = "result"
|
||||
task = Task(
|
||||
description="Test task",
|
||||
agent=mock_agent,
|
||||
schedule_time=datetime.now() + timedelta(days=1),
|
||||
)
|
||||
with mocker.patch.object(
|
||||
task.scheduler, "enter"
|
||||
) as mock_enter, mocker.patch.object(
|
||||
task.scheduler, "run"
|
||||
) as mock_run:
|
||||
task.handle_scheduled_task()
|
||||
mock_enter.assert_called_once()
|
||||
mock_run.assert_called_once()
|
||||
|
||||
|
||||
def test_task_set_trigger():
|
||||
task = Task(description="Test task", agent=Agent())
|
||||
|
||||
def trigger():
|
||||
return True
|
||||
|
||||
task.set_trigger(trigger)
|
||||
assert task.trigger == trigger
|
||||
|
||||
|
||||
def test_task_set_action():
|
||||
task = Task(description="Test task", agent=Agent())
|
||||
|
||||
def action():
|
||||
return True
|
||||
|
||||
task.set_action(action)
|
||||
assert task.action == action
|
||||
|
||||
|
||||
def test_task_set_condition():
|
||||
task = Task(description="Test task", agent=Agent())
|
||||
|
||||
def condition():
|
||||
return True
|
||||
|
||||
task.set_condition(condition)
|
||||
assert task.condition == condition
|
Loading…
Reference in new issue