import threading from pathlib import Path from langchain.chains import RetrievalQA from langchain.chains.question_answering import load_qa_chain from langchain.embeddings.openai import OpenAIEmbeddings from langchain.text_splitter import CharacterTextSplitter from langchain.vectorstores import Chroma from swarms.memory.base_vectordb import AbstractVectorDatabase from swarms.models.popular_llms import OpenAIChat def synchronized_mem(method): """ Decorator that synchronizes access to a method using a lock. Args: method: The method to be decorated. Returns: The decorated method. """ def wrapper(self, *args, **kwargs): with self.lock: try: return method(self, *args, **kwargs) except Exception as e: print(f"Failed to execute {method.__name__}: {e}") return wrapper class LangchainChromaVectorMemory(AbstractVectorDatabase): """ A class representing a vector memory for storing and retrieving text entries. Attributes: loc (str): The location of the vector memory. chunk_size (int): The size of each text chunk. chunk_overlap_frac (float): The fraction of overlap between text chunks. embeddings (OpenAIEmbeddings): The embeddings used for text representation. count (int): The current count of text entries in the vector memory. lock (threading.Lock): A lock for thread safety. db (Chroma): The Chroma database for storing text entries. qa (RetrievalQA): The retrieval QA system for answering questions. Methods: __init__: Initializes the VectorMemory object. _init_db: Initializes the Chroma database. _init_retriever: Initializes the retrieval QA system. add_entry: Adds an entry to the vector memory. search_memory: Searches the vector memory for similar entries. ask_question: Asks a question to the vector memory. """ def __init__( self, loc=None, chunk_size: int = 1000, chunk_overlap_frac: float = 0.1, *args, **kwargs, ): """ Initializes the VectorMemory object. Args: loc (str): The location of the vector memory. If None, defaults to "./tmp/vector_memory". chunk_size (int): The size of each text chunk. chunk_overlap_frac (float): The fraction of overlap between text chunks. """ if loc is None: loc = "./tmp/vector_memory" self.loc = Path(loc) self.chunk_size = chunk_size self.chunk_overlap = chunk_size * chunk_overlap_frac self.embeddings = OpenAIEmbeddings() self.count = 0 self.lock = threading.Lock() self.db = self._init_db() self.qa = self._init_retriever() def _init_db(self): """ Initializes the Chroma database. Returns: Chroma: The initialized Chroma database. """ texts = [ "init" ] # TODO find how to initialize Chroma without any text chroma_db = Chroma.from_texts( texts=texts, embedding=self.embeddings, persist_directory=str(self.loc), ) self.count = chroma_db._collection.count() return chroma_db def _init_retriever(self): """ Initializes the retrieval QA system. Returns: RetrievalQA: The initialized retrieval QA system. """ model = OpenAIChat( model_name="gpt-3.5-turbo", ) qa_chain = load_qa_chain(model, chain_type="stuff") retriever = self.db.as_retriever( search_type="mmr", search_kwargs={"k": 10} ) qa = RetrievalQA( combine_documents_chain=qa_chain, retriever=retriever ) return qa @synchronized_mem def add(self, entry: str): """ Add an entry to the internal memory. Args: entry (str): The entry to be added. Returns: bool: True if the entry was successfully added, False otherwise. """ text_splitter = CharacterTextSplitter( chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap, separator=" ", ) texts = text_splitter.split_text(entry) self.db.add_texts(texts) self.count += self.db._collection.count() self.db.persist() return True @synchronized_mem def search_memory( self, query: str, k=10, type="mmr", distance_threshold=0.5 ): """ Searching the vector memory for similar entries. Args: query (str): The query to search for. k (int): The number of results to return. type (str): The type of search to perform: "cos" or "mmr". distance_threshold (float): The similarity threshold to use for the search. Results with distance > similarity_threshold will be dropped. Returns: list[str]: A list of the top k results. """ self.count = self.db._collection.count() if k > self.count: k = self.count - 1 if k <= 0: return None if type == "mmr": texts = self.db.max_marginal_relevance_search( query=query, k=k, fetch_k=min(20, self.count) ) texts = [text.page_content for text in texts] elif type == "cos": texts = self.db.similarity_search_with_score( query=query, k=k ) texts = [ text[0].page_content for text in texts if text[-1] < distance_threshold ] return texts @synchronized_mem def query(self, question: str): """ Ask a question to the vector memory. Args: question (str): The question to ask. Returns: str: The answer to the question. """ answer = self.qa.run(question) return answer