You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
698 lines
21 KiB
698 lines
21 KiB
import shutil
|
|
from pathlib import Path
|
|
from datetime import datetime
|
|
from loguru import logger
|
|
from swarms.structs.conversation import Conversation
|
|
|
|
|
|
def setup_temp_conversations_dir():
|
|
"""Create a temporary directory for conversation cache files."""
|
|
temp_dir = Path("temp_test_conversations")
|
|
if temp_dir.exists():
|
|
shutil.rmtree(temp_dir)
|
|
temp_dir.mkdir()
|
|
logger.info(f"Created temporary test directory: {temp_dir}")
|
|
return temp_dir
|
|
|
|
|
|
def create_test_conversation(temp_dir):
|
|
"""Create a basic conversation for testing."""
|
|
conv = Conversation(
|
|
name="test_conversation", conversations_dir=str(temp_dir)
|
|
)
|
|
conv.add("user", "Hello, world!")
|
|
conv.add("assistant", "Hello, user!")
|
|
logger.info("Created test conversation with basic messages")
|
|
return conv
|
|
|
|
|
|
def test_add_message():
|
|
logger.info("Running test_add_message")
|
|
conv = Conversation()
|
|
conv.add("user", "Hello, world!")
|
|
try:
|
|
assert len(conv.conversation_history) == 1
|
|
assert conv.conversation_history[0]["role"] == "user"
|
|
assert (
|
|
conv.conversation_history[0]["content"] == "Hello, world!"
|
|
)
|
|
logger.success("test_add_message passed")
|
|
return True
|
|
except AssertionError as e:
|
|
logger.error(f"test_add_message failed: {str(e)}")
|
|
return False
|
|
|
|
|
|
def test_add_message_with_time():
|
|
logger.info("Running test_add_message_with_time")
|
|
conv = Conversation(time_enabled=False)
|
|
conv.add("user", "Hello, world!")
|
|
try:
|
|
assert len(conv.conversation_history) == 1
|
|
assert conv.conversation_history[0]["role"] == "user"
|
|
assert (
|
|
conv.conversation_history[0]["content"] == "Hello, world!"
|
|
)
|
|
assert "timestamp" in conv.conversation_history[0]
|
|
logger.success("test_add_message_with_time passed")
|
|
return True
|
|
except AssertionError as e:
|
|
logger.error(f"test_add_message_with_time failed: {str(e)}")
|
|
return False
|
|
|
|
|
|
def test_delete_message():
|
|
logger.info("Running test_delete_message")
|
|
conv = Conversation()
|
|
conv.add("user", "Hello, world!")
|
|
conv.delete(0)
|
|
try:
|
|
assert len(conv.conversation_history) == 0
|
|
logger.success("test_delete_message passed")
|
|
return True
|
|
except AssertionError as e:
|
|
logger.error(f"test_delete_message failed: {str(e)}")
|
|
return False
|
|
|
|
|
|
def test_delete_message_out_of_bounds():
|
|
logger.info("Running test_delete_message_out_of_bounds")
|
|
conv = Conversation()
|
|
conv.add("user", "Hello, world!")
|
|
try:
|
|
conv.delete(1)
|
|
logger.error(
|
|
"test_delete_message_out_of_bounds failed: Expected IndexError"
|
|
)
|
|
return False
|
|
except IndexError:
|
|
logger.success("test_delete_message_out_of_bounds passed")
|
|
return True
|
|
|
|
|
|
def test_update_message():
|
|
logger.info("Running test_update_message")
|
|
conv = Conversation()
|
|
conv.add("user", "Hello, world!")
|
|
conv.update(0, "assistant", "Hello, user!")
|
|
try:
|
|
assert len(conv.conversation_history) == 1
|
|
assert conv.conversation_history[0]["role"] == "assistant"
|
|
assert (
|
|
conv.conversation_history[0]["content"] == "Hello, user!"
|
|
)
|
|
logger.success("test_update_message passed")
|
|
return True
|
|
except AssertionError as e:
|
|
logger.error(f"test_update_message failed: {str(e)}")
|
|
return False
|
|
|
|
|
|
def test_update_message_out_of_bounds():
|
|
logger.info("Running test_update_message_out_of_bounds")
|
|
conv = Conversation()
|
|
conv.add("user", "Hello, world!")
|
|
try:
|
|
conv.update(1, "assistant", "Hello, user!")
|
|
logger.error(
|
|
"test_update_message_out_of_bounds failed: Expected IndexError"
|
|
)
|
|
return False
|
|
except IndexError:
|
|
logger.success("test_update_message_out_of_bounds passed")
|
|
return True
|
|
|
|
|
|
def test_return_history_as_string():
|
|
logger.info("Running test_return_history_as_string")
|
|
conv = Conversation()
|
|
conv.add("user", "Hello, world!")
|
|
conv.add("assistant", "Hello, user!")
|
|
result = conv.return_history_as_string()
|
|
expected = "user: Hello, world!\n\nassistant: Hello, user!\n\n"
|
|
try:
|
|
assert result == expected
|
|
logger.success("test_return_history_as_string passed")
|
|
return True
|
|
except AssertionError as e:
|
|
logger.error(
|
|
f"test_return_history_as_string failed: {str(e)}"
|
|
)
|
|
return False
|
|
|
|
|
|
def test_search():
|
|
logger.info("Running test_search")
|
|
conv = Conversation()
|
|
conv.add("user", "Hello, world!")
|
|
conv.add("assistant", "Hello, user!")
|
|
results = conv.search("Hello")
|
|
try:
|
|
assert len(results) == 2
|
|
assert results[0]["content"] == "Hello, world!"
|
|
assert results[1]["content"] == "Hello, user!"
|
|
logger.success("test_search passed")
|
|
return True
|
|
except AssertionError as e:
|
|
logger.error(f"test_search failed: {str(e)}")
|
|
return False
|
|
|
|
|
|
def test_conversation_cache_creation():
|
|
logger.info("Running test_conversation_cache_creation")
|
|
temp_dir = setup_temp_conversations_dir()
|
|
try:
|
|
conv = Conversation(
|
|
name="cache_test", conversations_dir=str(temp_dir)
|
|
)
|
|
conv.add("user", "Test message")
|
|
cache_file = temp_dir / "cache_test.json"
|
|
result = cache_file.exists()
|
|
if result:
|
|
logger.success("test_conversation_cache_creation passed")
|
|
else:
|
|
logger.error(
|
|
"test_conversation_cache_creation failed: Cache file not created"
|
|
)
|
|
return result
|
|
finally:
|
|
shutil.rmtree(temp_dir)
|
|
|
|
|
|
def test_conversation_cache_loading():
|
|
logger.info("Running test_conversation_cache_loading")
|
|
temp_dir = setup_temp_conversations_dir()
|
|
try:
|
|
conv1 = Conversation(
|
|
name="load_test", conversations_dir=str(temp_dir)
|
|
)
|
|
conv1.add("user", "Test message")
|
|
|
|
conv2 = Conversation.load_conversation(
|
|
name="load_test", conversations_dir=str(temp_dir)
|
|
)
|
|
result = (
|
|
len(conv2.conversation_history) == 1
|
|
and conv2.conversation_history[0]["content"]
|
|
== "Test message"
|
|
)
|
|
if result:
|
|
logger.success("test_conversation_cache_loading passed")
|
|
else:
|
|
logger.error(
|
|
"test_conversation_cache_loading failed: Loaded conversation mismatch"
|
|
)
|
|
return result
|
|
finally:
|
|
shutil.rmtree(temp_dir)
|
|
|
|
|
|
def test_add_multiple_messages():
|
|
logger.info("Running test_add_multiple_messages")
|
|
conv = Conversation()
|
|
roles = ["user", "assistant", "system"]
|
|
contents = ["Hello", "Hi there", "System message"]
|
|
conv.add_multiple_messages(roles, contents)
|
|
try:
|
|
assert len(conv.conversation_history) == 3
|
|
assert conv.conversation_history[0]["role"] == "user"
|
|
assert conv.conversation_history[1]["role"] == "assistant"
|
|
assert conv.conversation_history[2]["role"] == "system"
|
|
logger.success("test_add_multiple_messages passed")
|
|
return True
|
|
except AssertionError as e:
|
|
logger.error(f"test_add_multiple_messages failed: {str(e)}")
|
|
return False
|
|
|
|
|
|
def test_query():
|
|
logger.info("Running test_query")
|
|
conv = Conversation()
|
|
conv.add("user", "Test message")
|
|
try:
|
|
result = conv.query(0)
|
|
assert result["role"] == "user"
|
|
assert result["content"] == "Test message"
|
|
logger.success("test_query passed")
|
|
return True
|
|
except AssertionError as e:
|
|
logger.error(f"test_query failed: {str(e)}")
|
|
return False
|
|
|
|
|
|
def test_display_conversation():
|
|
logger.info("Running test_display_conversation")
|
|
conv = Conversation()
|
|
conv.add("user", "Hello")
|
|
conv.add("assistant", "Hi")
|
|
try:
|
|
conv.display_conversation()
|
|
logger.success("test_display_conversation passed")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"test_display_conversation failed: {str(e)}")
|
|
return False
|
|
|
|
|
|
def test_count_messages_by_role():
|
|
logger.info("Running test_count_messages_by_role")
|
|
conv = Conversation()
|
|
conv.add("user", "Hello")
|
|
conv.add("assistant", "Hi")
|
|
conv.add("system", "System message")
|
|
try:
|
|
counts = conv.count_messages_by_role()
|
|
assert counts["user"] == 1
|
|
assert counts["assistant"] == 1
|
|
assert counts["system"] == 1
|
|
logger.success("test_count_messages_by_role passed")
|
|
return True
|
|
except AssertionError as e:
|
|
logger.error(f"test_count_messages_by_role failed: {str(e)}")
|
|
return False
|
|
|
|
|
|
def test_get_str():
|
|
logger.info("Running test_get_str")
|
|
conv = Conversation()
|
|
conv.add("user", "Hello")
|
|
try:
|
|
result = conv.get_str()
|
|
assert "user: Hello" in result
|
|
logger.success("test_get_str passed")
|
|
return True
|
|
except AssertionError as e:
|
|
logger.error(f"test_get_str failed: {str(e)}")
|
|
return False
|
|
|
|
|
|
def test_to_json():
|
|
logger.info("Running test_to_json")
|
|
conv = Conversation()
|
|
conv.add("user", "Hello")
|
|
try:
|
|
result = conv.to_json()
|
|
assert isinstance(result, str)
|
|
assert "Hello" in result
|
|
logger.success("test_to_json passed")
|
|
return True
|
|
except AssertionError as e:
|
|
logger.error(f"test_to_json failed: {str(e)}")
|
|
return False
|
|
|
|
|
|
def test_to_dict():
|
|
logger.info("Running test_to_dict")
|
|
conv = Conversation()
|
|
conv.add("user", "Hello")
|
|
try:
|
|
result = conv.to_dict()
|
|
assert isinstance(result, list)
|
|
assert result[0]["content"] == "Hello"
|
|
logger.success("test_to_dict passed")
|
|
return True
|
|
except AssertionError as e:
|
|
logger.error(f"test_to_dict failed: {str(e)}")
|
|
return False
|
|
|
|
|
|
def test_to_yaml():
|
|
logger.info("Running test_to_yaml")
|
|
conv = Conversation()
|
|
conv.add("user", "Hello")
|
|
try:
|
|
result = conv.to_yaml()
|
|
assert isinstance(result, str)
|
|
assert "Hello" in result
|
|
logger.success("test_to_yaml passed")
|
|
return True
|
|
except AssertionError as e:
|
|
logger.error(f"test_to_yaml failed: {str(e)}")
|
|
return False
|
|
|
|
|
|
def test_get_last_message_as_string():
|
|
logger.info("Running test_get_last_message_as_string")
|
|
conv = Conversation()
|
|
conv.add("user", "First")
|
|
conv.add("assistant", "Last")
|
|
try:
|
|
result = conv.get_last_message_as_string()
|
|
assert result == "assistant: Last"
|
|
logger.success("test_get_last_message_as_string passed")
|
|
return True
|
|
except AssertionError as e:
|
|
logger.error(
|
|
f"test_get_last_message_as_string failed: {str(e)}"
|
|
)
|
|
return False
|
|
|
|
|
|
def test_return_messages_as_list():
|
|
logger.info("Running test_return_messages_as_list")
|
|
conv = Conversation()
|
|
conv.add("user", "Hello")
|
|
conv.add("assistant", "Hi")
|
|
try:
|
|
result = conv.return_messages_as_list()
|
|
assert len(result) == 2
|
|
assert result[0] == "user: Hello"
|
|
assert result[1] == "assistant: Hi"
|
|
logger.success("test_return_messages_as_list passed")
|
|
return True
|
|
except AssertionError as e:
|
|
logger.error(f"test_return_messages_as_list failed: {str(e)}")
|
|
return False
|
|
|
|
|
|
def test_return_messages_as_dictionary():
|
|
logger.info("Running test_return_messages_as_dictionary")
|
|
conv = Conversation()
|
|
conv.add("user", "Hello")
|
|
try:
|
|
result = conv.return_messages_as_dictionary()
|
|
assert len(result) == 1
|
|
assert result[0]["role"] == "user"
|
|
assert result[0]["content"] == "Hello"
|
|
logger.success("test_return_messages_as_dictionary passed")
|
|
return True
|
|
except AssertionError as e:
|
|
logger.error(
|
|
f"test_return_messages_as_dictionary failed: {str(e)}"
|
|
)
|
|
return False
|
|
|
|
|
|
def test_add_tool_output_to_agent():
|
|
logger.info("Running test_add_tool_output_to_agent")
|
|
conv = Conversation()
|
|
tool_output = {"name": "test_tool", "output": "test result"}
|
|
try:
|
|
conv.add_tool_output_to_agent("tool", tool_output)
|
|
assert len(conv.conversation_history) == 1
|
|
assert conv.conversation_history[0]["role"] == "tool"
|
|
assert conv.conversation_history[0]["content"] == tool_output
|
|
logger.success("test_add_tool_output_to_agent passed")
|
|
return True
|
|
except AssertionError as e:
|
|
logger.error(
|
|
f"test_add_tool_output_to_agent failed: {str(e)}"
|
|
)
|
|
return False
|
|
|
|
|
|
def test_get_final_message():
|
|
logger.info("Running test_get_final_message")
|
|
conv = Conversation()
|
|
conv.add("user", "First")
|
|
conv.add("assistant", "Last")
|
|
try:
|
|
result = conv.get_final_message()
|
|
assert result == "assistant: Last"
|
|
logger.success("test_get_final_message passed")
|
|
return True
|
|
except AssertionError as e:
|
|
logger.error(f"test_get_final_message failed: {str(e)}")
|
|
return False
|
|
|
|
|
|
def test_get_final_message_content():
|
|
logger.info("Running test_get_final_message_content")
|
|
conv = Conversation()
|
|
conv.add("user", "First")
|
|
conv.add("assistant", "Last")
|
|
try:
|
|
result = conv.get_final_message_content()
|
|
assert result == "Last"
|
|
logger.success("test_get_final_message_content passed")
|
|
return True
|
|
except AssertionError as e:
|
|
logger.error(
|
|
f"test_get_final_message_content failed: {str(e)}"
|
|
)
|
|
return False
|
|
|
|
|
|
def test_return_all_except_first():
|
|
logger.info("Running test_return_all_except_first")
|
|
conv = Conversation()
|
|
conv.add("system", "System")
|
|
conv.add("user", "Hello")
|
|
conv.add("assistant", "Hi")
|
|
try:
|
|
result = conv.return_all_except_first()
|
|
assert len(result) == 2
|
|
assert result[0]["role"] == "user"
|
|
assert result[1]["role"] == "assistant"
|
|
logger.success("test_return_all_except_first passed")
|
|
return True
|
|
except AssertionError as e:
|
|
logger.error(f"test_return_all_except_first failed: {str(e)}")
|
|
return False
|
|
|
|
|
|
def test_return_all_except_first_string():
|
|
logger.info("Running test_return_all_except_first_string")
|
|
conv = Conversation()
|
|
conv.add("system", "System")
|
|
conv.add("user", "Hello")
|
|
conv.add("assistant", "Hi")
|
|
try:
|
|
result = conv.return_all_except_first_string()
|
|
assert "Hello" in result
|
|
assert "Hi" in result
|
|
assert "System" not in result
|
|
logger.success("test_return_all_except_first_string passed")
|
|
return True
|
|
except AssertionError as e:
|
|
logger.error(
|
|
f"test_return_all_except_first_string failed: {str(e)}"
|
|
)
|
|
return False
|
|
|
|
|
|
def test_batch_add():
|
|
logger.info("Running test_batch_add")
|
|
conv = Conversation()
|
|
messages = [
|
|
{"role": "user", "content": "Hello"},
|
|
{"role": "assistant", "content": "Hi"},
|
|
]
|
|
try:
|
|
conv.batch_add(messages)
|
|
assert len(conv.conversation_history) == 2
|
|
assert conv.conversation_history[0]["role"] == "user"
|
|
assert conv.conversation_history[1]["role"] == "assistant"
|
|
logger.success("test_batch_add passed")
|
|
return True
|
|
except AssertionError as e:
|
|
logger.error(f"test_batch_add failed: {str(e)}")
|
|
return False
|
|
|
|
|
|
def test_get_cache_stats():
|
|
logger.info("Running test_get_cache_stats")
|
|
conv = Conversation(cache_enabled=True)
|
|
conv.add("user", "Hello")
|
|
try:
|
|
stats = conv.get_cache_stats()
|
|
assert "hits" in stats
|
|
assert "misses" in stats
|
|
assert "cached_tokens" in stats
|
|
assert "total_tokens" in stats
|
|
assert "hit_rate" in stats
|
|
logger.success("test_get_cache_stats passed")
|
|
return True
|
|
except AssertionError as e:
|
|
logger.error(f"test_get_cache_stats failed: {str(e)}")
|
|
return False
|
|
|
|
|
|
def test_list_cached_conversations():
|
|
logger.info("Running test_list_cached_conversations")
|
|
temp_dir = setup_temp_conversations_dir()
|
|
try:
|
|
conv = Conversation(
|
|
name="test_list", conversations_dir=str(temp_dir)
|
|
)
|
|
conv.add("user", "Test message")
|
|
|
|
conversations = Conversation.list_cached_conversations(
|
|
str(temp_dir)
|
|
)
|
|
try:
|
|
assert "test_list" in conversations
|
|
logger.success("test_list_cached_conversations passed")
|
|
return True
|
|
except AssertionError as e:
|
|
logger.error(
|
|
f"test_list_cached_conversations failed: {str(e)}"
|
|
)
|
|
return False
|
|
finally:
|
|
shutil.rmtree(temp_dir)
|
|
|
|
|
|
def test_clear():
|
|
logger.info("Running test_clear")
|
|
conv = Conversation()
|
|
conv.add("user", "Hello")
|
|
conv.add("assistant", "Hi")
|
|
try:
|
|
conv.clear()
|
|
assert len(conv.conversation_history) == 0
|
|
logger.success("test_clear passed")
|
|
return True
|
|
except AssertionError as e:
|
|
logger.error(f"test_clear failed: {str(e)}")
|
|
return False
|
|
|
|
|
|
def test_save_and_load_json():
|
|
logger.info("Running test_save_and_load_json")
|
|
temp_dir = setup_temp_conversations_dir()
|
|
file_path = temp_dir / "test_save.json"
|
|
|
|
try:
|
|
conv = Conversation()
|
|
conv.add("user", "Hello")
|
|
conv.save_as_json(str(file_path))
|
|
|
|
conv2 = Conversation()
|
|
conv2.load_from_json(str(file_path))
|
|
|
|
try:
|
|
assert len(conv2.conversation_history) == 1
|
|
assert conv2.conversation_history[0]["content"] == "Hello"
|
|
logger.success("test_save_and_load_json passed")
|
|
return True
|
|
except AssertionError as e:
|
|
logger.error(f"test_save_and_load_json failed: {str(e)}")
|
|
return False
|
|
finally:
|
|
shutil.rmtree(temp_dir)
|
|
|
|
|
|
def run_all_tests():
|
|
"""Run all test functions and return results."""
|
|
logger.info("Starting test suite execution")
|
|
test_results = []
|
|
test_functions = [
|
|
test_add_message,
|
|
test_add_message_with_time,
|
|
test_delete_message,
|
|
test_delete_message_out_of_bounds,
|
|
test_update_message,
|
|
test_update_message_out_of_bounds,
|
|
test_return_history_as_string,
|
|
test_search,
|
|
test_conversation_cache_creation,
|
|
test_conversation_cache_loading,
|
|
test_add_multiple_messages,
|
|
test_query,
|
|
test_display_conversation,
|
|
test_count_messages_by_role,
|
|
test_get_str,
|
|
test_to_json,
|
|
test_to_dict,
|
|
test_to_yaml,
|
|
test_get_last_message_as_string,
|
|
test_return_messages_as_list,
|
|
test_return_messages_as_dictionary,
|
|
test_add_tool_output_to_agent,
|
|
test_get_final_message,
|
|
test_get_final_message_content,
|
|
test_return_all_except_first,
|
|
test_return_all_except_first_string,
|
|
test_batch_add,
|
|
test_get_cache_stats,
|
|
test_list_cached_conversations,
|
|
test_clear,
|
|
test_save_and_load_json,
|
|
]
|
|
|
|
for test_func in test_functions:
|
|
start_time = datetime.now()
|
|
try:
|
|
result = test_func()
|
|
end_time = datetime.now()
|
|
duration = (end_time - start_time).total_seconds()
|
|
test_results.append(
|
|
{
|
|
"name": test_func.__name__,
|
|
"result": "PASS" if result else "FAIL",
|
|
"duration": duration,
|
|
}
|
|
)
|
|
except Exception as e:
|
|
end_time = datetime.now()
|
|
duration = (end_time - start_time).total_seconds()
|
|
test_results.append(
|
|
{
|
|
"name": test_func.__name__,
|
|
"result": "ERROR",
|
|
"error": str(e),
|
|
"duration": duration,
|
|
}
|
|
)
|
|
logger.error(
|
|
f"Test {test_func.__name__} failed with error: {str(e)}"
|
|
)
|
|
|
|
return test_results
|
|
|
|
|
|
def generate_markdown_report(results):
|
|
"""Generate a markdown report from test results."""
|
|
logger.info("Generating test report")
|
|
|
|
# Summary
|
|
total_tests = len(results)
|
|
passed_tests = sum(1 for r in results if r["result"] == "PASS")
|
|
failed_tests = sum(1 for r in results if r["result"] == "FAIL")
|
|
error_tests = sum(1 for r in results if r["result"] == "ERROR")
|
|
|
|
logger.info(f"Total Tests: {total_tests}")
|
|
logger.info(f"Passed: {passed_tests}")
|
|
logger.info(f"Failed: {failed_tests}")
|
|
logger.info(f"Errors: {error_tests}")
|
|
|
|
report = "# Test Results Report\n\n"
|
|
report += f"Test Run Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
|
|
|
|
report += "## Summary\n\n"
|
|
report += f"- Total Tests: {total_tests}\n"
|
|
report += f"- Passed: {passed_tests}\n"
|
|
report += f"- Failed: {failed_tests}\n"
|
|
report += f"- Errors: {error_tests}\n\n"
|
|
|
|
# Detailed Results
|
|
report += "## Detailed Results\n\n"
|
|
report += "| Test Name | Result | Duration (s) | Error |\n"
|
|
report += "|-----------|---------|--------------|-------|\n"
|
|
|
|
for result in results:
|
|
name = result["name"]
|
|
test_result = result["result"]
|
|
duration = f"{result['duration']:.4f}"
|
|
error = result.get("error", "")
|
|
report += (
|
|
f"| {name} | {test_result} | {duration} | {error} |\n"
|
|
)
|
|
|
|
return report
|
|
|
|
|
|
if __name__ == "__main__":
|
|
logger.info("Starting test execution")
|
|
results = run_all_tests()
|
|
report = generate_markdown_report(results)
|
|
|
|
# Save report to file
|
|
with open("test_results.md", "w") as f:
|
|
f.write(report)
|
|
|
|
logger.success(
|
|
"Test execution completed. Results saved to test_results.md"
|
|
)
|