commit
e113a77dc8
@ -0,0 +1,81 @@
|
||||
# Qdrant Client Library
|
||||
|
||||
## Overview
|
||||
|
||||
The Qdrant Client Library is designed for interacting with the Qdrant vector database, allowing efficient storage and retrieval of high-dimensional vector data. It integrates with machine learning models for embedding and is particularly suited for search and recommendation systems.
|
||||
|
||||
## Installation
|
||||
|
||||
```python
|
||||
pip install qdrant-client sentence-transformers httpx
|
||||
```
|
||||
|
||||
## Class Definition: Qdrant
|
||||
|
||||
```python
|
||||
class Qdrant:
|
||||
def __init__(self, api_key: str, host: str, port: int = 6333, collection_name: str = "qdrant", model_name: str = "BAAI/bge-small-en-v1.5", https: bool = True):
|
||||
...
|
||||
```
|
||||
|
||||
### Constructor Parameters
|
||||
|
||||
| Parameter | Type | Description | Default Value |
|
||||
|-----------------|---------|--------------------------------------------------|-----------------------|
|
||||
| api_key | str | API key for authentication. | - |
|
||||
| host | str | Host address of the Qdrant server. | - |
|
||||
| port | int | Port number for the Qdrant server. | 6333 |
|
||||
| collection_name | str | Name of the collection to be used or created. | "qdrant" |
|
||||
| model_name | str | Name of the sentence transformer model. | "BAAI/bge-small-en-v1.5" |
|
||||
| https | bool | Flag to use HTTPS for connection. | True |
|
||||
|
||||
### Methods
|
||||
|
||||
#### `_load_embedding_model(model_name: str)`
|
||||
|
||||
Loads the sentence embedding model.
|
||||
|
||||
#### `_setup_collection()`
|
||||
|
||||
Checks if the specified collection exists in Qdrant; if not, creates it.
|
||||
|
||||
#### `add_vectors(docs: List[dict]) -> OperationResponse`
|
||||
|
||||
Adds vectors to the Qdrant collection.
|
||||
|
||||
#### `search_vectors(query: str, limit: int = 3) -> SearchResult`
|
||||
|
||||
Searches the Qdrant collection for vectors similar to the query vector.
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### Example 1: Setting Up the Qdrant Client
|
||||
|
||||
```python
|
||||
from qdrant_client import Qdrant
|
||||
|
||||
qdrant_client = Qdrant(api_key="your_api_key", host="localhost", port=6333)
|
||||
```
|
||||
|
||||
### Example 2: Adding Vectors to a Collection
|
||||
|
||||
```python
|
||||
documents = [
|
||||
{"page_content": "Sample text 1"},
|
||||
{"page_content": "Sample text 2"}
|
||||
]
|
||||
|
||||
operation_info = qdrant_client.add_vectors(documents)
|
||||
print(operation_info)
|
||||
```
|
||||
|
||||
### Example 3: Searching for Vectors
|
||||
|
||||
```python
|
||||
search_result = qdrant_client.search_vectors("Sample search query")
|
||||
print(search_result)
|
||||
```
|
||||
|
||||
## Further Information
|
||||
|
||||
Refer to the [Qdrant Documentation](https://qdrant.tech/docs) for more details on the Qdrant vector database.
|
@ -0,0 +1,18 @@
|
||||
from langchain.document_loaders import CSVLoader
|
||||
from swarms.memory import qdrant
|
||||
|
||||
loader = CSVLoader(file_path="../document_parsing/aipg/aipg.csv", encoding='utf-8-sig')
|
||||
docs = loader.load()
|
||||
|
||||
|
||||
# Initialize the Qdrant instance
|
||||
# See qdrant documentation on how to run locally
|
||||
qdrant_client = qdrant.Qdrant(host ="https://697ea26c-2881-4e17-8af4-817fcb5862e8.europe-west3-0.gcp.cloud.qdrant.io", collection_name="qdrant", api_key="BhG2_yINqNU-aKovSEBadn69Zszhbo5uaqdJ6G_qDkdySjAljvuPqQ")
|
||||
qdrant_client.add_vectors(docs)
|
||||
|
||||
# Perform a search
|
||||
search_query = "Who is jojo"
|
||||
search_results = qdrant_client.search_vectors(search_query)
|
||||
print("Search Results:")
|
||||
for result in search_results:
|
||||
print(result)
|
@ -1,6 +1,110 @@
|
||||
"""
|
||||
QDRANT MEMORY CLASS
|
||||
from typing import List
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from httpx import RequestError
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.http.models import Distance, VectorParams, PointStruct
|
||||
|
||||
class Qdrant:
|
||||
def __init__(self, api_key: str, host: str, port: int = 6333, collection_name: str = "qdrant", model_name: str = "BAAI/bge-small-en-v1.5", https: bool = True):
|
||||
"""
|
||||
Qdrant class for managing collections and performing vector operations using QdrantClient.
|
||||
|
||||
Attributes:
|
||||
client (QdrantClient): The Qdrant client for interacting with the Qdrant server.
|
||||
collection_name (str): Name of the collection to be managed in Qdrant.
|
||||
model (SentenceTransformer): The model used for generating sentence embeddings.
|
||||
|
||||
"""
|
||||
Args:
|
||||
api_key (str): API key for authenticating with Qdrant.
|
||||
host (str): Host address of the Qdrant server.
|
||||
port (int): Port number of the Qdrant server. Defaults to 6333.
|
||||
collection_name (str): Name of the collection to be used or created. Defaults to "qdrant".
|
||||
model_name (str): Name of the model to be used for embeddings. Defaults to "BAAI/bge-small-en-v1.5".
|
||||
https (bool): Flag to indicate if HTTPS should be used. Defaults to True.
|
||||
"""
|
||||
try:
|
||||
self.client = QdrantClient(url=host, port=port, api_key=api_key)
|
||||
self.collection_name = collection_name
|
||||
self._load_embedding_model(model_name)
|
||||
self._setup_collection()
|
||||
except RequestError as e:
|
||||
print(f"Error setting up QdrantClient: {e}")
|
||||
|
||||
def _load_embedding_model(self, model_name: str):
|
||||
"""
|
||||
Loads the sentence embedding model specified by the model name.
|
||||
|
||||
Args:
|
||||
model_name (str): The name of the model to load for generating embeddings.
|
||||
"""
|
||||
try:
|
||||
self.model = SentenceTransformer(model_name)
|
||||
except Exception as e:
|
||||
print(f"Error loading embedding model: {e}")
|
||||
|
||||
def _setup_collection(self):
|
||||
try:
|
||||
exists = self.client.get_collection(self.collection_name)
|
||||
if exists:
|
||||
print(f"Collection '{self.collection_name}' already exists.")
|
||||
except Exception as e:
|
||||
self.client.create_collection(
|
||||
collection_name=self.collection_name,
|
||||
vectors_config=VectorParams(size=self.model.get_sentence_embedding_dimension(), distance=Distance.DOT),
|
||||
)
|
||||
print(f"Collection '{self.collection_name}' created.")
|
||||
|
||||
def add_vectors(self, docs: List[dict]):
|
||||
"""
|
||||
Adds vector representations of documents to the Qdrant collection.
|
||||
|
||||
Args:
|
||||
docs (List[dict]): A list of documents where each document is a dictionary with at least a 'page_content' key.
|
||||
|
||||
Returns:
|
||||
OperationResponse or None: Returns the operation information if successful, otherwise None.
|
||||
"""
|
||||
points = []
|
||||
for i, doc in enumerate(docs):
|
||||
try:
|
||||
if 'page_content' in doc:
|
||||
embedding = self.model.encode(doc['page_content'], normalize_embeddings=True)
|
||||
points.append(PointStruct(id=i + 1, vector=embedding, payload={"content": doc['page_content']}))
|
||||
else:
|
||||
print(f"Document at index {i} is missing 'page_content' key")
|
||||
except Exception as e:
|
||||
print(f"Error processing document at index {i}: {e}")
|
||||
|
||||
try:
|
||||
operation_info = self.client.upsert(
|
||||
collection_name=self.collection_name,
|
||||
wait=True,
|
||||
points=points,
|
||||
)
|
||||
return operation_info
|
||||
except Exception as e:
|
||||
print(f"Error adding vectors: {e}")
|
||||
return None
|
||||
|
||||
def search_vectors(self, query: str, limit: int = 3):
|
||||
"""
|
||||
Searches the collection for vectors similar to the query vector.
|
||||
|
||||
Args:
|
||||
query (str): The query string to be converted into a vector and used for searching.
|
||||
limit (int): The number of search results to return. Defaults to 3.
|
||||
|
||||
Returns:
|
||||
SearchResult or None: Returns the search results if successful, otherwise None.
|
||||
"""
|
||||
try:
|
||||
query_vector = self.model.encode(query, normalize_embeddings=True)
|
||||
search_result = self.client.search(
|
||||
collection_name=self.collection_name,
|
||||
query_vector=query_vector,
|
||||
limit=limit
|
||||
)
|
||||
return search_result
|
||||
except Exception as e:
|
||||
print(f"Error searching vectors: {e}")
|
||||
return None
|
||||
|
@ -0,0 +1,40 @@
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from swarms.memory.qdrant import Qdrant
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_qdrant_client():
|
||||
with patch('your_module.QdrantClient') as MockQdrantClient:
|
||||
yield MockQdrantClient()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_sentence_transformer():
|
||||
with patch('sentence_transformers.SentenceTransformer') as MockSentenceTransformer:
|
||||
yield MockSentenceTransformer()
|
||||
|
||||
@pytest.fixture
|
||||
def qdrant_client(mock_qdrant_client, mock_sentence_transformer):
|
||||
client = Qdrant(api_key="your_api_key", host="your_host")
|
||||
yield client
|
||||
|
||||
def test_qdrant_init(qdrant_client, mock_qdrant_client):
|
||||
assert qdrant_client.client is not None
|
||||
|
||||
def test_load_embedding_model(qdrant_client, mock_sentence_transformer):
|
||||
qdrant_client._load_embedding_model("model_name")
|
||||
mock_sentence_transformer.assert_called_once_with("model_name")
|
||||
|
||||
def test_setup_collection(qdrant_client, mock_qdrant_client):
|
||||
qdrant_client._setup_collection()
|
||||
mock_qdrant_client.get_collection.assert_called_once_with(qdrant_client.collection_name)
|
||||
|
||||
def test_add_vectors(qdrant_client, mock_qdrant_client):
|
||||
mock_doc = Mock(page_content="Sample text")
|
||||
qdrant_client.add_vectors([mock_doc])
|
||||
mock_qdrant_client.upsert.assert_called_once()
|
||||
|
||||
def test_search_vectors(qdrant_client, mock_qdrant_client):
|
||||
qdrant_client.search_vectors("test query")
|
||||
mock_qdrant_client.search.assert_called_once()
|
Loading…
Reference in new issue