import pytest from unittest.mock import Mock, patch from swarms.memory import WeaviateDB # Define fixture for a WeaviateDB instance with mocked methods @pytest.fixture def weaviate_client_mock(): client = WeaviateDB( http_host="mock_host", http_port="mock_port", http_secure=False, grpc_host="mock_grpc_host", grpc_port="mock_grpc_port", grpc_secure=False, auth_client_secret="mock_api_key", additional_headers={ "X-OpenAI-Api-Key": "mock_openai_api_key" }, additional_config=Mock(), ) # Mock the methods client.client.collections.create = Mock() client.client.collections.get = Mock() client.client.collections.query = Mock() client.client.collections.data.insert = Mock() client.client.collections.data.update = Mock() client.client.collections.data.delete_by_id = Mock() return client # Define tests for the WeaviateDB class def test_create_collection(weaviate_client_mock): # Test creating a collection weaviate_client_mock.create_collection( "test_collection", [{"name": "property"}] ) weaviate_client_mock.client.collections.create.assert_called_with( name="test_collection", vectorizer_config=None, properties=[{"name": "property"}], ) def test_add_object(weaviate_client_mock): # Test adding an object properties = {"name": "John"} weaviate_client_mock.add("test_collection", properties) weaviate_client_mock.client.collections.get.assert_called_with( "test_collection" ) weaviate_client_mock.client.collections.data.insert.assert_called_with( properties ) def test_query_objects(weaviate_client_mock): # Test querying objects query = "name:John" weaviate_client_mock.query("test_collection", query) weaviate_client_mock.client.collections.get.assert_called_with( "test_collection" ) weaviate_client_mock.client.collections.query.bm25.assert_called_with( query=query, limit=10 ) def test_update_object(weaviate_client_mock): # Test updating an object object_id = "12345" properties = {"name": "Jane"} weaviate_client_mock.update( "test_collection", object_id, properties ) weaviate_client_mock.client.collections.get.assert_called_with( "test_collection" ) weaviate_client_mock.client.collections.data.update.assert_called_with( object_id, properties ) def test_delete_object(weaviate_client_mock): # Test deleting an object object_id = "12345" weaviate_client_mock.delete("test_collection", object_id) weaviate_client_mock.client.collections.get.assert_called_with( "test_collection" ) weaviate_client_mock.client.collections.data.delete_by_id.assert_called_with( object_id ) def test_create_collection_with_vectorizer_config( weaviate_client_mock, ): # Test creating a collection with vectorizer configuration vectorizer_config = {"config_key": "config_value"} weaviate_client_mock.create_collection( "test_collection", [{"name": "property"}], vectorizer_config ) weaviate_client_mock.client.collections.create.assert_called_with( name="test_collection", vectorizer_config=vectorizer_config, properties=[{"name": "property"}], ) def test_query_objects_with_limit(weaviate_client_mock): # Test querying objects with a specified limit query = "name:John" limit = 20 weaviate_client_mock.query("test_collection", query, limit) weaviate_client_mock.client.collections.get.assert_called_with( "test_collection" ) weaviate_client_mock.client.collections.query.bm25.assert_called_with( query=query, limit=limit ) def test_query_objects_without_limit(weaviate_client_mock): # Test querying objects without specifying a limit query = "name:John" weaviate_client_mock.query("test_collection", query) weaviate_client_mock.client.collections.get.assert_called_with( "test_collection" ) weaviate_client_mock.client.collections.query.bm25.assert_called_with( query=query, limit=10 ) def test_create_collection_failure(weaviate_client_mock): # Test failure when creating a collection with patch( "weaviate_client.weaviate.collections.create", side_effect=Exception("Create error"), ): with pytest.raises( Exception, match="Error creating collection" ): weaviate_client_mock.create_collection( "test_collection", [{"name": "property"}] ) def test_add_object_failure(weaviate_client_mock): # Test failure when adding an object properties = {"name": "John"} with patch( "weaviate_client.weaviate.collections.data.insert", side_effect=Exception("Insert error"), ): with pytest.raises(Exception, match="Error adding object"): weaviate_client_mock.add("test_collection", properties) def test_query_objects_failure(weaviate_client_mock): # Test failure when querying objects query = "name:John" with patch( "weaviate_client.weaviate.collections.query.bm25", side_effect=Exception("Query error"), ): with pytest.raises(Exception, match="Error querying objects"): weaviate_client_mock.query("test_collection", query) def test_update_object_failure(weaviate_client_mock): # Test failure when updating an object object_id = "12345" properties = {"name": "Jane"} with patch( "weaviate_client.weaviate.collections.data.update", side_effect=Exception("Update error"), ): with pytest.raises(Exception, match="Error updating object"): weaviate_client_mock.update( "test_collection", object_id, properties ) def test_delete_object_failure(weaviate_client_mock): # Test failure when deleting an object object_id = "12345" with patch( "weaviate_client.weaviate.collections.data.delete_by_id", side_effect=Exception("Delete error"), ): with pytest.raises(Exception, match="Error deleting object"): weaviate_client_mock.delete("test_collection", object_id)