Merge pull request #266 from kyegomez/revert-250-master

Revert "pydantic bump fix for #249 "
pull/268/head
Eternal Reclaimer 1 year ago committed by GitHub
commit fa7e1c769b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,40 +0,0 @@
---
# This is a github action to run docker-compose
# docker-compose.yml
# to run the docker build in the top level directory
# to run the docker build in the tests directory and run the tests with pytest
# docker-compose run --rm app pytest
on:
push:
branches: [ main ]
paths:
- 'docker-compose.yml'
- 'Dockerfile'
- 'tests/**'
- 'app/**'
- 'app.py'
- 'requirements.txt'
- 'README.md'
- '.github/workflows/**'
- '.github/workflows/docker-compose.yml'
- '.github/workflows/main.yml'
- '.github/workflows/python-app.yml'
- '.github/workflows/python-app.yml'
- '.github/workflows'
name: Docker Compose
jobs:
build:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
# Add your build and test steps here
- name: Build and run docker services
run: |
docker-compose build
docker-compose up -d
docker-compose run --rm app pytest

@ -10,7 +10,6 @@ on:
env: env:
POETRY_VERSION: "1.4.2" POETRY_VERSION: "1.4.2"
jobs:
test: test:
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
@ -47,7 +46,7 @@ jobs:
make extended_tests make extended_tests
fi fi
shell: bash shell: bash
name: Python ${{ matrix.python-version }} ${{ matrix.test_type }}
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }} - name: Set up Python ${{ matrix.python-version }}

@ -16,7 +16,7 @@ jobs:
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v4 uses: actions/setup-python@v4
with: with:
python-version: 3.11 python-version: 3.x
- name: Install dependencies - name: Install dependencies
run: | run: |

@ -2,8 +2,6 @@
# ================================== # ==================================
# Use an official Python runtime as a parent image # Use an official Python runtime as a parent image
FROM python:3.9-slim FROM python:3.9-slim
RUN apt-get update && apt-get -y install libgl1-mesa-dev libglib2.0-0; apt-get clean
RUN pip install opencv-contrib-python-headless
# Set environment variables # Set environment variables
ENV PYTHONDONTWRITEBYTECODE 1 ENV PYTHONDONTWRITEBYTECODE 1

@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
[tool.poetry] [tool.poetry]
name = "swarms" name = "swarms"
version = "2.5.8" version = "2.5.7"
description = "Swarms - Pytorch" description = "Swarms - Pytorch"
license = "MIT" license = "MIT"
authors = ["Kye Gomez <kye@apac.ai>"] authors = ["Kye Gomez <kye@apac.ai>"]
@ -52,11 +52,11 @@ ratelimit = "*"
beautifulsoup4 = "*" beautifulsoup4 = "*"
cohere = "*" cohere = "*"
huggingface-hub = "*" huggingface-hub = "*"
pydantic = "2.*" pydantic = "1.10.12"
tenacity = "*" tenacity = "*"
Pillow = "*" Pillow = "*"
chromadb = "*" chromadb = "*"
opencv-python-headless opencv-python-headless = "*"
tabulate = "*" tabulate = "*"
termcolor = "*" termcolor = "*"
black = "*" black = "*"

@ -17,7 +17,7 @@ faiss-cpu
openai==0.28.0 openai==0.28.0
attrs attrs
datasets datasets
pydantic>2 pydantic==1.10.12
soundfile soundfile
huggingface-hub huggingface-hub
google-generativeai google-generativeai

