You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/tests/structs/test_conversation.py

243 lines
7.2 KiB

6 months ago
import pytest
from swarms.structs.conversation import Conversation
@pytest.fixture
def conversation():
conv = Conversation()
conv.add("user", "Hello, world!")
conv.add("assistant", "Hello, user!")
return conv
def test_add_message():
conv = Conversation()
conv.add("user", "Hello, world!")
assert len(conv.conversation_history) == 1
assert conv.conversation_history[0]["role"] == "user"
assert conv.conversation_history[0]["content"] == "Hello, world!"
def test_add_message_with_time():
conv = Conversation(time_enabled=True)
conv.add("user", "Hello, world!")
assert len(conv.conversation_history) == 1
assert conv.conversation_history[0]["role"] == "user"
assert conv.conversation_history[0]["content"] == "Hello, world!"
assert "timestamp" in conv.conversation_history[0]
def test_delete_message():
conv = Conversation()
conv.add("user", "Hello, world!")
conv.delete(0)
assert len(conv.conversation_history) == 0
def test_delete_message_out_of_bounds():
conv = Conversation()
conv.add("user", "Hello, world!")
with pytest.raises(IndexError):
conv.delete(1)
def test_update_message():
conv = Conversation()
conv.add("user", "Hello, world!")
conv.update(0, "assistant", "Hello, user!")
assert len(conv.conversation_history) == 1
assert conv.conversation_history[0]["role"] == "assistant"
assert conv.conversation_history[0]["content"] == "Hello, user!"
def test_update_message_out_of_bounds():
conv = Conversation()
conv.add("user", "Hello, world!")
with pytest.raises(IndexError):
conv.update(1, "assistant", "Hello, user!")
def test_return_history_as_string_with_messages(conversation):
result = conversation.return_history_as_string()
assert result is not None
def test_return_history_as_string_with_no_messages():
conv = Conversation()
result = conv.return_history_as_string()
assert result == ""
@pytest.mark.parametrize(
"role, content",
[
("user", "Hello, world!"),
("assistant", "Hello, user!"),
("system", "System message"),
("function", "Function message"),
],
)
def test_return_history_as_string_with_different_roles(role, content):
conv = Conversation()
conv.add(role, content)
result = conv.return_history_as_string()
expected = f"{role}: {content}\n\n"
assert result == expected
@pytest.mark.parametrize("message_count", range(1, 11))
def test_return_history_as_string_with_multiple_messages(
message_count,
):
conv = Conversation()
for i in range(message_count):
conv.add("user", f"Message {i + 1}")
result = conv.return_history_as_string()
expected = "".join(
[f"user: Message {i + 1}\n\n" for i in range(message_count)]
)
assert result == expected
@pytest.mark.parametrize(
"content",
[
"Hello, world!",
"This is a longer message with multiple words.",
"This message\nhas multiple\nlines.",
"This message has special characters: !@#$%^&*()",
"This message has unicode characters: 你好,世界!",
],
)
def test_return_history_as_string_with_different_contents(content):
conv = Conversation()
conv.add("user", content)
result = conv.return_history_as_string()
expected = f"user: {content}\n\n"
assert result == expected
def test_return_history_as_string_with_large_message(conversation):
large_message = "Hello, world! " * 10000 # 10,000 repetitions
conversation.add("user", large_message)
result = conversation.return_history_as_string()
expected = (
"user: Hello, world!\n\nassistant: Hello, user!\n\nuser:"
f" {large_message}\n\n"
)
assert result == expected
def test_search_keyword_in_conversation(conversation):
result = conversation.search_keyword_in_conversation("Hello")
assert len(result) == 2
assert result[0]["content"] == "Hello, world!"
assert result[1]["content"] == "Hello, user!"
def test_export_import_conversation(conversation, tmp_path):
filename = tmp_path / "conversation.txt"
conversation.export_conversation(filename)
new_conversation = Conversation()
new_conversation.import_conversation(filename)
assert (
new_conversation.return_history_as_string()
== conversation.return_history_as_string()
)
def test_count_messages_by_role(conversation):
counts = conversation.count_messages_by_role()
assert counts["user"] == 1
assert counts["assistant"] == 1
def test_display_conversation(capsys, conversation):
conversation.display_conversation()
captured = capsys.readouterr()
assert "user: Hello, world!\n\n" in captured.out
assert "assistant: Hello, user!\n\n" in captured.out
def test_display_conversation_detailed(capsys, conversation):
conversation.display_conversation(detailed=True)
captured = capsys.readouterr()
assert "user: Hello, world!\n\n" in captured.out
assert "assistant: Hello, user!\n\n" in captured.out
def test_search():
conv = Conversation()
conv.add("user", "Hello, world!")
conv.add("assistant", "Hello, user!")
results = conv.search("Hello")
assert len(results) == 2
assert results[0]["content"] == "Hello, world!"
assert results[1]["content"] == "Hello, user!"
def test_return_history_as_string():
conv = Conversation()
conv.add("user", "Hello, world!")
conv.add("assistant", "Hello, user!")
result = conv.return_history_as_string()
expected = "user: Hello, world!\n\nassistant: Hello, user!\n\n"
assert result == expected
def test_search_no_results():
conv = Conversation()
conv.add("user", "Hello, world!")
conv.add("assistant", "Hello, user!")
results = conv.search("Goodbye")
assert len(results) == 0
def test_search_case_insensitive():
conv = Conversation()
conv.add("user", "Hello, world!")
conv.add("assistant", "Hello, user!")
results = conv.search("hello")
assert len(results) == 2
assert results[0]["content"] == "Hello, world!"
assert results[1]["content"] == "Hello, user!"
def test_search_multiple_occurrences():
conv = Conversation()
conv.add("user", "Hello, world! Hello, world!")
conv.add("assistant", "Hello, user!")
results = conv.search("Hello")
assert len(results) == 2
assert results[0]["content"] == "Hello, world! Hello, world!"
assert results[1]["content"] == "Hello, user!"
def test_query_no_results():
conv = Conversation()
conv.add("user", "Hello, world!")
conv.add("assistant", "Hello, user!")
results = conv.query("Goodbye")
assert len(results) == 0
def test_query_case_insensitive():
conv = Conversation()
conv.add("user", "Hello, world!")
conv.add("assistant", "Hello, user!")
results = conv.query("hello")
assert len(results) == 2
assert results[0]["content"] == "Hello, world!"
assert results[1]["content"] == "Hello, user!"
def test_query_multiple_occurrences():
conv = Conversation()
conv.add("user", "Hello, world! Hello, world!")
conv.add("assistant", "Hello, user!")
results = conv.query("Hello")
assert len(results) == 2
assert results[0]["content"] == "Hello, world! Hello, world!"
assert results[1]["content"] == "Hello, user!"