You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							568 lines
						
					
					
						
							18 KiB
						
					
					
				
			
		
		
	
	
							568 lines
						
					
					
						
							18 KiB
						
					
					
				| import os
 | |
| import json
 | |
| import time
 | |
| import datetime
 | |
| import yaml
 | |
| from swarms.structs.conversation import (
 | |
|     Conversation,
 | |
|     generate_conversation_id,
 | |
| )
 | |
| 
 | |
| 
 | |
| def run_all_tests():
 | |
|     """Run all tests for the Conversation class"""
 | |
|     test_results = []
 | |
| 
 | |
|     def run_test(test_func):
 | |
|         try:
 | |
|             test_func()
 | |
|             test_results.append(f"✅ {test_func.__name__} passed")
 | |
|         except Exception as e:
 | |
|             test_results.append(
 | |
|                 f"❌ {test_func.__name__} failed: {str(e)}"
 | |
|             )
 | |
| 
 | |
|     def test_basic_initialization():
 | |
|         """Test basic initialization of Conversation"""
 | |
|         conv = Conversation()
 | |
|         assert conv.id is not None
 | |
|         assert conv.conversation_history is not None
 | |
|         assert isinstance(conv.conversation_history, list)
 | |
| 
 | |
|         # Test with custom ID
 | |
|         custom_id = generate_conversation_id()
 | |
|         conv_with_id = Conversation(id=custom_id)
 | |
|         assert conv_with_id.id == custom_id
 | |
| 
 | |
|         # Test with custom name
 | |
|         conv_with_name = Conversation(name="Test Conversation")
 | |
|         assert conv_with_name.name == "Test Conversation"
 | |
| 
 | |
|     def test_initialization_with_settings():
 | |
|         """Test initialization with various settings"""
 | |
|         conv = Conversation(
 | |
|             system_prompt="Test system prompt",
 | |
|             time_enabled=True,
 | |
|             autosave=True,
 | |
|             token_count=True,
 | |
|             provider="in-memory",
 | |
|             context_length=4096,
 | |
|             rules="Test rules",
 | |
|             custom_rules_prompt="Custom rules",
 | |
|             user="TestUser:",
 | |
|             save_as_yaml=True,
 | |
|             save_as_json_bool=True,
 | |
|         )
 | |
| 
 | |
|         # Test all settings
 | |
|         assert conv.system_prompt == "Test system prompt"
 | |
|         assert conv.time_enabled is True
 | |
|         assert conv.autosave is True
 | |
|         assert conv.token_count is True
 | |
|         assert conv.provider == "in-memory"
 | |
|         assert conv.context_length == 4096
 | |
|         assert conv.rules == "Test rules"
 | |
|         assert conv.custom_rules_prompt == "Custom rules"
 | |
|         assert conv.user == "TestUser:"
 | |
|         assert conv.save_as_yaml is True
 | |
|         assert conv.save_as_json_bool is True
 | |
| 
 | |
|     def test_message_manipulation():
 | |
|         """Test adding, deleting, and updating messages"""
 | |
|         conv = Conversation()
 | |
| 
 | |
|         # Test adding messages with different content types
 | |
|         conv.add("user", "Hello")  # String content
 | |
|         conv.add("assistant", {"response": "Hi"})  # Dict content
 | |
|         conv.add("system", ["Hello", "Hi"])  # List content
 | |
| 
 | |
|         assert len(conv.conversation_history) == 3
 | |
|         assert isinstance(
 | |
|             conv.conversation_history[1]["content"], dict
 | |
|         )
 | |
|         assert isinstance(
 | |
|             conv.conversation_history[2]["content"], list
 | |
|         )
 | |
| 
 | |
|         # Test adding multiple messages
 | |
|         conv.add_multiple(
 | |
|             ["user", "assistant", "system"],
 | |
|             ["Hi", "Hello there", "System message"],
 | |
|         )
 | |
