From 96a3e46dbbb5a411be47410184602c2207e7e4cd Mon Sep 17 00:00:00 2001 From: Richard Anthony Hein Date: Wed, 14 Aug 2024 01:17:19 +0000 Subject: [PATCH] fixes for gradio app --- swarms/server/server.py | 38 ++++++++++++++++++++++++-------------- 1 file changed, 24 insertions(+), 14 deletions(-) diff --git a/swarms/server/server.py b/swarms/server/server.py index 3f759a98..860f0567 100644 --- a/swarms/server/server.py +++ b/swarms/server/server.py @@ -7,7 +7,7 @@ from typing import List import langchain from pydantic import ValidationError, parse_obj_as -from swarms.prompts.chat_prompt import Message, Role +from swarms.prompts.chat_prompt import Message from swarms.server.callback_handlers import SourceDocumentsStreamingCallbackHandler, TokenStreamingCallbackHandler import tiktoken @@ -20,7 +20,7 @@ from fastapi.routing import APIRouter from fastapi.staticfiles import StaticFiles from huggingface_hub import login from langchain.callbacks import StreamingStdOutCallbackHandler -from langchain.memory import ConversationSummaryBufferMemory +from langchain.memory import VectorStoreRetrieverMemory from langchain.memory.chat_message_histories.in_memory import ChatMessageHistory from langchain.prompts.prompt import PromptTemplate from langchain_community.chat_models import ChatOpenAI @@ -48,6 +48,7 @@ from swarms.server.server_models import ( AIModels, RAGFile, RAGFiles, + Role, State, GetRAGFileStateRequest, ProcessRAGFileRequest @@ -176,23 +177,32 @@ async def create_chain( chat_memory = ChatMessageHistory() for message in messages: - if message.role == Role.HUMAN: - chat_memory.add_user_message(message.content) - elif message.role == Role.AI: - chat_memory.add_ai_message(message.content) + if message.role == Role.USER: + chat_memory.add_user_message(message) + elif message.role == Role.ASSISTANT: + chat_memory.add_ai_message(message) elif message.role == Role.SYSTEM: - chat_memory.add_message(message.content) - elif message.role == Role.FUNCTION: - chat_memory.add_message(message.content) + chat_memory.add_message(message) + + # memory = ConversationSummaryBufferMemory( + # llm=llm, + # chat_memory=chat_memory, + # memory_key="chat_history", + # input_key="question", + # output_key="answer", + # prompt=SUMMARY_PROMPT_TEMPLATE, + # return_messages=True, + # ) - memory = ConversationSummaryBufferMemory( - llm=llm, - chat_memory=chat_memory, - memory_key="chat_history", + memory = VectorStoreRetrieverMemory( input_key="question", output_key="answer", - prompt=SUMMARY_PROMPT_TEMPLATE, + chat_memory=chat_memory, + memory_key="chat_history", + return_docs=True, # Change this to False + retriever=retriever, return_messages=True, + prompt=SUMMARY_PROMPT_TEMPLATE ) question_generator = LLMChain(