parent
c86e62400a
commit
b9ada3d1bb
@ -0,0 +1,448 @@
|
||||
# modified from Lanarky sourcecode https://github.com/auxon/lanarky
|
||||
from typing import Any, Optional
|
||||
|
||||
from fastapi.websockets import WebSocket
|
||||
from langchain.callbacks.base import AsyncCallbackHandler
|
||||
from langchain.callbacks.streaming_stdout_final_only import (
|
||||
FinalStreamingStdOutCallbackHandler,
|
||||
)
|
||||
from langchain.globals import get_llm_cache
|
||||
from langchain.schema.document import Document
|
||||
from pydantic import BaseModel
|
||||
from starlette.types import Message, Send
|
||||
from sse_starlette.sse import ensure_bytes, ServerSentEvent
|
||||
from swarms.server.utils import StrEnum, model_dump_json
|
||||
|
||||
|
||||
class LangchainEvents(StrEnum):
|
||||
SOURCE_DOCUMENTS = "source_documents"
|
||||
|
||||
|
||||
class BaseCallbackHandler(AsyncCallbackHandler):
|
||||
"""Base callback handler for streaming / async applications."""
|
||||
|
||||
def __init__(self, **kwargs: dict[str, Any]) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.llm_cache_used = get_llm_cache() is not None
|
||||
|
||||
@property
|
||||
def always_verbose(self) -> bool:
|
||||
"""Verbose mode is always enabled."""
|
||||
return True
|
||||
|
||||
async def on_chat_model_start(self, *args: Any, **kwargs: Any) -> Any: ...
|
||||
|
||||
|
||||
class StreamingCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback handler for streaming responses."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
send: Send = None,
|
||||
**kwargs: dict[str, Any],
|
||||
) -> None:
|
||||
"""Constructor method.
|
||||
|
||||
Args:
|
||||
send: The ASGI send callable.
|
||||
**kwargs: Keyword arguments to pass to the parent constructor.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._send = send
|
||||
self.streaming = None
|
||||
|
||||
@property
|
||||
def send(self) -> Send:
|
||||
return self._send
|
||||
|
||||
@send.setter
|
||||
def send(self, value: Send) -> None:
|
||||
"""Setter method for send property."""
|
||||
if not callable(value):
|
||||
raise ValueError("value must be a Callable")
|
||||
self._send = value
|
||||
|
||||
def _construct_message(self, data: str, event: Optional[str] = None) -> Message:
|
||||
"""Constructs message payload.
|
||||
|
||||
Args:
|
||||
data: The data payload.
|
||||
event: The event name.
|
||||
"""
|
||||
chunk = ServerSentEvent(data=data, event=event)
|
||||
return {
|
||||
"type": "http.response.body",
|
||||
"body": ensure_bytes(chunk, None),
|
||||
"more_body": True,
|
||||
}
|
||||
|
||||
|
||||
class TokenStreamMode(StrEnum):
|
||||
TEXT = "text"
|
||||
JSON = "json"
|
||||
|
||||
|
||||
class TokenEventData(BaseModel):
|
||||
"""Event data payload for tokens."""
|
||||
|
||||
token: str = ""
|
||||
|
||||
|
||||
def get_token_data(token: str, mode: TokenStreamMode) -> str:
|
||||
"""Get token data based on mode.
|
||||
|
||||
Args:
|
||||
token: The token to use.
|
||||
mode: The stream mode.
|
||||
"""
|
||||
if mode not in list(TokenStreamMode):
|
||||
raise ValueError(f"Invalid stream mode: {mode}")
|
||||
|
||||
if mode == TokenStreamMode.TEXT:
|
||||
return token
|
||||
else:
|
||||
return model_dump_json(TokenEventData(token=token))
|
||||
|
||||
|
||||
class TokenStreamingCallbackHandler(StreamingCallbackHandler):
|
||||
"""Callback handler for streaming tokens."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
output_key: str,
|
||||
mode: TokenStreamMode = TokenStreamMode.JSON,
|
||||
**kwargs: dict[str, Any],
|
||||
) -> None:
|
||||
"""Constructor method.
|
||||
|
||||
Args:
|
||||
output_key: chain output key.
|
||||
mode: The stream mode.
|
||||
**kwargs: Keyword arguments to pass to the parent constructor.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.output_key = output_key
|
||||
|
||||
if mode not in list(TokenStreamMode):
|
||||
raise ValueError(f"Invalid stream mode: {mode}")
|
||||
self.mode = mode
|
||||
|
||||
async def on_chain_start(self, *args: Any, **kwargs: dict[str, Any]) -> None:
|
||||
"""Run when chain starts running."""
|
||||
self.streaming = False
|
||||
|
||||
async def on_llm_new_token(self, token: str, **kwargs: dict[str, Any]) -> None:
|
||||
"""Run on new LLM token. Only available when streaming is enabled."""
|
||||
if not self.streaming:
|
||||
self.streaming = True
|
||||
|
||||
if self.llm_cache_used: # cache missed (or was never enabled) if we are here
|
||||
self.llm_cache_used = False
|
||||
|
||||
message = self._construct_message(
|
||||
data=get_token_data(token, self.mode), event="completion"
|
||||
)
|
||||
await self.send(message)
|
||||
|
||||
async def on_chain_end(
|
||||
self, outputs: dict[str, Any], **kwargs: dict[str, Any]
|
||||
) -> None:
|
||||
"""Run when chain ends running.
|
||||
|
||||
Final output is streamed only if LLM cache is enabled.
|
||||
"""
|
||||
if self.llm_cache_used or not self.streaming:
|
||||
if self.output_key in outputs:
|
||||
message = self._construct_message(
|
||||
data=get_token_data(outputs[self.output_key], self.mode),
|
||||
event="completion",
|
||||
)
|
||||
await self.send(message)
|
||||
else:
|
||||
raise KeyError(f"missing outputs key: {self.output_key}")
|
||||
|
||||
|
||||
class SourceDocumentsEventData(BaseModel):
|
||||
"""Event data payload for source documents."""
|
||||
|
||||
source_documents: list[dict[str, Any]]
|
||||
|
||||
|
||||
class SourceDocumentsStreamingCallbackHandler(StreamingCallbackHandler):
|
||||
"""Callback handler for streaming source documents."""
|
||||
|
||||
async def on_chain_end(
|
||||
self, outputs: dict[str, Any], **kwargs: dict[str, Any]
|
||||
) -> None:
|
||||
"""Run when chain ends running."""
|
||||
if "source_documents" in outputs:
|
||||
if not isinstance(outputs["source_documents"], list):
|
||||
raise ValueError("source_documents must be a list")
|
||||
if not isinstance(outputs["source_documents"][0], Document):
|
||||
raise ValueError("source_documents must be a list of Document")
|
||||
|
||||
# NOTE: langchain is using pydantic_v1 for `Document`
|
||||
source_documents: list[dict] = [
|
||||
document.dict() for document in outputs["source_documents"]
|
||||
]
|
||||
message = self._construct_message(
|
||||
data=model_dump_json(
|
||||
SourceDocumentsEventData(source_documents=source_documents)
|
||||
),
|
||||
event=LangchainEvents.SOURCE_DOCUMENTS,
|
||||
)
|
||||
await self.send(message)
|
||||
|
||||
|
||||
class FinalTokenStreamingCallbackHandler(
|
||||
TokenStreamingCallbackHandler, FinalStreamingStdOutCallbackHandler
|
||||
):
|
||||
"""Callback handler for streaming final answer tokens.
|
||||
|
||||
Useful for streaming responses from Langchain Agents.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
answer_prefix_tokens: Optional[list[str]] = None,
|
||||
strip_tokens: bool = True,
|
||||
stream_prefix: bool = False,
|
||||
**kwargs: dict[str, Any],
|
||||
) -> None:
|
||||
"""Constructor method.
|
||||
|
||||
Args:
|
||||
answer_prefix_tokens: The answer prefix tokens to use.
|
||||
strip_tokens: Whether to strip tokens.
|
||||
stream_prefix: Whether to stream the answer prefix.
|
||||
**kwargs: Keyword arguments to pass to the parent constructor.
|
||||
"""
|
||||
super().__init__(output_key=None, **kwargs)
|
||||
|
||||
FinalStreamingStdOutCallbackHandler.__init__(
|
||||
self,
|
||||
answer_prefix_tokens=answer_prefix_tokens,
|
||||
strip_tokens=strip_tokens,
|
||||
stream_prefix=stream_prefix,
|
||||
)
|
||||
|
||||
async def on_llm_start(self, *args: Any, **kwargs: dict[str, Any]) -> None:
|
||||
"""Run when LLM starts running."""
|
||||
self.answer_reached = False
|
||||
self.streaming = False
|
||||
|
||||
async def on_llm_new_token(self, token: str, **kwargs: dict[str, Any]) -> None:
|
||||
"""Run on new LLM token. Only available when streaming is enabled."""
|
||||
if not self.streaming:
|
||||
self.streaming = True
|
||||
|
||||
# Remember the last n tokens, where n = len(answer_prefix_tokens)
|
||||
self.append_to_last_tokens(token)
|
||||
|
||||
# Check if the last n tokens match the answer_prefix_tokens list ...
|
||||
if self.check_if_answer_reached():
|
||||
self.answer_reached = True
|
||||
if self.stream_prefix:
|
||||
message = self._construct_message(
|
||||
data=get_token_data("".join(self.last_tokens), self.mode),
|
||||
event="completion",
|
||||
)
|
||||
await self.send(message)
|
||||
|
||||
# ... if yes, then print tokens from now on
|
||||
if self.answer_reached:
|
||||
message = self._construct_message(
|
||||
data=get_token_data(token, self.mode), event="completion"
|
||||
)
|
||||
await self.send(message)
|
||||
|
||||
|
||||
class WebSocketCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback handler for websocket sessions."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
mode: TokenStreamMode = TokenStreamMode.JSON,
|
||||
websocket: WebSocket = None,
|
||||
**kwargs: dict[str, Any],
|
||||
) -> None:
|
||||
"""Constructor method.
|
||||
|
||||
Args:
|
||||
mode: The stream mode.
|
||||
websocket: The websocket to use.
|
||||
**kwargs: Keyword arguments to pass to the parent constructor.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if mode not in list(TokenStreamMode):
|
||||
raise ValueError(f"Invalid stream mode: {mode}")
|
||||
self.mode = mode
|
||||
|
||||
self._websocket = websocket
|
||||
self.streaming = None
|
||||
|
||||
@property
|
||||
def websocket(self) -> WebSocket:
|
||||
return self._websocket
|
||||
|
||||
@websocket.setter
|
||||
def websocket(self, value: WebSocket) -> None:
|
||||
"""Setter method for send property."""
|
||||
if not isinstance(value, WebSocket):
|
||||
raise ValueError("value must be a WebSocket")
|
||||
self._websocket = value
|
||||
|
||||
def _construct_message(self, data: str, event: Optional[str] = None) -> Message:
|
||||
"""Constructs message payload.
|
||||
|
||||
Args:
|
||||
data: The data payload.
|
||||
event: The event name.
|
||||
"""
|
||||
return dict(data=data, event=event)
|
||||
|
||||
|
||||
class TokenWebSocketCallbackHandler(WebSocketCallbackHandler):
|
||||
"""Callback handler for sending tokens in websocket sessions."""
|
||||
|
||||
def __init__(self, *, output_key: str, **kwargs: dict[str, Any]) -> None:
|
||||
"""Constructor method.
|
||||
|
||||
Args:
|
||||
output_key: chain output key.
|
||||
**kwargs: Keyword arguments to pass to the parent constructor.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.output_key = output_key
|
||||
|
||||
async def on_chain_start(self, *args: Any, **kwargs: dict[str, Any]) -> None:
|
||||
"""Run when chain starts running."""
|
||||
self.streaming = False
|
||||
|
||||
async def on_llm_new_token(self, token: str, **kwargs: dict[str, Any]) -> None:
|
||||
"""Run on new LLM token. Only available when streaming is enabled."""
|
||||
if not self.streaming:
|
||||
self.streaming = True
|
||||
|
||||
if self.llm_cache_used: # cache missed (or was never enabled) if we are here
|
||||
self.llm_cache_used = False
|
||||
|
||||
message = self._construct_message(
|
||||
data=get_token_data(token, self.mode), event="completion"
|
||||
)
|
||||
await self.websocket.send_json(message)
|
||||
|
||||
async def on_chain_end(
|
||||
self, outputs: dict[str, Any], **kwargs: dict[str, Any]
|
||||
) -> None:
|
||||
"""Run when chain ends running.
|
||||
|
||||
Final output is streamed only if LLM cache is enabled.
|
||||
"""
|
||||
if self.llm_cache_used or not self.streaming:
|
||||
if self.output_key in outputs:
|
||||
message = self._construct_message(
|
||||
data=get_token_data(outputs[self.output_key], self.mode),
|
||||
event="completion",
|
||||
)
|
||||
await self.websocket.send_json(message)
|
||||
else:
|
||||
raise KeyError(f"missing outputs key: {self.output_key}")
|
||||
|
||||
|
||||
class SourceDocumentsWebSocketCallbackHandler(WebSocketCallbackHandler):
|
||||
"""Callback handler for sending source documents in websocket sessions."""
|
||||
|
||||
async def on_chain_end(
|
||||
self, outputs: dict[str, Any], **kwargs: dict[str, Any]
|
||||
) -> None:
|
||||
"""Run when chain ends running."""
|
||||
if "source_documents" in outputs:
|
||||
if not isinstance(outputs["source_documents"], list):
|
||||
raise ValueError("source_documents must be a list")
|
||||
if not isinstance(outputs["source_documents"][0], Document):
|
||||
raise ValueError("source_documents must be a list of Document")
|
||||
|
||||
# NOTE: langchain is using pydantic_v1 for `Document`
|
||||
source_documents: list[dict] = [
|
||||
document.dict() for document in outputs["source_documents"]
|
||||
]
|
||||
message = self._construct_message(
|
||||
data=model_dump_json(
|
||||
SourceDocumentsEventData(source_documents=source_documents)
|
||||
),
|
||||
event=LangchainEvents.SOURCE_DOCUMENTS,
|
||||
)
|
||||
await self.websocket.send_json(message)
|
||||
|
||||
|
||||
class FinalTokenWebSocketCallbackHandler(
|
||||
TokenWebSocketCallbackHandler, FinalStreamingStdOutCallbackHandler
|
||||
):
|
||||
"""Callback handler for sending final answer tokens in websocket sessions.
|
||||
|
||||
Useful for streaming responses from Langchain Agents.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
answer_prefix_tokens: Optional[list[str]] = None,
|
||||
strip_tokens: bool = True,
|
||||
stream_prefix: bool = False,
|
||||
**kwargs: dict[str, Any],
|
||||
) -> None:
|
||||
"""Constructor method.
|
||||
|
||||
Args:
|
||||
answer_prefix_tokens: The answer prefix tokens to use.
|
||||
strip_tokens: Whether to strip tokens.
|
||||
stream_prefix: Whether to stream the answer prefix.
|
||||
**kwargs: Keyword arguments to pass to the parent constructor.
|
||||
"""
|
||||
super().__init__(output_key=None, **kwargs)
|
||||
|
||||
FinalStreamingStdOutCallbackHandler.__init__(
|
||||
self,
|
||||
answer_prefix_tokens=answer_prefix_tokens,
|
||||
strip_tokens=strip_tokens,
|
||||
stream_prefix=stream_prefix,
|
||||
)
|
||||
|
||||
async def on_llm_start(self, *args, **kwargs) -> None:
|
||||
"""Run when LLM starts running."""
|
||||
self.answer_reached = False
|
||||
self.streaming = False
|
||||
|
||||
async def on_llm_new_token(self, token: str, **kwargs: dict[str, Any]) -> None:
|
||||
"""Run on new LLM token. Only available when streaming is enabled."""
|
||||
if not self.streaming:
|
||||
self.streaming = True
|
||||
|
||||
# Remember the last n tokens, where n = len(answer_prefix_tokens)
|
||||
self.append_to_last_tokens(token)
|
||||
|
||||
# Check if the last n tokens match the answer_prefix_tokens list ...
|
||||
if self.check_if_answer_reached():
|
||||
self.answer_reached = True
|
||||
if self.stream_prefix:
|
||||
message = self._construct_message(
|
||||
data=get_token_data("".join(self.last_tokens), self.mode),
|
||||
event="completion",
|
||||
)
|
||||
await self.websocket.send_json(message)
|
||||
|
||||
# ... if yes, then print tokens from now on
|
||||
if self.answer_reached:
|
||||
message = self._construct_message(
|
||||
data=get_token_data(token, self.mode), event="completion"
|
||||
)
|
||||
await self.websocket.send_json(message)
|
@ -0,0 +1,179 @@
|
||||
from typing import Any
|
||||
import asyncio
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
from fastapi import status
|
||||
from langchain.chains.base import Chain
|
||||
from starlette.types import Send
|
||||
from fastapi import status
|
||||
from sse_starlette import ServerSentEvent
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
from starlette.types import Send
|
||||
|
||||
from swarms.server.utils import StrEnum
|
||||
|
||||
from sse_starlette.sse import ensure_bytes
|
||||
|
||||
|
||||
class HTTPStatusDetail(StrEnum):
|
||||
INTERNAL_SERVER_ERROR = "Internal Server Error"
|
||||
|
||||
|
||||
class StreamingResponse(EventSourceResponse):
|
||||
"""`Response` class for streaming server-sent events.
|
||||
|
||||
Follows the
|
||||
[EventSource protocol](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events#interfaces)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
content: Any = iter(()),
|
||||
*args: Any,
|
||||
**kwargs: dict[str, Any],
|
||||
) -> None:
|
||||
"""Constructor method.
|
||||
|
||||
Args:
|
||||
content: The content to stream.
|
||||
"""
|
||||
super().__init__(content=content, *args, **kwargs)
|
||||
|
||||
async def stream_response(self, send: Send) -> None:
|
||||
"""Streams data chunks to client by iterating over `content`.
|
||||
|
||||
If an exception occurs while iterating over `content`, an
|
||||
internal server error is sent to the client.
|
||||
|
||||
Args:
|
||||
send: The send function from the ASGI framework.
|
||||
"""
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": self.status_code,
|
||||
"headers": self.raw_headers,
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
async for data in self.body_iterator:
|
||||
chunk = ensure_bytes(data, self.sep)
|
||||
print(f"chunk: {chunk.decode()}")
|
||||
await send(
|
||||
{"type": "http.response.body", "body": chunk, "more_body": True}
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"body iterator error: {e}")
|
||||
chunk = ServerSentEvent(
|
||||
data=dict(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=HTTPStatusDetail.INTERNAL_SERVER_ERROR,
|
||||
),
|
||||
event="error",
|
||||
)
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.body",
|
||||
"body": ensure_bytes(chunk, None),
|
||||
"more_body": True,
|
||||
}
|
||||
)
|
||||
|
||||
await send({"type": "http.response.body", "body": b"", "more_body": False})
|
||||
|
||||
|
||||
|
||||
class ChainRunMode(StrEnum):
|
||||
"""Enum for LangChain run modes."""
|
||||
|
||||
ASYNC = "async"
|
||||
SYNC = "sync"
|
||||
|
||||
|
||||
class LangchainStreamingResponse(StreamingResponse):
|
||||
"""StreamingResponse class for LangChain resources."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chain: Chain,
|
||||
config: dict[str, Any],
|
||||
run_mode: ChainRunMode = ChainRunMode.ASYNC,
|
||||
*args: Any,
|
||||
**kwargs: dict[str, Any],
|
||||
) -> None:
|
||||
"""Constructor method.
|
||||
|
||||
Args:
|
||||
chain: A LangChain instance.
|
||||
config: A config dict.
|
||||
*args: Positional arguments to pass to the parent constructor.
|
||||
**kwargs: Keyword arguments to pass to the parent constructor.
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.chain = chain
|
||||
self.config = config
|
||||
|
||||
if run_mode not in list(ChainRunMode):
|
||||
raise ValueError(
|
||||
f"Invalid run mode '{run_mode}'. Must be one of {list(ChainRunMode)}"
|
||||
)
|
||||
|
||||
self.run_mode = run_mode
|
||||
|
||||
async def stream_response(self, send: Send) -> None:
|
||||
"""Stream LangChain outputs.
|
||||
|
||||
If an exception occurs while iterating over the LangChain, an
|
||||
internal server error is sent to the client.
|
||||
|
||||
Args:
|
||||
send: The ASGI send callable.
|
||||
"""
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": self.status_code,
|
||||
"headers": self.raw_headers,
|
||||
}
|
||||
)
|
||||
|
||||
if "callbacks" in self.config:
|
||||
for callback in self.config["callbacks"]:
|
||||
if hasattr(callback, "send"):
|
||||
callback.send = send
|
||||
|
||||
try:
|
||||
# TODO: migrate to `.ainvoke` when adding support
|
||||
# for LCEL
|
||||
if self.run_mode == ChainRunMode.ASYNC:
|
||||
outputs = await self.chain.acall(**self.config)
|
||||
else:
|
||||
loop = asyncio.get_event_loop()
|
||||
outputs = await loop.run_in_executor(
|
||||
None, partial(self.chain, **self.config)
|
||||
)
|
||||
if self.background is not None:
|
||||
self.background.kwargs.update({"outputs": outputs})
|
||||
except Exception as e:
|
||||
print(f"chain runtime error: {e}")
|
||||
if self.background is not None:
|
||||
self.background.kwargs.update({"outputs": {}, "error": e})
|
||||
chunk = ServerSentEvent(
|
||||
data=dict(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=HTTPStatusDetail.INTERNAL_SERVER_ERROR,
|
||||
),
|
||||
event="error",
|
||||
)
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.body",
|
||||
"body": ensure_bytes(chunk, None),
|
||||
"more_body": True,
|
||||
}
|
||||
)
|
||||
|
||||
await send({"type": "http.response.body", "body": b"", "more_body": False})
|
@ -0,0 +1,69 @@
|
||||
from typing import Any
|
||||
from fastapi import status
|
||||
from starlette.types import Send
|
||||
from sse_starlette.sse import ensure_bytes, EventSourceResponse, ServerSentEvent
|
||||
|
||||
class StreamingResponse(EventSourceResponse):
|
||||
"""`Response` class for streaming server-sent events.
|
||||
|
||||
Follows the
|
||||
[EventSource protocol](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events#interfaces)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
content: Any = iter(()),
|
||||
*args: Any,
|
||||
**kwargs: dict[str, Any],
|
||||
) -> None:
|
||||
"""Constructor method.
|
||||
|
||||
Args:
|
||||
content: The content to stream.
|
||||
"""
|
||||
super().__init__(content=content, *args, **kwargs)
|
||||
|
||||
async def stream_response(self, send: Send) -> None:
|
||||
"""Streams data chunks to client by iterating over `content`.
|
||||
|
||||
If an exception occurs while iterating over `content`, an
|
||||
internal server error is sent to the client.
|
||||
|
||||
Args:
|
||||
send: The send function from the ASGI framework.
|
||||
"""
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": self.status_code,
|
||||
"headers": self.raw_headers,
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
async for data in self.body_iterator:
|
||||
chunk = ensure_bytes(data, self.sep)
|
||||
with open("log.txt", "a") as log_file:
|
||||
log_file.write(f"chunk: {chunk.decode()}\n")
|
||||
await send(
|
||||
{"type": "http.response.body", "body": chunk, "more_body": True}
|
||||
)
|
||||
except Exception as e:
|
||||
with open("log.txt", "a") as log_file:
|
||||
log_file.write(f"body iterator error: {e}\n")
|
||||
chunk = ServerSentEvent(
|
||||
data=dict(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Internal Server Error",
|
||||
),
|
||||
event="error",
|
||||
)
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.body",
|
||||
"body": ensure_bytes(chunk, None),
|
||||
"more_body": True,
|
||||
}
|
||||
)
|
||||
|
||||
await send({"type": "http.response.body", "body": b"", "more_body": False})
|
@ -0,0 +1,445 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
|
||||
import langchain
|
||||
from pydantic import ValidationError, parse_obj_as
|
||||
from swarms.prompts.chat_prompt import Message, Role
|
||||
from swarms.server.callback_handlers import SourceDocumentsStreamingCallbackHandler, TokenStreamingCallbackHandler
|
||||
import tiktoken
|
||||
|
||||
# import torch
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import FastAPI, File, Form, HTTPException, Request, UploadFile
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import FileResponse, JSONResponse
|
||||
from fastapi.routing import APIRouter
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from huggingface_hub import login
|
||||
from langchain.callbacks import StreamingStdOutCallbackHandler
|
||||
from langchain.memory import ConversationSummaryBufferMemory
|
||||
from langchain.memory.chat_message_histories.in_memory import ChatMessageHistory
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain_community.chat_models import ChatOpenAI
|
||||
from swarms.server.responses import LangchainStreamingResponse
|
||||
from langchain.chains.conversational_retrieval.base import ConversationalRetrievalChain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
||||
from swarms.prompts.conversational_RAG import (
|
||||
B_INST,
|
||||
B_SYS,
|
||||
CONDENSE_PROMPT_TEMPLATE,
|
||||
DOCUMENT_PROMPT_TEMPLATE,
|
||||
E_INST,
|
||||
E_SYS,
|
||||
QA_PROMPT_TEMPLATE,
|
||||
SUMMARY_PROMPT_TEMPLATE,
|
||||
)
|
||||
|
||||
from swarms.server.vector_store import VectorStorage
|
||||
|
||||
from swarms.server.server_models import (
|
||||
ChatRequest,
|
||||
LogMessage,
|
||||
AIModel,
|
||||
AIModels,
|
||||
RAGFile,
|
||||
RAGFiles,
|
||||
State,
|
||||
GetRAGFileStateRequest,
|
||||
ProcessRAGFileRequest
|
||||
)
|
||||
|
||||
# Explicitly specify the path to the .env file
|
||||
dotenv_path = os.path.join(os.path.dirname(__file__), '.env')
|
||||
load_dotenv(dotenv_path)
|
||||
|
||||
hf_token = os.environ.get("HUGGINFACEHUB_API_KEY") # Get the Huggingface API Token
|
||||
uploads = os.environ.get("UPLOADS") # Directory where user uploads files to be parsed for RAG
|
||||
model_dir = os.environ.get("MODEL_DIR")
|
||||
|
||||
# hugginface.co model (eg. meta-llama/Llama-2-70b-hf)
|
||||
model_name = os.environ.get("MODEL_NAME")
|
||||
|
||||
# Set OpenAI's API key to 'EMPTY' and API base URL to use vLLM's API server, or set them to OpenAI API key and base URL.
|
||||
openai_api_key = os.environ.get("OPENAI_API_KEY") or "EMPTY"
|
||||
openai_api_base = os.environ.get("OPENAI_API_BASE") or "http://localhost:8000/v1"
|
||||
|
||||
env_vars = [
|
||||
hf_token,
|
||||
uploads,
|
||||
model_dir,
|
||||
model_name,
|
||||
openai_api_key,
|
||||
openai_api_base,
|
||||
]
|
||||
missing_vars = [var for var in env_vars if not var]
|
||||
|
||||
if missing_vars:
|
||||
print(
|
||||
f"Error: The following environment variables are not set: {', '.join(missing_vars)}"
|
||||
)
|
||||
exit(1)
|
||||
|
||||
useMetal = os.environ.get("USE_METAL", "False") == "True"
|
||||
|
||||
print(f"Uploads={uploads}")
|
||||
print(f"MODEL_DIR={model_dir}")
|
||||
print(f"MODEL_NAME={model_name}")
|
||||
print(f"USE_METAL={useMetal}")
|
||||
print(f"OPENAI_API_KEY={openai_api_key}")
|
||||
print(f"OPENAI_API_BASE={openai_api_base}")
|
||||
|
||||
# update tiktoken to include the model name (avoids warning message)
|
||||
tiktoken.model.MODEL_TO_ENCODING.update(
|
||||
{
|
||||
model_name: "cl100k_base",
|
||||
}
|
||||
)
|
||||
|
||||
print("Logging in to huggingface.co...")
|
||||
login(token=hf_token) # login to huggingface.co
|
||||
|
||||
# langchain.debug = True
|
||||
langchain.verbose = True
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
asyncio.create_task(vector_store.initRetrievers())
|
||||
yield
|
||||
|
||||
app = FastAPI(title="Chatbot", lifespan=lifespan)
|
||||
router = APIRouter()
|
||||
|
||||
current_dir = os.path.dirname(__file__)
|
||||
print("current_dir: " + current_dir)
|
||||
static_dir = os.path.join(current_dir, "static")
|
||||
print("static_dir: " + static_dir)
|
||||
app.mount(static_dir, StaticFiles(directory=static_dir), name="static")
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["GET", "POST"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
# Create ./uploads folder if it doesn't exist
|
||||
uploads = uploads or os.path.join(os.getcwd(), "uploads")
|
||||
if not os.path.exists(uploads):
|
||||
os.makedirs(uploads)
|
||||
|
||||
# Initialize the vector store
|
||||
vector_store = VectorStorage(directory=uploads)
|
||||
|
||||
|
||||
async def create_chain(
|
||||
messages: list[Message],
|
||||
model=model_dir,
|
||||
max_tokens_to_gen=2048,
|
||||
temperature=0.5,
|
||||
prompt: PromptTemplate = QA_PROMPT_TEMPLATE,
|
||||
file: RAGFile | None = None,
|
||||
key: str | None = None,
|
||||
):
|
||||
print(
|
||||
f"Creating chain with key={key}, model={model}, max_tokens={max_tokens_to_gen}, temperature={temperature}, prompt={prompt}, file={file.title}"
|
||||
)
|
||||
|
||||
llm = ChatOpenAI(
|
||||
api_key=openai_api_key,
|
||||
base_url=openai_api_base,
|
||||
model=model_name,
|
||||
verbose=True,
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
# if llm is ALlamaCpp:
|
||||
# llm.max_tokens = max_tokens_to_gen
|
||||
# elif llm is AGPT4All:
|
||||
# llm.n_predict = max_tokens_to_gen
|
||||
# el
|
||||
# if llm is AChatOllama:
|
||||
# llm.max_tokens = max_tokens_to_gen
|
||||
# if llm is VLLMAsync:
|
||||
# llm.max_tokens = max_tokens_to_gen
|
||||
|
||||
retriever = await vector_store.getRetriever(os.path.join(file.username, file.filename))
|
||||
|
||||
chat_memory = ChatMessageHistory()
|
||||
|
||||
for message in messages:
|
||||
if message.role == Role.HUMAN:
|
||||
chat_memory.add_user_message(message.content)
|
||||
elif message.role == Role.AI:
|
||||
chat_memory.add_ai_message(message.content)
|
||||
elif message.role == Role.SYSTEM:
|
||||
chat_memory.add_message(message.content)
|
||||
elif message.role == Role.FUNCTION:
|
||||
chat_memory.add_message(message.content)
|
||||
|
||||
memory = ConversationSummaryBufferMemory(
|
||||
llm=llm,
|
||||
chat_memory=chat_memory,
|
||||
memory_key="chat_history",
|
||||
input_key="question",
|
||||
output_key="answer",
|
||||
prompt=SUMMARY_PROMPT_TEMPLATE,
|
||||
return_messages=True,
|
||||
)
|
||||
|
||||
question_generator = LLMChain(
|
||||
llm=llm,
|
||||
prompt=CONDENSE_PROMPT_TEMPLATE,
|
||||
memory=memory,
|
||||
verbose=True,
|
||||
output_key="answer",
|
||||
)
|
||||
|
||||
stuff_chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
verbose=True,
|
||||
output_key="answer",
|
||||
)
|
||||
|
||||
doc_chain = StuffDocumentsChain(
|
||||
llm_chain=stuff_chain,
|
||||
document_variable_name="context",
|
||||
document_prompt=DOCUMENT_PROMPT_TEMPLATE,
|
||||
verbose=True,
|
||||
output_key="answer",
|
||||
memory=memory,
|
||||
)
|
||||
|
||||
return ConversationalRetrievalChain(
|
||||
combine_docs_chain=doc_chain,
|
||||
memory=memory,
|
||||
retriever=retriever,
|
||||
question_generator=question_generator,
|
||||
return_generated_question=False,
|
||||
return_source_documents=True,
|
||||
output_key="answer",
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.post(
|
||||
"/chat",
|
||||
summary="Chatbot",
|
||||
description="Chatbot AI Service",
|
||||
)
|
||||
async def chat(request: ChatRequest):
|
||||
chain: ConversationalRetrievalChain = await create_chain(
|
||||
file=request.file,
|
||||
messages=request.messages[:-1],
|
||||
model=request.model.id,
|
||||
max_tokens_to_gen=request.maxTokens,
|
||||
temperature=request.temperature,
|
||||
prompt=PromptTemplate.from_template(
|
||||
f"{B_INST}{B_SYS}{request.prompt.strip()}{E_SYS}{E_INST}"
|
||||
),
|
||||
)
|
||||
|
||||
# async for token in chain.astream(request.messages[-1].content):
|
||||
# print(f"token={token}")
|
||||
|
||||
json_string = json.dumps(
|
||||
{
|
||||
"question": request.messages[-1].content,
|
||||
# "chat_history": [message.content for message in request.messages[:-1]],
|
||||
}
|
||||
)
|
||||
return LangchainStreamingResponse(
|
||||
chain,
|
||||
config={
|
||||
"inputs": json_string,
|
||||
"callbacks": [
|
||||
StreamingStdOutCallbackHandler(),
|
||||
TokenStreamingCallbackHandler(output_key="answer"),
|
||||
SourceDocumentsStreamingCallbackHandler(),
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
app.include_router(router, tags=["chat"])
|
||||
|
||||
|
||||
@app.get("/")
|
||||
def root():
|
||||
return {"message": "Chatbot API"}
|
||||
|
||||
|
||||
@app.get("/favicon.ico")
|
||||
def favicon():
|
||||
file_name = "favicon.ico"
|
||||
file_path = os.path.join(app.root_path, "static", file_name)
|
||||
return FileResponse(
|
||||
path=file_path,
|
||||
headers={"Content-Disposition": "attachment; filename=" + file_name},
|
||||
)
|
||||
|
||||
|
||||
@app.post("/log")
|
||||
def log_message(log_message: LogMessage):
|
||||
try:
|
||||
with open("log.txt", "a") as log_file:
|
||||
log_file.write(log_message.message + "\n")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error saving log: {e}")
|
||||
return {"message": "Log saved successfully"}
|
||||
|
||||
|
||||
@app.get("/models")
|
||||
def get_models():
|
||||
# llama7B = AIModel(
|
||||
# id="llama-2-7b-chat-ggml-q4_0",
|
||||
# name="llama-2-7b-chat-ggml-q4_0",
|
||||
# maxLength=2048,
|
||||
# tokenLimit=2048,
|
||||
# )
|
||||
# llama13B = AIModel(
|
||||
# id="llama-2-13b-chat-ggml-q4_0",
|
||||
# name="llama-2-13b-chat-ggml-q4_0",
|
||||
# maxLength=2048,
|
||||
# tokenLimit=2048,
|
||||
# )
|
||||
llama70B = AIModel(
|
||||
id="llama-2-70b.Q5_K_M",
|
||||
name="llama-2-70b.Q5_K_M",
|
||||
maxLength=2048,
|
||||
tokenLimit=2048,
|
||||
)
|
||||
models = AIModels(models=[llama70B])
|
||||
return models
|
||||
|
||||
|
||||
@app.get("/titles")
|
||||
def getTitles():
|
||||
titles = RAGFiles(
|
||||
titles=[
|
||||
# RAGFile(
|
||||
# versionId="d8ad3b1d-c33c-4524-9691-e93967d4d863",
|
||||
# title="d8ad3b1d-c33c-4524-9691-e93967d4d863",
|
||||
# state=State.Unavailable,
|
||||
# ),
|
||||
RAGFile(
|
||||
versionId=collection.name,
|
||||
title=collection.name,
|
||||
state=State.InProcess
|
||||
if collection.name in processing_books
|
||||
else State.Processed,
|
||||
)
|
||||
for collection in vector_store.list_collections()
|
||||
if collection.name != "langchain"
|
||||
]
|
||||
)
|
||||
return titles
|
||||
|
||||
|
||||
processing_books: list[str] = []
|
||||
processing_books_lock = asyncio.Lock()
|
||||
|
||||
logging.basicConfig(level=logging.ERROR)
|
||||
|
||||
|
||||
@app.post("/titleState")
|
||||
async def getTitleState(request: GetRAGFileStateRequest):
|
||||
# FastAPI + Pydantic will throw a 422 Unprocessable Entity if the request isn't the right type.
|
||||
# try:
|
||||
logging.debug(f"Received getTitleState request: {request}")
|
||||
titleStateRequest: GetRAGFileStateRequest = request
|
||||
# except ValidationError as e:
|
||||
# print(f"Error validating JSON: {e}")
|
||||
# raise HTTPException(status_code=422, detail=str(e))
|
||||
# except json.JSONDecodeError as e:
|
||||
# print(f"Error parsing JSON: {e}")
|
||||
# raise HTTPException(status_code=422, detail="Invalid JSON format")
|
||||
# check to see if the book has already been processed.
|
||||
# return the proper State directly to response.
|
||||
matchingCollection = next(
|
||||
(
|
||||
x
|
||||
for x in vector_store.list_collections()
|
||||
if x.name == titleStateRequest.versionRef
|
||||
),
|
||||
None,
|
||||
)
|
||||
print("Got a Title State request for version " + titleStateRequest.versionRef)
|
||||
if titleStateRequest.versionRef in processing_books:
|
||||
return {"message": State.InProcess}
|
||||
elif matchingCollection is not None:
|
||||
return {"message": State.Processed}
|
||||
else:
|
||||
return {"message": State.Unavailable}
|
||||
|
||||
|
||||
@app.post("/processRAGFile")
|
||||
async def processRAGFile(
|
||||
request: str = Form(...),
|
||||
files: List[UploadFile] = File(...),
|
||||
):
|
||||
try:
|
||||
logging.debug(f"Received processBook request: {request}")
|
||||
# Parse the JSON string into a ProcessBookRequest object
|
||||
fileRAGRequest: ProcessRAGFileRequest = parse_obj_as(
|
||||
ProcessRAGFileRequest, json.loads(request)
|
||||
)
|
||||
except ValidationError as e:
|
||||
print(f"Error validating JSON: {e}")
|
||||
raise HTTPException(status_code=422, detail=str(e))
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Error parsing JSON: {e}")
|
||||
raise HTTPException(status_code=422, detail="Invalid JSON format")
|
||||
|
||||
try:
|
||||
print(
|
||||
f"Processing file {fileRAGRequest.filename} for user {fileRAGRequest.username}."
|
||||
)
|
||||
# check to see if the file has already been processed.
|
||||
# write html to subfolder
|
||||
print(f"Writing file to path: {fileRAGRequest.username}/{fileRAGRequest.filename}...")
|
||||
|
||||
for index, segment in enumerate(files):
|
||||
filename = segment.filename if segment.filename else str(index)
|
||||
subDir = f"{fileRAGRequest.username}"
|
||||
with open(os.path.join(subDir, filename), "wb") as htmlFile:
|
||||
htmlFile.write(await segment.read())
|
||||
|
||||
# write metadata to subfolder
|
||||
print(f"Writing metadata to subfolder {fileRAGRequest.username}...")
|
||||
with open(os.path.join({fileRAGRequest.username}, "metadata.json"), "w") as metadataFile:
|
||||
metaData = {
|
||||
"filename": fileRAGRequest.filename,
|
||||
"username": fileRAGRequest.username,
|
||||
"processDate": datetime.now().isoformat(),
|
||||
}
|
||||
metadataFile.write(json.dumps(metaData))
|
||||
|
||||
vector_store.retrievers[
|
||||
f"{fileRAGRequest.username}/{fileRAGRequest.filename}"
|
||||
] = await vector_store.initRetriever(f"{fileRAGRequest.username}/{fileRAGRequest.filename}")
|
||||
|
||||
return {
|
||||
"message": f"File {fileRAGRequest.filename} processed successfully."
|
||||
}
|
||||
except Exception as e:
|
||||
logging.error(f"Error processing book: {e}")
|
||||
return {"message": f"Error processing book: {e}"}
|
||||
|
||||
|
||||
@app.exception_handler(HTTPException)
|
||||
async def http_exception_handler(bookRequest: Request, exc: HTTPException):
|
||||
logging.error(f"HTTPException: {exc.detail}")
|
||||
return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail})
|
||||
|
@ -0,0 +1,88 @@
|
||||
try:
|
||||
from enum import StrEnum
|
||||
except ImportError:
|
||||
from strenum import StrEnum
|
||||
|
||||
from pydantic import BaseModel
|
||||
from swarms.prompts import QA_PROMPT_TEMPLATE_STR as DefaultSystemPrompt
|
||||
|
||||
class AIModel(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
maxLength: int
|
||||
tokenLimit: int
|
||||
|
||||
|
||||
class AIModels(BaseModel):
|
||||
models: list[AIModel]
|
||||
|
||||
|
||||
class State(StrEnum):
|
||||
Unavailable = "Unavailable"
|
||||
InProcess = "InProcess"
|
||||
Processed = "Processed"
|
||||
|
||||
|
||||
class RAGFile(BaseModel):
|
||||
filename: str
|
||||
title: str
|
||||
username: str
|
||||
state: State = State.Unavailable
|
||||
|
||||
|
||||
class RAGFiles(BaseModel):
|
||||
files: list[RAGFile]
|
||||
|
||||
|
||||
class Role(StrEnum):
|
||||
SYSTEM = "system"
|
||||
ASSISTANT = "assistant"
|
||||
USER = "user"
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
role: Role
|
||||
content: str
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
id: str
|
||||
model: AIModel = AIModel(
|
||||
id="llama-2-70b.Q5_K_M",
|
||||
name="llama-2-70b.Q5_K_M",
|
||||
maxLength=2048,
|
||||
tokenLimit=2048,
|
||||
)
|
||||
messages: list[Message] = [
|
||||
Message(role=Role.SYSTEM, content="Hello, how may I help you?"),
|
||||
Message(role=Role.USER, content=""),
|
||||
]
|
||||
maxTokens: int = 2048
|
||||
temperature: float = 0
|
||||
prompt: str = DefaultSystemPrompt
|
||||
file: RAGFile = RAGFile(filename="None", title="None", username="None")
|
||||
|
||||
|
||||
class LogMessage(BaseModel):
|
||||
message: str
|
||||
|
||||
|
||||
class ConversationRequest(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
title: RAGFile
|
||||
messages: list[Message]
|
||||
model: AIModel
|
||||
prompt: str
|
||||
temperature: float
|
||||
folderId: str | None = None
|
||||
|
||||
|
||||
class ProcessRAGFileRequest(BaseModel):
|
||||
filename: str
|
||||
username: str
|
||||
|
||||
|
||||
class GetRAGFileStateRequest(BaseModel):
|
||||
filename: str
|
||||
username: str
|
@ -0,0 +1,78 @@
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
|
||||
B_INST, E_INST = "[INST]", "[/INST]"
|
||||
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
||||
|
||||
QA_CONDENSE_TEMPLATE_STR = (
|
||||
"Given the following Chat History and a Follow Up Question, "
|
||||
"rephrase the follow up question to be a new Standalone Question, "
|
||||
"but make sure the new question is still asking for the same "
|
||||
"information as the original follow up question. Respond only "
|
||||
" with the new Standalone Question. \n"
|
||||
"Chat History: \n"
|
||||
"{chat_history} \n"
|
||||
"Follow Up Question: {question} \n"
|
||||
"Standalone Question:"
|
||||
)
|
||||
|
||||
CONDENSE_TEMPLATE = PromptTemplate.from_template(
|
||||
f"{B_INST}{B_SYS}{QA_CONDENSE_TEMPLATE_STR.strip()}{E_SYS}{E_INST}"
|
||||
)
|
||||
|
||||
QA_PROMPT_TEMPLATE_STR = (
|
||||
"HUMAN: \n You are a helpful AI assistant. "
|
||||
"Use the following context and chat history to answer the "
|
||||
"question at the end with a helpful answer. "
|
||||
"Get straight to the point and always think things through step-by-step before answering. "
|
||||
"If you don't know the answer, just say 'I don't know'; "
|
||||
"don't try to make up an answer. \n\n"
|
||||
"<context>{context}</context>\n"
|
||||
"<chat_history>{chat_history}</chat_history>\n"
|
||||
"<question>{question}</question>\n\n"
|
||||
"AI: Here is the most relevant sentence in the context: \n"
|
||||
)
|
||||
|
||||
QA_PROMPT_TEMPLATE = PromptTemplate.from_template(
|
||||
f"{B_INST}{B_SYS}{QA_PROMPT_TEMPLATE_STR.strip()}{E_SYS}{E_INST}"
|
||||
)
|
||||
|
||||
DOCUMENT_PROMPT_TEMPLATE = PromptTemplate(
|
||||
input_variables=["page_content"], template="{page_content}"
|
||||
)
|
||||
|
||||
_STUFF_PROMPT_TEMPLATE_STR = "Summarize the following context: {context}"
|
||||
|
||||
STUFF_PROMPT_TEMPLATE = PromptTemplate.from_template(
|
||||
f"{B_INST}{B_SYS}{_STUFF_PROMPT_TEMPLATE_STR.strip()}{E_SYS}{E_INST}"
|
||||
)
|
||||
|
||||
_SUMMARIZER_SYS_TEMPLATE = (
|
||||
B_INST
|
||||
+ B_SYS
|
||||
+ """Progressively summarize the lines of conversation provided, adding onto the previous summary returning a new summary.
|
||||
EXAMPLE
|
||||
Current summary:
|
||||
The human asks what the AI thinks of artificial intelligence. The AI thinks artificial intelligence is a force for good.
|
||||
New lines of conversation:
|
||||
Human: Why do you think artificial intelligence is a force for good?
|
||||
AI: Because artificial intelligence will help humans reach their full potential.
|
||||
New summary:
|
||||
The human asks what the AI thinks of artificial intelligence. The AI thinks artificial intelligence is a force for good because it will help humans reach their full potential.
|
||||
END OF EXAMPLE"""
|
||||
+ E_SYS
|
||||
+ E_INST
|
||||
)
|
||||
|
||||
_SUMMARIZER_INST_TEMPLATE = (
|
||||
B_INST
|
||||
+ """Current summary:
|
||||
{summary}
|
||||
New lines of conversation:
|
||||
{new_lines}
|
||||
New summary:"""
|
||||
+ E_INST
|
||||
)
|
||||
|
||||
SUMMARY_PROMPT = PromptTemplate.from_template(
|
||||
template=(_SUMMARIZER_SYS_TEMPLATE + "\n" + _SUMMARIZER_INST_TEMPLATE).strip()
|
||||
)
|
After Width: | Height: | Size: 146 KiB |
@ -0,0 +1,51 @@
|
||||
# modified from Lanarky source https://github.com/auxon/lanarky
|
||||
from typing import Any
|
||||
|
||||
import pydantic
|
||||
from pydantic.fields import FieldInfo
|
||||
|
||||
try:
|
||||
from enum import StrEnum # type: ignore
|
||||
except ImportError:
|
||||
from enum import Enum
|
||||
|
||||
class StrEnum(str, Enum): ...
|
||||
|
||||
|
||||
PYDANTIC_V2 = pydantic.VERSION.startswith("2.")
|
||||
|
||||
|
||||
def model_dump(model: pydantic.BaseModel, **kwargs) -> dict[str, Any]:
|
||||
"""Dump a pydantic model to a dictionary.
|
||||
|
||||
Args:
|
||||
model: A pydantic model.
|
||||
"""
|
||||
if PYDANTIC_V2:
|
||||
return model.model_dump(**kwargs)
|
||||
else:
|
||||
return model.dict(**kwargs)
|
||||
|
||||
|
||||
def model_dump_json(model: pydantic.BaseModel, **kwargs) -> str:
|
||||
"""Dump a pydantic model to a JSON string.
|
||||
|
||||
Args:
|
||||
model: A pydantic model.
|
||||
"""
|
||||
if PYDANTIC_V2:
|
||||
return model.model_dump_json(**kwargs)
|
||||
else:
|
||||
return model.json(**kwargs)
|
||||
|
||||
|
||||
def model_fields(model: pydantic.BaseModel) -> dict[str, FieldInfo]:
|
||||
"""Get the fields of a pydantic model.
|
||||
|
||||
Args:
|
||||
model: A pydantic model.
|
||||
"""
|
||||
if PYDANTIC_V2:
|
||||
return model.model_fields
|
||||
else:
|
||||
return model.__fields__
|
Loading…
Reference in new issue