|         assert len(conv.conversation_history) == 6
 | |
| 
 | |
|         # Test updating message with different content type
 | |
|         conv.update(0, "user", {"updated": "content"})
 | |
|         assert isinstance(
 | |
|             conv.conversation_history[0]["content"], dict
 | |
|         )
 | |
| 
 | |
|         # Test deleting multiple messages
 | |
|         conv.delete(0)
 | |
|         conv.delete(0)
 | |
|         assert len(conv.conversation_history) == 4
 | |
| 
 | |
|     def test_message_retrieval():
 | |
|         """Test message retrieval methods"""
 | |
|         conv = Conversation()
 | |
| 
 | |
|         # Add messages in specific order for testing
 | |
|         conv.add("user", "Test message")
 | |
|         conv.add("assistant", "Test response")
 | |
|         conv.add("system", "System message")
 | |
| 
 | |
|         # Test query - note: messages might have system prompt prepended
 | |
|         message = conv.query(0)
 | |
|         assert "Test message" in message["content"]
 | |
| 
 | |
|         # Test search with multiple results
 | |
|         results = conv.search("Test")
 | |
|         assert (
 | |
|             len(results) >= 2
 | |
|         )  # At least two messages should contain "Test"
 | |
|         assert any(
 | |
|             "Test message" in str(msg["content"]) for msg in results
 | |
|         )
 | |
|         assert any(
 | |
|             "Test response" in str(msg["content"]) for msg in results
 | |
|         )
 | |
| 
 | |
|         # Test get_last_message_as_string
 | |
|         last_message = conv.get_last_message_as_string()
 | |
|         assert "System message" in last_message
 | |
| 
 | |
|         # Test return_messages_as_list
 | |
|         messages_list = conv.return_messages_as_list()
 | |
|         assert (
 | |
|             len(messages_list) >= 3
 | |
|         )  # At least our 3 added messages
 | |
|         assert any("Test message" in msg for msg in messages_list)
 | |
| 
 | |
|         # Test return_messages_as_dictionary
 | |
|         messages_dict = conv.return_messages_as_dictionary()
 | |
|         assert (
 | |
|             len(messages_dict) >= 3
 | |
|         )  # At least our 3 added messages
 | |
|         assert all(isinstance(m, dict) for m in messages_dict)
 | |
|         assert all(
 | |
|             {"role", "content"} <= set(m.keys())
 | |
|             for m in messages_dict
 | |
|         )
 | |
| 
 | |
|         # Test get_final_message and content
 | |
|         assert "System message" in conv.get_final_message()
 | |
|         assert "System message" in conv.get_final_message_content()
 | |
| 
 | |
|         # Test return_all_except_first
 | |
|         remaining_messages = conv.return_all_except_first()
 | |
|         assert (
 | |
|             len(remaining_messages) >= 2
 | |
|         )  # At least 2 messages after removing first
 | |
| 
 | |
|         # Test return_all_except_first_string
 | |
|         remaining_string = conv.return_all_except_first_string()
 | |
|         assert isinstance(remaining_string, str)
 | |
| 
 | |
|     def test_saving_loading():
 | |
|         """Test saving and loading conversation"""
 | |
|         # Test with save_enabled
 | |
|         conv = Conversation(
 | |
|             save_enabled=True,
 | |
|             conversations_dir="./test_conversations",
 | |
|         )
 | |
|         conv.add("user", "Test save message")
 | |
| 
 | |
|         # Test save_as_json
 | |
|         test_file = os.path.join(
 | |
|             "./test_conversations", "test_conversation.json"
 | |
|         )
 | |
|         conv.save_as_json(test_file)
 | |
|         assert os.path.exists(test_file)
 | |
| 
 | |
|         # Test load_from_json
 | |
|         new_conv = Conversation()
 | |
|         new_conv.load_from_json(test_file)
 | |
|         assert len(new_conv.conversation_history) == 1
 | |
