You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/swarms/server/server.py

446 lines
13 KiB

import asyncio
import json
import logging
import os
from datetime import datetime
from typing import List
import langchain
from pydantic import ValidationError, parse_obj_as
from swarms.prompts.chat_prompt import Message, Role
from swarms.server.callback_handlers import SourceDocumentsStreamingCallbackHandler, TokenStreamingCallbackHandler
import tiktoken
# import torch
from dotenv import load_dotenv
from fastapi import FastAPI, File, Form, HTTPException, Request, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, JSONResponse
from fastapi.routing import APIRouter
from fastapi.staticfiles import StaticFiles
from huggingface_hub import login
from langchain.callbacks import StreamingStdOutCallbackHandler
from langchain.memory import ConversationSummaryBufferMemory
from langchain.memory.chat_message_histories.in_memory import ChatMessageHistory
from langchain.prompts.prompt import PromptTemplate
from langchain_community.chat_models import ChatOpenAI
from swarms.server.responses import LangchainStreamingResponse
from langchain.chains.conversational_retrieval.base import ConversationalRetrievalChain
from langchain.chains.llm import LLMChain
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from swarms.prompts.conversational_RAG import (
B_INST,
B_SYS,
CONDENSE_PROMPT_TEMPLATE,
DOCUMENT_PROMPT_TEMPLATE,
E_INST,
E_SYS,
QA_PROMPT_TEMPLATE,
SUMMARY_PROMPT_TEMPLATE,
)
from swarms.server.vector_store import VectorStorage
from swarms.server.server_models import (
ChatRequest,
LogMessage,
AIModel,
AIModels,
RAGFile,
RAGFiles,
State,
GetRAGFileStateRequest,
ProcessRAGFileRequest
)
# Explicitly specify the path to the .env file
dotenv_path = os.path.join(os.path.dirname(__file__), '.env')
load_dotenv(dotenv_path)
hf_token = os.environ.get("HUGGINFACEHUB_API_KEY") # Get the Huggingface API Token
uploads = os.environ.get("UPLOADS") # Directory where user uploads files to be parsed for RAG
model_dir = os.environ.get("MODEL_DIR")
# hugginface.co model (eg. meta-llama/Llama-2-70b-hf)
model_name = os.environ.get("MODEL_NAME")
# Set OpenAI's API key to 'EMPTY' and API base URL to use vLLM's API server, or set them to OpenAI API key and base URL.
openai_api_key = os.environ.get("OPENAI_API_KEY") or "EMPTY"
openai_api_base = os.environ.get("OPENAI_API_BASE") or "http://localhost:8000/v1"
env_vars = [
hf_token,
uploads,
model_dir,
model_name,
openai_api_key,
openai_api_base,
]
missing_vars = [var for var in env_vars if not var]
if missing_vars:
print(
f"Error: The following environment variables are not set: {', '.join(missing_vars)}"
)
exit(1)
useMetal = os.environ.get("USE_METAL", "False") == "True"
print(f"Uploads={uploads}")
print(f"MODEL_DIR={model_dir}")
print(f"MODEL_NAME={model_name}")
print(f"USE_METAL={useMetal}")
print(f"OPENAI_API_KEY={openai_api_key}")
print(f"OPENAI_API_BASE={openai_api_base}")
# update tiktoken to include the model name (avoids warning message)
tiktoken.model.MODEL_TO_ENCODING.update(
{
model_name: "cl100k_base",
}
)
print("Logging in to huggingface.co...")
login(token=hf_token) # login to huggingface.co
# langchain.debug = True
langchain.verbose = True
from contextlib import asynccontextmanager
@asynccontextmanager
async def lifespan(app: FastAPI):
asyncio.create_task(vector_store.initRetrievers())
yield
app = FastAPI(title="Chatbot", lifespan=lifespan)
router = APIRouter()
current_dir = os.path.dirname(__file__)
print("current_dir: " + current_dir)
static_dir = os.path.join(current_dir, "static")
print("static_dir: " + static_dir)
app.mount(static_dir, StaticFiles(directory=static_dir), name="static")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["GET", "POST"],
allow_headers=["*"],
)
# Create ./uploads folder if it doesn't exist
uploads = uploads or os.path.join(os.getcwd(), "uploads")
if not os.path.exists(uploads):
os.makedirs(uploads)
# Initialize the vector store
vector_store = VectorStorage(directory=uploads)
async def create_chain(
messages: list[Message],
model=model_dir,
max_tokens_to_gen=2048,
temperature=0.5,
prompt: PromptTemplate = QA_PROMPT_TEMPLATE,
file: RAGFile | None = None,
key: str | None = None,
):
print(
f"Creating chain with key={key}, model={model}, max_tokens={max_tokens_to_gen}, temperature={temperature}, prompt={prompt}, file={file.title}"
)
llm = ChatOpenAI(
api_key=openai_api_key,
base_url=openai_api_base,
model=model_name,
verbose=True,
streaming=True,
)
# if llm is ALlamaCpp:
# llm.max_tokens = max_tokens_to_gen
# elif llm is AGPT4All:
# llm.n_predict = max_tokens_to_gen
# el
# if llm is AChatOllama:
# llm.max_tokens = max_tokens_to_gen
# if llm is VLLMAsync:
# llm.max_tokens = max_tokens_to_gen
retriever = await vector_store.getRetriever(os.path.join(file.username, file.filename))
chat_memory = ChatMessageHistory()
for message in messages:
if message.role == Role.HUMAN:
chat_memory.add_user_message(message.content)
elif message.role == Role.AI:
chat_memory.add_ai_message(message.content)
elif message.role == Role.SYSTEM:
chat_memory.add_message(message.content)
elif message.role == Role.FUNCTION:
chat_memory.add_message(message.content)
memory = ConversationSummaryBufferMemory(
llm=llm,
chat_memory=chat_memory,
memory_key="chat_history",
input_key="question",
output_key="answer",
prompt=SUMMARY_PROMPT_TEMPLATE,
return_messages=True,
)
question_generator = LLMChain(
llm=llm,
prompt=CONDENSE_PROMPT_TEMPLATE,
memory=memory,
verbose=True,
output_key="answer",
)
stuff_chain = LLMChain(
llm=llm,
prompt=prompt,
verbose=True,
output_key="answer",
)
doc_chain = StuffDocumentsChain(
llm_chain=stuff_chain,
document_variable_name="context",
document_prompt=DOCUMENT_PROMPT_TEMPLATE,
verbose=True,
output_key="answer",
memory=memory,
)
return ConversationalRetrievalChain(
combine_docs_chain=doc_chain,
memory=memory,
retriever=retriever,
question_generator=question_generator,
return_generated_question=False,
return_source_documents=True,
output_key="answer",
verbose=True,
)
router = APIRouter()
@router.post(
"/chat",
summary="Chatbot",
description="Chatbot AI Service",
)
async def chat(request: ChatRequest):
chain: ConversationalRetrievalChain = await create_chain(
file=request.file,
messages=request.messages[:-1],
model=request.model.id,
max_tokens_to_gen=request.maxTokens,
temperature=request.temperature,
prompt=PromptTemplate.from_template(
f"{B_INST}{B_SYS}{request.prompt.strip()}{E_SYS}{E_INST}"
),
)
# async for token in chain.astream(request.messages[-1].content):
# print(f"token={token}")
json_string = json.dumps(
{
"question": request.messages[-1].content,
# "chat_history": [message.content for message in request.messages[:-1]],
}
)
return LangchainStreamingResponse(
chain,
config={
"inputs": json_string,
"callbacks": [
StreamingStdOutCallbackHandler(),
TokenStreamingCallbackHandler(output_key="answer"),
SourceDocumentsStreamingCallbackHandler(),
],
},
)
app.include_router(router, tags=["chat"])
@app.get("/")
def root():
return {"message": "Chatbot API"}
@app.get("/favicon.ico")
def favicon():
file_name = "favicon.ico"
file_path = os.path.join(app.root_path, "static", file_name)
return FileResponse(
path=file_path,
headers={"Content-Disposition": "attachment; filename=" + file_name},
)
@app.post("/log")
def log_message(log_message: LogMessage):
try:
with open("log.txt", "a") as log_file:
log_file.write(log_message.message + "\n")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error saving log: {e}")
return {"message": "Log saved successfully"}
@app.get("/models")
def get_models():
# llama7B = AIModel(
# id="llama-2-7b-chat-ggml-q4_0",
# name="llama-2-7b-chat-ggml-q4_0",
# maxLength=2048,
# tokenLimit=2048,
# )
# llama13B = AIModel(
# id="llama-2-13b-chat-ggml-q4_0",
# name="llama-2-13b-chat-ggml-q4_0",
# maxLength=2048,
# tokenLimit=2048,
# )
llama70B = AIModel(
id="llama-2-70b.Q5_K_M",
name="llama-2-70b.Q5_K_M",
maxLength=2048,
tokenLimit=2048,
)
models = AIModels(models=[llama70B])
return models
@app.get("/titles")
def getTitles():
titles = RAGFiles(
titles=[
# RAGFile(
# versionId="d8ad3b1d-c33c-4524-9691-e93967d4d863",
# title="d8ad3b1d-c33c-4524-9691-e93967d4d863",
# state=State.Unavailable,
# ),
RAGFile(
versionId=collection.name,
title=collection.name,
state=State.InProcess
if collection.name in processing_books
else State.Processed,
)
for collection in vector_store.list_collections()
if collection.name != "langchain"
]
)
return titles
processing_books: list[str] = []
processing_books_lock = asyncio.Lock()
logging.basicConfig(level=logging.ERROR)
@app.post("/titleState")
async def getTitleState(request: GetRAGFileStateRequest):
# FastAPI + Pydantic will throw a 422 Unprocessable Entity if the request isn't the right type.
# try:
logging.debug(f"Received getTitleState request: {request}")
titleStateRequest: GetRAGFileStateRequest = request
# except ValidationError as e:
# print(f"Error validating JSON: {e}")
# raise HTTPException(status_code=422, detail=str(e))
# except json.JSONDecodeError as e:
# print(f"Error parsing JSON: {e}")
# raise HTTPException(status_code=422, detail="Invalid JSON format")
# check to see if the book has already been processed.
# return the proper State directly to response.
matchingCollection = next(
(
x
for x in vector_store.list_collections()
if x.name == titleStateRequest.versionRef
),
None,
)
print("Got a Title State request for version " + titleStateRequest.versionRef)
if titleStateRequest.versionRef in processing_books:
return {"message": State.InProcess}
elif matchingCollection is not None:
return {"message": State.Processed}
else:
return {"message": State.Unavailable}
@app.post("/processRAGFile")
async def processRAGFile(
request: str = Form(...),
files: List[UploadFile] = File(...),
):
try:
logging.debug(f"Received processBook request: {request}")
# Parse the JSON string into a ProcessBookRequest object
fileRAGRequest: ProcessRAGFileRequest = parse_obj_as(
ProcessRAGFileRequest, json.loads(request)
)
except ValidationError as e:
print(f"Error validating JSON: {e}")
raise HTTPException(status_code=422, detail=str(e))
except json.JSONDecodeError as e:
print(f"Error parsing JSON: {e}")
raise HTTPException(status_code=422, detail="Invalid JSON format")
try:
print(
f"Processing file {fileRAGRequest.filename} for user {fileRAGRequest.username}."
)
# check to see if the file has already been processed.
# write html to subfolder
print(f"Writing file to path: {fileRAGRequest.username}/{fileRAGRequest.filename}...")
for index, segment in enumerate(files):
filename = segment.filename if segment.filename else str(index)
subDir = f"{fileRAGRequest.username}"
with open(os.path.join(subDir, filename), "wb") as htmlFile:
htmlFile.write(await segment.read())
# write metadata to subfolder
print(f"Writing metadata to subfolder {fileRAGRequest.username}...")
with open(os.path.join({fileRAGRequest.username}, "metadata.json"), "w") as metadataFile:
metaData = {
"filename": fileRAGRequest.filename,
"username": fileRAGRequest.username,
"processDate": datetime.now().isoformat(),
}
metadataFile.write(json.dumps(metaData))
vector_store.retrievers[
f"{fileRAGRequest.username}/{fileRAGRequest.filename}"
] = await vector_store.initRetriever(f"{fileRAGRequest.username}/{fileRAGRequest.filename}")
return {
"message": f"File {fileRAGRequest.filename} processed successfully."
}
except Exception as e:
logging.error(f"Error processing book: {e}")
return {"message": f"Error processing book: {e}"}
@app.exception_handler(HTTPException)
async def http_exception_handler(bookRequest: Request, exc: HTTPException):
logging.error(f"HTTPException: {exc.detail}")
return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail})