diff --git a/swarms/memory/__init__.py b/swarms/memory/__init__.py index 71a7871d..631db943 100644 --- a/swarms/memory/__init__.py +++ b/swarms/memory/__init__.py @@ -1,5 +1,6 @@ from swarms.memory.base_vectordb import VectorDatabase from swarms.memory.short_term_memory import ShortTermMemory from swarms.memory.sqlite import SQLiteDB +from swarms.memory.weaviate_db import WeaviateClient -__all__ = ["VectorDatabase", "ShortTermMemory", "SQLiteDB"] +__all__ = ["VectorDatabase", "ShortTermMemory", "SQLiteDB", "WeaviateClient"] diff --git a/swarms/memory/weaviate_db.py b/swarms/memory/weaviate_db.py new file mode 100644 index 00000000..83bb1b77 --- /dev/null +++ b/swarms/memory/weaviate_db.py @@ -0,0 +1,182 @@ +""" +Weaviate API Client +""" + +from typing import Any, Dict, List, Optional + +from swarms.memory.base_vectordb import VectorDatabase + +try: + import weaviate +except ImportError as error: + print("pip install weaviate-client") + +class WeaviateClient(VectorDatabase): + """ + + Weaviate API Client + Interface to Weaviate, a vector database with a GraphQL API. + + Args: + http_host (str): The HTTP host of the Weaviate server. + http_port (str): The HTTP port of the Weaviate server. + http_secure (bool): Whether to use HTTPS. + grpc_host (Optional[str]): The gRPC host of the Weaviate server. + grpc_port (Optional[str]): The gRPC port of the Weaviate server. + grpc_secure (Optional[bool]): Whether to use gRPC over TLS. + auth_client_secret (Optional[Any]): The authentication client secret. + additional_headers (Optional[Dict[str, str]]): Additional headers to send with requests. + additional_config (Optional[weaviate.AdditionalConfig]): Additional configuration for the client. + + Methods: + create_collection: Create a new collection in Weaviate. + add: Add an object to a specified collection. + query: Query objects from a specified collection. + update: Update an object in a specified collection. + delete: Delete an object from a specified collection. + + Examples: + >>> from swarms.memory import WeaviateClient + """ + + def __init__( + self, + http_host: str, + http_port: str, + http_secure: bool, + grpc_host: Optional[str] = None, + grpc_port: Optional[str] = None, + grpc_secure: Optional[bool] = None, + auth_client_secret: Optional[Any] = None, + additional_headers: Optional[Dict[str, str]] = None, + additional_config: Optional[weaviate.AdditionalConfig] = None, + connection_params: Dict[str, Any] = None, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.http_host = http_host + self.http_port = http_port + self.http_secure = http_secure + self.grpc_host = grpc_host + self.grpc_port = grpc_port + self.grpc_secure = grpc_secure + self.auth_client_secret = auth_client_secret + self.additional_headers = additional_headers + self.additional_config = additional_config + self.connection_params = connection_params + + # If connection_params are provided, use them to initialize the client. + connection_params = weaviate.ConnectionParams.from_params( + http_host=http_host, + http_port=http_port, + http_secure=http_secure, + grpc_host=grpc_host, + grpc_port=grpc_port, + grpc_secure=grpc_secure, + ) + + # If additional headers are provided, add them to the connection params. + self.client = weaviate.WeaviateClient( + connection_params=connection_params, + auth_client_secret=auth_client_secret, + additional_headers=additional_headers, + additional_config=additional_config, + ) + + def create_collection( + self, + name: str, + properties: List[Dict[str, Any]], + vectorizer_config: Any = None, + ): + """Create a new collection in Weaviate. + + Args: + name (str): _description_ + properties (List[Dict[str, Any]]): _description_ + vectorizer_config (Any, optional): _description_. Defaults to None. + """ + try: + out = self.client.collections.create( + name=name, + vectorizer_config=vectorizer_config, + properties=properties, + ) + print(out) + except Exception as error: + print(f"Error creating collection: {error}") + raise + + def add(self, collection_name: str, properties: Dict[str, Any]): + """Add an object to a specified collection. + + Args: + collection_name (str): _description_ + properties (Dict[str, Any]): _description_ + + Returns: + _type_: _description_ + """ + try: + collection = self.client.collections.get(collection_name) + return collection.data.insert(properties) + except Exception as error: + print(f"Error adding object: {error}") + raise + + def query( + self, collection_name: str, query: str, limit: int = 10 + ): + """Query objects from a specified collection. + + Args: + collection_name (str): _description_ + query (str): _description_ + limit (int, optional): _description_. Defaults to 10. + + Returns: + _type_: _description_ + """ + try: + collection = self.client.collections.get(collection_name) + response = collection.query.bm25(query=query, limit=limit) + return [o.properties for o in response.objects] + except Exception as error: + print(f"Error querying objects: {error}") + raise + + def update( + self, + collection_name: str, + object_id: str, + properties: Dict[str, Any], + ): + """UPdate an object in a specified collection. + + Args: + collection_name (str): _description_ + object_id (str): _description_ + properties (Dict[str, Any]): _description_ + """ + try: + collection = self.client.collections.get(collection_name) + collection.data.update(object_id, properties) + except Exception as error: + print(f"Error updating object: {error}") + raise + + def delete(self, collection_name: str, object_id: str): + """Delete an object from a specified collection. + + Args: + collection_name (str): _description_ + object_id (str): _description_ + """ + try: + collection = self.client.collections.get(collection_name) + collection.data.delete_by_id(object_id) + except Exception as error: + print(f"Error deleting object: {error}") + raise +