|
|
@ -1,14 +1,13 @@
|
|
|
|
""" Chatbot with RAG Server """
|
|
|
|
""" Chatbot with RAG Server """
|
|
|
|
|
|
|
|
|
|
|
|
import asyncio
|
|
|
|
import asyncio
|
|
|
|
import logging
|
|
|
|
import logging
|
|
|
|
import os
|
|
|
|
import os
|
|
|
|
|
|
|
|
from urllib.parse import urlparse, urljoin
|
|
|
|
# import torch
|
|
|
|
# import torch
|
|
|
|
from contextlib import asynccontextmanager
|
|
|
|
from contextlib import asynccontextmanager
|
|
|
|
from typing import AsyncIterator
|
|
|
|
from typing import AsyncIterator
|
|
|
|
from swarms.structs.agent import Agent
|
|
|
|
from swarms.structs.agent import Agent
|
|
|
|
import tiktoken
|
|
|
|
|
|
|
|
from dotenv import load_dotenv
|
|
|
|
from dotenv import load_dotenv
|
|
|
|
from fastapi import FastAPI, HTTPException, Request
|
|
|
|
from fastapi import FastAPI, HTTPException, Request
|
|
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
@ -29,7 +28,7 @@ from swarms.prompts.conversational_RAG import (
|
|
|
|
)
|
|
|
|
)
|
|
|
|
from playground.demos.chatbot.server.responses import StreamingResponse
|
|
|
|
from playground.demos.chatbot.server.responses import StreamingResponse
|
|
|
|
from playground.demos.chatbot.server.server_models import ChatRequest
|
|
|
|
from playground.demos.chatbot.server.server_models import ChatRequest
|
|
|
|
from playground.demos.chatbot.server.vector_store import VectorStorage
|
|
|
|
from playground.demos.chatbot.server.vector_storage import RedisVectorStorage
|
|
|
|
from swarms.models.popular_llms import OpenAIChatLLM
|
|
|
|
from swarms.models.popular_llms import OpenAIChatLLM
|
|
|
|
|
|
|
|
|
|
|
|
# Explicitly specify the path to the .env file
|
|
|
|
# Explicitly specify the path to the .env file
|
|
|
@ -42,13 +41,15 @@ load_dotenv(dotenv_path)
|
|
|
|
hf_token = os.environ.get(
|
|
|
|
hf_token = os.environ.get(
|
|
|
|
"HUGGINFACEHUB_API_KEY"
|
|
|
|
"HUGGINFACEHUB_API_KEY"
|
|
|
|
) # Get the Huggingface API Token
|
|
|
|
) # Get the Huggingface API Token
|
|
|
|
|
|
|
|
|
|
|
|
uploads = os.environ.get(
|
|
|
|
uploads = os.environ.get(
|
|
|
|
"UPLOADS"
|
|
|
|
"UPLOADS"
|
|
|
|
) # Directory where user uploads files to be parsed for RAG
|
|
|
|
) # Directory where user uploads files to be parsed for RAG
|
|
|
|
model_dir = os.environ.get("MODEL_DIR")
|
|
|
|
|
|
|
|
|
|
|
|
# model_dir = os.environ.get("MODEL_DIR")
|
|
|
|
|
|
|
|
|
|
|
|
# hugginface.co model (eg. meta-llama/Llama-2-70b-hf)
|
|
|
|
# hugginface.co model (eg. meta-llama/Llama-2-70b-hf)
|
|
|
|
model_name = os.environ.get("MODEL_NAME")
|
|
|
|
# model_name = os.environ.get("MODEL_NAME")
|
|
|
|
|
|
|
|
|
|
|
|
# Set OpenAI's API key to 'EMPTY' and API base URL to use vLLM's API server
|
|
|
|
# Set OpenAI's API key to 'EMPTY' and API base URL to use vLLM's API server
|
|
|
|
# or set them to OpenAI API key and base URL.
|
|
|
|
# or set them to OpenAI API key and base URL.
|
|
|
@ -60,8 +61,6 @@ openai_api_base = (
|
|
|
|
env_vars = [
|
|
|
|
env_vars = [
|
|
|
|
hf_token,
|
|
|
|
hf_token,
|
|
|
|
uploads,
|
|
|
|
uploads,
|
|
|
|
model_dir,
|
|
|
|
|
|
|
|
model_name,
|
|
|
|
|
|
|
|
openai_api_key,
|
|
|
|
openai_api_key,
|
|
|
|
openai_api_base,
|
|
|
|
openai_api_base,
|
|
|
|
]
|
|
|
|
]
|
|
|
@ -77,31 +76,22 @@ useMetal = os.environ.get("USE_METAL", "False") == "True"
|
|
|
|
use_gpu = os.environ.get("USE_GPU", "False") == "True"
|
|
|
|
use_gpu = os.environ.get("USE_GPU", "False") == "True"
|
|
|
|
|
|
|
|
|
|
|
|
print(f"Uploads={uploads}")
|
|
|
|
print(f"Uploads={uploads}")
|
|
|
|
print(f"MODEL_DIR={model_dir}")
|
|
|
|
|
|
|
|
print(f"MODEL_NAME={model_name}")
|
|
|
|
|
|
|
|
print(f"USE_METAL={useMetal}")
|
|
|
|
print(f"USE_METAL={useMetal}")
|
|
|
|
print(f"USE_GPU={use_gpu}")
|
|
|
|
print(f"USE_GPU={use_gpu}")
|
|
|
|
print(f"OPENAI_API_KEY={openai_api_key}")
|
|
|
|
print(f"OPENAI_API_KEY={openai_api_key}")
|
|
|
|
print(f"OPENAI_API_BASE={openai_api_base}")
|
|
|
|
print(f"OPENAI_API_BASE={openai_api_base}")
|
|
|
|
|
|
|
|
|
|
|
|
# update tiktoken to include the model name (avoids warning message)
|
|
|
|
# # update tiktoken to include the model name (avoids warning message)
|
|
|
|
tiktoken.model.MODEL_TO_ENCODING.update(
|
|
|
|
# tiktoken.model.MODEL_TO_ENCODING.update(
|
|
|
|
{
|
|
|
|
# {
|
|
|
|
model_name: "cl100k_base",
|
|
|
|
# model_name: "cl100k_base",
|
|
|
|
}
|
|
|
|
# }
|
|
|
|
)
|
|
|
|
# )
|
|
|
|
|
|
|
|
|
|
|
|
print("Logging in to huggingface.co...")
|
|
|
|
print("Logging in to huggingface.co...")
|
|
|
|
login(token=hf_token) # login to huggingface.co
|
|
|
|
login(token=hf_token) # login to huggingface.co
|
|
|
|
|
|
|
|
|
|
|
|
@asynccontextmanager
|
|
|
|
app = FastAPI(title="Chatbot")
|
|
|
|
async def lifespan(app: FastAPI):
|
|
|
|
|
|
|
|
"""Initializes the vector store in a background task."""
|
|
|
|
|
|
|
|
asyncio.create_task(vector_store.init_retrievers())
|
|
|
|
|
|
|
|
yield
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app = FastAPI(title="Chatbot", lifespan=lifespan)
|
|
|
|
|
|
|
|
router = APIRouter()
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
|
|
|
|
|
|
current_dir = os.path.dirname(__file__)
|
|
|
|
current_dir = os.path.dirname(__file__)
|
|
|
@ -125,11 +115,15 @@ if not os.path.exists(uploads):
|
|
|
|
os.makedirs(uploads)
|
|
|
|
os.makedirs(uploads)
|
|
|
|
|
|
|
|
|
|
|
|
# Initialize the vector store
|
|
|
|
# Initialize the vector store
|
|
|
|
vector_store = VectorStorage(directory=uploads, use_gpu=use_gpu)
|
|
|
|
# Hardcoded for Swarms documention
|
|
|
|
|
|
|
|
URL = "https://docs.swarms.world/en/latest/"
|
|
|
|
|
|
|
|
vector_store = RedisVectorStorage(use_gpu=use_gpu)
|
|
|
|
|
|
|
|
vector_store.crawl(URL)
|
|
|
|
|
|
|
|
print("Vector storage initialized.")
|
|
|
|
|
|
|
|
|
|
|
|
async def create_chat(
|
|
|
|
async def create_chat(
|
|
|
|
messages: list[Message],
|
|
|
|
messages: list[Message],
|
|
|
|
|
|
|
|
model_name: str,
|
|
|
|
prompt: str = QA_PROMPT_TEMPLATE_STR,
|
|
|
|
prompt: str = QA_PROMPT_TEMPLATE_STR,
|
|
|
|
):
|
|
|
|
):
|
|
|
|
"""Creates the RAG conversational retrieval chain."""
|
|
|
|
"""Creates the RAG conversational retrieval chain."""
|
|
|
@ -143,7 +137,6 @@ async def create_chat(
|
|
|
|
streaming=True,
|
|
|
|
streaming=True,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
retriever = await vector_store.get_retriever("swarms")
|
|
|
|
|
|
|
|
doc_retrieval_string = ""
|
|
|
|
doc_retrieval_string = ""
|
|
|
|
for message in messages:
|
|
|
|
for message in messages:
|
|
|
|
if message.role == Role.HUMAN:
|
|
|
|
if message.role == Role.HUMAN:
|
|
|
@ -151,11 +144,15 @@ async def create_chat(
|
|
|
|
elif message.role == Role.AI:
|
|
|
|
elif message.role == Role.AI:
|
|
|
|
doc_retrieval_string += f"{Role.AI}: {message.content}\r\n"
|
|
|
|
doc_retrieval_string += f"{Role.AI}: {message.content}\r\n"
|
|
|
|
|
|
|
|
|
|
|
|
docs = retriever.invoke(doc_retrieval_string)
|
|
|
|
docs = vector_store.embed(messages[-1].content)
|
|
|
|
|
|
|
|
|
|
|
|
# find {context} in prompt and replace it with the docs page_content.
|
|
|
|
# find {context} in prompt and replace it with the docs page_content.
|
|
|
|
# Concatenate the content of all documents
|
|
|
|
# Concatenate the content of all documents
|
|
|
|
context = "\n".join(doc.page_content for doc in docs)
|
|
|
|
context = "\n".join(doc["content"] for doc in docs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sources = [urlparse(URL).scheme + "://" + doc["source_url"] for doc in docs]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(f"context: {context}")
|
|
|
|
|
|
|
|
|
|
|
|
# Replace {context} in the prompt with the concatenated document content
|
|
|
|
# Replace {context} in the prompt with the concatenated document content
|
|
|
|
prompt = prompt.replace("{context}", context)
|
|
|
|
prompt = prompt.replace("{context}", context)
|
|
|
@ -188,7 +185,7 @@ async def create_chat(
|
|
|
|
# sop="Calculate the profit for a company.",
|
|
|
|
# sop="Calculate the profit for a company.",
|
|
|
|
# sop_list=["Calculate the profit for a company."],
|
|
|
|
# sop_list=["Calculate the profit for a company."],
|
|
|
|
user_name="RAH@EntangleIT.com",
|
|
|
|
user_name="RAH@EntangleIT.com",
|
|
|
|
docs=[doc.page_content for doc in docs],
|
|
|
|
# docs=[doc["content"] for doc in docs],
|
|
|
|
# # docs_folder="docs",
|
|
|
|
# # docs_folder="docs",
|
|
|
|
retry_attempts=3,
|
|
|
|
retry_attempts=3,
|
|
|
|
# context_length=1000,
|
|
|
|
# context_length=1000,
|
|
|
@ -205,8 +202,16 @@ async def create_chat(
|
|
|
|
elif message.role == Role.AI:
|
|
|
|
elif message.role == Role.AI:
|
|
|
|
agent.add_message_to_memory(message.content)
|
|
|
|
agent.add_message_to_memory(message.content)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# add docs to short term memory
|
|
|
|
|
|
|
|
# for data in [doc["content"] for doc in docs]:
|
|
|
|
|
|
|
|
# agent.add_message_to_memory(role=Role.HUMAN, content=data)
|
|
|
|
|
|
|
|
|
|
|
|
async for response in agent.run_async(messages[-1].content):
|
|
|
|
async for response in agent.run_async(messages[-1].content):
|
|
|
|
yield response
|
|
|
|
res = response
|
|
|
|
|
|
|
|
res += "\n\nSources:\n"
|
|
|
|
|
|
|
|
for source in sources:
|
|
|
|
|
|
|
|
res += source + "\n"
|
|
|
|
|
|
|
|
yield res
|
|
|
|
|
|
|
|
|
|
|
|
# memory = ConversationBufferMemory(
|
|
|
|
# memory = ConversationBufferMemory(
|
|
|
|
# chat_memory=chat_memory,
|
|
|
|
# chat_memory=chat_memory,
|
|
|
@ -260,7 +265,8 @@ async def chat(request: ChatRequest):
|
|
|
|
""" Handles chatbot chat POST requests """
|
|
|
|
""" Handles chatbot chat POST requests """
|
|
|
|
response = create_chat(
|
|
|
|
response = create_chat(
|
|
|
|
messages=request.messages,
|
|
|
|
messages=request.messages,
|
|
|
|
prompt=request.prompt.strip()
|
|
|
|
prompt=request.prompt.strip(),
|
|
|
|
|
|
|
|
model_name=request.model.id
|
|
|
|
)
|
|
|
|
)
|
|
|
|
# return response
|
|
|
|
# return response
|
|
|
|
return StreamingResponse(content=response)
|
|
|
|
return StreamingResponse(content=response)
|
|
|
|