parent
a199e95ed7
commit
6a9cd36a32
@ -1,11 +1,6 @@
|
||||
try:
|
||||
from swarms.memory.weaviate import WeaviateClient
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
from swarms.memory.base_vectordb import VectorDatabase
|
||||
|
||||
__all__ = [
|
||||
"WeaviateClient",
|
||||
"VectorDatabase",
|
||||
]
|
||||
|
@ -1,216 +0,0 @@
|
||||
"""
|
||||
Weaviate API Client
|
||||
"""
|
||||
import os
|
||||
import subprocess
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from swarms.memory.base_vectordb import VectorDatabase
|
||||
|
||||
try:
|
||||
import weaviate
|
||||
except ImportError as error:
|
||||
print("pip install weaviate-client")
|
||||
subprocess.run(["pip", "install", "weaviate-client"])
|
||||
|
||||
|
||||
class WeaviateClient(VectorDatabase):
|
||||
"""
|
||||
|
||||
Weaviate API Client
|
||||
Interface to Weaviate, a vector database with a GraphQL API.
|
||||
|
||||
Args:
|
||||
http_host (str): The HTTP host of the Weaviate server.
|
||||
http_port (str): The HTTP port of the Weaviate server.
|
||||
http_secure (bool): Whether to use HTTPS.
|
||||
grpc_host (Optional[str]): The gRPC host of the Weaviate server.
|
||||
grpc_port (Optional[str]): The gRPC port of the Weaviate server.
|
||||
grpc_secure (Optional[bool]): Whether to use gRPC over TLS.
|
||||
auth_client_secret (Optional[Any]): The authentication client secret.
|
||||
additional_headers (Optional[Dict[str, str]]): Additional headers to send with requests.
|
||||
additional_config (Optional[weaviate.AdditionalConfig]): Additional configuration for the client.
|
||||
|
||||
Methods:
|
||||
create_collection: Create a new collection in Weaviate.
|
||||
add: Add an object to a specified collection.
|
||||
query: Query objects from a specified collection.
|
||||
update: Update an object in a specified collection.
|
||||
delete: Delete an object from a specified collection.
|
||||
|
||||
Examples:
|
||||
>>> from swarms.memory import WeaviateClient
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
http_host: str,
|
||||
http_port: str,
|
||||
http_secure: bool,
|
||||
grpc_host: Optional[str] = None,
|
||||
grpc_port: Optional[str] = None,
|
||||
grpc_secure: Optional[bool] = None,
|
||||
auth_client_secret: Optional[Any] = None,
|
||||
additional_headers: Optional[Dict[str, str]] = None,
|
||||
additional_config: Optional[weaviate.AdditionalConfig] = None,
|
||||
connection_params: Dict[str, Any] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.http_host = http_host
|
||||
self.http_port = http_port
|
||||
self.http_secure = http_secure
|
||||
self.grpc_host = grpc_host
|
||||
self.grpc_port = grpc_port
|
||||
self.grpc_secure = grpc_secure
|
||||
self.auth_client_secret = auth_client_secret
|
||||
self.additional_headers = additional_headers
|
||||
self.additional_config = additional_config
|
||||
self.connection_params = connection_params
|
||||
|
||||
# If connection_params are provided, use them to initialize the client.
|
||||
connection_params = weaviate.ConnectionParams.from_params(
|
||||
http_host=http_host,
|
||||
http_port=http_port,
|
||||
http_secure=http_secure,
|
||||
grpc_host=grpc_host,
|
||||
grpc_port=grpc_port,
|
||||
grpc_secure=grpc_secure,
|
||||
)
|
||||
|
||||
# If additional headers are provided, add them to the connection params.
|
||||
self.client = weaviate.WeaviateClient(
|
||||
connection_params=connection_params,
|
||||
auth_client_secret=auth_client_secret,
|
||||
additional_headers=additional_headers,
|
||||
additional_config=additional_config,
|
||||
)
|
||||
|
||||
def create_collection(
|
||||
self,
|
||||
name: str,
|
||||
properties: List[Dict[str, Any]],
|
||||
vectorizer_config: Any = None,
|
||||
):
|
||||
"""Create a new collection in Weaviate.
|
||||
|
||||
Args:
|
||||
name (str): _description_
|
||||
properties (List[Dict[str, Any]]): _description_
|
||||
vectorizer_config (Any, optional): _description_. Defaults to None.
|
||||
"""
|
||||
try:
|
||||
out = self.client.collections.create(
|
||||
name=name,
|
||||
vectorizer_config=vectorizer_config,
|
||||
properties=properties,
|
||||
)
|
||||
print(out)
|
||||
except Exception as error:
|
||||
print(f"Error creating collection: {error}")
|
||||
raise
|
||||
|
||||
def add(self, collection_name: str, properties: Dict[str, Any]):
|
||||
"""Add an object to a specified collection.
|
||||
|
||||
Args:
|
||||
collection_name (str): _description_
|
||||
properties (Dict[str, Any]): _description_
|
||||
|
||||
Returns:
|
||||
_type_: _description_
|
||||
"""
|
||||
try:
|
||||
collection = self.client.collections.get(collection_name)
|
||||
return collection.data.insert(properties)
|
||||
except Exception as error:
|
||||
print(f"Error adding object: {error}")
|
||||
raise
|
||||
|
||||
def query(
|
||||
self, collection_name: str, query: str, limit: int = 10
|
||||
):
|
||||
"""Query objects from a specified collection.
|
||||
|
||||
Args:
|
||||
collection_name (str): _description_
|
||||
query (str): _description_
|
||||
limit (int, optional): _description_. Defaults to 10.
|
||||
|
||||
Returns:
|
||||
_type_: _description_
|
||||
"""
|
||||
try:
|
||||
collection = self.client.collections.get(collection_name)
|
||||
response = collection.query.bm25(query=query, limit=limit)
|
||||
return [o.properties for o in response.objects]
|
||||
except Exception as error:
|
||||
print(f"Error querying objects: {error}")
|
||||
raise
|
||||
|
||||
def update(
|
||||
self,
|
||||
collection_name: str,
|
||||
object_id: str,
|
||||
properties: Dict[str, Any],
|
||||
):
|
||||
"""UPdate an object in a specified collection.
|
||||
|
||||
Args:
|
||||
collection_name (str): _description_
|
||||
object_id (str): _description_
|
||||
properties (Dict[str, Any]): _description_
|
||||
"""
|
||||
try:
|
||||
collection = self.client.collections.get(collection_name)
|
||||
collection.data.update(object_id, properties)
|
||||
except Exception as error:
|
||||
print(f"Error updating object: {error}")
|
||||
raise
|
||||
|
||||
def delete(self, collection_name: str, object_id: str):
|
||||
"""Delete an object from a specified collection.
|
||||
|
||||
Args:
|
||||
collection_name (str): _description_
|
||||
object_id (str): _description_
|
||||
"""
|
||||
try:
|
||||
collection = self.client.collections.get(collection_name)
|
||||
collection.data.delete_by_id(object_id)
|
||||
except Exception as error:
|
||||
print(f"Error deleting object: {error}")
|
||||
raise
|
||||
|
||||
|
||||
# # Example usage
|
||||
# connection_params = {
|
||||
# "http_host": "YOUR_HTTP_HOST",
|
||||
# "http_port": "YOUR_HTTP_PORT",
|
||||
# "http_secure": True,
|
||||
# "grpc_host": "YOUR_gRPC_HOST",
|
||||
# "grpc_port": "YOUR_gRPC_PORT",
|
||||
# "grpc_secure": True,
|
||||
# "auth_client_secret": weaviate.AuthApiKey("YOUR_APIKEY"),
|
||||
# "additional_headers": {"X-OpenAI-Api-Key": "YOUR_OPENAI_APIKEY"},
|
||||
# "additional_config": weaviate.AdditionalConfig(
|
||||
# startup_period=10, timeout=(5, 15)
|
||||
# ),
|
||||
# }
|
||||
|
||||
# weaviate_client = WeaviateClient(connection_params)
|
||||
|
||||
|
||||
# # Example usage
|
||||
# weaviate_client = WeaviateClient(
|
||||
# http_host="YOUR_HTTP_HOST",
|
||||
# http_port="YOUR_HTTP_PORT",
|
||||
# http_secure=True,
|
||||
# grpc_host="YOUR_gRPC_HOST",
|
||||
# grpc_port="YOUR_gRPC_PORT",
|
||||
# grpc_secure=True,
|
||||
# auth_client_secret=weaviate.AuthApiKey("YOUR_APIKEY"),
|
||||
# additional_headers={"X-OpenAI-Api-Key": "YOUR_OPENAI_APIKEY"},
|
||||
# additional_config=weaviate.AdditionalConfig(startup_period=10, timeout=(5, 15))
|
||||
# )
|
Loading…
Reference in new issue