fixed lint errors

pull/570/head
Richard Anthony Hein 8 months ago
parent 5cf7b8d798
commit 15578a3555

@ -1,6 +1,8 @@
""" AsyncParentDocumentRetriever is used by RAG
to split up documents into smaller *and* larger related chunks. """
import pickle import pickle
import uuid import uuid
from typing import ClassVar, Collection, List, Optional, Tuple from typing import Any, ClassVar, Collection, List, Optional, Tuple
from langchain.callbacks.manager import ( from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun, AsyncCallbackManagerForRetrieverRun,
@ -161,6 +163,7 @@ class AsyncParentDocumentRetriever(ParentDocumentRetriever):
documents: List[Document], documents: List[Document],
ids: Optional[List[str]] = None, ids: Optional[List[str]] = None,
add_to_docstore: bool = True, add_to_docstore: bool = True,
**kwargs: Any
) -> None: ) -> None:
"""Adds documents to the docstore and vectorstores. """Adds documents to the docstore and vectorstores.
@ -215,6 +218,7 @@ class AsyncParentDocumentRetriever(ParentDocumentRetriever):
documents: List[Document], documents: List[Document],
ids: Optional[List[str]] = None, ids: Optional[List[str]] = None,
add_to_docstore: bool = True, add_to_docstore: bool = True,
**kwargs: Any
) -> None: ) -> None:
"""Adds documents to the docstore and vectorstores. """Adds documents to the docstore and vectorstores.

@ -1,22 +1,19 @@
from typing import Any """ Customized Langchain StreamingResponse for Server-Side Events (SSE) """
import asyncio import asyncio
from functools import partial from functools import partial
from typing import Any from typing import Any
from fastapi import status from fastapi import status
from langchain.chains.base import Chain from langchain.chains.base import Chain
from starlette.types import Send
from fastapi import status
from sse_starlette import ServerSentEvent from sse_starlette import ServerSentEvent
from sse_starlette.sse import EventSourceResponse from sse_starlette.sse import EventSourceResponse, ensure_bytes
from starlette.types import Send from starlette.types import Send
from swarms.server.utils import StrEnum from swarms.server.utils import StrEnum
from sse_starlette.sse import ensure_bytes
class HTTPStatusDetail(StrEnum): class HTTPStatusDetail(StrEnum):
""" HTTP error descriptions. """
INTERNAL_SERVER_ERROR = "Internal Server Error" INTERNAL_SERVER_ERROR = "Internal Server Error"
@ -24,13 +21,14 @@ class StreamingResponse(EventSourceResponse):
"""`Response` class for streaming server-sent events. """`Response` class for streaming server-sent events.
Follows the Follows the
[EventSource protocol](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events#interfaces) [EventSource protocol]
(https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events#interfaces)
""" """
def __init__( def __init__(
self, self,
content: Any = iter(()),
*args: Any, *args: Any,
content: Any = iter(()),
**kwargs: dict[str, Any], **kwargs: dict[str, Any],
) -> None: ) -> None:
"""Constructor method. """Constructor method.
@ -97,10 +95,10 @@ class LangchainStreamingResponse(StreamingResponse):
def __init__( def __init__(
self, self,
*args: Any,
chain: Chain, chain: Chain,
config: dict[str, Any], config: dict[str, Any],
run_mode: ChainRunMode = ChainRunMode.ASYNC, run_mode: ChainRunMode = ChainRunMode.ASYNC,
*args: Any,
**kwargs: dict[str, Any], **kwargs: dict[str, Any],
) -> None: ) -> None:
"""Constructor method. """Constructor method.
@ -146,8 +144,6 @@ class LangchainStreamingResponse(StreamingResponse):
callback.send = send callback.send = send
try: try:
# TODO: migrate to `.ainvoke` when adding support
# for LCEL
if self.run_mode == ChainRunMode.ASYNC: if self.run_mode == ChainRunMode.ASYNC:
async for outputs in self.chain.astream(input=self.config): async for outputs in self.chain.astream(input=self.config):
if 'answer' in outputs: if 'answer' in outputs:
@ -156,7 +152,11 @@ class LangchainStreamingResponse(StreamingResponse):
) )
# Send each chunk with the appropriate body type # Send each chunk with the appropriate body type
await send( await send(
{"type": "http.response.body", "body": ensure_bytes(chunk, None), "more_body": True} {
"type": "http.response.body",
"body": ensure_bytes(chunk, None),
"more_body": True
}
) )
else: else:

