Added FastAPI server for Chatbot API

pull/570/head
Richard Anthony Hein 9 months ago
parent c86e62400a
commit b9ada3d1bb

@ -0,0 +1,448 @@
# modified from Lanarky sourcecode https://github.com/auxon/lanarky
from typing import Any, Optional
from fastapi.websockets import WebSocket
from langchain.callbacks.base import AsyncCallbackHandler
from langchain.callbacks.streaming_stdout_final_only import (
FinalStreamingStdOutCallbackHandler,
)
from langchain.globals import get_llm_cache
from langchain.schema.document import Document
from pydantic import BaseModel
from starlette.types import Message, Send
from sse_starlette.sse import ensure_bytes, ServerSentEvent
from swarms.server.utils import StrEnum, model_dump_json
class LangchainEvents(StrEnum):
SOURCE_DOCUMENTS = "source_documents"
class BaseCallbackHandler(AsyncCallbackHandler):
"""Base callback handler for streaming / async applications."""
def __init__(self, **kwargs: dict[str, Any]) -> None:
super().__init__(**kwargs)
self.llm_cache_used = get_llm_cache() is not None
@property
def always_verbose(self) -> bool:
"""Verbose mode is always enabled."""
return True
async def on_chat_model_start(self, *args: Any, **kwargs: Any) -> Any: ...
class StreamingCallbackHandler(BaseCallbackHandler):
"""Callback handler for streaming responses."""
def __init__(
self,
*,
send: Send = None,
**kwargs: dict[str, Any],
) -> None:
"""Constructor method.
Args:
send: The ASGI send callable.
**kwargs: Keyword arguments to pass to the parent constructor.
"""
super().__init__(**kwargs)
self._send = send
self.streaming = None
@property
def send(self) -> Send:
return self._send
@send.setter
def send(self, value: Send) -> None:
"""Setter method for send property."""
if not callable(value):
raise ValueError("value must be a Callable")
self._send = value
def _construct_message(self, data: str, event: Optional[str] = None) -> Message:
"""Constructs message payload.
Args:
data: The data payload.
event: The event name.
"""
chunk = ServerSentEvent(data=data, event=event)
return {
"type": "http.response.body",
"body": ensure_bytes(chunk, None),
"more_body": True,
}
class TokenStreamMode(StrEnum):
TEXT = "text"
JSON = "json"
class TokenEventData(BaseModel):
"""Event data payload for tokens."""
token: str = ""
def get_token_data(token: str, mode: TokenStreamMode) -> str:
"""Get token data based on mode.
Args:
token: The token to use.
mode: The stream mode.
"""
if mode not in list(TokenStreamMode):
raise ValueError(f"Invalid stream mode: {mode}")
if mode == TokenStreamMode.TEXT:
return token
else:
return model_dump_json(TokenEventData(token=token))
class TokenStreamingCallbackHandler(StreamingCallbackHandler):
"""Callback handler for streaming tokens."""
def __init__(
self,
*,
output_key: str,
mode: TokenStreamMode = TokenStreamMode.JSON,
**kwargs: dict[str, Any],
) -> None:
"""Constructor method.
Args:
output_key: chain output key.
mode: The stream mode.
**kwargs: Keyword arguments to pass to the parent constructor.
"""
super().__init__(**kwargs)
self.output_key = output_key
if mode not in list(TokenStreamMode):
raise ValueError(f"Invalid stream mode: {mode}")
self.mode = mode
async def on_chain_start(self, *args: Any, **kwargs: dict[str, Any]) -> None:
"""Run when chain starts running."""
self.streaming = False
async def on_llm_new_token(self, token: str, **kwargs: dict[str, Any]) -> None:
"""Run on new LLM token. Only available when streaming is enabled."""
if not self.streaming:
self.streaming = True
if self.llm_cache_used: # cache missed (or was never enabled) if we are here
self.llm_cache_used = False
message = self._construct_message(
data=get_token_data(token, self.mode), event="completion"
)
await self.send(message)
async def on_chain_end(
self, outputs: dict[str, Any], **kwargs: dict[str, Any]
) -> None:
"""Run when chain ends running.
Final output is streamed only if LLM cache is enabled.
"""
if self.llm_cache_used or not self.streaming:
if self.output_key in outputs:
message = self._construct_message(
data=get_token_data(outputs[self.output_key], self.mode),
event="completion",
)
await self.send(message)
else:
raise KeyError(f"missing outputs key: {self.output_key}")
class SourceDocumentsEventData(BaseModel):
"""Event data payload for source documents."""
source_documents: list[dict[str, Any]]
class SourceDocumentsStreamingCallbackHandler(StreamingCallbackHandler):
"""Callback handler for streaming source documents."""
async def on_chain_end(
self, outputs: dict[str, Any], **kwargs: dict[str, Any]
) -> None:
"""Run when chain ends running."""
if "source_documents" in outputs:
if not isinstance(outputs["source_documents"], list):
raise ValueError("source_documents must be a list")
if not isinstance(outputs["source_documents"][0], Document):
raise ValueError("source_documents must be a list of Document")
# NOTE: langchain is using pydantic_v1 for `Document`
source_documents: list[dict] = [
document.dict() for document in outputs["source_documents"]
]
message = self._construct_message(
data=model_dump_json(
SourceDocumentsEventData(source_documents=source_documents)
),
event=LangchainEvents.SOURCE_DOCUMENTS,
)
await self.send(message)
class FinalTokenStreamingCallbackHandler(
TokenStreamingCallbackHandler, FinalStreamingStdOutCallbackHandler
):
"""Callback handler for streaming final answer tokens.
Useful for streaming responses from Langchain Agents.
"""
def __init__(
self,
*,
answer_prefix_tokens: Optional[list[str]] = None,
strip_tokens: bool = True,
stream_prefix: bool = False,
**kwargs: dict[str, Any],
) -> None:
"""Constructor method.
Args:
answer_prefix_tokens: The answer prefix tokens to use.
strip_tokens: Whether to strip tokens.
stream_prefix: Whether to stream the answer prefix.
**kwargs: Keyword arguments to pass to the parent constructor.
"""
super().__init__(output_key=None, **kwargs)
FinalStreamingStdOutCallbackHandler.__init__(
self,
answer_prefix_tokens=answer_prefix_tokens,
strip_tokens=strip_tokens,
stream_prefix=stream_prefix,
)
async def on_llm_start(self, *args: Any, **kwargs: dict[str, Any]) -> None:
"""Run when LLM starts running."""
self.answer_reached = False
self.streaming = False
async def on_llm_new_token(self, token: str, **kwargs: dict[str, Any]) -> None:
"""Run on new LLM token. Only available when streaming is enabled."""
if not self.streaming:
self.streaming = True
# Remember the last n tokens, where n = len(answer_prefix_tokens)
self.append_to_last_tokens(token)
# Check if the last n tokens match the answer_prefix_tokens list ...
if self.check_if_answer_reached():
self.answer_reached = True
if self.stream_prefix:
message = self._construct_message(
data=get_token_data("".join(self.last_tokens), self.mode),
event="completion",
)
await self.send(message)
# ... if yes, then print tokens from now on
if self.answer_reached:
message = self._construct_message(
data=get_token_data(token, self.mode), event="completion"
)
await self.send(message)
class WebSocketCallbackHandler(BaseCallbackHandler):
"""Callback handler for websocket sessions."""
def __init__(
self,
*,
mode: TokenStreamMode = TokenStreamMode.JSON,
websocket: WebSocket = None,
**kwargs: dict[str, Any],
) -> None:
"""Constructor method.
Args:
mode: The stream mode.
websocket: The websocket to use.
**kwargs: Keyword arguments to pass to the parent constructor.
"""
super().__init__(**kwargs)
if mode not in list(TokenStreamMode):
raise ValueError(f"Invalid stream mode: {mode}")
self.mode = mode
self._websocket = websocket
self.streaming = None
@property
def websocket(self) -> WebSocket:
return self._websocket
@websocket.setter
def websocket(self, value: WebSocket) -> None:
"""Setter method for send property."""
if not isinstance(value, WebSocket):
raise ValueError("value must be a WebSocket")
self._websocket = value
def _construct_message(self, data: str, event: Optional[str] = None) -> Message:
"""Constructs message payload.
Args:
data: The data payload.
event: The event name.
"""
return dict(data=data, event=event)
class TokenWebSocketCallbackHandler(WebSocketCallbackHandler):
"""Callback handler for sending tokens in websocket sessions."""
def __init__(self, *, output_key: str, **kwargs: dict[str, Any]) -> None:
"""Constructor method.
Args:
output_key: chain output key.
**kwargs: Keyword arguments to pass to the parent constructor.
"""
super().__init__(**kwargs)
self.output_key = output_key
async def on_chain_start(self, *args: Any, **kwargs: dict[str, Any]) -> None:
"""Run when chain starts running."""
self.streaming = False
async def on_llm_new_token(self, token: str, **kwargs: dict[str, Any]) -> None:
"""Run on new LLM token. Only available when streaming is enabled."""
if not self.streaming:
self.streaming = True
if self.llm_cache_used: # cache missed (or was never enabled) if we are here
self.llm_cache_used = False
message = self._construct_message(
data=get_token_data(token, self.mode), event="completion"
)
await self.websocket.send_json(message)
async def on_chain_end(
self, outputs: dict[str, Any], **kwargs: dict[str, Any]
) -> None:
"""Run when chain ends running.
Final output is streamed only if LLM cache is enabled.
"""
if self.llm_cache_used or not self.streaming:
if self.output_key in outputs:
message = self._construct_message(
data=get_token_data(outputs[self.output_key], self.mode),
event="completion",
)
await self.websocket.send_json(message)
else:
raise KeyError(f"missing outputs key: {self.output_key}")
class SourceDocumentsWebSocketCallbackHandler(WebSocketCallbackHandler):
"""Callback handler for sending source documents in websocket sessions."""
async def on_chain_end(
self, outputs: dict[str, Any], **kwargs: dict[str, Any]
) -> None:
"""Run when chain ends running."""
if "source_documents" in outputs:
if not isinstance(outputs["source_documents"], list):
raise ValueError("source_documents must be a list")
if not isinstance(outputs["source_documents"][0], Document):
raise ValueError("source_documents must be a list of Document")
# NOTE: langchain is using pydantic_v1 for `Document`
source_documents: list[dict] = [
document.dict() for document in outputs["source_documents"]
]
message = self._construct_message(
data=model_dump_json(
SourceDocumentsEventData(source_documents=source_documents)
),
event=LangchainEvents.SOURCE_DOCUMENTS,
)
await self.websocket.send_json(message)
class FinalTokenWebSocketCallbackHandler(
TokenWebSocketCallbackHandler, FinalStreamingStdOutCallbackHandler
):
"""Callback handler for sending final answer tokens in websocket sessions.
Useful for streaming responses from Langchain Agents.
"""
def __init__(
self,
*,
answer_prefix_tokens: Optional[list[str]] = None,
strip_tokens: bool = True,
stream_prefix: bool = False,
**kwargs: dict[str, Any],
) -> None:
"""Constructor method.
Args:
answer_prefix_tokens: The answer prefix tokens to use.
strip_tokens: Whether to strip tokens.
stream_prefix: Whether to stream the answer prefix.
**kwargs: Keyword arguments to pass to the parent constructor.
"""
super().__init__(output_key=None, **kwargs)
FinalStreamingStdOutCallbackHandler.__init__(
self,
answer_prefix_tokens=answer_prefix_tokens,
strip_tokens=strip_tokens,
stream_prefix=stream_prefix,
)
async def on_llm_start(self, *args, **kwargs) -> None:
"""Run when LLM starts running."""
self.answer_reached = False
self.streaming = False
async def on_llm_new_token(self, token: str, **kwargs: dict[str, Any]) -> None:
"""Run on new LLM token. Only available when streaming is enabled."""
if not self.streaming:
self.streaming = True
# Remember the last n tokens, where n = len(answer_prefix_tokens)
self.append_to_last_tokens(token)
# Check if the last n tokens match the answer_prefix_tokens list ...
if self.check_if_answer_reached():
self.answer_reached = True
if self.stream_prefix:
message = self._construct_message(
data=get_token_data("".join(self.last_tokens), self.mode),
event="completion",
)
await self.websocket.send_json(message)
# ... if yes, then print tokens from now on
if self.answer_reached:
message = self._construct_message(
data=get_token_data(token, self.mode), event="completion"
)
await self.websocket.send_json(message)

