You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/tests/communication/test_pulsar.py

446 lines
14 KiB

import json
import time
import os
import sys
import socket
import subprocess
from datetime import datetime
from typing import Dict, Callable, Tuple
from loguru import logger
from swarms.communication.pulsar_struct import (
PulsarConversation,
Message,
)
def check_pulsar_client_installed() -> bool:
"""Check if pulsar-client package is installed."""
try:
import pulsar
return True
except ImportError:
return False
def install_pulsar_client() -> bool:
"""Install pulsar-client package using pip."""
try:
logger.info("Installing pulsar-client package...")
result = subprocess.run(
[sys.executable, "-m", "pip", "install", "pulsar-client"],
capture_output=True,
text=True,
)
if result.returncode == 0:
logger.info("Successfully installed pulsar-client")
return True
else:
logger.error(
f"Failed to install pulsar-client: {result.stderr}"
)
return False
except Exception as e:
logger.error(f"Error installing pulsar-client: {str(e)}")
return False
def check_port_available(
host: str = "localhost", port: int = 6650
) -> bool:
"""Check if a port is open on the given host."""
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
try:
sock.settimeout(2) # 2 second timeout
result = sock.connect_ex((host, port))
return result == 0
except Exception:
return False
finally:
sock.close()
def setup_test_broker() -> Tuple[bool, str]:
"""
Set up a test broker for running tests.
Returns (success, message).
"""
try:
from pulsar import Client
# Create a memory-based standalone broker for testing
client = Client("pulsar://localhost:6650")
producer = client.create_producer("test-topic")
producer.close()
client.close()
return True, "Test broker setup successful"
except Exception as e:
return False, f"Failed to set up test broker: {str(e)}"
class PulsarTestSuite:
"""Custom test suite for PulsarConversation class."""
def __init__(self, pulsar_host: str = "pulsar://localhost:6650"):
self.pulsar_host = pulsar_host
self.host = pulsar_host.split("://")[1].split(":")[0]
self.port = int(pulsar_host.split(":")[-1])
self.test_results = {
"test_suite": "PulsarConversation Tests",
"timestamp": datetime.now().isoformat(),
"total_tests": 0,
"passed_tests": 0,
"failed_tests": 0,
"skipped_tests": 0,
"results": [],
}
def check_pulsar_setup(self) -> bool:
"""
Check if Pulsar is properly set up and provide guidance if it's not.
"""
# First check if pulsar-client is installed
if not check_pulsar_client_installed():
logger.error(
"\nPulsar client library is not installed. Installing now..."
)
if not install_pulsar_client():
logger.error(
"\nFailed to install pulsar-client. Please install it manually:\n"
" $ pip install pulsar-client\n"
)
return False
# Import the newly installed package
try:
from swarms.communication.pulsar_struct import (
PulsarConversation,
Message,
)
except ImportError as e:
logger.error(
f"Failed to import PulsarConversation after installation: {str(e)}"
)
return False
# Try to set up test broker
success, message = setup_test_broker()
if not success:
logger.error(
f"\nFailed to set up test environment: {message}"
)
return False
logger.info("Pulsar setup check passed successfully")
return True
def run_test(self, test_func: Callable) -> Dict:
"""Run a single test and return its result."""
start_time = time.time()
test_name = test_func.__name__
try:
logger.info(f"Running test: {test_name}")
test_func()
success = True
error = None
status = "PASSED"
except Exception as e:
success = False
error = str(e)
status = "FAILED"
logger.error(f"Test {test_name} failed: {error}")
end_time = time.time()
duration = round(end_time - start_time, 3)
result = {
"test_name": test_name,
"success": success,
"duration": duration,
"error": error,
"timestamp": datetime.now().isoformat(),
"status": status,
}
self.test_results["total_tests"] += 1
if success:
self.test_results["passed_tests"] += 1
else:
self.test_results["failed_tests"] += 1
self.test_results["results"].append(result)
return result
def test_initialization(self):
"""Test PulsarConversation initialization."""
conversation = PulsarConversation(
pulsar_host=self.pulsar_host,
system_prompt="Test system prompt",
)
assert conversation.conversation_id is not None
assert conversation.health_check()["client_connected"] is True
conversation.__del__()
def test_add_message(self):
"""Test adding a message."""
conversation = PulsarConversation(
pulsar_host=self.pulsar_host
)
msg_id = conversation.add("user", "Test message")
assert msg_id is not None
# Verify message was added
messages = conversation.get_messages()
assert len(messages) > 0
assert messages[0]["content"] == "Test message"
conversation.__del__()
def test_batch_add_messages(self):
"""Test adding multiple messages."""
conversation = PulsarConversation(
pulsar_host=self.pulsar_host
)
messages = [
Message(role="user", content="Message 1"),
Message(role="assistant", content="Message 2"),
]
msg_ids = conversation.batch_add(messages)
assert len(msg_ids) == 2
# Verify messages were added
stored_messages = conversation.get_messages()
assert len(stored_messages) == 2
assert stored_messages[0]["content"] == "Message 1"
assert stored_messages[1]["content"] == "Message 2"
conversation.__del__()
def test_get_messages(self):
"""Test retrieving messages."""
conversation = PulsarConversation(
pulsar_host=self.pulsar_host
)
conversation.add("user", "Test message")
messages = conversation.get_messages()
assert len(messages) > 0
conversation.__del__()
def test_search_messages(self):
"""Test searching messages."""
conversation = PulsarConversation(
pulsar_host=self.pulsar_host
)
conversation.add("user", "Unique test message")
results = conversation.search("unique")
assert len(results) > 0
conversation.__del__()
def test_conversation_clear(self):
"""Test clearing conversation."""
conversation = PulsarConversation(
pulsar_host=self.pulsar_host
)
conversation.add("user", "Test message")
conversation.clear()
messages = conversation.get_messages()
assert len(messages) == 0
conversation.__del__()
def test_conversation_export_import(self):
"""Test exporting and importing conversation."""
conversation = PulsarConversation(
pulsar_host=self.pulsar_host
)
conversation.add("user", "Test message")
conversation.export_conversation("test_export.json")
new_conversation = PulsarConversation(
pulsar_host=self.pulsar_host
)
new_conversation.import_conversation("test_export.json")
messages = new_conversation.get_messages()
assert len(messages) > 0
conversation.__del__()
new_conversation.__del__()
def test_message_count(self):
"""Test message counting."""
conversation = PulsarConversation(
pulsar_host=self.pulsar_host
)
conversation.add("user", "Message 1")
conversation.add("assistant", "Message 2")
counts = conversation.count_messages_by_role()
assert counts["user"] == 1
assert counts["assistant"] == 1
conversation.__del__()
def test_conversation_string(self):
"""Test string representation."""
conversation = PulsarConversation(
pulsar_host=self.pulsar_host
)
conversation.add("user", "Test message")
string_rep = conversation.get_str()
assert "Test message" in string_rep
conversation.__del__()
def test_conversation_json(self):
"""Test JSON conversion."""
conversation = PulsarConversation(
pulsar_host=self.pulsar_host
)
conversation.add("user", "Test message")
json_data = conversation.to_json()
assert isinstance(json_data, str)
assert "Test message" in json_data
conversation.__del__()
def test_conversation_yaml(self):
"""Test YAML conversion."""
conversation = PulsarConversation(
pulsar_host=self.pulsar_host
)
conversation.add("user", "Test message")
yaml_data = conversation.to_yaml()
assert isinstance(yaml_data, str)
assert "Test message" in yaml_data
conversation.__del__()
def test_last_message(self):
"""Test getting last message."""
conversation = PulsarConversation(
pulsar_host=self.pulsar_host
)
conversation.add("user", "Test message")
last_msg = conversation.get_last_message()
assert last_msg["content"] == "Test message"
conversation.__del__()
def test_messages_by_role(self):
"""Test getting messages by role."""
conversation = PulsarConversation(
pulsar_host=self.pulsar_host
)
conversation.add("user", "User message")
conversation.add("assistant", "Assistant message")
user_messages = conversation.get_messages_by_role("user")
assert len(user_messages) == 1
conversation.__del__()
def test_conversation_summary(self):
"""Test getting conversation summary."""
conversation = PulsarConversation(
pulsar_host=self.pulsar_host
)
conversation.add("user", "Test message")
summary = conversation.get_conversation_summary()
assert summary["message_count"] == 1
conversation.__del__()
def test_conversation_statistics(self):
"""Test getting conversation statistics."""
conversation = PulsarConversation(
pulsar_host=self.pulsar_host
)
conversation.add("user", "Test message")
stats = conversation.get_statistics()
assert stats["total_messages"] == 1
conversation.__del__()
def test_health_check(self):
"""Test health check functionality."""
conversation = PulsarConversation(
pulsar_host=self.pulsar_host
)
health = conversation.health_check()
assert health["client_connected"] is True
conversation.__del__()
def test_cache_stats(self):
"""Test cache statistics."""
conversation = PulsarConversation(
pulsar_host=self.pulsar_host
)
stats = conversation.get_cache_stats()
assert "hits" in stats
assert "misses" in stats
conversation.__del__()
def run_all_tests(self):
"""Run all test cases."""
if not self.check_pulsar_setup():
logger.error(
"Pulsar setup check failed. Please check the error messages above."
)
return
test_methods = [
method
for method in dir(self)
if method.startswith("test_")
and callable(getattr(self, method))
]
logger.info(f"Running {len(test_methods)} tests...")
for method_name in test_methods:
test_method = getattr(self, method_name)
self.run_test(test_method)
self.save_results()
def save_results(self):
"""Save test results to JSON file."""
total_tests = (
self.test_results["passed_tests"]
+ self.test_results["failed_tests"]
)
if total_tests > 0:
self.test_results["success_rate"] = round(
(self.test_results["passed_tests"] / total_tests)
* 100,
2,
)
else:
self.test_results["success_rate"] = 0
# Add test environment info
self.test_results["environment"] = {
"pulsar_host": self.pulsar_host,
"pulsar_port": self.port,
"pulsar_client_installed": check_pulsar_client_installed(),
"os": os.uname().sysname,
"python_version": subprocess.check_output(
["python", "--version"]
)
.decode()
.strip(),
}
with open("pulsar_test_results.json", "w") as f:
json.dump(self.test_results, f, indent=2)
logger.info(
f"\nTest Results Summary:\n"
f"Total tests: {self.test_results['total_tests']}\n"
f"Passed: {self.test_results['passed_tests']}\n"
f"Failed: {self.test_results['failed_tests']}\n"
f"Skipped: {self.test_results['skipped_tests']}\n"
f"Success rate: {self.test_results['success_rate']}%\n"
f"Results saved to: pulsar_test_results.json"
)
if __name__ == "__main__":
try:
test_suite = PulsarTestSuite()
test_suite.run_all_tests()
except KeyboardInterrupt:
logger.warning("Tests interrupted by user")
exit(1)
except Exception as e:
logger.error(f"Test suite failed: {str(e)}")
exit(1)