You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
561 lines
15 KiB
561 lines
15 KiB
import os
|
|
from loguru import logger
|
|
from swarms.structs.conversation import Conversation
|
|
|
|
|
|
def assert_equal(actual, expected, message=""):
|
|
"""Custom assertion function for equality"""
|
|
if actual != expected:
|
|
logger.error(
|
|
f"Assertion failed: {message}\nExpected: {expected}\nActual: {actual}"
|
|
)
|
|
raise AssertionError(
|
|
f"{message}\nExpected: {expected}\nActual: {actual}"
|
|
)
|
|
logger.success(f"Assertion passed: {message}")
|
|
|
|
|
|
def assert_true(condition, message=""):
|
|
"""Custom assertion function for boolean conditions"""
|
|
if not condition:
|
|
logger.error(f"Assertion failed: {message}")
|
|
raise AssertionError(message)
|
|
logger.success(f"Assertion passed: {message}")
|
|
|
|
|
|
def test_conversation_initialization():
|
|
"""Test conversation initialization with different parameters"""
|
|
logger.info("Testing conversation initialization")
|
|
|
|
# Test default initialization
|
|
conv = Conversation()
|
|
assert_true(
|
|
isinstance(conv, Conversation),
|
|
"Should create Conversation instance",
|
|
)
|
|
assert_equal(
|
|
conv.provider,
|
|
"in-memory",
|
|
"Default provider should be in-memory",
|
|
)
|
|
|
|
# Test with custom parameters
|
|
conv = Conversation(
|
|
name="test-conv",
|
|
system_prompt="Test system prompt",
|
|
time_enabled=True,
|
|
token_count=True,
|
|
)
|
|
assert_equal(
|
|
conv.name, "test-conv", "Name should be set correctly"
|
|
)
|
|
assert_equal(
|
|
conv.system_prompt,
|
|
"Test system prompt",
|
|
"System prompt should be set",
|
|
)
|
|
assert_true(conv.time_enabled, "Time should be enabled")
|
|
assert_true(conv.token_count, "Token count should be enabled")
|
|
|
|
|
|
def test_add_message():
|
|
"""Test adding messages to conversation"""
|
|
logger.info("Testing add message functionality")
|
|
|
|
conv = Conversation(time_enabled=True, token_count=True)
|
|
|
|
# Test adding text message
|
|
conv.add("user", "Hello, world!")
|
|
assert_equal(
|
|
len(conv.conversation_history), 1, "Should have one message"
|
|
)
|
|
assert_equal(
|
|
conv.conversation_history[0]["role"],
|
|
"user",
|
|
"Role should be user",
|
|
)
|
|
assert_equal(
|
|
conv.conversation_history[0]["content"],
|
|
"Hello, world!",
|
|
"Content should match",
|
|
)
|
|
|
|
# Test adding dict message
|
|
dict_msg = {"key": "value"}
|
|
conv.add("assistant", dict_msg)
|
|
assert_equal(
|
|
len(conv.conversation_history), 2, "Should have two messages"
|
|
)
|
|
assert_equal(
|
|
conv.conversation_history[1]["role"],
|
|
"assistant",
|
|
"Role should be assistant",
|
|
)
|
|
assert_equal(
|
|
conv.conversation_history[1]["content"],
|
|
dict_msg,
|
|
"Content should match dict",
|
|
)
|
|
|
|
|
|
def test_delete_message():
|
|
"""Test deleting messages from conversation"""
|
|
logger.info("Testing delete message functionality")
|
|
|
|
conv = Conversation()
|
|
conv.add("user", "Message 1")
|
|
conv.add("user", "Message 2")
|
|
|
|
initial_length = len(conv.conversation_history)
|
|
conv.delete("0") # Delete first message
|
|
|
|
assert_equal(
|
|
len(conv.conversation_history),
|
|
initial_length - 1,
|
|
"Conversation history should be shorter by one",
|
|
)
|
|
assert_equal(
|
|
conv.conversation_history[0]["content"],
|
|
"Message 2",
|
|
"Remaining message should be Message 2",
|
|
)
|
|
|
|
|
|
def test_update_message():
|
|
"""Test updating messages in conversation"""
|
|
logger.info("Testing update message functionality")
|
|
|
|
conv = Conversation()
|
|
conv.add("user", "Original message")
|
|
|
|
conv.update("0", "user", "Updated message")
|
|
assert_equal(
|
|
conv.conversation_history[0]["content"],
|
|
"Updated message",
|
|
"Message should be updated",
|
|
)
|
|
|
|
|
|
def test_search_messages():
|
|
"""Test searching messages in conversation"""
|
|
logger.info("Testing search functionality")
|
|
|
|
conv = Conversation()
|
|
conv.add("user", "Hello world")
|
|
conv.add("assistant", "Hello user")
|
|
conv.add("user", "Goodbye world")
|
|
|
|
results = conv.search("Hello")
|
|
assert_equal(
|
|
len(results), 2, "Should find two messages with 'Hello'"
|
|
)
|
|
|
|
results = conv.search("Goodbye")
|
|
assert_equal(
|
|
len(results), 1, "Should find one message with 'Goodbye'"
|
|
)
|
|
|
|
|
|
def test_export_import():
|
|
"""Test exporting and importing conversation"""
|
|
logger.info("Testing export/import functionality")
|
|
|
|
conv = Conversation(name="export-test")
|
|
conv.add("user", "Test message")
|
|
|
|
# Test JSON export/import
|
|
test_file = "test_conversation_export.json"
|
|
conv.export_conversation(test_file)
|
|
|
|
assert_true(os.path.exists(test_file), "Export file should exist")
|
|
|
|
new_conv = Conversation(name="import-test")
|
|
new_conv.import_conversation(test_file)
|
|
|
|
assert_equal(
|
|
len(new_conv.conversation_history),
|
|
len(conv.conversation_history),
|
|
"Imported conversation should have same number of messages",
|
|
)
|
|
|
|
# Cleanup
|
|
os.remove(test_file)
|
|
|
|
|
|
def test_message_counting():
|
|
"""Test message counting functionality"""
|
|
logger.info("Testing message counting functionality")
|
|
|
|
conv = Conversation()
|
|
conv.add("user", "User message")
|
|
conv.add("assistant", "Assistant message")
|
|
conv.add("system", "System message")
|
|
|
|
counts = conv.count_messages_by_role()
|
|
assert_equal(counts["user"], 1, "Should have one user message")
|
|
assert_equal(
|
|
counts["assistant"], 1, "Should have one assistant message"
|
|
)
|
|
assert_equal(
|
|
counts["system"], 1, "Should have one system message"
|
|
)
|
|
|
|
|
|
def test_conversation_string_representation():
|
|
"""Test string representation methods"""
|
|
logger.info("Testing string representation methods")
|
|
|
|
conv = Conversation()
|
|
conv.add("user", "Test message")
|
|
|
|
str_repr = conv.return_history_as_string()
|
|
assert_true(
|
|
"user: Test message" in str_repr,
|
|
"String representation should contain message",
|
|
)
|
|
|
|
json_repr = conv.to_json()
|
|
assert_true(
|
|
isinstance(json_repr, str),
|
|
"JSON representation should be string",
|
|
)
|
|
assert_true(
|
|
"Test message" in json_repr,
|
|
"JSON should contain message content",
|
|
)
|
|
|
|
|
|
def test_memory_management():
|
|
"""Test memory management functions"""
|
|
logger.info("Testing memory management functions")
|
|
|
|
conv = Conversation()
|
|
conv.add("user", "Message 1")
|
|
conv.add("assistant", "Message 2")
|
|
|
|
# Test clear
|
|
conv.clear()
|
|
assert_equal(
|
|
len(conv.conversation_history),
|
|
0,
|
|
"History should be empty after clear",
|
|
)
|
|
|
|
# Test truncate
|
|
conv = Conversation(context_length=100, token_count=True)
|
|
long_message = (
|
|
"This is a very long message that should be truncated " * 10
|
|
)
|
|
conv.add("user", long_message)
|
|
conv.truncate_memory_with_tokenizer()
|
|
assert_true(
|
|
len(conv.conversation_history[0]["content"])
|
|
< len(long_message),
|
|
"Message should be truncated",
|
|
)
|
|
|
|
|
|
def test_backend_initialization():
|
|
"""Test different backend initializations"""
|
|
logger.info("Testing backend initialization")
|
|
|
|
# Test Redis backend
|
|
conv = Conversation(
|
|
backend="redis",
|
|
redis_host="localhost",
|
|
redis_port=6379,
|
|
redis_db=0,
|
|
use_embedded_redis=True,
|
|
)
|
|
assert_equal(conv.backend, "redis", "Backend should be redis")
|
|
|
|
# Test SQLite backend
|
|
conv = Conversation(
|
|
backend="sqlite",
|
|
db_path=":memory:",
|
|
table_name="test_conversations",
|
|
)
|
|
assert_equal(conv.backend, "sqlite", "Backend should be sqlite")
|
|
|
|
# Test DuckDB backend
|
|
conv = Conversation(
|
|
backend="duckdb",
|
|
db_path=":memory:",
|
|
table_name="test_conversations",
|
|
)
|
|
assert_equal(conv.backend, "duckdb", "Backend should be duckdb")
|
|
|
|
|
|
def test_conversation_with_system_prompt():
|
|
"""Test conversation with system prompt and rules"""
|
|
logger.info("Testing conversation with system prompt and rules")
|
|
|
|
conv = Conversation(
|
|
system_prompt="You are a helpful assistant",
|
|
rules="Be concise and clear",
|
|
custom_rules_prompt="Follow these guidelines",
|
|
time_enabled=True,
|
|
)
|
|
|
|
history = conv.conversation_history
|
|
assert_equal(
|
|
len(history),
|
|
3,
|
|
"Should have system prompt, rules, and custom rules",
|
|
)
|
|
assert_equal(
|
|
history[0]["content"],
|
|
"You are a helpful assistant",
|
|
"System prompt should match",
|
|
)
|
|
assert_equal(
|
|
history[1]["content"],
|
|
"Be concise and clear",
|
|
"Rules should match",
|
|
)
|
|
assert_true(
|
|
"timestamp" in history[0], "Messages should have timestamps"
|
|
)
|
|
|
|
|
|
def test_batch_operations():
|
|
"""Test batch operations on conversation"""
|
|
logger.info("Testing batch operations")
|
|
|
|
conv = Conversation()
|
|
|
|
# Test batch add
|
|
roles = ["user", "assistant", "user"]
|
|
contents = ["Hello", "Hi there", "How are you?"]
|
|
conv.add_multiple_messages(roles, contents)
|
|
|
|
assert_equal(
|
|
len(conv.conversation_history),
|
|
3,
|
|
"Should have three messages",
|
|
)
|
|
|
|
# Test batch search
|
|
results = conv.search("Hi")
|
|
assert_equal(len(results), 1, "Should find one message with 'Hi'")
|
|
|
|
|
|
def test_conversation_export_formats():
|
|
"""Test different export formats"""
|
|
logger.info("Testing export formats")
|
|
|
|
conv = Conversation(name="export-test")
|
|
conv.add("user", "Test message")
|
|
|
|
# Test YAML export
|
|
conv.export_method = "yaml"
|
|
conv.save_filepath = "test_conversation.yaml"
|
|
conv.export()
|
|
assert_true(
|
|
os.path.exists("test_conversation.yaml"),
|
|
"YAML file should exist",
|
|
)
|
|
|
|
# Test JSON export
|
|
conv.export_method = "json"
|
|
conv.save_filepath = "test_conversation.json"
|
|
conv.export()
|
|
assert_true(
|
|
os.path.exists("test_conversation.json"),
|
|
"JSON file should exist",
|
|
)
|
|
|
|
# Cleanup
|
|
os.remove("test_conversation.yaml")
|
|
os.remove("test_conversation.json")
|
|
|
|
|
|
def test_conversation_with_token_counting():
|
|
"""Test conversation with token counting enabled"""
|
|
logger.info("Testing token counting functionality")
|
|
|
|
conv = Conversation(
|
|
token_count=True,
|
|
tokenizer_model_name="gpt-4.1",
|
|
context_length=1000,
|
|
)
|
|
|
|
conv.add("user", "This is a test message")
|
|
assert_true(
|
|
"token_count" in conv.conversation_history[0],
|
|
"Message should have token count",
|
|
)
|
|
|
|
# Test token counting with different message types
|
|
conv.add(
|
|
"assistant", {"response": "This is a structured response"}
|
|
)
|
|
assert_true(
|
|
"token_count" in conv.conversation_history[1],
|
|
"Structured message should have token count",
|
|
)
|
|
|
|
|
|
def test_conversation_message_categories():
|
|
"""Test conversation with message categories"""
|
|
logger.info("Testing message categories")
|
|
|
|
conv = Conversation()
|
|
|
|
# Add messages with categories
|
|
conv.add("user", "Input message", category="input")
|
|
conv.add("assistant", "Output message", category="output")
|
|
|
|
# Test category counting
|
|
token_counts = conv.export_and_count_categories()
|
|
assert_true(
|
|
"input_tokens" in token_counts,
|
|
"Should have input token count",
|
|
)
|
|
assert_true(
|
|
"output_tokens" in token_counts,
|
|
"Should have output token count",
|
|
)
|
|
assert_true(
|
|
"total_tokens" in token_counts,
|
|
"Should have total token count",
|
|
)
|
|
|
|
|
|
def test_conversation_persistence():
|
|
"""Test conversation persistence and loading"""
|
|
logger.info("Testing conversation persistence")
|
|
|
|
# Create and save conversation
|
|
conv1 = Conversation(
|
|
name="persistence-test",
|
|
system_prompt="Test prompt",
|
|
time_enabled=True,
|
|
autosave=True,
|
|
)
|
|
conv1.add("user", "Test message")
|
|
conv1.export()
|
|
|
|
# Load conversation
|
|
conv2 = Conversation.load_conversation(name="persistence-test")
|
|
assert_equal(
|
|
conv2.system_prompt,
|
|
"Test prompt",
|
|
"System prompt should persist",
|
|
)
|
|
assert_equal(
|
|
len(conv2.conversation_history),
|
|
2,
|
|
"Should have system prompt and message",
|
|
)
|
|
|
|
|
|
def test_conversation_utilities():
|
|
"""Test various utility methods"""
|
|
logger.info("Testing utility methods")
|
|
|
|
conv = Conversation(message_id_on=True)
|
|
conv.add("user", "First message")
|
|
conv.add("assistant", "Second message")
|
|
|
|
# Test getting last message
|
|
last_msg = conv.get_last_message_as_string()
|
|
assert_true(
|
|
"Second message" in last_msg,
|
|
"Should get correct last message",
|
|
)
|
|
|
|
# Test getting messages as list
|
|
msg_list = conv.return_messages_as_list()
|
|
assert_equal(len(msg_list), 2, "Should have two messages in list")
|
|
|
|
# Test getting messages as dictionary
|
|
msg_dict = conv.return_messages_as_dictionary()
|
|
assert_equal(
|
|
len(msg_dict), 2, "Should have two messages in dictionary"
|
|
)
|
|
|
|
# Test message IDs
|
|
assert_true(
|
|
"message_id" in conv.conversation_history[0],
|
|
"Messages should have IDs when enabled",
|
|
)
|
|
|
|
|
|
def test_conversation_error_handling():
|
|
"""Test error handling in conversation methods"""
|
|
logger.info("Testing error handling")
|
|
|
|
conv = Conversation()
|
|
|
|
# Test invalid export method
|
|
try:
|
|
conv.export_method = "invalid"
|
|
conv.export()
|
|
assert_true(
|
|
False, "Should raise ValueError for invalid export method"
|
|
)
|
|
except ValueError:
|
|
assert_true(
|
|
True, "Should catch ValueError for invalid export method"
|
|
)
|
|
|
|
# Test invalid backend
|
|
try:
|
|
Conversation(backend="invalid_backend")
|
|
assert_true(
|
|
False, "Should raise ValueError for invalid backend"
|
|
)
|
|
except ValueError:
|
|
assert_true(
|
|
True, "Should catch ValueError for invalid backend"
|
|
)
|
|
|
|
|
|
def run_all_tests():
|
|
"""Run all test functions"""
|
|
logger.info("Starting all tests")
|
|
|
|
test_functions = [
|
|
test_conversation_initialization,
|
|
test_add_message,
|
|
test_delete_message,
|
|
test_update_message,
|
|
test_search_messages,
|
|
test_export_import,
|
|
test_message_counting,
|
|
test_conversation_string_representation,
|
|
test_memory_management,
|
|
test_backend_initialization,
|
|
test_conversation_with_system_prompt,
|
|
test_batch_operations,
|
|
test_conversation_export_formats,
|
|
test_conversation_with_token_counting,
|
|
test_conversation_message_categories,
|
|
test_conversation_persistence,
|
|
test_conversation_utilities,
|
|
test_conversation_error_handling,
|
|
]
|
|
|
|
passed = 0
|
|
failed = 0
|
|
|
|
for test_func in test_functions:
|
|
try:
|
|
logger.info(f"Running {test_func.__name__}")
|
|
test_func()
|
|
passed += 1
|
|
logger.success(f"{test_func.__name__} passed")
|
|
except Exception as e:
|
|
failed += 1
|
|
logger.error(f"{test_func.__name__} failed: {str(e)}")
|
|
|
|
logger.info(f"Test summary: {passed} passed, {failed} failed")
|
|
return passed, failed
|
|
|
|
|
|
if __name__ == "__main__":
|
|
passed, failed = run_all_tests()
|
|
if failed > 0:
|
|
exit(1)
|