From 1412aef5e2ec9271adcfa123afb8f50f67f685bd Mon Sep 17 00:00:00 2001 From: evelynmitchell Date: Sun, 3 Dec 2023 17:34:01 -0700 Subject: [PATCH] antthropic remove pydantic --- swarms/models/anthropic.py | 26 ++++++-------------------- 1 file changed, 6 insertions(+), 20 deletions(-) diff --git a/swarms/models/anthropic.py b/swarms/models/anthropic.py index fe30ac4f..1e61daa8 100644 --- a/swarms/models/anthropic.py +++ b/swarms/models/anthropic.py @@ -24,7 +24,7 @@ from langchain.callbacks.manager import ( CallbackManagerForLLMRun, ) from langchain.llms.base import LLM -from pydantic import model_validator, ConfigDict, Field, SecretStr + from langchain.schema.language_model import BaseLanguageModel from langchain.schema.output import GenerationChunk from langchain.schema.prompt import PromptValue @@ -219,21 +219,13 @@ def build_extra_kwargs( return extra_kwargs - -def convert_to_secret_str(value: Union[SecretStr, str]) -> SecretStr: - """Convert a string to a SecretStr if needed.""" - if isinstance(value, SecretStr): - return value - return SecretStr(value) - - class _AnthropicCommon(BaseLanguageModel): client: Any = None #: :meta private: async_client: Any = None #: :meta private: - model: str = Field(default="claude-2", alias="model_name") + model: str ="claude-2" """Model name to use.""" - max_tokens_to_sample: int = Field(default=256, alias="max_tokens") + max_tokens_to_sample: int =256 """Denotes the number of tokens to predict per generation.""" temperature: Optional[float] = None @@ -258,9 +250,8 @@ class _AnthropicCommon(BaseLanguageModel): HUMAN_PROMPT: Optional[str] = None AI_PROMPT: Optional[str] = None count_tokens: Optional[Callable[[str], int]] = None - model_kwargs: Dict[str, Any] = Field(default_factory=dict) + model_kwargs: Dict[str, Any] = {} - @model_validator(mode="before") @classmethod def build_extra(cls, values: Dict) -> Dict: extra = values.get("model_kwargs", {}) @@ -270,14 +261,11 @@ class _AnthropicCommon(BaseLanguageModel): ) return values - @model_validator() @classmethod def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" - values["anthropic_api_key"] = convert_to_secret_str( - get_from_dict_or_env( + values["anthropic_api_key"] = get_from_dict_or_env( values, "anthropic_api_key", "ANTHROPIC_API_KEY" - ) ) # Get custom api url from environment. values["anthropic_api_url"] = get_from_dict_or_env( @@ -378,9 +366,7 @@ class Anthropic(LLM, _AnthropicCommon): prompt = f"{anthropic.HUMAN_PROMPT} {prompt}{anthropic.AI_PROMPT}" response = model(prompt) """ - model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True) - - @model_validator() + @classmethod def raise_warning(cls, values: Dict) -> Dict: """Raise warning that this class is deprecated."""