diff --git a/swarms/memory/qdrant.py b/swarms/memory/qdrant.py index 24f0dc82..c06efeba 100644 --- a/swarms/memory/qdrant.py +++ b/swarms/memory/qdrant.py @@ -9,51 +9,56 @@ class Qdrant: self._load_embedding_model(model_name) self._setup_collection() - def _load_embedding_model(self, model_name): - # Load the embedding model - self.model = SentenceTransformer(model_name) + def _load_embedding_model(self, model_name: str): + try: + self.model = SentenceTransformer(model_name) + except Exception as e: + print(f"Error loading embedding model: {e}") def _setup_collection(self): - # Check if the collection already exists try: exists = self.client.get_collection(self.collection_name) - return - except Exception: - # Collection does not exist, create it + if exists: + print(f"Collection '{self.collection_name}' already exists.") + except Exception as e: self.client.create_collection( collection_name=self.collection_name, vectors_config=VectorParams(size=self.model.get_sentence_embedding_dimension(), distance=Distance.DOT), ) print(f"Collection '{self.collection_name}' created.") - else: - print(f"Collection '{self.collection_name}' already exists.") - def add_vectors(self, docs): - # Add vectors with payloads to the collection + def add_vectors(self, docs: List[dict]): points = [] for i, doc in enumerate(docs): - if doc.page_content: - embedding = self.model.encode(doc.page_content, normalize_embeddings=True) - points.append(PointStruct(id=i + 1, vector=embedding, payload={"content":doc.page_content})) - else: - print(f"Document at index {i} is missing 'text' or 'payload' key") - - operation_info = self.client.upsert( - collection_name=self.collection_name, - wait=True, - points=points, - ) - print(operation_info) - def search_vectors(self, query, limit=3): - query_vector= self.model.encode(query, normalize_embeddings=True) - # Search for similar vectors - search_result = self.client.search( - collection_name=self.collection_name, - query_vector=query_vector, - limit=limit - ) - return search_result - + try: + if 'page_content' in doc: + embedding = self.model.encode(doc['page_content'], normalize_embeddings=True) + points.append(PointStruct(id=i + 1, vector=embedding, payload={"content": doc['page_content']})) + else: + print(f"Document at index {i} is missing 'page_content' key") + except Exception as e: + print(f"Error processing document at index {i}: {e}") + try: + operation_info = self.client.upsert( + collection_name=self.collection_name, + wait=True, + points=points, + ) + return operation_info + except Exception as e: + print(f"Error adding vectors: {e}") + return None -#TODO, use kwargs in constructor, have search result be text \ No newline at end of file + def search_vectors(self, query: str, limit: int = 3): + try: + query_vector = self.model.encode(query, normalize_embeddings=True) + search_result = self.client.search( + collection_name=self.collection_name, + query_vector=query_vector, + limit=limit + ) + return search_result + except Exception as e: + print(f"Error searching vectors: {e}") + return None diff --git a/tests/memory/qdrant.py b/tests/memory/qdrant.py new file mode 100644 index 00000000..577ede2a --- /dev/null +++ b/tests/memory/qdrant.py @@ -0,0 +1,40 @@ +import pytest +from unittest.mock import Mock, patch + +from swarms.memory.qdrant import Qdrant + + +@pytest.fixture +def mock_qdrant_client(): + with patch('your_module.QdrantClient') as MockQdrantClient: + yield MockQdrantClient() + +@pytest.fixture +def mock_sentence_transformer(): + with patch('sentence_transformers.SentenceTransformer') as MockSentenceTransformer: + yield MockSentenceTransformer() + +@pytest.fixture +def qdrant_client(mock_qdrant_client, mock_sentence_transformer): + client = Qdrant(api_key="your_api_key", host="your_host") + yield client + +def test_qdrant_init(qdrant_client, mock_qdrant_client): + assert qdrant_client.client is not None + +def test_load_embedding_model(qdrant_client, mock_sentence_transformer): + qdrant_client._load_embedding_model("model_name") + mock_sentence_transformer.assert_called_once_with("model_name") + +def test_setup_collection(qdrant_client, mock_qdrant_client): + qdrant_client._setup_collection() + mock_qdrant_client.get_collection.assert_called_once_with(qdrant_client.collection_name) + +def test_add_vectors(qdrant_client, mock_qdrant_client): + mock_doc = Mock(page_content="Sample text") + qdrant_client.add_vectors([mock_doc]) + mock_qdrant_client.upsert.assert_called_once() + +def test_search_vectors(qdrant_client, mock_qdrant_client): + qdrant_client.search_vectors("test query") + mock_qdrant_client.search.assert_called_once()