You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
283 lines
9.4 KiB
283 lines
9.4 KiB
import time
|
|
import json
|
|
from datetime import datetime
|
|
from loguru import logger
|
|
|
|
from swarms.communication.redis_wrap import (
|
|
RedisConversation,
|
|
REDIS_AVAILABLE,
|
|
)
|
|
|
|
|
|
class TestResults:
|
|
def __init__(self):
|
|
self.results = []
|
|
self.start_time = datetime.now()
|
|
self.end_time = None
|
|
self.total_tests = 0
|
|
self.passed_tests = 0
|
|
self.failed_tests = 0
|
|
|
|
def add_result(
|
|
self, test_name: str, passed: bool, error: str = None
|
|
):
|
|
self.total_tests += 1
|
|
if passed:
|
|
self.passed_tests += 1
|
|
status = "✅ PASSED"
|
|
else:
|
|
self.failed_tests += 1
|
|
status = "❌ FAILED"
|
|
|
|
self.results.append(
|
|
{
|
|
"test_name": test_name,
|
|
"status": status,
|
|
"error": error if error else "None",
|
|
}
|
|
)
|
|
|
|
def generate_markdown(self) -> str:
|
|
self.end_time = datetime.now()
|
|
duration = (self.end_time - self.start_time).total_seconds()
|
|
|
|
md = [
|
|
"# Redis Conversation Test Results",
|
|
"",
|
|
f"Test Run: {self.start_time.strftime('%Y-%m-%d %H:%M:%S')}",
|
|
f"Duration: {duration:.2f} seconds",
|
|
"",
|
|
"## Summary",
|
|
f"- Total Tests: {self.total_tests}",
|
|
f"- Passed: {self.passed_tests}",
|
|
f"- Failed: {self.failed_tests}",
|
|
f"- Success Rate: {(self.passed_tests/self.total_tests*100):.1f}%",
|
|
"",
|
|
"## Detailed Results",
|
|
"",
|
|
"| Test Name | Status | Error |",
|
|
"|-----------|--------|-------|",
|
|
]
|
|
|
|
for result in self.results:
|
|
md.append(
|
|
f"| {result['test_name']} | {result['status']} | {result['error']} |"
|
|
)
|
|
|
|
return "\n".join(md)
|
|
|
|
|
|
class RedisConversationTester:
|
|
def __init__(self):
|
|
self.results = TestResults()
|
|
self.conversation = None
|
|
self.redis_server = None
|
|
|
|
def run_test(self, test_func: callable, test_name: str):
|
|
"""Run a single test and record its result."""
|
|
try:
|
|
test_func()
|
|
self.results.add_result(test_name, True)
|
|
except Exception as e:
|
|
self.results.add_result(test_name, False, str(e))
|
|
logger.error(f"Test '{test_name}' failed: {str(e)}")
|
|
|
|
def setup(self):
|
|
"""Initialize Redis server and conversation for testing."""
|
|
try:
|
|
# # Start embedded Redis server
|
|
# self.redis_server = EmbeddedRedis(port=6379)
|
|
# if not self.redis_server.start():
|
|
# logger.error("Failed to start embedded Redis server")
|
|
# return False
|
|
|
|
# Initialize Redis conversation
|
|
self.conversation = RedisConversation(
|
|
system_prompt="Test System Prompt",
|
|
redis_host="localhost",
|
|
redis_port=6379,
|
|
redis_retry_attempts=3,
|
|
use_embedded_redis=True,
|
|
)
|
|
return True
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Failed to initialize Redis conversation: {str(e)}"
|
|
)
|
|
return False
|
|
|
|
def cleanup(self):
|
|
"""Cleanup resources after tests."""
|
|
if self.redis_server:
|
|
self.redis_server.stop()
|
|
|
|
def test_initialization(self):
|
|
"""Test basic initialization."""
|
|
assert (
|
|
self.conversation is not None
|
|
), "Failed to initialize RedisConversation"
|
|
assert (
|
|
self.conversation.system_prompt == "Test System Prompt"
|
|
), "System prompt not set correctly"
|
|
|
|
def test_add_message(self):
|
|
"""Test adding messages."""
|
|
self.conversation.add("user", "Hello")
|
|
self.conversation.add("assistant", "Hi there!")
|
|
messages = self.conversation.return_messages_as_list()
|
|
assert len(messages) >= 2, "Failed to add messages"
|
|
|
|
def test_json_message(self):
|
|
"""Test adding JSON messages."""
|
|
json_content = {"key": "value", "nested": {"data": 123}}
|
|
self.conversation.add("system", json_content)
|
|
last_message = self.conversation.get_final_message_content()
|
|
assert isinstance(
|
|
json.loads(last_message), dict
|
|
), "Failed to handle JSON message"
|
|
|
|
def test_search(self):
|
|
"""Test search functionality."""
|
|
self.conversation.add("user", "searchable message")
|
|
results = self.conversation.search("searchable")
|
|
assert len(results) > 0, "Search failed to find message"
|
|
|
|
def test_delete(self):
|
|
"""Test message deletion."""
|
|
initial_count = len(
|
|
self.conversation.return_messages_as_list()
|
|
)
|
|
self.conversation.delete(0)
|
|
new_count = len(self.conversation.return_messages_as_list())
|
|
assert (
|
|
new_count == initial_count - 1
|
|
), "Failed to delete message"
|
|
|
|
def test_update(self):
|
|
"""Test message update."""
|
|
# Add initial message
|
|
self.conversation.add("user", "original message")
|
|
|
|
# Update the message
|
|
self.conversation.update(0, "user", "updated message")
|
|
|
|
# Get the message directly using query
|
|
updated_message = self.conversation.query(0)
|
|
|
|
# Verify the update
|
|
assert (
|
|
updated_message["content"] == "updated message"
|
|
), "Message content should be updated"
|
|
|
|
def test_clear(self):
|
|
"""Test clearing conversation."""
|
|
self.conversation.add("user", "test message")
|
|
self.conversation.clear()
|
|
messages = self.conversation.return_messages_as_list()
|
|
assert len(messages) == 0, "Failed to clear conversation"
|
|
|
|
def test_export_import(self):
|
|
"""Test export and import functionality."""
|
|
self.conversation.add("user", "export test")
|
|
self.conversation.export_conversation("test_export.txt")
|
|
self.conversation.clear()
|
|
self.conversation.import_conversation("test_export.txt")
|
|
messages = self.conversation.return_messages_as_list()
|
|
assert (
|
|
len(messages) > 0
|
|
), "Failed to export/import conversation"
|
|
|
|
def test_json_operations(self):
|
|
"""Test JSON operations."""
|
|
self.conversation.add("user", "json test")
|
|
json_data = self.conversation.to_json()
|
|
assert isinstance(
|
|
json.loads(json_data), list
|
|
), "Failed to convert to JSON"
|
|
|
|
def test_yaml_operations(self):
|
|
"""Test YAML operations."""
|
|
self.conversation.add("user", "yaml test")
|
|
yaml_data = self.conversation.to_yaml()
|
|
assert isinstance(yaml_data, str), "Failed to convert to YAML"
|
|
|
|
def test_token_counting(self):
|
|
"""Test token counting functionality."""
|
|
self.conversation.add("user", "token test message")
|
|
time.sleep(1) # Wait for async token counting
|
|
messages = self.conversation.to_dict()
|
|
assert any(
|
|
"token_count" in msg for msg in messages
|
|
), "Failed to count tokens"
|
|
|
|
def test_cache_operations(self):
|
|
"""Test cache operations."""
|
|
self.conversation.add("user", "cache test")
|
|
stats = self.conversation.get_cache_stats()
|
|
assert isinstance(stats, dict), "Failed to get cache stats"
|
|
|
|
def test_conversation_stats(self):
|
|
"""Test conversation statistics."""
|
|
self.conversation.add("user", "stats test")
|
|
counts = self.conversation.count_messages_by_role()
|
|
assert isinstance(
|
|
counts, dict
|
|
), "Failed to get message counts"
|
|
|
|
def run_all_tests(self):
|
|
"""Run all tests and generate report."""
|
|
if not REDIS_AVAILABLE:
|
|
logger.error(
|
|
"Redis is not available. Please install redis package."
|
|
)
|
|
return "# Redis Tests Failed\n\nRedis package is not installed."
|
|
|
|
try:
|
|
if not self.setup():
|
|
logger.error("Failed to setup Redis connection.")
|
|
return "# Redis Tests Failed\n\nFailed to connect to Redis server."
|
|
|
|
tests = [
|
|
(self.test_initialization, "Initialization Test"),
|
|
(self.test_add_message, "Add Message Test"),
|
|
(self.test_json_message, "JSON Message Test"),
|
|
(self.test_search, "Search Test"),
|
|
(self.test_delete, "Delete Test"),
|
|
(self.test_update, "Update Test"),
|
|
(self.test_clear, "Clear Test"),
|
|
(self.test_export_import, "Export/Import Test"),
|
|
(self.test_json_operations, "JSON Operations Test"),
|
|
(self.test_yaml_operations, "YAML Operations Test"),
|
|
(self.test_token_counting, "Token Counting Test"),
|
|
(self.test_cache_operations, "Cache Operations Test"),
|
|
(
|
|
self.test_conversation_stats,
|
|
"Conversation Stats Test",
|
|
),
|
|
]
|
|
|
|
for test_func, test_name in tests:
|
|
self.run_test(test_func, test_name)
|
|
|
|
return self.results.generate_markdown()
|
|
finally:
|
|
self.cleanup()
|
|
|
|
|
|
def main():
|
|
"""Main function to run tests and save results."""
|
|
tester = RedisConversationTester()
|
|
markdown_results = tester.run_all_tests()
|
|
|
|
# Save results to file
|
|
with open("redis_test_results.md", "w") as f:
|
|
f.write(markdown_results)
|
|
|
|
logger.info(
|
|
"Test results have been saved to redis_test_results.md"
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|