From 830c7bf651dcac656b5cd5b120133dac8e6a7f5d Mon Sep 17 00:00:00 2001 From: Kye Date: Sat, 9 Dec 2023 17:39:50 -0800 Subject: [PATCH] [WEAVIAT] --- docs/swarms/memory/weaviate.md | 215 ++++++++++++++++++++++++++++++ mkdocs.yml | 1 + playground/memory/chroma_usage.py | 5 +- swarms/memory/base_vectordb.py | 4 +- swarms/memory/weaviate.py | 214 ++++++++++++++++++++++++++++- swarms/structs/agent.py | 7 +- tests/memory/test_weaviate.py | 196 +++++++++++++++++++++++++++ 7 files changed, 630 insertions(+), 12 deletions(-) create mode 100644 docs/swarms/memory/weaviate.md create mode 100644 tests/memory/test_weaviate.py diff --git a/docs/swarms/memory/weaviate.md b/docs/swarms/memory/weaviate.md new file mode 100644 index 00000000..da28be4e --- /dev/null +++ b/docs/swarms/memory/weaviate.md @@ -0,0 +1,215 @@ +# Weaviate API Client Documentation + +## Overview + +The Weaviate API Client is an interface to Weaviate, a vector database with a GraphQL API. This client allows you to interact with Weaviate programmatically, making it easier to create collections, add objects, query data, update objects, and delete objects within your Weaviate instance. + +This documentation provides a comprehensive guide on how to use the Weaviate API Client, including its initialization, methods, and usage examples. + +## Table of Contents + +- [Installation](#installation) +- [Initialization](#initialization) +- [Methods](#methods) + - [create_collection](#create-collection) + - [add](#add) + - [query](#query) + - [update](#update) + - [delete](#delete) +- [Examples](#examples) + +## Installation + +Before using the Weaviate API Client, make sure to install the `weaviate-client` library. You can install it using pip: + +```bash +pip install weaviate-client +``` + +## Initialization + +To use the Weaviate API Client, you need to initialize an instance of the `WeaviateClient` class. Here are the parameters you can pass to the constructor: + +| Parameter | Type | Description | +|----------------------|----------------|----------------------------------------------------------------------------------------------------------------------------------| +| `http_host` | str | The HTTP host of the Weaviate server. | +| `http_port` | str | The HTTP port of the Weaviate server. | +| `http_secure` | bool | Whether to use HTTPS. | +| `grpc_host` | Optional[str] | The gRPC host of the Weaviate server. (Optional) | +| `grpc_port` | Optional[str] | The gRPC port of the Weaviate server. (Optional) | +| `grpc_secure` | Optional[bool] | Whether to use gRPC over TLS. (Optional) | +| `auth_client_secret` | Optional[Any] | The authentication client secret. (Optional) | +| `additional_headers` | Optional[Dict[str, str]] | Additional headers to send with requests. (Optional) | +| `additional_config` | Optional[weaviate.AdditionalConfig] | Additional configuration for the client. (Optional) | +| `connection_params` | Dict[str, Any] | Dictionary containing connection parameters. This parameter is used internally and can be ignored in most cases. | + +Here's an example of how to initialize a WeaviateClient: + +```python +from weaviate_client import WeaviateClient + +weaviate_client = WeaviateClient( + http_host="YOUR_HTTP_HOST", + http_port="YOUR_HTTP_PORT", + http_secure=True, + grpc_host="YOUR_gRPC_HOST", + grpc_port="YOUR_gRPC_PORT", + grpc_secure=True, + auth_client_secret="YOUR_APIKEY", + additional_headers={"X-OpenAI-Api-Key": "YOUR_OPENAI_APIKEY"}, + additional_config=None, # You can pass additional configuration here +) +``` + +## Methods + +### `create_collection` + +The `create_collection` method allows you to create a new collection in Weaviate. A collection is a container for storing objects with specific properties. + +#### Parameters + +- `name` (str): The name of the collection. +- `properties` (List[Dict[str, Any]]): A list of dictionaries specifying the properties of objects to be stored in the collection. +- `vectorizer_config` (Any, optional): Additional vectorizer configuration for the collection. (Optional) + +#### Usage + +```python +weaviate_client.create_collection( + name="my_collection", + properties=[ + {"name": "property1", "dataType": ["string"]}, + {"name": "property2", "dataType": ["int"]}, + ], + vectorizer_config=None # Optional vectorizer configuration +) +``` + +### `add` + +The `add` method allows you to add an object to a specified collection in Weaviate. + +#### Parameters + +- `collection_name` (str): The name of the collection where the object will be added. +- `properties` (Dict[str, Any]): A dictionary specifying the properties of the object to be added. + +#### Usage + +```python +weaviate_client.add( + collection_name="my_collection", + properties={"property1": "value1", "property2": 42} +) +``` + +### `query` + +The `query` method allows you to query objects from a specified collection in Weaviate. + +#### Parameters + +- `collection_name` (str): The name of the collection to query. +- `query` (str): The query string specifying the search criteria. +- `limit` (int, optional): The maximum number of results to return. (Default: 10) + +#### Usage + +```python +results = weaviate_client.query( + collection_name="my_collection", + query="property1:value1", + limit=20 # Optional, specify the limit + + if needed +) +``` + +### `update` + +The `update` method allows you to update an object in a specified collection in Weaviate. + +#### Parameters + +- `collection_name` (str): The name of the collection where the object exists. +- `object_id` (str): The ID of the object to be updated. +- `properties` (Dict[str, Any]): A dictionary specifying the properties to update. + +#### Usage + +```python +weaviate_client.update( + collection_name="my_collection", + object_id="object123", + properties={"property1": "new_value", "property2": 99} +) +``` + +### `delete` + +The `delete` method allows you to delete an object from a specified collection in Weaviate. + +#### Parameters + +- `collection_name` (str): The name of the collection from which to delete the object. +- `object_id` (str): The ID of the object to delete. + +#### Usage + +```python +weaviate_client.delete( + collection_name="my_collection", + object_id="object123" +) +``` + +## Examples + +Here are three examples demonstrating how to use the Weaviate API Client for common tasks: + +### Example 1: Creating a Collection + +```python +weaviate_client.create_collection( + name="people", + properties=[ + {"name": "name", "dataType": ["string"]}, + {"name": "age", "dataType": ["int"]} + ] +) +``` + +### Example 2: Adding an Object + +```python +weaviate_client.add( + collection_name="people", + properties={"name": "John", "age": 30} +) +``` + +### Example 3: Querying Objects + +```python +results = weaviate_client.query( + collection_name="people", + query="name:John", + limit=5 +) +``` + +These examples cover the basic operations of creating collections, adding objects, and querying objects using the Weaviate API Client. + +## Additional Information and Tips + +- If you encounter any errors during the operations, the client will raise exceptions with informative error messages. +- You can explore more advanced features and configurations in the Weaviate documentation. +- Make sure to handle authentication and security appropriately when using the client in production environments. + +## References and Resources + +- [Weaviate Documentation](https://weaviate.readthedocs.io/en/latest/): Official documentation for Weaviate. +- [Weaviate GitHub Repository](https://github.com/semi-technologies/weaviate): The source code and issue tracker for Weaviate. + +This documentation provides a comprehensive guide on using the Weaviate API Client to interact with Weaviate, making it easier to manage and query your data. \ No newline at end of file diff --git a/mkdocs.yml b/mkdocs.yml index f2f6294e..aed284a0 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -101,6 +101,7 @@ nav: - Agent: "swarms/structs/agent.md" - SequentialWorkflow: 'swarms/structs/sequential_workflow.md' - swarms.memory: + - Weaviate: "swarms/memory/weaviate.md" - PineconeVectorStoreStore: "swarms/memory/pinecone.md" - PGVectorStore: "swarms/memory/pg.md" - Guides: diff --git a/playground/memory/chroma_usage.py b/playground/memory/chroma_usage.py index 21ae475a..c17efa3a 100644 --- a/playground/memory/chroma_usage.py +++ b/playground/memory/chroma_usage.py @@ -2,9 +2,10 @@ from swarms.memory import chroma chromadbcl = chroma.ChromaClient() -chromadbcl.add_vectors(["This is a document", "BONSAIIIIIII", "the walking dead"]) +chromadbcl.add_vectors( + ["This is a document", "BONSAIIIIIII", "the walking dead"] +) results = chromadbcl.search_vectors("zombie", limit=1) print(results) - diff --git a/swarms/memory/base_vectordb.py b/swarms/memory/base_vectordb.py index 991bc8b5..841c6147 100644 --- a/swarms/memory/base_vectordb.py +++ b/swarms/memory/base_vectordb.py @@ -17,9 +17,7 @@ class VectorDatabase(ABC): pass @abstractmethod - def query( - self, text: str, num_results: int - ) -> Dict[str, Any]: + def query(self, text: str, num_results: int) -> Dict[str, Any]: """ Query the database for vectors similar to the given vector. diff --git a/swarms/memory/weaviate.py b/swarms/memory/weaviate.py index a482f71b..2f06e302 100644 --- a/swarms/memory/weaviate.py +++ b/swarms/memory/weaviate.py @@ -1,4 +1,216 @@ """ Weaviate API Client - """ +import os +import subprocess +from typing import Any, Dict, List, Optional + +from swarms.memory.base_vectordb import VectorDatabase + +try: + import weaviate +except ImportError as error: + print("pip install weaviate-client") + subprocess.run(["pip", "install", "weaviate-client"]) + + +class WeaviateClient(VectorDatabase): + """ + + Weaviate API Client + Interface to Weaviate, a vector database with a GraphQL API. + + Args: + http_host (str): The HTTP host of the Weaviate server. + http_port (str): The HTTP port of the Weaviate server. + http_secure (bool): Whether to use HTTPS. + grpc_host (Optional[str]): The gRPC host of the Weaviate server. + grpc_port (Optional[str]): The gRPC port of the Weaviate server. + grpc_secure (Optional[bool]): Whether to use gRPC over TLS. + auth_client_secret (Optional[Any]): The authentication client secret. + additional_headers (Optional[Dict[str, str]]): Additional headers to send with requests. + additional_config (Optional[weaviate.AdditionalConfig]): Additional configuration for the client. + + Methods: + create_collection: Create a new collection in Weaviate. + add: Add an object to a specified collection. + query: Query objects from a specified collection. + update: Update an object in a specified collection. + delete: Delete an object from a specified collection. + + Examples: + >>> from swarms.memory import WeaviateClient + """ + + def __init__( + self, + http_host: str, + http_port: str, + http_secure: bool, + grpc_host: Optional[str] = None, + grpc_port: Optional[str] = None, + grpc_secure: Optional[bool] = None, + auth_client_secret: Optional[Any] = None, + additional_headers: Optional[Dict[str, str]] = None, + additional_config: Optional[weaviate.AdditionalConfig] = None, + connection_params: Dict[str, Any] = None, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.http_host = http_host + self.http_port = http_port + self.http_secure = http_secure + self.grpc_host = grpc_host + self.grpc_port = grpc_port + self.grpc_secure = grpc_secure + self.auth_client_secret = auth_client_secret + self.additional_headers = additional_headers + self.additional_config = additional_config + self.connection_params = connection_params + + # If connection_params are provided, use them to initialize the client. + connection_params = weaviate.ConnectionParams.from_params( + http_host=http_host, + http_port=http_port, + http_secure=http_secure, + grpc_host=grpc_host, + grpc_port=grpc_port, + grpc_secure=grpc_secure, + ) + + # If additional headers are provided, add them to the connection params. + self.client = weaviate.WeaviateClient( + connection_params=connection_params, + auth_client_secret=auth_client_secret, + additional_headers=additional_headers, + additional_config=additional_config, + ) + + def create_collection( + self, + name: str, + properties: List[Dict[str, Any]], + vectorizer_config: Any = None, + ): + """Create a new collection in Weaviate. + + Args: + name (str): _description_ + properties (List[Dict[str, Any]]): _description_ + vectorizer_config (Any, optional): _description_. Defaults to None. + """ + try: + out = self.client.collections.create( + name=name, + vectorizer_config=vectorizer_config, + properties=properties, + ) + print(out) + except Exception as error: + print(f"Error creating collection: {error}") + raise + + def add(self, collection_name: str, properties: Dict[str, Any]): + """Add an object to a specified collection. + + Args: + collection_name (str): _description_ + properties (Dict[str, Any]): _description_ + + Returns: + _type_: _description_ + """ + try: + collection = self.client.collections.get(collection_name) + return collection.data.insert(properties) + except Exception as error: + print(f"Error adding object: {error}") + raise + + def query( + self, collection_name: str, query: str, limit: int = 10 + ): + """Query objects from a specified collection. + + Args: + collection_name (str): _description_ + query (str): _description_ + limit (int, optional): _description_. Defaults to 10. + + Returns: + _type_: _description_ + """ + try: + collection = self.client.collections.get(collection_name) + response = collection.query.bm25(query=query, limit=limit) + return [o.properties for o in response.objects] + except Exception as error: + print(f"Error querying objects: {error}") + raise + + def update( + self, + collection_name: str, + object_id: str, + properties: Dict[str, Any], + ): + """UPdate an object in a specified collection. + + Args: + collection_name (str): _description_ + object_id (str): _description_ + properties (Dict[str, Any]): _description_ + """ + try: + collection = self.client.collections.get(collection_name) + collection.data.update(object_id, properties) + except Exception as error: + print(f"Error updating object: {error}") + raise + + def delete(self, collection_name: str, object_id: str): + """Delete an object from a specified collection. + + Args: + collection_name (str): _description_ + object_id (str): _description_ + """ + try: + collection = self.client.collections.get(collection_name) + collection.data.delete_by_id(object_id) + except Exception as error: + print(f"Error deleting object: {error}") + raise + + +# # Example usage +# connection_params = { +# "http_host": "YOUR_HTTP_HOST", +# "http_port": "YOUR_HTTP_PORT", +# "http_secure": True, +# "grpc_host": "YOUR_gRPC_HOST", +# "grpc_port": "YOUR_gRPC_PORT", +# "grpc_secure": True, +# "auth_client_secret": weaviate.AuthApiKey("YOUR_APIKEY"), +# "additional_headers": {"X-OpenAI-Api-Key": "YOUR_OPENAI_APIKEY"}, +# "additional_config": weaviate.AdditionalConfig( +# startup_period=10, timeout=(5, 15) +# ), +# } + +# weaviate_client = WeaviateClient(connection_params) + + +# # Example usage +# weaviate_client = WeaviateClient( +# http_host="YOUR_HTTP_HOST", +# http_port="YOUR_HTTP_PORT", +# http_secure=True, +# grpc_host="YOUR_gRPC_HOST", +# grpc_port="YOUR_gRPC_PORT", +# grpc_secure=True, +# auth_client_secret=weaviate.AuthApiKey("YOUR_APIKEY"), +# additional_headers={"X-OpenAI-Api-Key": "YOUR_OPENAI_APIKEY"}, +# additional_config=weaviate.AdditionalConfig(startup_period=10, timeout=(5, 15)) +# ) diff --git a/swarms/structs/agent.py b/swarms/structs/agent.py index aca1b041..2ce86479 100644 --- a/swarms/structs/agent.py +++ b/swarms/structs/agent.py @@ -763,11 +763,7 @@ class Agent: """ return agent_history_prompt - def agent_memory_prompt( - self, - query, - prompt - ): + def agent_memory_prompt(self, query, prompt): """ Generate the agent long term memory prompt @@ -789,7 +785,6 @@ class Agent: return context_injected_prompt - async def run_concurrent(self, tasks: List[str], **kwargs): """ Run a batch of tasks concurrently and handle an infinite level of task inputs. diff --git a/tests/memory/test_weaviate.py b/tests/memory/test_weaviate.py new file mode 100644 index 00000000..09dc6d45 --- /dev/null +++ b/tests/memory/test_weaviate.py @@ -0,0 +1,196 @@ +import pytest +from unittest.mock import Mock, patch +from swarms.memory.weaviate import WeaviateClient + + +# Define fixture for a WeaviateClient instance with mocked methods +@pytest.fixture +def weaviate_client_mock(): + client = WeaviateClient( + http_host="mock_host", + http_port="mock_port", + http_secure=False, + grpc_host="mock_grpc_host", + grpc_port="mock_grpc_port", + grpc_secure=False, + auth_client_secret="mock_api_key", + additional_headers={ + "X-OpenAI-Api-Key": "mock_openai_api_key" + }, + additional_config=Mock(), + ) + + # Mock the methods + client.client.collections.create = Mock() + client.client.collections.get = Mock() + client.client.collections.query = Mock() + client.client.collections.data.insert = Mock() + client.client.collections.data.update = Mock() + client.client.collections.data.delete_by_id = Mock() + + return client + + +# Define tests for the WeaviateClient class +def test_create_collection(weaviate_client_mock): + # Test creating a collection + weaviate_client_mock.create_collection( + "test_collection", [{"name": "property"}] + ) + weaviate_client_mock.client.collections.create.assert_called_with( + name="test_collection", + vectorizer_config=None, + properties=[{"name": "property"}], + ) + + +def test_add_object(weaviate_client_mock): + # Test adding an object + properties = {"name": "John"} + weaviate_client_mock.add("test_collection", properties) + weaviate_client_mock.client.collections.get.assert_called_with( + "test_collection" + ) + weaviate_client_mock.client.collections.data.insert.assert_called_with( + properties + ) + + +def test_query_objects(weaviate_client_mock): + # Test querying objects + query = "name:John" + weaviate_client_mock.query("test_collection", query) + weaviate_client_mock.client.collections.get.assert_called_with( + "test_collection" + ) + weaviate_client_mock.client.collections.query.bm25.assert_called_with( + query=query, limit=10 + ) + + +def test_update_object(weaviate_client_mock): + # Test updating an object + object_id = "12345" + properties = {"name": "Jane"} + weaviate_client_mock.update( + "test_collection", object_id, properties + ) + weaviate_client_mock.client.collections.get.assert_called_with( + "test_collection" + ) + weaviate_client_mock.client.collections.data.update.assert_called_with( + object_id, properties + ) + + +def test_delete_object(weaviate_client_mock): + # Test deleting an object + object_id = "12345" + weaviate_client_mock.delete("test_collection", object_id) + weaviate_client_mock.client.collections.get.assert_called_with( + "test_collection" + ) + weaviate_client_mock.client.collections.data.delete_by_id.assert_called_with( + object_id + ) + + +def test_create_collection_with_vectorizer_config( + weaviate_client_mock, +): + # Test creating a collection with vectorizer configuration + vectorizer_config = {"config_key": "config_value"} + weaviate_client_mock.create_collection( + "test_collection", [{"name": "property"}], vectorizer_config + ) + weaviate_client_mock.client.collections.create.assert_called_with( + name="test_collection", + vectorizer_config=vectorizer_config, + properties=[{"name": "property"}], + ) + + +def test_query_objects_with_limit(weaviate_client_mock): + # Test querying objects with a specified limit + query = "name:John" + limit = 20 + weaviate_client_mock.query("test_collection", query, limit) + weaviate_client_mock.client.collections.get.assert_called_with( + "test_collection" + ) + weaviate_client_mock.client.collections.query.bm25.assert_called_with( + query=query, limit=limit + ) + + +def test_query_objects_without_limit(weaviate_client_mock): + # Test querying objects without specifying a limit + query = "name:John" + weaviate_client_mock.query("test_collection", query) + weaviate_client_mock.client.collections.get.assert_called_with( + "test_collection" + ) + weaviate_client_mock.client.collections.query.bm25.assert_called_with( + query=query, limit=10 + ) + + +def test_create_collection_failure(weaviate_client_mock): + # Test failure when creating a collection + with patch( + "weaviate_client.weaviate.collections.create", + side_effect=Exception("Create error"), + ): + with pytest.raises( + Exception, match="Error creating collection" + ): + weaviate_client_mock.create_collection( + "test_collection", [{"name": "property"}] + ) + + +def test_add_object_failure(weaviate_client_mock): + # Test failure when adding an object + properties = {"name": "John"} + with patch( + "weaviate_client.weaviate.collections.data.insert", + side_effect=Exception("Insert error"), + ): + with pytest.raises(Exception, match="Error adding object"): + weaviate_client_mock.add("test_collection", properties) + + +def test_query_objects_failure(weaviate_client_mock): + # Test failure when querying objects + query = "name:John" + with patch( + "weaviate_client.weaviate.collections.query.bm25", + side_effect=Exception("Query error"), + ): + with pytest.raises(Exception, match="Error querying objects"): + weaviate_client_mock.query("test_collection", query) + + +def test_update_object_failure(weaviate_client_mock): + # Test failure when updating an object + object_id = "12345" + properties = {"name": "Jane"} + with patch( + "weaviate_client.weaviate.collections.data.update", + side_effect=Exception("Update error"), + ): + with pytest.raises(Exception, match="Error updating object"): + weaviate_client_mock.update( + "test_collection", object_id, properties + ) + + +def test_delete_object_failure(weaviate_client_mock): + # Test failure when deleting an object + object_id = "12345" + with patch( + "weaviate_client.weaviate.collections.data.delete_by_id", + side_effect=Exception("Delete error"), + ): + with pytest.raises(Exception, match="Error deleting object"): + weaviate_client_mock.delete("test_collection", object_id)