You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/swarms/memory/lanchain_chroma.py

198 lines
6.0 KiB

import threading
from pathlib import Path
from langchain.chains import RetrievalQA
from langchain.chains.question_answering import load_qa_chain
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import Chroma
from swarms.memory.base_vectordb import AbstractVectorDatabase
from swarms.models.popular_llms import OpenAIChat
def synchronized_mem(method):
"""
Decorator that synchronizes access to a method using a lock.
Args:
method: The method to be decorated.
Returns:
The decorated method.
"""
def wrapper(self, *args, **kwargs):
with self.lock:
try:
return method(self, *args, **kwargs)
except Exception as e:
print(f"Failed to execute {method.__name__}: {e}")
return wrapper
class LangchainChromaVectorMemory(AbstractVectorDatabase):
"""
A class representing a vector memory for storing and retrieving text entries.
Attributes:
loc (str): The location of the vector memory.
chunk_size (int): The size of each text chunk.
chunk_overlap_frac (float): The fraction of overlap between text chunks.
embeddings (OpenAIEmbeddings): The embeddings used for text representation.
count (int): The current count of text entries in the vector memory.
lock (threading.Lock): A lock for thread safety.
db (Chroma): The Chroma database for storing text entries.
qa (RetrievalQA): The retrieval QA system for answering questions.
Methods:
__init__: Initializes the VectorMemory object.
_init_db: Initializes the Chroma database.
_init_retriever: Initializes the retrieval QA system.
add_entry: Adds an entry to the vector memory.
search_memory: Searches the vector memory for similar entries.
ask_question: Asks a question to the vector memory.
"""
def __init__(
self,
loc=None,
chunk_size: int = 1000,
chunk_overlap_frac: float = 0.1,
*args,
**kwargs,
):
"""
Initializes the VectorMemory object.
Args:
loc (str): The location of the vector memory. If None, defaults to "./tmp/vector_memory".
chunk_size (int): The size of each text chunk.
chunk_overlap_frac (float): The fraction of overlap between text chunks.
"""
if loc is None:
loc = "./tmp/vector_memory"
self.loc = Path(loc)
self.chunk_size = chunk_size
self.chunk_overlap = chunk_size * chunk_overlap_frac
self.embeddings = OpenAIEmbeddings()
self.count = 0
self.lock = threading.Lock()
self.db = self._init_db()
self.qa = self._init_retriever()
def _init_db(self):
"""
Initializes the Chroma database.
Returns:
Chroma: The initialized Chroma database.
"""
texts = [
"init"
] # TODO find how to initialize Chroma without any text
chroma_db = Chroma.from_texts(
texts=texts,
embedding=self.embeddings,
persist_directory=str(self.loc),
)
self.count = chroma_db._collection.count()
return chroma_db
def _init_retriever(self):
"""
Initializes the retrieval QA system.
Returns:
RetrievalQA: The initialized retrieval QA system.
"""
model = OpenAIChat(
model_name="gpt-3.5-turbo",
)
qa_chain = load_qa_chain(model, chain_type="stuff")
retriever = self.db.as_retriever(
search_type="mmr", search_kwargs={"k": 10}
)
qa = RetrievalQA(
combine_documents_chain=qa_chain, retriever=retriever
)
return qa
@synchronized_mem
def add(self, entry: str):
"""
Add an entry to the internal memory.
Args:
entry (str): The entry to be added.
Returns:
bool: True if the entry was successfully added, False otherwise.
"""
text_splitter = CharacterTextSplitter(
chunk_size=self.chunk_size,
chunk_overlap=self.chunk_overlap,
separator=" ",
)
texts = text_splitter.split_text(entry)
self.db.add_texts(texts)
self.count += self.db._collection.count()
self.db.persist()
return True
@synchronized_mem
def search_memory(
self, query: str, k=10, type="mmr", distance_threshold=0.5
):
"""
Searching the vector memory for similar entries.
Args:
query (str): The query to search for.
k (int): The number of results to return.
type (str): The type of search to perform: "cos" or "mmr".
distance_threshold (float): The similarity threshold to use for the search. Results with distance > similarity_threshold will be dropped.
Returns:
list[str]: A list of the top k results.
"""
self.count = self.db._collection.count()
if k > self.count:
k = self.count - 1
if k <= 0:
return None
if type == "mmr":
texts = self.db.max_marginal_relevance_search(
query=query, k=k, fetch_k=min(20, self.count)
)
texts = [text.page_content for text in texts]
elif type == "cos":
texts = self.db.similarity_search_with_score(
query=query, k=k
)
texts = [
text[0].page_content
for text in texts
if text[-1] < distance_threshold
]
return texts
@synchronized_mem
def query(self, question: str):
"""
Ask a question to the vector memory.
Args:
question (str): The question to ask.
Returns:
str: The answer to the question.
"""
answer = self.qa.run(question)
return answer