diff --git a/swarms/server/responses.py b/swarms/server/responses.py index fab90a40..5b1785e1 100644 --- a/swarms/server/responses.py +++ b/swarms/server/responses.py @@ -149,7 +149,16 @@ class LangchainStreamingResponse(StreamingResponse): # TODO: migrate to `.ainvoke` when adding support # for LCEL if self.run_mode == ChainRunMode.ASYNC: - outputs = await self.chain.acall(**self.config) + async for outputs in self.chain.astream(input=self.config): + if 'answer' in outputs: + chunk = ServerSentEvent( + data=outputs['answer'] + ) + # Send each chunk with the appropriate body type + await send( + {"type": "http.response.body", "body": ensure_bytes(chunk, None), "more_body": True} + ) + else: loop = asyncio.get_event_loop() outputs = await loop.run_in_executor( diff --git a/swarms/server/server.py b/swarms/server/server.py index 14e5a49d..412eea8c 100644 --- a/swarms/server/server.py +++ b/swarms/server/server.py @@ -20,8 +20,13 @@ from fastapi.routing import APIRouter from fastapi.staticfiles import StaticFiles from huggingface_hub import login from langchain.callbacks import StreamingStdOutCallbackHandler -from langchain.memory import ConversationStringBufferMemory +from langchain.memory import VectorStoreRetrieverMemory from langchain.memory.chat_message_histories.in_memory import ChatMessageHistory +from langchain_core.messages import SystemMessage, AIMessage, HumanMessage +from langchain.chains.history_aware_retriever import create_history_aware_retriever +from langchain.chains.retrieval import create_retrieval_chain +from langchain.chains.combine_documents import create_stuff_documents_chain +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain.prompts.prompt import PromptTemplate from langchain_community.chat_models import ChatOpenAI from swarms.server.responses import LangchainStreamingResponse @@ -174,71 +179,57 @@ async def create_chain( retriever = await vector_store.getRetriever(os.path.join(file.username, file.filename)) - chat_memory = ChatMessageHistory() - - for message in messages: - if message.role == Role.USER: - chat_memory.add_user_message(message) - elif message.role == Role.ASSISTANT: - chat_memory.add_ai_message(message) - elif message.role == Role.SYSTEM: - chat_memory.add_message(message) - - memory = ConversationStringBufferMemory( - llm=llm, - chat_memory=chat_memory, - memory_key="chat_history", - input_key="question", - output_key="answer", - prompt=SUMMARY_PROMPT_TEMPLATE, - return_messages=False, + # chat_memory = ChatMessageHistory() + + # for message in messages: + # if message.role == Role.USER: + # human_msg = HumanMessage(message.content) + # chat_memory.add_user_message(human_msg) + # elif message.role == Role.ASSISTANT: + # ai_msg = AIMessage(message.content) + # chat_memory.add_ai_message(ai_msg) + # elif message.role == Role.SYSTEM: + # system_msg = SystemMessage(message.content) + # chat_memory.add_message(system_msg) + + ### Contextualize question ### + contextualize_q_system_prompt = """Given a chat history and the latest user question \ + which might reference context in the chat history, formulate a standalone question \ + which can be understood without the chat history. Do NOT answer the question, \ + just reformulate it if needed and otherwise return it as is.""" + contextualize_q_prompt = ChatPromptTemplate.from_messages( + [ + ("system", contextualize_q_system_prompt), + MessagesPlaceholder("chat_history"), + ("human", "{input}"), + ] + ) + history_aware_retriever = create_history_aware_retriever( + llm, retriever, contextualize_q_prompt ) - # memory = VectorStoreRetrieverMemory( - # input_key="question", - # output_key="answer", - # chat_memory=chat_memory, - # memory_key="chat_history", - # return_docs=False, # Change this to False - # retriever=retriever, - # return_messages=True, - # prompt=SUMMARY_PROMPT_TEMPLATE - # ) - question_generator = LLMChain( - llm=llm, - prompt=CONDENSE_PROMPT_TEMPLATE, - memory=memory, - verbose=True, - output_key="answer", - ) + ### Answer question ### + qa_system_prompt = """You are an assistant for question-answering tasks. \ + Use the following pieces of retrieved context to answer the question. \ + If you don't know the answer, just say that you don't know. \ + Use three sentences maximum and keep the answer concise.\ - stuff_chain = LLMChain( - llm=llm, - prompt=prompt, - verbose=True, - output_key="answer", + {context}""" + qa_prompt = ChatPromptTemplate.from_messages( + [ + ("system", qa_system_prompt), + MessagesPlaceholder("chat_history"), + ("human", "{input}"), + ] ) + question_answer_chain = create_stuff_documents_chain(llm, qa_prompt) - doc_chain = StuffDocumentsChain( - llm_chain=stuff_chain, - document_variable_name="context", - document_prompt=DOCUMENT_PROMPT_TEMPLATE, - verbose=True, - output_key="answer", - memory=memory, - ) + from langchain_core.runnables import RunnablePassthrough - return ConversationalRetrievalChain( - combine_docs_chain=doc_chain, - memory=memory, - retriever=retriever, - question_generator=question_generator, - return_generated_question=False, - return_source_documents=True, - output_key="answer", - verbose=True, - ) + rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain) + + return rag_chain router = APIRouter() @@ -249,7 +240,7 @@ router = APIRouter() description="Chatbot AI Service", ) async def chat(request: ChatRequest): - chain: ConversationalRetrievalChain = await create_chain( + chain = await create_chain( file=request.file, messages=request.messages[:-1], model=request.model.id, @@ -260,19 +251,12 @@ async def chat(request: ChatRequest): ), ) - # async for token in chain.astream(request.messages[-1].content): - # print(f"token={token}") - - json_string = json.dumps( - { - "question": request.messages[-1].content, - # "chat_history": [message.content for message in request.messages[:-1]], - } - ) - return LangchainStreamingResponse( + response = LangchainStreamingResponse( chain, config={ - "inputs": json_string, + "input": request.messages[-1].content, + "chat_history": [message.content for message in request.messages[:-1]], + "context": "{context}", "callbacks": [ StreamingStdOutCallbackHandler(), TokenStreamingCallbackHandler(output_key="answer"), @@ -281,6 +265,8 @@ async def chat(request: ChatRequest): }, ) + return response + app.include_router(router, tags=["chat"])