|         assert (
 | |
|             new_conv.conversation_history[0]["content"]
 | |
|             == "Test save message"
 | |
|         )
 | |
| 
 | |
|         # Test class method load_conversation
 | |
|         loaded_conv = Conversation.load_conversation(
 | |
|             name=conv.id, conversations_dir="./test_conversations"
 | |
|         )
 | |
|         assert loaded_conv.id == conv.id
 | |
| 
 | |
|         # Cleanup
 | |
|         os.remove(test_file)
 | |
|         os.rmdir("./test_conversations")
 | |
| 
 | |
|     def test_output_formats():
 | |
|         """Test different output formats"""
 | |
|         conv = Conversation()
 | |
|         conv.add("user", "Test message")
 | |
|         conv.add("assistant", {"response": "Test"})
 | |
| 
 | |
|         # Test JSON output
 | |
|         json_output = conv.to_json()
 | |
|         assert isinstance(json_output, str)
 | |
|         parsed_json = json.loads(json_output)
 | |
|         assert len(parsed_json) == 2
 | |
| 
 | |
|         # Test dict output
 | |
|         dict_output = conv.to_dict()
 | |
|         assert isinstance(dict_output, list)
 | |
|         assert len(dict_output) == 2
 | |
| 
 | |
|         # Test YAML output
 | |
|         yaml_output = conv.to_yaml()
 | |
|         assert isinstance(yaml_output, str)
 | |
|         parsed_yaml = yaml.safe_load(yaml_output)
 | |
|         assert len(parsed_yaml) == 2
 | |
| 
 | |
|         # Test return_json
 | |
|         json_str = conv.return_json()
 | |
|         assert isinstance(json_str, str)
 | |
|         assert len(json.loads(json_str)) == 2
 | |
| 
 | |
|     def test_memory_management():
 | |
|         """Test memory management functions"""
 | |
|         conv = Conversation()
 | |
| 
 | |
|         # Test clear
 | |
|         conv.add("user", "Test message")
 | |
|         conv.clear()
 | |
|         assert len(conv.conversation_history) == 0
 | |
| 
 | |
|         # Test clear_memory
 | |
|         conv.add("user", "Test message")
 | |
|         conv.clear_memory()
 | |
|         assert len(conv.conversation_history) == 0
 | |
| 
 | |
|         # Test batch operations
 | |
|         messages = [
 | |
|             {"role": "user", "content": "Message 1"},
 | |
|             {"role": "assistant", "content": "Response 1"},
 | |
|         ]
 | |
|         conv.batch_add(messages)
 | |
|         assert len(conv.conversation_history) == 2
 | |
| 
 | |
|         # Test truncate_memory_with_tokenizer
 | |
|         if conv.tokenizer:  # Only if tokenizer is available
 | |
|             conv.truncate_memory_with_tokenizer()
 | |
|             assert len(conv.conversation_history) > 0
 | |
| 
 | |
|     def test_conversation_metadata():
 | |
|         """Test conversation metadata and listing"""
 | |
|         test_dir = "./test_conversations_metadata"
 | |
|         os.makedirs(test_dir, exist_ok=True)
 | |
| 
 | |
|         try:
 | |
|             # Create a conversation with metadata
 | |
|             conv = Conversation(
 | |
|                 name="Test Conv",
 | |
|                 system_prompt="System",
 | |
|                 rules="Rules",
 | |
|                 custom_rules_prompt="Custom",
 | |
|                 conversations_dir=test_dir,
 | |
|                 save_enabled=True,
 | |
|                 autosave=True,
 | |
|             )
 | |
| 
 | |
|             # Add a message to trigger save
 | |
|             conv.add("user", "Test message")
 | |
| 
 | |
|             # Give a small delay for autosave
 | |
|             time.sleep(0.1)
 | |
| 
 | |
|             # List conversations and verify
 | |