@ -0,0 +1,179 @@
from typing import Any
import asyncio
from functools import partial
from typing import Any
from fastapi import status
from langchain.chains.base import Chain
from starlette.types import Send
from fastapi import status
from sse_starlette import ServerSentEvent
from sse_starlette.sse import EventSourceResponse
from starlette.types import Send
from swarms.server.utils import StrEnum
from sse_starlette.sse import ensure_bytes
class HTTPStatusDetail(StrEnum):
INTERNAL_SERVER_ERROR = "Internal Server Error"
class StreamingResponse(EventSourceResponse):
"""`Response` class for streaming server-sent events.
Follows the
[EventSource protocol](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events#interfaces)
"""
def __init__(
self,
content: Any = iter(()),
*args: Any,
**kwargs: dict[str, Any],
) -> None:
"""Constructor method.
Args:
content: The content to stream.
"""
super().__init__(content=content, *args, **kwargs)
async def stream_response(self, send: Send) -> None:
"""Streams data chunks to client by iterating over `content`.
If an exception occurs while iterating over `content`, an
internal server error is sent to the client.
Args:
send: The send function from the ASGI framework.
"""
await send(
{
"type": "http.response.start",
"status": self.status_code,
"headers": self.raw_headers,
}
)
try:
async for data in self.body_iterator:
chunk = ensure_bytes(data, self.sep)
print(f"chunk: {chunk.decode()}")
await send(
{"type": "http.response.body", "body": chunk, "more_body": True}
)
except Exception as e:
print(f"body iterator error: {e}")
chunk = ServerSentEvent(
data=dict(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=HTTPStatusDetail.INTERNAL_SERVER_ERROR,
),
event="error",
)
await send(
{
"type": "http.response.body",
"body": ensure_bytes(chunk, None),
"more_body": True,
}
)
await send({"type": "http.response.body", "body": b"", "more_body": False})
class ChainRunMode(StrEnum):
"""Enum for LangChain run modes."""
ASYNC = "async"
SYNC = "sync"
class LangchainStreamingResponse(StreamingResponse):
"""StreamingResponse class for LangChain resources."""
def __init__(
self,
chain: Chain,
config: dict[str, Any],
run_mode: ChainRunMode = ChainRunMode.ASYNC,
*args: Any,
**kwargs: dict[str, Any],
) -> None:
"""Constructor method.
Args:
chain: A LangChain instance.
config: A config dict.
*args: Positional arguments to pass to the parent constructor.
**kwargs: Keyword arguments to pass to the parent constructor.
"""
super().__init__(*args, **kwargs)
self.chain = chain
self.config = config
if run_mode not in list(ChainRunMode):
raise ValueError(
f"Invalid run mode '{run_mode}'. Must be one of {list(ChainRunMode)}"
)
self.run_mode = run_mode
async def stream_response(self, send: Send) -> None:
"""Stream LangChain outputs.
If an exception occurs while iterating over the LangChain, an
internal server error is sent to the client.
Args:
send: The ASGI send callable.
"""
await send(
{
"type": "http.response.start",
"status": self.status_code,
"headers": self.raw_headers,
}
)
if "callbacks" in self.config:
for callback in self.config["callbacks"]:
if hasattr(callback, "send"):
callback.send = send
try:
# TODO: migrate to `.ainvoke` when adding support
# for LCEL
if self.run_mode == ChainRunMode.ASYNC:
outputs = await self.chain.acall(**self.config)
else:
loop = asyncio.get_event_loop()
outputs = await loop.run_in_executor(
None, partial(self.chain, **self.config)
)
if self.background is not None:
self.background.kwargs.update({"outputs": outputs})
except Exception as e:
print(f"chain runtime error: {e}")
if self.background is not None:
self.background.kwargs.update({"outputs": {}, "error": e})
chunk = ServerSentEvent(
data=dict(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=HTTPStatusDetail.INTERNAL_SERVER_ERROR,
),
event="error",
)
await send(
{
"type": "http.response.body",
"body": ensure_bytes(chunk, None),
"more_body": True,
}
)
await send({"type": "http.response.body", "body": b"", "more_body": False})

