|
|
|
@ -4,24 +4,33 @@ from httpx import RequestError
|
|
|
|
|
from qdrant_client import QdrantClient
|
|
|
|
|
from qdrant_client.http.models import Distance, VectorParams, PointStruct
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Qdrant:
|
|
|
|
|
def __init__(self, api_key: str, host: str, port: int = 6333, collection_name: str = "qdrant", model_name: str = "BAAI/bge-small-en-v1.5", https: bool = True):
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
api_key: str,
|
|
|
|
|
host: str,
|
|
|
|
|
port: int = 6333,
|
|
|
|
|
collection_name: str = "qdrant",
|
|
|
|
|
model_name: str = "BAAI/bge-small-en-v1.5",
|
|
|
|
|
https: bool = True,
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
Qdrant class for managing collections and performing vector operations using QdrantClient.
|
|
|
|
|
Qdrant class for managing collections and performing vector operations using QdrantClient.
|
|
|
|
|
|
|
|
|
|
Attributes:
|
|
|
|
|
client (QdrantClient): The Qdrant client for interacting with the Qdrant server.
|
|
|
|
|
collection_name (str): Name of the collection to be managed in Qdrant.
|
|
|
|
|
model (SentenceTransformer): The model used for generating sentence embeddings.
|
|
|
|
|
Attributes:
|
|
|
|
|
client (QdrantClient): The Qdrant client for interacting with the Qdrant server.
|
|
|
|
|
collection_name (str): Name of the collection to be managed in Qdrant.
|
|
|
|
|
model (SentenceTransformer): The model used for generating sentence embeddings.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
api_key (str): API key for authenticating with Qdrant.
|
|
|
|
|
host (str): Host address of the Qdrant server.
|
|
|
|
|
port (int): Port number of the Qdrant server. Defaults to 6333.
|
|
|
|
|
collection_name (str): Name of the collection to be used or created. Defaults to "qdrant".
|
|
|
|
|
model_name (str): Name of the model to be used for embeddings. Defaults to "BAAI/bge-small-en-v1.5".
|
|
|
|
|
https (bool): Flag to indicate if HTTPS should be used. Defaults to True.
|
|
|
|
|
"""
|
|
|
|
|
Args:
|
|
|
|
|
api_key (str): API key for authenticating with Qdrant.
|
|
|
|
|
host (str): Host address of the Qdrant server.
|
|
|
|
|
port (int): Port number of the Qdrant server. Defaults to 6333.
|
|
|
|
|
collection_name (str): Name of the collection to be used or created. Defaults to "qdrant".
|
|
|
|
|
model_name (str): Name of the model to be used for embeddings. Defaults to "BAAI/bge-small-en-v1.5".
|
|
|
|
|
https (bool): Flag to indicate if HTTPS should be used. Defaults to True.
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
self.client = QdrantClient(url=host, port=port, api_key=api_key)
|
|
|
|
|
self.collection_name = collection_name
|
|
|
|
@ -50,7 +59,10 @@ class Qdrant:
|
|
|
|
|
except Exception as e:
|
|
|
|
|
self.client.create_collection(
|
|
|
|
|
collection_name=self.collection_name,
|
|
|
|
|
vectors_config=VectorParams(size=self.model.get_sentence_embedding_dimension(), distance=Distance.DOT),
|
|
|
|
|
vectors_config=VectorParams(
|
|
|
|
|
size=self.model.get_sentence_embedding_dimension(),
|
|
|
|
|
distance=Distance.DOT,
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
print(f"Collection '{self.collection_name}' created.")
|
|
|
|
|
|
|
|
|
@ -67,11 +79,21 @@ class Qdrant:
|
|
|
|
|
points = []
|
|
|
|
|
for i, doc in enumerate(docs):
|
|
|
|
|
try:
|
|
|
|
|
if 'page_content' in doc:
|
|
|
|
|
embedding = self.model.encode(doc['page_content'], normalize_embeddings=True)
|
|
|
|
|
points.append(PointStruct(id=i + 1, vector=embedding, payload={"content": doc['page_content']}))
|
|
|
|
|
if "page_content" in doc:
|
|
|
|
|
embedding = self.model.encode(
|
|
|
|
|
doc["page_content"], normalize_embeddings=True
|
|
|
|
|
)
|
|
|
|
|
points.append(
|
|
|
|
|
PointStruct(
|
|
|
|
|
id=i + 1,
|
|
|
|
|
vector=embedding,
|
|
|
|
|
payload={"content": doc["page_content"]},
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
print(f"Document at index {i} is missing 'page_content' key")
|
|
|
|
|
print(
|
|
|
|
|
f"Document at index {i} is missing 'page_content' key"
|
|
|
|
|
)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
print(f"Error processing document at index {i}: {e}")
|
|
|
|
|
|
|
|
|
@ -102,7 +124,7 @@ class Qdrant:
|
|
|
|
|
search_result = self.client.search(
|
|
|
|
|
collection_name=self.collection_name,
|
|
|
|
|
query_vector=query_vector,
|
|
|
|
|
limit=limit
|
|
|
|
|
limit=limit,
|
|
|
|
|
)
|
|
|
|
|
return search_result
|
|
|
|
|
except Exception as e:
|
|
|
|
|