remove extra dependencies and fix more lint errors

pull/570/head
Richard Anthony Hein 8 months ago
parent 4df911b781
commit f021ed168a

@ -9,12 +9,6 @@ from sse_starlette import ServerSentEvent
from sse_starlette.sse import EventSourceResponse, ensure_bytes from sse_starlette.sse import EventSourceResponse, ensure_bytes
from starlette.types import Send from starlette.types import Send
from swarms.server.utils import StrEnum
class HTTPStatusDetail(StrEnum):
""" HTTP error descriptions. """
INTERNAL_SERVER_ERROR = "Internal Server Error"
class StreamingResponse(EventSourceResponse): class StreamingResponse(EventSourceResponse):
@ -67,7 +61,7 @@ class StreamingResponse(EventSourceResponse):
chunk = ServerSentEvent( chunk = ServerSentEvent(
data=dict( data=dict(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=HTTPStatusDetail.INTERNAL_SERVER_ERROR, detail="Internal Server Error",
), ),
event="error", event="error",
) )
@ -82,14 +76,6 @@ class StreamingResponse(EventSourceResponse):
await send({"type": "http.response.body", "body": b"", "more_body": False}) await send({"type": "http.response.body", "body": b"", "more_body": False})
class ChainRunMode(StrEnum):
"""Enum for LangChain run modes."""
ASYNC = "async"
SYNC = "sync"
class LangchainStreamingResponse(StreamingResponse): class LangchainStreamingResponse(StreamingResponse):
"""StreamingResponse class for LangChain resources.""" """StreamingResponse class for LangChain resources."""
@ -98,7 +84,7 @@ class LangchainStreamingResponse(StreamingResponse):
*args: Any, *args: Any,
chain: Chain, chain: Chain,
config: dict[str, Any], config: dict[str, Any],
run_mode: ChainRunMode = ChainRunMode.ASYNC, run_mode: str,
**kwargs: dict[str, Any], **kwargs: dict[str, Any],
) -> None: ) -> None:
"""Constructor method. """Constructor method.
@ -114,9 +100,9 @@ class LangchainStreamingResponse(StreamingResponse):
self.chain = chain self.chain = chain
self.config = config self.config = config
if run_mode not in list(ChainRunMode): if run_mode not in list(["async", "sync"]):
raise ValueError( raise ValueError(
f"Invalid run mode '{run_mode}'. Must be one of {list(ChainRunMode)}" f"Invalid run mode '{run_mode}'. Must be one of {list(['async', 'sync'])}"
) )
self.run_mode = run_mode self.run_mode = run_mode
@ -144,7 +130,7 @@ class LangchainStreamingResponse(StreamingResponse):
callback.send = send callback.send = send
try: try:
if self.run_mode == ChainRunMode.ASYNC: if self.run_mode == "async":
async for outputs in self.chain.astream(input=self.config): async for outputs in self.chain.astream(input=self.config):
if 'answer' in outputs: if 'answer' in outputs:
chunk = ServerSentEvent( chunk = ServerSentEvent(
@ -173,7 +159,7 @@ class LangchainStreamingResponse(StreamingResponse):
chunk = ServerSentEvent( chunk = ServerSentEvent(
data=dict( data=dict(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=HTTPStatusDetail.INTERNAL_SERVER_ERROR, detail="Internal Server Error",
), ),
event="error", event="error",
) )

Loading…
Cancel
Save