more fixes for stuff documents chain

pull/570/head
Richard Anthony Hein 8 months ago
parent 3bae493d3d
commit 2bac1e6b6e

@ -20,7 +20,7 @@ from fastapi.routing import APIRouter
from fastapi.staticfiles import StaticFiles
from huggingface_hub import login
from langchain.callbacks import StreamingStdOutCallbackHandler
from langchain.memory import VectorStoreRetrieverMemory
from langchain.memory import ConversationBufferMemory
from langchain.memory.chat_message_histories.in_memory import ChatMessageHistory
from langchain_core.messages import SystemMessage, AIMessage, HumanMessage
from langchain.chains.history_aware_retriever import create_history_aware_retriever
@ -43,7 +43,6 @@ from swarms.prompts.conversational_RAG import (
QA_PROMPT_TEMPLATE,
QA_PROMPT_TEMPLATE_STR,
QA_CONDENSE_TEMPLATE_STR,
SUMMARY_PROMPT_TEMPLATE,
)
from swarms.server.vector_store import VectorStorage
@ -62,7 +61,8 @@ from swarms.server.server_models import (
)
# Explicitly specify the path to the .env file
dotenv_path = os.path.join(os.path.dirname(__file__), '.env')
# Two folders above the current file's directory
dotenv_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), '.env')
load_dotenv(dotenv_path)
hf_token = os.environ.get("HUGGINFACEHUB_API_KEY") # Get the Huggingface API Token
@ -147,7 +147,7 @@ if not os.path.exists(uploads):
os.makedirs(uploads)
# Initialize the vector store
vector_store = VectorStorage(directory=uploads, useGPU=useGPU)
vector_store = VectorStorage(directoryOrUrl=uploads, useGPU=useGPU)
async def create_chain(
@ -181,48 +181,57 @@ async def create_chain(
# if llm is VLLMAsync:
# llm.max_tokens = max_tokens_to_gen
retriever = await vector_store.getRetriever(os.path.join(file.username, file.filename))
retriever = await vector_store.getRetriever()
chat_memory = ChatMessageHistory()
for message in messages:
if message.role == Role.USER:
human_msg = HumanMessage(message.content)
chat_memory.add_user_message(human_msg)
chat_memory.add_user_message(message.content)
elif message.role == Role.ASSISTANT:
ai_msg = AIMessage(message.content)
chat_memory.add_ai_message(ai_msg)
elif message.role == Role.SYSTEM:
system_msg = SystemMessage(message.content)
chat_memory.add_message(system_msg)
### Contextualize question ###
contextualize_q_system_prompt = """Given a chat history and the latest user question \
which might reference context in the chat history, formulate a standalone question \
which can be understood without the chat history. Do NOT answer the question, \
just reformulate it if needed and otherwise return it as is."""
contextualize_q_prompt = QA_PROMPT_TEMPLATE
history_aware_retriever = create_history_aware_retriever(
llm, retriever, contextualize_q_prompt
chat_memory.add_ai_message(message.content)
memory = ConversationBufferMemory(
chat_memory=chat_memory,
memory_key="chat_history",
input_key="question",
output_key="answer",
return_messages=True,
)
question_generator = LLMChain(
llm=llm,
prompt=CONDENSE_PROMPT_TEMPLATE,
memory=memory,
verbose=True,
output_key="answer",
)
stuff_chain = LLMChain(
llm=llm,
prompt=prompt,
verbose=True,
output_key="answer",
)
### Answer question ###
qa_system_prompt = """You are an assistant for question-answering tasks. \
Use the following pieces of retrieved context to answer the question. \
If you don't know the answer, just say that you don't know. \
Use three sentences maximum and keep the answer concise.\
{context}"""
qa_prompt = QA_PROMPT_TEMPLATE
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt, document_prompt=DOCUMENT_PROMPT_TEMPLATE)
from langchain_core.runnables import RunnablePassthrough
rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
doc_chain = StuffDocumentsChain(
llm_chain=stuff_chain,
document_variable_name="context",
document_prompt=DOCUMENT_PROMPT_TEMPLATE,
verbose=True,
output_key="answer",
memory=memory,
)
return rag_chain
return ConversationalRetrievalChain(
combine_docs_chain=doc_chain,
memory=memory,
retriever=retriever,
question_generator=question_generator,
return_generated_question=False,
return_source_documents=True,
output_key="answer",
verbose=True,
)
router = APIRouter()
@ -244,22 +253,20 @@ async def chat(request: ChatRequest):
),
)
response = LangchainStreamingResponse(
chain,
config={
"input": request.messages[-1].content,
json = {
"question": request.messages[-1].content,
"chat_history": [message.content for message in request.messages[:-1]],
"context": "{context}",
"callbacks": [
StreamingStdOutCallbackHandler(),
TokenStreamingCallbackHandler(output_key="answer"),
SourceDocumentsStreamingCallbackHandler(),
],
},
# "callbacks": [
# StreamingStdOutCallbackHandler(),
# TokenStreamingCallbackHandler(output_key="answer"),
# SourceDocumentsStreamingCallbackHandler(),
# ],
}
return LangchainStreamingResponse(
chain,
config=json,
)
return response
app.include_router(router, tags=["chat"])

@ -16,7 +16,7 @@ from swarms.server.async_parent_document_retriever import AsyncParentDocumentRet
store_type = "local" # "redis" or "local"
class VectorStorage:
def __init__(self, directory, useGPU=False):
def __init__(self, directoryOrUrl, useGPU=False):
self.embeddings = HuggingFaceBgeEmbeddings(
cache_folder="./.embeddings",
model_name="BAAI/bge-large-en",
@ -24,7 +24,7 @@ class VectorStorage:
encode_kwargs={"normalize_embeddings": True},
query_instruction="Represent this sentence for searching relevant passages: ",
)
self.directory = directory
self.directoryOrUrl = directoryOrUrl
self.child_splitter = RecursiveCharacterTextSplitter(
chunk_size=200, chunk_overlap=20
)
@ -62,16 +62,16 @@ class VectorStorage:
print(f"Start vectorstore initialization time: {start_time}")
# for each subdirectory in the directory, create a new collection if it doesn't exist
dirs = directories or os.listdir(self.directory)
dirs = directories or os.listdir(self.directoryOrUrl)
# make sure the subdir is not a file on MacOS (which has a hidden .DS_Store file)
dirs = [
subdir
for subdir in dirs
if not os.path.isfile(f"{self.directory}/{subdir}")
if not os.path.isfile(f"{self.directoryOrUrl}/{subdir}")
]
print(f"{len(dirs)} subdirectories to load: {dirs}")
self.retrievers[self.directory] = await self.initRetriever(self.directory)
self.retrievers[self.directoryOrUrl] = await self.initRetriever(self.directoryOrUrl)
end_time = datetime.now()
print("Vectorstore initialization complete.")
@ -97,7 +97,7 @@ class VectorStorage:
max_files = 1000
# Load existing metadata
metadata_file = f"{self.directory}/metadata.json"
metadata_file = f"{self.directoryOrUrl}/metadata.json"
metadata = {"processDate": str(datetime.now()), "processed_files": []}
processed_files = set() # Track processed files
if os.path.isfile(metadata_file):
@ -107,7 +107,7 @@ class VectorStorage:
# Get a list of all files in the directory and exclude processed files
all_files = [
file for file in glob.glob(f"{self.directory}/**/*.md", recursive=True)
file for file in glob.glob(f"{self.directoryOrUrl}/**/*.md", recursive=True)
if file not in processed_files
]
@ -131,16 +131,16 @@ class VectorStorage:
"processed_at": str(datetime.now())
})
print(f"Creating new collection for {self.directory}...")
print(f"Creating new collection for {self.directoryOrUrl}...")
# Create or get the collection
collection = self.client.create_collection(
name=self.directory,
name=self.directoryOrUrl,
get_or_create=True,
metadata={"processDate": metadata["processDate"]},
)
# Reload vectorstore based on collection
vectorstore = self.getVectorStore(collection_name=self.directory)
vectorstore = self.getVectorStore(collection_name=self.directoryOrUrl)
# Create a new parent document retriever
retriever = AsyncParentDocumentRetriever(
@ -151,8 +151,8 @@ class VectorStorage:
)
# force reload of collection to make sure we don't have the default langchain collection
collection = self.client.get_collection(name=self.directory)
vectorstore = self.getVectorStore(collection_name=self.directory)
collection = self.client.get_collection(name=self.directoryOrUrl)
vectorstore = self.getVectorStore(collection_name=self.directoryOrUrl)
# Add documents to the collection and docstore
print(f"Adding {len(documents)} documents to collection...")
@ -182,8 +182,8 @@ class VectorStorage:
print(f"Time taken: {subdir_end_time - subdir_start_time}")
# Reload vectorstore based on collection to pass to parent doc retriever
collection = self.client.get_collection(name=self.directory)
vectorstore = self.getVectorStore(collection_name=self.directory)
# collection = self.client.get_collection(name=self.directoryOrUrl)
vectorstore = self.getVectorStore()
retriever = AsyncParentDocumentRetriever(
docstore=self.store,
vectorstore=vectorstore,

Loading…
Cancel
Save