|
|
|
@ -20,8 +20,13 @@ from fastapi.routing import APIRouter
|
|
|
|
|
from fastapi.staticfiles import StaticFiles
|
|
|
|
|
from huggingface_hub import login
|
|
|
|
|
from langchain.callbacks import StreamingStdOutCallbackHandler
|
|
|
|
|
from langchain.memory import ConversationStringBufferMemory
|
|
|
|
|
from langchain.memory import VectorStoreRetrieverMemory
|
|
|
|
|
from langchain.memory.chat_message_histories.in_memory import ChatMessageHistory
|
|
|
|
|
from langchain_core.messages import SystemMessage, AIMessage, HumanMessage
|
|
|
|
|
from langchain.chains.history_aware_retriever import create_history_aware_retriever
|
|
|
|
|
from langchain.chains.retrieval import create_retrieval_chain
|
|
|
|
|
from langchain.chains.combine_documents import create_stuff_documents_chain
|
|
|
|
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
|
|
|
|
from langchain.prompts.prompt import PromptTemplate
|
|
|
|
|
from langchain_community.chat_models import ChatOpenAI
|
|
|
|
|
from swarms.server.responses import LangchainStreamingResponse
|
|
|
|
@ -174,71 +179,57 @@ async def create_chain(
|
|
|
|
|
|
|
|
|
|
retriever = await vector_store.getRetriever(os.path.join(file.username, file.filename))
|
|
|
|
|
|
|
|
|
|
chat_memory = ChatMessageHistory()
|
|
|
|
|
|
|
|
|
|
for message in messages:
|
|
|
|
|
if message.role == Role.USER:
|
|
|
|
|
chat_memory.add_user_message(message)
|
|
|
|
|
elif message.role == Role.ASSISTANT:
|
|
|
|
|
chat_memory.add_ai_message(message)
|
|
|
|
|
elif message.role == Role.SYSTEM:
|
|
|
|
|
chat_memory.add_message(message)
|
|
|
|
|
|
|
|
|
|
memory = ConversationStringBufferMemory(
|
|
|
|
|
llm=llm,
|
|
|
|
|
chat_memory=chat_memory,
|
|
|
|
|
memory_key="chat_history",
|
|
|
|
|
input_key="question",
|
|
|
|
|
output_key="answer",
|
|
|
|
|
prompt=SUMMARY_PROMPT_TEMPLATE,
|
|
|
|
|
return_messages=False,
|
|
|
|
|
# chat_memory = ChatMessageHistory()
|
|
|
|
|
|
|
|
|
|
# for message in messages:
|
|
|
|
|
# if message.role == Role.USER:
|
|
|
|
|
# human_msg = HumanMessage(message.content)
|
|
|
|
|
# chat_memory.add_user_message(human_msg)
|
|
|
|
|
# elif message.role == Role.ASSISTANT:
|
|
|
|
|
# ai_msg = AIMessage(message.content)
|
|
|
|
|
# chat_memory.add_ai_message(ai_msg)
|
|
|
|
|
# elif message.role == Role.SYSTEM:
|
|
|
|
|
# system_msg = SystemMessage(message.content)
|
|
|
|
|
# chat_memory.add_message(system_msg)
|
|
|
|
|
|
|
|
|
|
### Contextualize question ###
|
|
|
|
|
contextualize_q_system_prompt = """Given a chat history and the latest user question \
|
|
|
|
|
which might reference context in the chat history, formulate a standalone question \
|
|
|
|
|
which can be understood without the chat history. Do NOT answer the question, \
|
|
|
|
|
just reformulate it if needed and otherwise return it as is."""
|
|
|
|
|
contextualize_q_prompt = ChatPromptTemplate.from_messages(
|
|
|
|
|
[
|
|
|
|
|
("system", contextualize_q_system_prompt),
|
|
|
|
|
MessagesPlaceholder("chat_history"),
|
|
|
|
|
("human", "{input}"),
|
|
|
|
|
]
|
|
|
|
|
)
|
|
|
|
|
history_aware_retriever = create_history_aware_retriever(
|
|
|
|
|
llm, retriever, contextualize_q_prompt
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# memory = VectorStoreRetrieverMemory(
|
|
|
|
|
# input_key="question",
|
|
|
|
|
# output_key="answer",
|
|
|
|
|
# chat_memory=chat_memory,
|
|
|
|
|
# memory_key="chat_history",
|
|
|
|
|
# return_docs=False, # Change this to False
|
|
|
|
|
# retriever=retriever,
|
|
|
|
|
# return_messages=True,
|
|
|
|
|
# prompt=SUMMARY_PROMPT_TEMPLATE
|
|
|
|
|
# )
|
|
|
|
|
|
|
|
|
|
question_generator = LLMChain(
|
|
|
|
|
llm=llm,
|
|
|
|
|
prompt=CONDENSE_PROMPT_TEMPLATE,
|
|
|
|
|
memory=memory,
|
|
|
|
|
verbose=True,
|
|
|
|
|
output_key="answer",
|
|
|
|
|
)
|
|
|
|
|
### Answer question ###
|
|
|
|
|
qa_system_prompt = """You are an assistant for question-answering tasks. \
|
|
|
|
|
Use the following pieces of retrieved context to answer the question. \
|
|
|
|
|
If you don't know the answer, just say that you don't know. \
|
|
|
|
|
Use three sentences maximum and keep the answer concise.\
|
|
|
|
|
|
|
|
|
|
stuff_chain = LLMChain(
|
|
|
|
|
llm=llm,
|
|
|
|
|
prompt=prompt,
|
|
|
|
|
verbose=True,
|
|
|
|
|
output_key="answer",
|
|
|
|
|
{context}"""
|
|
|
|
|
qa_prompt = ChatPromptTemplate.from_messages(
|
|
|
|
|
[
|
|
|
|
|
("system", qa_system_prompt),
|
|
|
|
|
MessagesPlaceholder("chat_history"),
|
|
|
|
|
("human", "{input}"),
|
|
|
|
|
]
|
|
|
|
|
)
|
|
|
|
|
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
|
|
|
|
|
|
|
|
|
|
doc_chain = StuffDocumentsChain(
|
|
|
|
|
llm_chain=stuff_chain,
|
|
|
|
|
document_variable_name="context",
|
|
|
|
|
document_prompt=DOCUMENT_PROMPT_TEMPLATE,
|
|
|
|
|
verbose=True,
|
|
|
|
|
output_key="answer",
|
|
|
|
|
memory=memory,
|
|
|
|
|
)
|
|
|
|
|
from langchain_core.runnables import RunnablePassthrough
|
|
|
|
|
|
|
|
|
|
return ConversationalRetrievalChain(
|
|
|
|
|
combine_docs_chain=doc_chain,
|
|
|
|
|
memory=memory,
|
|
|
|
|
retriever=retriever,
|
|
|
|
|
question_generator=question_generator,
|
|
|
|
|
return_generated_question=False,
|
|
|
|
|
return_source_documents=True,
|
|
|
|
|
output_key="answer",
|
|
|
|
|
verbose=True,
|
|
|
|
|
)
|
|
|
|
|
rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
|
|
|
|
|
|
|
|
|
|
return rag_chain
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
router = APIRouter()
|
|
|
|
@ -249,7 +240,7 @@ router = APIRouter()
|
|
|
|
|
description="Chatbot AI Service",
|
|
|
|
|
)
|
|
|
|
|
async def chat(request: ChatRequest):
|
|
|
|
|
chain: ConversationalRetrievalChain = await create_chain(
|
|
|
|
|
chain = await create_chain(
|
|
|
|
|
file=request.file,
|
|
|
|
|
messages=request.messages[:-1],
|
|
|
|
|
model=request.model.id,
|
|
|
|
@ -260,19 +251,12 @@ async def chat(request: ChatRequest):
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# async for token in chain.astream(request.messages[-1].content):
|
|
|
|
|
# print(f"token={token}")
|
|
|
|
|
|
|
|
|
|
json_string = json.dumps(
|
|
|
|
|
{
|
|
|
|
|
"question": request.messages[-1].content,
|
|
|
|
|
# "chat_history": [message.content for message in request.messages[:-1]],
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
return LangchainStreamingResponse(
|
|
|
|
|
response = LangchainStreamingResponse(
|
|
|
|
|
chain,
|
|
|
|
|
config={
|
|
|
|
|
"inputs": json_string,
|
|
|
|
|
"input": request.messages[-1].content,
|
|
|
|
|
"chat_history": [message.content for message in request.messages[:-1]],
|
|
|
|
|
"context": "{context}",
|
|
|
|
|
"callbacks": [
|
|
|
|
|
StreamingStdOutCallbackHandler(),
|
|
|
|
|
TokenStreamingCallbackHandler(output_key="answer"),
|
|
|
|
@ -281,6 +265,8 @@ async def chat(request: ChatRequest):
|
|
|
|
|
},
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return response
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.include_router(router, tags=["chat"])
|
|
|
|
|
|
|
|
|
|