fixes to stream SSE events

pull/570/head
Richard Anthony Hein 8 months ago
parent 4c18e8d588
commit fde7febd29

@ -149,7 +149,16 @@ class LangchainStreamingResponse(StreamingResponse):
# TODO: migrate to `.ainvoke` when adding support # TODO: migrate to `.ainvoke` when adding support
# for LCEL # for LCEL
if self.run_mode == ChainRunMode.ASYNC: if self.run_mode == ChainRunMode.ASYNC:
outputs = await self.chain.acall(**self.config) async for outputs in self.chain.astream(input=self.config):
if 'answer' in outputs:
chunk = ServerSentEvent(
data=outputs['answer']
)
# Send each chunk with the appropriate body type
await send(
{"type": "http.response.body", "body": ensure_bytes(chunk, None), "more_body": True}
)
else: else:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
outputs = await loop.run_in_executor( outputs = await loop.run_in_executor(

@ -20,8 +20,13 @@ from fastapi.routing import APIRouter
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from huggingface_hub import login from huggingface_hub import login
from langchain.callbacks import StreamingStdOutCallbackHandler from langchain.callbacks import StreamingStdOutCallbackHandler
from langchain.memory import ConversationStringBufferMemory from langchain.memory import VectorStoreRetrieverMemory
from langchain.memory.chat_message_histories.in_memory import ChatMessageHistory from langchain.memory.chat_message_histories.in_memory import ChatMessageHistory
from langchain_core.messages import SystemMessage, AIMessage, HumanMessage
from langchain.chains.history_aware_retriever import create_history_aware_retriever
from langchain.chains.retrieval import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.prompts.prompt import PromptTemplate from langchain.prompts.prompt import PromptTemplate
from langchain_community.chat_models import ChatOpenAI from langchain_community.chat_models import ChatOpenAI
from swarms.server.responses import LangchainStreamingResponse from swarms.server.responses import LangchainStreamingResponse
@ -174,71 +179,57 @@ async def create_chain(
retriever = await vector_store.getRetriever(os.path.join(file.username, file.filename)) retriever = await vector_store.getRetriever(os.path.join(file.username, file.filename))
chat_memory = ChatMessageHistory() # chat_memory = ChatMessageHistory()
for message in messages: # for message in messages:
if message.role == Role.USER: # if message.role == Role.USER:
chat_memory.add_user_message(message) # human_msg = HumanMessage(message.content)
elif message.role == Role.ASSISTANT: # chat_memory.add_user_message(human_msg)
chat_memory.add_ai_message(message) # elif message.role == Role.ASSISTANT:
elif message.role == Role.SYSTEM: # ai_msg = AIMessage(message.content)
chat_memory.add_message(message) # chat_memory.add_ai_message(ai_msg)
# elif message.role == Role.SYSTEM:
memory = ConversationStringBufferMemory( # system_msg = SystemMessage(message.content)
llm=llm, # chat_memory.add_message(system_msg)
chat_memory=chat_memory,
memory_key="chat_history", ### Contextualize question ###
input_key="question", contextualize_q_system_prompt = """Given a chat history and the latest user question \
output_key="answer", which might reference context in the chat history, formulate a standalone question \
prompt=SUMMARY_PROMPT_TEMPLATE, which can be understood without the chat history. Do NOT answer the question, \
return_messages=False, just reformulate it if needed and otherwise return it as is."""
contextualize_q_prompt = ChatPromptTemplate.from_messages(
[
("system", contextualize_q_system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
]
)
history_aware_retriever = create_history_aware_retriever(
llm, retriever, contextualize_q_prompt
) )
# memory = VectorStoreRetrieverMemory(
# input_key="question",
# output_key="answer",
# chat_memory=chat_memory,
# memory_key="chat_history",
# return_docs=False, # Change this to False
# retriever=retriever,
# return_messages=True,
# prompt=SUMMARY_PROMPT_TEMPLATE
# )
question_generator = LLMChain( ### Answer question ###
llm=llm, qa_system_prompt = """You are an assistant for question-answering tasks. \
prompt=CONDENSE_PROMPT_TEMPLATE, Use the following pieces of retrieved context to answer the question. \
memory=memory, If you don't know the answer, just say that you don't know. \
verbose=True, Use three sentences maximum and keep the answer concise.\
output_key="answer",
)
stuff_chain = LLMChain( {context}"""
llm=llm, qa_prompt = ChatPromptTemplate.from_messages(
prompt=prompt, [
verbose=True, ("system", qa_system_prompt),
output_key="answer", MessagesPlaceholder("chat_history"),
("human", "{input}"),
]
) )
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
doc_chain = StuffDocumentsChain( from langchain_core.runnables import RunnablePassthrough
llm_chain=stuff_chain,
document_variable_name="context",
document_prompt=DOCUMENT_PROMPT_TEMPLATE,
verbose=True,
output_key="answer",
memory=memory,
)
return ConversationalRetrievalChain( rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
combine_docs_chain=doc_chain,
memory=memory, return rag_chain
retriever=retriever,
question_generator=question_generator,
return_generated_question=False,
return_source_documents=True,
output_key="answer",
verbose=True,
)
router = APIRouter() router = APIRouter()
@ -249,7 +240,7 @@ router = APIRouter()
description="Chatbot AI Service", description="Chatbot AI Service",
) )
async def chat(request: ChatRequest): async def chat(request: ChatRequest):
chain: ConversationalRetrievalChain = await create_chain( chain = await create_chain(
file=request.file, file=request.file,
messages=request.messages[:-1], messages=request.messages[:-1],
model=request.model.id, model=request.model.id,
@ -260,19 +251,12 @@ async def chat(request: ChatRequest):
), ),
) )
# async for token in chain.astream(request.messages[-1].content): response = LangchainStreamingResponse(
# print(f"token={token}")
json_string = json.dumps(
{
"question": request.messages[-1].content,
# "chat_history": [message.content for message in request.messages[:-1]],
}
)
return LangchainStreamingResponse(
chain, chain,
config={ config={
"inputs": json_string, "input": request.messages[-1].content,
"chat_history": [message.content for message in request.messages[:-1]],
"context": "{context}",
"callbacks": [ "callbacks": [
StreamingStdOutCallbackHandler(), StreamingStdOutCallbackHandler(),
TokenStreamingCallbackHandler(output_key="answer"), TokenStreamingCallbackHandler(output_key="answer"),
@ -281,6 +265,8 @@ async def chat(request: ChatRequest):
}, },
) )
return response
app.include_router(router, tags=["chat"]) app.include_router(router, tags=["chat"])

Loading…
Cancel
Save