import pytest from unittest.mock import Mock from swarms.memory.ocean import OceanDB @pytest.fixture def mock_ocean_client(): return Mock() @pytest.fixture def mock_collection(): return Mock() @pytest.fixture def ocean_db(mock_ocean_client): OceanDB.client = mock_ocean_client return OceanDB() def test_init(ocean_db, mock_ocean_client): mock_ocean_client.heartbeat.return_value = "OK" assert ocean_db.client.heartbeat() == "OK" def test_create_collection(ocean_db, mock_ocean_client, mock_collection): mock_ocean_client.create_collection.return_value = mock_collection collection = ocean_db.create_collection("test", "text") assert collection == mock_collection def test_append_document(ocean_db, mock_collection): document = "test_document" id = "test_id" ocean_db.append_document(mock_collection, document, id) mock_collection.add.assert_called_once_with(documents=[document], ids=[id]) def test_add_documents(ocean_db, mock_collection): documents = ["test_document1", "test_document2"] ids = ["test_id1", "test_id2"] ocean_db.add_documents(mock_collection, documents, ids) mock_collection.add.assert_called_once_with(documents=documents, ids=ids) def test_query(ocean_db, mock_collection): query_texts = ["test_query"] n_results = 10 mock_collection.query.return_value = "query_result" result = ocean_db.query(mock_collection, query_texts, n_results) assert result == "query_result"