You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
163 lines
4.5 KiB
163 lines
4.5 KiB
from abc import ABC
|
|
from typing import Any, Dict, List, Literal, TypedDict, Union, cast
|
|
|
|
from pydantic import BaseModel, PrivateAttr
|
|
|
|
|
|
class BaseSerialized(TypedDict):
|
|
"""Base class for serialized objects."""
|
|
|
|
lc: int
|
|
id: List[str]
|
|
|
|
|
|
class SerializedConstructor(BaseSerialized):
|
|
"""Serialized constructor."""
|
|
|
|
type: Literal["constructor"]
|
|
kwargs: Dict[str, Any]
|
|
|
|
|
|
class SerializedSecret(BaseSerialized):
|
|
"""Serialized secret."""
|
|
|
|
type: Literal["secret"]
|
|
|
|
|
|
class SerializedNotImplemented(BaseSerialized):
|
|
"""Serialized not implemented."""
|
|
|
|
type: Literal["not_implemented"]
|
|
|
|
|
|
class Serializable(BaseModel, ABC):
|
|
"""Serializable base class."""
|
|
|
|
@property
|
|
def lc_serializable(self) -> bool:
|
|
"""
|
|
Return whether or not the class is serializable.
|
|
"""
|
|
return False
|
|
|
|
@property
|
|
def lc_namespace(self) -> List[str]:
|
|
"""
|
|
Return the namespace of the langchain object.
|
|
eg. ["langchain", "llms", "openai"]
|
|
"""
|
|
return self.__class__.__module__.split(".")
|
|
|
|
@property
|
|
def lc_secrets(self) -> Dict[str, str]:
|
|
"""
|
|
Return a map of constructor argument names to secret ids.
|
|
eg. {"openai_api_key": "OPENAI_API_KEY"}
|
|
"""
|
|
return dict()
|
|
|
|
@property
|
|
def lc_attributes(self) -> Dict:
|
|
"""
|
|
Return a list of attribute names that should be included in the
|
|
serialized kwargs. These attributes must be accepted by the
|
|
constructor.
|
|
"""
|
|
return {}
|
|
|
|
class Config:
|
|
extra = "ignore"
|
|
|
|
_lc_kwargs = PrivateAttr(default_factory=dict)
|
|
|
|
def __init__(self, **kwargs: Any) -> None:
|
|
super().__init__(**kwargs)
|
|
self._lc_kwargs = kwargs
|
|
|
|
def to_json(self) -> Union[SerializedConstructor, SerializedNotImplemented]:
|
|
if not self.lc_serializable:
|
|
return self.to_json_not_implemented()
|
|
|
|
secrets = dict()
|
|
# Get latest values for kwargs if there is an attribute with same name
|
|
lc_kwargs = {
|
|
k: getattr(self, k, v)
|
|
for k, v in self._lc_kwargs.items()
|
|
if not (self.__exclude_fields__ or {}).get(k, False) # type: ignore
|
|
}
|
|
|
|
# Merge the lc_secrets and lc_attributes from every class in the MRO
|
|
for cls in [None, *self.__class__.mro()]:
|
|
# Once we get to Serializable, we're done
|
|
if cls is Serializable:
|
|
break
|
|
|
|
# Get a reference to self bound to each class in the MRO
|
|
this = cast(Serializable, self if cls is None else super(cls, self))
|
|
|
|
secrets.update(this.lc_secrets)
|
|
lc_kwargs.update(this.lc_attributes)
|
|
|
|
# include all secrets, even if not specified in kwargs
|
|
# as these secrets may be passed as an environment variable instead
|
|
for key in secrets.keys():
|
|
secret_value = getattr(self, key, None) or lc_kwargs.get(key)
|
|
if secret_value is not None:
|
|
lc_kwargs.update({key: secret_value})
|
|
|
|
return {
|
|
"lc": 1,
|
|
"type": "constructor",
|
|
"id": [*self.lc_namespace, self.__class__.__name__],
|
|
"kwargs": lc_kwargs
|
|
if not secrets
|
|
else _replace_secrets(lc_kwargs, secrets),
|
|
}
|
|
|
|
def to_json_not_implemented(self) -> SerializedNotImplemented:
|
|
return to_json_not_implemented(self)
|
|
|
|
|
|
def _replace_secrets(
|
|
root: Dict[Any, Any], secrets_map: Dict[str, str]
|
|
) -> Dict[Any, Any]:
|
|
result = root.copy()
|
|
for path, secret_id in secrets_map.items():
|
|
[*parts, last] = path.split(".")
|
|
current = result
|
|
for part in parts:
|
|
if part not in current:
|
|
break
|
|
current[part] = current[part].copy()
|
|
current = current[part]
|
|
if last in current:
|
|
current[last] = {
|
|
"lc": 1,
|
|
"type": "secret",
|
|
"id": [secret_id],
|
|
}
|
|
return result
|
|
|
|
|
|
def to_json_not_implemented(obj: object) -> SerializedNotImplemented:
|
|
"""Serialize a "not implemented" object.
|
|
|
|
Args:
|
|
obj: object to serialize
|
|
|
|
Returns:
|
|
SerializedNotImplemented
|
|
"""
|
|
_id: List[str] = []
|
|
try:
|
|
if hasattr(obj, "__name__"):
|
|
_id = [*obj.__module__.split("."), obj.__name__]
|
|
elif hasattr(obj, "__class__"):
|
|
_id = [*obj.__class__.__module__.split("."), obj.__class__.__name__]
|
|
except Exception:
|
|
pass
|
|
return {
|
|
"lc": 1,
|
|
"type": "not_implemented",
|
|
"id": _id,
|
|
} |