removed final langchain dependency

pull/570/head
Richard Anthony Hein 8 months ago
parent f422450347
commit 3ac06c76a3

@ -4,7 +4,6 @@ from functools import partial
from typing import Any, AsyncIterator from typing import Any, AsyncIterator
from fastapi import status from fastapi import status
from langchain.chains.base import Chain
from sse_starlette import ServerSentEvent from sse_starlette import ServerSentEvent
from sse_starlette.sse import EventSourceResponse, ensure_bytes from sse_starlette.sse import EventSourceResponse, ensure_bytes
from starlette.types import Send from starlette.types import Send
@ -76,104 +75,3 @@ class StreamingResponse(EventSourceResponse):
def enable_compression(self, force: bool=False): def enable_compression(self, force: bool=False):
raise NotImplementedError raise NotImplementedError
class LangchainStreamingResponse(StreamingResponse):
"""StreamingResponse class for LangChain resources."""
def __init__(
self,
*args: Any,
chain: Chain,
config: dict[str, Any],
run_mode: str,
**kwargs: dict[str, Any],
) -> None:
"""Constructor method.
Args:
chain: A LangChain instance.
config: A config dict.
*args: Positional arguments to pass to the parent constructor.
**kwargs: Keyword arguments to pass to the parent constructor.
"""
super().__init__(*args, **kwargs)
self.chain = chain
self.config = config
if run_mode not in list(["async", "sync"]):
raise ValueError(
f"Invalid run mode '{run_mode}'. Must be one of {list(['async', 'sync'])}"
)
self.run_mode = run_mode
async def stream_response(self, send: Send) -> None:
"""Stream LangChain outputs.
If an exception occurs while iterating over the LangChain, an
internal server error is sent to the client.
Args:
send: The ASGI send callable.
"""
await send(
{
"type": "http.response.start",
"status": self.status_code,
"headers": self.raw_headers,
}
)
if "callbacks" in self.config:
for callback in self.config["callbacks"]:
if hasattr(callback, "send"):
callback.send = send
try:
if self.run_mode == "async":
async for outputs in self.chain.astream(input=self.config):
if 'answer' in outputs:
chunk = ServerSentEvent(
data=outputs['answer']
)
# Send each chunk with the appropriate body type
await send(
{
"type": "http.response.body",
"body": ensure_bytes(chunk, None),
"more_body": True
}
)
else:
loop = asyncio.get_event_loop()
outputs = await loop.run_in_executor(
None, partial(self.chain, **self.config)
)
if self.background is not None:
self.background.kwargs.update({"outputs": outputs})
except Exception as e:
print(f"chain runtime error: {e}")
if self.background is not None:
self.background.kwargs.update({"outputs": {}, "error": e})
chunk = ServerSentEvent(
data=dict(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Internal Server Error",
),
event="error",
)
await send(
{
"type": "http.response.body",
"body": ensure_bytes(chunk, None),
"more_body": True,
}
)
await send({"type": "http.response.body", "body": b"", "more_body": False})
def enable_compression(self, force: bool=False):
raise NotImplementedError

Loading…
Cancel
Save