From 43198ef71322c27d101062f9c3feb25f87096e94 Mon Sep 17 00:00:00 2001 From: Eternal Reclaimer <98760976+kyegomez@users.noreply.github.com> Date: Tue, 5 Dec 2023 11:52:49 -0800 Subject: [PATCH] Revert "pydantic bump fix for #249 " --- .github/workflows/docker-compose.yml | 40 ------------ .github/workflows/test.yml | 3 +- .github/workflows/testing.yml | 2 +- Dockerfile | 2 - pyproject.toml | 6 +- requirements.txt | 2 +- swarms/memory/schemas.py | 36 +++++------ swarms/models/anthropic.py | 36 ++++++++--- swarms/models/cohere_chat.py | 11 ++-- swarms/models/dalle3.py | 5 +- swarms/models/eleven_labs.py | 5 +- swarms/models/fastvit.py | 2 - swarms/models/kosmos2.py | 2 - swarms/models/openai_embeddings.py | 16 ++--- swarms/models/openai_function_caller.py | 5 +- swarms/models/openai_models.py | 19 +++--- swarms/models/palm.py | 4 +- swarms/models/ssd_1b.py | 5 +- swarms/models/timm.py | 7 ++- swarms/tools/tool.py | 84 +++++++++++++++++++++---- swarms/utils/serializable.py | 6 +- tests/Dockerfile | 2 - 22 files changed, 166 insertions(+), 134 deletions(-) delete mode 100644 .github/workflows/docker-compose.yml diff --git a/.github/workflows/docker-compose.yml b/.github/workflows/docker-compose.yml deleted file mode 100644 index 3927c541..00000000 --- a/.github/workflows/docker-compose.yml +++ /dev/null @@ -1,40 +0,0 @@ ---- -# This is a github action to run docker-compose -# docker-compose.yml -# to run the docker build in the top level directory -# to run the docker build in the tests directory and run the tests with pytest -# docker-compose run --rm app pytest -on: - push: - branches: [ main ] - paths: - - 'docker-compose.yml' - - 'Dockerfile' - - 'tests/**' - - 'app/**' - - 'app.py' - - 'requirements.txt' - - 'README.md' - - '.github/workflows/**' - - '.github/workflows/docker-compose.yml' - - '.github/workflows/main.yml' - - '.github/workflows/python-app.yml' - - '.github/workflows/python-app.yml' - - '.github/workflows' - -name: Docker Compose - -jobs: - build: - runs-on: ubuntu-latest - steps: - - name: Checkout repository - uses: actions/checkout@v4 - - # Add your build and test steps here - - - name: Build and run docker services - run: | - docker-compose build - docker-compose up -d - docker-compose run --rm app pytest diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f4baf4f2..d9dafc76 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -10,7 +10,6 @@ on: env: POETRY_VERSION: "1.4.2" -jobs: test: runs-on: ubuntu-latest strategy: @@ -47,7 +46,7 @@ jobs: make extended_tests fi shell: bash - + name: Python ${{ matrix.python-version }} ${{ matrix.test_type }} steps: - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index 2607281f..ae572d22 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -16,7 +16,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v4 with: - python-version: 3.11 + python-version: 3.x - name: Install dependencies run: | diff --git a/Dockerfile b/Dockerfile index e05a00ea..aa11856d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -2,8 +2,6 @@ # ================================== # Use an official Python runtime as a parent image FROM python:3.9-slim -RUN apt-get update && apt-get -y install libgl1-mesa-dev libglib2.0-0; apt-get clean -RUN pip install opencv-contrib-python-headless # Set environment variables ENV PYTHONDONTWRITEBYTECODE 1 diff --git a/pyproject.toml b/pyproject.toml index 0ec01ccb..0ed3e85f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "swarms" -version = "2.5.8" +version = "2.5.7" description = "Swarms - Pytorch" license = "MIT" authors = ["Kye Gomez "] @@ -52,11 +52,11 @@ ratelimit = "*" beautifulsoup4 = "*" cohere = "*" huggingface-hub = "*" -pydantic = "2.*" +pydantic = "1.10.12" tenacity = "*" Pillow = "*" chromadb = "*" -opencv-python-headless +opencv-python-headless = "*" tabulate = "*" termcolor = "*" black = "*" diff --git a/requirements.txt b/requirements.txt index 028f5a03..0bc6a065 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,7 +17,7 @@ faiss-cpu openai==0.28.0 attrs datasets -pydantic>2 +pydantic==1.10.12 soundfile huggingface-hub google-generativeai diff --git a/swarms/memory/schemas.py b/swarms/memory/schemas.py index 589a80ae..9147a909 100644 --- a/swarms/memory/schemas.py +++ b/swarms/memory/schemas.py @@ -12,7 +12,7 @@ class TaskInput(BaseModel): description=( "The input parameters for the task. Any value is allowed." ), - examples=['{\n"debug": false,\n"mode": "benchmarks"\n}'], + example='{\n"debug": false,\n"mode": "benchmarks"\n}', ) @@ -20,17 +20,17 @@ class Artifact(BaseModel): artifact_id: str = Field( ..., description="Id of the artifact", - examples=["b225e278-8b4c-4f99-a696-8facf19f0e56"], + example="b225e278-8b4c-4f99-a696-8facf19f0e56", ) file_name: str = Field( - ..., description="Filename of the artifact", examples=["main.py"] + ..., description="Filename of the artifact", example="main.py" ) relative_path: Optional[str] = Field( None, description=( "Relative path of the artifact in the agent's workspace" ), - examples=["python/code/"], + example="python/code/", ) @@ -41,7 +41,7 @@ class ArtifactUpload(BaseModel): description=( "Relative path of the artifact in the agent's workspace" ), - examples=["python/code/"], + example="python/code/", ) @@ -52,7 +52,7 @@ class StepInput(BaseModel): "Input parameters for the task step. Any value is" " allowed." ), - examples=['{\n"file_to_refactor": "models.py"\n}'], + example='{\n"file_to_refactor": "models.py"\n}', ) @@ -63,7 +63,7 @@ class StepOutput(BaseModel): "Output that the task step has produced. Any value is" " allowed." ), - examples=['{\n"tokens": 7894,\n"estimated_cost": "0,24$"\n}'], + example='{\n"tokens": 7894,\n"estimated_cost": "0,24$"\n}', ) @@ -71,9 +71,9 @@ class TaskRequestBody(BaseModel): input: Optional[str] = Field( None, description="Input prompt for the task.", - examples=[( + example=( "Write the words you receive to the file 'output.txt'." - )], + ), ) additional_input: Optional[TaskInput] = None @@ -82,15 +82,15 @@ class Task(TaskRequestBody): task_id: str = Field( ..., description="The ID of the task.", - examples=["50da533e-3904-4401-8a07-c49adf88b5eb"], + example="50da533e-3904-4401-8a07-c49adf88b5eb", ) artifacts: List[Artifact] = Field( [], description="A list of artifacts that the task has produced.", - examples=[[ + example=[ "7a49f31c-f9c6-4346-a22c-e32bc5af4d8e", "ab7b4091-2560-4692-a4fe-d831ea3ca7d6", - ]], + ], ) @@ -98,7 +98,7 @@ class StepRequestBody(BaseModel): input: Optional[str] = Field( None, description="Input prompt for the step.", - examples=["Washington"], + example="Washington", ) additional_input: Optional[StepInput] = None @@ -113,17 +113,17 @@ class Step(StepRequestBody): task_id: str = Field( ..., description="The ID of the task this step belongs to.", - examples=["50da533e-3904-4401-8a07-c49adf88b5eb"], + example="50da533e-3904-4401-8a07-c49adf88b5eb", ) step_id: str = Field( ..., description="The ID of the task step.", - examples=["6bb1801a-fd80-45e8-899a-4dd723cc602e"], + example="6bb1801a-fd80-45e8-899a-4dd723cc602e", ) name: Optional[str] = Field( None, description="The name of the task step.", - examples=["Write to file"], + example="Write to file", ) status: Status = Field( ..., description="The status of the task step." @@ -131,11 +131,11 @@ class Step(StepRequestBody): output: Optional[str] = Field( None, description="Output of the task step.", - examples=[( + example=( "I am going to use the write_to_file command and write" " Washington to a file called output.txt" " SecretStr: + """Convert a string to a SecretStr if needed.""" + if isinstance(value, SecretStr): + return value + return SecretStr(value) + + class _AnthropicCommon(BaseLanguageModel): client: Any = None #: :meta private: async_client: Any = None #: :meta private: - model: str ="claude-2" + model: str = Field(default="claude-2", alias="model_name") """Model name to use.""" - max_tokens_to_sample: int =256 + max_tokens_to_sample: int = Field(default=256, alias="max_tokens") """Denotes the number of tokens to predict per generation.""" temperature: Optional[float] = None @@ -245,14 +253,14 @@ class _AnthropicCommon(BaseLanguageModel): anthropic_api_url: Optional[str] = None - anthropic_api_key: Optional[str] = None + anthropic_api_key: Optional[SecretStr] = None HUMAN_PROMPT: Optional[str] = None AI_PROMPT: Optional[str] = None count_tokens: Optional[Callable[[str], int]] = None - model_kwargs: Dict[str, Any] = {} + model_kwargs: Dict[str, Any] = Field(default_factory=dict) - @classmethod + @root_validator(pre=True) def build_extra(cls, values: Dict) -> Dict: extra = values.get("model_kwargs", {}) all_required_field_names = get_pydantic_field_names(cls) @@ -261,11 +269,13 @@ class _AnthropicCommon(BaseLanguageModel): ) return values - @classmethod + @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" - values["anthropic_api_key"] = get_from_dict_or_env( + values["anthropic_api_key"] = convert_to_secret_str( + get_from_dict_or_env( values, "anthropic_api_key", "ANTHROPIC_API_KEY" + ) ) # Get custom api url from environment. values["anthropic_api_url"] = get_from_dict_or_env( @@ -366,8 +376,14 @@ class Anthropic(LLM, _AnthropicCommon): prompt = f"{anthropic.HUMAN_PROMPT} {prompt}{anthropic.AI_PROMPT}" response = model(prompt) """ - - @classmethod + + class Config: + """Configuration for this pydantic object.""" + + allow_population_by_field_name = True + arbitrary_types_allowed = True + + @root_validator() def raise_warning(cls, values: Dict) -> Dict: """Raise warning that this class is deprecated.""" warnings.warn( diff --git a/swarms/models/cohere_chat.py b/swarms/models/cohere_chat.py index efd8728a..1a31d82e 100644 --- a/swarms/models/cohere_chat.py +++ b/swarms/models/cohere_chat.py @@ -16,7 +16,7 @@ from langchain.callbacks.manager import ( from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens from langchain.load.serializable import Serializable -from pydantic import model_validator, ConfigDict, Field +from pydantic import Extra, Field, root_validator from langchain.utils import get_from_dict_or_env logger = logging.getLogger(__name__) @@ -85,8 +85,7 @@ class BaseCohere(Serializable): user_agent: str = "langchain" """Identifier for the application making the request.""" - @model_validator() - @classmethod + @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" try: @@ -146,7 +145,11 @@ class Cohere(LLM, BaseCohere): max_retries: int = 10 """Maximum number of retries to make when generating.""" - model_config = ConfigDict(extra="forbid") + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid @property def _default_params(self) -> Dict[str, Any]: diff --git a/swarms/models/dalle3.py b/swarms/models/dalle3.py index 17790c74..40f63418 100644 --- a/swarms/models/dalle3.py +++ b/swarms/models/dalle3.py @@ -13,7 +13,7 @@ from cachetools import TTLCache from dotenv import load_dotenv from openai import OpenAI from PIL import Image -from pydantic import field_validator +from pydantic import validator from termcolor import colored load_dotenv() @@ -92,8 +92,7 @@ class Dalle3: arbitrary_types_allowed = True - @field_validator("max_retries", "time_seconds") - @classmethod + @validator("max_retries", "time_seconds") def must_be_positive(cls, value): if value <= 0: raise ValueError("Must be positive") diff --git a/swarms/models/eleven_labs.py b/swarms/models/eleven_labs.py index 759c65bb..2d55e864 100644 --- a/swarms/models/eleven_labs.py +++ b/swarms/models/eleven_labs.py @@ -3,7 +3,7 @@ from enum import Enum from typing import Any, Dict, Union from langchain.utils import get_from_dict_or_env -from pydantic import model_validator +from pydantic import root_validator from swarms.tools.tool import BaseTool @@ -59,8 +59,7 @@ class ElevenLabsText2SpeechTool(BaseTool): " Italian, French, Portuguese, and Hindi. " ) - @model_validator(mode="before") - @classmethod + @root_validator(pre=True) def validate_environment(cls, values: Dict) -> Dict: """Validate that api key exists in environment.""" _ = get_from_dict_or_env( diff --git a/swarms/models/fastvit.py b/swarms/models/fastvit.py index f3b60587..a6fc31f8 100644 --- a/swarms/models/fastvit.py +++ b/swarms/models/fastvit.py @@ -20,8 +20,6 @@ class ClassificationResult(BaseModel): class_id: List[StrictInt] confidence: List[StrictFloat] - # TODO[pydantic]: We couldn't refactor the `validator`, please replace it by `field_validator` manually. - # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-validators for more information. @validator("class_id", "confidence", pre=True, each_item=True) def check_list_contents(cls, v): assert isinstance(v, int) or isinstance( diff --git a/swarms/models/kosmos2.py b/swarms/models/kosmos2.py index d251ea23..9a9a0de3 100644 --- a/swarms/models/kosmos2.py +++ b/swarms/models/kosmos2.py @@ -20,8 +20,6 @@ class Detections(BaseModel): ), "All fields must have the same length." return values - # TODO[pydantic]: We couldn't refactor the `validator`, please replace it by `field_validator` manually. - # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-validators for more information. @validator( "xyxy", "class_id", "confidence", pre=True, each_item=True ) diff --git a/swarms/models/openai_embeddings.py b/swarms/models/openai_embeddings.py index 3265a141..0cbbdbee 100644 --- a/swarms/models/openai_embeddings.py +++ b/swarms/models/openai_embeddings.py @@ -16,7 +16,7 @@ from typing import ( ) import numpy as np -from pydantic import model_validator, ConfigDict, BaseModel, Field +from pydantic import BaseModel, Extra, Field, root_validator from tenacity import ( AsyncRetrying, before_sleep_log, @@ -186,7 +186,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): """ - client: Any = None #: :meta private: + client: Any #: :meta private: model: str = "text-embedding-ada-002" deployment: str = model # to support Azure OpenAI Service custom deployment names openai_api_version: Optional[str] = None @@ -227,10 +227,13 @@ class OpenAIEmbeddings(BaseModel, Embeddings): """Whether to show a progress bar when embedding.""" model_kwargs: Dict[str, Any] = Field(default_factory=dict) """Holds any model parameters valid for `create` call not explicitly specified.""" - model_config = ConfigDict(extra="forbid") - @model_validator(mode="before") - @classmethod + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + + @root_validator(pre=True) def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: """Build extra kwargs from additional params that were passed in.""" all_required_field_names = get_pydantic_field_names(cls) @@ -261,8 +264,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): values["model_kwargs"] = extra return values - @model_validator() - @classmethod + @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" values["openai_api_key"] = get_from_dict_or_env( diff --git a/swarms/models/openai_function_caller.py b/swarms/models/openai_function_caller.py index feb04387..6542e457 100644 --- a/swarms/models/openai_function_caller.py +++ b/swarms/models/openai_function_caller.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List, Optional, Union import openai import requests -from pydantic import field_validator, BaseModel +from pydantic import BaseModel, validator from tenacity import ( retry, stop_after_attempt, @@ -78,8 +78,7 @@ class FunctionSpecification(BaseModel): parameters: Dict[str, Any] required: Optional[List[str]] = None - @field_validator("parameters") - @classmethod + @validator("parameters") def check_parameters(cls, params): if not isinstance(params, dict): raise ValueError("Parameters must be a dictionary.") diff --git a/swarms/models/openai_models.py b/swarms/models/openai_models.py index 233b99c3..14332ff2 100644 --- a/swarms/models/openai_models.py +++ b/swarms/models/openai_models.py @@ -38,7 +38,6 @@ from importlib.metadata import version from packaging.version import parse - logger = logging.getLogger(__name__) @@ -249,8 +248,12 @@ class BaseOpenAI(BaseLLM): data.get("model_name", "") return super().__new__(cls) - - @classmethod + class Config: + """Configuration for this pydantic object.""" + + allow_population_by_field_name = True + + @root_validator(pre=True) def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: """Build extra kwargs from additional params that were passed in.""" all_required_field_names = get_pydantic_field_names(cls) @@ -260,8 +263,7 @@ class BaseOpenAI(BaseLLM): ) return values - - @classmethod + @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" values["openai_api_key"] = get_from_dict_or_env( @@ -756,7 +758,7 @@ class AzureOpenAI(BaseOpenAI): openai_api_type: str = "" openai_api_version: str = "" - @classmethod + @root_validator() def validate_azure_settings(cls, values: Dict) -> Dict: values["openai_api_version"] = get_from_dict_or_env( values, @@ -845,7 +847,7 @@ class OpenAIChat(BaseLLM): disallowed_special: Union[Literal["all"], Collection[str]] = "all" """Set of special tokens that are not allowed。""" - @classmethod + @root_validator(pre=True) def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: """Build extra kwargs from additional params that were passed in.""" all_required_field_names = { @@ -863,8 +865,7 @@ class OpenAIChat(BaseLLM): values["model_kwargs"] = extra return values - - @classmethod + @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" openai_api_key = get_from_dict_or_env( diff --git a/swarms/models/palm.py b/swarms/models/palm.py index e016a776..d61d4856 100644 --- a/swarms/models/palm.py +++ b/swarms/models/palm.py @@ -15,7 +15,6 @@ from tenacity import ( stop_after_attempt, wait_exponential, ) -from pydantic import model_validator logger = logging.getLogger(__name__) @@ -105,8 +104,7 @@ class GooglePalm(BaseLLM, BaseModel): """Number of chat completions to generate for each prompt. Note that the API may not return the full n completions if duplicates are generated.""" - @model_validator() - @classmethod + @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate api key, python package exists.""" google_api_key = get_from_dict_or_env( diff --git a/swarms/models/ssd_1b.py b/swarms/models/ssd_1b.py index 9a905bd4..d3b9086b 100644 --- a/swarms/models/ssd_1b.py +++ b/swarms/models/ssd_1b.py @@ -9,7 +9,7 @@ import backoff import torch from diffusers import StableDiffusionXLPipeline from PIL import Image -from pydantic import field_validator +from pydantic import validator from termcolor import colored from cachetools import TTLCache @@ -72,8 +72,7 @@ class SSD1B: arbitrary_types_allowed = True - @field_validator("max_retries", "time_seconds") - @classmethod + @validator("max_retries", "time_seconds") def must_be_positive(cls, value): if value <= 0: raise ValueError("Must be positive") diff --git a/swarms/models/timm.py b/swarms/models/timm.py index 8dec0bc9..d1c42165 100644 --- a/swarms/models/timm.py +++ b/swarms/models/timm.py @@ -2,14 +2,17 @@ from typing import List import timm import torch -from pydantic import ConfigDict, BaseModel +from pydantic import BaseModel class TimmModelInfo(BaseModel): model_name: str pretrained: bool in_chans: int - model_config = ConfigDict(strict=True) + + class Config: + # Use strict typing for all fields + strict = True class TimmModel: diff --git a/swarms/tools/tool.py b/swarms/tools/tool.py index 838b89bb..1029a183 100644 --- a/swarms/tools/tool.py +++ b/swarms/tools/tool.py @@ -29,7 +29,14 @@ from langchain.callbacks.manager import ( ) from langchain.load.serializable import Serializable - +from pydantic import ( + BaseModel, + Extra, + Field, + create_model, + root_validator, + validate_arguments, +) from langchain.schema.runnable import ( Runnable, RunnableConfig, @@ -41,9 +48,62 @@ class SchemaAnnotationError(TypeError): """Raised when 'args_schema' is missing or has an incorrect type annotation.""" +def _create_subset_model( + name: str, model: BaseModel, field_names: list +) -> Type[BaseModel]: + """Create a pydantic model with only a subset of model's fields.""" + fields = {} + for field_name in field_names: + field = model.__fields__[field_name] + fields[field_name] = (field.outer_type_, field.field_info) + return create_model(name, **fields) # type: ignore + +def _get_filtered_args( + inferred_model: Type[BaseModel], + func: Callable, +) -> dict: + """Get the arguments from a function's signature.""" + schema = inferred_model.schema()["properties"] + valid_keys = signature(func).parameters + return { + k: schema[k] + for k in valid_keys + if k not in ("run_manager", "callbacks") + } +class _SchemaConfig: + """Configuration for the pydantic model.""" + + extra: Any = Extra.forbid + arbitrary_types_allowed: bool = True + + +def create_schema_from_function( + model_name: str, + func: Callable, +) -> Type[BaseModel]: + """Create a pydantic schema from a function's signature. + Args: + model_name: Name to assign to the generated pydandic schema + func: Function to generate the schema from + Returns: + A pydantic model with the same arguments as the function + """ + # https://docs.pydantic.dev/latest/usage/validation_decorator/ + validated = validate_arguments(func, config=_SchemaConfig) # type: ignore + inferred_model = validated.model # type: ignore + if "run_manager" in inferred_model.__fields__: + del inferred_model.__fields__["run_manager"] + if "callbacks" in inferred_model.__fields__: + del inferred_model.__fields__["callbacks"] + # Pydantic adds placeholder virtual fields we need to strip + valid_properties = _get_filtered_args(inferred_model, func) + return _create_subset_model( + f"{model_name}Schema", inferred_model, list(valid_properties) + ) + class ToolException(Exception): """An optional exception that tool throws when execution error occurs. @@ -71,7 +131,7 @@ class BaseTool(RunnableSerializable[Union[str, Dict], Any]): if args_schema_type is not None: if ( args_schema_type is None - # or args_schema_type == BaseModel + or args_schema_type == BaseModel ): # Throw errors for common mis-annotations. # TODO: Use get_args / get_origin and fully @@ -108,10 +168,11 @@ class ChildTool(BaseTool): verbose: bool = False """Whether to log the tool's progress.""" - callbacks: Callbacks = None + callbacks: Callbacks = Field(default=None, exclude=True) """Callbacks to be called during tool execution.""" - # TODO: I don't know how to remove Field here - callback_manager: Optional[BaseCallbackManager] = None + callback_manager: Optional[BaseCallbackManager] = Field( + default=None, exclude=True + ) """Deprecated. Please use callbacks instead.""" tags: Optional[List[str]] = None """Optional list of tags associated with the tool. Defaults to None @@ -131,11 +192,9 @@ class ChildTool(BaseTool): ] = False """Handle the content of the ToolException thrown.""" - # TODO[pydantic]: The `Config` class inherits from another class, please create the `model_config` manually. - # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information. class Config(Serializable.Config): """Configuration for this pydantic object.""" - model_config = {} + arbitrary_types_allowed = True @property @@ -155,8 +214,7 @@ class ChildTool(BaseTool): # --- Runnable --- @property - # TODO - def input_schema(self): + def input_schema(self) -> Type[BaseModel]: """The tool's input schema.""" if self.args_schema is not None: return self.args_schema @@ -218,7 +276,7 @@ class ChildTool(BaseTool): } return tool_input - @classmethod + @root_validator() def raise_deprecation(cls, values: Dict) -> Dict: """Raise deprecation warning if callback_manager is used.""" if values.get("callback_manager") is not None: @@ -613,7 +671,9 @@ class StructuredTool(BaseTool): """Tool that can operate on any number of inputs.""" description: str = "" - + args_schema: Type[BaseModel] = Field( + ..., description="The tool schema." + ) """The input arguments' schema.""" func: Optional[Callable[..., Any]] """The function to run when the tool is called.""" diff --git a/swarms/utils/serializable.py b/swarms/utils/serializable.py index 3cc3a5f6..de9444ef 100644 --- a/swarms/utils/serializable.py +++ b/swarms/utils/serializable.py @@ -1,7 +1,7 @@ from abc import ABC from typing import Any, Dict, List, Literal, TypedDict, Union, cast -from pydantic import ConfigDict, BaseModel, PrivateAttr +from pydantic import BaseModel, PrivateAttr class BaseSerialized(TypedDict): @@ -64,7 +64,9 @@ class Serializable(BaseModel, ABC): constructor. """ return {} - model_config = ConfigDict(extra="ignore") + + class Config: + extra = "ignore" _lc_kwargs = PrivateAttr(default_factory=dict) diff --git a/tests/Dockerfile b/tests/Dockerfile index e28fbc8e..f6e46515 100644 --- a/tests/Dockerfile +++ b/tests/Dockerfile @@ -2,8 +2,6 @@ # -================== # Use an official Python runtime as a parent image FROM python:3.9-slim -RUN apt-get update && apt-get -y install libgl1-mesa-dev libglib2.0-0; apt-get clean -RUN pip install opencv-contrib-python-headless # Set environment variables to make Python output unbuffered and disable the PIP cache ENV PYTHONDONTWRITEBYTECODE 1