You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/tests/models/test_ultralytics.py

35 lines
1.2 KiB

from unittest.mock import patch
from swarms.models.ultralytics_model import UltralyticsModel
def test_ultralytics_init():
with patch("swarms.models.YOLO") as mock_yolo:
model_name = "yolov5s"
ultralytics = UltralyticsModel(model_name)
mock_yolo.assert_called_once_with(model_name)
assert ultralytics.model_name == model_name
assert ultralytics.model == mock_yolo.return_value
def test_ultralytics_call():
with patch("swarms.models.YOLO") as mock_yolo:
model_name = "yolov5s"
ultralytics = UltralyticsModel(model_name)
task = "detect"
args = (1, 2, 3)
kwargs = {"a": "A", "b": "B"}
result = ultralytics(task, *args, **kwargs)
mock_yolo.return_value.assert_called_once_with(
task, *args, **kwargs
)
assert result == mock_yolo.return_value.return_value
def test_ultralytics_list_models():
with patch("swarms.models.YOLO") as mock_yolo:
model_name = "yolov5s"
ultralytics = UltralyticsModel(model_name)
result = ultralytics.list_models()
mock_yolo.list_models.assert_called_once()
assert result == mock_yolo.list_models.return_value