pydantic bump

pull/250/head
evelynmitchell 1 year ago
parent 915fd0641b
commit b0347ac296

@ -12,7 +12,7 @@ class TaskInput(BaseModel):
description=( description=(
"The input parameters for the task. Any value is allowed." "The input parameters for the task. Any value is allowed."
), ),
example='{\n"debug": false,\n"mode": "benchmarks"\n}', examples=['{\n"debug": false,\n"mode": "benchmarks"\n}'],
) )
@ -20,17 +20,17 @@ class Artifact(BaseModel):
artifact_id: str = Field( artifact_id: str = Field(
..., ...,
description="Id of the artifact", description="Id of the artifact",
example="b225e278-8b4c-4f99-a696-8facf19f0e56", examples=["b225e278-8b4c-4f99-a696-8facf19f0e56"],
) )
file_name: str = Field( file_name: str = Field(
..., description="Filename of the artifact", example="main.py" ..., description="Filename of the artifact", examples=["main.py"]
) )
relative_path: Optional[str] = Field( relative_path: Optional[str] = Field(
None, None,
description=( description=(
"Relative path of the artifact in the agent's workspace" "Relative path of the artifact in the agent's workspace"
), ),
example="python/code/", examples=["python/code/"],
) )
@ -41,7 +41,7 @@ class ArtifactUpload(BaseModel):
description=( description=(
"Relative path of the artifact in the agent's workspace" "Relative path of the artifact in the agent's workspace"
), ),
example="python/code/", examples=["python/code/"],
) )
@ -52,7 +52,7 @@ class StepInput(BaseModel):
"Input parameters for the task step. Any value is" "Input parameters for the task step. Any value is"
" allowed." " allowed."
), ),
example='{\n"file_to_refactor": "models.py"\n}', examples=['{\n"file_to_refactor": "models.py"\n}'],
) )
@ -63,7 +63,7 @@ class StepOutput(BaseModel):
"Output that the task step has produced. Any value is" "Output that the task step has produced. Any value is"
" allowed." " allowed."
), ),
example='{\n"tokens": 7894,\n"estimated_cost": "0,24$"\n}', examples=['{\n"tokens": 7894,\n"estimated_cost": "0,24$"\n}'],
) )
@ -71,9 +71,9 @@ class TaskRequestBody(BaseModel):
input: Optional[str] = Field( input: Optional[str] = Field(
None, None,
description="Input prompt for the task.", description="Input prompt for the task.",
example=( examples=[(
"Write the words you receive to the file 'output.txt'." "Write the words you receive to the file 'output.txt'."
), )],
) )
additional_input: Optional[TaskInput] = None additional_input: Optional[TaskInput] = None
@ -82,15 +82,15 @@ class Task(TaskRequestBody):
task_id: str = Field( task_id: str = Field(
..., ...,
description="The ID of the task.", description="The ID of the task.",
example="50da533e-3904-4401-8a07-c49adf88b5eb", examples=["50da533e-3904-4401-8a07-c49adf88b5eb"],
) )
artifacts: List[Artifact] = Field( artifacts: List[Artifact] = Field(
[], [],
description="A list of artifacts that the task has produced.", description="A list of artifacts that the task has produced.",
example=[ examples=[[
"7a49f31c-f9c6-4346-a22c-e32bc5af4d8e", "7a49f31c-f9c6-4346-a22c-e32bc5af4d8e",
"ab7b4091-2560-4692-a4fe-d831ea3ca7d6", "ab7b4091-2560-4692-a4fe-d831ea3ca7d6",
], ]],
) )
@ -98,7 +98,7 @@ class StepRequestBody(BaseModel):
input: Optional[str] = Field( input: Optional[str] = Field(
None, None,
description="Input prompt for the step.", description="Input prompt for the step.",
example="Washington", examples=["Washington"],
) )
additional_input: Optional[StepInput] = None additional_input: Optional[StepInput] = None
@ -113,17 +113,17 @@ class Step(StepRequestBody):
task_id: str = Field( task_id: str = Field(
..., ...,
description="The ID of the task this step belongs to.", description="The ID of the task this step belongs to.",
example="50da533e-3904-4401-8a07-c49adf88b5eb", examples=["50da533e-3904-4401-8a07-c49adf88b5eb"],
) )
step_id: str = Field( step_id: str = Field(
..., ...,
description="The ID of the task step.", description="The ID of the task step.",
example="6bb1801a-fd80-45e8-899a-4dd723cc602e", examples=["6bb1801a-fd80-45e8-899a-4dd723cc602e"],
) )
name: Optional[str] = Field( name: Optional[str] = Field(
None, None,
description="The name of the task step.", description="The name of the task step.",
example="Write to file", examples=["Write to file"],
) )
status: Status = Field( status: Status = Field(
..., description="The status of the task step." ..., description="The status of the task step."
@ -131,11 +131,11 @@ class Step(StepRequestBody):
output: Optional[str] = Field( output: Optional[str] = Field(
None, None,
description="Output of the task step.", description="Output of the task step.",
example=( examples=[(
"I am going to use the write_to_file command and write" "I am going to use the write_to_file command and write"
" Washington to a file called output.txt" " Washington to a file called output.txt"
" <write_to_file('output.txt', 'Washington')" " <write_to_file('output.txt', 'Washington')"
), )],
) )
additional_output: Optional[StepOutput] = None additional_output: Optional[StepOutput] = None
artifacts: List[Artifact] = Field( artifacts: List[Artifact] = Field(

@ -24,7 +24,7 @@ from langchain.callbacks.manager import (
CallbackManagerForLLMRun, CallbackManagerForLLMRun,
) )
from langchain.llms.base import LLM from langchain.llms.base import LLM
from pydantic import Field, SecretStr, root_validator from pydantic import model_validator, ConfigDict, Field, SecretStr
from langchain.schema.language_model import BaseLanguageModel from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.output import GenerationChunk from langchain.schema.output import GenerationChunk
from langchain.schema.prompt import PromptValue from langchain.schema.prompt import PromptValue
@ -260,7 +260,8 @@ class _AnthropicCommon(BaseLanguageModel):
count_tokens: Optional[Callable[[str], int]] = None count_tokens: Optional[Callable[[str], int]] = None
model_kwargs: Dict[str, Any] = Field(default_factory=dict) model_kwargs: Dict[str, Any] = Field(default_factory=dict)
@root_validator(pre=True) @model_validator(mode="before")
@classmethod
def build_extra(cls, values: Dict) -> Dict: def build_extra(cls, values: Dict) -> Dict:
extra = values.get("model_kwargs", {}) extra = values.get("model_kwargs", {})
all_required_field_names = get_pydantic_field_names(cls) all_required_field_names = get_pydantic_field_names(cls)
@ -269,7 +270,8 @@ class _AnthropicCommon(BaseLanguageModel):
) )
return values return values
@root_validator() @model_validator()
@classmethod
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
values["anthropic_api_key"] = convert_to_secret_str( values["anthropic_api_key"] = convert_to_secret_str(
@ -376,14 +378,10 @@ class Anthropic(LLM, _AnthropicCommon):
prompt = f"{anthropic.HUMAN_PROMPT} {prompt}{anthropic.AI_PROMPT}" prompt = f"{anthropic.HUMAN_PROMPT} {prompt}{anthropic.AI_PROMPT}"
response = model(prompt) response = model(prompt)
""" """
model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True)
class Config: @model_validator()
"""Configuration for this pydantic object.""" @classmethod
allow_population_by_field_name = True
arbitrary_types_allowed = True
@root_validator()
def raise_warning(cls, values: Dict) -> Dict: def raise_warning(cls, values: Dict) -> Dict:
"""Raise warning that this class is deprecated.""" """Raise warning that this class is deprecated."""
warnings.warn( warnings.warn(

@ -16,7 +16,7 @@ from langchain.callbacks.manager import (
from langchain.llms.base import LLM from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens from langchain.llms.utils import enforce_stop_tokens
from langchain.load.serializable import Serializable from langchain.load.serializable import Serializable
from pydantic import Extra, Field, root_validator from pydantic import model_validator, ConfigDict, Field
from langchain.utils import get_from_dict_or_env from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -85,7 +85,8 @@ class BaseCohere(Serializable):
user_agent: str = "langchain" user_agent: str = "langchain"
"""Identifier for the application making the request.""" """Identifier for the application making the request."""
@root_validator() @model_validator()
@classmethod
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
try: try:
@ -145,11 +146,7 @@ class Cohere(LLM, BaseCohere):
max_retries: int = 10 max_retries: int = 10
"""Maximum number of retries to make when generating.""" """Maximum number of retries to make when generating."""
model_config = ConfigDict(extra="forbid")
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
@property @property
def _default_params(self) -> Dict[str, Any]: def _default_params(self) -> Dict[str, Any]:

@ -13,7 +13,7 @@ from cachetools import TTLCache
from dotenv import load_dotenv from dotenv import load_dotenv
from openai import OpenAI from openai import OpenAI
from PIL import Image from PIL import Image
from pydantic import validator from pydantic import field_validator
from termcolor import colored from termcolor import colored
load_dotenv() load_dotenv()
@ -92,7 +92,8 @@ class Dalle3:
arbitrary_types_allowed = True arbitrary_types_allowed = True
@validator("max_retries", "time_seconds") @field_validator("max_retries", "time_seconds")
@classmethod
def must_be_positive(cls, value): def must_be_positive(cls, value):
if value <= 0: if value <= 0:
raise ValueError("Must be positive") raise ValueError("Must be positive")

@ -3,7 +3,7 @@ from enum import Enum
from typing import Any, Dict, Union from typing import Any, Dict, Union
from langchain.utils import get_from_dict_or_env from langchain.utils import get_from_dict_or_env
from pydantic import root_validator from pydantic import model_validator
from swarms.tools.tool import BaseTool from swarms.tools.tool import BaseTool
@ -59,7 +59,8 @@ class ElevenLabsText2SpeechTool(BaseTool):
" Italian, French, Portuguese, and Hindi. " " Italian, French, Portuguese, and Hindi. "
) )
@root_validator(pre=True) @model_validator(mode="before")
@classmethod
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key exists in environment.""" """Validate that api key exists in environment."""
_ = get_from_dict_or_env( _ = get_from_dict_or_env(

@ -20,6 +20,8 @@ class ClassificationResult(BaseModel):
class_id: List[StrictInt] class_id: List[StrictInt]
confidence: List[StrictFloat] confidence: List[StrictFloat]
# TODO[pydantic]: We couldn't refactor the `validator`, please replace it by `field_validator` manually.
# Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-validators for more information.
@validator("class_id", "confidence", pre=True, each_item=True) @validator("class_id", "confidence", pre=True, each_item=True)
def check_list_contents(cls, v): def check_list_contents(cls, v):
assert isinstance(v, int) or isinstance( assert isinstance(v, int) or isinstance(

@ -20,6 +20,8 @@ class Detections(BaseModel):
), "All fields must have the same length." ), "All fields must have the same length."
return values return values
# TODO[pydantic]: We couldn't refactor the `validator`, please replace it by `field_validator` manually.
# Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-validators for more information.
@validator( @validator(
"xyxy", "class_id", "confidence", pre=True, each_item=True "xyxy", "class_id", "confidence", pre=True, each_item=True
) )

@ -16,7 +16,7 @@ from typing import (
) )
import numpy as np import numpy as np
from pydantic import BaseModel, Extra, Field, root_validator from pydantic import model_validator, ConfigDict, BaseModel, Field
from tenacity import ( from tenacity import (
AsyncRetrying, AsyncRetrying,
before_sleep_log, before_sleep_log,
@ -186,7 +186,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
""" """
client: Any #: :meta private: client: Any = None #: :meta private:
model: str = "text-embedding-ada-002" model: str = "text-embedding-ada-002"
deployment: str = model # to support Azure OpenAI Service custom deployment names deployment: str = model # to support Azure OpenAI Service custom deployment names
openai_api_version: Optional[str] = None openai_api_version: Optional[str] = None
@ -227,13 +227,10 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
"""Whether to show a progress bar when embedding.""" """Whether to show a progress bar when embedding."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict) model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified.""" """Holds any model parameters valid for `create` call not explicitly specified."""
model_config = ConfigDict(extra="forbid")
class Config: @model_validator(mode="before")
"""Configuration for this pydantic object.""" @classmethod
extra = Extra.forbid
@root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in.""" """Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls) all_required_field_names = get_pydantic_field_names(cls)
@ -264,7 +261,8 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
values["model_kwargs"] = extra values["model_kwargs"] = extra
return values return values
@root_validator() @model_validator()
@classmethod
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
values["openai_api_key"] = get_from_dict_or_env( values["openai_api_key"] = get_from_dict_or_env(

@ -2,7 +2,7 @@ from typing import Any, Dict, List, Optional, Union
import openai import openai
import requests import requests
from pydantic import BaseModel, validator from pydantic import field_validator, BaseModel
from tenacity import ( from tenacity import (
retry, retry,
stop_after_attempt, stop_after_attempt,
@ -78,7 +78,8 @@ class FunctionSpecification(BaseModel):
parameters: Dict[str, Any] parameters: Dict[str, Any]
required: Optional[List[str]] = None required: Optional[List[str]] = None
@validator("parameters") @field_validator("parameters")
@classmethod
def check_parameters(cls, params): def check_parameters(cls, params):
if not isinstance(params, dict): if not isinstance(params, dict):
raise ValueError("Parameters must be a dictionary.") raise ValueError("Parameters must be a dictionary.")

@ -37,6 +37,7 @@ from langchain.utils.utils import build_extra_kwargs
from importlib.metadata import version from importlib.metadata import version
from packaging.version import parse from packaging.version import parse
from pydantic import model_validator, ConfigDict
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -247,13 +248,10 @@ class BaseOpenAI(BaseLLM):
"""Initialize the OpenAI object.""" """Initialize the OpenAI object."""
data.get("model_name", "") data.get("model_name", "")
return super().__new__(cls) return super().__new__(cls)
model_config = ConfigDict(populate_by_name=True)
class Config: @model_validator(mode="before")
"""Configuration for this pydantic object.""" @classmethod
allow_population_by_field_name = True
@root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in.""" """Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls) all_required_field_names = get_pydantic_field_names(cls)
@ -263,7 +261,8 @@ class BaseOpenAI(BaseLLM):
) )
return values return values
@root_validator() @model_validator()
@classmethod
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
values["openai_api_key"] = get_from_dict_or_env( values["openai_api_key"] = get_from_dict_or_env(
@ -758,7 +757,8 @@ class AzureOpenAI(BaseOpenAI):
openai_api_type: str = "" openai_api_type: str = ""
openai_api_version: str = "" openai_api_version: str = ""
@root_validator() @model_validator()
@classmethod
def validate_azure_settings(cls, values: Dict) -> Dict: def validate_azure_settings(cls, values: Dict) -> Dict:
values["openai_api_version"] = get_from_dict_or_env( values["openai_api_version"] = get_from_dict_or_env(
values, values,
@ -847,7 +847,8 @@ class OpenAIChat(BaseLLM):
disallowed_special: Union[Literal["all"], Collection[str]] = "all" disallowed_special: Union[Literal["all"], Collection[str]] = "all"
"""Set of special tokens that are not allowed。""" """Set of special tokens that are not allowed。"""
@root_validator(pre=True) @model_validator(mode="before")
@classmethod
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in.""" """Build extra kwargs from additional params that were passed in."""
all_required_field_names = { all_required_field_names = {
@ -865,7 +866,8 @@ class OpenAIChat(BaseLLM):
values["model_kwargs"] = extra values["model_kwargs"] = extra
return values return values
@root_validator() @model_validator()
@classmethod
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
openai_api_key = get_from_dict_or_env( openai_api_key = get_from_dict_or_env(

@ -15,6 +15,7 @@ from tenacity import (
stop_after_attempt, stop_after_attempt,
wait_exponential, wait_exponential,
) )
from pydantic import model_validator
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -104,7 +105,8 @@ class GooglePalm(BaseLLM, BaseModel):
"""Number of chat completions to generate for each prompt. Note that the API may """Number of chat completions to generate for each prompt. Note that the API may
not return the full n completions if duplicates are generated.""" not return the full n completions if duplicates are generated."""
@root_validator() @model_validator()
@classmethod
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate api key, python package exists.""" """Validate api key, python package exists."""
google_api_key = get_from_dict_or_env( google_api_key = get_from_dict_or_env(

@ -9,7 +9,7 @@ import backoff
import torch import torch
from diffusers import StableDiffusionXLPipeline from diffusers import StableDiffusionXLPipeline
from PIL import Image from PIL import Image
from pydantic import validator from pydantic import field_validator
from termcolor import colored from termcolor import colored
from cachetools import TTLCache from cachetools import TTLCache
@ -72,7 +72,8 @@ class SSD1B:
arbitrary_types_allowed = True arbitrary_types_allowed = True
@validator("max_retries", "time_seconds") @field_validator("max_retries", "time_seconds")
@classmethod
def must_be_positive(cls, value): def must_be_positive(cls, value):
if value <= 0: if value <= 0:
raise ValueError("Must be positive") raise ValueError("Must be positive")

@ -2,17 +2,14 @@ from typing import List
import timm import timm
import torch import torch
from pydantic import BaseModel from pydantic import ConfigDict, BaseModel
class TimmModelInfo(BaseModel): class TimmModelInfo(BaseModel):
model_name: str model_name: str
pretrained: bool pretrained: bool
in_chans: int in_chans: int
model_config = ConfigDict(strict=True)
class Config:
# Use strict typing for all fields
strict = True
class TimmModel: class TimmModel:

@ -30,11 +30,10 @@ from langchain.callbacks.manager import (
from langchain.load.serializable import Serializable from langchain.load.serializable import Serializable
from pydantic import ( from pydantic import (
BaseModel, model_validator, BaseModel,
Extra, Extra,
Field, Field,
create_model, create_model,
root_validator,
validate_arguments, validate_arguments,
) )
from langchain.schema.runnable import ( from langchain.schema.runnable import (
@ -192,6 +191,8 @@ class ChildTool(BaseTool):
] = False ] = False
"""Handle the content of the ToolException thrown.""" """Handle the content of the ToolException thrown."""
# TODO[pydantic]: The `Config` class inherits from another class, please create the `model_config` manually.
# Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information.
class Config(Serializable.Config): class Config(Serializable.Config):
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
@ -276,7 +277,8 @@ class ChildTool(BaseTool):
} }
return tool_input return tool_input
@root_validator() @model_validator()
@classmethod
def raise_deprecation(cls, values: Dict) -> Dict: def raise_deprecation(cls, values: Dict) -> Dict:
"""Raise deprecation warning if callback_manager is used.""" """Raise deprecation warning if callback_manager is used."""
if values.get("callback_manager") is not None: if values.get("callback_manager") is not None:

@ -1,7 +1,7 @@
from abc import ABC from abc import ABC
from typing import Any, Dict, List, Literal, TypedDict, Union, cast from typing import Any, Dict, List, Literal, TypedDict, Union, cast
from pydantic import BaseModel, PrivateAttr from pydantic import ConfigDict, BaseModel, PrivateAttr
class BaseSerialized(TypedDict): class BaseSerialized(TypedDict):
@ -64,9 +64,7 @@ class Serializable(BaseModel, ABC):
constructor. constructor.
""" """
return {} return {}
model_config = ConfigDict(extra="ignore")
class Config:
extra = "ignore"
_lc_kwargs = PrivateAttr(default_factory=dict) _lc_kwargs = PrivateAttr(default_factory=dict)

Loading…
Cancel
Save