diff --git a/swarms/memory/qdrant.py b/swarms/memory/qdrant.py index 56596965..83ff5593 100644 --- a/swarms/memory/qdrant.py +++ b/swarms/memory/qdrant.py @@ -1,5 +1,6 @@ import subprocess from typing import List + from httpx import RequestError try: @@ -15,8 +16,8 @@ try: from qdrant_client import QdrantClient from qdrant_client.http.models import ( Distance, - VectorParams, PointStruct, + VectorParams, ) except ImportError: print("Please install the qdrant-client package") @@ -91,7 +92,7 @@ class Qdrant: ) print(f"Collection '{self.collection_name}' created.") - def add_vectors(self, docs: List[dict]): + def add(self, docs: List[dict], *args, **kwargs): """ Adds vector representations of documents to the Qdrant collection. @@ -128,13 +129,15 @@ class Qdrant: collection_name=self.collection_name, wait=True, points=points, + *args, + **kwargs, ) return operation_info except Exception as e: print(f"Error adding vectors: {e}") return None - def search_vectors(self, query: str, limit: int = 3): + def query(self, query: str, limit: int = 3, *args, **kwargs): """ Searches the collection for vectors similar to the query vector. @@ -147,12 +150,14 @@ class Qdrant: """ try: query_vector = self.model.encode( - query, normalize_embeddings=True + query, normalize_embeddings=True, *args, **kwargs ) search_result = self.client.search( collection_name=self.collection_name, query_vector=query_vector, limit=limit, + *args, + **kwargs, ) return search_result except Exception as e: