You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
123 lines
4.1 KiB
123 lines
4.1 KiB
1 year ago
|
from unittest.mock import MagicMock, patch
|
||
|
|
||
|
import pytest
|
||
|
|
||
|
from swarms.models.zeroscope import ZeroscopeTTV
|
||
|
|
||
|
|
||
|
@patch("swarms.models.zeroscope.DiffusionPipeline")
|
||
|
@patch("swarms.models.zeroscope.DPMSolverMultistepScheduler")
|
||
|
def test_zeroscope_ttv_init(mock_scheduler, mock_pipeline):
|
||
|
zeroscope = ZeroscopeTTV()
|
||
|
mock_pipeline.from_pretrained.assert_called_once()
|
||
|
mock_scheduler.assert_called_once()
|
||
|
assert zeroscope.model_name == "cerspense/zeroscope_v2_576w"
|
||
|
assert zeroscope.chunk_size == 1
|
||
|
assert zeroscope.dim == 1
|
||
|
assert zeroscope.num_inference_steps == 40
|
||
|
assert zeroscope.height == 320
|
||
|
assert zeroscope.width == 576
|
||
|
assert zeroscope.num_frames == 36
|
||
|
|
||
|
|
||
|
@patch("swarms.models.zeroscope.DiffusionPipeline")
|
||
|
@patch("swarms.models.zeroscope.DPMSolverMultistepScheduler")
|
||
|
def test_zeroscope_ttv_forward(mock_scheduler, mock_pipeline):
|
||
|
zeroscope = ZeroscopeTTV()
|
||
|
mock_pipeline_instance = MagicMock()
|
||
|
mock_pipeline.from_pretrained.return_value = (
|
||
|
mock_pipeline_instance
|
||
|
)
|
||
|
mock_pipeline_instance.return_value = MagicMock(
|
||
|
frames="Generated frames"
|
||
|
)
|
||
|
mock_pipeline_instance.enable_vae_slicing.assert_called_once()
|
||
|
mock_pipeline_instance.enable_forward_chunking.assert_called_once_with(
|
||
|
chunk_size=1, dim=1
|
||
|
)
|
||
|
result = zeroscope.forward("Test task")
|
||
|
assert result == "Generated frames"
|
||
|
mock_pipeline_instance.assert_called_once_with(
|
||
|
"Test task",
|
||
|
num_inference_steps=40,
|
||
|
height=320,
|
||
|
width=576,
|
||
|
num_frames=36,
|
||
|
)
|
||
|
|
||
|
|
||
|
@patch("swarms.models.zeroscope.DiffusionPipeline")
|
||
|
@patch("swarms.models.zeroscope.DPMSolverMultistepScheduler")
|
||
|
def test_zeroscope_ttv_forward_error(mock_scheduler, mock_pipeline):
|
||
|
zeroscope = ZeroscopeTTV()
|
||
|
mock_pipeline_instance = MagicMock()
|
||
|
mock_pipeline.from_pretrained.return_value = (
|
||
|
mock_pipeline_instance
|
||
|
)
|
||
|
mock_pipeline_instance.return_value = MagicMock(
|
||
|
frames="Generated frames"
|
||
|
)
|
||
|
mock_pipeline_instance.side_effect = Exception("Test error")
|
||
|
with pytest.raises(Exception, match="Test error"):
|
||
|
zeroscope.forward("Test task")
|
||
|
|
||
|
|
||
|
@patch("swarms.models.zeroscope.DiffusionPipeline")
|
||
|
@patch("swarms.models.zeroscope.DPMSolverMultistepScheduler")
|
||
|
def test_zeroscope_ttv_call(mock_scheduler, mock_pipeline):
|
||
|
zeroscope = ZeroscopeTTV()
|
||
|
mock_pipeline_instance = MagicMock()
|
||
|
mock_pipeline.from_pretrained.return_value = (
|
||
|
mock_pipeline_instance
|
||
|
)
|
||
|
mock_pipeline_instance.return_value = MagicMock(
|
||
|
frames="Generated frames"
|
||
|
)
|
||
|
result = zeroscope.__call__("Test task")
|
||
|
assert result == "Generated frames"
|
||
|
mock_pipeline_instance.assert_called_once_with(
|
||
|
"Test task",
|
||
|
num_inference_steps=40,
|
||
|
height=320,
|
||
|
width=576,
|
||
|
num_frames=36,
|
||
|
)
|
||
|
|
||
|
|
||
|
@patch("swarms.models.zeroscope.DiffusionPipeline")
|
||
|
@patch("swarms.models.zeroscope.DPMSolverMultistepScheduler")
|
||
|
def test_zeroscope_ttv_call_error(mock_scheduler, mock_pipeline):
|
||
|
zeroscope = ZeroscopeTTV()
|
||
|
mock_pipeline_instance = MagicMock()
|
||
|
mock_pipeline.from_pretrained.return_value = (
|
||
|
mock_pipeline_instance
|
||
|
)
|
||
|
mock_pipeline_instance.return_value = MagicMock(
|
||
|
frames="Generated frames"
|
||
|
)
|
||
|
mock_pipeline_instance.side_effect = Exception("Test error")
|
||
|
with pytest.raises(Exception, match="Test error"):
|
||
|
zeroscope.__call__("Test task")
|
||
|
|
||
|
|
||
|
@patch("swarms.models.zeroscope.DiffusionPipeline")
|
||
|
@patch("swarms.models.zeroscope.DPMSolverMultistepScheduler")
|
||
|
def test_zeroscope_ttv_save_video_path(mock_scheduler, mock_pipeline):
|
||
|
zeroscope = ZeroscopeTTV()
|
||
|
mock_pipeline_instance = MagicMock()
|
||
|
mock_pipeline.from_pretrained.return_value = (
|
||
|
mock_pipeline_instance
|
||
|
)
|
||
|
mock_pipeline_instance.return_value = MagicMock(
|
||
|
frames="Generated frames"
|
||
|
)
|
||
|
result = zeroscope.save_video_path("Test video path")
|
||
|
assert result == "Test video path"
|
||
|
mock_pipeline_instance.assert_called_once_with(
|
||
|
"Test video path",
|
||
|
num_inference_steps=40,
|
||
|
height=320,
|
||
|
width=576,
|
||
|
num_frames=36,
|
||
|
)
|