From f021ed168a2722817e73cfc0b7cb764adaeffe9a Mon Sep 17 00:00:00 2001 From: Richard Anthony Hein Date: Mon, 19 Aug 2024 20:33:16 +0000 Subject: [PATCH] remove extra dependencies and fix more lint errors --- swarms/server/responses.py | 26 ++++++-------------------- 1 file changed, 6 insertions(+), 20 deletions(-) diff --git a/swarms/server/responses.py b/swarms/server/responses.py index 48c9cae9..51d2087a 100644 --- a/swarms/server/responses.py +++ b/swarms/server/responses.py @@ -9,12 +9,6 @@ from sse_starlette import ServerSentEvent from sse_starlette.sse import EventSourceResponse, ensure_bytes from starlette.types import Send -from swarms.server.utils import StrEnum - - -class HTTPStatusDetail(StrEnum): - """ HTTP error descriptions. """ - INTERNAL_SERVER_ERROR = "Internal Server Error" class StreamingResponse(EventSourceResponse): @@ -67,7 +61,7 @@ class StreamingResponse(EventSourceResponse): chunk = ServerSentEvent( data=dict( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=HTTPStatusDetail.INTERNAL_SERVER_ERROR, + detail="Internal Server Error", ), event="error", ) @@ -82,14 +76,6 @@ class StreamingResponse(EventSourceResponse): await send({"type": "http.response.body", "body": b"", "more_body": False}) - -class ChainRunMode(StrEnum): - """Enum for LangChain run modes.""" - - ASYNC = "async" - SYNC = "sync" - - class LangchainStreamingResponse(StreamingResponse): """StreamingResponse class for LangChain resources.""" @@ -98,7 +84,7 @@ class LangchainStreamingResponse(StreamingResponse): *args: Any, chain: Chain, config: dict[str, Any], - run_mode: ChainRunMode = ChainRunMode.ASYNC, + run_mode: str, **kwargs: dict[str, Any], ) -> None: """Constructor method. @@ -114,9 +100,9 @@ class LangchainStreamingResponse(StreamingResponse): self.chain = chain self.config = config - if run_mode not in list(ChainRunMode): + if run_mode not in list(["async", "sync"]): raise ValueError( - f"Invalid run mode '{run_mode}'. Must be one of {list(ChainRunMode)}" + f"Invalid run mode '{run_mode}'. Must be one of {list(['async', 'sync'])}" ) self.run_mode = run_mode @@ -144,7 +130,7 @@ class LangchainStreamingResponse(StreamingResponse): callback.send = send try: - if self.run_mode == ChainRunMode.ASYNC: + if self.run_mode == "async": async for outputs in self.chain.astream(input=self.config): if 'answer' in outputs: chunk = ServerSentEvent( @@ -173,7 +159,7 @@ class LangchainStreamingResponse(StreamingResponse): chunk = ServerSentEvent( data=dict( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=HTTPStatusDetail.INTERNAL_SERVER_ERROR, + detail="Internal Server Error", ), event="error", )