From b0347ac296ff75c662e424ca010e325b085b1684 Mon Sep 17 00:00:00 2001 From: evelynmitchell Date: Sun, 3 Dec 2023 15:19:22 -0700 Subject: [PATCH] pydantic bump --- swarms/memory/schemas.py | 36 ++++++++++++------------- swarms/models/anthropic.py | 18 ++++++------- swarms/models/cohere_chat.py | 11 +++----- swarms/models/dalle3.py | 5 ++-- swarms/models/eleven_labs.py | 5 ++-- swarms/models/fastvit.py | 2 ++ swarms/models/kosmos2.py | 2 ++ swarms/models/openai_embeddings.py | 16 +++++------ swarms/models/openai_function_caller.py | 5 ++-- swarms/models/openai_models.py | 22 ++++++++------- swarms/models/palm.py | 4 ++- swarms/models/ssd_1b.py | 5 ++-- swarms/models/timm.py | 7 ++--- swarms/tools/tool.py | 8 +++--- swarms/utils/serializable.py | 6 ++--- 15 files changed, 77 insertions(+), 75 deletions(-) diff --git a/swarms/memory/schemas.py b/swarms/memory/schemas.py index 9147a909..589a80ae 100644 --- a/swarms/memory/schemas.py +++ b/swarms/memory/schemas.py @@ -12,7 +12,7 @@ class TaskInput(BaseModel): description=( "The input parameters for the task. Any value is allowed." ), - example='{\n"debug": false,\n"mode": "benchmarks"\n}', + examples=['{\n"debug": false,\n"mode": "benchmarks"\n}'], ) @@ -20,17 +20,17 @@ class Artifact(BaseModel): artifact_id: str = Field( ..., description="Id of the artifact", - example="b225e278-8b4c-4f99-a696-8facf19f0e56", + examples=["b225e278-8b4c-4f99-a696-8facf19f0e56"], ) file_name: str = Field( - ..., description="Filename of the artifact", example="main.py" + ..., description="Filename of the artifact", examples=["main.py"] ) relative_path: Optional[str] = Field( None, description=( "Relative path of the artifact in the agent's workspace" ), - example="python/code/", + examples=["python/code/"], ) @@ -41,7 +41,7 @@ class ArtifactUpload(BaseModel): description=( "Relative path of the artifact in the agent's workspace" ), - example="python/code/", + examples=["python/code/"], ) @@ -52,7 +52,7 @@ class StepInput(BaseModel): "Input parameters for the task step. Any value is" " allowed." ), - example='{\n"file_to_refactor": "models.py"\n}', + examples=['{\n"file_to_refactor": "models.py"\n}'], ) @@ -63,7 +63,7 @@ class StepOutput(BaseModel): "Output that the task step has produced. Any value is" " allowed." ), - example='{\n"tokens": 7894,\n"estimated_cost": "0,24$"\n}', + examples=['{\n"tokens": 7894,\n"estimated_cost": "0,24$"\n}'], ) @@ -71,9 +71,9 @@ class TaskRequestBody(BaseModel): input: Optional[str] = Field( None, description="Input prompt for the task.", - example=( + examples=[( "Write the words you receive to the file 'output.txt'." - ), + )], ) additional_input: Optional[TaskInput] = None @@ -82,15 +82,15 @@ class Task(TaskRequestBody): task_id: str = Field( ..., description="The ID of the task.", - example="50da533e-3904-4401-8a07-c49adf88b5eb", + examples=["50da533e-3904-4401-8a07-c49adf88b5eb"], ) artifacts: List[Artifact] = Field( [], description="A list of artifacts that the task has produced.", - example=[ + examples=[[ "7a49f31c-f9c6-4346-a22c-e32bc5af4d8e", "ab7b4091-2560-4692-a4fe-d831ea3ca7d6", - ], + ]], ) @@ -98,7 +98,7 @@ class StepRequestBody(BaseModel): input: Optional[str] = Field( None, description="Input prompt for the step.", - example="Washington", + examples=["Washington"], ) additional_input: Optional[StepInput] = None @@ -113,17 +113,17 @@ class Step(StepRequestBody): task_id: str = Field( ..., description="The ID of the task this step belongs to.", - example="50da533e-3904-4401-8a07-c49adf88b5eb", + examples=["50da533e-3904-4401-8a07-c49adf88b5eb"], ) step_id: str = Field( ..., description="The ID of the task step.", - example="6bb1801a-fd80-45e8-899a-4dd723cc602e", + examples=["6bb1801a-fd80-45e8-899a-4dd723cc602e"], ) name: Optional[str] = Field( None, description="The name of the task step.", - example="Write to file", + examples=["Write to file"], ) status: Status = Field( ..., description="The status of the task step." @@ -131,11 +131,11 @@ class Step(StepRequestBody): output: Optional[str] = Field( None, description="Output of the task step.", - example=( + examples=[( "I am going to use the write_to_file command and write" " Washington to a file called output.txt" " Dict: extra = values.get("model_kwargs", {}) all_required_field_names = get_pydantic_field_names(cls) @@ -269,7 +270,8 @@ class _AnthropicCommon(BaseLanguageModel): ) return values - @root_validator() + @model_validator() + @classmethod def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" values["anthropic_api_key"] = convert_to_secret_str( @@ -376,14 +378,10 @@ class Anthropic(LLM, _AnthropicCommon): prompt = f"{anthropic.HUMAN_PROMPT} {prompt}{anthropic.AI_PROMPT}" response = model(prompt) """ + model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True) - class Config: - """Configuration for this pydantic object.""" - - allow_population_by_field_name = True - arbitrary_types_allowed = True - - @root_validator() + @model_validator() + @classmethod def raise_warning(cls, values: Dict) -> Dict: """Raise warning that this class is deprecated.""" warnings.warn( diff --git a/swarms/models/cohere_chat.py b/swarms/models/cohere_chat.py index 1a31d82e..efd8728a 100644 --- a/swarms/models/cohere_chat.py +++ b/swarms/models/cohere_chat.py @@ -16,7 +16,7 @@ from langchain.callbacks.manager import ( from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens from langchain.load.serializable import Serializable -from pydantic import Extra, Field, root_validator +from pydantic import model_validator, ConfigDict, Field from langchain.utils import get_from_dict_or_env logger = logging.getLogger(__name__) @@ -85,7 +85,8 @@ class BaseCohere(Serializable): user_agent: str = "langchain" """Identifier for the application making the request.""" - @root_validator() + @model_validator() + @classmethod def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" try: @@ -145,11 +146,7 @@ class Cohere(LLM, BaseCohere): max_retries: int = 10 """Maximum number of retries to make when generating.""" - - class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid + model_config = ConfigDict(extra="forbid") @property def _default_params(self) -> Dict[str, Any]: diff --git a/swarms/models/dalle3.py b/swarms/models/dalle3.py index 40f63418..17790c74 100644 --- a/swarms/models/dalle3.py +++ b/swarms/models/dalle3.py @@ -13,7 +13,7 @@ from cachetools import TTLCache from dotenv import load_dotenv from openai import OpenAI from PIL import Image -from pydantic import validator +from pydantic import field_validator from termcolor import colored load_dotenv() @@ -92,7 +92,8 @@ class Dalle3: arbitrary_types_allowed = True - @validator("max_retries", "time_seconds") + @field_validator("max_retries", "time_seconds") + @classmethod def must_be_positive(cls, value): if value <= 0: raise ValueError("Must be positive") diff --git a/swarms/models/eleven_labs.py b/swarms/models/eleven_labs.py index 2d55e864..759c65bb 100644 --- a/swarms/models/eleven_labs.py +++ b/swarms/models/eleven_labs.py @@ -3,7 +3,7 @@ from enum import Enum from typing import Any, Dict, Union from langchain.utils import get_from_dict_or_env -from pydantic import root_validator +from pydantic import model_validator from swarms.tools.tool import BaseTool @@ -59,7 +59,8 @@ class ElevenLabsText2SpeechTool(BaseTool): " Italian, French, Portuguese, and Hindi. " ) - @root_validator(pre=True) + @model_validator(mode="before") + @classmethod def validate_environment(cls, values: Dict) -> Dict: """Validate that api key exists in environment.""" _ = get_from_dict_or_env( diff --git a/swarms/models/fastvit.py b/swarms/models/fastvit.py index a6fc31f8..f3b60587 100644 --- a/swarms/models/fastvit.py +++ b/swarms/models/fastvit.py @@ -20,6 +20,8 @@ class ClassificationResult(BaseModel): class_id: List[StrictInt] confidence: List[StrictFloat] + # TODO[pydantic]: We couldn't refactor the `validator`, please replace it by `field_validator` manually. + # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-validators for more information. @validator("class_id", "confidence", pre=True, each_item=True) def check_list_contents(cls, v): assert isinstance(v, int) or isinstance( diff --git a/swarms/models/kosmos2.py b/swarms/models/kosmos2.py index 9a9a0de3..d251ea23 100644 --- a/swarms/models/kosmos2.py +++ b/swarms/models/kosmos2.py @@ -20,6 +20,8 @@ class Detections(BaseModel): ), "All fields must have the same length." return values + # TODO[pydantic]: We couldn't refactor the `validator`, please replace it by `field_validator` manually. + # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-validators for more information. @validator( "xyxy", "class_id", "confidence", pre=True, each_item=True ) diff --git a/swarms/models/openai_embeddings.py b/swarms/models/openai_embeddings.py index 0cbbdbee..3265a141 100644 --- a/swarms/models/openai_embeddings.py +++ b/swarms/models/openai_embeddings.py @@ -16,7 +16,7 @@ from typing import ( ) import numpy as np -from pydantic import BaseModel, Extra, Field, root_validator +from pydantic import model_validator, ConfigDict, BaseModel, Field from tenacity import ( AsyncRetrying, before_sleep_log, @@ -186,7 +186,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): """ - client: Any #: :meta private: + client: Any = None #: :meta private: model: str = "text-embedding-ada-002" deployment: str = model # to support Azure OpenAI Service custom deployment names openai_api_version: Optional[str] = None @@ -227,13 +227,10 @@ class OpenAIEmbeddings(BaseModel, Embeddings): """Whether to show a progress bar when embedding.""" model_kwargs: Dict[str, Any] = Field(default_factory=dict) """Holds any model parameters valid for `create` call not explicitly specified.""" + model_config = ConfigDict(extra="forbid") - class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid - - @root_validator(pre=True) + @model_validator(mode="before") + @classmethod def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: """Build extra kwargs from additional params that were passed in.""" all_required_field_names = get_pydantic_field_names(cls) @@ -264,7 +261,8 @@ class OpenAIEmbeddings(BaseModel, Embeddings): values["model_kwargs"] = extra return values - @root_validator() + @model_validator() + @classmethod def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" values["openai_api_key"] = get_from_dict_or_env( diff --git a/swarms/models/openai_function_caller.py b/swarms/models/openai_function_caller.py index 6542e457..feb04387 100644 --- a/swarms/models/openai_function_caller.py +++ b/swarms/models/openai_function_caller.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List, Optional, Union import openai import requests -from pydantic import BaseModel, validator +from pydantic import field_validator, BaseModel from tenacity import ( retry, stop_after_attempt, @@ -78,7 +78,8 @@ class FunctionSpecification(BaseModel): parameters: Dict[str, Any] required: Optional[List[str]] = None - @validator("parameters") + @field_validator("parameters") + @classmethod def check_parameters(cls, params): if not isinstance(params, dict): raise ValueError("Parameters must be a dictionary.") diff --git a/swarms/models/openai_models.py b/swarms/models/openai_models.py index 14332ff2..12830cec 100644 --- a/swarms/models/openai_models.py +++ b/swarms/models/openai_models.py @@ -37,6 +37,7 @@ from langchain.utils.utils import build_extra_kwargs from importlib.metadata import version from packaging.version import parse +from pydantic import model_validator, ConfigDict logger = logging.getLogger(__name__) @@ -247,13 +248,10 @@ class BaseOpenAI(BaseLLM): """Initialize the OpenAI object.""" data.get("model_name", "") return super().__new__(cls) + model_config = ConfigDict(populate_by_name=True) - class Config: - """Configuration for this pydantic object.""" - - allow_population_by_field_name = True - - @root_validator(pre=True) + @model_validator(mode="before") + @classmethod def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: """Build extra kwargs from additional params that were passed in.""" all_required_field_names = get_pydantic_field_names(cls) @@ -263,7 +261,8 @@ class BaseOpenAI(BaseLLM): ) return values - @root_validator() + @model_validator() + @classmethod def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" values["openai_api_key"] = get_from_dict_or_env( @@ -758,7 +757,8 @@ class AzureOpenAI(BaseOpenAI): openai_api_type: str = "" openai_api_version: str = "" - @root_validator() + @model_validator() + @classmethod def validate_azure_settings(cls, values: Dict) -> Dict: values["openai_api_version"] = get_from_dict_or_env( values, @@ -847,7 +847,8 @@ class OpenAIChat(BaseLLM): disallowed_special: Union[Literal["all"], Collection[str]] = "all" """Set of special tokens that are not allowed。""" - @root_validator(pre=True) + @model_validator(mode="before") + @classmethod def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: """Build extra kwargs from additional params that were passed in.""" all_required_field_names = { @@ -865,7 +866,8 @@ class OpenAIChat(BaseLLM): values["model_kwargs"] = extra return values - @root_validator() + @model_validator() + @classmethod def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" openai_api_key = get_from_dict_or_env( diff --git a/swarms/models/palm.py b/swarms/models/palm.py index d61d4856..e016a776 100644 --- a/swarms/models/palm.py +++ b/swarms/models/palm.py @@ -15,6 +15,7 @@ from tenacity import ( stop_after_attempt, wait_exponential, ) +from pydantic import model_validator logger = logging.getLogger(__name__) @@ -104,7 +105,8 @@ class GooglePalm(BaseLLM, BaseModel): """Number of chat completions to generate for each prompt. Note that the API may not return the full n completions if duplicates are generated.""" - @root_validator() + @model_validator() + @classmethod def validate_environment(cls, values: Dict) -> Dict: """Validate api key, python package exists.""" google_api_key = get_from_dict_or_env( diff --git a/swarms/models/ssd_1b.py b/swarms/models/ssd_1b.py index d3b9086b..9a905bd4 100644 --- a/swarms/models/ssd_1b.py +++ b/swarms/models/ssd_1b.py @@ -9,7 +9,7 @@ import backoff import torch from diffusers import StableDiffusionXLPipeline from PIL import Image -from pydantic import validator +from pydantic import field_validator from termcolor import colored from cachetools import TTLCache @@ -72,7 +72,8 @@ class SSD1B: arbitrary_types_allowed = True - @validator("max_retries", "time_seconds") + @field_validator("max_retries", "time_seconds") + @classmethod def must_be_positive(cls, value): if value <= 0: raise ValueError("Must be positive") diff --git a/swarms/models/timm.py b/swarms/models/timm.py index d1c42165..8dec0bc9 100644 --- a/swarms/models/timm.py +++ b/swarms/models/timm.py @@ -2,17 +2,14 @@ from typing import List import timm import torch -from pydantic import BaseModel +from pydantic import ConfigDict, BaseModel class TimmModelInfo(BaseModel): model_name: str pretrained: bool in_chans: int - - class Config: - # Use strict typing for all fields - strict = True + model_config = ConfigDict(strict=True) class TimmModel: diff --git a/swarms/tools/tool.py b/swarms/tools/tool.py index 1029a183..ba7752bd 100644 --- a/swarms/tools/tool.py +++ b/swarms/tools/tool.py @@ -30,11 +30,10 @@ from langchain.callbacks.manager import ( from langchain.load.serializable import Serializable from pydantic import ( - BaseModel, + model_validator, BaseModel, Extra, Field, create_model, - root_validator, validate_arguments, ) from langchain.schema.runnable import ( @@ -192,6 +191,8 @@ class ChildTool(BaseTool): ] = False """Handle the content of the ToolException thrown.""" + # TODO[pydantic]: The `Config` class inherits from another class, please create the `model_config` manually. + # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information. class Config(Serializable.Config): """Configuration for this pydantic object.""" @@ -276,7 +277,8 @@ class ChildTool(BaseTool): } return tool_input - @root_validator() + @model_validator() + @classmethod def raise_deprecation(cls, values: Dict) -> Dict: """Raise deprecation warning if callback_manager is used.""" if values.get("callback_manager") is not None: diff --git a/swarms/utils/serializable.py b/swarms/utils/serializable.py index de9444ef..3cc3a5f6 100644 --- a/swarms/utils/serializable.py +++ b/swarms/utils/serializable.py @@ -1,7 +1,7 @@ from abc import ABC from typing import Any, Dict, List, Literal, TypedDict, Union, cast -from pydantic import BaseModel, PrivateAttr +from pydantic import ConfigDict, BaseModel, PrivateAttr class BaseSerialized(TypedDict): @@ -64,9 +64,7 @@ class Serializable(BaseModel, ABC): constructor. """ return {} - - class Config: - extra = "ignore" + model_config = ConfigDict(extra="ignore") _lc_kwargs = PrivateAttr(default_factory=dict)