|             conversations = Conversation.list_conversations(test_dir)
 | |
|             assert len(conversations) >= 1
 | |
|             found_conv = next(
 | |
|                 (
 | |
|                     c
 | |
|                     for c in conversations
 | |
|                     if c["name"] == "Test Conv"
 | |
|                 ),
 | |
|                 None,
 | |
|             )
 | |
|             assert found_conv is not None
 | |
|             assert found_conv["id"] == conv.id
 | |
| 
 | |
|         finally:
 | |
|             # Cleanup
 | |
|             import shutil
 | |
| 
 | |
|             if os.path.exists(test_dir):
 | |
|                 shutil.rmtree(test_dir)
 | |
| 
 | |
|     def test_time_enabled_messages():
 | |
|         """Test time-enabled messages"""
 | |
|         conv = Conversation(time_enabled=True)
 | |
|         conv.add("user", "Time test")
 | |
| 
 | |
|         # Verify timestamp in message
 | |
|         message = conv.conversation_history[0]
 | |
|         assert "timestamp" in message
 | |
|         assert isinstance(message["timestamp"], str)
 | |
| 
 | |
|         # Verify time in content when time_enabled is True
 | |
|         assert "Time:" in message["content"]
 | |
| 
 | |
|     def test_provider_specific():
 | |
|         """Test provider-specific functionality"""
 | |
|         # Test in-memory provider
 | |
|         conv_memory = Conversation(provider="in-memory")
 | |
|         conv_memory.add("user", "Test")
 | |
|         assert len(conv_memory.conversation_history) == 1
 | |
| 
 | |
|         # Test mem0 provider if available
 | |
|         try:
 | |
|             conv_mem0 = Conversation(provider="mem0")
 | |
|             conv_mem0.add("user", "Test")
 | |
|             # Add appropriate assertions based on mem0 behavior
 | |
|         except:
 | |
|             pass  # Skip if mem0 is not available
 | |
| 
 | |
|     def test_tool_output():
 | |
|         """Test tool output handling"""
 | |
|         conv = Conversation()
 | |
|         tool_output = {
 | |
|             "tool_name": "test_tool",
 | |
|             "output": "test result",
 | |
|         }
 | |
|         conv.add_tool_output_to_agent("tool", tool_output)
 | |
| 
 | |
|         assert len(conv.conversation_history) == 1
 | |
|         assert conv.conversation_history[0]["role"] == "tool"
 | |
|         assert conv.conversation_history[0]["content"] == tool_output
 | |
| 
 | |
|     def test_autosave_functionality():
 | |
|         """Test autosave functionality and related features"""
 | |
|         test_dir = "./test_conversations_autosave"
 | |
|         os.makedirs(test_dir, exist_ok=True)
 | |
| 
 | |
|         try:
 | |
|             # Test with autosave and save_enabled True
 | |
|             conv = Conversation(
 | |
|                 autosave=True,
 | |
|                 save_enabled=True,
 | |
|                 conversations_dir=test_dir,
 | |
|                 name="autosave_test",
 | |
|             )
 | |
| 
 | |
|             # Add a message and verify it was auto-saved
 | |
|             conv.add("user", "Test autosave message")
 | |
|             save_path = os.path.join(test_dir, f"{conv.id}.json")
 | |
| 
 | |
|             # Give a small delay for autosave to complete
 | |
|             time.sleep(0.1)
 | |
| 
 | |
|             assert os.path.exists(
 | |
|                 save_path
 | |
|             ), f"Save file not found at {save_path}"
 | |
| 
 | |
|             # Load the saved conversation and verify content
 | |
|             loaded_conv = Conversation.load_conversation(
 | |
|                 name=conv.id, conversations_dir=test_dir
 | |
|             )
 | |
|             found_message = False
 | |
|             for msg in loaded_conv.conversation_history:
 | |
|                 if "Test autosave message" in str(msg["content"]):
 | |
|                     found_message = True
 | |
