parent
							
								
									2906576fc2
								
							
						
					
					
						commit
						dec98fe28f
					
				@ -1,6 +1,59 @@
 | 
				
			||||
"""
 | 
				
			||||
QDRANT MEMORY CLASS
 | 
				
			||||
from httpx import RequestError
 | 
				
			||||
from qdrant_client import QdrantClient
 | 
				
			||||
from qdrant_client.http.models import Distance, VectorParams, PointStruct
 | 
				
			||||
from sentence_transformers import SentenceTransformer
 | 
				
			||||
class Qdrant:
 | 
				
			||||
    def __init__(self,api_key, host, port=6333, collection_name="qdrant", model_name="BAAI/bge-small-en-v1.5", https=True ):
 | 
				
			||||
        self.client = QdrantClient(url=host, port=port, api_key=api_key) #, port=port, api_key=api_key, https=False
 | 
				
			||||
        self.collection_name = collection_name
 | 
				
			||||
        self._load_embedding_model(model_name)
 | 
				
			||||
        self._setup_collection()
 | 
				
			||||
 | 
				
			||||
    def _load_embedding_model(self, model_name):
 | 
				
			||||
        # Load the embedding model
 | 
				
			||||
        self.model = SentenceTransformer(model_name)
 | 
				
			||||
 | 
				
			||||
    def _setup_collection(self):
 | 
				
			||||
        # Check if the collection already exists
 | 
				
			||||
        try:
 | 
				
			||||
            exists = self.client.get_collection(self.collection_name)
 | 
				
			||||
            return
 | 
				
			||||
        except Exception:
 | 
				
			||||
            # Collection does not exist, create it
 | 
				
			||||
            self.client.create_collection(
 | 
				
			||||
                collection_name=self.collection_name,
 | 
				
			||||
                vectors_config=VectorParams(size=self.model.get_sentence_embedding_dimension(), distance=Distance.DOT),
 | 
				
			||||
            )
 | 
				
			||||
            print(f"Collection '{self.collection_name}' created.")
 | 
				
			||||
        else:
 | 
				
			||||
            print(f"Collection '{self.collection_name}' already exists.")
 | 
				
			||||
 | 
				
			||||
"""
 | 
				
			||||
    def add_vectors(self, docs):
 | 
				
			||||
        # Add vectors with payloads to the collection
 | 
				
			||||
        points = []
 | 
				
			||||
        for i, doc in enumerate(docs):
 | 
				
			||||
            if doc.page_content:
 | 
				
			||||
                embedding = self.model.encode(doc.page_content, normalize_embeddings=True)
 | 
				
			||||
                points.append(PointStruct(id=i + 1, vector=embedding, payload={"content":doc.page_content}))
 | 
				
			||||
            else:
 | 
				
			||||
                print(f"Document at index {i} is missing 'text' or 'payload' key")
 | 
				
			||||
 | 
				
			||||
        operation_info = self.client.upsert(
 | 
				
			||||
            collection_name=self.collection_name,
 | 
				
			||||
            wait=True,
 | 
				
			||||
            points=points,
 | 
				
			||||
        )
 | 
				
			||||
        print(operation_info)
 | 
				
			||||
    def search_vectors(self, query, limit=3):
 | 
				
			||||
        query_vector= self.model.encode(query, normalize_embeddings=True)
 | 
				
			||||
        # Search for similar vectors
 | 
				
			||||
        search_result = self.client.search(
 | 
				
			||||
            collection_name=self.collection_name, 
 | 
				
			||||
            query_vector=query_vector, 
 | 
				
			||||
            limit=limit
 | 
				
			||||
        )
 | 
				
			||||
        return search_result
 | 
				
			||||
 | 
				
			||||
 | 
				
			||||
 | 
				
			||||
#TODO, use kwargs in constructor, have search result be text
 | 
				
			||||
@ -0,0 +1,18 @@
 | 
				
			||||
from langchain.document_loaders import CSVLoader
 | 
				
			||||
from swarms.memory import qdrant
 | 
				
			||||
 | 
				
			||||
loader = CSVLoader(file_path="../document_parsing/aipg/aipg.csv", encoding='utf-8-sig')
 | 
				
			||||
docs = loader.load()
 | 
				
			||||
 | 
				
			||||
 | 
				
			||||
# Initialize the Qdrant instance
 | 
				
			||||
# See qdrant documentation on how to run locally
 | 
				
			||||
qdrant_client = qdrant.Qdrant(host ="https://697ea26c-2881-4e17-8af4-817fcb5862e8.europe-west3-0.gcp.cloud.qdrant.io", collection_name="qdrant", api_key="BhG2_yINqNU-aKovSEBadn69Zszhbo5uaqdJ6G_qDkdySjAljvuPqQ")
 | 
				
			||||
qdrant_client.add_vectors(docs)
 | 
				
			||||
 | 
				
			||||
# Perform a search
 | 
				
			||||
search_query = "Who is jojo"
 | 
				
			||||
search_results = qdrant_client.search_vectors(search_query)
 | 
				
			||||
print("Search Results:")
 | 
				
			||||
for result in search_results:
 | 
				
			||||
    print(result)
 | 
				
			||||
					Loading…
					
					
				
		Reference in new issue