|
|
|
@ -9,12 +9,6 @@ from sse_starlette import ServerSentEvent
|
|
|
|
|
from sse_starlette.sse import EventSourceResponse, ensure_bytes
|
|
|
|
|
from starlette.types import Send
|
|
|
|
|
|
|
|
|
|
from swarms.server.utils import StrEnum
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class HTTPStatusDetail(StrEnum):
|
|
|
|
|
""" HTTP error descriptions. """
|
|
|
|
|
INTERNAL_SERVER_ERROR = "Internal Server Error"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class StreamingResponse(EventSourceResponse):
|
|
|
|
@ -67,7 +61,7 @@ class StreamingResponse(EventSourceResponse):
|
|
|
|
|
chunk = ServerSentEvent(
|
|
|
|
|
data=dict(
|
|
|
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
|
|
|
detail=HTTPStatusDetail.INTERNAL_SERVER_ERROR,
|
|
|
|
|
detail="Internal Server Error",
|
|
|
|
|
),
|
|
|
|
|
event="error",
|
|
|
|
|
)
|
|
|
|
@ -82,14 +76,6 @@ class StreamingResponse(EventSourceResponse):
|
|
|
|
|
await send({"type": "http.response.body", "body": b"", "more_body": False})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ChainRunMode(StrEnum):
|
|
|
|
|
"""Enum for LangChain run modes."""
|
|
|
|
|
|
|
|
|
|
ASYNC = "async"
|
|
|
|
|
SYNC = "sync"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LangchainStreamingResponse(StreamingResponse):
|
|
|
|
|
"""StreamingResponse class for LangChain resources."""
|
|
|
|
|
|
|
|
|
@ -98,7 +84,7 @@ class LangchainStreamingResponse(StreamingResponse):
|
|
|
|
|
*args: Any,
|
|
|
|
|
chain: Chain,
|
|
|
|
|
config: dict[str, Any],
|
|
|
|
|
run_mode: ChainRunMode = ChainRunMode.ASYNC,
|
|
|
|
|
run_mode: str,
|
|
|
|
|
**kwargs: dict[str, Any],
|
|
|
|
|
) -> None:
|
|
|
|
|
"""Constructor method.
|
|
|
|
@ -114,9 +100,9 @@ class LangchainStreamingResponse(StreamingResponse):
|
|
|
|
|
self.chain = chain
|
|
|
|
|
self.config = config
|
|
|
|
|
|
|
|
|
|
if run_mode not in list(ChainRunMode):
|
|
|
|
|
if run_mode not in list(["async", "sync"]):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Invalid run mode '{run_mode}'. Must be one of {list(ChainRunMode)}"
|
|
|
|
|
f"Invalid run mode '{run_mode}'. Must be one of {list(['async', 'sync'])}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self.run_mode = run_mode
|
|
|
|
@ -144,7 +130,7 @@ class LangchainStreamingResponse(StreamingResponse):
|
|
|
|
|
callback.send = send
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
if self.run_mode == ChainRunMode.ASYNC:
|
|
|
|
|
if self.run_mode == "async":
|
|
|
|
|
async for outputs in self.chain.astream(input=self.config):
|
|
|
|
|
if 'answer' in outputs:
|
|
|
|
|
chunk = ServerSentEvent(
|
|
|
|
@ -173,7 +159,7 @@ class LangchainStreamingResponse(StreamingResponse):
|
|
|
|
|
chunk = ServerSentEvent(
|
|
|
|
|
data=dict(
|
|
|
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
|
|
|
detail=HTTPStatusDetail.INTERNAL_SERVER_ERROR,
|
|
|
|
|
detail="Internal Server Error",
|
|
|
|
|
),
|
|
|
|
|
event="error",
|
|
|
|
|
)
|
|
|
|
|