You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
161 lines
5.5 KiB
161 lines
5.5 KiB
from typing import List
|
|
|
|
from httpx import RequestError
|
|
|
|
try:
|
|
from sentence_transformers import SentenceTransformer
|
|
except ImportError:
|
|
print("Please install the sentence-transformers package")
|
|
print("pip install sentence-transformers")
|
|
|
|
try:
|
|
from qdrant_client import QdrantClient
|
|
from qdrant_client.http.models import (
|
|
Distance,
|
|
PointStruct,
|
|
VectorParams,
|
|
)
|
|
except ImportError:
|
|
print("Please install the qdrant-client package")
|
|
print("pip install qdrant-client")
|
|
|
|
|
|
class Qdrant:
|
|
"""
|
|
Qdrant class for managing collections and performing vector operations using QdrantClient.
|
|
|
|
Attributes:
|
|
client (QdrantClient): The Qdrant client for interacting with the Qdrant server.
|
|
collection_name (str): Name of the collection to be managed in Qdrant.
|
|
model (SentenceTransformer): The model used for generating sentence embeddings.
|
|
|
|
Args:
|
|
api_key (str): API key for authenticating with Qdrant.
|
|
host (str): Host address of the Qdrant server.
|
|
port (int): Port number of the Qdrant server. Defaults to 6333.
|
|
collection_name (str): Name of the collection to be used or created. Defaults to "qdrant".
|
|
model_name (str): Name of the model to be used for embeddings. Defaults to "BAAI/bge-small-en-v1.5".
|
|
https (bool): Flag to indicate if HTTPS should be used. Defaults to True.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
api_key: str,
|
|
host: str,
|
|
port: int = 6333,
|
|
collection_name: str = "qdrant",
|
|
model_name: str = "BAAI/bge-small-en-v1.5",
|
|
https: bool = True,
|
|
):
|
|
try:
|
|
self.client = QdrantClient(
|
|
url=host, port=port, api_key=api_key
|
|
)
|
|
self.collection_name = collection_name
|
|
self._load_embedding_model(model_name)
|
|
self._setup_collection()
|
|
except RequestError as e:
|
|
print(f"Error setting up QdrantClient: {e}")
|
|
|
|
def _load_embedding_model(self, model_name: str):
|
|
"""
|
|
Loads the sentence embedding model specified by the model name.
|
|
|
|
Args:
|
|
model_name (str): The name of the model to load for generating embeddings.
|
|
"""
|
|
try:
|
|
self.model = SentenceTransformer(model_name)
|
|
except Exception as e:
|
|
print(f"Error loading embedding model: {e}")
|
|
|
|
def _setup_collection(self):
|
|
try:
|
|
exists = self.client.get_collection(self.collection_name)
|
|
if exists:
|
|
print(
|
|
f"Collection '{self.collection_name}' already"
|
|
" exists."
|
|
)
|
|
except Exception:
|
|
self.client.create_collection(
|
|
collection_name=self.collection_name,
|
|
vectors_config=VectorParams(
|
|
size=self.model.get_sentence_embedding_dimension(),
|
|
distance=Distance.DOT,
|
|
),
|
|
)
|
|
print(f"Collection '{self.collection_name}' created.")
|
|
|
|
def add(self, docs: List[dict], *args, **kwargs):
|
|
"""
|
|
Adds vector representations of documents to the Qdrant collection.
|
|
|
|
Args:
|
|
docs (List[dict]): A list of documents where each document is a dictionary with at least a 'page_content' key.
|
|
|
|
Returns:
|
|
OperationResponse or None: Returns the operation information if successful, otherwise None.
|
|
"""
|
|
points = []
|
|
for i, doc in enumerate(docs):
|
|
try:
|
|
if "page_content" in doc:
|
|
embedding = self.model.encode(
|
|
doc["page_content"], normalize_embeddings=True
|
|
)
|
|
points.append(
|
|
PointStruct(
|
|
id=i + 1,
|
|
vector=embedding,
|
|
payload={"content": doc["page_content"]},
|
|
)
|
|
)
|
|
else:
|
|
print(
|
|
f"Document at index {i} is missing"
|
|
" 'page_content' key"
|
|
)
|
|
except Exception as e:
|
|
print(f"Error processing document at index {i}: {e}")
|
|
|
|
try:
|
|
operation_info = self.client.upsert(
|
|
collection_name=self.collection_name,
|
|
wait=True,
|
|
points=points,
|
|
*args,
|
|
**kwargs,
|
|
)
|
|
return operation_info
|
|
except Exception as e:
|
|
print(f"Error adding vectors: {e}")
|
|
return None
|
|
|
|
def query(self, query: str, limit: int = 3, *args, **kwargs):
|
|
"""
|
|
Searches the collection for vectors similar to the query vector.
|
|
|
|
Args:
|
|
query (str): The query string to be converted into a vector and used for searching.
|
|
limit (int): The number of search results to return. Defaults to 3.
|
|
|
|
Returns:
|
|
SearchResult or None: Returns the search results if successful, otherwise None.
|
|
"""
|
|
try:
|
|
query_vector = self.model.encode(
|
|
query, normalize_embeddings=True, *args, **kwargs
|
|
)
|
|
search_result = self.client.search(
|
|
collection_name=self.collection_name,
|
|
query_vector=query_vector,
|
|
limit=limit,
|
|
*args,
|
|
**kwargs,
|
|
)
|
|
return search_result
|
|
except Exception as e:
|
|
print(f"Error searching vectors: {e}")
|
|
return None
|