You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
198 lines
6.0 KiB
198 lines
6.0 KiB
import threading
|
|
from pathlib import Path
|
|
|
|
from langchain.chains import RetrievalQA
|
|
from langchain.chains.question_answering import load_qa_chain
|
|
from langchain.embeddings.openai import OpenAIEmbeddings
|
|
from langchain.text_splitter import CharacterTextSplitter
|
|
from langchain.vectorstores import Chroma
|
|
|
|
from swarms.memory.base_vectordb import AbstractVectorDatabase
|
|
from swarms.models.popular_llms import OpenAIChat
|
|
|
|
|
|
def synchronized_mem(method):
|
|
"""
|
|
Decorator that synchronizes access to a method using a lock.
|
|
|
|
Args:
|
|
method: The method to be decorated.
|
|
|
|
Returns:
|
|
The decorated method.
|
|
"""
|
|
|
|
def wrapper(self, *args, **kwargs):
|
|
with self.lock:
|
|
try:
|
|
return method(self, *args, **kwargs)
|
|
except Exception as e:
|
|
print(f"Failed to execute {method.__name__}: {e}")
|
|
|
|
return wrapper
|
|
|
|
|
|
class LangchainChromaVectorMemory(AbstractVectorDatabase):
|
|
"""
|
|
A class representing a vector memory for storing and retrieving text entries.
|
|
|
|
Attributes:
|
|
loc (str): The location of the vector memory.
|
|
chunk_size (int): The size of each text chunk.
|
|
chunk_overlap_frac (float): The fraction of overlap between text chunks.
|
|
embeddings (OpenAIEmbeddings): The embeddings used for text representation.
|
|
count (int): The current count of text entries in the vector memory.
|
|
lock (threading.Lock): A lock for thread safety.
|
|
db (Chroma): The Chroma database for storing text entries.
|
|
qa (RetrievalQA): The retrieval QA system for answering questions.
|
|
|
|
Methods:
|
|
__init__: Initializes the VectorMemory object.
|
|
_init_db: Initializes the Chroma database.
|
|
_init_retriever: Initializes the retrieval QA system.
|
|
add_entry: Adds an entry to the vector memory.
|
|
search_memory: Searches the vector memory for similar entries.
|
|
ask_question: Asks a question to the vector memory.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
loc=None,
|
|
chunk_size: int = 1000,
|
|
chunk_overlap_frac: float = 0.1,
|
|
*args,
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Initializes the VectorMemory object.
|
|
|
|
Args:
|
|
loc (str): The location of the vector memory. If None, defaults to "./tmp/vector_memory".
|
|
chunk_size (int): The size of each text chunk.
|
|
chunk_overlap_frac (float): The fraction of overlap between text chunks.
|
|
"""
|
|
if loc is None:
|
|
loc = "./tmp/vector_memory"
|
|
self.loc = Path(loc)
|
|
self.chunk_size = chunk_size
|
|
self.chunk_overlap = chunk_size * chunk_overlap_frac
|
|
self.embeddings = OpenAIEmbeddings()
|
|
self.count = 0
|
|
self.lock = threading.Lock()
|
|
|
|
self.db = self._init_db()
|
|
self.qa = self._init_retriever()
|
|
|
|
def _init_db(self):
|
|
"""
|
|
Initializes the Chroma database.
|
|
|
|
Returns:
|
|
Chroma: The initialized Chroma database.
|
|
"""
|
|
texts = [
|
|
"init"
|
|
] # TODO find how to initialize Chroma without any text
|
|
chroma_db = Chroma.from_texts(
|
|
texts=texts,
|
|
embedding=self.embeddings,
|
|
persist_directory=str(self.loc),
|
|
)
|
|
self.count = chroma_db._collection.count()
|
|
return chroma_db
|
|
|
|
def _init_retriever(self):
|
|
"""
|
|
Initializes the retrieval QA system.
|
|
|
|
Returns:
|
|
RetrievalQA: The initialized retrieval QA system.
|
|
"""
|
|
model = OpenAIChat(
|
|
model_name="gpt-3.5-turbo",
|
|
)
|
|
qa_chain = load_qa_chain(model, chain_type="stuff")
|
|
retriever = self.db.as_retriever(
|
|
search_type="mmr", search_kwargs={"k": 10}
|
|
)
|
|
qa = RetrievalQA(
|
|
combine_documents_chain=qa_chain, retriever=retriever
|
|
)
|
|
return qa
|
|
|
|
@synchronized_mem
|
|
def add(self, entry: str):
|
|
"""
|
|
Add an entry to the internal memory.
|
|
|
|
Args:
|
|
entry (str): The entry to be added.
|
|
|
|
Returns:
|
|
bool: True if the entry was successfully added, False otherwise.
|
|
"""
|
|
text_splitter = CharacterTextSplitter(
|
|
chunk_size=self.chunk_size,
|
|
chunk_overlap=self.chunk_overlap,
|
|
separator=" ",
|
|
)
|
|
texts = text_splitter.split_text(entry)
|
|
|
|
self.db.add_texts(texts)
|
|
self.count += self.db._collection.count()
|
|
self.db.persist()
|
|
return True
|
|
|
|
@synchronized_mem
|
|
def search_memory(
|
|
self, query: str, k=10, type="mmr", distance_threshold=0.5
|
|
):
|
|
"""
|
|
Searching the vector memory for similar entries.
|
|
|
|
Args:
|
|
query (str): The query to search for.
|
|
k (int): The number of results to return.
|
|
type (str): The type of search to perform: "cos" or "mmr".
|
|
distance_threshold (float): The similarity threshold to use for the search. Results with distance > similarity_threshold will be dropped.
|
|
|
|
Returns:
|
|
list[str]: A list of the top k results.
|
|
"""
|
|
self.count = self.db._collection.count()
|
|
if k > self.count:
|
|
k = self.count - 1
|
|
if k <= 0:
|
|
return None
|
|
|
|
if type == "mmr":
|
|
texts = self.db.max_marginal_relevance_search(
|
|
query=query, k=k, fetch_k=min(20, self.count)
|
|
)
|
|
texts = [text.page_content for text in texts]
|
|
elif type == "cos":
|
|
texts = self.db.similarity_search_with_score(
|
|
query=query, k=k
|
|
)
|
|
texts = [
|
|
text[0].page_content
|
|
for text in texts
|
|
if text[-1] < distance_threshold
|
|
]
|
|
|
|
return texts
|
|
|
|
@synchronized_mem
|
|
def query(self, question: str):
|
|
"""
|
|
Ask a question to the vector memory.
|
|
|
|
Args:
|
|
question (str): The question to ask.
|
|
|
|
Returns:
|
|
str: The answer to the question.
|
|
"""
|
|
answer = self.qa.run(question)
|
|
return answer
|