diff --git a/swarms/server/callback_handlers.py b/swarms/server/callback_handlers.py new file mode 100644 index 00000000..710b64d7 --- /dev/null +++ b/swarms/server/callback_handlers.py @@ -0,0 +1,448 @@ +# modified from Lanarky sourcecode https://github.com/auxon/lanarky +from typing import Any, Optional + +from fastapi.websockets import WebSocket +from langchain.callbacks.base import AsyncCallbackHandler +from langchain.callbacks.streaming_stdout_final_only import ( + FinalStreamingStdOutCallbackHandler, +) +from langchain.globals import get_llm_cache +from langchain.schema.document import Document +from pydantic import BaseModel +from starlette.types import Message, Send +from sse_starlette.sse import ensure_bytes, ServerSentEvent +from swarms.server.utils import StrEnum, model_dump_json + + +class LangchainEvents(StrEnum): + SOURCE_DOCUMENTS = "source_documents" + + +class BaseCallbackHandler(AsyncCallbackHandler): + """Base callback handler for streaming / async applications.""" + + def __init__(self, **kwargs: dict[str, Any]) -> None: + super().__init__(**kwargs) + self.llm_cache_used = get_llm_cache() is not None + + @property + def always_verbose(self) -> bool: + """Verbose mode is always enabled.""" + return True + + async def on_chat_model_start(self, *args: Any, **kwargs: Any) -> Any: ... + + +class StreamingCallbackHandler(BaseCallbackHandler): + """Callback handler for streaming responses.""" + + def __init__( + self, + *, + send: Send = None, + **kwargs: dict[str, Any], + ) -> None: + """Constructor method. + + Args: + send: The ASGI send callable. + **kwargs: Keyword arguments to pass to the parent constructor. + """ + super().__init__(**kwargs) + + self._send = send + self.streaming = None + + @property + def send(self) -> Send: + return self._send + + @send.setter + def send(self, value: Send) -> None: + """Setter method for send property.""" + if not callable(value): + raise ValueError("value must be a Callable") + self._send = value + + def _construct_message(self, data: str, event: Optional[str] = None) -> Message: + """Constructs message payload. + + Args: + data: The data payload. + event: The event name. + """ + chunk = ServerSentEvent(data=data, event=event) + return { + "type": "http.response.body", + "body": ensure_bytes(chunk, None), + "more_body": True, + } + + +class TokenStreamMode(StrEnum): + TEXT = "text" + JSON = "json" + + +class TokenEventData(BaseModel): + """Event data payload for tokens.""" + + token: str = "" + + +def get_token_data(token: str, mode: TokenStreamMode) -> str: + """Get token data based on mode. + + Args: + token: The token to use. + mode: The stream mode. + """ + if mode not in list(TokenStreamMode): + raise ValueError(f"Invalid stream mode: {mode}") + + if mode == TokenStreamMode.TEXT: + return token + else: + return model_dump_json(TokenEventData(token=token)) + + +class TokenStreamingCallbackHandler(StreamingCallbackHandler): + """Callback handler for streaming tokens.""" + + def __init__( + self, + *, + output_key: str, + mode: TokenStreamMode = TokenStreamMode.JSON, + **kwargs: dict[str, Any], + ) -> None: + """Constructor method. + + Args: + output_key: chain output key. + mode: The stream mode. + **kwargs: Keyword arguments to pass to the parent constructor. + """ + super().__init__(**kwargs) + + self.output_key = output_key + + if mode not in list(TokenStreamMode): + raise ValueError(f"Invalid stream mode: {mode}") + self.mode = mode + + async def on_chain_start(self, *args: Any, **kwargs: dict[str, Any]) -> None: + """Run when chain starts running.""" + self.streaming = False + + async def on_llm_new_token(self, token: str, **kwargs: dict[str, Any]) -> None: + """Run on new LLM token. Only available when streaming is enabled.""" + if not self.streaming: + self.streaming = True + + if self.llm_cache_used: # cache missed (or was never enabled) if we are here + self.llm_cache_used = False + + message = self._construct_message( + data=get_token_data(token, self.mode), event="completion" + ) + await self.send(message) + + async def on_chain_end( + self, outputs: dict[str, Any], **kwargs: dict[str, Any] + ) -> None: + """Run when chain ends running. + + Final output is streamed only if LLM cache is enabled. + """ + if self.llm_cache_used or not self.streaming: + if self.output_key in outputs: + message = self._construct_message( + data=get_token_data(outputs[self.output_key], self.mode), + event="completion", + ) + await self.send(message) + else: + raise KeyError(f"missing outputs key: {self.output_key}") + + +class SourceDocumentsEventData(BaseModel): + """Event data payload for source documents.""" + + source_documents: list[dict[str, Any]] + + +class SourceDocumentsStreamingCallbackHandler(StreamingCallbackHandler): + """Callback handler for streaming source documents.""" + + async def on_chain_end( + self, outputs: dict[str, Any], **kwargs: dict[str, Any] + ) -> None: + """Run when chain ends running.""" + if "source_documents" in outputs: + if not isinstance(outputs["source_documents"], list): + raise ValueError("source_documents must be a list") + if not isinstance(outputs["source_documents"][0], Document): + raise ValueError("source_documents must be a list of Document") + + # NOTE: langchain is using pydantic_v1 for `Document` + source_documents: list[dict] = [ + document.dict() for document in outputs["source_documents"] + ] + message = self._construct_message( + data=model_dump_json( + SourceDocumentsEventData(source_documents=source_documents) + ), + event=LangchainEvents.SOURCE_DOCUMENTS, + ) + await self.send(message) + + +class FinalTokenStreamingCallbackHandler( + TokenStreamingCallbackHandler, FinalStreamingStdOutCallbackHandler +): + """Callback handler for streaming final answer tokens. + + Useful for streaming responses from Langchain Agents. + """ + + def __init__( + self, + *, + answer_prefix_tokens: Optional[list[str]] = None, + strip_tokens: bool = True, + stream_prefix: bool = False, + **kwargs: dict[str, Any], + ) -> None: + """Constructor method. + + Args: + answer_prefix_tokens: The answer prefix tokens to use. + strip_tokens: Whether to strip tokens. + stream_prefix: Whether to stream the answer prefix. + **kwargs: Keyword arguments to pass to the parent constructor. + """ + super().__init__(output_key=None, **kwargs) + + FinalStreamingStdOutCallbackHandler.__init__( + self, + answer_prefix_tokens=answer_prefix_tokens, + strip_tokens=strip_tokens, + stream_prefix=stream_prefix, + ) + + async def on_llm_start(self, *args: Any, **kwargs: dict[str, Any]) -> None: + """Run when LLM starts running.""" + self.answer_reached = False + self.streaming = False + + async def on_llm_new_token(self, token: str, **kwargs: dict[str, Any]) -> None: + """Run on new LLM token. Only available when streaming is enabled.""" + if not self.streaming: + self.streaming = True + + # Remember the last n tokens, where n = len(answer_prefix_tokens) + self.append_to_last_tokens(token) + + # Check if the last n tokens match the answer_prefix_tokens list ... + if self.check_if_answer_reached(): + self.answer_reached = True + if self.stream_prefix: + message = self._construct_message( + data=get_token_data("".join(self.last_tokens), self.mode), + event="completion", + ) + await self.send(message) + + # ... if yes, then print tokens from now on + if self.answer_reached: + message = self._construct_message( + data=get_token_data(token, self.mode), event="completion" + ) + await self.send(message) + + +class WebSocketCallbackHandler(BaseCallbackHandler): + """Callback handler for websocket sessions.""" + + def __init__( + self, + *, + mode: TokenStreamMode = TokenStreamMode.JSON, + websocket: WebSocket = None, + **kwargs: dict[str, Any], + ) -> None: + """Constructor method. + + Args: + mode: The stream mode. + websocket: The websocket to use. + **kwargs: Keyword arguments to pass to the parent constructor. + """ + super().__init__(**kwargs) + + if mode not in list(TokenStreamMode): + raise ValueError(f"Invalid stream mode: {mode}") + self.mode = mode + + self._websocket = websocket + self.streaming = None + + @property + def websocket(self) -> WebSocket: + return self._websocket + + @websocket.setter + def websocket(self, value: WebSocket) -> None: + """Setter method for send property.""" + if not isinstance(value, WebSocket): + raise ValueError("value must be a WebSocket") + self._websocket = value + + def _construct_message(self, data: str, event: Optional[str] = None) -> Message: + """Constructs message payload. + + Args: + data: The data payload. + event: The event name. + """ + return dict(data=data, event=event) + + +class TokenWebSocketCallbackHandler(WebSocketCallbackHandler): + """Callback handler for sending tokens in websocket sessions.""" + + def __init__(self, *, output_key: str, **kwargs: dict[str, Any]) -> None: + """Constructor method. + + Args: + output_key: chain output key. + **kwargs: Keyword arguments to pass to the parent constructor. + """ + super().__init__(**kwargs) + + self.output_key = output_key + + async def on_chain_start(self, *args: Any, **kwargs: dict[str, Any]) -> None: + """Run when chain starts running.""" + self.streaming = False + + async def on_llm_new_token(self, token: str, **kwargs: dict[str, Any]) -> None: + """Run on new LLM token. Only available when streaming is enabled.""" + if not self.streaming: + self.streaming = True + + if self.llm_cache_used: # cache missed (or was never enabled) if we are here + self.llm_cache_used = False + + message = self._construct_message( + data=get_token_data(token, self.mode), event="completion" + ) + await self.websocket.send_json(message) + + async def on_chain_end( + self, outputs: dict[str, Any], **kwargs: dict[str, Any] + ) -> None: + """Run when chain ends running. + + Final output is streamed only if LLM cache is enabled. + """ + if self.llm_cache_used or not self.streaming: + if self.output_key in outputs: + message = self._construct_message( + data=get_token_data(outputs[self.output_key], self.mode), + event="completion", + ) + await self.websocket.send_json(message) + else: + raise KeyError(f"missing outputs key: {self.output_key}") + + +class SourceDocumentsWebSocketCallbackHandler(WebSocketCallbackHandler): + """Callback handler for sending source documents in websocket sessions.""" + + async def on_chain_end( + self, outputs: dict[str, Any], **kwargs: dict[str, Any] + ) -> None: + """Run when chain ends running.""" + if "source_documents" in outputs: + if not isinstance(outputs["source_documents"], list): + raise ValueError("source_documents must be a list") + if not isinstance(outputs["source_documents"][0], Document): + raise ValueError("source_documents must be a list of Document") + + # NOTE: langchain is using pydantic_v1 for `Document` + source_documents: list[dict] = [ + document.dict() for document in outputs["source_documents"] + ] + message = self._construct_message( + data=model_dump_json( + SourceDocumentsEventData(source_documents=source_documents) + ), + event=LangchainEvents.SOURCE_DOCUMENTS, + ) + await self.websocket.send_json(message) + + +class FinalTokenWebSocketCallbackHandler( + TokenWebSocketCallbackHandler, FinalStreamingStdOutCallbackHandler +): + """Callback handler for sending final answer tokens in websocket sessions. + + Useful for streaming responses from Langchain Agents. + """ + + def __init__( + self, + *, + answer_prefix_tokens: Optional[list[str]] = None, + strip_tokens: bool = True, + stream_prefix: bool = False, + **kwargs: dict[str, Any], + ) -> None: + """Constructor method. + + Args: + answer_prefix_tokens: The answer prefix tokens to use. + strip_tokens: Whether to strip tokens. + stream_prefix: Whether to stream the answer prefix. + **kwargs: Keyword arguments to pass to the parent constructor. + """ + super().__init__(output_key=None, **kwargs) + + FinalStreamingStdOutCallbackHandler.__init__( + self, + answer_prefix_tokens=answer_prefix_tokens, + strip_tokens=strip_tokens, + stream_prefix=stream_prefix, + ) + + async def on_llm_start(self, *args, **kwargs) -> None: + """Run when LLM starts running.""" + self.answer_reached = False + self.streaming = False + + async def on_llm_new_token(self, token: str, **kwargs: dict[str, Any]) -> None: + """Run on new LLM token. Only available when streaming is enabled.""" + if not self.streaming: + self.streaming = True + + # Remember the last n tokens, where n = len(answer_prefix_tokens) + self.append_to_last_tokens(token) + + # Check if the last n tokens match the answer_prefix_tokens list ... + if self.check_if_answer_reached(): + self.answer_reached = True + if self.stream_prefix: + message = self._construct_message( + data=get_token_data("".join(self.last_tokens), self.mode), + event="completion", + ) + await self.websocket.send_json(message) + + # ... if yes, then print tokens from now on + if self.answer_reached: + message = self._construct_message( + data=get_token_data(token, self.mode), event="completion" + ) + await self.websocket.send_json(message) \ No newline at end of file diff --git a/swarms/server/responses.py b/swarms/server/responses.py new file mode 100644 index 00000000..fab90a40 --- /dev/null +++ b/swarms/server/responses.py @@ -0,0 +1,179 @@ +from typing import Any +import asyncio +from functools import partial +from typing import Any + +from fastapi import status +from langchain.chains.base import Chain +from starlette.types import Send +from fastapi import status +from sse_starlette import ServerSentEvent +from sse_starlette.sse import EventSourceResponse +from starlette.types import Send + +from swarms.server.utils import StrEnum + +from sse_starlette.sse import ensure_bytes + + +class HTTPStatusDetail(StrEnum): + INTERNAL_SERVER_ERROR = "Internal Server Error" + + +class StreamingResponse(EventSourceResponse): + """`Response` class for streaming server-sent events. + + Follows the + [EventSource protocol](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events#interfaces) + """ + + def __init__( + self, + content: Any = iter(()), + *args: Any, + **kwargs: dict[str, Any], + ) -> None: + """Constructor method. + + Args: + content: The content to stream. + """ + super().__init__(content=content, *args, **kwargs) + + async def stream_response(self, send: Send) -> None: + """Streams data chunks to client by iterating over `content`. + + If an exception occurs while iterating over `content`, an + internal server error is sent to the client. + + Args: + send: The send function from the ASGI framework. + """ + await send( + { + "type": "http.response.start", + "status": self.status_code, + "headers": self.raw_headers, + } + ) + + try: + async for data in self.body_iterator: + chunk = ensure_bytes(data, self.sep) + print(f"chunk: {chunk.decode()}") + await send( + {"type": "http.response.body", "body": chunk, "more_body": True} + ) + except Exception as e: + print(f"body iterator error: {e}") + chunk = ServerSentEvent( + data=dict( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=HTTPStatusDetail.INTERNAL_SERVER_ERROR, + ), + event="error", + ) + await send( + { + "type": "http.response.body", + "body": ensure_bytes(chunk, None), + "more_body": True, + } + ) + + await send({"type": "http.response.body", "body": b"", "more_body": False}) + + + +class ChainRunMode(StrEnum): + """Enum for LangChain run modes.""" + + ASYNC = "async" + SYNC = "sync" + + +class LangchainStreamingResponse(StreamingResponse): + """StreamingResponse class for LangChain resources.""" + + def __init__( + self, + chain: Chain, + config: dict[str, Any], + run_mode: ChainRunMode = ChainRunMode.ASYNC, + *args: Any, + **kwargs: dict[str, Any], + ) -> None: + """Constructor method. + + Args: + chain: A LangChain instance. + config: A config dict. + *args: Positional arguments to pass to the parent constructor. + **kwargs: Keyword arguments to pass to the parent constructor. + """ + super().__init__(*args, **kwargs) + + self.chain = chain + self.config = config + + if run_mode not in list(ChainRunMode): + raise ValueError( + f"Invalid run mode '{run_mode}'. Must be one of {list(ChainRunMode)}" + ) + + self.run_mode = run_mode + + async def stream_response(self, send: Send) -> None: + """Stream LangChain outputs. + + If an exception occurs while iterating over the LangChain, an + internal server error is sent to the client. + + Args: + send: The ASGI send callable. + """ + await send( + { + "type": "http.response.start", + "status": self.status_code, + "headers": self.raw_headers, + } + ) + + if "callbacks" in self.config: + for callback in self.config["callbacks"]: + if hasattr(callback, "send"): + callback.send = send + + try: + # TODO: migrate to `.ainvoke` when adding support + # for LCEL + if self.run_mode == ChainRunMode.ASYNC: + outputs = await self.chain.acall(**self.config) + else: + loop = asyncio.get_event_loop() + outputs = await loop.run_in_executor( + None, partial(self.chain, **self.config) + ) + if self.background is not None: + self.background.kwargs.update({"outputs": outputs}) + except Exception as e: + print(f"chain runtime error: {e}") + if self.background is not None: + self.background.kwargs.update({"outputs": {}, "error": e}) + chunk = ServerSentEvent( + data=dict( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=HTTPStatusDetail.INTERNAL_SERVER_ERROR, + ), + event="error", + ) + await send( + { + "type": "http.response.body", + "body": ensure_bytes(chunk, None), + "more_body": True, + } + ) + + await send({"type": "http.response.body", "body": b"", "more_body": False}) \ No newline at end of file diff --git a/swarms/server/responses/server_responses.py b/swarms/server/responses/server_responses.py new file mode 100644 index 00000000..adf9cb88 --- /dev/null +++ b/swarms/server/responses/server_responses.py @@ -0,0 +1,69 @@ +from typing import Any +from fastapi import status +from starlette.types import Send +from sse_starlette.sse import ensure_bytes, EventSourceResponse, ServerSentEvent + +class StreamingResponse(EventSourceResponse): + """`Response` class for streaming server-sent events. + + Follows the + [EventSource protocol](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events#interfaces) + """ + + def __init__( + self, + content: Any = iter(()), + *args: Any, + **kwargs: dict[str, Any], + ) -> None: + """Constructor method. + + Args: + content: The content to stream. + """ + super().__init__(content=content, *args, **kwargs) + + async def stream_response(self, send: Send) -> None: + """Streams data chunks to client by iterating over `content`. + + If an exception occurs while iterating over `content`, an + internal server error is sent to the client. + + Args: + send: The send function from the ASGI framework. + """ + await send( + { + "type": "http.response.start", + "status": self.status_code, + "headers": self.raw_headers, + } + ) + + try: + async for data in self.body_iterator: + chunk = ensure_bytes(data, self.sep) + with open("log.txt", "a") as log_file: + log_file.write(f"chunk: {chunk.decode()}\n") + await send( + {"type": "http.response.body", "body": chunk, "more_body": True} + ) + except Exception as e: + with open("log.txt", "a") as log_file: + log_file.write(f"body iterator error: {e}\n") + chunk = ServerSentEvent( + data=dict( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Internal Server Error", + ), + event="error", + ) + await send( + { + "type": "http.response.body", + "body": ensure_bytes(chunk, None), + "more_body": True, + } + ) + + await send({"type": "http.response.body", "body": b"", "more_body": False}) \ No newline at end of file diff --git a/swarms/server/server.py b/swarms/server/server.py new file mode 100644 index 00000000..3f759a98 --- /dev/null +++ b/swarms/server/server.py @@ -0,0 +1,445 @@ +import asyncio +import json +import logging +import os +from datetime import datetime +from typing import List + +import langchain +from pydantic import ValidationError, parse_obj_as +from swarms.prompts.chat_prompt import Message, Role +from swarms.server.callback_handlers import SourceDocumentsStreamingCallbackHandler, TokenStreamingCallbackHandler +import tiktoken + +# import torch +from dotenv import load_dotenv +from fastapi import FastAPI, File, Form, HTTPException, Request, UploadFile +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import FileResponse, JSONResponse +from fastapi.routing import APIRouter +from fastapi.staticfiles import StaticFiles +from huggingface_hub import login +from langchain.callbacks import StreamingStdOutCallbackHandler +from langchain.memory import ConversationSummaryBufferMemory +from langchain.memory.chat_message_histories.in_memory import ChatMessageHistory +from langchain.prompts.prompt import PromptTemplate +from langchain_community.chat_models import ChatOpenAI +from swarms.server.responses import LangchainStreamingResponse +from langchain.chains.conversational_retrieval.base import ConversationalRetrievalChain +from langchain.chains.llm import LLMChain +from langchain.chains.combine_documents.stuff import StuffDocumentsChain +from swarms.prompts.conversational_RAG import ( + B_INST, + B_SYS, + CONDENSE_PROMPT_TEMPLATE, + DOCUMENT_PROMPT_TEMPLATE, + E_INST, + E_SYS, + QA_PROMPT_TEMPLATE, + SUMMARY_PROMPT_TEMPLATE, +) + +from swarms.server.vector_store import VectorStorage + +from swarms.server.server_models import ( + ChatRequest, + LogMessage, + AIModel, + AIModels, + RAGFile, + RAGFiles, + State, + GetRAGFileStateRequest, + ProcessRAGFileRequest +) + +# Explicitly specify the path to the .env file +dotenv_path = os.path.join(os.path.dirname(__file__), '.env') +load_dotenv(dotenv_path) + +hf_token = os.environ.get("HUGGINFACEHUB_API_KEY") # Get the Huggingface API Token +uploads = os.environ.get("UPLOADS") # Directory where user uploads files to be parsed for RAG +model_dir = os.environ.get("MODEL_DIR") + +# hugginface.co model (eg. meta-llama/Llama-2-70b-hf) +model_name = os.environ.get("MODEL_NAME") + +# Set OpenAI's API key to 'EMPTY' and API base URL to use vLLM's API server, or set them to OpenAI API key and base URL. +openai_api_key = os.environ.get("OPENAI_API_KEY") or "EMPTY" +openai_api_base = os.environ.get("OPENAI_API_BASE") or "http://localhost:8000/v1" + +env_vars = [ + hf_token, + uploads, + model_dir, + model_name, + openai_api_key, + openai_api_base, +] +missing_vars = [var for var in env_vars if not var] + +if missing_vars: + print( + f"Error: The following environment variables are not set: {', '.join(missing_vars)}" + ) + exit(1) + +useMetal = os.environ.get("USE_METAL", "False") == "True" + +print(f"Uploads={uploads}") +print(f"MODEL_DIR={model_dir}") +print(f"MODEL_NAME={model_name}") +print(f"USE_METAL={useMetal}") +print(f"OPENAI_API_KEY={openai_api_key}") +print(f"OPENAI_API_BASE={openai_api_base}") + +# update tiktoken to include the model name (avoids warning message) +tiktoken.model.MODEL_TO_ENCODING.update( + { + model_name: "cl100k_base", + } +) + +print("Logging in to huggingface.co...") +login(token=hf_token) # login to huggingface.co + +# langchain.debug = True +langchain.verbose = True + +from contextlib import asynccontextmanager + +@asynccontextmanager +async def lifespan(app: FastAPI): + asyncio.create_task(vector_store.initRetrievers()) + yield + +app = FastAPI(title="Chatbot", lifespan=lifespan) +router = APIRouter() + +current_dir = os.path.dirname(__file__) +print("current_dir: " + current_dir) +static_dir = os.path.join(current_dir, "static") +print("static_dir: " + static_dir) +app.mount(static_dir, StaticFiles(directory=static_dir), name="static") + +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["GET", "POST"], + allow_headers=["*"], +) + + +# Create ./uploads folder if it doesn't exist +uploads = uploads or os.path.join(os.getcwd(), "uploads") +if not os.path.exists(uploads): + os.makedirs(uploads) + +# Initialize the vector store +vector_store = VectorStorage(directory=uploads) + + +async def create_chain( + messages: list[Message], + model=model_dir, + max_tokens_to_gen=2048, + temperature=0.5, + prompt: PromptTemplate = QA_PROMPT_TEMPLATE, + file: RAGFile | None = None, + key: str | None = None, +): + print( + f"Creating chain with key={key}, model={model}, max_tokens={max_tokens_to_gen}, temperature={temperature}, prompt={prompt}, file={file.title}" + ) + + llm = ChatOpenAI( + api_key=openai_api_key, + base_url=openai_api_base, + model=model_name, + verbose=True, + streaming=True, + ) + + # if llm is ALlamaCpp: + # llm.max_tokens = max_tokens_to_gen + # elif llm is AGPT4All: + # llm.n_predict = max_tokens_to_gen + # el + # if llm is AChatOllama: + # llm.max_tokens = max_tokens_to_gen + # if llm is VLLMAsync: + # llm.max_tokens = max_tokens_to_gen + + retriever = await vector_store.getRetriever(os.path.join(file.username, file.filename)) + + chat_memory = ChatMessageHistory() + + for message in messages: + if message.role == Role.HUMAN: + chat_memory.add_user_message(message.content) + elif message.role == Role.AI: + chat_memory.add_ai_message(message.content) + elif message.role == Role.SYSTEM: + chat_memory.add_message(message.content) + elif message.role == Role.FUNCTION: + chat_memory.add_message(message.content) + + memory = ConversationSummaryBufferMemory( + llm=llm, + chat_memory=chat_memory, + memory_key="chat_history", + input_key="question", + output_key="answer", + prompt=SUMMARY_PROMPT_TEMPLATE, + return_messages=True, + ) + + question_generator = LLMChain( + llm=llm, + prompt=CONDENSE_PROMPT_TEMPLATE, + memory=memory, + verbose=True, + output_key="answer", + ) + + stuff_chain = LLMChain( + llm=llm, + prompt=prompt, + verbose=True, + output_key="answer", + ) + + doc_chain = StuffDocumentsChain( + llm_chain=stuff_chain, + document_variable_name="context", + document_prompt=DOCUMENT_PROMPT_TEMPLATE, + verbose=True, + output_key="answer", + memory=memory, + ) + + return ConversationalRetrievalChain( + combine_docs_chain=doc_chain, + memory=memory, + retriever=retriever, + question_generator=question_generator, + return_generated_question=False, + return_source_documents=True, + output_key="answer", + verbose=True, + ) + + +router = APIRouter() + +@router.post( + "/chat", + summary="Chatbot", + description="Chatbot AI Service", +) +async def chat(request: ChatRequest): + chain: ConversationalRetrievalChain = await create_chain( + file=request.file, + messages=request.messages[:-1], + model=request.model.id, + max_tokens_to_gen=request.maxTokens, + temperature=request.temperature, + prompt=PromptTemplate.from_template( + f"{B_INST}{B_SYS}{request.prompt.strip()}{E_SYS}{E_INST}" + ), + ) + + # async for token in chain.astream(request.messages[-1].content): + # print(f"token={token}") + + json_string = json.dumps( + { + "question": request.messages[-1].content, + # "chat_history": [message.content for message in request.messages[:-1]], + } + ) + return LangchainStreamingResponse( + chain, + config={ + "inputs": json_string, + "callbacks": [ + StreamingStdOutCallbackHandler(), + TokenStreamingCallbackHandler(output_key="answer"), + SourceDocumentsStreamingCallbackHandler(), + ], + }, + ) + + +app.include_router(router, tags=["chat"]) + + +@app.get("/") +def root(): + return {"message": "Chatbot API"} + + +@app.get("/favicon.ico") +def favicon(): + file_name = "favicon.ico" + file_path = os.path.join(app.root_path, "static", file_name) + return FileResponse( + path=file_path, + headers={"Content-Disposition": "attachment; filename=" + file_name}, + ) + + +@app.post("/log") +def log_message(log_message: LogMessage): + try: + with open("log.txt", "a") as log_file: + log_file.write(log_message.message + "\n") + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error saving log: {e}") + return {"message": "Log saved successfully"} + + +@app.get("/models") +def get_models(): + # llama7B = AIModel( + # id="llama-2-7b-chat-ggml-q4_0", + # name="llama-2-7b-chat-ggml-q4_0", + # maxLength=2048, + # tokenLimit=2048, + # ) + # llama13B = AIModel( + # id="llama-2-13b-chat-ggml-q4_0", + # name="llama-2-13b-chat-ggml-q4_0", + # maxLength=2048, + # tokenLimit=2048, + # ) + llama70B = AIModel( + id="llama-2-70b.Q5_K_M", + name="llama-2-70b.Q5_K_M", + maxLength=2048, + tokenLimit=2048, + ) + models = AIModels(models=[llama70B]) + return models + + +@app.get("/titles") +def getTitles(): + titles = RAGFiles( + titles=[ + # RAGFile( + # versionId="d8ad3b1d-c33c-4524-9691-e93967d4d863", + # title="d8ad3b1d-c33c-4524-9691-e93967d4d863", + # state=State.Unavailable, + # ), + RAGFile( + versionId=collection.name, + title=collection.name, + state=State.InProcess + if collection.name in processing_books + else State.Processed, + ) + for collection in vector_store.list_collections() + if collection.name != "langchain" + ] + ) + return titles + + +processing_books: list[str] = [] +processing_books_lock = asyncio.Lock() + +logging.basicConfig(level=logging.ERROR) + + +@app.post("/titleState") +async def getTitleState(request: GetRAGFileStateRequest): + # FastAPI + Pydantic will throw a 422 Unprocessable Entity if the request isn't the right type. + # try: + logging.debug(f"Received getTitleState request: {request}") + titleStateRequest: GetRAGFileStateRequest = request + # except ValidationError as e: + # print(f"Error validating JSON: {e}") + # raise HTTPException(status_code=422, detail=str(e)) + # except json.JSONDecodeError as e: + # print(f"Error parsing JSON: {e}") + # raise HTTPException(status_code=422, detail="Invalid JSON format") + # check to see if the book has already been processed. + # return the proper State directly to response. + matchingCollection = next( + ( + x + for x in vector_store.list_collections() + if x.name == titleStateRequest.versionRef + ), + None, + ) + print("Got a Title State request for version " + titleStateRequest.versionRef) + if titleStateRequest.versionRef in processing_books: + return {"message": State.InProcess} + elif matchingCollection is not None: + return {"message": State.Processed} + else: + return {"message": State.Unavailable} + + +@app.post("/processRAGFile") +async def processRAGFile( + request: str = Form(...), + files: List[UploadFile] = File(...), +): + try: + logging.debug(f"Received processBook request: {request}") + # Parse the JSON string into a ProcessBookRequest object + fileRAGRequest: ProcessRAGFileRequest = parse_obj_as( + ProcessRAGFileRequest, json.loads(request) + ) + except ValidationError as e: + print(f"Error validating JSON: {e}") + raise HTTPException(status_code=422, detail=str(e)) + except json.JSONDecodeError as e: + print(f"Error parsing JSON: {e}") + raise HTTPException(status_code=422, detail="Invalid JSON format") + + try: + print( + f"Processing file {fileRAGRequest.filename} for user {fileRAGRequest.username}." + ) + # check to see if the file has already been processed. + # write html to subfolder + print(f"Writing file to path: {fileRAGRequest.username}/{fileRAGRequest.filename}...") + + for index, segment in enumerate(files): + filename = segment.filename if segment.filename else str(index) + subDir = f"{fileRAGRequest.username}" + with open(os.path.join(subDir, filename), "wb") as htmlFile: + htmlFile.write(await segment.read()) + + # write metadata to subfolder + print(f"Writing metadata to subfolder {fileRAGRequest.username}...") + with open(os.path.join({fileRAGRequest.username}, "metadata.json"), "w") as metadataFile: + metaData = { + "filename": fileRAGRequest.filename, + "username": fileRAGRequest.username, + "processDate": datetime.now().isoformat(), + } + metadataFile.write(json.dumps(metaData)) + + vector_store.retrievers[ + f"{fileRAGRequest.username}/{fileRAGRequest.filename}" + ] = await vector_store.initRetriever(f"{fileRAGRequest.username}/{fileRAGRequest.filename}") + + return { + "message": f"File {fileRAGRequest.filename} processed successfully." + } + except Exception as e: + logging.error(f"Error processing book: {e}") + return {"message": f"Error processing book: {e}"} + + +@app.exception_handler(HTTPException) +async def http_exception_handler(bookRequest: Request, exc: HTTPException): + logging.error(f"HTTPException: {exc.detail}") + return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail}) + diff --git a/swarms/server/server_models.py b/swarms/server/server_models.py new file mode 100644 index 00000000..1c4491b7 --- /dev/null +++ b/swarms/server/server_models.py @@ -0,0 +1,88 @@ +try: + from enum import StrEnum +except ImportError: + from strenum import StrEnum + +from pydantic import BaseModel +from swarms.prompts import QA_PROMPT_TEMPLATE_STR as DefaultSystemPrompt + +class AIModel(BaseModel): + id: str + name: str + maxLength: int + tokenLimit: int + + +class AIModels(BaseModel): + models: list[AIModel] + + +class State(StrEnum): + Unavailable = "Unavailable" + InProcess = "InProcess" + Processed = "Processed" + + +class RAGFile(BaseModel): + filename: str + title: str + username: str + state: State = State.Unavailable + + +class RAGFiles(BaseModel): + files: list[RAGFile] + + +class Role(StrEnum): + SYSTEM = "system" + ASSISTANT = "assistant" + USER = "user" + + +class Message(BaseModel): + role: Role + content: str + + +class ChatRequest(BaseModel): + id: str + model: AIModel = AIModel( + id="llama-2-70b.Q5_K_M", + name="llama-2-70b.Q5_K_M", + maxLength=2048, + tokenLimit=2048, + ) + messages: list[Message] = [ + Message(role=Role.SYSTEM, content="Hello, how may I help you?"), + Message(role=Role.USER, content=""), + ] + maxTokens: int = 2048 + temperature: float = 0 + prompt: str = DefaultSystemPrompt + file: RAGFile = RAGFile(filename="None", title="None", username="None") + + +class LogMessage(BaseModel): + message: str + + +class ConversationRequest(BaseModel): + id: str + name: str + title: RAGFile + messages: list[Message] + model: AIModel + prompt: str + temperature: float + folderId: str | None = None + + +class ProcessRAGFileRequest(BaseModel): + filename: str + username: str + + +class GetRAGFileStateRequest(BaseModel): + filename: str + username: str \ No newline at end of file diff --git a/swarms/server/server_prompts.py b/swarms/server/server_prompts.py new file mode 100644 index 00000000..00666cb7 --- /dev/null +++ b/swarms/server/server_prompts.py @@ -0,0 +1,78 @@ +from langchain.prompts.prompt import PromptTemplate + +B_INST, E_INST = "[INST]", "[/INST]" +B_SYS, E_SYS = "<>\n", "\n<>\n\n" + +QA_CONDENSE_TEMPLATE_STR = ( + "Given the following Chat History and a Follow Up Question, " + "rephrase the follow up question to be a new Standalone Question, " + "but make sure the new question is still asking for the same " + "information as the original follow up question. Respond only " + " with the new Standalone Question. \n" + "Chat History: \n" + "{chat_history} \n" + "Follow Up Question: {question} \n" + "Standalone Question:" +) + +CONDENSE_TEMPLATE = PromptTemplate.from_template( + f"{B_INST}{B_SYS}{QA_CONDENSE_TEMPLATE_STR.strip()}{E_SYS}{E_INST}" +) + +QA_PROMPT_TEMPLATE_STR = ( + "HUMAN: \n You are a helpful AI assistant. " + "Use the following context and chat history to answer the " + "question at the end with a helpful answer. " + "Get straight to the point and always think things through step-by-step before answering. " + "If you don't know the answer, just say 'I don't know'; " + "don't try to make up an answer. \n\n" + "{context}\n" + "{chat_history}\n" + "{question}\n\n" + "AI: Here is the most relevant sentence in the context: \n" +) + +QA_PROMPT_TEMPLATE = PromptTemplate.from_template( + f"{B_INST}{B_SYS}{QA_PROMPT_TEMPLATE_STR.strip()}{E_SYS}{E_INST}" +) + +DOCUMENT_PROMPT_TEMPLATE = PromptTemplate( + input_variables=["page_content"], template="{page_content}" +) + +_STUFF_PROMPT_TEMPLATE_STR = "Summarize the following context: {context}" + +STUFF_PROMPT_TEMPLATE = PromptTemplate.from_template( + f"{B_INST}{B_SYS}{_STUFF_PROMPT_TEMPLATE_STR.strip()}{E_SYS}{E_INST}" +) + +_SUMMARIZER_SYS_TEMPLATE = ( + B_INST + + B_SYS + + """Progressively summarize the lines of conversation provided, adding onto the previous summary returning a new summary. + EXAMPLE + Current summary: + The human asks what the AI thinks of artificial intelligence. The AI thinks artificial intelligence is a force for good. + New lines of conversation: + Human: Why do you think artificial intelligence is a force for good? + AI: Because artificial intelligence will help humans reach their full potential. + New summary: + The human asks what the AI thinks of artificial intelligence. The AI thinks artificial intelligence is a force for good because it will help humans reach their full potential. + END OF EXAMPLE""" + + E_SYS + + E_INST +) + +_SUMMARIZER_INST_TEMPLATE = ( + B_INST + + """Current summary: + {summary} + New lines of conversation: + {new_lines} + New summary:""" + + E_INST +) + +SUMMARY_PROMPT = PromptTemplate.from_template( + template=(_SUMMARIZER_SYS_TEMPLATE + "\n" + _SUMMARIZER_INST_TEMPLATE).strip() +) \ No newline at end of file diff --git a/swarms/server/static/favicon.ico b/swarms/server/static/favicon.ico new file mode 100644 index 00000000..988bba1c Binary files /dev/null and b/swarms/server/static/favicon.ico differ diff --git a/swarms/server/utils.py b/swarms/server/utils.py new file mode 100644 index 00000000..89e3d7c8 --- /dev/null +++ b/swarms/server/utils.py @@ -0,0 +1,51 @@ +# modified from Lanarky source https://github.com/auxon/lanarky +from typing import Any + +import pydantic +from pydantic.fields import FieldInfo + +try: + from enum import StrEnum # type: ignore +except ImportError: + from enum import Enum + + class StrEnum(str, Enum): ... + + +PYDANTIC_V2 = pydantic.VERSION.startswith("2.") + + +def model_dump(model: pydantic.BaseModel, **kwargs) -> dict[str, Any]: + """Dump a pydantic model to a dictionary. + + Args: + model: A pydantic model. + """ + if PYDANTIC_V2: + return model.model_dump(**kwargs) + else: + return model.dict(**kwargs) + + +def model_dump_json(model: pydantic.BaseModel, **kwargs) -> str: + """Dump a pydantic model to a JSON string. + + Args: + model: A pydantic model. + """ + if PYDANTIC_V2: + return model.model_dump_json(**kwargs) + else: + return model.json(**kwargs) + + +def model_fields(model: pydantic.BaseModel) -> dict[str, FieldInfo]: + """Get the fields of a pydantic model. + + Args: + model: A pydantic model. + """ + if PYDANTIC_V2: + return model.model_fields + else: + return model.__fields__ \ No newline at end of file