From 2bac1e6b6ea1a3269e2fd05c7d0aa9d54b3d2918 Mon Sep 17 00:00:00 2001 From: Richard Anthony Hein Date: Mon, 19 Aug 2024 15:19:23 +0000 Subject: [PATCH] more fixes for stuff documents chain --- swarms/server/server.py | 107 ++++++++++++++++++---------------- swarms/server/vector_store.py | 28 ++++----- 2 files changed, 71 insertions(+), 64 deletions(-) diff --git a/swarms/server/server.py b/swarms/server/server.py index 7749ba90..95e1185b 100644 --- a/swarms/server/server.py +++ b/swarms/server/server.py @@ -20,7 +20,7 @@ from fastapi.routing import APIRouter from fastapi.staticfiles import StaticFiles from huggingface_hub import login from langchain.callbacks import StreamingStdOutCallbackHandler -from langchain.memory import VectorStoreRetrieverMemory +from langchain.memory import ConversationBufferMemory from langchain.memory.chat_message_histories.in_memory import ChatMessageHistory from langchain_core.messages import SystemMessage, AIMessage, HumanMessage from langchain.chains.history_aware_retriever import create_history_aware_retriever @@ -43,7 +43,6 @@ from swarms.prompts.conversational_RAG import ( QA_PROMPT_TEMPLATE, QA_PROMPT_TEMPLATE_STR, QA_CONDENSE_TEMPLATE_STR, - SUMMARY_PROMPT_TEMPLATE, ) from swarms.server.vector_store import VectorStorage @@ -62,7 +61,8 @@ from swarms.server.server_models import ( ) # Explicitly specify the path to the .env file -dotenv_path = os.path.join(os.path.dirname(__file__), '.env') +# Two folders above the current file's directory +dotenv_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), '.env') load_dotenv(dotenv_path) hf_token = os.environ.get("HUGGINFACEHUB_API_KEY") # Get the Huggingface API Token @@ -147,7 +147,7 @@ if not os.path.exists(uploads): os.makedirs(uploads) # Initialize the vector store -vector_store = VectorStorage(directory=uploads, useGPU=useGPU) +vector_store = VectorStorage(directoryOrUrl=uploads, useGPU=useGPU) async def create_chain( @@ -181,48 +181,57 @@ async def create_chain( # if llm is VLLMAsync: # llm.max_tokens = max_tokens_to_gen - retriever = await vector_store.getRetriever(os.path.join(file.username, file.filename)) + retriever = await vector_store.getRetriever() chat_memory = ChatMessageHistory() - for message in messages: if message.role == Role.USER: - human_msg = HumanMessage(message.content) - chat_memory.add_user_message(human_msg) + chat_memory.add_user_message(message.content) elif message.role == Role.ASSISTANT: - ai_msg = AIMessage(message.content) - chat_memory.add_ai_message(ai_msg) - elif message.role == Role.SYSTEM: - system_msg = SystemMessage(message.content) - chat_memory.add_message(system_msg) - - ### Contextualize question ### - contextualize_q_system_prompt = """Given a chat history and the latest user question \ - which might reference context in the chat history, formulate a standalone question \ - which can be understood without the chat history. Do NOT answer the question, \ - just reformulate it if needed and otherwise return it as is.""" - contextualize_q_prompt = QA_PROMPT_TEMPLATE - history_aware_retriever = create_history_aware_retriever( - llm, retriever, contextualize_q_prompt + chat_memory.add_ai_message(message.content) + + memory = ConversationBufferMemory( + chat_memory=chat_memory, + memory_key="chat_history", + input_key="question", + output_key="answer", + return_messages=True, ) + question_generator = LLMChain( + llm=llm, + prompt=CONDENSE_PROMPT_TEMPLATE, + memory=memory, + verbose=True, + output_key="answer", + ) + stuff_chain = LLMChain( + llm=llm, + prompt=prompt, + verbose=True, + output_key="answer", + ) - ### Answer question ### - qa_system_prompt = """You are an assistant for question-answering tasks. \ - Use the following pieces of retrieved context to answer the question. \ - If you don't know the answer, just say that you don't know. \ - Use three sentences maximum and keep the answer concise.\ - - {context}""" - qa_prompt = QA_PROMPT_TEMPLATE - question_answer_chain = create_stuff_documents_chain(llm, qa_prompt, document_prompt=DOCUMENT_PROMPT_TEMPLATE) - - from langchain_core.runnables import RunnablePassthrough + doc_chain = StuffDocumentsChain( + llm_chain=stuff_chain, + document_variable_name="context", + document_prompt=DOCUMENT_PROMPT_TEMPLATE, + verbose=True, + output_key="answer", + memory=memory, + ) - rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain) - - return rag_chain + return ConversationalRetrievalChain( + combine_docs_chain=doc_chain, + memory=memory, + retriever=retriever, + question_generator=question_generator, + return_generated_question=False, + return_source_documents=True, + output_key="answer", + verbose=True, + ) router = APIRouter() @@ -243,23 +252,21 @@ async def chat(request: ChatRequest): f"{B_INST}{B_SYS}{request.prompt.strip()}{E_SYS}{E_INST}" ), ) - - response = LangchainStreamingResponse( + + json = { + "question": request.messages[-1].content, + "chat_history": [message.content for message in request.messages[:-1]], + # "callbacks": [ + # StreamingStdOutCallbackHandler(), + # TokenStreamingCallbackHandler(output_key="answer"), + # SourceDocumentsStreamingCallbackHandler(), + # ], + } + return LangchainStreamingResponse( chain, - config={ - "input": request.messages[-1].content, - "chat_history": [message.content for message in request.messages[:-1]], - "context": "{context}", - "callbacks": [ - StreamingStdOutCallbackHandler(), - TokenStreamingCallbackHandler(output_key="answer"), - SourceDocumentsStreamingCallbackHandler(), - ], - }, + config=json, ) - return response - app.include_router(router, tags=["chat"]) diff --git a/swarms/server/vector_store.py b/swarms/server/vector_store.py index ca1d37f4..04f2e6fe 100644 --- a/swarms/server/vector_store.py +++ b/swarms/server/vector_store.py @@ -16,7 +16,7 @@ from swarms.server.async_parent_document_retriever import AsyncParentDocumentRet store_type = "local" # "redis" or "local" class VectorStorage: - def __init__(self, directory, useGPU=False): + def __init__(self, directoryOrUrl, useGPU=False): self.embeddings = HuggingFaceBgeEmbeddings( cache_folder="./.embeddings", model_name="BAAI/bge-large-en", @@ -24,7 +24,7 @@ class VectorStorage: encode_kwargs={"normalize_embeddings": True}, query_instruction="Represent this sentence for searching relevant passages: ", ) - self.directory = directory + self.directoryOrUrl = directoryOrUrl self.child_splitter = RecursiveCharacterTextSplitter( chunk_size=200, chunk_overlap=20 ) @@ -62,16 +62,16 @@ class VectorStorage: print(f"Start vectorstore initialization time: {start_time}") # for each subdirectory in the directory, create a new collection if it doesn't exist - dirs = directories or os.listdir(self.directory) + dirs = directories or os.listdir(self.directoryOrUrl) # make sure the subdir is not a file on MacOS (which has a hidden .DS_Store file) dirs = [ subdir for subdir in dirs - if not os.path.isfile(f"{self.directory}/{subdir}") + if not os.path.isfile(f"{self.directoryOrUrl}/{subdir}") ] print(f"{len(dirs)} subdirectories to load: {dirs}") - self.retrievers[self.directory] = await self.initRetriever(self.directory) + self.retrievers[self.directoryOrUrl] = await self.initRetriever(self.directoryOrUrl) end_time = datetime.now() print("Vectorstore initialization complete.") @@ -97,7 +97,7 @@ class VectorStorage: max_files = 1000 # Load existing metadata - metadata_file = f"{self.directory}/metadata.json" + metadata_file = f"{self.directoryOrUrl}/metadata.json" metadata = {"processDate": str(datetime.now()), "processed_files": []} processed_files = set() # Track processed files if os.path.isfile(metadata_file): @@ -107,7 +107,7 @@ class VectorStorage: # Get a list of all files in the directory and exclude processed files all_files = [ - file for file in glob.glob(f"{self.directory}/**/*.md", recursive=True) + file for file in glob.glob(f"{self.directoryOrUrl}/**/*.md", recursive=True) if file not in processed_files ] @@ -131,16 +131,16 @@ class VectorStorage: "processed_at": str(datetime.now()) }) - print(f"Creating new collection for {self.directory}...") + print(f"Creating new collection for {self.directoryOrUrl}...") # Create or get the collection collection = self.client.create_collection( - name=self.directory, + name=self.directoryOrUrl, get_or_create=True, metadata={"processDate": metadata["processDate"]}, ) # Reload vectorstore based on collection - vectorstore = self.getVectorStore(collection_name=self.directory) + vectorstore = self.getVectorStore(collection_name=self.directoryOrUrl) # Create a new parent document retriever retriever = AsyncParentDocumentRetriever( @@ -151,8 +151,8 @@ class VectorStorage: ) # force reload of collection to make sure we don't have the default langchain collection - collection = self.client.get_collection(name=self.directory) - vectorstore = self.getVectorStore(collection_name=self.directory) + collection = self.client.get_collection(name=self.directoryOrUrl) + vectorstore = self.getVectorStore(collection_name=self.directoryOrUrl) # Add documents to the collection and docstore print(f"Adding {len(documents)} documents to collection...") @@ -182,8 +182,8 @@ class VectorStorage: print(f"Time taken: {subdir_end_time - subdir_start_time}") # Reload vectorstore based on collection to pass to parent doc retriever - collection = self.client.get_collection(name=self.directory) - vectorstore = self.getVectorStore(collection_name=self.directory) + # collection = self.client.get_collection(name=self.directoryOrUrl) + vectorstore = self.getVectorStore() retriever = AsyncParentDocumentRetriever( docstore=self.store, vectorstore=vectorstore,