You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
446 lines
14 KiB
446 lines
14 KiB
import json
|
|
import time
|
|
import os
|
|
import sys
|
|
import socket
|
|
import subprocess
|
|
from datetime import datetime
|
|
from typing import Dict, Callable, Tuple
|
|
from loguru import logger
|
|
from swarms.communication.pulsar_struct import (
|
|
PulsarConversation,
|
|
Message,
|
|
)
|
|
|
|
|
|
def check_pulsar_client_installed() -> bool:
|
|
"""Check if pulsar-client package is installed."""
|
|
try:
|
|
import pulsar
|
|
|
|
return True
|
|
except ImportError:
|
|
return False
|
|
|
|
|
|
def install_pulsar_client() -> bool:
|
|
"""Install pulsar-client package using pip."""
|
|
try:
|
|
logger.info("Installing pulsar-client package...")
|
|
result = subprocess.run(
|
|
[sys.executable, "-m", "pip", "install", "pulsar-client"],
|
|
capture_output=True,
|
|
text=True,
|
|
)
|
|
if result.returncode == 0:
|
|
logger.info("Successfully installed pulsar-client")
|
|
return True
|
|
else:
|
|
logger.error(
|
|
f"Failed to install pulsar-client: {result.stderr}"
|
|
)
|
|
return False
|
|
except Exception as e:
|
|
logger.error(f"Error installing pulsar-client: {str(e)}")
|
|
return False
|
|
|
|
|
|
def check_port_available(
|
|
host: str = "localhost", port: int = 6650
|
|
) -> bool:
|
|
"""Check if a port is open on the given host."""
|
|
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
try:
|
|
sock.settimeout(2) # 2 second timeout
|
|
result = sock.connect_ex((host, port))
|
|
return result == 0
|
|
except Exception:
|
|
return False
|
|
finally:
|
|
sock.close()
|
|
|
|
|
|
def setup_test_broker() -> Tuple[bool, str]:
|
|
"""
|
|
Set up a test broker for running tests.
|
|
Returns (success, message).
|
|
"""
|
|
try:
|
|
from pulsar import Client
|
|
|
|
# Create a memory-based standalone broker for testing
|
|
client = Client("pulsar://localhost:6650")
|
|
producer = client.create_producer("test-topic")
|
|
producer.close()
|
|
client.close()
|
|
return True, "Test broker setup successful"
|
|
except Exception as e:
|
|
return False, f"Failed to set up test broker: {str(e)}"
|
|
|
|
|
|
class PulsarTestSuite:
|
|
"""Custom test suite for PulsarConversation class."""
|
|
|
|
def __init__(self, pulsar_host: str = "pulsar://localhost:6650"):
|
|
self.pulsar_host = pulsar_host
|
|
self.host = pulsar_host.split("://")[1].split(":")[0]
|
|
self.port = int(pulsar_host.split(":")[-1])
|
|
self.test_results = {
|
|
"test_suite": "PulsarConversation Tests",
|
|
"timestamp": datetime.now().isoformat(),
|
|
"total_tests": 0,
|
|
"passed_tests": 0,
|
|
"failed_tests": 0,
|
|
"skipped_tests": 0,
|
|
"results": [],
|
|
}
|
|
|
|
def check_pulsar_setup(self) -> bool:
|
|
"""
|
|
Check if Pulsar is properly set up and provide guidance if it's not.
|
|
"""
|
|
# First check if pulsar-client is installed
|
|
if not check_pulsar_client_installed():
|
|
logger.error(
|
|
"\nPulsar client library is not installed. Installing now..."
|
|
)
|
|
if not install_pulsar_client():
|
|
logger.error(
|
|
"\nFailed to install pulsar-client. Please install it manually:\n"
|
|
" $ pip install pulsar-client\n"
|
|
)
|
|
return False
|
|
|
|
# Import the newly installed package
|
|
try:
|
|
from swarms.communication.pulsar_struct import (
|
|
PulsarConversation,
|
|
Message,
|
|
)
|
|
except ImportError as e:
|
|
logger.error(
|
|
f"Failed to import PulsarConversation after installation: {str(e)}"
|
|
)
|
|
return False
|
|
|
|
# Try to set up test broker
|
|
success, message = setup_test_broker()
|
|
if not success:
|
|
logger.error(
|
|
f"\nFailed to set up test environment: {message}"
|
|
)
|
|
return False
|
|
|
|
logger.info("Pulsar setup check passed successfully")
|
|
return True
|
|
|
|
def run_test(self, test_func: Callable) -> Dict:
|
|
"""Run a single test and return its result."""
|
|
start_time = time.time()
|
|
test_name = test_func.__name__
|
|
|
|
try:
|
|
logger.info(f"Running test: {test_name}")
|
|
test_func()
|
|
success = True
|
|
error = None
|
|
status = "PASSED"
|
|
except Exception as e:
|
|
success = False
|
|
error = str(e)
|
|
status = "FAILED"
|
|
logger.error(f"Test {test_name} failed: {error}")
|
|
|
|
end_time = time.time()
|
|
duration = round(end_time - start_time, 3)
|
|
|
|
result = {
|
|
"test_name": test_name,
|
|
"success": success,
|
|
"duration": duration,
|
|
"error": error,
|
|
"timestamp": datetime.now().isoformat(),
|
|
"status": status,
|
|
}
|
|
|
|
self.test_results["total_tests"] += 1
|
|
if success:
|
|
self.test_results["passed_tests"] += 1
|
|
else:
|
|
self.test_results["failed_tests"] += 1
|
|
|
|
self.test_results["results"].append(result)
|
|
return result
|
|
|
|
def test_initialization(self):
|
|
"""Test PulsarConversation initialization."""
|
|
conversation = PulsarConversation(
|
|
pulsar_host=self.pulsar_host,
|
|
system_prompt="Test system prompt",
|
|
)
|
|
assert conversation.conversation_id is not None
|
|
assert conversation.health_check()["client_connected"] is True
|
|
conversation.__del__()
|
|
|
|
def test_add_message(self):
|
|
"""Test adding a message."""
|
|
conversation = PulsarConversation(
|
|
pulsar_host=self.pulsar_host
|
|
)
|
|
msg_id = conversation.add("user", "Test message")
|
|
assert msg_id is not None
|
|
|
|
# Verify message was added
|
|
messages = conversation.get_messages()
|
|
assert len(messages) > 0
|
|
assert messages[0]["content"] == "Test message"
|
|
conversation.__del__()
|
|
|
|
def test_batch_add_messages(self):
|
|
"""Test adding multiple messages."""
|
|
conversation = PulsarConversation(
|
|
pulsar_host=self.pulsar_host
|
|
)
|
|
messages = [
|
|
Message(role="user", content="Message 1"),
|
|
Message(role="assistant", content="Message 2"),
|
|
]
|
|
msg_ids = conversation.batch_add(messages)
|
|
assert len(msg_ids) == 2
|
|
|
|
# Verify messages were added
|
|
stored_messages = conversation.get_messages()
|
|
assert len(stored_messages) == 2
|
|
assert stored_messages[0]["content"] == "Message 1"
|
|
assert stored_messages[1]["content"] == "Message 2"
|
|
conversation.__del__()
|
|
|
|
def test_get_messages(self):
|
|
"""Test retrieving messages."""
|
|
conversation = PulsarConversation(
|
|
pulsar_host=self.pulsar_host
|
|
)
|
|
conversation.add("user", "Test message")
|
|
messages = conversation.get_messages()
|
|
assert len(messages) > 0
|
|
conversation.__del__()
|
|
|
|
def test_search_messages(self):
|
|
"""Test searching messages."""
|
|
conversation = PulsarConversation(
|
|
pulsar_host=self.pulsar_host
|
|
)
|
|
conversation.add("user", "Unique test message")
|
|
results = conversation.search("unique")
|
|
assert len(results) > 0
|
|
conversation.__del__()
|
|
|
|
def test_conversation_clear(self):
|
|
"""Test clearing conversation."""
|
|
conversation = PulsarConversation(
|
|
pulsar_host=self.pulsar_host
|
|
)
|
|
conversation.add("user", "Test message")
|
|
conversation.clear()
|
|
messages = conversation.get_messages()
|
|
assert len(messages) == 0
|
|
conversation.__del__()
|
|
|
|
def test_conversation_export_import(self):
|
|
"""Test exporting and importing conversation."""
|
|
conversation = PulsarConversation(
|
|
pulsar_host=self.pulsar_host
|
|
)
|
|
conversation.add("user", "Test message")
|
|
conversation.export_conversation("test_export.json")
|
|
|
|
new_conversation = PulsarConversation(
|
|
pulsar_host=self.pulsar_host
|
|
)
|
|
new_conversation.import_conversation("test_export.json")
|
|
messages = new_conversation.get_messages()
|
|
assert len(messages) > 0
|
|
conversation.__del__()
|
|
new_conversation.__del__()
|
|
|
|
def test_message_count(self):
|
|
"""Test message counting."""
|
|
conversation = PulsarConversation(
|
|
pulsar_host=self.pulsar_host
|
|
)
|
|
conversation.add("user", "Message 1")
|
|
conversation.add("assistant", "Message 2")
|
|
counts = conversation.count_messages_by_role()
|
|
assert counts["user"] == 1
|
|
assert counts["assistant"] == 1
|
|
conversation.__del__()
|
|
|
|
def test_conversation_string(self):
|
|
"""Test string representation."""
|
|
conversation = PulsarConversation(
|
|
pulsar_host=self.pulsar_host
|
|
)
|
|
conversation.add("user", "Test message")
|
|
string_rep = conversation.get_str()
|
|
assert "Test message" in string_rep
|
|
conversation.__del__()
|
|
|
|
def test_conversation_json(self):
|
|
"""Test JSON conversion."""
|
|
conversation = PulsarConversation(
|
|
pulsar_host=self.pulsar_host
|
|
)
|
|
conversation.add("user", "Test message")
|
|
json_data = conversation.to_json()
|
|
assert isinstance(json_data, str)
|
|
assert "Test message" in json_data
|
|
conversation.__del__()
|
|
|
|
def test_conversation_yaml(self):
|
|
"""Test YAML conversion."""
|
|
conversation = PulsarConversation(
|
|
pulsar_host=self.pulsar_host
|
|
)
|
|
conversation.add("user", "Test message")
|
|
yaml_data = conversation.to_yaml()
|
|
assert isinstance(yaml_data, str)
|
|
assert "Test message" in yaml_data
|
|
conversation.__del__()
|
|
|
|
def test_last_message(self):
|
|
"""Test getting last message."""
|
|
conversation = PulsarConversation(
|
|
pulsar_host=self.pulsar_host
|
|
)
|
|
conversation.add("user", "Test message")
|
|
last_msg = conversation.get_last_message()
|
|
assert last_msg["content"] == "Test message"
|
|
conversation.__del__()
|
|
|
|
def test_messages_by_role(self):
|
|
"""Test getting messages by role."""
|
|
conversation = PulsarConversation(
|
|
pulsar_host=self.pulsar_host
|
|
)
|
|
conversation.add("user", "User message")
|
|
conversation.add("assistant", "Assistant message")
|
|
user_messages = conversation.get_messages_by_role("user")
|
|
assert len(user_messages) == 1
|
|
conversation.__del__()
|
|
|
|
def test_conversation_summary(self):
|
|
"""Test getting conversation summary."""
|
|
conversation = PulsarConversation(
|
|
pulsar_host=self.pulsar_host
|
|
)
|
|
conversation.add("user", "Test message")
|
|
summary = conversation.get_conversation_summary()
|
|
assert summary["message_count"] == 1
|
|
conversation.__del__()
|
|
|
|
def test_conversation_statistics(self):
|
|
"""Test getting conversation statistics."""
|
|
conversation = PulsarConversation(
|
|
pulsar_host=self.pulsar_host
|
|
)
|
|
conversation.add("user", "Test message")
|
|
stats = conversation.get_statistics()
|
|
assert stats["total_messages"] == 1
|
|
conversation.__del__()
|
|
|
|
def test_health_check(self):
|
|
"""Test health check functionality."""
|
|
conversation = PulsarConversation(
|
|
pulsar_host=self.pulsar_host
|
|
)
|
|
health = conversation.health_check()
|
|
assert health["client_connected"] is True
|
|
conversation.__del__()
|
|
|
|
def test_cache_stats(self):
|
|
"""Test cache statistics."""
|
|
conversation = PulsarConversation(
|
|
pulsar_host=self.pulsar_host
|
|
)
|
|
stats = conversation.get_cache_stats()
|
|
assert "hits" in stats
|
|
assert "misses" in stats
|
|
conversation.__del__()
|
|
|
|
def run_all_tests(self):
|
|
"""Run all test cases."""
|
|
if not self.check_pulsar_setup():
|
|
logger.error(
|
|
"Pulsar setup check failed. Please check the error messages above."
|
|
)
|
|
return
|
|
|
|
test_methods = [
|
|
method
|
|
for method in dir(self)
|
|
if method.startswith("test_")
|
|
and callable(getattr(self, method))
|
|
]
|
|
|
|
logger.info(f"Running {len(test_methods)} tests...")
|
|
|
|
for method_name in test_methods:
|
|
test_method = getattr(self, method_name)
|
|
self.run_test(test_method)
|
|
|
|
self.save_results()
|
|
|
|
def save_results(self):
|
|
"""Save test results to JSON file."""
|
|
total_tests = (
|
|
self.test_results["passed_tests"]
|
|
+ self.test_results["failed_tests"]
|
|
)
|
|
|
|
if total_tests > 0:
|
|
self.test_results["success_rate"] = round(
|
|
(self.test_results["passed_tests"] / total_tests)
|
|
* 100,
|
|
2,
|
|
)
|
|
else:
|
|
self.test_results["success_rate"] = 0
|
|
|
|
# Add test environment info
|
|
self.test_results["environment"] = {
|
|
"pulsar_host": self.pulsar_host,
|
|
"pulsar_port": self.port,
|
|
"pulsar_client_installed": check_pulsar_client_installed(),
|
|
"os": os.uname().sysname,
|
|
"python_version": subprocess.check_output(
|
|
["python", "--version"]
|
|
)
|
|
.decode()
|
|
.strip(),
|
|
}
|
|
|
|
with open("pulsar_test_results.json", "w") as f:
|
|
json.dump(self.test_results, f, indent=2)
|
|
|
|
logger.info(
|
|
f"\nTest Results Summary:\n"
|
|
f"Total tests: {self.test_results['total_tests']}\n"
|
|
f"Passed: {self.test_results['passed_tests']}\n"
|
|
f"Failed: {self.test_results['failed_tests']}\n"
|
|
f"Skipped: {self.test_results['skipped_tests']}\n"
|
|
f"Success rate: {self.test_results['success_rate']}%\n"
|
|
f"Results saved to: pulsar_test_results.json"
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
try:
|
|
test_suite = PulsarTestSuite()
|
|
test_suite.run_all_tests()
|
|
except KeyboardInterrupt:
|
|
logger.warning("Tests interrupted by user")
|
|
exit(1)
|
|
except Exception as e:
|
|
logger.error(f"Test suite failed: {str(e)}")
|
|
exit(1)
|