pydantic bump

pull/250/head
evelynmitchell 1 year ago
parent 915fd0641b
commit b0347ac296

@ -12,7 +12,7 @@ class TaskInput(BaseModel):
description=(
"The input parameters for the task. Any value is allowed."
),
example='{\n"debug": false,\n"mode": "benchmarks"\n}',
examples=['{\n"debug": false,\n"mode": "benchmarks"\n}'],
)
@ -20,17 +20,17 @@ class Artifact(BaseModel):
artifact_id: str = Field(
...,
description="Id of the artifact",
example="b225e278-8b4c-4f99-a696-8facf19f0e56",
examples=["b225e278-8b4c-4f99-a696-8facf19f0e56"],
)
file_name: str = Field(
..., description="Filename of the artifact", example="main.py"
..., description="Filename of the artifact", examples=["main.py"]
)
relative_path: Optional[str] = Field(
None,
description=(
"Relative path of the artifact in the agent's workspace"
),
example="python/code/",
examples=["python/code/"],
)
@ -41,7 +41,7 @@ class ArtifactUpload(BaseModel):
description=(
"Relative path of the artifact in the agent's workspace"
),
example="python/code/",
examples=["python/code/"],
)
@ -52,7 +52,7 @@ class StepInput(BaseModel):
"Input parameters for the task step. Any value is"
" allowed."
),
example='{\n"file_to_refactor": "models.py"\n}',
examples=['{\n"file_to_refactor": "models.py"\n}'],
)
@ -63,7 +63,7 @@ class StepOutput(BaseModel):
"Output that the task step has produced. Any value is"
" allowed."
),
example='{\n"tokens": 7894,\n"estimated_cost": "0,24$"\n}',
examples=['{\n"tokens": 7894,\n"estimated_cost": "0,24$"\n}'],
)
@ -71,9 +71,9 @@ class TaskRequestBody(BaseModel):
input: Optional[str] = Field(
None,
description="Input prompt for the task.",
example=(
examples=[(
"Write the words you receive to the file 'output.txt'."
),
)],
)
additional_input: Optional[TaskInput] = None
@ -82,15 +82,15 @@ class Task(TaskRequestBody):
task_id: str = Field(
...,
description="The ID of the task.",
example="50da533e-3904-4401-8a07-c49adf88b5eb",
examples=["50da533e-3904-4401-8a07-c49adf88b5eb"],
)
artifacts: List[Artifact] = Field(
[],
description="A list of artifacts that the task has produced.",
example=[
examples=[[
"7a49f31c-f9c6-4346-a22c-e32bc5af4d8e",
"ab7b4091-2560-4692-a4fe-d831ea3ca7d6",
],
]],
)
@ -98,7 +98,7 @@ class StepRequestBody(BaseModel):
input: Optional[str] = Field(
None,
description="Input prompt for the step.",
example="Washington",
examples=["Washington"],
)
additional_input: Optional[StepInput] = None
@ -113,17 +113,17 @@ class Step(StepRequestBody):
task_id: str = Field(
...,
description="The ID of the task this step belongs to.",
example="50da533e-3904-4401-8a07-c49adf88b5eb",
examples=["50da533e-3904-4401-8a07-c49adf88b5eb"],
)
step_id: str = Field(
...,
description="The ID of the task step.",
example="6bb1801a-fd80-45e8-899a-4dd723cc602e",
examples=["6bb1801a-fd80-45e8-899a-4dd723cc602e"],
)
name: Optional[str] = Field(
None,
description="The name of the task step.",
example="Write to file",
examples=["Write to file"],
)
status: Status = Field(
..., description="The status of the task step."
@ -131,11 +131,11 @@ class Step(StepRequestBody):
output: Optional[str] = Field(
None,
description="Output of the task step.",
example=(
examples=[(
"I am going to use the write_to_file command and write"
" Washington to a file called output.txt"
" <write_to_file('output.txt', 'Washington')"
),
)],
)
additional_output: Optional[StepOutput] = None
artifacts: List[Artifact] = Field(

@ -24,7 +24,7 @@ from langchain.callbacks.manager import (
CallbackManagerForLLMRun,
)
from langchain.llms.base import LLM
from pydantic import Field, SecretStr, root_validator
from pydantic import model_validator, ConfigDict, Field, SecretStr
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.output import GenerationChunk
from langchain.schema.prompt import PromptValue
@ -260,7 +260,8 @@ class _AnthropicCommon(BaseLanguageModel):
count_tokens: Optional[Callable[[str], int]] = None
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
@root_validator(pre=True)
@model_validator(mode="before")
@classmethod
def build_extra(cls, values: Dict) -> Dict:
extra = values.get("model_kwargs", {})
all_required_field_names = get_pydantic_field_names(cls)
@ -269,7 +270,8 @@ class _AnthropicCommon(BaseLanguageModel):
)
return values
@root_validator()
@model_validator()
@classmethod
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["anthropic_api_key"] = convert_to_secret_str(
@ -376,14 +378,10 @@ class Anthropic(LLM, _AnthropicCommon):
prompt = f"{anthropic.HUMAN_PROMPT} {prompt}{anthropic.AI_PROMPT}"
response = model(prompt)
"""
model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True)
class Config:
"""Configuration for this pydantic object."""
allow_population_by_field_name = True
arbitrary_types_allowed = True
@root_validator()
@model_validator()
@classmethod
def raise_warning(cls, values: Dict) -> Dict:
"""Raise warning that this class is deprecated."""
warnings.warn(

@ -16,7 +16,7 @@ from langchain.callbacks.manager import (
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
from langchain.load.serializable import Serializable
from pydantic import Extra, Field, root_validator
from pydantic import model_validator, ConfigDict, Field
from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__name__)
@ -85,7 +85,8 @@ class BaseCohere(Serializable):
user_agent: str = "langchain"
"""Identifier for the application making the request."""
@root_validator()
@model_validator()
@classmethod
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
try:
@ -145,11 +146,7 @@ class Cohere(LLM, BaseCohere):
max_retries: int = 10
"""Maximum number of retries to make when generating."""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
model_config = ConfigDict(extra="forbid")
@property
def _default_params(self) -> Dict[str, Any]:

@ -13,7 +13,7 @@ from cachetools import TTLCache
from dotenv import load_dotenv
from openai import OpenAI
from PIL import Image
from pydantic import validator
from pydantic import field_validator
from termcolor import colored
load_dotenv()
@ -92,7 +92,8 @@ class Dalle3:
arbitrary_types_allowed = True
@validator("max_retries", "time_seconds")
@field_validator("max_retries", "time_seconds")
@classmethod
def must_be_positive(cls, value):
if value <= 0:
raise ValueError("Must be positive")

@ -3,7 +3,7 @@ from enum import Enum
from typing import Any, Dict, Union
from langchain.utils import get_from_dict_or_env
from pydantic import root_validator
from pydantic import model_validator
from swarms.tools.tool import BaseTool
@ -59,7 +59,8 @@ class ElevenLabsText2SpeechTool(BaseTool):
" Italian, French, Portuguese, and Hindi. "
)
@root_validator(pre=True)
@model_validator(mode="before")
@classmethod
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key exists in environment."""
_ = get_from_dict_or_env(

@ -20,6 +20,8 @@ class ClassificationResult(BaseModel):
class_id: List[StrictInt]
confidence: List[StrictFloat]
# TODO[pydantic]: We couldn't refactor the `validator`, please replace it by `field_validator` manually.
# Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-validators for more information.
@validator("class_id", "confidence", pre=True, each_item=True)
def check_list_contents(cls, v):
assert isinstance(v, int) or isinstance(

@ -20,6 +20,8 @@ class Detections(BaseModel):
), "All fields must have the same length."
return values
# TODO[pydantic]: We couldn't refactor the `validator`, please replace it by `field_validator` manually.
# Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-validators for more information.
@validator(
"xyxy", "class_id", "confidence", pre=True, each_item=True
)

@ -16,7 +16,7 @@ from typing import (
)
import numpy as np
from pydantic import BaseModel, Extra, Field, root_validator
from pydantic import model_validator, ConfigDict, BaseModel, Field
from tenacity import (
AsyncRetrying,
before_sleep_log,
@ -186,7 +186,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
"""
client: Any #: :meta private:
client: Any = None #: :meta private:
model: str = "text-embedding-ada-002"
deployment: str = model # to support Azure OpenAI Service custom deployment names
openai_api_version: Optional[str] = None
@ -227,13 +227,10 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
"""Whether to show a progress bar when embedding."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified."""
model_config = ConfigDict(extra="forbid")
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
@root_validator(pre=True)
@model_validator(mode="before")
@classmethod
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls)
@ -264,7 +261,8 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
values["model_kwargs"] = extra
return values
@root_validator()
@model_validator()
@classmethod
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["openai_api_key"] = get_from_dict_or_env(

@ -2,7 +2,7 @@ from typing import Any, Dict, List, Optional, Union
import openai
import requests
from pydantic import BaseModel, validator
from pydantic import field_validator, BaseModel
from tenacity import (
retry,
stop_after_attempt,
@ -78,7 +78,8 @@ class FunctionSpecification(BaseModel):
parameters: Dict[str, Any]
required: Optional[List[str]] = None
@validator("parameters")
@field_validator("parameters")
@classmethod
def check_parameters(cls, params):
if not isinstance(params, dict):
raise ValueError("Parameters must be a dictionary.")

@ -37,6 +37,7 @@ from langchain.utils.utils import build_extra_kwargs
from importlib.metadata import version
from packaging.version import parse
from pydantic import model_validator, ConfigDict
logger = logging.getLogger(__name__)
@ -247,13 +248,10 @@ class BaseOpenAI(BaseLLM):
"""Initialize the OpenAI object."""
data.get("model_name", "")
return super().__new__(cls)
model_config = ConfigDict(populate_by_name=True)
class Config:
"""Configuration for this pydantic object."""
allow_population_by_field_name = True
@root_validator(pre=True)
@model_validator(mode="before")
@classmethod
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls)
@ -263,7 +261,8 @@ class BaseOpenAI(BaseLLM):
)
return values
@root_validator()
@model_validator()
@classmethod
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["openai_api_key"] = get_from_dict_or_env(
@ -758,7 +757,8 @@ class AzureOpenAI(BaseOpenAI):
openai_api_type: str = ""
openai_api_version: str = ""
@root_validator()
@model_validator()
@classmethod
def validate_azure_settings(cls, values: Dict) -> Dict:
values["openai_api_version"] = get_from_dict_or_env(
values,
@ -847,7 +847,8 @@ class OpenAIChat(BaseLLM):
disallowed_special: Union[Literal["all"], Collection[str]] = "all"
"""Set of special tokens that are not allowed。"""
@root_validator(pre=True)
@model_validator(mode="before")
@classmethod
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = {
@ -865,7 +866,8 @@ class OpenAIChat(BaseLLM):
values["model_kwargs"] = extra
return values
@root_validator()
@model_validator()
@classmethod
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
openai_api_key = get_from_dict_or_env(

@ -15,6 +15,7 @@ from tenacity import (
stop_after_attempt,
wait_exponential,
)
from pydantic import model_validator
logger = logging.getLogger(__name__)
@ -104,7 +105,8 @@ class GooglePalm(BaseLLM, BaseModel):
"""Number of chat completions to generate for each prompt. Note that the API may
not return the full n completions if duplicates are generated."""
@root_validator()
@model_validator()
@classmethod
def validate_environment(cls, values: Dict) -> Dict:
"""Validate api key, python package exists."""
google_api_key = get_from_dict_or_env(

@ -9,7 +9,7 @@ import backoff
import torch
from diffusers import StableDiffusionXLPipeline
from PIL import Image
from pydantic import validator
from pydantic import field_validator
from termcolor import colored
from cachetools import TTLCache
@ -72,7 +72,8 @@ class SSD1B:
arbitrary_types_allowed = True
@validator("max_retries", "time_seconds")
@field_validator("max_retries", "time_seconds")
@classmethod
def must_be_positive(cls, value):
if value <= 0:
raise ValueError("Must be positive")

@ -2,17 +2,14 @@ from typing import List
import timm
import torch
from pydantic import BaseModel
from pydantic import ConfigDict, BaseModel
class TimmModelInfo(BaseModel):
model_name: str
pretrained: bool
in_chans: int
class Config:
# Use strict typing for all fields
strict = True
model_config = ConfigDict(strict=True)
class TimmModel:

@ -30,11 +30,10 @@ from langchain.callbacks.manager import (
from langchain.load.serializable import Serializable
from pydantic import (
BaseModel,
model_validator, BaseModel,
Extra,
Field,
create_model,
root_validator,
validate_arguments,
)
from langchain.schema.runnable import (
@ -192,6 +191,8 @@ class ChildTool(BaseTool):
] = False
"""Handle the content of the ToolException thrown."""
# TODO[pydantic]: The `Config` class inherits from another class, please create the `model_config` manually.
# Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information.
class Config(Serializable.Config):
"""Configuration for this pydantic object."""
@ -276,7 +277,8 @@ class ChildTool(BaseTool):
}
return tool_input
@root_validator()
@model_validator()
@classmethod
def raise_deprecation(cls, values: Dict) -> Dict:
"""Raise deprecation warning if callback_manager is used."""
if values.get("callback_manager") is not None:

@ -1,7 +1,7 @@
from abc import ABC
from typing import Any, Dict, List, Literal, TypedDict, Union, cast
from pydantic import BaseModel, PrivateAttr
from pydantic import ConfigDict, BaseModel, PrivateAttr
class BaseSerialized(TypedDict):
@ -64,9 +64,7 @@ class Serializable(BaseModel, ABC):
constructor.
"""
return {}
class Config:
extra = "ignore"
model_config = ConfigDict(extra="ignore")
_lc_kwargs = PrivateAttr(default_factory=dict)

Loading…
Cancel
Save