Initiaal push

pull/175/head
Sashin 1 year ago
parent d985fdad00
commit 018a0f437f

@ -9,51 +9,56 @@ class Qdrant:
self._load_embedding_model(model_name) self._load_embedding_model(model_name)
self._setup_collection() self._setup_collection()
def _load_embedding_model(self, model_name): def _load_embedding_model(self, model_name: str):
# Load the embedding model try:
self.model = SentenceTransformer(model_name) self.model = SentenceTransformer(model_name)
except Exception as e:
print(f"Error loading embedding model: {e}")
def _setup_collection(self): def _setup_collection(self):
# Check if the collection already exists
try: try:
exists = self.client.get_collection(self.collection_name) exists = self.client.get_collection(self.collection_name)
return if exists:
except Exception: print(f"Collection '{self.collection_name}' already exists.")
# Collection does not exist, create it except Exception as e:
self.client.create_collection( self.client.create_collection(
collection_name=self.collection_name, collection_name=self.collection_name,
vectors_config=VectorParams(size=self.model.get_sentence_embedding_dimension(), distance=Distance.DOT), vectors_config=VectorParams(size=self.model.get_sentence_embedding_dimension(), distance=Distance.DOT),
) )
print(f"Collection '{self.collection_name}' created.") print(f"Collection '{self.collection_name}' created.")
else:
print(f"Collection '{self.collection_name}' already exists.")
def add_vectors(self, docs): def add_vectors(self, docs: List[dict]):
# Add vectors with payloads to the collection
points = [] points = []
for i, doc in enumerate(docs): for i, doc in enumerate(docs):
if doc.page_content: try:
embedding = self.model.encode(doc.page_content, normalize_embeddings=True) if 'page_content' in doc:
points.append(PointStruct(id=i + 1, vector=embedding, payload={"content":doc.page_content})) embedding = self.model.encode(doc['page_content'], normalize_embeddings=True)
else: points.append(PointStruct(id=i + 1, vector=embedding, payload={"content": doc['page_content']}))
print(f"Document at index {i} is missing 'text' or 'payload' key") else:
print(f"Document at index {i} is missing 'page_content' key")
operation_info = self.client.upsert( except Exception as e:
collection_name=self.collection_name, print(f"Error processing document at index {i}: {e}")
wait=True,
points=points,
)
print(operation_info)
def search_vectors(self, query, limit=3):
query_vector= self.model.encode(query, normalize_embeddings=True)
# Search for similar vectors
search_result = self.client.search(
collection_name=self.collection_name,
query_vector=query_vector,
limit=limit
)
return search_result
try:
operation_info = self.client.upsert(
collection_name=self.collection_name,
wait=True,
points=points,
)
return operation_info
except Exception as e:
print(f"Error adding vectors: {e}")
return None
#TODO, use kwargs in constructor, have search result be text def search_vectors(self, query: str, limit: int = 3):
try:
query_vector = self.model.encode(query, normalize_embeddings=True)
search_result = self.client.search(
collection_name=self.collection_name,
query_vector=query_vector,
limit=limit
)
return search_result
except Exception as e:
print(f"Error searching vectors: {e}")
return None

@ -0,0 +1,40 @@
import pytest
from unittest.mock import Mock, patch
from swarms.memory.qdrant import Qdrant
@pytest.fixture
def mock_qdrant_client():
with patch('your_module.QdrantClient') as MockQdrantClient:
yield MockQdrantClient()
@pytest.fixture
def mock_sentence_transformer():
with patch('sentence_transformers.SentenceTransformer') as MockSentenceTransformer:
yield MockSentenceTransformer()
@pytest.fixture
def qdrant_client(mock_qdrant_client, mock_sentence_transformer):
client = Qdrant(api_key="your_api_key", host="your_host")
yield client
def test_qdrant_init(qdrant_client, mock_qdrant_client):
assert qdrant_client.client is not None
def test_load_embedding_model(qdrant_client, mock_sentence_transformer):
qdrant_client._load_embedding_model("model_name")
mock_sentence_transformer.assert_called_once_with("model_name")
def test_setup_collection(qdrant_client, mock_qdrant_client):
qdrant_client._setup_collection()
mock_qdrant_client.get_collection.assert_called_once_with(qdrant_client.collection_name)
def test_add_vectors(qdrant_client, mock_qdrant_client):
mock_doc = Mock(page_content="Sample text")
qdrant_client.add_vectors([mock_doc])
mock_qdrant_client.upsert.assert_called_once()
def test_search_vectors(qdrant_client, mock_qdrant_client):
qdrant_client.search_vectors("test query")
mock_qdrant_client.search.assert_called_once()
Loading…
Cancel
Save