You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
258 lines
7.3 KiB
258 lines
7.3 KiB
from __future__ import annotations
|
|
|
|
from abc import abstractmethod
|
|
from typing import TYPE_CHECKING, Any, Dict, List, Sequence
|
|
|
|
from pydantic import Field
|
|
|
|
from swarms.utils.serializable import Serializable
|
|
|
|
if TYPE_CHECKING:
|
|
from langchain.prompts.chat import ChatPromptTemplate
|
|
|
|
|
|
def get_buffer_string(
|
|
messages: Sequence[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI"
|
|
) -> str:
|
|
"""Convert sequence of Messages to strings and concatenate them into one string.
|
|
|
|
Args:
|
|
messages: Messages to be converted to strings.
|
|
human_prefix: The prefix to prepend to contents of HumanMessages.
|
|
ai_prefix: THe prefix to prepend to contents of AIMessages.
|
|
|
|
Returns:
|
|
A single string concatenation of all input messages.
|
|
|
|
Example:
|
|
.. code-block:: python
|
|
|
|
from langchain.schema import AIMessage, HumanMessage
|
|
|
|
messages = [
|
|
HumanMessage(content="Hi, how are you?"),
|
|
AIMessage(content="Good, how are you?"),
|
|
]
|
|
get_buffer_string(messages)
|
|
# -> "Human: Hi, how are you?\nAI: Good, how are you?"
|
|
"""
|
|
string_messages = []
|
|
for m in messages:
|
|
if isinstance(m, HumanMessage):
|
|
role = human_prefix
|
|
elif isinstance(m, AIMessage):
|
|
role = ai_prefix
|
|
elif isinstance(m, SystemMessage):
|
|
role = "System"
|
|
elif isinstance(m, FunctionMessage):
|
|
role = "Function"
|
|
elif isinstance(m, ChatMessage):
|
|
role = m.role
|
|
else:
|
|
raise ValueError(f"Got unsupported message type: {m}")
|
|
message = f"{role}: {m.content}"
|
|
if isinstance(m, AIMessage) and "function_call" in m.additional_kwargs:
|
|
message += f"{m.additional_kwargs['function_call']}"
|
|
string_messages.append(message)
|
|
|
|
return "\n".join(string_messages)
|
|
|
|
|
|
class BaseMessage(Serializable):
|
|
"""The base abstract Message class.
|
|
|
|
Messages are the inputs and outputs of ChatModels.
|
|
"""
|
|
|
|
content: str
|
|
"""The string contents of the message."""
|
|
|
|
additional_kwargs: dict = Field(default_factory=dict)
|
|
"""Any additional information."""
|
|
|
|
@property
|
|
@abstractmethod
|
|
def type(self) -> str:
|
|
"""Type of the Message, used for serialization."""
|
|
|
|
@property
|
|
def lc_serializable(self) -> bool:
|
|
"""Whether this class is LangChain serializable."""
|
|
return True
|
|
|
|
def __add__(self, other: Any) -> ChatPromptTemplate:
|
|
from langchain.prompts.chat import ChatPromptTemplate
|
|
|
|
prompt = ChatPromptTemplate(messages=[self])
|
|
return prompt + other
|
|
|
|
|
|
class BaseMessageChunk(BaseMessage):
|
|
def _merge_kwargs_dict(
|
|
self, left: Dict[str, Any], right: Dict[str, Any]
|
|
) -> Dict[str, Any]:
|
|
"""Merge additional_kwargs from another BaseMessageChunk into this one."""
|
|
merged = left.copy()
|
|
for k, v in right.items():
|
|
if k not in merged:
|
|
merged[k] = v
|
|
elif not isinstance(merged[k], type(v)):
|
|
raise ValueError(
|
|
f'additional_kwargs["{k}"] already exists in this message,'
|
|
" but with a different type."
|
|
)
|
|
elif isinstance(merged[k], str):
|
|
merged[k] += v
|
|
elif isinstance(merged[k], dict):
|
|
merged[k] = self._merge_kwargs_dict(merged[k], v)
|
|
else:
|
|
raise ValueError(
|
|
f"Additional kwargs key {k} already exists in this message."
|
|
)
|
|
return merged
|
|
|
|
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
|
|
if isinstance(other, BaseMessageChunk):
|
|
# If both are (subclasses of) BaseMessageChunk,
|
|
# concat into a single BaseMessageChunk
|
|
|
|
return self.__class__(
|
|
content=self.content + other.content,
|
|
additional_kwargs=self._merge_kwargs_dict(
|
|
self.additional_kwargs, other.additional_kwargs
|
|
),
|
|
)
|
|
else:
|
|
raise TypeError(
|
|
'unsupported operand type(s) for +: "'
|
|
f"{self.__class__.__name__}"
|
|
f'" and "{other.__class__.__name__}"'
|
|
)
|
|
|
|
|
|
class HumanMessage(BaseMessage):
|
|
"""A Message from a human."""
|
|
|
|
example: bool = False
|
|
"""Whether this Message is being passed in to the model as part of an example
|
|
conversation.
|
|
"""
|
|
|
|
@property
|
|
def type(self) -> str:
|
|
"""Type of the message, used for serialization."""
|
|
return "human"
|
|
|
|
|
|
class HumanMessageChunk(HumanMessage, BaseMessageChunk):
|
|
pass
|
|
|
|
|
|
class AIMessage(BaseMessage):
|
|
"""A Message from an AI."""
|
|
|
|
example: bool = False
|
|
"""Whether this Message is being passed in to the model as part of an example
|
|
conversation.
|
|
"""
|
|
|
|
@property
|
|
def type(self) -> str:
|
|
"""Type of the message, used for serialization."""
|
|
return "ai"
|
|
|
|
|
|
class AIMessageChunk(AIMessage, BaseMessageChunk):
|
|
pass
|
|
|
|
|
|
class SystemMessage(BaseMessage):
|
|
"""A Message for priming AI behavior, usually passed in as the first of a sequence
|
|
of input messages.
|
|
"""
|
|
|
|
@property
|
|
def type(self) -> str:
|
|
"""Type of the message, used for serialization."""
|
|
return "system"
|
|
|
|
|
|
class SystemMessageChunk(SystemMessage, BaseMessageChunk):
|
|
pass
|
|
|
|
|
|
class FunctionMessage(BaseMessage):
|
|
"""A Message for passing the result of executing a function back to a model."""
|
|
|
|
name: str
|
|
"""The name of the function that was executed."""
|
|
|
|
@property
|
|
def type(self) -> str:
|
|
"""Type of the message, used for serialization."""
|
|
return "function"
|
|
|
|
|
|
class FunctionMessageChunk(FunctionMessage, BaseMessageChunk):
|
|
pass
|
|
|
|
|
|
class ChatMessage(BaseMessage):
|
|
"""A Message that can be assigned an arbitrary speaker (i.e. role)."""
|
|
|
|
role: str
|
|
"""The speaker / role of the Message."""
|
|
|
|
@property
|
|
def type(self) -> str:
|
|
"""Type of the message, used for serialization."""
|
|
return "chat"
|
|
|
|
|
|
class ChatMessageChunk(ChatMessage, BaseMessageChunk):
|
|
pass
|
|
|
|
|
|
def _message_to_dict(message: BaseMessage) -> dict:
|
|
return {"type": message.type, "data": message.dict()}
|
|
|
|
|
|
def messages_to_dict(messages: Sequence[BaseMessage]) -> List[dict]:
|
|
"""Convert a sequence of Messages to a list of dictionaries.
|
|
|
|
Args:
|
|
messages: Sequence of messages (as BaseMessages) to convert.
|
|
|
|
Returns:
|
|
List of messages as dicts.
|
|
"""
|
|
return [_message_to_dict(m) for m in messages]
|
|
|
|
|
|
def _message_from_dict(message: dict) -> BaseMessage:
|
|
_type = message["type"]
|
|
if _type == "human":
|
|
return HumanMessage(**message["data"])
|
|
elif _type == "ai":
|
|
return AIMessage(**message["data"])
|
|
elif _type == "system":
|
|
return SystemMessage(**message["data"])
|
|
elif _type == "chat":
|
|
return ChatMessage(**message["data"])
|
|
elif _type == "function":
|
|
return FunctionMessage(**message["data"])
|
|
else:
|
|
raise ValueError(f"Got unexpected message type: {_type}")
|
|
|
|
|
|
def messages_from_dict(messages: List[dict]) -> List[BaseMessage]:
|
|
"""Convert a sequence of messages from dicts to Message objects.
|
|
|
|
Args:
|
|
messages: Sequence of messages (as dicts) to convert.
|
|
|
|
Returns:
|
|
List of messages (BaseMessages).
|
|
"""
|
|
return [_message_from_dict(m) for m in messages]
|