|
|
|
@ -24,7 +24,7 @@ from langchain.callbacks.manager import (
|
|
|
|
|
CallbackManagerForLLMRun,
|
|
|
|
|
)
|
|
|
|
|
from langchain.llms.base import LLM
|
|
|
|
|
from pydantic import model_validator, ConfigDict, Field, SecretStr
|
|
|
|
|
|
|
|
|
|
from langchain.schema.language_model import BaseLanguageModel
|
|
|
|
|
from langchain.schema.output import GenerationChunk
|
|
|
|
|
from langchain.schema.prompt import PromptValue
|
|
|
|
@ -219,21 +219,13 @@ def build_extra_kwargs(
|
|
|
|
|
|
|
|
|
|
return extra_kwargs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def convert_to_secret_str(value: Union[SecretStr, str]) -> SecretStr:
|
|
|
|
|
"""Convert a string to a SecretStr if needed."""
|
|
|
|
|
if isinstance(value, SecretStr):
|
|
|
|
|
return value
|
|
|
|
|
return SecretStr(value)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _AnthropicCommon(BaseLanguageModel):
|
|
|
|
|
client: Any = None #: :meta private:
|
|
|
|
|
async_client: Any = None #: :meta private:
|
|
|
|
|
model: str = Field(default="claude-2", alias="model_name")
|
|
|
|
|
model: str ="claude-2"
|
|
|
|
|
"""Model name to use."""
|
|
|
|
|
|
|
|
|
|
max_tokens_to_sample: int = Field(default=256, alias="max_tokens")
|
|
|
|
|
max_tokens_to_sample: int =256
|
|
|
|
|
"""Denotes the number of tokens to predict per generation."""
|
|
|
|
|
|
|
|
|
|
temperature: Optional[float] = None
|
|
|
|
@ -258,9 +250,8 @@ class _AnthropicCommon(BaseLanguageModel):
|
|
|
|
|
HUMAN_PROMPT: Optional[str] = None
|
|
|
|
|
AI_PROMPT: Optional[str] = None
|
|
|
|
|
count_tokens: Optional[Callable[[str], int]] = None
|
|
|
|
|
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
|
|
|
|
model_kwargs: Dict[str, Any] = {}
|
|
|
|
|
|
|
|
|
|
@model_validator(mode="before")
|
|
|
|
|
@classmethod
|
|
|
|
|
def build_extra(cls, values: Dict) -> Dict:
|
|
|
|
|
extra = values.get("model_kwargs", {})
|
|
|
|
@ -270,14 +261,11 @@ class _AnthropicCommon(BaseLanguageModel):
|
|
|
|
|
)
|
|
|
|
|
return values
|
|
|
|
|
|
|
|
|
|
@model_validator()
|
|
|
|
|
@classmethod
|
|
|
|
|
def validate_environment(cls, values: Dict) -> Dict:
|
|
|
|
|
"""Validate that api key and python package exists in environment."""
|
|
|
|
|
values["anthropic_api_key"] = convert_to_secret_str(
|
|
|
|
|
get_from_dict_or_env(
|
|
|
|
|
values["anthropic_api_key"] = get_from_dict_or_env(
|
|
|
|
|
values, "anthropic_api_key", "ANTHROPIC_API_KEY"
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
# Get custom api url from environment.
|
|
|
|
|
values["anthropic_api_url"] = get_from_dict_or_env(
|
|
|
|
@ -378,9 +366,7 @@ class Anthropic(LLM, _AnthropicCommon):
|
|
|
|
|
prompt = f"{anthropic.HUMAN_PROMPT} {prompt}{anthropic.AI_PROMPT}"
|
|
|
|
|
response = model(prompt)
|
|
|
|
|
"""
|
|
|
|
|
model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True)
|
|
|
|
|
|
|
|
|
|
@model_validator()
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def raise_warning(cls, values: Dict) -> Dict:
|
|
|
|
|
"""Raise warning that this class is deprecated."""
|
|
|
|
|