|
|
@ -20,7 +20,7 @@ from fastapi.routing import APIRouter
|
|
|
|
from fastapi.staticfiles import StaticFiles
|
|
|
|
from fastapi.staticfiles import StaticFiles
|
|
|
|
from huggingface_hub import login
|
|
|
|
from huggingface_hub import login
|
|
|
|
from langchain.callbacks import StreamingStdOutCallbackHandler
|
|
|
|
from langchain.callbacks import StreamingStdOutCallbackHandler
|
|
|
|
from langchain.memory import VectorStoreRetrieverMemory
|
|
|
|
from langchain.memory import ConversationBufferMemory
|
|
|
|
from langchain.memory.chat_message_histories.in_memory import ChatMessageHistory
|
|
|
|
from langchain.memory.chat_message_histories.in_memory import ChatMessageHistory
|
|
|
|
from langchain_core.messages import SystemMessage, AIMessage, HumanMessage
|
|
|
|
from langchain_core.messages import SystemMessage, AIMessage, HumanMessage
|
|
|
|
from langchain.chains.history_aware_retriever import create_history_aware_retriever
|
|
|
|
from langchain.chains.history_aware_retriever import create_history_aware_retriever
|
|
|
@ -43,7 +43,6 @@ from swarms.prompts.conversational_RAG import (
|
|
|
|
QA_PROMPT_TEMPLATE,
|
|
|
|
QA_PROMPT_TEMPLATE,
|
|
|
|
QA_PROMPT_TEMPLATE_STR,
|
|
|
|
QA_PROMPT_TEMPLATE_STR,
|
|
|
|
QA_CONDENSE_TEMPLATE_STR,
|
|
|
|
QA_CONDENSE_TEMPLATE_STR,
|
|
|
|
SUMMARY_PROMPT_TEMPLATE,
|
|
|
|
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
from swarms.server.vector_store import VectorStorage
|
|
|
|
from swarms.server.vector_store import VectorStorage
|
|
|
@ -62,7 +61,8 @@ from swarms.server.server_models import (
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# Explicitly specify the path to the .env file
|
|
|
|
# Explicitly specify the path to the .env file
|
|
|
|
dotenv_path = os.path.join(os.path.dirname(__file__), '.env')
|
|
|
|
# Two folders above the current file's directory
|
|
|
|
|
|
|
|
dotenv_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), '.env')
|
|
|
|
load_dotenv(dotenv_path)
|
|
|
|
load_dotenv(dotenv_path)
|
|
|
|
|
|
|
|
|
|
|
|
hf_token = os.environ.get("HUGGINFACEHUB_API_KEY") # Get the Huggingface API Token
|
|
|
|
hf_token = os.environ.get("HUGGINFACEHUB_API_KEY") # Get the Huggingface API Token
|
|
|
@ -147,7 +147,7 @@ if not os.path.exists(uploads):
|
|
|
|
os.makedirs(uploads)
|
|
|
|
os.makedirs(uploads)
|
|
|
|
|
|
|
|
|
|
|
|
# Initialize the vector store
|
|
|
|
# Initialize the vector store
|
|
|
|
vector_store = VectorStorage(directory=uploads, useGPU=useGPU)
|
|
|
|
vector_store = VectorStorage(directoryOrUrl=uploads, useGPU=useGPU)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def create_chain(
|
|
|
|
async def create_chain(
|
|
|
@ -181,48 +181,57 @@ async def create_chain(
|
|
|
|
# if llm is VLLMAsync:
|
|
|
|
# if llm is VLLMAsync:
|
|
|
|
# llm.max_tokens = max_tokens_to_gen
|
|
|
|
# llm.max_tokens = max_tokens_to_gen
|
|
|
|
|
|
|
|
|
|
|
|
retriever = await vector_store.getRetriever(os.path.join(file.username, file.filename))
|
|
|
|
retriever = await vector_store.getRetriever()
|
|
|
|
|
|
|
|
|
|
|
|
chat_memory = ChatMessageHistory()
|
|
|
|
chat_memory = ChatMessageHistory()
|
|
|
|
|
|
|
|
|
|
|
|
for message in messages:
|
|
|
|
for message in messages:
|
|
|
|
if message.role == Role.USER:
|
|
|
|
if message.role == Role.USER:
|
|
|
|
human_msg = HumanMessage(message.content)
|
|
|
|
chat_memory.add_user_message(message.content)
|
|
|
|
chat_memory.add_user_message(human_msg)
|
|
|
|
|
|
|
|
elif message.role == Role.ASSISTANT:
|
|
|
|
elif message.role == Role.ASSISTANT:
|
|
|
|
ai_msg = AIMessage(message.content)
|
|
|
|
chat_memory.add_ai_message(message.content)
|
|
|
|
chat_memory.add_ai_message(ai_msg)
|
|
|
|
|
|
|
|
elif message.role == Role.SYSTEM:
|
|
|
|
memory = ConversationBufferMemory(
|
|
|
|
system_msg = SystemMessage(message.content)
|
|
|
|
chat_memory=chat_memory,
|
|
|
|
chat_memory.add_message(system_msg)
|
|
|
|
memory_key="chat_history",
|
|
|
|
|
|
|
|
input_key="question",
|
|
|
|
### Contextualize question ###
|
|
|
|
output_key="answer",
|
|
|
|
contextualize_q_system_prompt = """Given a chat history and the latest user question \
|
|
|
|
return_messages=True,
|
|
|
|
which might reference context in the chat history, formulate a standalone question \
|
|
|
|
|
|
|
|
which can be understood without the chat history. Do NOT answer the question, \
|
|
|
|
|
|
|
|
just reformulate it if needed and otherwise return it as is."""
|
|
|
|
|
|
|
|
contextualize_q_prompt = QA_PROMPT_TEMPLATE
|
|
|
|
|
|
|
|
history_aware_retriever = create_history_aware_retriever(
|
|
|
|
|
|
|
|
llm, retriever, contextualize_q_prompt
|
|
|
|
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
question_generator = LLMChain(
|
|
|
|
|
|
|
|
llm=llm,
|
|
|
|
|
|
|
|
prompt=CONDENSE_PROMPT_TEMPLATE,
|
|
|
|
|
|
|
|
memory=memory,
|
|
|
|
|
|
|
|
verbose=True,
|
|
|
|
|
|
|
|
output_key="answer",
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
stuff_chain = LLMChain(
|
|
|
|
|
|
|
|
llm=llm,
|
|
|
|
|
|
|
|
prompt=prompt,
|
|
|
|
|
|
|
|
verbose=True,
|
|
|
|
|
|
|
|
output_key="answer",
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
### Answer question ###
|
|
|
|
doc_chain = StuffDocumentsChain(
|
|
|
|
qa_system_prompt = """You are an assistant for question-answering tasks. \
|
|
|
|
llm_chain=stuff_chain,
|
|
|
|
Use the following pieces of retrieved context to answer the question. \
|
|
|
|
document_variable_name="context",
|
|
|
|
If you don't know the answer, just say that you don't know. \
|
|
|
|
document_prompt=DOCUMENT_PROMPT_TEMPLATE,
|
|
|
|
Use three sentences maximum and keep the answer concise.\
|
|
|
|
verbose=True,
|
|
|
|
|
|
|
|
output_key="answer",
|
|
|
|
{context}"""
|
|
|
|
memory=memory,
|
|
|
|
qa_prompt = QA_PROMPT_TEMPLATE
|
|
|
|
)
|
|
|
|
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt, document_prompt=DOCUMENT_PROMPT_TEMPLATE)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from langchain_core.runnables import RunnablePassthrough
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return rag_chain
|
|
|
|
return ConversationalRetrievalChain(
|
|
|
|
|
|
|
|
combine_docs_chain=doc_chain,
|
|
|
|
|
|
|
|
memory=memory,
|
|
|
|
|
|
|
|
retriever=retriever,
|
|
|
|
|
|
|
|
question_generator=question_generator,
|
|
|
|
|
|
|
|
return_generated_question=False,
|
|
|
|
|
|
|
|
return_source_documents=True,
|
|
|
|
|
|
|
|
output_key="answer",
|
|
|
|
|
|
|
|
verbose=True,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
router = APIRouter()
|
|
|
|
router = APIRouter()
|
|
|
@ -244,22 +253,20 @@ async def chat(request: ChatRequest):
|
|
|
|
),
|
|
|
|
),
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
response = LangchainStreamingResponse(
|
|
|
|
json = {
|
|
|
|
|
|
|
|
"question": request.messages[-1].content,
|
|
|
|
|
|
|
|
"chat_history": [message.content for message in request.messages[:-1]],
|
|
|
|
|
|
|
|
# "callbacks": [
|
|
|
|
|
|
|
|
# StreamingStdOutCallbackHandler(),
|
|
|
|
|
|
|
|
# TokenStreamingCallbackHandler(output_key="answer"),
|
|
|
|
|
|
|
|
# SourceDocumentsStreamingCallbackHandler(),
|
|
|
|
|
|
|
|
# ],
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
return LangchainStreamingResponse(
|
|
|
|
chain,
|
|
|
|
chain,
|
|
|
|
config={
|
|
|
|
config=json,
|
|
|
|
"input": request.messages[-1].content,
|
|
|
|
|
|
|
|
"chat_history": [message.content for message in request.messages[:-1]],
|
|
|
|
|
|
|
|
"context": "{context}",
|
|
|
|
|
|
|
|
"callbacks": [
|
|
|
|
|
|
|
|
StreamingStdOutCallbackHandler(),
|
|
|
|
|
|
|
|
TokenStreamingCallbackHandler(output_key="answer"),
|
|
|
|
|
|
|
|
SourceDocumentsStreamingCallbackHandler(),
|
|
|
|
|
|
|
|
],
|
|
|
|
|
|
|
|
},
|
|
|
|
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
return response
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.include_router(router, tags=["chat"])
|
|
|
|
app.include_router(router, tags=["chat"])
|
|
|
|
|
|
|
|
|
|
|
|