@ -0,0 +1,69 @@
from typing import Any
from fastapi import status
from starlette.types import Send
from sse_starlette.sse import ensure_bytes, EventSourceResponse, ServerSentEvent
class StreamingResponse(EventSourceResponse):
"""`Response` class for streaming server-sent events.
Follows the
[EventSource protocol](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events#interfaces)
"""
def __init__(
self,
content: Any = iter(()),
*args: Any,
**kwargs: dict[str, Any],
) -> None:
"""Constructor method.
Args:
content: The content to stream.
"""
super().__init__(content=content, *args, **kwargs)
async def stream_response(self, send: Send) -> None:
"""Streams data chunks to client by iterating over `content`.
If an exception occurs while iterating over `content`, an
internal server error is sent to the client.
Args:
send: The send function from the ASGI framework.
"""
await send(
{
"type": "http.response.start",
"status": self.status_code,
"headers": self.raw_headers,
}
)
try:
async for data in self.body_iterator:
chunk = ensure_bytes(data, self.sep)
with open("log.txt", "a") as log_file:
log_file.write(f"chunk: {chunk.decode()}\n")
await send(
{"type": "http.response.body", "body": chunk, "more_body": True}
)
except Exception as e:
with open("log.txt", "a") as log_file:
log_file.write(f"body iterator error: {e}\n")
chunk = ServerSentEvent(
data=dict(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Internal Server Error",
),
event="error",
)
await send(
{
"type": "http.response.body",
"body": ensure_bytes(chunk, None),
"more_body": True,
}
)
await send({"type": "http.response.body", "body": b"", "more_body": False})

