You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/tests/memory/test_langchainchromavectorm...

93 lines
2.3 KiB

6 months ago
# LangchainChromaVectorMemory
from unittest.mock import MagicMock, patch
import pytest
from swarms.memory import LangchainChromaVectorMemory
# Fixtures for setting up the memory and mocks
@pytest.fixture()
def vector_memory(tmp_path):
loc = tmp_path / "vector_memory"
return LangchainChromaVectorMemory(loc=loc)
@pytest.fixture()
def embeddings_mock():
with patch("swarms.memory.OpenAIEmbeddings") as mock:
yield mock
@pytest.fixture()
def chroma_mock():
with patch("swarms.memory.Chroma") as mock:
yield mock
@pytest.fixture()
def qa_mock():
with patch("swarms.memory.RetrievalQA") as mock:
yield mock
# Example test cases
def test_initialization_default_settings(vector_memory):
assert vector_memory.chunk_size == 1000
assert (
vector_memory.chunk_overlap == 100
) # assuming default overlap of 0.1
assert vector_memory.loc.exists()
def test_add_entry(vector_memory, embeddings_mock):
with patch.object(vector_memory.db, "add_texts") as add_texts_mock:
vector_memory.add("Example text")
add_texts_mock.assert_called()
def test_search_memory_returns_list(vector_memory):
result = vector_memory.search_memory("example query", k=5)
assert isinstance(result, list)
def test_ask_question_returns_string(vector_memory, qa_mock):
result = vector_memory.query("What is the color of the sky?")
assert isinstance(result, str)
@pytest.mark.parametrize(
"query,k,type,expected",
[
("example query", 5, "mmr", [MagicMock()]),
(
"example query",
0,
"mmr",
None,
), # Expected none when k is 0 or negative
(
"example query",
3,
"cos",
[MagicMock()],
), # Mocked object as a placeholder
],
)
def test_search_memory_different_params(
vector_memory, query, k, type, expected
):
with patch.object(
vector_memory.db,
"max_marginal_relevance_search",
return_value=expected,
):
with patch.object(
vector_memory.db,
"similarity_search_with_score",
return_value=expected,
):
result = vector_memory.search_memory(query, k=k, type=type)
assert len(result) == (k if k > 0 else 0)