@ -1,3 +1,4 @@
""" Chatbot Server API Models """
try: try:
from enum import StrEnum from enum import StrEnum
except ImportError: except ImportError:
@ -7,45 +8,48 @@ from pydantic import BaseModel
from swarms.prompts import QA_PROMPT_TEMPLATE_STR as DefaultSystemPrompt from swarms.prompts import QA_PROMPT_TEMPLATE_STR as DefaultSystemPrompt
class AIModel(BaseModel): class AIModel(BaseModel):
""" Defines the model a user selected. """
id: str id: str
name: str name: str
maxLength: int maxLength: int
tokenLimit: int tokenLimit: int
class AIModels(BaseModel):
models: list[AIModel]
class State(StrEnum): class State(StrEnum):
Unavailable = "Unavailable" """ State of RAGFile that's been uploaded. """
InProcess = "InProcess" UNAVAILABLE = "UNAVAILABLE"
Processed = "Processed" PROCESSING = "PROCESSING"
PROCESSED = "PROCESSED"
class RAGFile(BaseModel): class RAGFile(BaseModel):
""" Defines a file uploaded by the users for RAG processing. """
filename: str filename: str
title: str title: str
username: str username: str
state: State = State.Unavailable state: State = State.UNAVAILABLE
class RAGFiles(BaseModel): class RAGFiles(BaseModel):
""" Defines a list of RAGFile objects. """
files: list[RAGFile] files: list[RAGFile]
class Role(StrEnum): class Role(StrEnum):
""" The role of a message in a conversation. """
SYSTEM = "system" SYSTEM = "system"
ASSISTANT = "assistant" ASSISTANT = "assistant"
USER = "user" USER = "user"
class Message(BaseModel): class Message(BaseModel):
""" Defines the type of a Message with a role and content. """
role: Role role: Role
content: str content: str
class ChatRequest(BaseModel): class ChatRequest(BaseModel):
""" The model for a ChatRequest expected by the Chatbot Chat POST endpoint. """
id: str id: str
model: AIModel = AIModel( model: AIModel = AIModel(
id="llama-2-70b.Q5_K_M", id="llama-2-70b.Q5_K_M",
@ -61,28 +65,3 @@ class ChatRequest(BaseModel):
temperature: float = 0 temperature: float = 0
prompt: str = DefaultSystemPrompt prompt: str = DefaultSystemPrompt
file: RAGFile = RAGFile(filename="None", title="None", username="None") file: RAGFile = RAGFile(filename="None", title="None", username="None")
class LogMessage(BaseModel):
message: str
class ConversationRequest(BaseModel):
id: str
name: str
title: RAGFile
messages: list[Message]
model: AIModel
prompt: str
temperature: float
folderId: str | None = None
class ProcessRAGFileRequest(BaseModel):
filename: str
username: str
class GetRAGFileStateRequest(BaseModel):
filename: str
username: str

@ -1,17 +1,7 @@
# modified from Lanarky source https://github.com/auxon/lanarky """ modified from Lanarky source https://github.com/auxon/lanarky """
from typing import Any from typing import Any
import pydantic import pydantic
from pydantic.fields import FieldInfo from pydantic.fields import FieldInfo
try:
from enum import StrEnum # type: ignore
except ImportError:
from enum import Enum
class StrEnum(str, Enum): ...
PYDANTIC_V2 = pydantic.VERSION.startswith("2.") PYDANTIC_V2 = pydantic.VERSION.startswith("2.")
@ -23,7 +13,7 @@ def model_dump(model: pydantic.BaseModel, **kwargs) -> dict[str, Any]:
""" """
if PYDANTIC_V2: if PYDANTIC_V2:
return model.model_dump(**kwargs) return model.model_dump(**kwargs)
else:
return model.dict(**kwargs) return model.dict(**kwargs)
@ -35,7 +25,7 @@ def model_dump_json(model: pydantic.BaseModel, **kwargs) -> str:
""" """
if PYDANTIC_V2: if PYDANTIC_V2:
return model.model_dump_json(**kwargs) return model.model_dump_json(**kwargs)
else:
return model.json(**kwargs) return model.json(**kwargs)
@ -47,5 +37,5 @@ def model_fields(model: pydantic.BaseModel) -> dict[str, FieldInfo]:
""" """
if PYDANTIC_V2: if PYDANTIC_V2:
return model.model_fields return model.model_fields
else:
return model.__fields__ return model.__fields__
Loading…
Cancel
Save