@ -0,0 +1,445 @@
import asyncio
import json
import logging
import os
from datetime import datetime
from typing import List
import langchain
from pydantic import ValidationError, parse_obj_as
from swarms.prompts.chat_prompt import Message, Role
from swarms.server.callback_handlers import SourceDocumentsStreamingCallbackHandler, TokenStreamingCallbackHandler
import tiktoken
# import torch
from dotenv import load_dotenv
from fastapi import FastAPI, File, Form, HTTPException, Request, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, JSONResponse
from fastapi.routing import APIRouter
from fastapi.staticfiles import StaticFiles
from huggingface_hub import login
from langchain.callbacks import StreamingStdOutCallbackHandler
from langchain.memory import ConversationSummaryBufferMemory
from langchain.memory.chat_message_histories.in_memory import ChatMessageHistory
from langchain.prompts.prompt import PromptTemplate
from langchain_community.chat_models import ChatOpenAI
from swarms.server.responses import LangchainStreamingResponse
from langchain.chains.conversational_retrieval.base import ConversationalRetrievalChain
from langchain.chains.llm import LLMChain
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from swarms.prompts.conversational_RAG import (
B_INST,
B_SYS,
CONDENSE_PROMPT_TEMPLATE,
DOCUMENT_PROMPT_TEMPLATE,
E_INST,
E_SYS,
QA_PROMPT_TEMPLATE,
SUMMARY_PROMPT_TEMPLATE,
)
from swarms.server.vector_store import VectorStorage
from swarms.server.server_models import (
ChatRequest,
LogMessage,
AIModel,
AIModels,
RAGFile,
RAGFiles,
State,
GetRAGFileStateRequest,
ProcessRAGFileRequest
)
# Explicitly specify the path to the .env file
dotenv_path = os.path.join(os.path.dirname(__file__), '.env')
load_dotenv(dotenv_path)
hf_token = os.environ.get("HUGGINFACEHUB_API_KEY") # Get the Huggingface API Token
uploads = os.environ.get("UPLOADS") # Directory where user uploads files to be parsed for RAG
model_dir = os.environ.get("MODEL_DIR")
# hugginface.co model (eg. meta-llama/Llama-2-70b-hf)
model_name = os.environ.get("MODEL_NAME")
# Set OpenAI's API key to 'EMPTY' and API base URL to use vLLM's API server, or set them to OpenAI API key and base URL.
openai_api_key = os.environ.get("OPENAI_API_KEY") or "EMPTY"
openai_api_base = os.environ.get("OPENAI_API_BASE") or "http://localhost:8000/v1"
env_vars = [
hf_token,
uploads,
model_dir,
model_name,
openai_api_key,
openai_api_base,
]
missing_vars = [var for var in env_vars if not var]
if missing_vars:
print(
f"Error: The following environment variables are not set: {', '.join(missing_vars)}"
)
exit(1)
useMetal = os.environ.get("USE_METAL", "False") == "True"
print(f"Uploads={uploads}")
print(f"MODEL_DIR={model_dir}")
print(f"MODEL_NAME={model_name}")
print(f"USE_METAL={useMetal}")
print(f"OPENAI_API_KEY={openai_api_key}")
print(f"OPENAI_API_BASE={openai_api_base}")
# update tiktoken to include the model name (avoids warning message)
tiktoken.model.MODEL_TO_ENCODING.update(
{
model_name: "cl100k_base",
}
)
print("Logging in to huggingface.co...")
login(token=hf_token) # login to huggingface.co
# langchain.debug = True
langchain.verbose = True
from contextlib import asynccontextmanager
@asynccontextmanager
async def lifespan(app: FastAPI):
asyncio.create_task(vector_store.initRetrievers())
yield
app = FastAPI(title="Chatbot", lifespan=lifespan)
router = APIRouter()
current_dir = os.path.dirname(__file__)
print("current_dir: " + current_dir)
static_dir = os.path.join(current_dir, "static")
print("static_dir: " + static_dir)
app.mount(static_dir, StaticFiles(directory=static_dir), name="static")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["GET", "POST"],
allow_headers=["*"],
)
# Create ./uploads folder if it doesn't exist
uploads = uploads or os.path.join(os.getcwd(), "uploads")
if not os.path.exists(uploads):
os.makedirs(uploads)
# Initialize the vector store
vector_store = VectorStorage(directory=uploads)
async def create_chain(
messages: list[Message],
model=model_dir,
max_tokens_to_gen=2048,
temperature=0.5,
prompt: PromptTemplate = QA_PROMPT_TEMPLATE,
file: RAGFile | None = None,
key: str | None = None,
):
print(
f"Creating chain with key={key}, model={model}, max_tokens={max_tokens_to_gen}, temperature={temperature}, prompt={prompt}, file={file.title}"
)
llm = ChatOpenAI(
api_key=openai_api_key,
base_url=openai_api_base,
model=model_name,
verbose=True,
streaming=True,
)
# if llm is ALlamaCpp:
# llm.max_tokens = max_tokens_to_gen
# elif llm is AGPT4All:
# llm.n_predict = max_tokens_to_gen
# el
# if llm is AChatOllama:
# llm.max_tokens = max_tokens_to_gen
# if llm is VLLMAsync:
# llm.max_tokens = max_tokens_to_gen
retriever = await vector_store.getRetriever(os.path.join(file.username, file.filename))
chat_memory = ChatMessageHistory()
for message in messages:
if message.role == Role.HUMAN:
chat_memory.add_user_message(message.content)
elif message.role == Role.AI:
chat_memory.add_ai_message(message.content)
elif message.role == Role.SYSTEM:
chat_memory.add_message(message.content)
elif message.role == Role.FUNCTION:
chat_memory.add_message(message.content)
memory = ConversationSummaryBufferMemory(
llm=llm,
chat_memory=chat_memory,
memory_key="chat_history",
input_key="question",
output_key="answer",
prompt=SUMMARY_PROMPT_TEMPLATE,
return_messages=True,
)
question_generator = LLMChain(
llm=llm,
prompt=CONDENSE_PROMPT_TEMPLATE,
memory=memory,
verbose=True,
output_key="answer",
)
stuff_chain = LLMChain(
llm=llm,
prompt=prompt,
verbose=True,
output_key="answer",
)
doc_chain = StuffDocumentsChain(
llm_chain=stuff_chain,
document_variable_name="context",
document_prompt=DOCUMENT_PROMPT_TEMPLATE,
verbose=True,
output_key="answer",
memory=memory,
)
return ConversationalRetrievalChain(
combine_docs_chain=doc_chain,
memory=memory,
retriever=retriever,
question_generator=question_generator,
return_generated_question=False,
return_source_documents=True,
output_key="answer",
verbose=True,
)
router = APIRouter()
@router.post(
"/chat",
summary="Chatbot",
description="Chatbot AI Service",
)
async def chat(request: ChatRequest):
chain: ConversationalRetrievalChain = await create_chain(
file=request.file,
messages=request.messages[:-1],
model=request.model.id,
max_tokens_to_gen=request.maxTokens,
temperature=request.temperature,
prompt=PromptTemplate.from_template(
f"{B_INST}{B_SYS}{request.prompt.strip()}{E_SYS}{E_INST}"
),
)
# async for token in chain.astream(request.messages[-1].content):
# print(f"token={token}")
json_string = json.dumps(
{
"question": request.messages[-1].content,
# "chat_history": [message.content for message in request.messages[:-1]],
}
)
return LangchainStreamingResponse(
chain,
config={
"inputs": json_string,
"callbacks": [
StreamingStdOutCallbackHandler(),
TokenStreamingCallbackHandler(output_key="answer"),
SourceDocumentsStreamingCallbackHandler(),
],
},
)
app.include_router(router, tags=["chat"])
@app.get("/")
def root():
return {"message": "Chatbot API"}
@app.get("/favicon.ico")
def favicon():
file_name = "favicon.ico"
file_path = os.path.join(app.root_path, "static", file_name)
return FileResponse(
path=file_path,
headers={"Content-Disposition": "attachment; filename=" + file_name},
)
@app.post("/log")
def log_message(log_message: LogMessage):
try:
with open("log.txt", "a") as log_file:
log_file.write(log_message.message + "\n")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error saving log: {e}")
return {"message": "Log saved successfully"}
@app.get("/models")
def get_models():
# llama7B = AIModel(
# id="llama-2-7b-chat-ggml-q4_0",
# name="llama-2-7b-chat-ggml-q4_0",
# maxLength=2048,
# tokenLimit=2048,
# )
# llama13B = AIModel(
# id="llama-2-13b-chat-ggml-q4_0",
# name="llama-2-13b-chat-ggml-q4_0",
# maxLength=2048,
# tokenLimit=2048,
# )
llama70B = AIModel(
id="llama-2-70b.Q5_K_M",
name="llama-2-70b.Q5_K_M",
maxLength=2048,
tokenLimit=2048,
)
models = AIModels(models=[llama70B])
return models
@app.get("/titles")
def getTitles():
titles = RAGFiles(
titles=[
# RAGFile(
# versionId="d8ad3b1d-c33c-4524-9691-e93967d4d863",
# title="d8ad3b1d-c33c-4524-9691-e93967d4d863",
# state=State.Unavailable,
# ),
RAGFile(
versionId=collection.name,
title=collection.name,
state=State.InProcess
if collection.name in processing_books
else State.Processed,
)
for collection in vector_store.list_collections()
if collection.name != "langchain"
]
)
return titles
processing_books: list[str] = []
processing_books_lock = asyncio.Lock()
logging.basicConfig(level=logging.ERROR)
@app.post("/titleState")
async def getTitleState(request: GetRAGFileStateRequest):
# FastAPI + Pydantic will throw a 422 Unprocessable Entity if the request isn't the right type.
# try:
logging.debug(f"Received getTitleState request: {request}")
titleStateRequest: GetRAGFileStateRequest = request
# except ValidationError as e:
# print(f"Error validating JSON: {e}")
# raise HTTPException(status_code=422, detail=str(e))
# except json.JSONDecodeError as e:
# print(f"Error parsing JSON: {e}")
# raise HTTPException(status_code=422, detail="Invalid JSON format")
# check to see if the book has already been processed.
# return the proper State directly to response.
matchingCollection = next(
(
x
for x in vector_store.list_collections()
if x.name == titleStateRequest.versionRef
),
None,
)
print("Got a Title State request for version " + titleStateRequest.versionRef)
if titleStateRequest.versionRef in processing_books:
return {"message": State.InProcess}
elif matchingCollection is not None:
return {"message": State.Processed}
else:
return {"message": State.Unavailable}
@app.post("/processRAGFile")
async def processRAGFile(
request: str = Form(...),
files: List[UploadFile] = File(...),
):
try:
logging.debug(f"Received processBook request: {request}")
# Parse the JSON string into a ProcessBookRequest object
fileRAGRequest: ProcessRAGFileRequest = parse_obj_as(
ProcessRAGFileRequest, json.loads(request)
)
except ValidationError as e:
print(f"Error validating JSON: {e}")
raise HTTPException(status_code=422, detail=str(e))
except json.JSONDecodeError as e:
print(f"Error parsing JSON: {e}")
raise HTTPException(status_code=422, detail="Invalid JSON format")
try:
print(
f"Processing file {fileRAGRequest.filename} for user {fileRAGRequest.username}."
)
# check to see if the file has already been processed.
# write html to subfolder
print(f"Writing file to path: {fileRAGRequest.username}/{fileRAGRequest.filename}...")
for index, segment in enumerate(files):
filename = segment.filename if segment.filename else str(index)
subDir = f"{fileRAGRequest.username}"
with open(os.path.join(subDir, filename), "wb") as htmlFile:
htmlFile.write(await segment.read())
# write metadata to subfolder
print(f"Writing metadata to subfolder {fileRAGRequest.username}...")
with open(os.path.join({fileRAGRequest.username}, "metadata.json"), "w") as metadataFile:
metaData = {
"filename": fileRAGRequest.filename,
"username": fileRAGRequest.username,
"processDate": datetime.now().isoformat(),
}
metadataFile.write(json.dumps(metaData))
vector_store.retrievers[
f"{fileRAGRequest.username}/{fileRAGRequest.filename}"
] = await vector_store.initRetriever(f"{fileRAGRequest.username}/{fileRAGRequest.filename}")
return {
"message": f"File {fileRAGRequest.filename} processed successfully."
}
except Exception as e:
logging.error(f"Error processing book: {e}")
return {"message": f"Error processing book: {e}"}
@app.exception_handler(HTTPException)
async def http_exception_handler(bookRequest: Request, exc: HTTPException):
logging.error(f"HTTPException: {exc.detail}")
return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail})

