more fixes for stuff documents chain

pull/570/head
Richard Anthony Hein 8 months ago
parent 3bae493d3d
commit 2bac1e6b6e

@ -20,7 +20,7 @@ from fastapi.routing import APIRouter
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from huggingface_hub import login from huggingface_hub import login
from langchain.callbacks import StreamingStdOutCallbackHandler from langchain.callbacks import StreamingStdOutCallbackHandler
from langchain.memory import VectorStoreRetrieverMemory from langchain.memory import ConversationBufferMemory
from langchain.memory.chat_message_histories.in_memory import ChatMessageHistory from langchain.memory.chat_message_histories.in_memory import ChatMessageHistory
from langchain_core.messages import SystemMessage, AIMessage, HumanMessage from langchain_core.messages import SystemMessage, AIMessage, HumanMessage
from langchain.chains.history_aware_retriever import create_history_aware_retriever from langchain.chains.history_aware_retriever import create_history_aware_retriever
@ -43,7 +43,6 @@ from swarms.prompts.conversational_RAG import (
QA_PROMPT_TEMPLATE, QA_PROMPT_TEMPLATE,
QA_PROMPT_TEMPLATE_STR, QA_PROMPT_TEMPLATE_STR,
QA_CONDENSE_TEMPLATE_STR, QA_CONDENSE_TEMPLATE_STR,
SUMMARY_PROMPT_TEMPLATE,
) )
from swarms.server.vector_store import VectorStorage from swarms.server.vector_store import VectorStorage
@ -62,7 +61,8 @@ from swarms.server.server_models import (
) )
# Explicitly specify the path to the .env file # Explicitly specify the path to the .env file
dotenv_path = os.path.join(os.path.dirname(__file__), '.env') # Two folders above the current file's directory
dotenv_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), '.env')
load_dotenv(dotenv_path) load_dotenv(dotenv_path)
hf_token = os.environ.get("HUGGINFACEHUB_API_KEY") # Get the Huggingface API Token hf_token = os.environ.get("HUGGINFACEHUB_API_KEY") # Get the Huggingface API Token
@ -147,7 +147,7 @@ if not os.path.exists(uploads):
os.makedirs(uploads) os.makedirs(uploads)
# Initialize the vector store # Initialize the vector store
vector_store = VectorStorage(directory=uploads, useGPU=useGPU) vector_store = VectorStorage(directoryOrUrl=uploads, useGPU=useGPU)
async def create_chain( async def create_chain(
@ -181,48 +181,57 @@ async def create_chain(
# if llm is VLLMAsync: # if llm is VLLMAsync:
# llm.max_tokens = max_tokens_to_gen # llm.max_tokens = max_tokens_to_gen
retriever = await vector_store.getRetriever(os.path.join(file.username, file.filename)) retriever = await vector_store.getRetriever()
chat_memory = ChatMessageHistory() chat_memory = ChatMessageHistory()
for message in messages: for message in messages:
if message.role == Role.USER: if message.role == Role.USER:
human_msg = HumanMessage(message.content) chat_memory.add_user_message(message.content)
chat_memory.add_user_message(human_msg)
elif message.role == Role.ASSISTANT: elif message.role == Role.ASSISTANT:
ai_msg = AIMessage(message.content) chat_memory.add_ai_message(message.content)
chat_memory.add_ai_message(ai_msg)
elif message.role == Role.SYSTEM: memory = ConversationBufferMemory(
system_msg = SystemMessage(message.content) chat_memory=chat_memory,
chat_memory.add_message(system_msg) memory_key="chat_history",
input_key="question",
### Contextualize question ### output_key="answer",
contextualize_q_system_prompt = """Given a chat history and the latest user question \ return_messages=True,
which might reference context in the chat history, formulate a standalone question \
which can be understood without the chat history. Do NOT answer the question, \
just reformulate it if needed and otherwise return it as is."""
contextualize_q_prompt = QA_PROMPT_TEMPLATE
history_aware_retriever = create_history_aware_retriever(
llm, retriever, contextualize_q_prompt
) )
question_generator = LLMChain(
llm=llm,
prompt=CONDENSE_PROMPT_TEMPLATE,
memory=memory,
verbose=True,
output_key="answer",
)
stuff_chain = LLMChain(
llm=llm,
prompt=prompt,
verbose=True,
output_key="answer",
)
### Answer question ### doc_chain = StuffDocumentsChain(
qa_system_prompt = """You are an assistant for question-answering tasks. \ llm_chain=stuff_chain,
Use the following pieces of retrieved context to answer the question. \ document_variable_name="context",
If you don't know the answer, just say that you don't know. \ document_prompt=DOCUMENT_PROMPT_TEMPLATE,
Use three sentences maximum and keep the answer concise.\ verbose=True,
output_key="answer",
{context}""" memory=memory,
qa_prompt = QA_PROMPT_TEMPLATE )
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt, document_prompt=DOCUMENT_PROMPT_TEMPLATE)
from langchain_core.runnables import RunnablePassthrough
rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
return rag_chain return ConversationalRetrievalChain(
combine_docs_chain=doc_chain,
memory=memory,
retriever=retriever,
question_generator=question_generator,
return_generated_question=False,
return_source_documents=True,
output_key="answer",
verbose=True,
)
router = APIRouter() router = APIRouter()
@ -244,22 +253,20 @@ async def chat(request: ChatRequest):
), ),
) )
response = LangchainStreamingResponse( json = {
"question": request.messages[-1].content,
"chat_history": [message.content for message in request.messages[:-1]],
# "callbacks": [
# StreamingStdOutCallbackHandler(),
# TokenStreamingCallbackHandler(output_key="answer"),
# SourceDocumentsStreamingCallbackHandler(),
# ],
}
return LangchainStreamingResponse(
chain, chain,
config={ config=json,
"input": request.messages[-1].content,
"chat_history": [message.content for message in request.messages[:-1]],
"context": "{context}",
"callbacks": [
StreamingStdOutCallbackHandler(),
TokenStreamingCallbackHandler(output_key="answer"),
SourceDocumentsStreamingCallbackHandler(),
],
},
) )
return response
app.include_router(router, tags=["chat"]) app.include_router(router, tags=["chat"])

