|
|
|
@ -4,7 +4,6 @@ from functools import partial
|
|
|
|
|
from typing import Any, AsyncIterator
|
|
|
|
|
|
|
|
|
|
from fastapi import status
|
|
|
|
|
from langchain.chains.base import Chain
|
|
|
|
|
from sse_starlette import ServerSentEvent
|
|
|
|
|
from sse_starlette.sse import EventSourceResponse, ensure_bytes
|
|
|
|
|
from starlette.types import Send
|
|
|
|
@ -76,104 +75,3 @@ class StreamingResponse(EventSourceResponse):
|
|
|
|
|
|
|
|
|
|
def enable_compression(self, force: bool=False):
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
class LangchainStreamingResponse(StreamingResponse):
|
|
|
|
|
"""StreamingResponse class for LangChain resources."""
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
*args: Any,
|
|
|
|
|
chain: Chain,
|
|
|
|
|
config: dict[str, Any],
|
|
|
|
|
run_mode: str,
|
|
|
|
|
**kwargs: dict[str, Any],
|
|
|
|
|
) -> None:
|
|
|
|
|
"""Constructor method.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
chain: A LangChain instance.
|
|
|
|
|
config: A config dict.
|
|
|
|
|
*args: Positional arguments to pass to the parent constructor.
|
|
|
|
|
**kwargs: Keyword arguments to pass to the parent constructor.
|
|
|
|
|
"""
|
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
self.chain = chain
|
|
|
|
|
self.config = config
|
|
|
|
|
|
|
|
|
|
if run_mode not in list(["async", "sync"]):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Invalid run mode '{run_mode}'. Must be one of {list(['async', 'sync'])}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self.run_mode = run_mode
|
|
|
|
|
|
|
|
|
|
async def stream_response(self, send: Send) -> None:
|
|
|
|
|
"""Stream LangChain outputs.
|
|
|
|
|
|
|
|
|
|
If an exception occurs while iterating over the LangChain, an
|
|
|
|
|
internal server error is sent to the client.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
send: The ASGI send callable.
|
|
|
|
|
"""
|
|
|
|
|
await send(
|
|
|
|
|
{
|
|
|
|
|
"type": "http.response.start",
|
|
|
|
|
"status": self.status_code,
|
|
|
|
|
"headers": self.raw_headers,
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if "callbacks" in self.config:
|
|
|
|
|
for callback in self.config["callbacks"]:
|
|
|
|
|
if hasattr(callback, "send"):
|
|
|
|
|
callback.send = send
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
if self.run_mode == "async":
|
|
|
|
|
async for outputs in self.chain.astream(input=self.config):
|
|
|
|
|
if 'answer' in outputs:
|
|
|
|
|
chunk = ServerSentEvent(
|
|
|
|
|
data=outputs['answer']
|
|
|
|
|
)
|
|
|
|
|
# Send each chunk with the appropriate body type
|
|
|
|
|
await send(
|
|
|
|
|
{
|
|
|
|
|
"type": "http.response.body",
|
|
|
|
|
"body": ensure_bytes(chunk, None),
|
|
|
|
|
"more_body": True
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
loop = asyncio.get_event_loop()
|
|
|
|
|
outputs = await loop.run_in_executor(
|
|
|
|
|
None, partial(self.chain, **self.config)
|
|
|
|
|
)
|
|
|
|
|
if self.background is not None:
|
|
|
|
|
self.background.kwargs.update({"outputs": outputs})
|
|
|
|
|
except Exception as e:
|
|
|
|
|
print(f"chain runtime error: {e}")
|
|
|
|
|
if self.background is not None:
|
|
|
|
|
self.background.kwargs.update({"outputs": {}, "error": e})
|
|
|
|
|
chunk = ServerSentEvent(
|
|
|
|
|
data=dict(
|
|
|
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
|
|
|
detail="Internal Server Error",
|
|
|
|
|
),
|
|
|
|
|
event="error",
|
|
|
|
|
)
|
|
|
|
|
await send(
|
|
|
|
|
{
|
|
|
|
|
"type": "http.response.body",
|
|
|
|
|
"body": ensure_bytes(chunk, None),
|
|
|
|
|
"more_body": True,
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
await send({"type": "http.response.body", "body": b"", "more_body": False})
|
|
|
|
|
|
|
|
|
|
def enable_compression(self, force: bool=False):
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|