You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
97 lines
2.3 KiB
97 lines
2.3 KiB
# LangchainChromaVectorMemory
|
|
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from swarms.memory import LangchainChromaVectorMemory
|
|
|
|
|
|
# Fixtures for setting up the memory and mocks
|
|
@pytest.fixture()
|
|
def vector_memory(tmp_path):
|
|
loc = tmp_path / "vector_memory"
|
|
return LangchainChromaVectorMemory(loc=loc)
|
|
|
|
|
|
@pytest.fixture()
|
|
def embeddings_mock():
|
|
with patch("swarms.memory.OpenAIEmbeddings") as mock:
|
|
yield mock
|
|
|
|
|
|
@pytest.fixture()
|
|
def chroma_mock():
|
|
with patch("swarms.memory.Chroma") as mock:
|
|
yield mock
|
|
|
|
|
|
@pytest.fixture()
|
|
def qa_mock():
|
|
with patch("swarms.memory.RetrievalQA") as mock:
|
|
yield mock
|
|
|
|
|
|
# Example test cases
|
|
def test_initialization_default_settings(vector_memory):
|
|
assert vector_memory.chunk_size == 1000
|
|
assert (
|
|
vector_memory.chunk_overlap == 100
|
|
) # assuming default overlap of 0.1
|
|
assert vector_memory.loc.exists()
|
|
|
|
|
|
def test_add_entry(vector_memory, embeddings_mock):
|
|
with patch.object(
|
|
vector_memory.db, "add_texts"
|
|
) as add_texts_mock:
|
|
vector_memory.add("Example text")
|
|
add_texts_mock.assert_called()
|
|
|
|
|
|
def test_search_memory_returns_list(vector_memory):
|
|
result = vector_memory.search_memory("example query", k=5)
|
|
assert isinstance(result, list)
|
|
|
|
|
|
def test_ask_question_returns_string(vector_memory, qa_mock):
|
|
result = vector_memory.query("What is the color of the sky?")
|
|
assert isinstance(result, str)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"query,k,type,expected",
|
|
[
|
|
("example query", 5, "mmr", [MagicMock()]),
|
|
(
|
|
"example query",
|
|
0,
|
|
"mmr",
|
|
None,
|
|
), # Expected none when k is 0 or negative
|
|
(
|
|
"example query",
|
|
3,
|
|
"cos",
|
|
[MagicMock()],
|
|
), # Mocked object as a placeholder
|
|
],
|
|
)
|
|
def test_search_memory_different_params(
|
|
vector_memory, query, k, type, expected
|
|
):
|
|
with patch.object(
|
|
vector_memory.db,
|
|
"max_marginal_relevance_search",
|
|
return_value=expected,
|
|
):
|
|
with patch.object(
|
|
vector_memory.db,
|
|
"similarity_search_with_score",
|
|
return_value=expected,
|
|
):
|
|
result = vector_memory.search_memory(
|
|
query, k=k, type=type
|
|
)
|
|
assert len(result) == (k if k > 0 else 0)
|