parent
e9a7c7994c
commit
ed46063dcc
@ -0,0 +1,241 @@
|
|||||||
|
from swarms.structs.conversation import Conversation
|
||||||
|
import time
|
||||||
|
import threading
|
||||||
|
import random
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
|
def test_conversation_cache():
|
||||||
|
"""
|
||||||
|
Test the caching functionality of the Conversation class.
|
||||||
|
This test demonstrates:
|
||||||
|
1. Cache hits and misses
|
||||||
|
2. Token counting with caching
|
||||||
|
3. Cache statistics
|
||||||
|
4. Thread safety
|
||||||
|
5. Different content types
|
||||||
|
6. Edge cases
|
||||||
|
7. Performance metrics
|
||||||
|
"""
|
||||||
|
print("\n=== Testing Conversation Cache ===")
|
||||||
|
|
||||||
|
# Create a conversation with caching enabled
|
||||||
|
conv = Conversation(cache_enabled=True)
|
||||||
|
|
||||||
|
# Test 1: Basic caching with repeated messages
|
||||||
|
print("\nTest 1: Basic caching with repeated messages")
|
||||||
|
message = "This is a test message that should be cached"
|
||||||
|
|
||||||
|
# First add (should be a cache miss)
|
||||||
|
print("\nAdding first message...")
|
||||||
|
conv.add("user", message)
|
||||||
|
time.sleep(0.1) # Wait for token counting thread
|
||||||
|
|
||||||
|
# Second add (should be a cache hit)
|
||||||
|
print("\nAdding same message again...")
|
||||||
|
conv.add("user", message)
|
||||||
|
time.sleep(0.1) # Wait for token counting thread
|
||||||
|
|
||||||
|
# Check cache stats
|
||||||
|
stats = conv.get_cache_stats()
|
||||||
|
print("\nCache stats after repeated message:")
|
||||||
|
print(f"Hits: {stats['hits']}")
|
||||||
|
print(f"Misses: {stats['misses']}")
|
||||||
|
print(f"Cached tokens: {stats['cached_tokens']}")
|
||||||
|
print(f"Hit rate: {stats['hit_rate']:.2%}")
|
||||||
|
|
||||||
|
# Test 2: Different content types
|
||||||
|
print("\nTest 2: Different content types")
|
||||||
|
|
||||||
|
# Test with dictionary
|
||||||
|
dict_content = {"key": "value", "nested": {"inner": "data"}}
|
||||||
|
print("\nAdding dictionary content...")
|
||||||
|
conv.add("user", dict_content)
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
# Test with list
|
||||||
|
list_content = ["item1", "item2", {"nested": "data"}]
|
||||||
|
print("\nAdding list content...")
|
||||||
|
conv.add("user", list_content)
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
# Test 3: Thread safety
|
||||||
|
print("\nTest 3: Thread safety with concurrent adds")
|
||||||
|
|
||||||
|
def add_message(msg):
|
||||||
|
conv.add("user", msg)
|
||||||
|
|
||||||
|
# Add multiple messages concurrently
|
||||||
|
messages = [f"Concurrent message {i}" for i in range(5)]
|
||||||
|
for msg in messages:
|
||||||
|
add_message(msg)
|
||||||
|
|
||||||
|
time.sleep(0.5) # Wait for all token counting threads
|
||||||
|
|
||||||
|
# Test 4: Cache with different message lengths
|
||||||
|
print("\nTest 4: Cache with different message lengths")
|
||||||
|
|
||||||
|
# Short message
|
||||||
|
short_msg = "Short"
|
||||||
|
conv.add("user", short_msg)
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
# Long message
|
||||||
|
long_msg = "This is a much longer message that should have more tokens and might be cached differently"
|
||||||
|
conv.add("user", long_msg)
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
# Test 5: Cache statistics after all tests
|
||||||
|
print("\nTest 5: Final cache statistics")
|
||||||
|
final_stats = conv.get_cache_stats()
|
||||||
|
print("\nFinal cache stats:")
|
||||||
|
print(f"Total hits: {final_stats['hits']}")
|
||||||
|
print(f"Total misses: {final_stats['misses']}")
|
||||||
|
print(f"Total cached tokens: {final_stats['cached_tokens']}")
|
||||||
|
print(f"Total tokens: {final_stats['total_tokens']}")
|
||||||
|
print(f"Overall hit rate: {final_stats['hit_rate']:.2%}")
|
||||||
|
|
||||||
|
# Test 6: Display conversation with cache status
|
||||||
|
print("\nTest 6: Display conversation with cache status")
|
||||||
|
print("\nConversation history:")
|
||||||
|
print(conv.get_str())
|
||||||
|
|
||||||
|
# Test 7: Cache disabled
|
||||||
|
print("\nTest 7: Cache disabled")
|
||||||
|
conv_disabled = Conversation(cache_enabled=False)
|
||||||
|
conv_disabled.add("user", message)
|
||||||
|
time.sleep(0.1)
|
||||||
|
conv_disabled.add("user", message)
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
disabled_stats = conv_disabled.get_cache_stats()
|
||||||
|
print("\nCache stats with caching disabled:")
|
||||||
|
print(f"Hits: {disabled_stats['hits']}")
|
||||||
|
print(f"Misses: {disabled_stats['misses']}")
|
||||||
|
print(f"Cached tokens: {disabled_stats['cached_tokens']}")
|
||||||
|
|
||||||
|
# Test 8: High concurrency stress test
|
||||||
|
print("\nTest 8: High concurrency stress test")
|
||||||
|
conv_stress = Conversation(cache_enabled=True)
|
||||||
|
|
||||||
|
def stress_test_worker(messages: List[str]):
|
||||||
|
for msg in messages:
|
||||||
|
conv_stress.add("user", msg)
|
||||||
|
time.sleep(random.uniform(0.01, 0.05))
|
||||||
|
|
||||||
|
# Create multiple threads with different messages
|
||||||
|
threads = []
|
||||||
|
for i in range(5):
|
||||||
|
thread_messages = [
|
||||||
|
f"Stress test message {i}_{j}" for j in range(10)
|
||||||
|
]
|
||||||
|
t = threading.Thread(
|
||||||
|
target=stress_test_worker, args=(thread_messages,)
|
||||||
|
)
|
||||||
|
threads.append(t)
|
||||||
|
t.start()
|
||||||
|
|
||||||
|
# Wait for all threads to complete
|
||||||
|
for t in threads:
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
time.sleep(0.5) # Wait for token counting
|
||||||
|
stress_stats = conv_stress.get_cache_stats()
|
||||||
|
print("\nStress test stats:")
|
||||||
|
print(
|
||||||
|
f"Total messages: {stress_stats['hits'] + stress_stats['misses']}"
|
||||||
|
)
|
||||||
|
print(f"Cache hits: {stress_stats['hits']}")
|
||||||
|
print(f"Cache misses: {stress_stats['misses']}")
|
||||||
|
|
||||||
|
# Test 9: Complex nested structures
|
||||||
|
print("\nTest 9: Complex nested structures")
|
||||||
|
complex_content = {
|
||||||
|
"nested": {
|
||||||
|
"array": [1, 2, 3, {"deep": "value"}],
|
||||||
|
"object": {
|
||||||
|
"key": "value",
|
||||||
|
"nested_array": ["a", "b", "c"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"simple": "value",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add complex content multiple times
|
||||||
|
for _ in range(3):
|
||||||
|
conv.add("user", complex_content)
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
# Test 10: Large message test
|
||||||
|
print("\nTest 10: Large message test")
|
||||||
|
large_message = "x" * 10000 # 10KB message
|
||||||
|
conv.add("user", large_message)
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
# Test 11: Mixed content types in sequence
|
||||||
|
print("\nTest 11: Mixed content types in sequence")
|
||||||
|
mixed_sequence = [
|
||||||
|
"Simple string",
|
||||||
|
{"key": "value"},
|
||||||
|
["array", "items"],
|
||||||
|
"Simple string", # Should be cached
|
||||||
|
{"key": "value"}, # Should be cached
|
||||||
|
["array", "items"], # Should be cached
|
||||||
|
]
|
||||||
|
|
||||||
|
for content in mixed_sequence:
|
||||||
|
conv.add("user", content)
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
# Test 12: Cache performance metrics
|
||||||
|
print("\nTest 12: Cache performance metrics")
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Add 100 messages quickly
|
||||||
|
for i in range(100):
|
||||||
|
conv.add("user", f"Performance test message {i}")
|
||||||
|
|
||||||
|
end_time = time.time()
|
||||||
|
performance_stats = conv.get_cache_stats()
|
||||||
|
|
||||||
|
print("\nPerformance metrics:")
|
||||||
|
print(f"Time taken: {end_time - start_time:.2f} seconds")
|
||||||
|
print(f"Messages per second: {100 / (end_time - start_time):.2f}")
|
||||||
|
print(f"Cache hit rate: {performance_stats['hit_rate']:.2%}")
|
||||||
|
|
||||||
|
# Test 13: Cache with special characters
|
||||||
|
print("\nTest 13: Cache with special characters")
|
||||||
|
special_chars = [
|
||||||
|
"Hello! @#$%^&*()",
|
||||||
|
"Unicode: 你好世界",
|
||||||
|
"Emoji: 😀🎉🌟",
|
||||||
|
"Hello! @#$%^&*()", # Should be cached
|
||||||
|
"Unicode: 你好世界", # Should be cached
|
||||||
|
"Emoji: 😀🎉🌟", # Should be cached
|
||||||
|
]
|
||||||
|
|
||||||
|
for content in special_chars:
|
||||||
|
conv.add("user", content)
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
# Test 14: Cache with different roles
|
||||||
|
print("\nTest 14: Cache with different roles")
|
||||||
|
roles = ["user", "assistant", "system", "function"]
|
||||||
|
for role in roles:
|
||||||
|
conv.add(role, "Same message different role")
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
# Final statistics
|
||||||
|
print("\n=== Final Cache Statistics ===")
|
||||||
|
final_stats = conv.get_cache_stats()
|
||||||
|
print(f"Total hits: {final_stats['hits']}")
|
||||||
|
print(f"Total misses: {final_stats['misses']}")
|
||||||
|
print(f"Total cached tokens: {final_stats['cached_tokens']}")
|
||||||
|
print(f"Total tokens: {final_stats['total_tokens']}")
|
||||||
|
print(f"Overall hit rate: {final_stats['hit_rate']:.2%}")
|
||||||
|
|
||||||
|
print("\n=== Cache Testing Complete ===")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_conversation_cache()
|
Loading…
Reference in new issue