@ -12,7 +12,7 @@ class TaskInput(BaseModel):
description=( description=(
"The input parameters for the task. Any value is allowed." "The input parameters for the task. Any value is allowed."
), ),
examples=['{\n"debug": false,\n"mode": "benchmarks"\n}'], example='{\n"debug": false,\n"mode": "benchmarks"\n}',
) )
@ -20,17 +20,17 @@ class Artifact(BaseModel):
artifact_id: str = Field( artifact_id: str = Field(
..., ...,
description="Id of the artifact", description="Id of the artifact",
examples=["b225e278-8b4c-4f99-a696-8facf19f0e56"], example="b225e278-8b4c-4f99-a696-8facf19f0e56",
) )
file_name: str = Field( file_name: str = Field(
..., description="Filename of the artifact", examples=["main.py"] ..., description="Filename of the artifact", example="main.py"
) )
relative_path: Optional[str] = Field( relative_path: Optional[str] = Field(
None, None,
description=( description=(
"Relative path of the artifact in the agent's workspace" "Relative path of the artifact in the agent's workspace"
), ),
examples=["python/code/"], example="python/code/",
) )
@ -41,7 +41,7 @@ class ArtifactUpload(BaseModel):
description=( description=(
"Relative path of the artifact in the agent's workspace" "Relative path of the artifact in the agent's workspace"
), ),
examples=["python/code/"], example="python/code/",
) )
@ -52,7 +52,7 @@ class StepInput(BaseModel):
"Input parameters for the task step. Any value is" "Input parameters for the task step. Any value is"
" allowed." " allowed."
), ),
examples=['{\n"file_to_refactor": "models.py"\n}'], example='{\n"file_to_refactor": "models.py"\n}',
) )
@ -63,7 +63,7 @@ class StepOutput(BaseModel):
"Output that the task step has produced. Any value is" "Output that the task step has produced. Any value is"
" allowed." " allowed."
), ),
examples=['{\n"tokens": 7894,\n"estimated_cost": "0,24$"\n}'], example='{\n"tokens": 7894,\n"estimated_cost": "0,24$"\n}',
) )
@ -71,9 +71,9 @@ class TaskRequestBody(BaseModel):
input: Optional[str] = Field( input: Optional[str] = Field(
None, None,
description="Input prompt for the task.", description="Input prompt for the task.",
examples=[( example=(
"Write the words you receive to the file 'output.txt'." "Write the words you receive to the file 'output.txt'."
)], ),
) )
additional_input: Optional[TaskInput] = None additional_input: Optional[TaskInput] = None
@ -82,15 +82,15 @@ class Task(TaskRequestBody):
task_id: str = Field( task_id: str = Field(
..., ...,
description="The ID of the task.", description="The ID of the task.",
examples=["50da533e-3904-4401-8a07-c49adf88b5eb"], example="50da533e-3904-4401-8a07-c49adf88b5eb",
) )
artifacts: List[Artifact] = Field( artifacts: List[Artifact] = Field(
[], [],
description="A list of artifacts that the task has produced.", description="A list of artifacts that the task has produced.",
examples=[[ example=[
"7a49f31c-f9c6-4346-a22c-e32bc5af4d8e", "7a49f31c-f9c6-4346-a22c-e32bc5af4d8e",
"ab7b4091-2560-4692-a4fe-d831ea3ca7d6", "ab7b4091-2560-4692-a4fe-d831ea3ca7d6",
]], ],
) )
@ -98,7 +98,7 @@ class StepRequestBody(BaseModel):
input: Optional[str] = Field( input: Optional[str] = Field(
None, None,
description="Input prompt for the step.", description="Input prompt for the step.",
examples=["Washington"], example="Washington",
) )
additional_input: Optional[StepInput] = None additional_input: Optional[StepInput] = None
@ -113,17 +113,17 @@ class Step(StepRequestBody):
task_id: str = Field( task_id: str = Field(
..., ...,
description="The ID of the task this step belongs to.", description="The ID of the task this step belongs to.",
examples=["50da533e-3904-4401-8a07-c49adf88b5eb"], example="50da533e-3904-4401-8a07-c49adf88b5eb",
) )
step_id: str = Field( step_id: str = Field(
..., ...,
description="The ID of the task step.", description="The ID of the task step.",
examples=["6bb1801a-fd80-45e8-899a-4dd723cc602e"], example="6bb1801a-fd80-45e8-899a-4dd723cc602e",
) )
name: Optional[str] = Field( name: Optional[str] = Field(
None, None,
description="The name of the task step.", description="The name of the task step.",
examples=["Write to file"], example="Write to file",
) )
status: Status = Field( status: Status = Field(
..., description="The status of the task step." ..., description="The status of the task step."
@ -131,11 +131,11 @@ class Step(StepRequestBody):
output: Optional[str] = Field( output: Optional[str] = Field(
None, None,
description="Output of the task step.", description="Output of the task step.",
examples=[( example=(
"I am going to use the write_to_file command and write" "I am going to use the write_to_file command and write"
" Washington to a file called output.txt" " Washington to a file called output.txt"
" <write_to_file('output.txt', 'Washington')" " <write_to_file('output.txt', 'Washington')"
)], ),
) )
additional_output: Optional[StepOutput] = None additional_output: Optional[StepOutput] = None
artifacts: List[Artifact] = Field( artifacts: List[Artifact] = Field(

@ -24,7 +24,7 @@ from langchain.callbacks.manager import (
CallbackManagerForLLMRun, CallbackManagerForLLMRun,
) )
from langchain.llms.base import LLM from langchain.llms.base import LLM
from pydantic import Field, SecretStr, root_validator
from langchain.schema.language_model import BaseLanguageModel from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.output import GenerationChunk from langchain.schema.output import GenerationChunk
from langchain.schema.prompt import PromptValue from langchain.schema.prompt import PromptValue
@ -219,13 +219,21 @@ def build_extra_kwargs(
return extra_kwargs return extra_kwargs
def convert_to_secret_str(value: Union[SecretStr, str]) -> SecretStr:
"""Convert a string to a SecretStr if needed."""
if isinstance(value, SecretStr):
return value
return SecretStr(value)
class _AnthropicCommon(BaseLanguageModel): class _AnthropicCommon(BaseLanguageModel):
client: Any = None #: :meta private: client: Any = None #: :meta private:
async_client: Any = None #: :meta private: async_client: Any = None #: :meta private:
model: str ="claude-2" model: str = Field(default="claude-2", alias="model_name")
"""Model name to use.""" """Model name to use."""
max_tokens_to_sample: int =256 max_tokens_to_sample: int = Field(default=256, alias="max_tokens")
"""Denotes the number of tokens to predict per generation.""" """Denotes the number of tokens to predict per generation."""
temperature: Optional[float] = None temperature: Optional[float] = None
@ -245,14 +253,14 @@ class _AnthropicCommon(BaseLanguageModel):
anthropic_api_url: Optional[str] = None anthropic_api_url: Optional[str] = None
anthropic_api_key: Optional[str] = None anthropic_api_key: Optional[SecretStr] = None
HUMAN_PROMPT: Optional[str] = None HUMAN_PROMPT: Optional[str] = None
AI_PROMPT: Optional[str] = None AI_PROMPT: Optional[str] = None
count_tokens: Optional[Callable[[str], int]] = None count_tokens: Optional[Callable[[str], int]] = None
model_kwargs: Dict[str, Any] = {} model_kwargs: Dict[str, Any] = Field(default_factory=dict)
@classmethod @root_validator(pre=True)
def build_extra(cls, values: Dict) -> Dict: def build_extra(cls, values: Dict) -> Dict:
extra = values.get("model_kwargs", {}) extra = values.get("model_kwargs", {})
all_required_field_names = get_pydantic_field_names(cls) all_required_field_names = get_pydantic_field_names(cls)
@ -261,12 +269,14 @@ class _AnthropicCommon(BaseLanguageModel):
) )
return values return values
@classmethod @root_validator()
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
values["anthropic_api_key"] = get_from_dict_or_env( values["anthropic_api_key"] = convert_to_secret_str(
get_from_dict_or_env(
values, "anthropic_api_key", "ANTHROPIC_API_KEY" values, "anthropic_api_key", "ANTHROPIC_API_KEY"
) )
)
# Get custom api url from environment. # Get custom api url from environment.
values["anthropic_api_url"] = get_from_dict_or_env( values["anthropic_api_url"] = get_from_dict_or_env(
values, values,
@ -367,7 +377,13 @@ class Anthropic(LLM, _AnthropicCommon):
response = model(prompt) response = model(prompt)
""" """
@classmethod class Config:
"""Configuration for this pydantic object."""
allow_population_by_field_name = True
arbitrary_types_allowed = True
@root_validator()
def raise_warning(cls, values: Dict) -> Dict: def raise_warning(cls, values: Dict) -> Dict:
"""Raise warning that this class is deprecated.""" """Raise warning that this class is deprecated."""
warnings.warn( warnings.warn(

@ -16,7 +16,7 @@ from langchain.callbacks.manager import (
from langchain.llms.base import LLM from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens from langchain.llms.utils import enforce_stop_tokens
from langchain.load.serializable import Serializable from langchain.load.serializable import Serializable
from pydantic import model_validator, ConfigDict, Field from pydantic import Extra, Field, root_validator
from langchain.utils import get_from_dict_or_env from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -85,8 +85,7 @@ class BaseCohere(Serializable):
user_agent: str = "langchain" user_agent: str = "langchain"
"""Identifier for the application making the request.""" """Identifier for the application making the request."""
@model_validator() @root_validator()
@classmethod
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
try: try:
@ -146,7 +145,11 @@ class Cohere(LLM, BaseCohere):
max_retries: int = 10 max_retries: int = 10
"""Maximum number of retries to make when generating.""" """Maximum number of retries to make when generating."""
model_config = ConfigDict(extra="forbid")
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
@property @property
def _default_params(self) -> Dict[str, Any]: def _default_params(self) -> Dict[str, Any]:

@ -13,7 +13,7 @@ from cachetools import TTLCache
from dotenv import load_dotenv from dotenv import load_dotenv
from openai import OpenAI from openai import OpenAI
from PIL import Image from PIL import Image
from pydantic import field_validator from pydantic import validator
from termcolor import colored from termcolor import colored
load_dotenv() load_dotenv()
@ -92,8 +92,7 @@ class Dalle3:
arbitrary_types_allowed = True arbitrary_types_allowed = True
@field_validator("max_retries", "time_seconds") @validator("max_retries", "time_seconds")
@classmethod
def must_be_positive(cls, value): def must_be_positive(cls, value):
if value <= 0: if value <= 0:
raise ValueError("Must be positive") raise ValueError("Must be positive")

@ -3,7 +3,7 @@ from enum import Enum
from typing import Any, Dict, Union from typing import Any, Dict, Union
from langchain.utils import get_from_dict_or_env from langchain.utils import get_from_dict_or_env
from pydantic import model_validator from pydantic import root_validator
from swarms.tools.tool import BaseTool from swarms.tools.tool import BaseTool
@ -59,8 +59,7 @@ class ElevenLabsText2SpeechTool(BaseTool):
" Italian, French, Portuguese, and Hindi. " " Italian, French, Portuguese, and Hindi. "
) )
@model_validator(mode="before") @root_validator(pre=True)
@classmethod
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key exists in environment.""" """Validate that api key exists in environment."""
_ = get_from_dict_or_env( _ = get_from_dict_or_env(

@ -20,8 +20,6 @@ class ClassificationResult(BaseModel):
class_id: List[StrictInt] class_id: List[StrictInt]
confidence: List[StrictFloat] confidence: List[StrictFloat]
# TODO[pydantic]: We couldn't refactor the `validator`, please replace it by `field_validator` manually.
# Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-validators for more information.
@validator("class_id", "confidence", pre=True, each_item=True) @validator("class_id", "confidence", pre=True, each_item=True)
def check_list_contents(cls, v): def check_list_contents(cls, v):
assert isinstance(v, int) or isinstance( assert isinstance(v, int) or isinstance(

@ -20,8 +20,6 @@ class Detections(BaseModel):
), "All fields must have the same length." ), "All fields must have the same length."
return values return values
# TODO[pydantic]: We couldn't refactor the `validator`, please replace it by `field_validator` manually.
# Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-validators for more information.
@validator( @validator(
"xyxy", "class_id", "confidence", pre=True, each_item=True "xyxy", "class_id", "confidence", pre=True, each_item=True
) )

@ -16,7 +16,7 @@ from typing import (
) )
import numpy as np import numpy as np
from pydantic import model_validator, ConfigDict, BaseModel, Field from pydantic import BaseModel, Extra, Field, root_validator
from tenacity import ( from tenacity import (
AsyncRetrying, AsyncRetrying,
before_sleep_log, before_sleep_log,
@ -186,7 +186,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
""" """
client: Any = None #: :meta private: client: Any #: :meta private:
model: str = "text-embedding-ada-002" model: str = "text-embedding-ada-002"
deployment: str = model # to support Azure OpenAI Service custom deployment names deployment: str = model # to support Azure OpenAI Service custom deployment names
openai_api_version: Optional[str] = None openai_api_version: Optional[str] = None
@ -227,10 +227,13 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
"""Whether to show a progress bar when embedding.""" """Whether to show a progress bar when embedding."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict) model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified.""" """Holds any model parameters valid for `create` call not explicitly specified."""
model_config = ConfigDict(extra="forbid")
@model_validator(mode="before") class Config:
@classmethod """Configuration for this pydantic object."""
extra = Extra.forbid
@root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in.""" """Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls) all_required_field_names = get_pydantic_field_names(cls)
@ -261,8 +264,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
values["model_kwargs"] = extra values["model_kwargs"] = extra
return values return values
@model_validator() @root_validator()
@classmethod
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
values["openai_api_key"] = get_from_dict_or_env( values["openai_api_key"] = get_from_dict_or_env(

@ -2,7 +2,7 @@ from typing import Any, Dict, List, Optional, Union
import openai import openai
import requests import requests
from pydantic import field_validator, BaseModel from pydantic import BaseModel, validator
from tenacity import ( from tenacity import (
retry, retry,
stop_after_attempt, stop_after_attempt,
@ -78,8 +78,7 @@ class FunctionSpecification(BaseModel):
parameters: Dict[str, Any] parameters: Dict[str, Any]
required: Optional[List[str]] = None required: Optional[List[str]] = None
@field_validator("parameters") @validator("parameters")
@classmethod
def check_parameters(cls, params): def check_parameters(cls, params):
if not isinstance(params, dict): if not isinstance(params, dict):
raise ValueError("Parameters must be a dictionary.") raise ValueError("Parameters must be a dictionary.")

@ -38,7 +38,6 @@ from importlib.metadata import version
from packaging.version import parse from packaging.version import parse
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -249,8 +248,12 @@ class BaseOpenAI(BaseLLM):
data.get("model_name", "") data.get("model_name", "")
return super().__new__(cls) return super().__new__(cls)
class Config:
"""Configuration for this pydantic object."""
@classmethod allow_population_by_field_name = True
@root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in.""" """Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls) all_required_field_names = get_pydantic_field_names(cls)
@ -260,8 +263,7 @@ class BaseOpenAI(BaseLLM):
) )
return values return values
@root_validator()
@classmethod
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
values["openai_api_key"] = get_from_dict_or_env( values["openai_api_key"] = get_from_dict_or_env(
@ -756,7 +758,7 @@ class AzureOpenAI(BaseOpenAI):
openai_api_type: str = "" openai_api_type: str = ""
openai_api_version: str = "" openai_api_version: str = ""
@classmethod @root_validator()
def validate_azure_settings(cls, values: Dict) -> Dict: def validate_azure_settings(cls, values: Dict) -> Dict:
values["openai_api_version"] = get_from_dict_or_env( values["openai_api_version"] = get_from_dict_or_env(
values, values,
@ -845,7 +847,7 @@ class OpenAIChat(BaseLLM):
disallowed_special: Union[Literal["all"], Collection[str]] = "all" disallowed_special: Union[Literal["all"], Collection[str]] = "all"
"""Set of special tokens that are not allowed。""" """Set of special tokens that are not allowed。"""
@classmethod @root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in.""" """Build extra kwargs from additional params that were passed in."""
all_required_field_names = { all_required_field_names = {
@ -863,8 +865,7 @@ class OpenAIChat(BaseLLM):
values["model_kwargs"] = extra values["model_kwargs"] = extra
return values return values
@root_validator()
@classmethod
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
openai_api_key = get_from_dict_or_env( openai_api_key = get_from_dict_or_env(

@ -15,7 +15,6 @@ from tenacity import (
stop_after_attempt, stop_after_attempt,
wait_exponential, wait_exponential,
) )
from pydantic import model_validator
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -105,8 +104,7 @@ class GooglePalm(BaseLLM, BaseModel):
"""Number of chat completions to generate for each prompt. Note that the API may """Number of chat completions to generate for each prompt. Note that the API may
not return the full n completions if duplicates are generated.""" not return the full n completions if duplicates are generated."""
@model_validator() @root_validator()
@classmethod
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate api key, python package exists.""" """Validate api key, python package exists."""
google_api_key = get_from_dict_or_env( google_api_key = get_from_dict_or_env(

@ -9,7 +9,7 @@ import backoff
import torch import torch
from diffusers import StableDiffusionXLPipeline from diffusers import StableDiffusionXLPipeline
from PIL import Image from PIL import Image
from pydantic import field_validator from pydantic import validator
from termcolor import colored from termcolor import colored
from cachetools import TTLCache from cachetools import TTLCache
@ -72,8 +72,7 @@ class SSD1B:
arbitrary_types_allowed = True arbitrary_types_allowed = True
@field_validator("max_retries", "time_seconds") @validator("max_retries", "time_seconds")
@classmethod
def must_be_positive(cls, value): def must_be_positive(cls, value):
if value <= 0: if value <= 0:
raise ValueError("Must be positive") raise ValueError("Must be positive")

@ -2,14 +2,17 @@ from typing import List
import timm import timm
import torch import torch
from pydantic import ConfigDict, BaseModel from pydantic import BaseModel
class TimmModelInfo(BaseModel): class TimmModelInfo(BaseModel):
model_name: str model_name: str
pretrained: bool pretrained: bool
in_chans: int in_chans: int
model_config = ConfigDict(strict=True)
class Config:
# Use strict typing for all fields
strict = True
class TimmModel: class TimmModel:

@ -29,7 +29,14 @@ from langchain.callbacks.manager import (
) )
from langchain.load.serializable import Serializable from langchain.load.serializable import Serializable
from pydantic import (
BaseModel,
Extra,
Field,
create_model,
root_validator,
validate_arguments,
)
from langchain.schema.runnable import ( from langchain.schema.runnable import (
Runnable, Runnable,
RunnableConfig, RunnableConfig,
@ -41,9 +48,62 @@ class SchemaAnnotationError(TypeError):
"""Raised when 'args_schema' is missing or has an incorrect type annotation.""" """Raised when 'args_schema' is missing or has an incorrect type annotation."""
def _create_subset_model(
name: str, model: BaseModel, field_names: list
) -> Type[BaseModel]:
"""Create a pydantic model with only a subset of model's fields."""
fields = {}
for field_name in field_names:
field = model.__fields__[field_name]
fields[field_name] = (field.outer_type_, field.field_info)
return create_model(name, **fields) # type: ignore
def _get_filtered_args(
inferred_model: Type[BaseModel],
func: Callable,
) -> dict:
"""Get the arguments from a function's signature."""
schema = inferred_model.schema()["properties"]
valid_keys = signature(func).parameters
return {
k: schema[k]
for k in valid_keys
if k not in ("run_manager", "callbacks")
}
class _SchemaConfig:
"""Configuration for the pydantic model."""
extra: Any = Extra.forbid
arbitrary_types_allowed: bool = True
def create_schema_from_function(
model_name: str,
func: Callable,
) -> Type[BaseModel]:
"""Create a pydantic schema from a function's signature.
Args:
model_name: Name to assign to the generated pydandic schema
func: Function to generate the schema from
Returns:
A pydantic model with the same arguments as the function
"""
# https://docs.pydantic.dev/latest/usage/validation_decorator/
validated = validate_arguments(func, config=_SchemaConfig) # type: ignore
inferred_model = validated.model # type: ignore
if "run_manager" in inferred_model.__fields__:
del inferred_model.__fields__["run_manager"]
if "callbacks" in inferred_model.__fields__:
del inferred_model.__fields__["callbacks"]
# Pydantic adds placeholder virtual fields we need to strip
valid_properties = _get_filtered_args(inferred_model, func)
return _create_subset_model(
f"{model_name}Schema", inferred_model, list(valid_properties)
)
class ToolException(Exception): class ToolException(Exception):
"""An optional exception that tool throws when execution error occurs. """An optional exception that tool throws when execution error occurs.
@ -71,7 +131,7 @@ class BaseTool(RunnableSerializable[Union[str, Dict], Any]):
if args_schema_type is not None: if args_schema_type is not None:
if ( if (
args_schema_type is None args_schema_type is None
# or args_schema_type == BaseModel or args_schema_type == BaseModel
): ):
# Throw errors for common mis-annotations. # Throw errors for common mis-annotations.
# TODO: Use get_args / get_origin and fully # TODO: Use get_args / get_origin and fully
@ -108,10 +168,11 @@ class ChildTool(BaseTool):
verbose: bool = False verbose: bool = False
"""Whether to log the tool's progress.""" """Whether to log the tool's progress."""
callbacks: Callbacks = None callbacks: Callbacks = Field(default=None, exclude=True)
"""Callbacks to be called during tool execution.""" """Callbacks to be called during tool execution."""
# TODO: I don't know how to remove Field here callback_manager: Optional[BaseCallbackManager] = Field(
callback_manager: Optional[BaseCallbackManager] = None default=None, exclude=True
)
"""Deprecated. Please use callbacks instead.""" """Deprecated. Please use callbacks instead."""
tags: Optional[List[str]] = None tags: Optional[List[str]] = None
"""Optional list of tags associated with the tool. Defaults to None """Optional list of tags associated with the tool. Defaults to None
@ -131,11 +192,9 @@ class ChildTool(BaseTool):
] = False ] = False
"""Handle the content of the ToolException thrown.""" """Handle the content of the ToolException thrown."""
# TODO[pydantic]: The `Config` class inherits from another class, please create the `model_config` manually.
# Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information.
class Config(Serializable.Config): class Config(Serializable.Config):
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
model_config = {}
arbitrary_types_allowed = True arbitrary_types_allowed = True
@property @property
@ -155,8 +214,7 @@ class ChildTool(BaseTool):
# --- Runnable --- # --- Runnable ---
@property @property
# TODO def input_schema(self) -> Type[BaseModel]:
def input_schema(self):
"""The tool's input schema.""" """The tool's input schema."""
if self.args_schema is not None: if self.args_schema is not None:
return self.args_schema return self.args_schema
@ -218,7 +276,7 @@ class ChildTool(BaseTool):
} }
return tool_input return tool_input
@classmethod @root_validator()
def raise_deprecation(cls, values: Dict) -> Dict: def raise_deprecation(cls, values: Dict) -> Dict:
"""Raise deprecation warning if callback_manager is used.""" """Raise deprecation warning if callback_manager is used."""
if values.get("callback_manager") is not None: if values.get("callback_manager") is not None:
@ -613,7 +671,9 @@ class StructuredTool(BaseTool):
"""Tool that can operate on any number of inputs.""" """Tool that can operate on any number of inputs."""
description: str = "" description: str = ""
args_schema: Type[BaseModel] = Field(
..., description="The tool schema."
)
"""The input arguments' schema.""" """The input arguments' schema."""
func: Optional[Callable[..., Any]] func: Optional[Callable[..., Any]]
"""The function to run when the tool is called.""" """The function to run when the tool is called."""

@ -1,7 +1,7 @@
from abc import ABC from abc import ABC
from typing import Any, Dict, List, Literal, TypedDict, Union, cast from typing import Any, Dict, List, Literal, TypedDict, Union, cast
from pydantic import ConfigDict, BaseModel, PrivateAttr from pydantic import BaseModel, PrivateAttr
class BaseSerialized(TypedDict): class BaseSerialized(TypedDict):
@ -64,7 +64,9 @@ class Serializable(BaseModel, ABC):
constructor. constructor.
""" """
return {} return {}
model_config = ConfigDict(extra="ignore")
class Config:
extra = "ignore"
_lc_kwargs = PrivateAttr(default_factory=dict) _lc_kwargs = PrivateAttr(default_factory=dict)

@ -2,8 +2,6 @@
# -================== # -==================
# Use an official Python runtime as a parent image # Use an official Python runtime as a parent image
FROM python:3.9-slim FROM python:3.9-slim
RUN apt-get update && apt-get -y install libgl1-mesa-dev libglib2.0-0; apt-get clean
RUN pip install opencv-contrib-python-headless
# Set environment variables to make Python output unbuffered and disable the PIP cache # Set environment variables to make Python output unbuffered and disable the PIP cache
ENV PYTHONDONTWRITEBYTECODE 1 ENV PYTHONDONTWRITEBYTECODE 1

Loading…
Cancel
Save