You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
36 lines
1.2 KiB
36 lines
1.2 KiB
from unittest.mock import patch
|
|
|
|
from swarms.models.ultralytics_model import UltralyticsModel
|
|
|
|
|
|
def test_ultralytics_init():
|
|
with patch("swarms.models.YOLO") as mock_yolo:
|
|
model_name = "yolov5s"
|
|
ultralytics = UltralyticsModel(model_name)
|
|
mock_yolo.assert_called_once_with(model_name)
|
|
assert ultralytics.model_name == model_name
|
|
assert ultralytics.model == mock_yolo.return_value
|
|
|
|
|
|
def test_ultralytics_call():
|
|
with patch("swarms.models.YOLO") as mock_yolo:
|
|
model_name = "yolov5s"
|
|
ultralytics = UltralyticsModel(model_name)
|
|
task = "detect"
|
|
args = (1, 2, 3)
|
|
kwargs = {"a": "A", "b": "B"}
|
|
result = ultralytics(task, *args, **kwargs)
|
|
mock_yolo.return_value.assert_called_once_with(
|
|
task, *args, **kwargs
|
|
)
|
|
assert result == mock_yolo.return_value.return_value
|
|
|
|
|
|
def test_ultralytics_list_models():
|
|
with patch("swarms.models.YOLO") as mock_yolo:
|
|
model_name = "yolov5s"
|
|
ultralytics = UltralyticsModel(model_name)
|
|
result = ultralytics.list_models()
|
|
mock_yolo.list_models.assert_called_once()
|
|
assert result == mock_yolo.list_models.return_value
|