parent
ca55d478ae
commit
e41c707fa9
After Width: | Height: | Size: 14 KiB |
@ -0,0 +1,167 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
from abc import ABC
|
||||||
|
from typing import Any, Dict, Optional, Tuple
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
|
||||||
|
from swarms.agents.models.prompts.base import AIMessage, BaseMessage, HumanMessage
|
||||||
|
from swarms.utils.serializable import Serializable
|
||||||
|
from swarms.agents.memory.chat_message_history import ChatMessageHistory
|
||||||
|
|
||||||
|
from langchain.memory.utils import get_prompt_input_key
|
||||||
|
|
||||||
|
|
||||||
|
class BaseMemory(Serializable, ABC):
|
||||||
|
"""Abstract base class for memory in Chains.
|
||||||
|
|
||||||
|
Memory refers to state in Chains. Memory can be used to store information about
|
||||||
|
past executions of a Chain and inject that information into the inputs of
|
||||||
|
future executions of the Chain. For example, for conversational Chains Memory
|
||||||
|
can be used to store conversations and automatically add them to future model
|
||||||
|
prompts so that the model has the necessary context to respond coherently to
|
||||||
|
the latest input.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
class SimpleMemory(BaseMemory):
|
||||||
|
memories: Dict[str, Any] = dict()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def memory_variables(self) -> List[str]:
|
||||||
|
return list(self.memories.keys())
|
||||||
|
|
||||||
|
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||||
|
return self.memories
|
||||||
|
|
||||||
|
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
pass
|
||||||
|
""" # noqa: E501
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def memory_variables(self) -> List[str]:
|
||||||
|
"""The string keys this memory class will add to chain inputs."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Return key-value pairs given the text input to the chain."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||||
|
"""Save the context of this chain run to memory."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""Clear memory contents."""
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class BaseChatMessageHistory(ABC):
|
||||||
|
"""Abstract base class for storing chat message history.
|
||||||
|
|
||||||
|
See `ChatMessageHistory` for default implementation.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
class FileChatMessageHistory(BaseChatMessageHistory):
|
||||||
|
storage_path: str
|
||||||
|
session_id: str
|
||||||
|
|
||||||
|
@property
|
||||||
|
def messages(self):
|
||||||
|
with open(os.path.join(storage_path, session_id), 'r:utf-8') as f:
|
||||||
|
messages = json.loads(f.read())
|
||||||
|
return messages_from_dict(messages)
|
||||||
|
|
||||||
|
def add_message(self, message: BaseMessage) -> None:
|
||||||
|
messages = self.messages.append(_message_to_dict(message))
|
||||||
|
with open(os.path.join(storage_path, session_id), 'w') as f:
|
||||||
|
json.dump(f, messages)
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
with open(os.path.join(storage_path, session_id), 'w') as f:
|
||||||
|
f.write("[]")
|
||||||
|
"""
|
||||||
|
|
||||||
|
messages: List[BaseMessage]
|
||||||
|
"""A list of Messages stored in-memory."""
|
||||||
|
|
||||||
|
def add_user_message(self, message: str) -> None:
|
||||||
|
"""Convenience method for adding a human message string to the store.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: The string contents of a human message.
|
||||||
|
"""
|
||||||
|
self.add_message(HumanMessage(content=message))
|
||||||
|
|
||||||
|
def add_ai_message(self, message: str) -> None:
|
||||||
|
"""Convenience method for adding an AI message string to the store.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: The string contents of an AI message.
|
||||||
|
"""
|
||||||
|
self.add_message(AIMessage(content=message))
|
||||||
|
|
||||||
|
# TODO: Make this an abstractmethod.
|
||||||
|
def add_message(self, message: BaseMessage) -> None:
|
||||||
|
"""Add a Message object to the store.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: A BaseMessage object to store.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""Remove all messages from the store"""
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class BaseChatMemory(BaseMemory, ABC):
|
||||||
|
"""Abstract base class for chat memory."""
|
||||||
|
|
||||||
|
chat_memory: BaseChatMessageHistory = Field(default_factory=ChatMessageHistory)
|
||||||
|
output_key: Optional[str] = None
|
||||||
|
input_key: Optional[str] = None
|
||||||
|
return_messages: bool = False
|
||||||
|
|
||||||
|
def _get_input_output(
|
||||||
|
self, inputs: Dict[str, Any], outputs: Dict[str, str]
|
||||||
|
) -> Tuple[str, str]:
|
||||||
|
if self.input_key is None:
|
||||||
|
prompt_input_key = get_prompt_input_key(inputs, self.memory_variables)
|
||||||
|
else:
|
||||||
|
prompt_input_key = self.input_key
|
||||||
|
if self.output_key is None:
|
||||||
|
if len(outputs) != 1:
|
||||||
|
raise ValueError(f"One output key expected, got {outputs.keys()}")
|
||||||
|
output_key = list(outputs.keys())[0]
|
||||||
|
else:
|
||||||
|
output_key = self.output_key
|
||||||
|
return inputs[prompt_input_key], outputs[output_key]
|
||||||
|
|
||||||
|
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||||
|
"""Save context from this conversation to buffer."""
|
||||||
|
input_str, output_str = self._get_input_output(inputs, outputs)
|
||||||
|
self.chat_memory.add_user_message(input_str)
|
||||||
|
self.chat_memory.add_ai_message(output_str)
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""Clear memory contents."""
|
||||||
|
self.chat_memory.clear()
|
@ -0,0 +1,21 @@
|
|||||||
|
from typing import List
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from swarms.agents.memory.base_memory import BaseChatMessageHistory, BaseMessage
|
||||||
|
|
||||||
|
|
||||||
|
class ChatMessageHistory(BaseChatMessageHistory, BaseModel):
|
||||||
|
"""In memory implementation of chat message history.
|
||||||
|
|
||||||
|
Stores messages in an in memory list.
|
||||||
|
"""
|
||||||
|
|
||||||
|
messages: List[BaseMessage] = []
|
||||||
|
|
||||||
|
def add_message(self, message: BaseMessage) -> None:
|
||||||
|
"""Add a self-created message to the store"""
|
||||||
|
self.messages.append(message)
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
self.messages = []
|
@ -0,0 +1,81 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, Sequence
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from swarms.utils.serializable import Serializable
|
||||||
|
|
||||||
|
class Document(Serializable):
|
||||||
|
"""Class for storing a piece of text and associated metadata."""
|
||||||
|
|
||||||
|
page_content: str
|
||||||
|
"""String text."""
|
||||||
|
metadata: dict = Field(default_factory=dict)
|
||||||
|
"""Arbitrary metadata about the page content (e.g., source, relationships to other
|
||||||
|
documents, etc.).
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class BaseDocumentTransformer(ABC):
|
||||||
|
"""Abstract base class for document transformation systems.
|
||||||
|
|
||||||
|
A document transformation system takes a sequence of Documents and returns a
|
||||||
|
sequence of transformed Documents.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
class EmbeddingsRedundantFilter(BaseDocumentTransformer, BaseModel):
|
||||||
|
embeddings: Embeddings
|
||||||
|
similarity_fn: Callable = cosine_similarity
|
||||||
|
similarity_threshold: float = 0.95
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
def transform_documents(
|
||||||
|
self, documents: Sequence[Document], **kwargs: Any
|
||||||
|
) -> Sequence[Document]:
|
||||||
|
stateful_documents = get_stateful_documents(documents)
|
||||||
|
embedded_documents = _get_embeddings_from_stateful_docs(
|
||||||
|
self.embeddings, stateful_documents
|
||||||
|
)
|
||||||
|
included_idxs = _filter_similar_embeddings(
|
||||||
|
embedded_documents, self.similarity_fn, self.similarity_threshold
|
||||||
|
)
|
||||||
|
return [stateful_documents[i] for i in sorted(included_idxs)]
|
||||||
|
|
||||||
|
async def atransform_documents(
|
||||||
|
self, documents: Sequence[Document], **kwargs: Any
|
||||||
|
) -> Sequence[Document]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
""" # noqa: E501
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def transform_documents(
|
||||||
|
self, documents: Sequence[Document], **kwargs: Any
|
||||||
|
) -> Sequence[Document]:
|
||||||
|
"""Transform a list of documents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
documents: A sequence of Documents to be transformed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of transformed Documents.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def atransform_documents(
|
||||||
|
self, documents: Sequence[Document], **kwargs: Any
|
||||||
|
) -> Sequence[Document]:
|
||||||
|
"""Asynchronously transform a list of documents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
documents: A sequence of Documents to be transformed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of transformed Documents.
|
||||||
|
"""
|
@ -0,0 +1,23 @@
|
|||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
from swarms.agents.memory.base import get_buffer_string
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def get_prompt_input_key(inputs: Dict[str, Any], memory_variables: List[str]) -> str:
|
||||||
|
"""
|
||||||
|
Get the prompt input key.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: Dict[str, Any]
|
||||||
|
memory_variables: List[str]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A prompt input key.
|
||||||
|
"""
|
||||||
|
# "stop" is a special key that can be passed as input but is not used to
|
||||||
|
# format the prompt.
|
||||||
|
prompt_input_keys = list(set(inputs).difference(memory_variables + ["stop"]))
|
||||||
|
if len(prompt_input_keys) != 1:
|
||||||
|
raise ValueError(f"One input key expected got {prompt_input_keys}")
|
||||||
|
return prompt_input_keys[0]
|
@ -0,0 +1,256 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from abc import abstractmethod
|
||||||
|
from typing import TYPE_CHECKING, Any, Dict, List, Sequence
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from swarms.utils.serializable import Serializable
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from langchain.prompts.chat import ChatPromptTemplate
|
||||||
|
|
||||||
|
def get_buffer_string(
|
||||||
|
messages: Sequence[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI"
|
||||||
|
) -> str:
|
||||||
|
"""Convert sequence of Messages to strings and concatenate them into one string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: Messages to be converted to strings.
|
||||||
|
human_prefix: The prefix to prepend to contents of HumanMessages.
|
||||||
|
ai_prefix: THe prefix to prepend to contents of AIMessages.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A single string concatenation of all input messages.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain.schema import AIMessage, HumanMessage
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
HumanMessage(content="Hi, how are you?"),
|
||||||
|
AIMessage(content="Good, how are you?"),
|
||||||
|
]
|
||||||
|
get_buffer_string(messages)
|
||||||
|
# -> "Human: Hi, how are you?\nAI: Good, how are you?"
|
||||||
|
"""
|
||||||
|
string_messages = []
|
||||||
|
for m in messages:
|
||||||
|
if isinstance(m, HumanMessage):
|
||||||
|
role = human_prefix
|
||||||
|
elif isinstance(m, AIMessage):
|
||||||
|
role = ai_prefix
|
||||||
|
elif isinstance(m, SystemMessage):
|
||||||
|
role = "System"
|
||||||
|
elif isinstance(m, FunctionMessage):
|
||||||
|
role = "Function"
|
||||||
|
elif isinstance(m, ChatMessage):
|
||||||
|
role = m.role
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Got unsupported message type: {m}")
|
||||||
|
message = f"{role}: {m.content}"
|
||||||
|
if isinstance(m, AIMessage) and "function_call" in m.additional_kwargs:
|
||||||
|
message += f"{m.additional_kwargs['function_call']}"
|
||||||
|
string_messages.append(message)
|
||||||
|
|
||||||
|
return "\n".join(string_messages)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseMessage(Serializable):
|
||||||
|
"""The base abstract Message class.
|
||||||
|
|
||||||
|
Messages are the inputs and outputs of ChatModels.
|
||||||
|
"""
|
||||||
|
|
||||||
|
content: str
|
||||||
|
"""The string contents of the message."""
|
||||||
|
|
||||||
|
additional_kwargs: dict = Field(default_factory=dict)
|
||||||
|
"""Any additional information."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def type(self) -> str:
|
||||||
|
"""Type of the Message, used for serialization."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_serializable(self) -> bool:
|
||||||
|
"""Whether this class is LangChain serializable."""
|
||||||
|
return True
|
||||||
|
|
||||||
|
def __add__(self, other: Any) -> ChatPromptTemplate:
|
||||||
|
from langchain.prompts.chat import ChatPromptTemplate
|
||||||
|
|
||||||
|
prompt = ChatPromptTemplate(messages=[self])
|
||||||
|
return prompt + other
|
||||||
|
|
||||||
|
|
||||||
|
class BaseMessageChunk(BaseMessage):
|
||||||
|
def _merge_kwargs_dict(
|
||||||
|
self, left: Dict[str, Any], right: Dict[str, Any]
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Merge additional_kwargs from another BaseMessageChunk into this one."""
|
||||||
|
merged = left.copy()
|
||||||
|
for k, v in right.items():
|
||||||
|
if k not in merged:
|
||||||
|
merged[k] = v
|
||||||
|
elif type(merged[k]) != type(v):
|
||||||
|
raise ValueError(
|
||||||
|
f'additional_kwargs["{k}"] already exists in this message,'
|
||||||
|
" but with a different type."
|
||||||
|
)
|
||||||
|
elif isinstance(merged[k], str):
|
||||||
|
merged[k] += v
|
||||||
|
elif isinstance(merged[k], dict):
|
||||||
|
merged[k] = self._merge_kwargs_dict(merged[k], v)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Additional kwargs key {k} already exists in this message."
|
||||||
|
)
|
||||||
|
return merged
|
||||||
|
|
||||||
|
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
|
||||||
|
if isinstance(other, BaseMessageChunk):
|
||||||
|
# If both are (subclasses of) BaseMessageChunk,
|
||||||
|
# concat into a single BaseMessageChunk
|
||||||
|
|
||||||
|
return self.__class__(
|
||||||
|
content=self.content + other.content,
|
||||||
|
additional_kwargs=self._merge_kwargs_dict(
|
||||||
|
self.additional_kwargs, other.additional_kwargs
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
'unsupported operand type(s) for +: "'
|
||||||
|
f"{self.__class__.__name__}"
|
||||||
|
f'" and "{other.__class__.__name__}"'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class HumanMessage(BaseMessage):
|
||||||
|
"""A Message from a human."""
|
||||||
|
|
||||||
|
example: bool = False
|
||||||
|
"""Whether this Message is being passed in to the model as part of an example
|
||||||
|
conversation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def type(self) -> str:
|
||||||
|
"""Type of the message, used for serialization."""
|
||||||
|
return "human"
|
||||||
|
|
||||||
|
|
||||||
|
class HumanMessageChunk(HumanMessage, BaseMessageChunk):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class AIMessage(BaseMessage):
|
||||||
|
"""A Message from an AI."""
|
||||||
|
|
||||||
|
example: bool = False
|
||||||
|
"""Whether this Message is being passed in to the model as part of an example
|
||||||
|
conversation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def type(self) -> str:
|
||||||
|
"""Type of the message, used for serialization."""
|
||||||
|
return "ai"
|
||||||
|
|
||||||
|
|
||||||
|
class AIMessageChunk(AIMessage, BaseMessageChunk):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SystemMessage(BaseMessage):
|
||||||
|
"""A Message for priming AI behavior, usually passed in as the first of a sequence
|
||||||
|
of input messages.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def type(self) -> str:
|
||||||
|
"""Type of the message, used for serialization."""
|
||||||
|
return "system"
|
||||||
|
|
||||||
|
|
||||||
|
class SystemMessageChunk(SystemMessage, BaseMessageChunk):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionMessage(BaseMessage):
|
||||||
|
"""A Message for passing the result of executing a function back to a model."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
"""The name of the function that was executed."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def type(self) -> str:
|
||||||
|
"""Type of the message, used for serialization."""
|
||||||
|
return "function"
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionMessageChunk(FunctionMessage, BaseMessageChunk):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ChatMessage(BaseMessage):
|
||||||
|
"""A Message that can be assigned an arbitrary speaker (i.e. role)."""
|
||||||
|
|
||||||
|
role: str
|
||||||
|
"""The speaker / role of the Message."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def type(self) -> str:
|
||||||
|
"""Type of the message, used for serialization."""
|
||||||
|
return "chat"
|
||||||
|
|
||||||
|
|
||||||
|
class ChatMessageChunk(ChatMessage, BaseMessageChunk):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _message_to_dict(message: BaseMessage) -> dict:
|
||||||
|
return {"type": message.type, "data": message.dict()}
|
||||||
|
|
||||||
|
|
||||||
|
def messages_to_dict(messages: Sequence[BaseMessage]) -> List[dict]:
|
||||||
|
"""Convert a sequence of Messages to a list of dictionaries.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: Sequence of messages (as BaseMessages) to convert.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of messages as dicts.
|
||||||
|
"""
|
||||||
|
return [_message_to_dict(m) for m in messages]
|
||||||
|
|
||||||
|
|
||||||
|
def _message_from_dict(message: dict) -> BaseMessage:
|
||||||
|
_type = message["type"]
|
||||||
|
if _type == "human":
|
||||||
|
return HumanMessage(**message["data"])
|
||||||
|
elif _type == "ai":
|
||||||
|
return AIMessage(**message["data"])
|
||||||
|
elif _type == "system":
|
||||||
|
return SystemMessage(**message["data"])
|
||||||
|
elif _type == "chat":
|
||||||
|
return ChatMessage(**message["data"])
|
||||||
|
elif _type == "function":
|
||||||
|
return FunctionMessage(**message["data"])
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Got unexpected message type: {_type}")
|
||||||
|
|
||||||
|
|
||||||
|
def messages_from_dict(messages: List[dict]) -> List[BaseMessage]:
|
||||||
|
"""Convert a sequence of messages from dicts to Message objects.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: Sequence of messages (as dicts) to convert.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of messages (BaseMessages).
|
||||||
|
"""
|
||||||
|
return [_message_from_dict(m) for m in messages]
|
@ -1,3 +0,0 @@
|
|||||||
.env
|
|
||||||
__pycache__
|
|
||||||
.venv
|
|
@ -0,0 +1,163 @@
|
|||||||
|
from abc import ABC
|
||||||
|
from typing import Any, Dict, List, Literal, TypedDict, Union, cast
|
||||||
|
|
||||||
|
from pydantic import BaseModel, PrivateAttr
|
||||||
|
|
||||||
|
|
||||||
|
class BaseSerialized(TypedDict):
|
||||||
|
"""Base class for serialized objects."""
|
||||||
|
|
||||||
|
lc: int
|
||||||
|
id: List[str]
|
||||||
|
|
||||||
|
|
||||||
|
class SerializedConstructor(BaseSerialized):
|
||||||
|
"""Serialized constructor."""
|
||||||
|
|
||||||
|
type: Literal["constructor"]
|
||||||
|
kwargs: Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class SerializedSecret(BaseSerialized):
|
||||||
|
"""Serialized secret."""
|
||||||
|
|
||||||
|
type: Literal["secret"]
|
||||||
|
|
||||||
|
|
||||||
|
class SerializedNotImplemented(BaseSerialized):
|
||||||
|
"""Serialized not implemented."""
|
||||||
|
|
||||||
|
type: Literal["not_implemented"]
|
||||||
|
|
||||||
|
|
||||||
|
class Serializable(BaseModel, ABC):
|
||||||
|
"""Serializable base class."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_serializable(self) -> bool:
|
||||||
|
"""
|
||||||
|
Return whether or not the class is serializable.
|
||||||
|
"""
|
||||||
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_namespace(self) -> List[str]:
|
||||||
|
"""
|
||||||
|
Return the namespace of the langchain object.
|
||||||
|
eg. ["langchain", "llms", "openai"]
|
||||||
|
"""
|
||||||
|
return self.__class__.__module__.split(".")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_secrets(self) -> Dict[str, str]:
|
||||||
|
"""
|
||||||
|
Return a map of constructor argument names to secret ids.
|
||||||
|
eg. {"openai_api_key": "OPENAI_API_KEY"}
|
||||||
|
"""
|
||||||
|
return dict()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_attributes(self) -> Dict:
|
||||||
|
"""
|
||||||
|
Return a list of attribute names that should be included in the
|
||||||
|
serialized kwargs. These attributes must be accepted by the
|
||||||
|
constructor.
|
||||||
|
"""
|
||||||
|
return {}
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
extra = "ignore"
|
||||||
|
|
||||||
|
_lc_kwargs = PrivateAttr(default_factory=dict)
|
||||||
|
|
||||||
|
def __init__(self, **kwargs: Any) -> None:
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self._lc_kwargs = kwargs
|
||||||
|
|
||||||
|
def to_json(self) -> Union[SerializedConstructor, SerializedNotImplemented]:
|
||||||
|
if not self.lc_serializable:
|
||||||
|
return self.to_json_not_implemented()
|
||||||
|
|
||||||
|
secrets = dict()
|
||||||
|
# Get latest values for kwargs if there is an attribute with same name
|
||||||
|
lc_kwargs = {
|
||||||
|
k: getattr(self, k, v)
|
||||||
|
for k, v in self._lc_kwargs.items()
|
||||||
|
if not (self.__exclude_fields__ or {}).get(k, False) # type: ignore
|
||||||
|
}
|
||||||
|
|
||||||
|
# Merge the lc_secrets and lc_attributes from every class in the MRO
|
||||||
|
for cls in [None, *self.__class__.mro()]:
|
||||||
|
# Once we get to Serializable, we're done
|
||||||
|
if cls is Serializable:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Get a reference to self bound to each class in the MRO
|
||||||
|
this = cast(Serializable, self if cls is None else super(cls, self))
|
||||||
|
|
||||||
|
secrets.update(this.lc_secrets)
|
||||||
|
lc_kwargs.update(this.lc_attributes)
|
||||||
|
|
||||||
|
# include all secrets, even if not specified in kwargs
|
||||||
|
# as these secrets may be passed as an environment variable instead
|
||||||
|
for key in secrets.keys():
|
||||||
|
secret_value = getattr(self, key, None) or lc_kwargs.get(key)
|
||||||
|
if secret_value is not None:
|
||||||
|
lc_kwargs.update({key: secret_value})
|
||||||
|
|
||||||
|
return {
|
||||||
|
"lc": 1,
|
||||||
|
"type": "constructor",
|
||||||
|
"id": [*self.lc_namespace, self.__class__.__name__],
|
||||||
|
"kwargs": lc_kwargs
|
||||||
|
if not secrets
|
||||||
|
else _replace_secrets(lc_kwargs, secrets),
|
||||||
|
}
|
||||||
|
|
||||||
|
def to_json_not_implemented(self) -> SerializedNotImplemented:
|
||||||
|
return to_json_not_implemented(self)
|
||||||
|
|
||||||
|
|
||||||
|
def _replace_secrets(
|
||||||
|
root: Dict[Any, Any], secrets_map: Dict[str, str]
|
||||||
|
) -> Dict[Any, Any]:
|
||||||
|
result = root.copy()
|
||||||
|
for path, secret_id in secrets_map.items():
|
||||||
|
[*parts, last] = path.split(".")
|
||||||
|
current = result
|
||||||
|
for part in parts:
|
||||||
|
if part not in current:
|
||||||
|
break
|
||||||
|
current[part] = current[part].copy()
|
||||||
|
current = current[part]
|
||||||
|
if last in current:
|
||||||
|
current[last] = {
|
||||||
|
"lc": 1,
|
||||||
|
"type": "secret",
|
||||||
|
"id": [secret_id],
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def to_json_not_implemented(obj: object) -> SerializedNotImplemented:
|
||||||
|
"""Serialize a "not implemented" object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obj: object to serialize
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SerializedNotImplemented
|
||||||
|
"""
|
||||||
|
_id: List[str] = []
|
||||||
|
try:
|
||||||
|
if hasattr(obj, "__name__"):
|
||||||
|
_id = [*obj.__module__.split("."), obj.__name__]
|
||||||
|
elif hasattr(obj, "__class__"):
|
||||||
|
_id = [*obj.__class__.__module__.split("."), obj.__class__.__name__]
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return {
|
||||||
|
"lc": 1,
|
||||||
|
"type": "not_implemented",
|
||||||
|
"id": _id,
|
||||||
|
}
|
Loading…
Reference in new issue