You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
199 lines
6.2 KiB
199 lines
6.2 KiB
from unittest.mock import Mock, patch
|
|
|
|
import pytest
|
|
|
|
from swarms.memory import WeaviateDB
|
|
|
|
|
|
# Define fixture for a WeaviateDB instance with mocked methods
|
|
@pytest.fixture
|
|
def weaviate_client_mock():
|
|
client = WeaviateDB(
|
|
http_host="mock_host",
|
|
http_port="mock_port",
|
|
http_secure=False,
|
|
grpc_host="mock_grpc_host",
|
|
grpc_port="mock_grpc_port",
|
|
grpc_secure=False,
|
|
auth_client_secret="mock_api_key",
|
|
additional_headers={
|
|
"X-OpenAI-Api-Key": "mock_openai_api_key"
|
|
},
|
|
additional_config=Mock(),
|
|
)
|
|
|
|
# Mock the methods
|
|
client.client.collections.create = Mock()
|
|
client.client.collections.get = Mock()
|
|
client.client.collections.query = Mock()
|
|
client.client.collections.data.insert = Mock()
|
|
client.client.collections.data.update = Mock()
|
|
client.client.collections.data.delete_by_id = Mock()
|
|
|
|
return client
|
|
|
|
|
|
# Define tests for the WeaviateDB class
|
|
def test_create_collection(weaviate_client_mock):
|
|
# Test creating a collection
|
|
weaviate_client_mock.create_collection(
|
|
"test_collection", [{"name": "property"}]
|
|
)
|
|
weaviate_client_mock.client.collections.create.assert_called_with(
|
|
name="test_collection",
|
|
vectorizer_config=None,
|
|
properties=[{"name": "property"}],
|
|
)
|
|
|
|
|
|
def test_add_object(weaviate_client_mock):
|
|
# Test adding an object
|
|
properties = {"name": "John"}
|
|
weaviate_client_mock.add("test_collection", properties)
|
|
weaviate_client_mock.client.collections.get.assert_called_with(
|
|
"test_collection"
|
|
)
|
|
weaviate_client_mock.client.collections.data.insert.assert_called_with(
|
|
properties
|
|
)
|
|
|
|
|
|
def test_query_objects(weaviate_client_mock):
|
|
# Test querying objects
|
|
query = "name:John"
|
|
weaviate_client_mock.query("test_collection", query)
|
|
weaviate_client_mock.client.collections.get.assert_called_with(
|
|
"test_collection"
|
|
)
|
|
weaviate_client_mock.client.collections.query.bm25.assert_called_with(
|
|
query=query, limit=10
|
|
)
|
|
|
|
|
|
def test_update_object(weaviate_client_mock):
|
|
# Test updating an object
|
|
object_id = "12345"
|
|
properties = {"name": "Jane"}
|
|
weaviate_client_mock.update(
|
|
"test_collection", object_id, properties
|
|
)
|
|
weaviate_client_mock.client.collections.get.assert_called_with(
|
|
"test_collection"
|
|
)
|
|
weaviate_client_mock.client.collections.data.update.assert_called_with(
|
|
object_id, properties
|
|
)
|
|
|
|
|
|
def test_delete_object(weaviate_client_mock):
|
|
# Test deleting an object
|
|
object_id = "12345"
|
|
weaviate_client_mock.delete("test_collection", object_id)
|
|
weaviate_client_mock.client.collections.get.assert_called_with(
|
|
"test_collection"
|
|
)
|
|
weaviate_client_mock.client.collections.data.delete_by_id.assert_called_with(
|
|
object_id
|
|
)
|
|
|
|
|
|
def test_create_collection_with_vectorizer_config(
|
|
weaviate_client_mock,
|
|
):
|
|
# Test creating a collection with vectorizer configuration
|
|
vectorizer_config = {"config_key": "config_value"}
|
|
weaviate_client_mock.create_collection(
|
|
"test_collection", [{"name": "property"}], vectorizer_config
|
|
)
|
|
weaviate_client_mock.client.collections.create.assert_called_with(
|
|
name="test_collection",
|
|
vectorizer_config=vectorizer_config,
|
|
properties=[{"name": "property"}],
|
|
)
|
|
|
|
|
|
def test_query_objects_with_limit(weaviate_client_mock):
|
|
# Test querying objects with a specified limit
|
|
query = "name:John"
|
|
limit = 20
|
|
weaviate_client_mock.query("test_collection", query, limit)
|
|
weaviate_client_mock.client.collections.get.assert_called_with(
|
|
"test_collection"
|
|
)
|
|
weaviate_client_mock.client.collections.query.bm25.assert_called_with(
|
|
query=query, limit=limit
|
|
)
|
|
|
|
|
|
def test_query_objects_without_limit(weaviate_client_mock):
|
|
# Test querying objects without specifying a limit
|
|
query = "name:John"
|
|
weaviate_client_mock.query("test_collection", query)
|
|
weaviate_client_mock.client.collections.get.assert_called_with(
|
|
"test_collection"
|
|
)
|
|
weaviate_client_mock.client.collections.query.bm25.assert_called_with(
|
|
query=query, limit=10
|
|
)
|
|
|
|
|
|
def test_create_collection_failure(weaviate_client_mock):
|
|
# Test failure when creating a collection
|
|
with patch(
|
|
"weaviate_client.weaviate.collections.create",
|
|
side_effect=Exception("Create error"),
|
|
):
|
|
with pytest.raises(
|
|
Exception, match="Error creating collection"
|
|
):
|
|
weaviate_client_mock.create_collection(
|
|
"test_collection", [{"name": "property"}]
|
|
)
|
|
|
|
|
|
def test_add_object_failure(weaviate_client_mock):
|
|
# Test failure when adding an object
|
|
properties = {"name": "John"}
|
|
with patch(
|
|
"weaviate_client.weaviate.collections.data.insert",
|
|
side_effect=Exception("Insert error"),
|
|
):
|
|
with pytest.raises(Exception, match="Error adding object"):
|
|
weaviate_client_mock.add("test_collection", properties)
|
|
|
|
|
|
def test_query_objects_failure(weaviate_client_mock):
|
|
# Test failure when querying objects
|
|
query = "name:John"
|
|
with patch(
|
|
"weaviate_client.weaviate.collections.query.bm25",
|
|
side_effect=Exception("Query error"),
|
|
):
|
|
with pytest.raises(Exception, match="Error querying objects"):
|
|
weaviate_client_mock.query("test_collection", query)
|
|
|
|
|
|
def test_update_object_failure(weaviate_client_mock):
|
|
# Test failure when updating an object
|
|
object_id = "12345"
|
|
properties = {"name": "Jane"}
|
|
with patch(
|
|
"weaviate_client.weaviate.collections.data.update",
|
|
side_effect=Exception("Update error"),
|
|
):
|
|
with pytest.raises(Exception, match="Error updating object"):
|
|
weaviate_client_mock.update(
|
|
"test_collection", object_id, properties
|
|
)
|
|
|
|
|
|
def test_delete_object_failure(weaviate_client_mock):
|
|
# Test failure when deleting an object
|
|
object_id = "12345"
|
|
with patch(
|
|
"weaviate_client.weaviate.collections.data.delete_by_id",
|
|
side_effect=Exception("Delete error"),
|
|
):
|
|
with pytest.raises(Exception, match="Error deleting object"):
|
|
weaviate_client_mock.delete("test_collection", object_id)
|