[TimmModel]

pull/362/head
Kye 1 year ago
parent 73c946d4ba
commit 1a3b16a71a

@ -18,6 +18,7 @@ from swarms.models.wizard_storytelling import (
) # noqa: E402 ) # noqa: E402
from swarms.models.mpt import MPT7B # noqa: E402 from swarms.models.mpt import MPT7B # noqa: E402
from swarms.models.mixtral import Mixtral # noqa: E402 from swarms.models.mixtral import Mixtral # noqa: E402
# from swarms.models.modelscope_pipeline import ModelScopePipeline # from swarms.models.modelscope_pipeline import ModelScopePipeline
# from swarms.models.modelscope_llm import ( # from swarms.models.modelscope_llm import (
# ModelScopeAutoModel, # ModelScopeAutoModel,

@ -2,59 +2,42 @@ from typing import List
import timm import timm
import torch import torch
from pydantic import BaseModel from torch import Tensor
from swarms.models.base_multimodal_model import BaseMultiModalModel
class TimmModelInfo(BaseModel): class TimmModel(BaseMultiModalModel):
model_name: str """
pretrained: bool TimmModel is a class that wraps the timm library to provide a consistent
in_chans: int interface for creating and running models.
class Config:
# Use strict typing for all fields
strict = True
Args:
model_name: A string representing the name of the model to be created.
pretrained: A boolean indicating whether to use a pretrained model.
in_chans: An integer representing the number of input channels.
class TimmModel: Returns:
""" A TimmModel instance.
# Usage Example:
model_handler = TimmModelHandler() model = TimmModel('resnet18', pretrained=True, in_chans=3)
model_info = TimmModelInfo(model_name='resnet34', pretrained=True, in_chans=1) output_shape = model(input_tensor)
input_tensor = torch.randn(1, 1, 224, 224)
output_shape = model_handler(model_info=model_info, input_tensor=input_tensor)
print(output_shape)
""" """
def __init__(self): def __init__(
self, model_name: str, pretrained: bool, in_chans: int
):
self.model_name = model_name
self.pretrained = pretrained
self.in_chans = in_chans
self.models = self._get_supported_models() self.models = self._get_supported_models()
def _get_supported_models(self) -> List[str]: def _get_supported_models(self) -> List[str]:
"""Retrieve the list of supported models from timm.""" """Retrieve the list of supported models from timm."""
return timm.list_models() return timm.list_models()
def _create_model( def __call__(self, task: Tensor, *args, **kwargs) -> torch.Size:
self, model_info: TimmModelInfo
) -> torch.nn.Module:
"""
Create a model instance from timm with specified parameters.
Args:
model_info: An instance of TimmModelInfo containing model specifications.
Returns:
An instance of a pytorch model.
"""
return timm.create_model(
model_info.model_name,
pretrained=model_info.pretrained,
in_chans=model_info.in_chans,
)
def __call__(
self, model_info: TimmModelInfo, input_tensor: torch.Tensor
) -> torch.Size:
""" """
Create and run a model specified by `model_info` on `input_tensor`. Create and run a model specified by `model_info` on `input_tensor`.
@ -65,5 +48,8 @@ class TimmModel:
Returns: Returns:
The shape of the output from the model. The shape of the output from the model.
""" """
model = self._create_model(model_info) model = timm.create_model(self.model, *args, **kwargs)
return model(input_tensor).shape return model(task)
def list_models(self):
return timm.list_models()

Loading…
Cancel
Save