many fixes for the new vectorstore and to support model changing in the UI

pull/570/head
Richard Anthony Hein 8 months ago
parent 71faeadfa4
commit 8a2ce30598

@ -1,14 +1,13 @@
""" Chatbot with RAG Server """ """ Chatbot with RAG Server """
import asyncio import asyncio
import logging import logging
import os import os
from urllib.parse import urlparse, urljoin
# import torch # import torch
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import AsyncIterator from typing import AsyncIterator
from swarms.structs.agent import Agent from swarms.structs.agent import Agent
import tiktoken
from dotenv import load_dotenv from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException, Request from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
@ -29,7 +28,7 @@ from swarms.prompts.conversational_RAG import (
) )
from playground.demos.chatbot.server.responses import StreamingResponse from playground.demos.chatbot.server.responses import StreamingResponse
from playground.demos.chatbot.server.server_models import ChatRequest from playground.demos.chatbot.server.server_models import ChatRequest
from playground.demos.chatbot.server.vector_store import VectorStorage from playground.demos.chatbot.server.vector_storage import RedisVectorStorage
from swarms.models.popular_llms import OpenAIChatLLM from swarms.models.popular_llms import OpenAIChatLLM
# Explicitly specify the path to the .env file # Explicitly specify the path to the .env file
@ -42,13 +41,15 @@ load_dotenv(dotenv_path)
hf_token = os.environ.get( hf_token = os.environ.get(
"HUGGINFACEHUB_API_KEY" "HUGGINFACEHUB_API_KEY"
) # Get the Huggingface API Token ) # Get the Huggingface API Token
uploads = os.environ.get( uploads = os.environ.get(
"UPLOADS" "UPLOADS"
) # Directory where user uploads files to be parsed for RAG ) # Directory where user uploads files to be parsed for RAG
model_dir = os.environ.get("MODEL_DIR")
# model_dir = os.environ.get("MODEL_DIR")
# hugginface.co model (eg. meta-llama/Llama-2-70b-hf) # hugginface.co model (eg. meta-llama/Llama-2-70b-hf)
model_name = os.environ.get("MODEL_NAME") # model_name = os.environ.get("MODEL_NAME")
# Set OpenAI's API key to 'EMPTY' and API base URL to use vLLM's API server # Set OpenAI's API key to 'EMPTY' and API base URL to use vLLM's API server
# or set them to OpenAI API key and base URL. # or set them to OpenAI API key and base URL.
@ -60,8 +61,6 @@ openai_api_base = (
env_vars = [ env_vars = [
hf_token, hf_token,
uploads, uploads,
model_dir,
model_name,
openai_api_key, openai_api_key,
openai_api_base, openai_api_base,
] ]
@ -77,31 +76,22 @@ useMetal = os.environ.get("USE_METAL", "False") == "True"
use_gpu = os.environ.get("USE_GPU", "False") == "True" use_gpu = os.environ.get("USE_GPU", "False") == "True"
print(f"Uploads={uploads}") print(f"Uploads={uploads}")
print(f"MODEL_DIR={model_dir}")
print(f"MODEL_NAME={model_name}")
print(f"USE_METAL={useMetal}") print(f"USE_METAL={useMetal}")
print(f"USE_GPU={use_gpu}") print(f"USE_GPU={use_gpu}")
print(f"OPENAI_API_KEY={openai_api_key}") print(f"OPENAI_API_KEY={openai_api_key}")
print(f"OPENAI_API_BASE={openai_api_base}") print(f"OPENAI_API_BASE={openai_api_base}")
# update tiktoken to include the model name (avoids warning message) # # update tiktoken to include the model name (avoids warning message)
tiktoken.model.MODEL_TO_ENCODING.update( # tiktoken.model.MODEL_TO_ENCODING.update(
{ # {
model_name: "cl100k_base", # model_name: "cl100k_base",
} # }
) # )
print("Logging in to huggingface.co...") print("Logging in to huggingface.co...")
login(token=hf_token) # login to huggingface.co login(token=hf_token) # login to huggingface.co
@asynccontextmanager app = FastAPI(title="Chatbot")
async def lifespan(app: FastAPI):
"""Initializes the vector store in a background task."""
asyncio.create_task(vector_store.init_retrievers())
yield
app = FastAPI(title="Chatbot", lifespan=lifespan)
router = APIRouter() router = APIRouter()
current_dir = os.path.dirname(__file__) current_dir = os.path.dirname(__file__)
@ -125,11 +115,15 @@ if not os.path.exists(uploads):
os.makedirs(uploads) os.makedirs(uploads)
# Initialize the vector store # Initialize the vector store
vector_store = VectorStorage(directory=uploads, use_gpu=use_gpu) # Hardcoded for Swarms documention
URL = "https://docs.swarms.world/en/latest/"
vector_store = RedisVectorStorage(use_gpu=use_gpu)
vector_store.crawl(URL)
print("Vector storage initialized.")
async def create_chat( async def create_chat(
messages: list[Message], messages: list[Message],
model_name: str,
prompt: str = QA_PROMPT_TEMPLATE_STR, prompt: str = QA_PROMPT_TEMPLATE_STR,
): ):
"""Creates the RAG conversational retrieval chain.""" """Creates the RAG conversational retrieval chain."""
@ -143,7 +137,6 @@ async def create_chat(
streaming=True, streaming=True,
) )
retriever = await vector_store.get_retriever("swarms")
doc_retrieval_string = "" doc_retrieval_string = ""
for message in messages: for message in messages:
if message.role == Role.HUMAN: if message.role == Role.HUMAN:
@ -151,11 +144,15 @@ async def create_chat(
elif message.role == Role.AI: elif message.role == Role.AI:
doc_retrieval_string += f"{Role.AI}: {message.content}\r\n" doc_retrieval_string += f"{Role.AI}: {message.content}\r\n"
docs = retriever.invoke(doc_retrieval_string) docs = vector_store.embed(messages[-1].content)
# find {context} in prompt and replace it with the docs page_content. # find {context} in prompt and replace it with the docs page_content.
# Concatenate the content of all documents # Concatenate the content of all documents
context = "\n".join(doc.page_content for doc in docs) context = "\n".join(doc["content"] for doc in docs)
sources = [urlparse(URL).scheme + "://" + doc["source_url"] for doc in docs]
print(f"context: {context}")
# Replace {context} in the prompt with the concatenated document content # Replace {context} in the prompt with the concatenated document content
prompt = prompt.replace("{context}", context) prompt = prompt.replace("{context}", context)
@ -188,7 +185,7 @@ async def create_chat(
# sop="Calculate the profit for a company.", # sop="Calculate the profit for a company.",
# sop_list=["Calculate the profit for a company."], # sop_list=["Calculate the profit for a company."],
user_name="RAH@EntangleIT.com", user_name="RAH@EntangleIT.com",
docs=[doc.page_content for doc in docs], # docs=[doc["content"] for doc in docs],
# # docs_folder="docs", # # docs_folder="docs",
retry_attempts=3, retry_attempts=3,
# context_length=1000, # context_length=1000,
@ -205,8 +202,16 @@ async def create_chat(
elif message.role == Role.AI: elif message.role == Role.AI:
agent.add_message_to_memory(message.content) agent.add_message_to_memory(message.content)
# add docs to short term memory
# for data in [doc["content"] for doc in docs]:
# agent.add_message_to_memory(role=Role.HUMAN, content=data)
async for response in agent.run_async(messages[-1].content): async for response in agent.run_async(messages[-1].content):
yield response res = response
res += "\n\nSources:\n"
for source in sources:
res += source + "\n"
yield res
# memory = ConversationBufferMemory( # memory = ConversationBufferMemory(
# chat_memory=chat_memory, # chat_memory=chat_memory,
@ -260,7 +265,8 @@ async def chat(request: ChatRequest):
""" Handles chatbot chat POST requests """ """ Handles chatbot chat POST requests """
response = create_chat( response = create_chat(
messages=request.messages, messages=request.messages,
prompt=request.prompt.strip() prompt=request.prompt.strip(),
model_name=request.model.id
) )
# return response # return response
return StreamingResponse(content=response) return StreamingResponse(content=response)

Loading…
Cancel
Save