commit
a4aab51655
@ -0,0 +1,10 @@
|
|||||||
|
from swarms.memory import chroma
|
||||||
|
|
||||||
|
chromadbcl = chroma.ChromaClient()
|
||||||
|
|
||||||
|
chromadbcl.add_vectors(["This is a document", "BONSAIIIIIII", "the walking dead"])
|
||||||
|
|
||||||
|
results = chromadbcl.search_vectors("zombie", limit=1)
|
||||||
|
|
||||||
|
print(results)
|
||||||
|
|
@ -0,0 +1,106 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
from chromadb.utils import embedding_functions
|
||||||
|
from httpx import RequestError
|
||||||
|
import chromadb
|
||||||
|
|
||||||
|
from swarms.memory.base_vector_db import VectorDatabase
|
||||||
|
|
||||||
|
|
||||||
|
class ChromaClient(VectorDatabase):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
collection_name: str = "chromadb-collection",
|
||||||
|
model_name: str = "BAAI/bge-small-en-v1.5",
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
self.client = chromadb.Client()
|
||||||
|
self.collection_name = collection_name
|
||||||
|
self.model = None
|
||||||
|
self.collection = None
|
||||||
|
self._load_embedding_model(model_name)
|
||||||
|
self._setup_collection()
|
||||||
|
except RequestError as e:
|
||||||
|
print(f"Error setting up QdrantClient: {e}")
|
||||||
|
|
||||||
|
def _load_embedding_model(self, model_name: str):
|
||||||
|
"""
|
||||||
|
Loads the sentence embedding model specified by the model name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name (str): The name of the model to load for generating embeddings.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self.model =embedding_functions.SentenceTransformerEmbeddingFunction(model_name=model_name)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error loading embedding model: {e}")
|
||||||
|
|
||||||
|
def _setup_collection(self):
|
||||||
|
try:
|
||||||
|
self.collection = self.client.get_collection(name=self.collection_name, embedding_function=self.model)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"{e}. Creating new collection: {self.collection}")
|
||||||
|
|
||||||
|
self.collection = self.client.create_collection(name=self.collection_name, embedding_function=self.model)
|
||||||
|
|
||||||
|
|
||||||
|
def add_vectors(self, docs: List[str]):
|
||||||
|
"""
|
||||||
|
Adds vector representations of documents to the Qdrant collection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
docs (List[dict]): A list of documents where each document is a dictionary with at least a 'page_content' key.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
OperationResponse or None: Returns the operation information if successful, otherwise None.
|
||||||
|
"""
|
||||||
|
points = []
|
||||||
|
ids = []
|
||||||
|
for i, doc in enumerate(docs):
|
||||||
|
try:
|
||||||
|
points.append(doc)
|
||||||
|
ids.append("id"+str(i))
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing document at index {i}: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.collection.add(
|
||||||
|
documents=points,
|
||||||
|
ids=ids
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error adding vectors: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def search_vectors(self, query: str, limit: int = 2):
|
||||||
|
"""
|
||||||
|
Searches the collection for vectors similar to the query vector.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query (str): The query string to be converted into a vector and used for searching.
|
||||||
|
limit (int): The number of search results to return. Defaults to 3.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SearchResult or None: Returns the search results if successful, otherwise None.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
search_result = self.collection.query(
|
||||||
|
query_texts=query,
|
||||||
|
n_results=limit,
|
||||||
|
)
|
||||||
|
return search_result
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error searching vectors: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def add(self, vector: Dict[str, Any], metadata: Dict[str, Any]) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def query(self, vector: Dict[str, Any], num_results: int) -> Dict[str, Any]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def delete(self, vector_id: str) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def update(self, vector_id: str, vector: Dict[str, Any], metadata: Dict[str, Any]) -> None:
|
||||||
|
pass
|
Loading…
Reference in new issue