From dec98fe28f26de4a63fdbd27af4bbd17e1e0f0d4 Mon Sep 17 00:00:00 2001 From: Sashin Date: Thu, 23 Nov 2023 22:18:08 +0200 Subject: [PATCH] Initiaal push --- pyproject.toml | 2 ++ swarms/memory/qdrant.py | 59 +++++++++++++++++++++++++++++++++-- swarms/memory/qdrant/usage.py | 18 +++++++++++ 3 files changed, 76 insertions(+), 3 deletions(-) create mode 100644 swarms/memory/qdrant/usage.py diff --git a/pyproject.toml b/pyproject.toml index 075bbd15..b9b0f89a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,9 @@ classifiers = [ [tool.poetry.dependencies] python = "^3.8.1" +sentence_transformers = "*" transformers = "*" +qdrant_client = "*" openai = "0.28.0" langchain = "*" asyncio = "*" diff --git a/swarms/memory/qdrant.py b/swarms/memory/qdrant.py index 7bc5018e..24f0dc82 100644 --- a/swarms/memory/qdrant.py +++ b/swarms/memory/qdrant.py @@ -1,6 +1,59 @@ -""" -QDRANT MEMORY CLASS +from httpx import RequestError +from qdrant_client import QdrantClient +from qdrant_client.http.models import Distance, VectorParams, PointStruct +from sentence_transformers import SentenceTransformer +class Qdrant: + def __init__(self,api_key, host, port=6333, collection_name="qdrant", model_name="BAAI/bge-small-en-v1.5", https=True ): + self.client = QdrantClient(url=host, port=port, api_key=api_key) #, port=port, api_key=api_key, https=False + self.collection_name = collection_name + self._load_embedding_model(model_name) + self._setup_collection() + def _load_embedding_model(self, model_name): + # Load the embedding model + self.model = SentenceTransformer(model_name) + def _setup_collection(self): + # Check if the collection already exists + try: + exists = self.client.get_collection(self.collection_name) + return + except Exception: + # Collection does not exist, create it + self.client.create_collection( + collection_name=self.collection_name, + vectors_config=VectorParams(size=self.model.get_sentence_embedding_dimension(), distance=Distance.DOT), + ) + print(f"Collection '{self.collection_name}' created.") + else: + print(f"Collection '{self.collection_name}' already exists.") -""" + def add_vectors(self, docs): + # Add vectors with payloads to the collection + points = [] + for i, doc in enumerate(docs): + if doc.page_content: + embedding = self.model.encode(doc.page_content, normalize_embeddings=True) + points.append(PointStruct(id=i + 1, vector=embedding, payload={"content":doc.page_content})) + else: + print(f"Document at index {i} is missing 'text' or 'payload' key") + + operation_info = self.client.upsert( + collection_name=self.collection_name, + wait=True, + points=points, + ) + print(operation_info) + def search_vectors(self, query, limit=3): + query_vector= self.model.encode(query, normalize_embeddings=True) + # Search for similar vectors + search_result = self.client.search( + collection_name=self.collection_name, + query_vector=query_vector, + limit=limit + ) + return search_result + + + +#TODO, use kwargs in constructor, have search result be text \ No newline at end of file diff --git a/swarms/memory/qdrant/usage.py b/swarms/memory/qdrant/usage.py new file mode 100644 index 00000000..0378d540 --- /dev/null +++ b/swarms/memory/qdrant/usage.py @@ -0,0 +1,18 @@ +from langchain.document_loaders import CSVLoader +from swarms.memory import qdrant + +loader = CSVLoader(file_path="../document_parsing/aipg/aipg.csv", encoding='utf-8-sig') +docs = loader.load() + + +# Initialize the Qdrant instance +# See qdrant documentation on how to run locally +qdrant_client = qdrant.Qdrant(host ="https://697ea26c-2881-4e17-8af4-817fcb5862e8.europe-west3-0.gcp.cloud.qdrant.io", collection_name="qdrant", api_key="BhG2_yINqNU-aKovSEBadn69Zszhbo5uaqdJ6G_qDkdySjAljvuPqQ") +qdrant_client.add_vectors(docs) + +# Perform a search +search_query = "Who is jojo" +search_results = qdrant_client.search_vectors(search_query) +print("Search Results:") +for result in search_results: + print(result)