import pytest
from unittest.mock import Mock
from swarms.memory.oceandb import OceanDB


@pytest.fixture
def mock_ocean_client():
    return Mock()


@pytest.fixture
def mock_collection():
    return Mock()


@pytest.fixture
def ocean_db(mock_ocean_client):
    OceanDB.client = mock_ocean_client
    return OceanDB()


def test_init(ocean_db, mock_ocean_client):
    mock_ocean_client.heartbeat.return_value = "OK"
    assert ocean_db.client.heartbeat() == "OK"


def test_create_collection(ocean_db, mock_ocean_client, mock_collection):
    mock_ocean_client.create_collection.return_value = mock_collection
    collection = ocean_db.create_collection("test", "text")
    assert collection == mock_collection


def test_append_document(ocean_db, mock_collection):
    document = "test_document"
    id = "test_id"
    ocean_db.append_document(mock_collection, document, id)
    mock_collection.add.assert_called_once_with(documents=[document], ids=[id])


def test_add_documents(ocean_db, mock_collection):
    documents = ["test_document1", "test_document2"]
    ids = ["test_id1", "test_id2"]
    ocean_db.add_documents(mock_collection, documents, ids)
    mock_collection.add.assert_called_once_with(documents=documents, ids=ids)


def test_query(ocean_db, mock_collection):
    query_texts = ["test_query"]
    n_results = 10
    mock_collection.query.return_value = "query_result"
    result = ocean_db.query(mock_collection, query_texts, n_results)
    assert result == "query_result"