parent
a4aab51655
commit
b4e614ce51
@ -1,106 +0,0 @@
|
|||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Any, Dict, List
|
|
||||||
from chromadb.utils import embedding_functions
|
|
||||||
from httpx import RequestError
|
|
||||||
import chromadb
|
|
||||||
|
|
||||||
from swarms.memory.base_vector_db import VectorDatabase
|
|
||||||
|
|
||||||
|
|
||||||
class ChromaClient(VectorDatabase):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
collection_name: str = "chromadb-collection",
|
|
||||||
model_name: str = "BAAI/bge-small-en-v1.5",
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
self.client = chromadb.Client()
|
|
||||||
self.collection_name = collection_name
|
|
||||||
self.model = None
|
|
||||||
self.collection = None
|
|
||||||
self._load_embedding_model(model_name)
|
|
||||||
self._setup_collection()
|
|
||||||
except RequestError as e:
|
|
||||||
print(f"Error setting up QdrantClient: {e}")
|
|
||||||
|
|
||||||
def _load_embedding_model(self, model_name: str):
|
|
||||||
"""
|
|
||||||
Loads the sentence embedding model specified by the model name.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_name (str): The name of the model to load for generating embeddings.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
self.model =embedding_functions.SentenceTransformerEmbeddingFunction(model_name=model_name)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error loading embedding model: {e}")
|
|
||||||
|
|
||||||
def _setup_collection(self):
|
|
||||||
try:
|
|
||||||
self.collection = self.client.get_collection(name=self.collection_name, embedding_function=self.model)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"{e}. Creating new collection: {self.collection}")
|
|
||||||
|
|
||||||
self.collection = self.client.create_collection(name=self.collection_name, embedding_function=self.model)
|
|
||||||
|
|
||||||
|
|
||||||
def add_vectors(self, docs: List[str]):
|
|
||||||
"""
|
|
||||||
Adds vector representations of documents to the Qdrant collection.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
docs (List[dict]): A list of documents where each document is a dictionary with at least a 'page_content' key.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
OperationResponse or None: Returns the operation information if successful, otherwise None.
|
|
||||||
"""
|
|
||||||
points = []
|
|
||||||
ids = []
|
|
||||||
for i, doc in enumerate(docs):
|
|
||||||
try:
|
|
||||||
points.append(doc)
|
|
||||||
ids.append("id"+str(i))
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error processing document at index {i}: {e}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
self.collection.add(
|
|
||||||
documents=points,
|
|
||||||
ids=ids
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error adding vectors: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def search_vectors(self, query: str, limit: int = 2):
|
|
||||||
"""
|
|
||||||
Searches the collection for vectors similar to the query vector.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query (str): The query string to be converted into a vector and used for searching.
|
|
||||||
limit (int): The number of search results to return. Defaults to 3.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
SearchResult or None: Returns the search results if successful, otherwise None.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
search_result = self.collection.query(
|
|
||||||
query_texts=query,
|
|
||||||
n_results=limit,
|
|
||||||
)
|
|
||||||
return search_result
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error searching vectors: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def add(self, vector: Dict[str, Any], metadata: Dict[str, Any]) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def query(self, vector: Dict[str, Any], num_results: int) -> Dict[str, Any]:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def delete(self, vector_id: str) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def update(self, vector_id: str, vector: Dict[str, Any], metadata: Dict[str, Any]) -> None:
|
|
||||||
pass
|
|
Loading…
Reference in new issue