|
|
|
@ -7,7 +7,7 @@ from typing import List
|
|
|
|
|
|
|
|
|
|
import langchain
|
|
|
|
|
from pydantic import ValidationError, parse_obj_as
|
|
|
|
|
from swarms.prompts.chat_prompt import Message, Role
|
|
|
|
|
from swarms.prompts.chat_prompt import Message
|
|
|
|
|
from swarms.server.callback_handlers import SourceDocumentsStreamingCallbackHandler, TokenStreamingCallbackHandler
|
|
|
|
|
import tiktoken
|
|
|
|
|
|
|
|
|
@ -20,7 +20,7 @@ from fastapi.routing import APIRouter
|
|
|
|
|
from fastapi.staticfiles import StaticFiles
|
|
|
|
|
from huggingface_hub import login
|
|
|
|
|
from langchain.callbacks import StreamingStdOutCallbackHandler
|
|
|
|
|
from langchain.memory import ConversationSummaryBufferMemory
|
|
|
|
|
from langchain.memory import VectorStoreRetrieverMemory
|
|
|
|
|
from langchain.memory.chat_message_histories.in_memory import ChatMessageHistory
|
|
|
|
|
from langchain.prompts.prompt import PromptTemplate
|
|
|
|
|
from langchain_community.chat_models import ChatOpenAI
|
|
|
|
@ -48,6 +48,7 @@ from swarms.server.server_models import (
|
|
|
|
|
AIModels,
|
|
|
|
|
RAGFile,
|
|
|
|
|
RAGFiles,
|
|
|
|
|
Role,
|
|
|
|
|
State,
|
|
|
|
|
GetRAGFileStateRequest,
|
|
|
|
|
ProcessRAGFileRequest
|
|
|
|
@ -176,23 +177,32 @@ async def create_chain(
|
|
|
|
|
chat_memory = ChatMessageHistory()
|
|
|
|
|
|
|
|
|
|
for message in messages:
|
|
|
|
|
if message.role == Role.HUMAN:
|
|
|
|
|
chat_memory.add_user_message(message.content)
|
|
|
|
|
elif message.role == Role.AI:
|
|
|
|
|
chat_memory.add_ai_message(message.content)
|
|
|
|
|
if message.role == Role.USER:
|
|
|
|
|
chat_memory.add_user_message(message)
|
|
|
|
|
elif message.role == Role.ASSISTANT:
|
|
|
|
|
chat_memory.add_ai_message(message)
|
|
|
|
|
elif message.role == Role.SYSTEM:
|
|
|
|
|
chat_memory.add_message(message.content)
|
|
|
|
|
elif message.role == Role.FUNCTION:
|
|
|
|
|
chat_memory.add_message(message.content)
|
|
|
|
|
chat_memory.add_message(message)
|
|
|
|
|
|
|
|
|
|
# memory = ConversationSummaryBufferMemory(
|
|
|
|
|
# llm=llm,
|
|
|
|
|
# chat_memory=chat_memory,
|
|
|
|
|
# memory_key="chat_history",
|
|
|
|
|
# input_key="question",
|
|
|
|
|
# output_key="answer",
|
|
|
|
|
# prompt=SUMMARY_PROMPT_TEMPLATE,
|
|
|
|
|
# return_messages=True,
|
|
|
|
|
# )
|
|
|
|
|
|
|
|
|
|
memory = ConversationSummaryBufferMemory(
|
|
|
|
|
llm=llm,
|
|
|
|
|
chat_memory=chat_memory,
|
|
|
|
|
memory_key="chat_history",
|
|
|
|
|
memory = VectorStoreRetrieverMemory(
|
|
|
|
|
input_key="question",
|
|
|
|
|
output_key="answer",
|
|
|
|
|
prompt=SUMMARY_PROMPT_TEMPLATE,
|
|
|
|
|
chat_memory=chat_memory,
|
|
|
|
|
memory_key="chat_history",
|
|
|
|
|
return_docs=True, # Change this to False
|
|
|
|
|
retriever=retriever,
|
|
|
|
|
return_messages=True,
|
|
|
|
|
prompt=SUMMARY_PROMPT_TEMPLATE
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
question_generator = LLMChain(
|
|
|
|
|