diff --git a/swarms/server/async_parent_document_retriever.py b/swarms/server/async_parent_document_retriever.py index fe46422d..20c1e825 100644 --- a/swarms/server/async_parent_document_retriever.py +++ b/swarms/server/async_parent_document_retriever.py @@ -1,6 +1,8 @@ +""" AsyncParentDocumentRetriever is used by RAG +to split up documents into smaller *and* larger related chunks. """ import pickle import uuid -from typing import ClassVar, Collection, List, Optional, Tuple +from typing import Any, ClassVar, Collection, List, Optional, Tuple from langchain.callbacks.manager import ( AsyncCallbackManagerForRetrieverRun, @@ -161,6 +163,7 @@ class AsyncParentDocumentRetriever(ParentDocumentRetriever): documents: List[Document], ids: Optional[List[str]] = None, add_to_docstore: bool = True, + **kwargs: Any ) -> None: """Adds documents to the docstore and vectorstores. @@ -215,6 +218,7 @@ class AsyncParentDocumentRetriever(ParentDocumentRetriever): documents: List[Document], ids: Optional[List[str]] = None, add_to_docstore: bool = True, + **kwargs: Any ) -> None: """Adds documents to the docstore and vectorstores. @@ -251,7 +255,7 @@ class AsyncParentDocumentRetriever(ParentDocumentRetriever): if len(documents) < 1: return - + for i, doc in enumerate(documents): _id = doc_ids[i] sub_docs = self.child_splitter.split_documents([doc]) @@ -275,4 +279,4 @@ class AsyncParentDocumentRetriever(ParentDocumentRetriever): serialized_docs = [(id, pickle.dumps(doc)) for id, doc in full_docs] self.docstore.mset(serialized_docs) else: - self.docstore.mset(full_docs) \ No newline at end of file + self.docstore.mset(full_docs) diff --git a/swarms/server/responses.py b/swarms/server/responses.py index 5b1785e1..48c9cae9 100644 --- a/swarms/server/responses.py +++ b/swarms/server/responses.py @@ -1,22 +1,19 @@ -from typing import Any +""" Customized Langchain StreamingResponse for Server-Side Events (SSE) """ import asyncio from functools import partial from typing import Any from fastapi import status from langchain.chains.base import Chain -from starlette.types import Send -from fastapi import status from sse_starlette import ServerSentEvent -from sse_starlette.sse import EventSourceResponse +from sse_starlette.sse import EventSourceResponse, ensure_bytes from starlette.types import Send from swarms.server.utils import StrEnum -from sse_starlette.sse import ensure_bytes - class HTTPStatusDetail(StrEnum): + """ HTTP error descriptions. """ INTERNAL_SERVER_ERROR = "Internal Server Error" @@ -24,13 +21,14 @@ class StreamingResponse(EventSourceResponse): """`Response` class for streaming server-sent events. Follows the - [EventSource protocol](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events#interfaces) + [EventSource protocol] + (https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events#interfaces) """ def __init__( self, - content: Any = iter(()), *args: Any, + content: Any = iter(()), **kwargs: dict[str, Any], ) -> None: """Constructor method. @@ -97,10 +95,10 @@ class LangchainStreamingResponse(StreamingResponse): def __init__( self, + *args: Any, chain: Chain, config: dict[str, Any], run_mode: ChainRunMode = ChainRunMode.ASYNC, - *args: Any, **kwargs: dict[str, Any], ) -> None: """Constructor method. @@ -146,8 +144,6 @@ class LangchainStreamingResponse(StreamingResponse): callback.send = send try: - # TODO: migrate to `.ainvoke` when adding support - # for LCEL if self.run_mode == ChainRunMode.ASYNC: async for outputs in self.chain.astream(input=self.config): if 'answer' in outputs: @@ -156,7 +152,11 @@ class LangchainStreamingResponse(StreamingResponse): ) # Send each chunk with the appropriate body type await send( - {"type": "http.response.body", "body": ensure_bytes(chunk, None), "more_body": True} + { + "type": "http.response.body", + "body": ensure_bytes(chunk, None), + "more_body": True + } ) else: @@ -185,4 +185,4 @@ class LangchainStreamingResponse(StreamingResponse): } ) - await send({"type": "http.response.body", "body": b"", "more_body": False}) \ No newline at end of file + await send({"type": "http.response.body", "body": b"", "more_body": False}) diff --git a/swarms/server/server_models.py b/swarms/server/server_models.py index 1c4491b7..8b029931 100644 --- a/swarms/server/server_models.py +++ b/swarms/server/server_models.py @@ -1,3 +1,4 @@ +""" Chatbot Server API Models """ try: from enum import StrEnum except ImportError: @@ -7,45 +8,48 @@ from pydantic import BaseModel from swarms.prompts import QA_PROMPT_TEMPLATE_STR as DefaultSystemPrompt class AIModel(BaseModel): + """ Defines the model a user selected. """ id: str name: str maxLength: int tokenLimit: int -class AIModels(BaseModel): - models: list[AIModel] - - class State(StrEnum): - Unavailable = "Unavailable" - InProcess = "InProcess" - Processed = "Processed" + """ State of RAGFile that's been uploaded. """ + UNAVAILABLE = "UNAVAILABLE" + PROCESSING = "PROCESSING" + PROCESSED = "PROCESSED" class RAGFile(BaseModel): + """ Defines a file uploaded by the users for RAG processing. """ filename: str title: str username: str - state: State = State.Unavailable + state: State = State.UNAVAILABLE class RAGFiles(BaseModel): + """ Defines a list of RAGFile objects. """ files: list[RAGFile] class Role(StrEnum): + """ The role of a message in a conversation. """ SYSTEM = "system" ASSISTANT = "assistant" USER = "user" class Message(BaseModel): + """ Defines the type of a Message with a role and content. """ role: Role content: str class ChatRequest(BaseModel): + """ The model for a ChatRequest expected by the Chatbot Chat POST endpoint. """ id: str model: AIModel = AIModel( id="llama-2-70b.Q5_K_M", @@ -61,28 +65,3 @@ class ChatRequest(BaseModel): temperature: float = 0 prompt: str = DefaultSystemPrompt file: RAGFile = RAGFile(filename="None", title="None", username="None") - - -class LogMessage(BaseModel): - message: str - - -class ConversationRequest(BaseModel): - id: str - name: str - title: RAGFile - messages: list[Message] - model: AIModel - prompt: str - temperature: float - folderId: str | None = None - - -class ProcessRAGFileRequest(BaseModel): - filename: str - username: str - - -class GetRAGFileStateRequest(BaseModel): - filename: str - username: str \ No newline at end of file diff --git a/swarms/server/utils.py b/swarms/server/utils.py index 89e3d7c8..a8cc3048 100644 --- a/swarms/server/utils.py +++ b/swarms/server/utils.py @@ -1,20 +1,10 @@ -# modified from Lanarky source https://github.com/auxon/lanarky +""" modified from Lanarky source https://github.com/auxon/lanarky """ from typing import Any - import pydantic from pydantic.fields import FieldInfo - -try: - from enum import StrEnum # type: ignore -except ImportError: - from enum import Enum - - class StrEnum(str, Enum): ... - - PYDANTIC_V2 = pydantic.VERSION.startswith("2.") - - + + def model_dump(model: pydantic.BaseModel, **kwargs) -> dict[str, Any]: """Dump a pydantic model to a dictionary. @@ -23,10 +13,10 @@ def model_dump(model: pydantic.BaseModel, **kwargs) -> dict[str, Any]: """ if PYDANTIC_V2: return model.model_dump(**kwargs) - else: - return model.dict(**kwargs) - - + + return model.dict(**kwargs) + + def model_dump_json(model: pydantic.BaseModel, **kwargs) -> str: """Dump a pydantic model to a JSON string. @@ -35,10 +25,10 @@ def model_dump_json(model: pydantic.BaseModel, **kwargs) -> str: """ if PYDANTIC_V2: return model.model_dump_json(**kwargs) - else: - return model.json(**kwargs) - - + + return model.json(**kwargs) + + def model_fields(model: pydantic.BaseModel) -> dict[str, FieldInfo]: """Get the fields of a pydantic model. @@ -47,5 +37,5 @@ def model_fields(model: pydantic.BaseModel) -> dict[str, FieldInfo]: """ if PYDANTIC_V2: return model.model_fields - else: - return model.__fields__ \ No newline at end of file + + return model.__fields__