antthropic remove pydantic

pull/250/head
evelynmitchell 1 year ago
parent 86c262e43a
commit 1412aef5e2

@ -24,7 +24,7 @@ from langchain.callbacks.manager import (
CallbackManagerForLLMRun,
)
from langchain.llms.base import LLM
from pydantic import model_validator, ConfigDict, Field, SecretStr
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.output import GenerationChunk
from langchain.schema.prompt import PromptValue
@ -219,21 +219,13 @@ def build_extra_kwargs(
return extra_kwargs
def convert_to_secret_str(value: Union[SecretStr, str]) -> SecretStr:
"""Convert a string to a SecretStr if needed."""
if isinstance(value, SecretStr):
return value
return SecretStr(value)
class _AnthropicCommon(BaseLanguageModel):
client: Any = None #: :meta private:
async_client: Any = None #: :meta private:
model: str = Field(default="claude-2", alias="model_name")
model: str ="claude-2"
"""Model name to use."""
max_tokens_to_sample: int = Field(default=256, alias="max_tokens")
max_tokens_to_sample: int =256
"""Denotes the number of tokens to predict per generation."""
temperature: Optional[float] = None
@ -258,9 +250,8 @@ class _AnthropicCommon(BaseLanguageModel):
HUMAN_PROMPT: Optional[str] = None
AI_PROMPT: Optional[str] = None
count_tokens: Optional[Callable[[str], int]] = None
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
model_kwargs: Dict[str, Any] = {}
@model_validator(mode="before")
@classmethod
def build_extra(cls, values: Dict) -> Dict:
extra = values.get("model_kwargs", {})
@ -270,14 +261,11 @@ class _AnthropicCommon(BaseLanguageModel):
)
return values
@model_validator()
@classmethod
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["anthropic_api_key"] = convert_to_secret_str(
get_from_dict_or_env(
values["anthropic_api_key"] = get_from_dict_or_env(
values, "anthropic_api_key", "ANTHROPIC_API_KEY"
)
)
# Get custom api url from environment.
values["anthropic_api_url"] = get_from_dict_or_env(
@ -378,9 +366,7 @@ class Anthropic(LLM, _AnthropicCommon):
prompt = f"{anthropic.HUMAN_PROMPT} {prompt}{anthropic.AI_PROMPT}"
response = model(prompt)
"""
model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True)
@model_validator()
@classmethod
def raise_warning(cls, values: Dict) -> Dict:
"""Raise warning that this class is deprecated."""

Loading…
Cancel
Save