diff --git a/swarms/server/callback_handlers.py b/swarms/server/callback_handlers.py
new file mode 100644
index 00000000..710b64d7
--- /dev/null
+++ b/swarms/server/callback_handlers.py
@@ -0,0 +1,448 @@
+# modified from Lanarky sourcecode https://github.com/auxon/lanarky
+from typing import Any, Optional
+
+from fastapi.websockets import WebSocket
+from langchain.callbacks.base import AsyncCallbackHandler
+from langchain.callbacks.streaming_stdout_final_only import (
+ FinalStreamingStdOutCallbackHandler,
+)
+from langchain.globals import get_llm_cache
+from langchain.schema.document import Document
+from pydantic import BaseModel
+from starlette.types import Message, Send
+from sse_starlette.sse import ensure_bytes, ServerSentEvent
+from swarms.server.utils import StrEnum, model_dump_json
+
+
+class LangchainEvents(StrEnum):
+ SOURCE_DOCUMENTS = "source_documents"
+
+
+class BaseCallbackHandler(AsyncCallbackHandler):
+ """Base callback handler for streaming / async applications."""
+
+ def __init__(self, **kwargs: dict[str, Any]) -> None:
+ super().__init__(**kwargs)
+ self.llm_cache_used = get_llm_cache() is not None
+
+ @property
+ def always_verbose(self) -> bool:
+ """Verbose mode is always enabled."""
+ return True
+
+ async def on_chat_model_start(self, *args: Any, **kwargs: Any) -> Any: ...
+
+
+class StreamingCallbackHandler(BaseCallbackHandler):
+ """Callback handler for streaming responses."""
+
+ def __init__(
+ self,
+ *,
+ send: Send = None,
+ **kwargs: dict[str, Any],
+ ) -> None:
+ """Constructor method.
+
+ Args:
+ send: The ASGI send callable.
+ **kwargs: Keyword arguments to pass to the parent constructor.
+ """
+ super().__init__(**kwargs)
+
+ self._send = send
+ self.streaming = None
+
+ @property
+ def send(self) -> Send:
+ return self._send
+
+ @send.setter
+ def send(self, value: Send) -> None:
+ """Setter method for send property."""
+ if not callable(value):
+ raise ValueError("value must be a Callable")
+ self._send = value
+
+ def _construct_message(self, data: str, event: Optional[str] = None) -> Message:
+ """Constructs message payload.
+
+ Args:
+ data: The data payload.
+ event: The event name.
+ """
+ chunk = ServerSentEvent(data=data, event=event)
+ return {
+ "type": "http.response.body",
+ "body": ensure_bytes(chunk, None),
+ "more_body": True,
+ }
+
+
+class TokenStreamMode(StrEnum):
+ TEXT = "text"
+ JSON = "json"
+
+
+class TokenEventData(BaseModel):
+ """Event data payload for tokens."""
+
+ token: str = ""
+
+
+def get_token_data(token: str, mode: TokenStreamMode) -> str:
+ """Get token data based on mode.
+
+ Args:
+ token: The token to use.
+ mode: The stream mode.
+ """
+ if mode not in list(TokenStreamMode):
+ raise ValueError(f"Invalid stream mode: {mode}")
+
+ if mode == TokenStreamMode.TEXT:
+ return token
+ else:
+ return model_dump_json(TokenEventData(token=token))
+
+
+class TokenStreamingCallbackHandler(StreamingCallbackHandler):
+ """Callback handler for streaming tokens."""
+
+ def __init__(
+ self,
+ *,
+ output_key: str,
+ mode: TokenStreamMode = TokenStreamMode.JSON,
+ **kwargs: dict[str, Any],
+ ) -> None:
+ """Constructor method.
+
+ Args:
+ output_key: chain output key.
+ mode: The stream mode.
+ **kwargs: Keyword arguments to pass to the parent constructor.
+ """
+ super().__init__(**kwargs)
+
+ self.output_key = output_key
+
+ if mode not in list(TokenStreamMode):
+ raise ValueError(f"Invalid stream mode: {mode}")
+ self.mode = mode
+
+ async def on_chain_start(self, *args: Any, **kwargs: dict[str, Any]) -> None:
+ """Run when chain starts running."""
+ self.streaming = False
+
+ async def on_llm_new_token(self, token: str, **kwargs: dict[str, Any]) -> None:
+ """Run on new LLM token. Only available when streaming is enabled."""
+ if not self.streaming:
+ self.streaming = True
+
+ if self.llm_cache_used: # cache missed (or was never enabled) if we are here
+ self.llm_cache_used = False
+
+ message = self._construct_message(
+ data=get_token_data(token, self.mode), event="completion"
+ )
+ await self.send(message)
+
+ async def on_chain_end(
+ self, outputs: dict[str, Any], **kwargs: dict[str, Any]
+ ) -> None:
+ """Run when chain ends running.
+
+ Final output is streamed only if LLM cache is enabled.
+ """
+ if self.llm_cache_used or not self.streaming:
+ if self.output_key in outputs:
+ message = self._construct_message(
+ data=get_token_data(outputs[self.output_key], self.mode),
+ event="completion",
+ )
+ await self.send(message)
+ else:
+ raise KeyError(f"missing outputs key: {self.output_key}")
+
+
+class SourceDocumentsEventData(BaseModel):
+ """Event data payload for source documents."""
+
+ source_documents: list[dict[str, Any]]
+
+
+class SourceDocumentsStreamingCallbackHandler(StreamingCallbackHandler):
+ """Callback handler for streaming source documents."""
+
+ async def on_chain_end(
+ self, outputs: dict[str, Any], **kwargs: dict[str, Any]
+ ) -> None:
+ """Run when chain ends running."""
+ if "source_documents" in outputs:
+ if not isinstance(outputs["source_documents"], list):
+ raise ValueError("source_documents must be a list")
+ if not isinstance(outputs["source_documents"][0], Document):
+ raise ValueError("source_documents must be a list of Document")
+
+ # NOTE: langchain is using pydantic_v1 for `Document`
+ source_documents: list[dict] = [
+ document.dict() for document in outputs["source_documents"]
+ ]
+ message = self._construct_message(
+ data=model_dump_json(
+ SourceDocumentsEventData(source_documents=source_documents)
+ ),
+ event=LangchainEvents.SOURCE_DOCUMENTS,
+ )
+ await self.send(message)
+
+
+class FinalTokenStreamingCallbackHandler(
+ TokenStreamingCallbackHandler, FinalStreamingStdOutCallbackHandler
+):
+ """Callback handler for streaming final answer tokens.
+
+ Useful for streaming responses from Langchain Agents.
+ """
+
+ def __init__(
+ self,
+ *,
+ answer_prefix_tokens: Optional[list[str]] = None,
+ strip_tokens: bool = True,
+ stream_prefix: bool = False,
+ **kwargs: dict[str, Any],
+ ) -> None:
+ """Constructor method.
+
+ Args:
+ answer_prefix_tokens: The answer prefix tokens to use.
+ strip_tokens: Whether to strip tokens.
+ stream_prefix: Whether to stream the answer prefix.
+ **kwargs: Keyword arguments to pass to the parent constructor.
+ """
+ super().__init__(output_key=None, **kwargs)
+
+ FinalStreamingStdOutCallbackHandler.__init__(
+ self,
+ answer_prefix_tokens=answer_prefix_tokens,
+ strip_tokens=strip_tokens,
+ stream_prefix=stream_prefix,
+ )
+
+ async def on_llm_start(self, *args: Any, **kwargs: dict[str, Any]) -> None:
+ """Run when LLM starts running."""
+ self.answer_reached = False
+ self.streaming = False
+
+ async def on_llm_new_token(self, token: str, **kwargs: dict[str, Any]) -> None:
+ """Run on new LLM token. Only available when streaming is enabled."""
+ if not self.streaming:
+ self.streaming = True
+
+ # Remember the last n tokens, where n = len(answer_prefix_tokens)
+ self.append_to_last_tokens(token)
+
+ # Check if the last n tokens match the answer_prefix_tokens list ...
+ if self.check_if_answer_reached():
+ self.answer_reached = True
+ if self.stream_prefix:
+ message = self._construct_message(
+ data=get_token_data("".join(self.last_tokens), self.mode),
+ event="completion",
+ )
+ await self.send(message)
+
+ # ... if yes, then print tokens from now on
+ if self.answer_reached:
+ message = self._construct_message(
+ data=get_token_data(token, self.mode), event="completion"
+ )
+ await self.send(message)
+
+
+class WebSocketCallbackHandler(BaseCallbackHandler):
+ """Callback handler for websocket sessions."""
+
+ def __init__(
+ self,
+ *,
+ mode: TokenStreamMode = TokenStreamMode.JSON,
+ websocket: WebSocket = None,
+ **kwargs: dict[str, Any],
+ ) -> None:
+ """Constructor method.
+
+ Args:
+ mode: The stream mode.
+ websocket: The websocket to use.
+ **kwargs: Keyword arguments to pass to the parent constructor.
+ """
+ super().__init__(**kwargs)
+
+ if mode not in list(TokenStreamMode):
+ raise ValueError(f"Invalid stream mode: {mode}")
+ self.mode = mode
+
+ self._websocket = websocket
+ self.streaming = None
+
+ @property
+ def websocket(self) -> WebSocket:
+ return self._websocket
+
+ @websocket.setter
+ def websocket(self, value: WebSocket) -> None:
+ """Setter method for send property."""
+ if not isinstance(value, WebSocket):
+ raise ValueError("value must be a WebSocket")
+ self._websocket = value
+
+ def _construct_message(self, data: str, event: Optional[str] = None) -> Message:
+ """Constructs message payload.
+
+ Args:
+ data: The data payload.
+ event: The event name.
+ """
+ return dict(data=data, event=event)
+
+
+class TokenWebSocketCallbackHandler(WebSocketCallbackHandler):
+ """Callback handler for sending tokens in websocket sessions."""
+
+ def __init__(self, *, output_key: str, **kwargs: dict[str, Any]) -> None:
+ """Constructor method.
+
+ Args:
+ output_key: chain output key.
+ **kwargs: Keyword arguments to pass to the parent constructor.
+ """
+ super().__init__(**kwargs)
+
+ self.output_key = output_key
+
+ async def on_chain_start(self, *args: Any, **kwargs: dict[str, Any]) -> None:
+ """Run when chain starts running."""
+ self.streaming = False
+
+ async def on_llm_new_token(self, token: str, **kwargs: dict[str, Any]) -> None:
+ """Run on new LLM token. Only available when streaming is enabled."""
+ if not self.streaming:
+ self.streaming = True
+
+ if self.llm_cache_used: # cache missed (or was never enabled) if we are here
+ self.llm_cache_used = False
+
+ message = self._construct_message(
+ data=get_token_data(token, self.mode), event="completion"
+ )
+ await self.websocket.send_json(message)
+
+ async def on_chain_end(
+ self, outputs: dict[str, Any], **kwargs: dict[str, Any]
+ ) -> None:
+ """Run when chain ends running.
+
+ Final output is streamed only if LLM cache is enabled.
+ """
+ if self.llm_cache_used or not self.streaming:
+ if self.output_key in outputs:
+ message = self._construct_message(
+ data=get_token_data(outputs[self.output_key], self.mode),
+ event="completion",
+ )
+ await self.websocket.send_json(message)
+ else:
+ raise KeyError(f"missing outputs key: {self.output_key}")
+
+
+class SourceDocumentsWebSocketCallbackHandler(WebSocketCallbackHandler):
+ """Callback handler for sending source documents in websocket sessions."""
+
+ async def on_chain_end(
+ self, outputs: dict[str, Any], **kwargs: dict[str, Any]
+ ) -> None:
+ """Run when chain ends running."""
+ if "source_documents" in outputs:
+ if not isinstance(outputs["source_documents"], list):
+ raise ValueError("source_documents must be a list")
+ if not isinstance(outputs["source_documents"][0], Document):
+ raise ValueError("source_documents must be a list of Document")
+
+ # NOTE: langchain is using pydantic_v1 for `Document`
+ source_documents: list[dict] = [
+ document.dict() for document in outputs["source_documents"]
+ ]
+ message = self._construct_message(
+ data=model_dump_json(
+ SourceDocumentsEventData(source_documents=source_documents)
+ ),
+ event=LangchainEvents.SOURCE_DOCUMENTS,
+ )
+ await self.websocket.send_json(message)
+
+
+class FinalTokenWebSocketCallbackHandler(
+ TokenWebSocketCallbackHandler, FinalStreamingStdOutCallbackHandler
+):
+ """Callback handler for sending final answer tokens in websocket sessions.
+
+ Useful for streaming responses from Langchain Agents.
+ """
+
+ def __init__(
+ self,
+ *,
+ answer_prefix_tokens: Optional[list[str]] = None,
+ strip_tokens: bool = True,
+ stream_prefix: bool = False,
+ **kwargs: dict[str, Any],
+ ) -> None:
+ """Constructor method.
+
+ Args:
+ answer_prefix_tokens: The answer prefix tokens to use.
+ strip_tokens: Whether to strip tokens.
+ stream_prefix: Whether to stream the answer prefix.
+ **kwargs: Keyword arguments to pass to the parent constructor.
+ """
+ super().__init__(output_key=None, **kwargs)
+
+ FinalStreamingStdOutCallbackHandler.__init__(
+ self,
+ answer_prefix_tokens=answer_prefix_tokens,
+ strip_tokens=strip_tokens,
+ stream_prefix=stream_prefix,
+ )
+
+ async def on_llm_start(self, *args, **kwargs) -> None:
+ """Run when LLM starts running."""
+ self.answer_reached = False
+ self.streaming = False
+
+ async def on_llm_new_token(self, token: str, **kwargs: dict[str, Any]) -> None:
+ """Run on new LLM token. Only available when streaming is enabled."""
+ if not self.streaming:
+ self.streaming = True
+
+ # Remember the last n tokens, where n = len(answer_prefix_tokens)
+ self.append_to_last_tokens(token)
+
+ # Check if the last n tokens match the answer_prefix_tokens list ...
+ if self.check_if_answer_reached():
+ self.answer_reached = True
+ if self.stream_prefix:
+ message = self._construct_message(
+ data=get_token_data("".join(self.last_tokens), self.mode),
+ event="completion",
+ )
+ await self.websocket.send_json(message)
+
+ # ... if yes, then print tokens from now on
+ if self.answer_reached:
+ message = self._construct_message(
+ data=get_token_data(token, self.mode), event="completion"
+ )
+ await self.websocket.send_json(message)
\ No newline at end of file
diff --git a/swarms/server/responses.py b/swarms/server/responses.py
new file mode 100644
index 00000000..fab90a40
--- /dev/null
+++ b/swarms/server/responses.py
@@ -0,0 +1,179 @@
+from typing import Any
+import asyncio
+from functools import partial
+from typing import Any
+
+from fastapi import status
+from langchain.chains.base import Chain
+from starlette.types import Send
+from fastapi import status
+from sse_starlette import ServerSentEvent
+from sse_starlette.sse import EventSourceResponse
+from starlette.types import Send
+
+from swarms.server.utils import StrEnum
+
+from sse_starlette.sse import ensure_bytes
+
+
+class HTTPStatusDetail(StrEnum):
+ INTERNAL_SERVER_ERROR = "Internal Server Error"
+
+
+class StreamingResponse(EventSourceResponse):
+ """`Response` class for streaming server-sent events.
+
+ Follows the
+ [EventSource protocol](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events#interfaces)
+ """
+
+ def __init__(
+ self,
+ content: Any = iter(()),
+ *args: Any,
+ **kwargs: dict[str, Any],
+ ) -> None:
+ """Constructor method.
+
+ Args:
+ content: The content to stream.
+ """
+ super().__init__(content=content, *args, **kwargs)
+
+ async def stream_response(self, send: Send) -> None:
+ """Streams data chunks to client by iterating over `content`.
+
+ If an exception occurs while iterating over `content`, an
+ internal server error is sent to the client.
+
+ Args:
+ send: The send function from the ASGI framework.
+ """
+ await send(
+ {
+ "type": "http.response.start",
+ "status": self.status_code,
+ "headers": self.raw_headers,
+ }
+ )
+
+ try:
+ async for data in self.body_iterator:
+ chunk = ensure_bytes(data, self.sep)
+ print(f"chunk: {chunk.decode()}")
+ await send(
+ {"type": "http.response.body", "body": chunk, "more_body": True}
+ )
+ except Exception as e:
+ print(f"body iterator error: {e}")
+ chunk = ServerSentEvent(
+ data=dict(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=HTTPStatusDetail.INTERNAL_SERVER_ERROR,
+ ),
+ event="error",
+ )
+ await send(
+ {
+ "type": "http.response.body",
+ "body": ensure_bytes(chunk, None),
+ "more_body": True,
+ }
+ )
+
+ await send({"type": "http.response.body", "body": b"", "more_body": False})
+
+
+
+class ChainRunMode(StrEnum):
+ """Enum for LangChain run modes."""
+
+ ASYNC = "async"
+ SYNC = "sync"
+
+
+class LangchainStreamingResponse(StreamingResponse):
+ """StreamingResponse class for LangChain resources."""
+
+ def __init__(
+ self,
+ chain: Chain,
+ config: dict[str, Any],
+ run_mode: ChainRunMode = ChainRunMode.ASYNC,
+ *args: Any,
+ **kwargs: dict[str, Any],
+ ) -> None:
+ """Constructor method.
+
+ Args:
+ chain: A LangChain instance.
+ config: A config dict.
+ *args: Positional arguments to pass to the parent constructor.
+ **kwargs: Keyword arguments to pass to the parent constructor.
+ """
+ super().__init__(*args, **kwargs)
+
+ self.chain = chain
+ self.config = config
+
+ if run_mode not in list(ChainRunMode):
+ raise ValueError(
+ f"Invalid run mode '{run_mode}'. Must be one of {list(ChainRunMode)}"
+ )
+
+ self.run_mode = run_mode
+
+ async def stream_response(self, send: Send) -> None:
+ """Stream LangChain outputs.
+
+ If an exception occurs while iterating over the LangChain, an
+ internal server error is sent to the client.
+
+ Args:
+ send: The ASGI send callable.
+ """
+ await send(
+ {
+ "type": "http.response.start",
+ "status": self.status_code,
+ "headers": self.raw_headers,
+ }
+ )
+
+ if "callbacks" in self.config:
+ for callback in self.config["callbacks"]:
+ if hasattr(callback, "send"):
+ callback.send = send
+
+ try:
+ # TODO: migrate to `.ainvoke` when adding support
+ # for LCEL
+ if self.run_mode == ChainRunMode.ASYNC:
+ outputs = await self.chain.acall(**self.config)
+ else:
+ loop = asyncio.get_event_loop()
+ outputs = await loop.run_in_executor(
+ None, partial(self.chain, **self.config)
+ )
+ if self.background is not None:
+ self.background.kwargs.update({"outputs": outputs})
+ except Exception as e:
+ print(f"chain runtime error: {e}")
+ if self.background is not None:
+ self.background.kwargs.update({"outputs": {}, "error": e})
+ chunk = ServerSentEvent(
+ data=dict(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=HTTPStatusDetail.INTERNAL_SERVER_ERROR,
+ ),
+ event="error",
+ )
+ await send(
+ {
+ "type": "http.response.body",
+ "body": ensure_bytes(chunk, None),
+ "more_body": True,
+ }
+ )
+
+ await send({"type": "http.response.body", "body": b"", "more_body": False})
\ No newline at end of file
diff --git a/swarms/server/responses/server_responses.py b/swarms/server/responses/server_responses.py
new file mode 100644
index 00000000..adf9cb88
--- /dev/null
+++ b/swarms/server/responses/server_responses.py
@@ -0,0 +1,69 @@
+from typing import Any
+from fastapi import status
+from starlette.types import Send
+from sse_starlette.sse import ensure_bytes, EventSourceResponse, ServerSentEvent
+
+class StreamingResponse(EventSourceResponse):
+ """`Response` class for streaming server-sent events.
+
+ Follows the
+ [EventSource protocol](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events#interfaces)
+ """
+
+ def __init__(
+ self,
+ content: Any = iter(()),
+ *args: Any,
+ **kwargs: dict[str, Any],
+ ) -> None:
+ """Constructor method.
+
+ Args:
+ content: The content to stream.
+ """
+ super().__init__(content=content, *args, **kwargs)
+
+ async def stream_response(self, send: Send) -> None:
+ """Streams data chunks to client by iterating over `content`.
+
+ If an exception occurs while iterating over `content`, an
+ internal server error is sent to the client.
+
+ Args:
+ send: The send function from the ASGI framework.
+ """
+ await send(
+ {
+ "type": "http.response.start",
+ "status": self.status_code,
+ "headers": self.raw_headers,
+ }
+ )
+
+ try:
+ async for data in self.body_iterator:
+ chunk = ensure_bytes(data, self.sep)
+ with open("log.txt", "a") as log_file:
+ log_file.write(f"chunk: {chunk.decode()}\n")
+ await send(
+ {"type": "http.response.body", "body": chunk, "more_body": True}
+ )
+ except Exception as e:
+ with open("log.txt", "a") as log_file:
+ log_file.write(f"body iterator error: {e}\n")
+ chunk = ServerSentEvent(
+ data=dict(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail="Internal Server Error",
+ ),
+ event="error",
+ )
+ await send(
+ {
+ "type": "http.response.body",
+ "body": ensure_bytes(chunk, None),
+ "more_body": True,
+ }
+ )
+
+ await send({"type": "http.response.body", "body": b"", "more_body": False})
\ No newline at end of file
diff --git a/swarms/server/server.py b/swarms/server/server.py
new file mode 100644
index 00000000..3f759a98
--- /dev/null
+++ b/swarms/server/server.py
@@ -0,0 +1,445 @@
+import asyncio
+import json
+import logging
+import os
+from datetime import datetime
+from typing import List
+
+import langchain
+from pydantic import ValidationError, parse_obj_as
+from swarms.prompts.chat_prompt import Message, Role
+from swarms.server.callback_handlers import SourceDocumentsStreamingCallbackHandler, TokenStreamingCallbackHandler
+import tiktoken
+
+# import torch
+from dotenv import load_dotenv
+from fastapi import FastAPI, File, Form, HTTPException, Request, UploadFile
+from fastapi.middleware.cors import CORSMiddleware
+from fastapi.responses import FileResponse, JSONResponse
+from fastapi.routing import APIRouter
+from fastapi.staticfiles import StaticFiles
+from huggingface_hub import login
+from langchain.callbacks import StreamingStdOutCallbackHandler
+from langchain.memory import ConversationSummaryBufferMemory
+from langchain.memory.chat_message_histories.in_memory import ChatMessageHistory
+from langchain.prompts.prompt import PromptTemplate
+from langchain_community.chat_models import ChatOpenAI
+from swarms.server.responses import LangchainStreamingResponse
+from langchain.chains.conversational_retrieval.base import ConversationalRetrievalChain
+from langchain.chains.llm import LLMChain
+from langchain.chains.combine_documents.stuff import StuffDocumentsChain
+from swarms.prompts.conversational_RAG import (
+ B_INST,
+ B_SYS,
+ CONDENSE_PROMPT_TEMPLATE,
+ DOCUMENT_PROMPT_TEMPLATE,
+ E_INST,
+ E_SYS,
+ QA_PROMPT_TEMPLATE,
+ SUMMARY_PROMPT_TEMPLATE,
+)
+
+from swarms.server.vector_store import VectorStorage
+
+from swarms.server.server_models import (
+ ChatRequest,
+ LogMessage,
+ AIModel,
+ AIModels,
+ RAGFile,
+ RAGFiles,
+ State,
+ GetRAGFileStateRequest,
+ ProcessRAGFileRequest
+)
+
+# Explicitly specify the path to the .env file
+dotenv_path = os.path.join(os.path.dirname(__file__), '.env')
+load_dotenv(dotenv_path)
+
+hf_token = os.environ.get("HUGGINFACEHUB_API_KEY") # Get the Huggingface API Token
+uploads = os.environ.get("UPLOADS") # Directory where user uploads files to be parsed for RAG
+model_dir = os.environ.get("MODEL_DIR")
+
+# hugginface.co model (eg. meta-llama/Llama-2-70b-hf)
+model_name = os.environ.get("MODEL_NAME")
+
+# Set OpenAI's API key to 'EMPTY' and API base URL to use vLLM's API server, or set them to OpenAI API key and base URL.
+openai_api_key = os.environ.get("OPENAI_API_KEY") or "EMPTY"
+openai_api_base = os.environ.get("OPENAI_API_BASE") or "http://localhost:8000/v1"
+
+env_vars = [
+ hf_token,
+ uploads,
+ model_dir,
+ model_name,
+ openai_api_key,
+ openai_api_base,
+]
+missing_vars = [var for var in env_vars if not var]
+
+if missing_vars:
+ print(
+ f"Error: The following environment variables are not set: {', '.join(missing_vars)}"
+ )
+ exit(1)
+
+useMetal = os.environ.get("USE_METAL", "False") == "True"
+
+print(f"Uploads={uploads}")
+print(f"MODEL_DIR={model_dir}")
+print(f"MODEL_NAME={model_name}")
+print(f"USE_METAL={useMetal}")
+print(f"OPENAI_API_KEY={openai_api_key}")
+print(f"OPENAI_API_BASE={openai_api_base}")
+
+# update tiktoken to include the model name (avoids warning message)
+tiktoken.model.MODEL_TO_ENCODING.update(
+ {
+ model_name: "cl100k_base",
+ }
+)
+
+print("Logging in to huggingface.co...")
+login(token=hf_token) # login to huggingface.co
+
+# langchain.debug = True
+langchain.verbose = True
+
+from contextlib import asynccontextmanager
+
+@asynccontextmanager
+async def lifespan(app: FastAPI):
+ asyncio.create_task(vector_store.initRetrievers())
+ yield
+
+app = FastAPI(title="Chatbot", lifespan=lifespan)
+router = APIRouter()
+
+current_dir = os.path.dirname(__file__)
+print("current_dir: " + current_dir)
+static_dir = os.path.join(current_dir, "static")
+print("static_dir: " + static_dir)
+app.mount(static_dir, StaticFiles(directory=static_dir), name="static")
+
+app.add_middleware(
+ CORSMiddleware,
+ allow_origins=["*"],
+ allow_credentials=True,
+ allow_methods=["GET", "POST"],
+ allow_headers=["*"],
+)
+
+
+# Create ./uploads folder if it doesn't exist
+uploads = uploads or os.path.join(os.getcwd(), "uploads")
+if not os.path.exists(uploads):
+ os.makedirs(uploads)
+
+# Initialize the vector store
+vector_store = VectorStorage(directory=uploads)
+
+
+async def create_chain(
+ messages: list[Message],
+ model=model_dir,
+ max_tokens_to_gen=2048,
+ temperature=0.5,
+ prompt: PromptTemplate = QA_PROMPT_TEMPLATE,
+ file: RAGFile | None = None,
+ key: str | None = None,
+):
+ print(
+ f"Creating chain with key={key}, model={model}, max_tokens={max_tokens_to_gen}, temperature={temperature}, prompt={prompt}, file={file.title}"
+ )
+
+ llm = ChatOpenAI(
+ api_key=openai_api_key,
+ base_url=openai_api_base,
+ model=model_name,
+ verbose=True,
+ streaming=True,
+ )
+
+ # if llm is ALlamaCpp:
+ # llm.max_tokens = max_tokens_to_gen
+ # elif llm is AGPT4All:
+ # llm.n_predict = max_tokens_to_gen
+ # el
+ # if llm is AChatOllama:
+ # llm.max_tokens = max_tokens_to_gen
+ # if llm is VLLMAsync:
+ # llm.max_tokens = max_tokens_to_gen
+
+ retriever = await vector_store.getRetriever(os.path.join(file.username, file.filename))
+
+ chat_memory = ChatMessageHistory()
+
+ for message in messages:
+ if message.role == Role.HUMAN:
+ chat_memory.add_user_message(message.content)
+ elif message.role == Role.AI:
+ chat_memory.add_ai_message(message.content)
+ elif message.role == Role.SYSTEM:
+ chat_memory.add_message(message.content)
+ elif message.role == Role.FUNCTION:
+ chat_memory.add_message(message.content)
+
+ memory = ConversationSummaryBufferMemory(
+ llm=llm,
+ chat_memory=chat_memory,
+ memory_key="chat_history",
+ input_key="question",
+ output_key="answer",
+ prompt=SUMMARY_PROMPT_TEMPLATE,
+ return_messages=True,
+ )
+
+ question_generator = LLMChain(
+ llm=llm,
+ prompt=CONDENSE_PROMPT_TEMPLATE,
+ memory=memory,
+ verbose=True,
+ output_key="answer",
+ )
+
+ stuff_chain = LLMChain(
+ llm=llm,
+ prompt=prompt,
+ verbose=True,
+ output_key="answer",
+ )
+
+ doc_chain = StuffDocumentsChain(
+ llm_chain=stuff_chain,
+ document_variable_name="context",
+ document_prompt=DOCUMENT_PROMPT_TEMPLATE,
+ verbose=True,
+ output_key="answer",
+ memory=memory,
+ )
+
+ return ConversationalRetrievalChain(
+ combine_docs_chain=doc_chain,
+ memory=memory,
+ retriever=retriever,
+ question_generator=question_generator,
+ return_generated_question=False,
+ return_source_documents=True,
+ output_key="answer",
+ verbose=True,
+ )
+
+
+router = APIRouter()
+
+@router.post(
+ "/chat",
+ summary="Chatbot",
+ description="Chatbot AI Service",
+)
+async def chat(request: ChatRequest):
+ chain: ConversationalRetrievalChain = await create_chain(
+ file=request.file,
+ messages=request.messages[:-1],
+ model=request.model.id,
+ max_tokens_to_gen=request.maxTokens,
+ temperature=request.temperature,
+ prompt=PromptTemplate.from_template(
+ f"{B_INST}{B_SYS}{request.prompt.strip()}{E_SYS}{E_INST}"
+ ),
+ )
+
+ # async for token in chain.astream(request.messages[-1].content):
+ # print(f"token={token}")
+
+ json_string = json.dumps(
+ {
+ "question": request.messages[-1].content,
+ # "chat_history": [message.content for message in request.messages[:-1]],
+ }
+ )
+ return LangchainStreamingResponse(
+ chain,
+ config={
+ "inputs": json_string,
+ "callbacks": [
+ StreamingStdOutCallbackHandler(),
+ TokenStreamingCallbackHandler(output_key="answer"),
+ SourceDocumentsStreamingCallbackHandler(),
+ ],
+ },
+ )
+
+
+app.include_router(router, tags=["chat"])
+
+
+@app.get("/")
+def root():
+ return {"message": "Chatbot API"}
+
+
+@app.get("/favicon.ico")
+def favicon():
+ file_name = "favicon.ico"
+ file_path = os.path.join(app.root_path, "static", file_name)
+ return FileResponse(
+ path=file_path,
+ headers={"Content-Disposition": "attachment; filename=" + file_name},
+ )
+
+
+@app.post("/log")
+def log_message(log_message: LogMessage):
+ try:
+ with open("log.txt", "a") as log_file:
+ log_file.write(log_message.message + "\n")
+ except Exception as e:
+ raise HTTPException(status_code=500, detail=f"Error saving log: {e}")
+ return {"message": "Log saved successfully"}
+
+
+@app.get("/models")
+def get_models():
+ # llama7B = AIModel(
+ # id="llama-2-7b-chat-ggml-q4_0",
+ # name="llama-2-7b-chat-ggml-q4_0",
+ # maxLength=2048,
+ # tokenLimit=2048,
+ # )
+ # llama13B = AIModel(
+ # id="llama-2-13b-chat-ggml-q4_0",
+ # name="llama-2-13b-chat-ggml-q4_0",
+ # maxLength=2048,
+ # tokenLimit=2048,
+ # )
+ llama70B = AIModel(
+ id="llama-2-70b.Q5_K_M",
+ name="llama-2-70b.Q5_K_M",
+ maxLength=2048,
+ tokenLimit=2048,
+ )
+ models = AIModels(models=[llama70B])
+ return models
+
+
+@app.get("/titles")
+def getTitles():
+ titles = RAGFiles(
+ titles=[
+ # RAGFile(
+ # versionId="d8ad3b1d-c33c-4524-9691-e93967d4d863",
+ # title="d8ad3b1d-c33c-4524-9691-e93967d4d863",
+ # state=State.Unavailable,
+ # ),
+ RAGFile(
+ versionId=collection.name,
+ title=collection.name,
+ state=State.InProcess
+ if collection.name in processing_books
+ else State.Processed,
+ )
+ for collection in vector_store.list_collections()
+ if collection.name != "langchain"
+ ]
+ )
+ return titles
+
+
+processing_books: list[str] = []
+processing_books_lock = asyncio.Lock()
+
+logging.basicConfig(level=logging.ERROR)
+
+
+@app.post("/titleState")
+async def getTitleState(request: GetRAGFileStateRequest):
+ # FastAPI + Pydantic will throw a 422 Unprocessable Entity if the request isn't the right type.
+ # try:
+ logging.debug(f"Received getTitleState request: {request}")
+ titleStateRequest: GetRAGFileStateRequest = request
+ # except ValidationError as e:
+ # print(f"Error validating JSON: {e}")
+ # raise HTTPException(status_code=422, detail=str(e))
+ # except json.JSONDecodeError as e:
+ # print(f"Error parsing JSON: {e}")
+ # raise HTTPException(status_code=422, detail="Invalid JSON format")
+ # check to see if the book has already been processed.
+ # return the proper State directly to response.
+ matchingCollection = next(
+ (
+ x
+ for x in vector_store.list_collections()
+ if x.name == titleStateRequest.versionRef
+ ),
+ None,
+ )
+ print("Got a Title State request for version " + titleStateRequest.versionRef)
+ if titleStateRequest.versionRef in processing_books:
+ return {"message": State.InProcess}
+ elif matchingCollection is not None:
+ return {"message": State.Processed}
+ else:
+ return {"message": State.Unavailable}
+
+
+@app.post("/processRAGFile")
+async def processRAGFile(
+ request: str = Form(...),
+ files: List[UploadFile] = File(...),
+):
+ try:
+ logging.debug(f"Received processBook request: {request}")
+ # Parse the JSON string into a ProcessBookRequest object
+ fileRAGRequest: ProcessRAGFileRequest = parse_obj_as(
+ ProcessRAGFileRequest, json.loads(request)
+ )
+ except ValidationError as e:
+ print(f"Error validating JSON: {e}")
+ raise HTTPException(status_code=422, detail=str(e))
+ except json.JSONDecodeError as e:
+ print(f"Error parsing JSON: {e}")
+ raise HTTPException(status_code=422, detail="Invalid JSON format")
+
+ try:
+ print(
+ f"Processing file {fileRAGRequest.filename} for user {fileRAGRequest.username}."
+ )
+ # check to see if the file has already been processed.
+ # write html to subfolder
+ print(f"Writing file to path: {fileRAGRequest.username}/{fileRAGRequest.filename}...")
+
+ for index, segment in enumerate(files):
+ filename = segment.filename if segment.filename else str(index)
+ subDir = f"{fileRAGRequest.username}"
+ with open(os.path.join(subDir, filename), "wb") as htmlFile:
+ htmlFile.write(await segment.read())
+
+ # write metadata to subfolder
+ print(f"Writing metadata to subfolder {fileRAGRequest.username}...")
+ with open(os.path.join({fileRAGRequest.username}, "metadata.json"), "w") as metadataFile:
+ metaData = {
+ "filename": fileRAGRequest.filename,
+ "username": fileRAGRequest.username,
+ "processDate": datetime.now().isoformat(),
+ }
+ metadataFile.write(json.dumps(metaData))
+
+ vector_store.retrievers[
+ f"{fileRAGRequest.username}/{fileRAGRequest.filename}"
+ ] = await vector_store.initRetriever(f"{fileRAGRequest.username}/{fileRAGRequest.filename}")
+
+ return {
+ "message": f"File {fileRAGRequest.filename} processed successfully."
+ }
+ except Exception as e:
+ logging.error(f"Error processing book: {e}")
+ return {"message": f"Error processing book: {e}"}
+
+
+@app.exception_handler(HTTPException)
+async def http_exception_handler(bookRequest: Request, exc: HTTPException):
+ logging.error(f"HTTPException: {exc.detail}")
+ return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail})
+
diff --git a/swarms/server/server_models.py b/swarms/server/server_models.py
new file mode 100644
index 00000000..1c4491b7
--- /dev/null
+++ b/swarms/server/server_models.py
@@ -0,0 +1,88 @@
+try:
+ from enum import StrEnum
+except ImportError:
+ from strenum import StrEnum
+
+from pydantic import BaseModel
+from swarms.prompts import QA_PROMPT_TEMPLATE_STR as DefaultSystemPrompt
+
+class AIModel(BaseModel):
+ id: str
+ name: str
+ maxLength: int
+ tokenLimit: int
+
+
+class AIModels(BaseModel):
+ models: list[AIModel]
+
+
+class State(StrEnum):
+ Unavailable = "Unavailable"
+ InProcess = "InProcess"
+ Processed = "Processed"
+
+
+class RAGFile(BaseModel):
+ filename: str
+ title: str
+ username: str
+ state: State = State.Unavailable
+
+
+class RAGFiles(BaseModel):
+ files: list[RAGFile]
+
+
+class Role(StrEnum):
+ SYSTEM = "system"
+ ASSISTANT = "assistant"
+ USER = "user"
+
+
+class Message(BaseModel):
+ role: Role
+ content: str
+
+
+class ChatRequest(BaseModel):
+ id: str
+ model: AIModel = AIModel(
+ id="llama-2-70b.Q5_K_M",
+ name="llama-2-70b.Q5_K_M",
+ maxLength=2048,
+ tokenLimit=2048,
+ )
+ messages: list[Message] = [
+ Message(role=Role.SYSTEM, content="Hello, how may I help you?"),
+ Message(role=Role.USER, content=""),
+ ]
+ maxTokens: int = 2048
+ temperature: float = 0
+ prompt: str = DefaultSystemPrompt
+ file: RAGFile = RAGFile(filename="None", title="None", username="None")
+
+
+class LogMessage(BaseModel):
+ message: str
+
+
+class ConversationRequest(BaseModel):
+ id: str
+ name: str
+ title: RAGFile
+ messages: list[Message]
+ model: AIModel
+ prompt: str
+ temperature: float
+ folderId: str | None = None
+
+
+class ProcessRAGFileRequest(BaseModel):
+ filename: str
+ username: str
+
+
+class GetRAGFileStateRequest(BaseModel):
+ filename: str
+ username: str
\ No newline at end of file
diff --git a/swarms/server/server_prompts.py b/swarms/server/server_prompts.py
new file mode 100644
index 00000000..00666cb7
--- /dev/null
+++ b/swarms/server/server_prompts.py
@@ -0,0 +1,78 @@
+from langchain.prompts.prompt import PromptTemplate
+
+B_INST, E_INST = "[INST]", "[/INST]"
+B_SYS, E_SYS = "<>\n", "\n<>\n\n"
+
+QA_CONDENSE_TEMPLATE_STR = (
+ "Given the following Chat History and a Follow Up Question, "
+ "rephrase the follow up question to be a new Standalone Question, "
+ "but make sure the new question is still asking for the same "
+ "information as the original follow up question. Respond only "
+ " with the new Standalone Question. \n"
+ "Chat History: \n"
+ "{chat_history} \n"
+ "Follow Up Question: {question} \n"
+ "Standalone Question:"
+)
+
+CONDENSE_TEMPLATE = PromptTemplate.from_template(
+ f"{B_INST}{B_SYS}{QA_CONDENSE_TEMPLATE_STR.strip()}{E_SYS}{E_INST}"
+)
+
+QA_PROMPT_TEMPLATE_STR = (
+ "HUMAN: \n You are a helpful AI assistant. "
+ "Use the following context and chat history to answer the "
+ "question at the end with a helpful answer. "
+ "Get straight to the point and always think things through step-by-step before answering. "
+ "If you don't know the answer, just say 'I don't know'; "
+ "don't try to make up an answer. \n\n"
+ "{context}\n"
+ "{chat_history}\n"
+ "{question}\n\n"
+ "AI: Here is the most relevant sentence in the context: \n"
+)
+
+QA_PROMPT_TEMPLATE = PromptTemplate.from_template(
+ f"{B_INST}{B_SYS}{QA_PROMPT_TEMPLATE_STR.strip()}{E_SYS}{E_INST}"
+)
+
+DOCUMENT_PROMPT_TEMPLATE = PromptTemplate(
+ input_variables=["page_content"], template="{page_content}"
+)
+
+_STUFF_PROMPT_TEMPLATE_STR = "Summarize the following context: {context}"
+
+STUFF_PROMPT_TEMPLATE = PromptTemplate.from_template(
+ f"{B_INST}{B_SYS}{_STUFF_PROMPT_TEMPLATE_STR.strip()}{E_SYS}{E_INST}"
+)
+
+_SUMMARIZER_SYS_TEMPLATE = (
+ B_INST
+ + B_SYS
+ + """Progressively summarize the lines of conversation provided, adding onto the previous summary returning a new summary.
+ EXAMPLE
+ Current summary:
+ The human asks what the AI thinks of artificial intelligence. The AI thinks artificial intelligence is a force for good.
+ New lines of conversation:
+ Human: Why do you think artificial intelligence is a force for good?
+ AI: Because artificial intelligence will help humans reach their full potential.
+ New summary:
+ The human asks what the AI thinks of artificial intelligence. The AI thinks artificial intelligence is a force for good because it will help humans reach their full potential.
+ END OF EXAMPLE"""
+ + E_SYS
+ + E_INST
+)
+
+_SUMMARIZER_INST_TEMPLATE = (
+ B_INST
+ + """Current summary:
+ {summary}
+ New lines of conversation:
+ {new_lines}
+ New summary:"""
+ + E_INST
+)
+
+SUMMARY_PROMPT = PromptTemplate.from_template(
+ template=(_SUMMARIZER_SYS_TEMPLATE + "\n" + _SUMMARIZER_INST_TEMPLATE).strip()
+)
\ No newline at end of file
diff --git a/swarms/server/static/favicon.ico b/swarms/server/static/favicon.ico
new file mode 100644
index 00000000..988bba1c
Binary files /dev/null and b/swarms/server/static/favicon.ico differ
diff --git a/swarms/server/utils.py b/swarms/server/utils.py
new file mode 100644
index 00000000..89e3d7c8
--- /dev/null
+++ b/swarms/server/utils.py
@@ -0,0 +1,51 @@
+# modified from Lanarky source https://github.com/auxon/lanarky
+from typing import Any
+
+import pydantic
+from pydantic.fields import FieldInfo
+
+try:
+ from enum import StrEnum # type: ignore
+except ImportError:
+ from enum import Enum
+
+ class StrEnum(str, Enum): ...
+
+
+PYDANTIC_V2 = pydantic.VERSION.startswith("2.")
+
+
+def model_dump(model: pydantic.BaseModel, **kwargs) -> dict[str, Any]:
+ """Dump a pydantic model to a dictionary.
+
+ Args:
+ model: A pydantic model.
+ """
+ if PYDANTIC_V2:
+ return model.model_dump(**kwargs)
+ else:
+ return model.dict(**kwargs)
+
+
+def model_dump_json(model: pydantic.BaseModel, **kwargs) -> str:
+ """Dump a pydantic model to a JSON string.
+
+ Args:
+ model: A pydantic model.
+ """
+ if PYDANTIC_V2:
+ return model.model_dump_json(**kwargs)
+ else:
+ return model.json(**kwargs)
+
+
+def model_fields(model: pydantic.BaseModel) -> dict[str, FieldInfo]:
+ """Get the fields of a pydantic model.
+
+ Args:
+ model: A pydantic model.
+ """
+ if PYDANTIC_V2:
+ return model.model_fields
+ else:
+ return model.__fields__
\ No newline at end of file