|                     break
 | |
|             assert (
 | |
|                 found_message
 | |
|             ), "Message not found in loaded conversation"
 | |
| 
 | |
|             # Clean up first conversation files
 | |
|             if os.path.exists(save_path):
 | |
|                 os.remove(save_path)
 | |
| 
 | |
|             # Test with save_enabled=False
 | |
|             conv_no_save = Conversation(
 | |
|                 autosave=False,  # Changed to False to prevent autosave
 | |
|                 save_enabled=False,
 | |
|                 conversations_dir=test_dir,
 | |
|             )
 | |
|             conv_no_save.add("user", "This shouldn't be saved")
 | |
|             save_path_no_save = os.path.join(
 | |
|                 test_dir, f"{conv_no_save.id}.json"
 | |
|             )
 | |
|             time.sleep(0.1)  # Give time for potential save
 | |
|             assert not os.path.exists(
 | |
|                 save_path_no_save
 | |
|             ), "File should not exist when save_enabled is False"
 | |
| 
 | |
|         finally:
 | |
|             # Cleanup
 | |
|             import shutil
 | |
| 
 | |
|             if os.path.exists(test_dir):
 | |
|                 shutil.rmtree(test_dir)
 | |
| 
 | |
|     def test_advanced_message_handling():
 | |
|         """Test advanced message handling features"""
 | |
|         conv = Conversation()
 | |
| 
 | |
|         # Test adding messages with metadata
 | |
|         metadata = {"timestamp": "2024-01-01", "session_id": "123"}
 | |
|         conv.add("user", "Test with metadata", metadata=metadata)
 | |
| 
 | |
|         # Test batch operations with different content types
 | |
|         messages = [
 | |
|             {"role": "user", "content": "Message 1"},
 | |
|             {
 | |
|                 "role": "assistant",
 | |
|                 "content": {"response": "Complex response"},
 | |
|             },
 | |
|             {"role": "system", "content": ["Multiple", "Items"]},
 | |
|         ]
 | |
|         conv.batch_add(messages)
 | |
|         assert (
 | |
|             len(conv.conversation_history) == 4
 | |
|         )  # Including the first message
 | |
| 
 | |
|         # Test message format consistency
 | |
|         for msg in conv.conversation_history:
 | |
|             assert "role" in msg
 | |
|             assert "content" in msg
 | |
|             if "timestamp" in msg:
 | |
|                 assert isinstance(msg["timestamp"], str)
 | |
| 
 | |
|     def test_conversation_metadata_handling():
 | |
|         """Test handling of conversation metadata and attributes"""
 | |
|         test_dir = "./test_conversations_metadata_handling"
 | |
|         os.makedirs(test_dir, exist_ok=True)
 | |
| 
 | |
|         try:
 | |
|             # Test initialization with all optional parameters
 | |
|             conv = Conversation(
 | |
|                 name="Test Conv",
 | |
|                 system_prompt="System Prompt",
 | |
|                 time_enabled=True,
 | |
|                 context_length=2048,
 | |
|                 rules="Test Rules",
 | |
|                 custom_rules_prompt="Custom Rules",
 | |
|                 user="CustomUser:",
 | |
|                 provider="in-memory",
 | |
|                 conversations_dir=test_dir,
 | |
|                 save_enabled=True,
 | |
|             )
 | |
| 
 | |
|             # Verify all attributes are set correctly
 | |
|             assert conv.name == "Test Conv"
 | |
|             assert conv.system_prompt == "System Prompt"
 | |
|             assert conv.time_enabled is True
 | |
|             assert conv.context_length == 2048
 | |
|             assert conv.rules == "Test Rules"
 | |
|             assert conv.custom_rules_prompt == "Custom Rules"
 | |
|             assert conv.user == "CustomUser:"
 | |
|             assert conv.provider == "in-memory"
 | |
| 
 | |
|             # Test saving and loading preserves metadata
 | |