@ -16,7 +16,7 @@ from swarms.server.async_parent_document_retriever import AsyncParentDocumentRet
store_type = "local" # "redis" or "local" store_type = "local" # "redis" or "local"
class VectorStorage: class VectorStorage:
def __init__(self, directory, useGPU=False): def __init__(self, directoryOrUrl, useGPU=False):
self.embeddings = HuggingFaceBgeEmbeddings( self.embeddings = HuggingFaceBgeEmbeddings(
cache_folder="./.embeddings", cache_folder="./.embeddings",
model_name="BAAI/bge-large-en", model_name="BAAI/bge-large-en",
@ -24,7 +24,7 @@ class VectorStorage:
encode_kwargs={"normalize_embeddings": True}, encode_kwargs={"normalize_embeddings": True},
query_instruction="Represent this sentence for searching relevant passages: ", query_instruction="Represent this sentence for searching relevant passages: ",
) )
self.directory = directory self.directoryOrUrl = directoryOrUrl
self.child_splitter = RecursiveCharacterTextSplitter( self.child_splitter = RecursiveCharacterTextSplitter(
chunk_size=200, chunk_overlap=20 chunk_size=200, chunk_overlap=20
) )
@ -62,16 +62,16 @@ class VectorStorage:
print(f"Start vectorstore initialization time: {start_time}") print(f"Start vectorstore initialization time: {start_time}")
# for each subdirectory in the directory, create a new collection if it doesn't exist # for each subdirectory in the directory, create a new collection if it doesn't exist
dirs = directories or os.listdir(self.directory) dirs = directories or os.listdir(self.directoryOrUrl)
# make sure the subdir is not a file on MacOS (which has a hidden .DS_Store file) # make sure the subdir is not a file on MacOS (which has a hidden .DS_Store file)
dirs = [ dirs = [
subdir subdir
for subdir in dirs for subdir in dirs
if not os.path.isfile(f"{self.directory}/{subdir}") if not os.path.isfile(f"{self.directoryOrUrl}/{subdir}")
] ]
print(f"{len(dirs)} subdirectories to load: {dirs}") print(f"{len(dirs)} subdirectories to load: {dirs}")
self.retrievers[self.directory] = await self.initRetriever(self.directory) self.retrievers[self.directoryOrUrl] = await self.initRetriever(self.directoryOrUrl)
end_time = datetime.now() end_time = datetime.now()
print("Vectorstore initialization complete.") print("Vectorstore initialization complete.")
@ -97,7 +97,7 @@ class VectorStorage:
max_files = 1000 max_files = 1000
# Load existing metadata # Load existing metadata
metadata_file = f"{self.directory}/metadata.json" metadata_file = f"{self.directoryOrUrl}/metadata.json"
metadata = {"processDate": str(datetime.now()), "processed_files": []} metadata = {"processDate": str(datetime.now()), "processed_files": []}
processed_files = set() # Track processed files processed_files = set() # Track processed files
if os.path.isfile(metadata_file): if os.path.isfile(metadata_file):
@ -107,7 +107,7 @@ class VectorStorage:
# Get a list of all files in the directory and exclude processed files # Get a list of all files in the directory and exclude processed files
all_files = [ all_files = [
file for file in glob.glob(f"{self.directory}/**/*.md", recursive=True) file for file in glob.glob(f"{self.directoryOrUrl}/**/*.md", recursive=True)
if file not in processed_files if file not in processed_files
] ]
@ -131,16 +131,16 @@ class VectorStorage:
"processed_at": str(datetime.now()) "processed_at": str(datetime.now())
}) })
print(f"Creating new collection for {self.directory}...") print(f"Creating new collection for {self.directoryOrUrl}...")
# Create or get the collection # Create or get the collection
collection = self.client.create_collection( collection = self.client.create_collection(
name=self.directory, name=self.directoryOrUrl,
get_or_create=True, get_or_create=True,
metadata={"processDate": metadata["processDate"]}, metadata={"processDate": metadata["processDate"]},
) )
# Reload vectorstore based on collection # Reload vectorstore based on collection
vectorstore = self.getVectorStore(collection_name=self.directory) vectorstore = self.getVectorStore(collection_name=self.directoryOrUrl)
# Create a new parent document retriever # Create a new parent document retriever
retriever = AsyncParentDocumentRetriever( retriever = AsyncParentDocumentRetriever(
@ -151,8 +151,8 @@ class VectorStorage:
) )
# force reload of collection to make sure we don't have the default langchain collection # force reload of collection to make sure we don't have the default langchain collection
collection = self.client.get_collection(name=self.directory) collection = self.client.get_collection(name=self.directoryOrUrl)
vectorstore = self.getVectorStore(collection_name=self.directory) vectorstore = self.getVectorStore(collection_name=self.directoryOrUrl)
# Add documents to the collection and docstore # Add documents to the collection and docstore
print(f"Adding {len(documents)} documents to collection...") print(f"Adding {len(documents)} documents to collection...")
@ -182,8 +182,8 @@ class VectorStorage:
print(f"Time taken: {subdir_end_time - subdir_start_time}") print(f"Time taken: {subdir_end_time - subdir_start_time}")
# Reload vectorstore based on collection to pass to parent doc retriever # Reload vectorstore based on collection to pass to parent doc retriever
collection = self.client.get_collection(name=self.directory) # collection = self.client.get_collection(name=self.directoryOrUrl)
vectorstore = self.getVectorStore(collection_name=self.directory) vectorstore = self.getVectorStore()
retriever = AsyncParentDocumentRetriever( retriever = AsyncParentDocumentRetriever(
docstore=self.store, docstore=self.store,
vectorstore=vectorstore, vectorstore=vectorstore,

Loading…
Cancel
Save