You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/tests/communication/test_conversation.py

698 lines
21 KiB

import shutil
from pathlib import Path
from datetime import datetime
from loguru import logger
from swarms.structs.conversation import Conversation
def setup_temp_conversations_dir():
"""Create a temporary directory for conversation cache files."""
temp_dir = Path("temp_test_conversations")
if temp_dir.exists():
shutil.rmtree(temp_dir)
temp_dir.mkdir()
logger.info(f"Created temporary test directory: {temp_dir}")
return temp_dir
def create_test_conversation(temp_dir):
"""Create a basic conversation for testing."""
conv = Conversation(
name="test_conversation", conversations_dir=str(temp_dir)
)
conv.add("user", "Hello, world!")
conv.add("assistant", "Hello, user!")
logger.info("Created test conversation with basic messages")
return conv
def test_add_message():
logger.info("Running test_add_message")
conv = Conversation()
conv.add("user", "Hello, world!")
try:
assert len(conv.conversation_history) == 1
assert conv.conversation_history[0]["role"] == "user"
assert (
conv.conversation_history[0]["content"] == "Hello, world!"
)
logger.success("test_add_message passed")
return True
except AssertionError as e:
logger.error(f"test_add_message failed: {str(e)}")
return False
def test_add_message_with_time():
logger.info("Running test_add_message_with_time")
conv = Conversation(time_enabled=False)
conv.add("user", "Hello, world!")
try:
assert len(conv.conversation_history) == 1
assert conv.conversation_history[0]["role"] == "user"
assert (
conv.conversation_history[0]["content"] == "Hello, world!"
)
assert "timestamp" in conv.conversation_history[0]
logger.success("test_add_message_with_time passed")
return True
except AssertionError as e:
logger.error(f"test_add_message_with_time failed: {str(e)}")
return False
def test_delete_message():
logger.info("Running test_delete_message")
conv = Conversation()
conv.add("user", "Hello, world!")
conv.delete(0)
try:
assert len(conv.conversation_history) == 0
logger.success("test_delete_message passed")
return True
except AssertionError as e:
logger.error(f"test_delete_message failed: {str(e)}")
return False
def test_delete_message_out_of_bounds():
logger.info("Running test_delete_message_out_of_bounds")
conv = Conversation()
conv.add("user", "Hello, world!")
try:
conv.delete(1)
logger.error(
"test_delete_message_out_of_bounds failed: Expected IndexError"
)
return False
except IndexError:
logger.success("test_delete_message_out_of_bounds passed")
return True
def test_update_message():
logger.info("Running test_update_message")
conv = Conversation()
conv.add("user", "Hello, world!")
conv.update(0, "assistant", "Hello, user!")
try:
assert len(conv.conversation_history) == 1
assert conv.conversation_history[0]["role"] == "assistant"
assert (
conv.conversation_history[0]["content"] == "Hello, user!"
)
logger.success("test_update_message passed")
return True
except AssertionError as e:
logger.error(f"test_update_message failed: {str(e)}")
return False
def test_update_message_out_of_bounds():
logger.info("Running test_update_message_out_of_bounds")
conv = Conversation()
conv.add("user", "Hello, world!")
try:
conv.update(1, "assistant", "Hello, user!")
logger.error(
"test_update_message_out_of_bounds failed: Expected IndexError"
)
return False
except IndexError:
logger.success("test_update_message_out_of_bounds passed")
return True
def test_return_history_as_string():
logger.info("Running test_return_history_as_string")
conv = Conversation()
conv.add("user", "Hello, world!")
conv.add("assistant", "Hello, user!")
result = conv.return_history_as_string()
expected = "user: Hello, world!\n\nassistant: Hello, user!\n\n"
try:
assert result == expected
logger.success("test_return_history_as_string passed")
return True
except AssertionError as e:
logger.error(
f"test_return_history_as_string failed: {str(e)}"
)
return False
def test_search():
logger.info("Running test_search")
conv = Conversation()
conv.add("user", "Hello, world!")
conv.add("assistant", "Hello, user!")
results = conv.search("Hello")
try:
assert len(results) == 2
assert results[0]["content"] == "Hello, world!"
assert results[1]["content"] == "Hello, user!"
logger.success("test_search passed")
return True
except AssertionError as e:
logger.error(f"test_search failed: {str(e)}")
return False
def test_conversation_cache_creation():
logger.info("Running test_conversation_cache_creation")
temp_dir = setup_temp_conversations_dir()
try:
conv = Conversation(
name="cache_test", conversations_dir=str(temp_dir)
)
conv.add("user", "Test message")
cache_file = temp_dir / "cache_test.json"
result = cache_file.exists()
if result:
logger.success("test_conversation_cache_creation passed")
else:
logger.error(
"test_conversation_cache_creation failed: Cache file not created"
)
return result
finally:
shutil.rmtree(temp_dir)
def test_conversation_cache_loading():
logger.info("Running test_conversation_cache_loading")
temp_dir = setup_temp_conversations_dir()
try:
conv1 = Conversation(
name="load_test", conversations_dir=str(temp_dir)
)
conv1.add("user", "Test message")
conv2 = Conversation.load_conversation(
name="load_test", conversations_dir=str(temp_dir)
)
result = (
len(conv2.conversation_history) == 1
and conv2.conversation_history[0]["content"]
== "Test message"
)
if result:
logger.success("test_conversation_cache_loading passed")
else:
logger.error(
"test_conversation_cache_loading failed: Loaded conversation mismatch"
)
return result
finally:
shutil.rmtree(temp_dir)
def test_add_multiple_messages():
logger.info("Running test_add_multiple_messages")
conv = Conversation()
roles = ["user", "assistant", "system"]
contents = ["Hello", "Hi there", "System message"]
conv.add_multiple_messages(roles, contents)
try:
assert len(conv.conversation_history) == 3
assert conv.conversation_history[0]["role"] == "user"
assert conv.conversation_history[1]["role"] == "assistant"
assert conv.conversation_history[2]["role"] == "system"
logger.success("test_add_multiple_messages passed")
return True
except AssertionError as e:
logger.error(f"test_add_multiple_messages failed: {str(e)}")
return False
def test_query():
logger.info("Running test_query")
conv = Conversation()
conv.add("user", "Test message")
try:
result = conv.query(0)
assert result["role"] == "user"
assert result["content"] == "Test message"
logger.success("test_query passed")
return True
except AssertionError as e:
logger.error(f"test_query failed: {str(e)}")
return False
def test_display_conversation():
logger.info("Running test_display_conversation")
conv = Conversation()
conv.add("user", "Hello")
conv.add("assistant", "Hi")
try:
conv.display_conversation()
logger.success("test_display_conversation passed")
return True
except Exception as e:
logger.error(f"test_display_conversation failed: {str(e)}")
return False
def test_count_messages_by_role():
logger.info("Running test_count_messages_by_role")
conv = Conversation()
conv.add("user", "Hello")
conv.add("assistant", "Hi")
conv.add("system", "System message")
try:
counts = conv.count_messages_by_role()
assert counts["user"] == 1
assert counts["assistant"] == 1
assert counts["system"] == 1
logger.success("test_count_messages_by_role passed")
return True
except AssertionError as e:
logger.error(f"test_count_messages_by_role failed: {str(e)}")
return False
def test_get_str():
logger.info("Running test_get_str")
conv = Conversation()
conv.add("user", "Hello")
try:
result = conv.get_str()
assert "user: Hello" in result
logger.success("test_get_str passed")
return True
except AssertionError as e:
logger.error(f"test_get_str failed: {str(e)}")
return False
def test_to_json():
logger.info("Running test_to_json")
conv = Conversation()
conv.add("user", "Hello")
try:
result = conv.to_json()
assert isinstance(result, str)
assert "Hello" in result
logger.success("test_to_json passed")
return True
except AssertionError as e:
logger.error(f"test_to_json failed: {str(e)}")
return False
def test_to_dict():
logger.info("Running test_to_dict")
conv = Conversation()
conv.add("user", "Hello")
try:
result = conv.to_dict()
assert isinstance(result, list)
assert result[0]["content"] == "Hello"
logger.success("test_to_dict passed")
return True
except AssertionError as e:
logger.error(f"test_to_dict failed: {str(e)}")
return False
def test_to_yaml():
logger.info("Running test_to_yaml")
conv = Conversation()
conv.add("user", "Hello")
try:
result = conv.to_yaml()
assert isinstance(result, str)
assert "Hello" in result
logger.success("test_to_yaml passed")
return True
except AssertionError as e:
logger.error(f"test_to_yaml failed: {str(e)}")
return False
def test_get_last_message_as_string():
logger.info("Running test_get_last_message_as_string")
conv = Conversation()
conv.add("user", "First")
conv.add("assistant", "Last")
try:
result = conv.get_last_message_as_string()
assert result == "assistant: Last"
logger.success("test_get_last_message_as_string passed")
return True
except AssertionError as e:
logger.error(
f"test_get_last_message_as_string failed: {str(e)}"
)
return False
def test_return_messages_as_list():
logger.info("Running test_return_messages_as_list")
conv = Conversation()
conv.add("user", "Hello")
conv.add("assistant", "Hi")
try:
result = conv.return_messages_as_list()
assert len(result) == 2
assert result[0] == "user: Hello"
assert result[1] == "assistant: Hi"
logger.success("test_return_messages_as_list passed")
return True
except AssertionError as e:
logger.error(f"test_return_messages_as_list failed: {str(e)}")
return False
def test_return_messages_as_dictionary():
logger.info("Running test_return_messages_as_dictionary")
conv = Conversation()
conv.add("user", "Hello")
try:
result = conv.return_messages_as_dictionary()
assert len(result) == 1
assert result[0]["role"] == "user"
assert result[0]["content"] == "Hello"
logger.success("test_return_messages_as_dictionary passed")
return True
except AssertionError as e:
logger.error(
f"test_return_messages_as_dictionary failed: {str(e)}"
)
return False
def test_add_tool_output_to_agent():
logger.info("Running test_add_tool_output_to_agent")
conv = Conversation()
tool_output = {"name": "test_tool", "output": "test result"}
try:
conv.add_tool_output_to_agent("tool", tool_output)
assert len(conv.conversation_history) == 1
assert conv.conversation_history[0]["role"] == "tool"
assert conv.conversation_history[0]["content"] == tool_output
logger.success("test_add_tool_output_to_agent passed")
return True
except AssertionError as e:
logger.error(
f"test_add_tool_output_to_agent failed: {str(e)}"
)
return False
def test_get_final_message():
logger.info("Running test_get_final_message")
conv = Conversation()
conv.add("user", "First")
conv.add("assistant", "Last")
try:
result = conv.get_final_message()
assert result == "assistant: Last"
logger.success("test_get_final_message passed")
return True
except AssertionError as e:
logger.error(f"test_get_final_message failed: {str(e)}")
return False
def test_get_final_message_content():
logger.info("Running test_get_final_message_content")
conv = Conversation()
conv.add("user", "First")
conv.add("assistant", "Last")
try:
result = conv.get_final_message_content()
assert result == "Last"
logger.success("test_get_final_message_content passed")
return True
except AssertionError as e:
logger.error(
f"test_get_final_message_content failed: {str(e)}"
)
return False
def test_return_all_except_first():
logger.info("Running test_return_all_except_first")
conv = Conversation()
conv.add("system", "System")
conv.add("user", "Hello")
conv.add("assistant", "Hi")
try:
result = conv.return_all_except_first()
assert len(result) == 2
assert result[0]["role"] == "user"
assert result[1]["role"] == "assistant"
logger.success("test_return_all_except_first passed")
return True
except AssertionError as e:
logger.error(f"test_return_all_except_first failed: {str(e)}")
return False
def test_return_all_except_first_string():
logger.info("Running test_return_all_except_first_string")
conv = Conversation()
conv.add("system", "System")
conv.add("user", "Hello")
conv.add("assistant", "Hi")
try:
result = conv.return_all_except_first_string()
assert "Hello" in result
assert "Hi" in result
assert "System" not in result
logger.success("test_return_all_except_first_string passed")
return True
except AssertionError as e:
logger.error(
f"test_return_all_except_first_string failed: {str(e)}"
)
return False
def test_batch_add():
logger.info("Running test_batch_add")
conv = Conversation()
messages = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi"},
]
try:
conv.batch_add(messages)
assert len(conv.conversation_history) == 2
assert conv.conversation_history[0]["role"] == "user"
assert conv.conversation_history[1]["role"] == "assistant"
logger.success("test_batch_add passed")
return True
except AssertionError as e:
logger.error(f"test_batch_add failed: {str(e)}")
return False
def test_get_cache_stats():
logger.info("Running test_get_cache_stats")
conv = Conversation(cache_enabled=True)
conv.add("user", "Hello")
try:
stats = conv.get_cache_stats()
assert "hits" in stats
assert "misses" in stats
assert "cached_tokens" in stats
assert "total_tokens" in stats
assert "hit_rate" in stats
logger.success("test_get_cache_stats passed")
return True
except AssertionError as e:
logger.error(f"test_get_cache_stats failed: {str(e)}")
return False
def test_list_cached_conversations():
logger.info("Running test_list_cached_conversations")
temp_dir = setup_temp_conversations_dir()
try:
conv = Conversation(
name="test_list", conversations_dir=str(temp_dir)
)
conv.add("user", "Test message")
conversations = Conversation.list_cached_conversations(
str(temp_dir)
)
try:
assert "test_list" in conversations
logger.success("test_list_cached_conversations passed")
return True
except AssertionError as e:
logger.error(
f"test_list_cached_conversations failed: {str(e)}"
)
return False
finally:
shutil.rmtree(temp_dir)
def test_clear():
logger.info("Running test_clear")
conv = Conversation()
conv.add("user", "Hello")
conv.add("assistant", "Hi")
try:
conv.clear()
assert len(conv.conversation_history) == 0
logger.success("test_clear passed")
return True
except AssertionError as e:
logger.error(f"test_clear failed: {str(e)}")
return False
def test_save_and_load_json():
logger.info("Running test_save_and_load_json")
temp_dir = setup_temp_conversations_dir()
file_path = temp_dir / "test_save.json"
try:
conv = Conversation()
conv.add("user", "Hello")
conv.save_as_json(str(file_path))
conv2 = Conversation()
conv2.load_from_json(str(file_path))
try:
assert len(conv2.conversation_history) == 1
assert conv2.conversation_history[0]["content"] == "Hello"
logger.success("test_save_and_load_json passed")
return True
except AssertionError as e:
logger.error(f"test_save_and_load_json failed: {str(e)}")
return False
finally:
shutil.rmtree(temp_dir)
def run_all_tests():
"""Run all test functions and return results."""
logger.info("Starting test suite execution")
test_results = []
test_functions = [
test_add_message,
test_add_message_with_time,
test_delete_message,
test_delete_message_out_of_bounds,
test_update_message,
test_update_message_out_of_bounds,
test_return_history_as_string,
test_search,
test_conversation_cache_creation,
test_conversation_cache_loading,
test_add_multiple_messages,
test_query,
test_display_conversation,
test_count_messages_by_role,
test_get_str,
test_to_json,
test_to_dict,
test_to_yaml,
test_get_last_message_as_string,
test_return_messages_as_list,
test_return_messages_as_dictionary,
test_add_tool_output_to_agent,
test_get_final_message,
test_get_final_message_content,
test_return_all_except_first,
test_return_all_except_first_string,
test_batch_add,
test_get_cache_stats,
test_list_cached_conversations,
test_clear,
test_save_and_load_json,
]
for test_func in test_functions:
start_time = datetime.now()
try:
result = test_func()
end_time = datetime.now()
duration = (end_time - start_time).total_seconds()
test_results.append(
{
"name": test_func.__name__,
"result": "PASS" if result else "FAIL",
"duration": duration,
}
)
except Exception as e:
end_time = datetime.now()
duration = (end_time - start_time).total_seconds()
test_results.append(
{
"name": test_func.__name__,
"result": "ERROR",
"error": str(e),
"duration": duration,
}
)
logger.error(
f"Test {test_func.__name__} failed with error: {str(e)}"
)
return test_results
def generate_markdown_report(results):
"""Generate a markdown report from test results."""
logger.info("Generating test report")
# Summary
total_tests = len(results)
passed_tests = sum(1 for r in results if r["result"] == "PASS")
failed_tests = sum(1 for r in results if r["result"] == "FAIL")
error_tests = sum(1 for r in results if r["result"] == "ERROR")
logger.info(f"Total Tests: {total_tests}")
logger.info(f"Passed: {passed_tests}")
logger.info(f"Failed: {failed_tests}")
logger.info(f"Errors: {error_tests}")
report = "# Test Results Report\n\n"
report += f"Test Run Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
report += "## Summary\n\n"
report += f"- Total Tests: {total_tests}\n"
report += f"- Passed: {passed_tests}\n"
report += f"- Failed: {failed_tests}\n"
report += f"- Errors: {error_tests}\n\n"
# Detailed Results
report += "## Detailed Results\n\n"
report += "| Test Name | Result | Duration (s) | Error |\n"
report += "|-----------|---------|--------------|-------|\n"
for result in results:
name = result["name"]
test_result = result["result"]
duration = f"{result['duration']:.4f}"
error = result.get("error", "")
report += (
f"| {name} | {test_result} | {duration} | {error} |\n"
)
return report
if __name__ == "__main__":
logger.info("Starting test execution")
results = run_all_tests()
report = generate_markdown_report(results)
# Save report to file
with open("test_results.md", "w") as f:
f.write(report)
logger.success(
"Test execution completed. Results saved to test_results.md"
)