From 075f6320e1cc164d5cc595dcce5b306b512a2d92 Mon Sep 17 00:00:00 2001 From: Kye Date: Sun, 4 Feb 2024 01:35:48 -0800 Subject: [PATCH] [CODE QUALITY] --- pyproject.toml | 3 ++- swarms/models/__init__.py | 5 +++-- swarms/models/model_registry.py | 28 ++++++++++++++++++++++++++++ swarms/models/qwen.py | 2 +- 4 files changed, 34 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f4971b47..2a89e00f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "swarms" -version = "4.0.2" +version = "4.0.3" description = "Swarms - Pytorch" license = "MIT" authors = ["Kye Gomez "] @@ -21,6 +21,7 @@ classifiers = [ "Programming Language :: Python :: 3.10" ] + [tool.poetry.dependencies] python = "^3.6.1" torch = "2.1.1" diff --git a/swarms/models/__init__.py b/swarms/models/__init__.py index 25b10024..00d9d1f2 100644 --- a/swarms/models/__init__.py +++ b/swarms/models/__init__.py @@ -44,7 +44,8 @@ from swarms.models.timm import TimmModel # noqa: E402 from swarms.models.ultralytics_model import ( UltralyticsModel, ) # noqa: E402 -from swarms.models.vip_llava import VipLlavaMultiModal # noqa: E402 + +# from swarms.models.vip_llava import VipLlavaMultiModal # noqa: E402 from swarms.models.llava import LavaMultiModal # noqa: E402 from swarms.models.qwen import QwenVLMultiModal # noqa: E402 from swarms.models.clipq import CLIPQ # noqa: E402 @@ -118,7 +119,7 @@ __all__ = [ "TogetherLLM", "TimmModel", "UltralyticsModel", - "VipLlavaMultiModal", + # "VipLlavaMultiModal", "LavaMultiModal", "QwenVLMultiModal", "CLIPQ", diff --git a/swarms/models/model_registry.py b/swarms/models/model_registry.py index a65ca154..6da04282 100644 --- a/swarms/models/model_registry.py +++ b/swarms/models/model_registry.py @@ -52,3 +52,31 @@ class ModelRegistry: for name, model in self.models.items() if text in name } + + def run_model( + self, model_name: str, task: str, img: str, *args, **kwargs + ): + """ + Runs the specified model for the given task and image. + + Args: + model_name (str): The name of the model to run. + task (str): The task to perform using the model. + img (str): The image to process. + *args: Additional positional arguments to pass to the model's run method. + **kwargs: Additional keyword arguments to pass to the model's run method. + + Returns: + The result of running the model. + + Raises: + ValueError: If the specified model is not found in the model registry. + """ + if model_name not in self.models: + raise ValueError(f"Model {model_name} not found") + + # Get the model + model = self.models[model_name] + + # Run the model + return model.run(task, img, *args, **kwargs) diff --git a/swarms/models/qwen.py b/swarms/models/qwen.py index 5ed131bf..b5a4ed1a 100644 --- a/swarms/models/qwen.py +++ b/swarms/models/qwen.py @@ -29,7 +29,7 @@ class QwenVLMultiModal(BaseMultiModalModel): >>> print(response) """ - model_name: str = "Qwen/Qwen-VL-Chat" + model_name: str = "Qwen/Qwen-VL" device: str = "cuda" args: tuple = field(default_factory=tuple) kwargs: dict = field(default_factory=dict)