You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/tests/structs/test_conversation.py

568 lines
18 KiB

import os
import json
import time
import datetime
import yaml
from swarms.structs.conversation import (
Conversation,
generate_conversation_id,
)
def run_all_tests():
"""Run all tests for the Conversation class"""
test_results = []
def run_test(test_func):
try:
test_func()
test_results.append(f"{test_func.__name__} passed")
except Exception as e:
test_results.append(
f"{test_func.__name__} failed: {str(e)}"
)
def test_basic_initialization():
"""Test basic initialization of Conversation"""
conv = Conversation()
assert conv.id is not None
assert conv.conversation_history is not None
assert isinstance(conv.conversation_history, list)
# Test with custom ID
custom_id = generate_conversation_id()
conv_with_id = Conversation(id=custom_id)
assert conv_with_id.id == custom_id
# Test with custom name
conv_with_name = Conversation(name="Test Conversation")
assert conv_with_name.name == "Test Conversation"
def test_initialization_with_settings():
"""Test initialization with various settings"""
conv = Conversation(
system_prompt="Test system prompt",
time_enabled=True,
autosave=True,
token_count=True,
provider="in-memory",
context_length=4096,
rules="Test rules",
custom_rules_prompt="Custom rules",
user="TestUser:",
save_as_yaml=True,
save_as_json_bool=True,
)
# Test all settings
assert conv.system_prompt == "Test system prompt"
assert conv.time_enabled is True
assert conv.autosave is True
assert conv.token_count is True
assert conv.provider == "in-memory"
assert conv.context_length == 4096
assert conv.rules == "Test rules"
assert conv.custom_rules_prompt == "Custom rules"
assert conv.user == "TestUser:"
assert conv.save_as_yaml is True
assert conv.save_as_json_bool is True
def test_message_manipulation():
"""Test adding, deleting, and updating messages"""
conv = Conversation()
# Test adding messages with different content types
conv.add("user", "Hello") # String content
conv.add("assistant", {"response": "Hi"}) # Dict content
conv.add("system", ["Hello", "Hi"]) # List content
assert len(conv.conversation_history) == 3
assert isinstance(
conv.conversation_history[1]["content"], dict
)
assert isinstance(
conv.conversation_history[2]["content"], list
)
# Test adding multiple messages
conv.add_multiple(
["user", "assistant", "system"],
["Hi", "Hello there", "System message"],
)
assert len(conv.conversation_history) == 6
# Test updating message with different content type
conv.update(0, "user", {"updated": "content"})
assert isinstance(
conv.conversation_history[0]["content"], dict
)
# Test deleting multiple messages
conv.delete(0)
conv.delete(0)
assert len(conv.conversation_history) == 4
def test_message_retrieval():
"""Test message retrieval methods"""
conv = Conversation()
# Add messages in specific order for testing
conv.add("user", "Test message")
conv.add("assistant", "Test response")
conv.add("system", "System message")
# Test query - note: messages might have system prompt prepended
message = conv.query(0)
assert "Test message" in message["content"]
# Test search with multiple results
results = conv.search("Test")
assert (
len(results) >= 2
) # At least two messages should contain "Test"
assert any(
"Test message" in str(msg["content"]) for msg in results
)
assert any(
"Test response" in str(msg["content"]) for msg in results
)
# Test get_last_message_as_string
last_message = conv.get_last_message_as_string()
assert "System message" in last_message
# Test return_messages_as_list
messages_list = conv.return_messages_as_list()
assert (
len(messages_list) >= 3
) # At least our 3 added messages
assert any("Test message" in msg for msg in messages_list)
# Test return_messages_as_dictionary
messages_dict = conv.return_messages_as_dictionary()
assert (
len(messages_dict) >= 3
) # At least our 3 added messages
assert all(isinstance(m, dict) for m in messages_dict)
assert all(
{"role", "content"} <= set(m.keys())
for m in messages_dict
)
# Test get_final_message and content
assert "System message" in conv.get_final_message()
assert "System message" in conv.get_final_message_content()
# Test return_all_except_first
remaining_messages = conv.return_all_except_first()
assert (
len(remaining_messages) >= 2
) # At least 2 messages after removing first
# Test return_all_except_first_string
remaining_string = conv.return_all_except_first_string()
assert isinstance(remaining_string, str)
def test_saving_loading():
"""Test saving and loading conversation"""
# Test with save_enabled
conv = Conversation(
save_enabled=True,
conversations_dir="./test_conversations",
)
conv.add("user", "Test save message")
# Test save_as_json
test_file = os.path.join(
"./test_conversations", "test_conversation.json"
)
conv.save_as_json(test_file)
assert os.path.exists(test_file)
# Test load_from_json
new_conv = Conversation()
new_conv.load_from_json(test_file)
assert len(new_conv.conversation_history) == 1
assert (
new_conv.conversation_history[0]["content"]
== "Test save message"
)
# Test class method load_conversation
loaded_conv = Conversation.load_conversation(
name=conv.id, conversations_dir="./test_conversations"
)
assert loaded_conv.id == conv.id
# Cleanup
os.remove(test_file)
os.rmdir("./test_conversations")
def test_output_formats():
"""Test different output formats"""
conv = Conversation()
conv.add("user", "Test message")
conv.add("assistant", {"response": "Test"})
# Test JSON output
json_output = conv.to_json()
assert isinstance(json_output, str)
parsed_json = json.loads(json_output)
assert len(parsed_json) == 2
# Test dict output
dict_output = conv.to_dict()
assert isinstance(dict_output, list)
assert len(dict_output) == 2
# Test YAML output
yaml_output = conv.to_yaml()
assert isinstance(yaml_output, str)
parsed_yaml = yaml.safe_load(yaml_output)
assert len(parsed_yaml) == 2
# Test return_json
json_str = conv.return_json()
assert isinstance(json_str, str)
assert len(json.loads(json_str)) == 2
def test_memory_management():
"""Test memory management functions"""
conv = Conversation()
# Test clear
conv.add("user", "Test message")
conv.clear()
assert len(conv.conversation_history) == 0
# Test clear_memory
conv.add("user", "Test message")
conv.clear_memory()
assert len(conv.conversation_history) == 0
# Test batch operations
messages = [
{"role": "user", "content": "Message 1"},
{"role": "assistant", "content": "Response 1"},
]
conv.batch_add(messages)
assert len(conv.conversation_history) == 2
# Test truncate_memory_with_tokenizer
if conv.tokenizer: # Only if tokenizer is available
conv.truncate_memory_with_tokenizer()
assert len(conv.conversation_history) > 0
def test_conversation_metadata():
"""Test conversation metadata and listing"""
test_dir = "./test_conversations_metadata"
os.makedirs(test_dir, exist_ok=True)
try:
# Create a conversation with metadata
conv = Conversation(
name="Test Conv",
system_prompt="System",
rules="Rules",
custom_rules_prompt="Custom",
conversations_dir=test_dir,
save_enabled=True,
autosave=True,
)
# Add a message to trigger save
conv.add("user", "Test message")
# Give a small delay for autosave
time.sleep(0.1)
# List conversations and verify
conversations = Conversation.list_conversations(test_dir)
assert len(conversations) >= 1
found_conv = next(
(
c
for c in conversations
if c["name"] == "Test Conv"
),
None,
)
assert found_conv is not None
assert found_conv["id"] == conv.id
finally:
# Cleanup
import shutil
if os.path.exists(test_dir):
shutil.rmtree(test_dir)
def test_time_enabled_messages():
"""Test time-enabled messages"""
conv = Conversation(time_enabled=True)
conv.add("user", "Time test")
# Verify timestamp in message
message = conv.conversation_history[0]
assert "timestamp" in message
assert isinstance(message["timestamp"], str)
# Verify time in content when time_enabled is True
assert "Time:" in message["content"]
def test_provider_specific():
"""Test provider-specific functionality"""
# Test in-memory provider
conv_memory = Conversation(provider="in-memory")
conv_memory.add("user", "Test")
assert len(conv_memory.conversation_history) == 1
# Test mem0 provider if available
try:
conv_mem0 = Conversation(provider="mem0")
conv_mem0.add("user", "Test")
# Add appropriate assertions based on mem0 behavior
except:
pass # Skip if mem0 is not available
def test_tool_output():
"""Test tool output handling"""
conv = Conversation()
tool_output = {
"tool_name": "test_tool",
"output": "test result",
}
conv.add_tool_output_to_agent("tool", tool_output)
assert len(conv.conversation_history) == 1
assert conv.conversation_history[0]["role"] == "tool"
assert conv.conversation_history[0]["content"] == tool_output
def test_autosave_functionality():
"""Test autosave functionality and related features"""
test_dir = "./test_conversations_autosave"
os.makedirs(test_dir, exist_ok=True)
try:
# Test with autosave and save_enabled True
conv = Conversation(
autosave=True,
save_enabled=True,
conversations_dir=test_dir,
name="autosave_test",
)
# Add a message and verify it was auto-saved
conv.add("user", "Test autosave message")
save_path = os.path.join(test_dir, f"{conv.id}.json")
# Give a small delay for autosave to complete
time.sleep(0.1)
assert os.path.exists(
save_path
), f"Save file not found at {save_path}"
# Load the saved conversation and verify content
loaded_conv = Conversation.load_conversation(
name=conv.id, conversations_dir=test_dir
)
found_message = False
for msg in loaded_conv.conversation_history:
if "Test autosave message" in str(msg["content"]):
found_message = True
break
assert (
found_message
), "Message not found in loaded conversation"
# Clean up first conversation files
if os.path.exists(save_path):
os.remove(save_path)
# Test with save_enabled=False
conv_no_save = Conversation(
autosave=False, # Changed to False to prevent autosave
save_enabled=False,
conversations_dir=test_dir,
)
conv_no_save.add("user", "This shouldn't be saved")
save_path_no_save = os.path.join(
test_dir, f"{conv_no_save.id}.json"
)
time.sleep(0.1) # Give time for potential save
assert not os.path.exists(
save_path_no_save
), "File should not exist when save_enabled is False"
finally:
# Cleanup
import shutil
if os.path.exists(test_dir):
shutil.rmtree(test_dir)
def test_advanced_message_handling():
"""Test advanced message handling features"""
conv = Conversation()
# Test adding messages with metadata
metadata = {"timestamp": "2024-01-01", "session_id": "123"}
conv.add("user", "Test with metadata", metadata=metadata)
# Test batch operations with different content types
messages = [
{"role": "user", "content": "Message 1"},
{
"role": "assistant",
"content": {"response": "Complex response"},
},
{"role": "system", "content": ["Multiple", "Items"]},
]
conv.batch_add(messages)
assert (
len(conv.conversation_history) == 4
) # Including the first message
# Test message format consistency
for msg in conv.conversation_history:
assert "role" in msg
assert "content" in msg
if "timestamp" in msg:
assert isinstance(msg["timestamp"], str)
def test_conversation_metadata_handling():
"""Test handling of conversation metadata and attributes"""
test_dir = "./test_conversations_metadata_handling"
os.makedirs(test_dir, exist_ok=True)
try:
# Test initialization with all optional parameters
conv = Conversation(
name="Test Conv",
system_prompt="System Prompt",
time_enabled=True,
context_length=2048,
rules="Test Rules",
custom_rules_prompt="Custom Rules",
user="CustomUser:",
provider="in-memory",
conversations_dir=test_dir,
save_enabled=True,
)
# Verify all attributes are set correctly
assert conv.name == "Test Conv"
assert conv.system_prompt == "System Prompt"
assert conv.time_enabled is True
assert conv.context_length == 2048
assert conv.rules == "Test Rules"
assert conv.custom_rules_prompt == "Custom Rules"
assert conv.user == "CustomUser:"
assert conv.provider == "in-memory"
# Test saving and loading preserves metadata
conv.save_as_json()
# Load using load_conversation
loaded_conv = Conversation.load_conversation(
name=conv.id, conversations_dir=test_dir
)
# Verify metadata was preserved
assert loaded_conv.name == "Test Conv"
assert loaded_conv.system_prompt == "System Prompt"
assert loaded_conv.rules == "Test Rules"
finally:
# Cleanup
import shutil
shutil.rmtree(test_dir)
def test_time_enabled_features():
"""Test time-enabled message features"""
conv = Conversation(time_enabled=True)
# Add message and verify timestamp
conv.add("user", "Time test message")
message = conv.conversation_history[0]
# Verify timestamp format
assert "timestamp" in message
try:
datetime.datetime.fromisoformat(message["timestamp"])
except ValueError:
assert False, "Invalid timestamp format"
# Verify time in content
assert "Time:" in message["content"]
assert (
datetime.datetime.now().strftime("%Y-%m-%d")
in message["content"]
)
def test_provider_specific_features():
"""Test provider-specific features and behaviors"""
# Test in-memory provider
conv_memory = Conversation(provider="in-memory")
conv_memory.add("user", "Test in-memory")
assert len(conv_memory.conversation_history) == 1
assert (
"Test in-memory"
in conv_memory.get_last_message_as_string()
)
# Test mem0 provider if available
try:
from mem0 import AsyncMemory
# Skip actual mem0 testing since it requires async
pass
except ImportError:
pass
# Test invalid provider
invalid_provider = "invalid_provider"
try:
Conversation(provider=invalid_provider)
# If we get here, the provider was accepted when it shouldn't have been
raise AssertionError(
f"Should have raised ValueError for provider '{invalid_provider}'"
)
except ValueError:
# This is the expected behavior
pass
# Run all tests
tests = [
test_basic_initialization,
test_initialization_with_settings,
test_message_manipulation,
test_message_retrieval,
test_saving_loading,
test_output_formats,
test_memory_management,
test_conversation_metadata,
test_time_enabled_messages,
test_provider_specific,
test_tool_output,
test_autosave_functionality,
test_advanced_message_handling,
test_conversation_metadata_handling,
test_time_enabled_features,
test_provider_specific_features,
]
for test in tests:
run_test(test)
# Print results
print("\nTest Results:")
for result in test_results:
print(result)
if __name__ == "__main__":
run_all_tests()