@ -0,0 +1,88 @@
try:
from enum import StrEnum
except ImportError:
from strenum import StrEnum
from pydantic import BaseModel
from swarms.prompts import QA_PROMPT_TEMPLATE_STR as DefaultSystemPrompt
class AIModel(BaseModel):
id: str
name: str
maxLength: int
tokenLimit: int
class AIModels(BaseModel):
models: list[AIModel]
class State(StrEnum):
Unavailable = "Unavailable"
InProcess = "InProcess"
Processed = "Processed"
class RAGFile(BaseModel):
filename: str
title: str
username: str
state: State = State.Unavailable
class RAGFiles(BaseModel):
files: list[RAGFile]
class Role(StrEnum):
SYSTEM = "system"
ASSISTANT = "assistant"
USER = "user"
class Message(BaseModel):
role: Role
content: str
class ChatRequest(BaseModel):
id: str
model: AIModel = AIModel(
id="llama-2-70b.Q5_K_M",
name="llama-2-70b.Q5_K_M",
maxLength=2048,
tokenLimit=2048,
)
messages: list[Message] = [
Message(role=Role.SYSTEM, content="Hello, how may I help you?"),
Message(role=Role.USER, content=""),
]
maxTokens: int = 2048
temperature: float = 0
prompt: str = DefaultSystemPrompt
file: RAGFile = RAGFile(filename="None", title="None", username="None")
class LogMessage(BaseModel):
message: str
class ConversationRequest(BaseModel):
id: str
name: str
title: RAGFile
messages: list[Message]
model: AIModel
prompt: str
temperature: float
folderId: str | None = None
class ProcessRAGFileRequest(BaseModel):
filename: str
username: str
class GetRAGFileStateRequest(BaseModel):
filename: str
username: str

