fixes for gradio app

pull/570/head
Richard Anthony Hein 9 months ago
parent 35f70affee
commit 96a3e46dbb

@ -7,7 +7,7 @@ from typing import List
import langchain import langchain
from pydantic import ValidationError, parse_obj_as from pydantic import ValidationError, parse_obj_as
from swarms.prompts.chat_prompt import Message, Role from swarms.prompts.chat_prompt import Message
from swarms.server.callback_handlers import SourceDocumentsStreamingCallbackHandler, TokenStreamingCallbackHandler from swarms.server.callback_handlers import SourceDocumentsStreamingCallbackHandler, TokenStreamingCallbackHandler
import tiktoken import tiktoken
@ -20,7 +20,7 @@ from fastapi.routing import APIRouter
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from huggingface_hub import login from huggingface_hub import login
from langchain.callbacks import StreamingStdOutCallbackHandler from langchain.callbacks import StreamingStdOutCallbackHandler
from langchain.memory import ConversationSummaryBufferMemory from langchain.memory import VectorStoreRetrieverMemory
from langchain.memory.chat_message_histories.in_memory import ChatMessageHistory from langchain.memory.chat_message_histories.in_memory import ChatMessageHistory
from langchain.prompts.prompt import PromptTemplate from langchain.prompts.prompt import PromptTemplate
from langchain_community.chat_models import ChatOpenAI from langchain_community.chat_models import ChatOpenAI
@ -48,6 +48,7 @@ from swarms.server.server_models import (
AIModels, AIModels,
RAGFile, RAGFile,
RAGFiles, RAGFiles,
Role,
State, State,
GetRAGFileStateRequest, GetRAGFileStateRequest,
ProcessRAGFileRequest ProcessRAGFileRequest
@ -176,23 +177,32 @@ async def create_chain(
chat_memory = ChatMessageHistory() chat_memory = ChatMessageHistory()
for message in messages: for message in messages:
if message.role == Role.HUMAN: if message.role == Role.USER:
chat_memory.add_user_message(message.content) chat_memory.add_user_message(message)
elif message.role == Role.AI: elif message.role == Role.ASSISTANT:
chat_memory.add_ai_message(message.content) chat_memory.add_ai_message(message)
elif message.role == Role.SYSTEM: elif message.role == Role.SYSTEM:
chat_memory.add_message(message.content) chat_memory.add_message(message)
elif message.role == Role.FUNCTION:
chat_memory.add_message(message.content) # memory = ConversationSummaryBufferMemory(
# llm=llm,
# chat_memory=chat_memory,
# memory_key="chat_history",
# input_key="question",
# output_key="answer",
# prompt=SUMMARY_PROMPT_TEMPLATE,
# return_messages=True,
# )
memory = ConversationSummaryBufferMemory( memory = VectorStoreRetrieverMemory(
llm=llm,
chat_memory=chat_memory,
memory_key="chat_history",
input_key="question", input_key="question",
output_key="answer", output_key="answer",
prompt=SUMMARY_PROMPT_TEMPLATE, chat_memory=chat_memory,
memory_key="chat_history",
return_docs=True, # Change this to False
retriever=retriever,
return_messages=True, return_messages=True,
prompt=SUMMARY_PROMPT_TEMPLATE
) )
question_generator = LLMChain( question_generator = LLMChain(

Loading…
Cancel
Save