diff --git a/tests/memory/test_main.py b/tests/memory/test_main.py deleted file mode 100644 index 63d56907..00000000 --- a/tests/memory/test_main.py +++ /dev/null @@ -1,58 +0,0 @@ -import pytest -from unittest.mock import Mock -from swarms.memory.ocean import OceanDB - - -@pytest.fixture -def mock_ocean_client(): - return Mock() - - -@pytest.fixture -def mock_collection(): - return Mock() - - -@pytest.fixture -def ocean_db(mock_ocean_client): - OceanDB.client = mock_ocean_client - return OceanDB() - - -def test_init(ocean_db, mock_ocean_client): - mock_ocean_client.heartbeat.return_value = "OK" - assert ocean_db.client.heartbeat() == "OK" - - -def test_create_collection( - ocean_db, mock_ocean_client, mock_collection -): - mock_ocean_client.create_collection.return_value = mock_collection - collection = ocean_db.create_collection("test", "text") - assert collection == mock_collection - - -def test_append_document(ocean_db, mock_collection): - document = "test_document" - id = "test_id" - ocean_db.append_document(mock_collection, document, id) - mock_collection.add.assert_called_once_with( - documents=[document], ids=[id] - ) - - -def test_add_documents(ocean_db, mock_collection): - documents = ["test_document1", "test_document2"] - ids = ["test_id1", "test_id2"] - ocean_db.add_documents(mock_collection, documents, ids) - mock_collection.add.assert_called_once_with( - documents=documents, ids=ids - ) - - -def test_query(ocean_db, mock_collection): - query_texts = ["test_query"] - n_results = 10 - mock_collection.query.return_value = "query_result" - result = ocean_db.query(mock_collection, query_texts, n_results) - assert result == "query_result" diff --git a/tests/memory/test_oceandb.py b/tests/memory/test_oceandb.py deleted file mode 100644 index e760dc61..00000000 --- a/tests/memory/test_oceandb.py +++ /dev/null @@ -1,103 +0,0 @@ -import pytest -from unittest.mock import Mock, patch -from swarms.memory.ocean import OceanDB - - -def test_init(): - with patch("oceandb.Client") as MockClient: - MockClient.return_value.heartbeat.return_value = "OK" - db = OceanDB(MockClient) - MockClient.assert_called_once() - assert db.client == MockClient - - -def test_init_exception(): - with patch("oceandb.Client") as MockClient: - MockClient.side_effect = Exception("Client error") - with pytest.raises(Exception) as e: - OceanDB(MockClient) - assert str(e.value) == "Client error" - - -def test_create_collection(): - with patch("oceandb.Client") as MockClient: - db = OceanDB(MockClient) - db.create_collection("test", "modality") - MockClient.create_collection.assert_called_once_with( - "test", embedding_function=Mock.ANY - ) - - -def test_create_collection_exception(): - with patch("oceandb.Client") as MockClient: - MockClient.create_collection.side_effect = Exception( - "Create collection error" - ) - db = OceanDB(MockClient) - with pytest.raises(Exception) as e: - db.create_collection("test", "modality") - assert str(e.value) == "Create collection error" - - -def test_append_document(): - with patch("oceandb.Client") as MockClient: - db = OceanDB(MockClient) - collection = Mock() - db.append_document(collection, "doc", "id") - collection.add.assert_called_once_with( - documents=["doc"], ids=["id"] - ) - - -def test_append_document_exception(): - with patch("oceandb.Client") as MockClient: - db = OceanDB(MockClient) - collection = Mock() - collection.add.side_effect = Exception( - "Append document error" - ) - with pytest.raises(Exception) as e: - db.append_document(collection, "doc", "id") - assert str(e.value) == "Append document error" - - -def test_add_documents(): - with patch("oceandb.Client") as MockClient: - db = OceanDB(MockClient) - collection = Mock() - db.add_documents(collection, ["doc1", "doc2"], ["id1", "id2"]) - collection.add.assert_called_once_with( - documents=["doc1", "doc2"], ids=["id1", "id2"] - ) - - -def test_add_documents_exception(): - with patch("oceandb.Client") as MockClient: - db = OceanDB(MockClient) - collection = Mock() - collection.add.side_effect = Exception("Add documents error") - with pytest.raises(Exception) as e: - db.add_documents( - collection, ["doc1", "doc2"], ["id1", "id2"] - ) - assert str(e.value) == "Add documents error" - - -def test_query(): - with patch("oceandb.Client") as MockClient: - db = OceanDB(MockClient) - collection = Mock() - db.query(collection, ["query1", "query2"], 2) - collection.query.assert_called_once_with( - query_texts=["query1", "query2"], n_results=2 - ) - - -def test_query_exception(): - with patch("oceandb.Client") as MockClient: - db = OceanDB(MockClient) - collection = Mock() - collection.query.side_effect = Exception("Query error") - with pytest.raises(Exception) as e: - db.query(collection, ["query1", "query2"], 2) - assert str(e.value) == "Query error" diff --git a/tests/memory/test_pg.py b/tests/memory/test_pg.py index ba564586..2bddfb27 100644 --- a/tests/memory/test_pg.py +++ b/tests/memory/test_pg.py @@ -1,6 +1,6 @@ import pytest from unittest.mock import patch -from swarms.memory import PgVectorVectorStore +from swarms.memory.pg import PgVectorVectorStore from dotenv import load_dotenv import os diff --git a/tests/memory/test_pinecone.py b/tests/memory/test_pinecone.py index 9cc99781..7c71503e 100644 --- a/tests/memory/test_pinecone.py +++ b/tests/memory/test_pinecone.py @@ -1,6 +1,6 @@ import os from unittest.mock import patch -from swarms.memory import PineconeVectorStore +from swarms.memory.pinecone import PineconeVectorStore api_key = os.getenv("PINECONE_API_KEY") or ""