@ -0,0 +1,78 @@
from langchain.prompts.prompt import PromptTemplate
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
QA_CONDENSE_TEMPLATE_STR = (
"Given the following Chat History and a Follow Up Question, "
"rephrase the follow up question to be a new Standalone Question, "
"but make sure the new question is still asking for the same "
"information as the original follow up question. Respond only "
" with the new Standalone Question. \n"
"Chat History: \n"
"{chat_history} \n"
"Follow Up Question: {question} \n"
"Standalone Question:"
)
CONDENSE_TEMPLATE = PromptTemplate.from_template(
f"{B_INST}{B_SYS}{QA_CONDENSE_TEMPLATE_STR.strip()}{E_SYS}{E_INST}"
)
QA_PROMPT_TEMPLATE_STR = (
"HUMAN: \n You are a helpful AI assistant. "
"Use the following context and chat history to answer the "
"question at the end with a helpful answer. "
"Get straight to the point and always think things through step-by-step before answering. "
"If you don't know the answer, just say 'I don't know'; "
"don't try to make up an answer. \n\n"
"<context>{context}</context>\n"
"<chat_history>{chat_history}</chat_history>\n"
"<question>{question}</question>\n\n"
"AI: Here is the most relevant sentence in the context: \n"
)
QA_PROMPT_TEMPLATE = PromptTemplate.from_template(
f"{B_INST}{B_SYS}{QA_PROMPT_TEMPLATE_STR.strip()}{E_SYS}{E_INST}"
)
DOCUMENT_PROMPT_TEMPLATE = PromptTemplate(
input_variables=["page_content"], template="{page_content}"
)
_STUFF_PROMPT_TEMPLATE_STR = "Summarize the following context: {context}"
STUFF_PROMPT_TEMPLATE = PromptTemplate.from_template(
f"{B_INST}{B_SYS}{_STUFF_PROMPT_TEMPLATE_STR.strip()}{E_SYS}{E_INST}"
)
_SUMMARIZER_SYS_TEMPLATE = (
B_INST
+ B_SYS
+ """Progressively summarize the lines of conversation provided, adding onto the previous summary returning a new summary.
EXAMPLE
Current summary:
The human asks what the AI thinks of artificial intelligence. The AI thinks artificial intelligence is a force for good.
New lines of conversation:
Human: Why do you think artificial intelligence is a force for good?
AI: Because artificial intelligence will help humans reach their full potential.
New summary:
The human asks what the AI thinks of artificial intelligence. The AI thinks artificial intelligence is a force for good because it will help humans reach their full potential.
END OF EXAMPLE"""
+ E_SYS
+ E_INST
)
_SUMMARIZER_INST_TEMPLATE = (
B_INST
+ """Current summary:
{summary}
New lines of conversation:
{new_lines}
New summary:"""
+ E_INST
)
SUMMARY_PROMPT = PromptTemplate.from_template(
template=(_SUMMARIZER_SYS_TEMPLATE + "\n" + _SUMMARIZER_INST_TEMPLATE).strip()
)

