You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
322 lines
9.3 KiB
322 lines
9.3 KiB
""" Chatbot with RAG Server """
|
|
import asyncio
|
|
import logging
|
|
import os
|
|
from urllib.parse import urlparse, urljoin
|
|
# import torch
|
|
from contextlib import asynccontextmanager
|
|
from typing import AsyncIterator
|
|
from swarms.structs.agent import Agent
|
|
|
|
from dotenv import load_dotenv
|
|
from fastapi import FastAPI, HTTPException, Request
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import FileResponse, JSONResponse
|
|
from fastapi.routing import APIRouter
|
|
from fastapi.staticfiles import StaticFiles
|
|
from huggingface_hub import login
|
|
|
|
from swarms.prompts.chat_prompt import Message, Role
|
|
from swarms.prompts.conversational_RAG import (
|
|
B_INST,
|
|
B_SYS,
|
|
CONDENSE_PROMPT_TEMPLATE,
|
|
DOCUMENT_PROMPT_TEMPLATE,
|
|
E_INST,
|
|
E_SYS,
|
|
QA_PROMPT_TEMPLATE_STR,
|
|
)
|
|
from playground.demos.chatbot.server.responses import StreamingResponse
|
|
from playground.demos.chatbot.server.server_models import ChatRequest
|
|
from playground.demos.chatbot.server.vector_storage import RedisVectorStorage
|
|
from swarms.models.popular_llms import OpenAIChatLLM
|
|
|
|
# Explicitly specify the path to the .env file
|
|
# Two folders above the current file's directory
|
|
dotenv_path = os.path.join(
|
|
os.path.dirname(os.path.dirname(os.path.dirname(__file__))), ".env"
|
|
)
|
|
load_dotenv(dotenv_path)
|
|
|
|
hf_token = os.environ.get(
|
|
"HUGGINFACEHUB_API_KEY"
|
|
) # Get the Huggingface API Token
|
|
|
|
uploads = os.environ.get(
|
|
"UPLOADS"
|
|
) # Directory where user uploads files to be parsed for RAG
|
|
|
|
# model_dir = os.environ.get("MODEL_DIR")
|
|
|
|
# hugginface.co model (eg. meta-llama/Llama-2-70b-hf)
|
|
# model_name = os.environ.get("MODEL_NAME")
|
|
|
|
# Set OpenAI's API key to 'EMPTY' and API base URL to use vLLM's API server
|
|
# or set them to OpenAI API key and base URL.
|
|
openai_api_key = os.environ.get("OPENAI_API_KEY") or "EMPTY"
|
|
openai_api_base = (
|
|
os.environ.get("OPENAI_API_BASE") or "http://localhost:8000/v1"
|
|
)
|
|
|
|
env_vars = [
|
|
hf_token,
|
|
uploads,
|
|
openai_api_key,
|
|
openai_api_base,
|
|
]
|
|
missing_vars = [var for var in env_vars if not var]
|
|
|
|
if missing_vars:
|
|
print(
|
|
f"Error: The following environment variables are not set: {', '.join(missing_vars)}"
|
|
)
|
|
exit(1)
|
|
|
|
useMetal = os.environ.get("USE_METAL", "False") == "True"
|
|
use_gpu = os.environ.get("USE_GPU", "False") == "True"
|
|
|
|
print(f"Uploads={uploads}")
|
|
print(f"USE_METAL={useMetal}")
|
|
print(f"USE_GPU={use_gpu}")
|
|
print(f"OPENAI_API_KEY={openai_api_key}")
|
|
print(f"OPENAI_API_BASE={openai_api_base}")
|
|
|
|
# # update tiktoken to include the model name (avoids warning message)
|
|
# tiktoken.model.MODEL_TO_ENCODING.update(
|
|
# {
|
|
# model_name: "cl100k_base",
|
|
# }
|
|
# )
|
|
|
|
print("Logging in to huggingface.co...")
|
|
login(token=hf_token) # login to huggingface.co
|
|
|
|
app = FastAPI(title="Chatbot")
|
|
router = APIRouter()
|
|
|
|
current_dir = os.path.dirname(__file__)
|
|
print("current_dir: " + current_dir)
|
|
static_dir = os.path.join(current_dir, "static")
|
|
print("static_dir: " + static_dir)
|
|
app.mount(static_dir, StaticFiles(directory=static_dir), name="static")
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["GET", "POST"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
|
|
# Create ./uploads folder if it doesn't exist
|
|
uploads = uploads or os.path.join(os.getcwd(), "uploads")
|
|
if not os.path.exists(uploads):
|
|
os.makedirs(uploads)
|
|
|
|
# Initialize the vector store
|
|
# Hardcoded for Swarms documention
|
|
URL = "https://docs.swarms.world/en/latest/"
|
|
vector_store = RedisVectorStorage(use_gpu=use_gpu)
|
|
vector_store.crawl(URL)
|
|
print("Vector storage initialized.")
|
|
|
|
async def create_chat(
|
|
messages: list[Message],
|
|
model_name: str,
|
|
prompt: str = QA_PROMPT_TEMPLATE_STR,
|
|
):
|
|
"""Creates the RAG conversational retrieval chain."""
|
|
print("Creating chat from history and relevant docs if any ...")
|
|
|
|
llm = OpenAIChatLLM(
|
|
api_key=openai_api_key,
|
|
base_url=openai_api_base,
|
|
model=model_name,
|
|
verbose=True,
|
|
streaming=True,
|
|
)
|
|
|
|
doc_retrieval_string = ""
|
|
for message in messages:
|
|
if message.role == Role.HUMAN:
|
|
doc_retrieval_string += f"{Role.HUMAN}: {message.content}\r\n"
|
|
elif message.role == Role.AI:
|
|
doc_retrieval_string += f"{Role.AI}: {message.content}\r\n"
|
|
|
|
docs = vector_store.embed(messages[-1].content)
|
|
|
|
# find {context} in prompt and replace it with the docs page_content.
|
|
# Concatenate the content of all documents
|
|
context = "\n".join(doc["content"] for doc in docs)
|
|
|
|
sources = [urlparse(URL).scheme + "://" + doc["source_url"] for doc in docs]
|
|
|
|
print(f"context: {context}")
|
|
|
|
# Replace {context} in the prompt with the concatenated document content
|
|
prompt = prompt.replace("{context}", context)
|
|
|
|
# Replace {chat_history} in the prompt with doc_retrieval_string
|
|
prompt = prompt.replace("{chat_history}", doc_retrieval_string)
|
|
|
|
# Replace {question} in the prompt with the last message.
|
|
prompt = prompt.replace("{question}", messages[-1].content)
|
|
|
|
# Initialize the agent
|
|
agent = Agent(
|
|
agent_name="Swarms QA ChatBot",
|
|
system_prompt=prompt,
|
|
llm=llm,
|
|
max_loops=1,
|
|
autosave=True,
|
|
# dynamic_temperature_enabled=True,
|
|
dashboard=False,
|
|
verbose=True,
|
|
streaming_on=True,
|
|
# interactive=True, # Set to False to disable interactive mode
|
|
dynamic_temperature_enabled=False,
|
|
saved_state_path="chatbot.json",
|
|
# tools=[#Add your functions here# ],
|
|
# stopping_token="Stop!",
|
|
# interactive=True,
|
|
# docs_folder="docs", # Enter your folder name
|
|
# pdf_path="docs/finance_agent.pdf",
|
|
# sop="Calculate the profit for a company.",
|
|
# sop_list=["Calculate the profit for a company."],
|
|
user_name="RAH@EntangleIT.com",
|
|
# docs=[doc["content"] for doc in docs],
|
|
# # docs_folder="docs",
|
|
retry_attempts=3,
|
|
# context_length=1000,
|
|
# tool_schema = dict
|
|
context_length=200000,
|
|
# tool_schema=
|
|
# tools
|
|
# agent_ops_on=True,
|
|
)
|
|
|
|
for message in messages[:-1]:
|
|
if message.role == Role.HUMAN:
|
|
agent.add_message_to_memory(message.content)
|
|
elif message.role == Role.AI:
|
|
agent.add_message_to_memory(message.content)
|
|
|
|
# add docs to short term memory
|
|
# for data in [doc["content"] for doc in docs]:
|
|
# agent.add_message_to_memory(role=Role.HUMAN, content=data)
|
|
|
|
async for response in agent.run_async(messages[-1].content):
|
|
res = response
|
|
res += "\n\nSources:\n"
|
|
for source in sources:
|
|
res += source + "\n"
|
|
yield res
|
|
|
|
# memory = ConversationBufferMemory(
|
|
# chat_memory=chat_memory,
|
|
# memory_key="chat_history",
|
|
# input_key="question",
|
|
# output_key="answer",
|
|
# return_messages=True,
|
|
# )
|
|
|
|
# question_generator = LLMChain(
|
|
# llm=llm,
|
|
# prompt=CONDENSE_PROMPT_TEMPLATE,
|
|
# memory=memory,
|
|
# verbose=True,
|
|
# output_key="answer",
|
|
# )
|
|
|
|
# stuff_chain = LLMChain(
|
|
# llm=llm,
|
|
# prompt=prompt,
|
|
# verbose=True,
|
|
# output_key="answer",
|
|
# )
|
|
|
|
# doc_chain = StuffDocumentsChain(
|
|
# llm_chain=stuff_chain,
|
|
# document_variable_name="context",
|
|
# document_prompt=DOCUMENT_PROMPT_TEMPLATE,
|
|
# verbose=True,
|
|
# output_key="answer",
|
|
# memory=memory,
|
|
# )
|
|
|
|
# return ConversationalRetrievalChain(
|
|
# combine_docs_chain=doc_chain,
|
|
# memory=memory,
|
|
# retriever=retriever,
|
|
# question_generator=question_generator,
|
|
# return_generated_question=False,
|
|
# return_source_documents=True,
|
|
# output_key="answer",
|
|
# verbose=True,
|
|
# )
|
|
|
|
@app.post(
|
|
"/chat",
|
|
summary="Chatbot",
|
|
description="Chatbot AI Service",
|
|
)
|
|
async def chat(request: ChatRequest):
|
|
""" Handles chatbot chat POST requests """
|
|
response = create_chat(
|
|
messages=request.messages,
|
|
prompt=request.prompt.strip(),
|
|
model_name=request.model.id
|
|
)
|
|
# return response
|
|
return StreamingResponse(content=response)
|
|
|
|
# json_config = {
|
|
# "question": request.messages[-1].content,
|
|
# "chat_history": [
|
|
# message.content for message in request.messages[:-1]
|
|
# ],
|
|
# # "callbacks": [
|
|
# # StreamingStdOutCallbackHandler(),
|
|
# # TokenStreamingCallbackHandler(output_key="answer"),
|
|
# # SourceDocumentsStreamingCallbackHandler(),
|
|
# # ],
|
|
# }
|
|
# return LangchainStreamingResponse(
|
|
# chain=chain,
|
|
# config=json_config,
|
|
# run_mode="async"
|
|
# )
|
|
|
|
@app.get("/")
|
|
def root():
|
|
"""Swarms Chatbot API Root"""
|
|
return {"message": "Swarms Chatbot API"}
|
|
|
|
|
|
@app.get("/favicon.ico")
|
|
def favicon():
|
|
""" Returns a favicon """
|
|
file_name = "favicon.ico"
|
|
file_path = os.path.join(app.root_path, "static", file_name)
|
|
return FileResponse(
|
|
path=file_path,
|
|
headers={
|
|
"Content-Disposition": "attachment; filename=" + file_name
|
|
},
|
|
)
|
|
|
|
|
|
logging.basicConfig(level=logging.ERROR)
|
|
|
|
|
|
@app.exception_handler(HTTPException)
|
|
async def http_exception_handler(r: Request, exc: HTTPException):
|
|
"""Log and return exception details in response."""
|
|
logging.error(
|
|
"HTTPException: %s executing request: %s", exc.detail, r.base_url
|
|
)
|
|
return JSONResponse(
|
|
status_code=exc.status_code, content={"detail": exc.detail}
|
|
)
|