|             conv.save_as_json()
 | |
| 
 | |
|             # Load using load_conversation
 | |
|             loaded_conv = Conversation.load_conversation(
 | |
|                 name=conv.id, conversations_dir=test_dir
 | |
|             )
 | |
| 
 | |
|             # Verify metadata was preserved
 | |
|             assert loaded_conv.name == "Test Conv"
 | |
|             assert loaded_conv.system_prompt == "System Prompt"
 | |
|             assert loaded_conv.rules == "Test Rules"
 | |
| 
 | |
|         finally:
 | |
|             # Cleanup
 | |
|             import shutil
 | |
| 
 | |
|             shutil.rmtree(test_dir)
 | |
| 
 | |
|     def test_time_enabled_features():
 | |
|         """Test time-enabled message features"""
 | |
|         conv = Conversation(time_enabled=True)
 | |
| 
 | |
|         # Add message and verify timestamp
 | |
|         conv.add("user", "Time test message")
 | |
|         message = conv.conversation_history[0]
 | |
| 
 | |
|         # Verify timestamp format
 | |
|         assert "timestamp" in message
 | |
|         try:
 | |
|             datetime.datetime.fromisoformat(message["timestamp"])
 | |
|         except ValueError:
 | |
|             assert False, "Invalid timestamp format"
 | |
| 
 | |
|         # Verify time in content
 | |
|         assert "Time:" in message["content"]
 | |
|         assert (
 | |
|             datetime.datetime.now().strftime("%Y-%m-%d")
 | |
|             in message["content"]
 | |
|         )
 | |
| 
 | |
|     def test_provider_specific_features():
 | |
|         """Test provider-specific features and behaviors"""
 | |
|         # Test in-memory provider
 | |
|         conv_memory = Conversation(provider="in-memory")
 | |
|         conv_memory.add("user", "Test in-memory")
 | |
|         assert len(conv_memory.conversation_history) == 1
 | |
|         assert (
 | |
|             "Test in-memory"
 | |
|             in conv_memory.get_last_message_as_string()
 | |
|         )
 | |
| 
 | |
|         # Test mem0 provider if available
 | |
|         try:
 | |
|             from mem0 import AsyncMemory
 | |
| 
 | |
|             # Skip actual mem0 testing since it requires async
 | |
|             pass
 | |
|         except ImportError:
 | |
|             pass
 | |
| 
 | |
|         # Test invalid provider
 | |
|         invalid_provider = "invalid_provider"
 | |
|         try:
 | |
|             Conversation(provider=invalid_provider)
 | |
|             # If we get here, the provider was accepted when it shouldn't have been
 | |
|             raise AssertionError(
 | |
|                 f"Should have raised ValueError for provider '{invalid_provider}'"
 | |
|             )
 | |
|         except ValueError:
 | |
|             # This is the expected behavior
 | |
|             pass
 | |
| 
 | |
|     # Run all tests
 | |
|     tests = [
 | |
|         test_basic_initialization,
 | |
|         test_initialization_with_settings,
 | |
|         test_message_manipulation,
 | |
|         test_message_retrieval,
 | |
|         test_saving_loading,
 | |
|         test_output_formats,
 | |
|         test_memory_management,
 | |
|         test_conversation_metadata,
 | |
|         test_time_enabled_messages,
 | |
|         test_provider_specific,
 | |
|         test_tool_output,
 | |
|         test_autosave_functionality,
 | |
|         test_advanced_message_handling,
 | |
|         test_conversation_metadata_handling,
 | |
|         test_time_enabled_features,
 | |
|         test_provider_specific_features,
 | |
|     ]
 | |
| 
 | |
|     for test in tests:
 | |
|         run_test(test)
 | |
| 
 | |
|     # Print results
 | |
|     print("\nTest Results:")
 | |
|     for result in test_results:
 | |
|         print(result)
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     run_all_tests()
 |