Binary file not shown.

After

Width:  |  Height:  |  Size: 146 KiB

@ -0,0 +1,51 @@
# modified from Lanarky source https://github.com/auxon/lanarky
from typing import Any
import pydantic
from pydantic.fields import FieldInfo
try:
from enum import StrEnum # type: ignore
except ImportError:
from enum import Enum
class StrEnum(str, Enum): ...
PYDANTIC_V2 = pydantic.VERSION.startswith("2.")
def model_dump(model: pydantic.BaseModel, **kwargs) -> dict[str, Any]:
"""Dump a pydantic model to a dictionary.
Args:
model: A pydantic model.
"""
if PYDANTIC_V2:
return model.model_dump(**kwargs)
else:
return model.dict(**kwargs)
def model_dump_json(model: pydantic.BaseModel, **kwargs) -> str:
"""Dump a pydantic model to a JSON string.
Args:
model: A pydantic model.
"""
if PYDANTIC_V2:
return model.model_dump_json(**kwargs)
else:
return model.json(**kwargs)
def model_fields(model: pydantic.BaseModel) -> dict[str, FieldInfo]:
"""Get the fields of a pydantic model.
Args:
model: A pydantic model.
"""
if PYDANTIC_V2:
return model.model_fields
else:
return model.__fields__
Loading…
Cancel
Save