many fixes for the new vectorstore and to support model changing in the UI

pull/570/head
Richard Anthony Hein 8 months ago
parent 71faeadfa4
commit 8a2ce30598

@ -1,14 +1,13 @@
""" Chatbot with RAG Server """
import asyncio
import logging
import os
from urllib.parse import urlparse, urljoin
# import torch
from contextlib import asynccontextmanager
from typing import AsyncIterator
from swarms.structs.agent import Agent
import tiktoken
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
@ -29,7 +28,7 @@ from swarms.prompts.conversational_RAG import (
)
from playground.demos.chatbot.server.responses import StreamingResponse
from playground.demos.chatbot.server.server_models import ChatRequest
from playground.demos.chatbot.server.vector_store import VectorStorage
from playground.demos.chatbot.server.vector_storage import RedisVectorStorage
from swarms.models.popular_llms import OpenAIChatLLM
# Explicitly specify the path to the .env file
@ -42,13 +41,15 @@ load_dotenv(dotenv_path)
hf_token = os.environ.get(
"HUGGINFACEHUB_API_KEY"
) # Get the Huggingface API Token
uploads = os.environ.get(
"UPLOADS"
) # Directory where user uploads files to be parsed for RAG
model_dir = os.environ.get("MODEL_DIR")
# model_dir = os.environ.get("MODEL_DIR")
# hugginface.co model (eg. meta-llama/Llama-2-70b-hf)
model_name = os.environ.get("MODEL_NAME")
# model_name = os.environ.get("MODEL_NAME")
# Set OpenAI's API key to 'EMPTY' and API base URL to use vLLM's API server
# or set them to OpenAI API key and base URL.
@ -60,8 +61,6 @@ openai_api_base = (
env_vars = [
hf_token,
uploads,
model_dir,
model_name,
openai_api_key,
openai_api_base,
]
@ -77,31 +76,22 @@ useMetal = os.environ.get("USE_METAL", "False") == "True"
use_gpu = os.environ.get("USE_GPU", "False") == "True"
print(f"Uploads={uploads}")
print(f"MODEL_DIR={model_dir}")
print(f"MODEL_NAME={model_name}")
print(f"USE_METAL={useMetal}")
print(f"USE_GPU={use_gpu}")
print(f"OPENAI_API_KEY={openai_api_key}")
print(f"OPENAI_API_BASE={openai_api_base}")
# update tiktoken to include the model name (avoids warning message)
tiktoken.model.MODEL_TO_ENCODING.update(
{
model_name: "cl100k_base",
}
)
# # update tiktoken to include the model name (avoids warning message)
# tiktoken.model.MODEL_TO_ENCODING.update(
# {
# model_name: "cl100k_base",
# }
# )
print("Logging in to huggingface.co...")
login(token=hf_token) # login to huggingface.co
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Initializes the vector store in a background task."""
asyncio.create_task(vector_store.init_retrievers())
yield
app = FastAPI(title="Chatbot", lifespan=lifespan)
app = FastAPI(title="Chatbot")
router = APIRouter()
current_dir = os.path.dirname(__file__)
@ -125,11 +115,15 @@ if not os.path.exists(uploads):
os.makedirs(uploads)
# Initialize the vector store
vector_store = VectorStorage(directory=uploads, use_gpu=use_gpu)
# Hardcoded for Swarms documention
URL = "https://docs.swarms.world/en/latest/"
vector_store = RedisVectorStorage(use_gpu=use_gpu)
vector_store.crawl(URL)
print("Vector storage initialized.")
async def create_chat(
messages: list[Message],
model_name: str,
prompt: str = QA_PROMPT_TEMPLATE_STR,
):
"""Creates the RAG conversational retrieval chain."""
@ -143,7 +137,6 @@ async def create_chat(
streaming=True,
)
retriever = await vector_store.get_retriever("swarms")
doc_retrieval_string = ""
for message in messages:
if message.role == Role.HUMAN:
@ -151,11 +144,15 @@ async def create_chat(
elif message.role == Role.AI:
doc_retrieval_string += f"{Role.AI}: {message.content}\r\n"
docs = retriever.invoke(doc_retrieval_string)
docs = vector_store.embed(messages[-1].content)
# find {context} in prompt and replace it with the docs page_content.
# Concatenate the content of all documents
context = "\n".join(doc.page_content for doc in docs)
context = "\n".join(doc["content"] for doc in docs)
sources = [urlparse(URL).scheme + "://" + doc["source_url"] for doc in docs]
print(f"context: {context}")
# Replace {context} in the prompt with the concatenated document content
prompt = prompt.replace("{context}", context)
@ -188,7 +185,7 @@ async def create_chat(
# sop="Calculate the profit for a company.",
# sop_list=["Calculate the profit for a company."],
user_name="RAH@EntangleIT.com",
docs=[doc.page_content for doc in docs],
# docs=[doc["content"] for doc in docs],
# # docs_folder="docs",
retry_attempts=3,
# context_length=1000,
@ -205,8 +202,16 @@ async def create_chat(
elif message.role == Role.AI:
agent.add_message_to_memory(message.content)
# add docs to short term memory
# for data in [doc["content"] for doc in docs]:
# agent.add_message_to_memory(role=Role.HUMAN, content=data)
async for response in agent.run_async(messages[-1].content):
yield response
res = response
res += "\n\nSources:\n"
for source in sources:
res += source + "\n"
yield res
# memory = ConversationBufferMemory(
# chat_memory=chat_memory,
@ -260,7 +265,8 @@ async def chat(request: ChatRequest):
""" Handles chatbot chat POST requests """
response = create_chat(
messages=request.messages,
prompt=request.prompt.strip()
prompt=request.prompt.strip(),
model_name=request.model.id
)
# return response
return StreamingResponse(content=response)

Loading…
Cancel
Save