[CODE QUALITY]

pull/378/head^2
Kye 11 months ago
parent c4f496a5bd
commit 075f6320e1

@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
[tool.poetry] [tool.poetry]
name = "swarms" name = "swarms"
version = "4.0.2" version = "4.0.3"
description = "Swarms - Pytorch" description = "Swarms - Pytorch"
license = "MIT" license = "MIT"
authors = ["Kye Gomez <kye@apac.ai>"] authors = ["Kye Gomez <kye@apac.ai>"]
@ -21,6 +21,7 @@ classifiers = [
"Programming Language :: Python :: 3.10" "Programming Language :: Python :: 3.10"
] ]
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = "^3.6.1" python = "^3.6.1"
torch = "2.1.1" torch = "2.1.1"

@ -44,7 +44,8 @@ from swarms.models.timm import TimmModel # noqa: E402
from swarms.models.ultralytics_model import ( from swarms.models.ultralytics_model import (
UltralyticsModel, UltralyticsModel,
) # noqa: E402 ) # noqa: E402
from swarms.models.vip_llava import VipLlavaMultiModal # noqa: E402
# from swarms.models.vip_llava import VipLlavaMultiModal # noqa: E402
from swarms.models.llava import LavaMultiModal # noqa: E402 from swarms.models.llava import LavaMultiModal # noqa: E402
from swarms.models.qwen import QwenVLMultiModal # noqa: E402 from swarms.models.qwen import QwenVLMultiModal # noqa: E402
from swarms.models.clipq import CLIPQ # noqa: E402 from swarms.models.clipq import CLIPQ # noqa: E402
@ -118,7 +119,7 @@ __all__ = [
"TogetherLLM", "TogetherLLM",
"TimmModel", "TimmModel",
"UltralyticsModel", "UltralyticsModel",
"VipLlavaMultiModal", # "VipLlavaMultiModal",
"LavaMultiModal", "LavaMultiModal",
"QwenVLMultiModal", "QwenVLMultiModal",
"CLIPQ", "CLIPQ",

@ -52,3 +52,31 @@ class ModelRegistry:
for name, model in self.models.items() for name, model in self.models.items()
if text in name if text in name
} }
def run_model(
self, model_name: str, task: str, img: str, *args, **kwargs
):
"""
Runs the specified model for the given task and image.
Args:
model_name (str): The name of the model to run.
task (str): The task to perform using the model.
img (str): The image to process.
*args: Additional positional arguments to pass to the model's run method.
**kwargs: Additional keyword arguments to pass to the model's run method.
Returns:
The result of running the model.
Raises:
ValueError: If the specified model is not found in the model registry.
"""
if model_name not in self.models:
raise ValueError(f"Model {model_name} not found")
# Get the model
model = self.models[model_name]
# Run the model
return model.run(task, img, *args, **kwargs)

@ -29,7 +29,7 @@ class QwenVLMultiModal(BaseMultiModalModel):
>>> print(response) >>> print(response)
""" """
model_name: str = "Qwen/Qwen-VL-Chat" model_name: str = "Qwen/Qwen-VL"
device: str = "cuda" device: str = "cuda"
args: tuple = field(default_factory=tuple) args: tuple = field(default_factory=tuple)
kwargs: dict = field(default_factory=dict) kwargs: dict = field(default_factory=dict)

Loading…
Cancel
Save