You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
113 lines
3.5 KiB
113 lines
3.5 KiB
6 months ago
|
import pytest
|
||
|
import torch
|
||
|
from torch import nn
|
||
|
|
||
|
from swarms.utils import load_model_torch
|
||
|
|
||
|
|
||
|
class DummyModel(nn.Module):
|
||
|
def __init__(self):
|
||
|
super().__init__()
|
||
|
self.fc = nn.Linear(10, 2)
|
||
|
|
||
|
def forward(self, x):
|
||
|
return self.fc(x)
|
||
|
|
||
|
|
||
|
# Test case 1: Test if model can be loaded successfully
|
||
|
def test_load_model_torch_success(tmp_path):
|
||
|
model = DummyModel()
|
||
|
# Save the model to a temporary directory
|
||
|
model_path = tmp_path / "model.pt"
|
||
|
torch.save(model.state_dict(), model_path)
|
||
|
|
||
|
# Load the model
|
||
|
model_loaded = load_model_torch(model_path, model=DummyModel())
|
||
|
|
||
|
# Check if loaded model has the same architecture
|
||
|
assert isinstance(
|
||
|
model_loaded, DummyModel
|
||
|
), "Loaded model type mismatch."
|
||
|
|
||
|
|
||
|
# Test case 2: Test if function raises FileNotFoundError for non-existent file
|
||
|
def test_load_model_torch_file_not_found():
|
||
|
with pytest.raises(FileNotFoundError):
|
||
|
load_model_torch("non_existent_model.pt")
|
||
|
|
||
|
|
||
|
# Test case 3: Test if function catches and raises RuntimeError for invalid model file
|
||
|
def test_load_model_torch_invalid_file(tmp_path):
|
||
|
file = tmp_path / "invalid_model.pt"
|
||
|
file.write_text("Invalid model file.")
|
||
|
|
||
|
with pytest.raises(RuntimeError):
|
||
|
load_model_torch(file)
|
||
|
|
||
|
|
||
|
# Test case 4: Test for handling of 'strict' parameter
|
||
|
def test_load_model_torch_strict_handling(tmp_path):
|
||
|
# Create a model and modify it to cause a mismatch
|
||
|
model = DummyModel()
|
||
|
model.fc = nn.Linear(10, 3)
|
||
|
model_path = tmp_path / "model.pt"
|
||
|
torch.save(model.state_dict(), model_path)
|
||
|
|
||
|
# Try to load the modified model with 'strict' parameter set to True
|
||
|
with pytest.raises(RuntimeError):
|
||
|
load_model_torch(model_path, model=DummyModel(), strict=True)
|
||
|
|
||
|
|
||
|
# Test case 5: Test for 'device' parameter handling
|
||
|
def test_load_model_torch_device_handling(tmp_path):
|
||
|
model = DummyModel()
|
||
|
model_path = tmp_path / "model.pt"
|
||
|
torch.save(model.state_dict(), model_path)
|
||
|
|
||
|
# Define a device other than default and load the model to the specified device
|
||
|
device = torch.device("cpu")
|
||
|
model_loaded = load_model_torch(
|
||
|
model_path, model=DummyModel(), device=device
|
||
|
)
|
||
|
|
||
|
assert (
|
||
|
model_loaded.fc.weight.device == device
|
||
|
), "Model not loaded to specified device."
|
||
|
|
||
|
|
||
|
# Test case 6: Testing for correct handling of '*args' and '**kwargs'
|
||
|
def test_load_model_torch_args_kwargs_handling(monkeypatch, tmp_path):
|
||
|
model = DummyModel()
|
||
|
model_path = tmp_path / "model.pt"
|
||
|
torch.save(model.state_dict(), model_path)
|
||
|
|
||
|
def mock_torch_load(*args, **kwargs):
|
||
|
assert (
|
||
|
"pickle_module" in kwargs
|
||
|
), "Keyword arguments not passed to 'torch.load'."
|
||
|
|
||
|
# Monkeypatch 'torch.load' to check if '*args' and '**kwargs' are passed correctly
|
||
|
monkeypatch.setattr(torch, "load", mock_torch_load)
|
||
|
load_model_torch(
|
||
|
model_path, model=DummyModel(), pickle_module="dummy_module"
|
||
|
)
|
||
|
|
||
|
|
||
|
# Test case 7: Test for model loading on CPU if no GPU is available
|
||
|
def test_load_model_torch_cpu(tmp_path):
|
||
|
model = DummyModel()
|
||
|
model_path = tmp_path / "model.pt"
|
||
|
torch.save(model.state_dict(), model_path)
|
||
|
|
||
|
def mock_torch_cuda_is_available():
|
||
|
return False
|
||
|
|
||
|
# Monkeypatch to simulate no GPU available
|
||
|
pytest.MonkeyPatch.setattr(
|
||
|
torch.cuda, "is_available", mock_torch_cuda_is_available
|
||
|
)
|
||
|
model_loaded = load_model_torch(model_path, model=DummyModel())
|
||
|
|
||
|
# Ensure model is loaded on CPU
|
||
|
assert next(model_loaded.parameters()).device.type == "cpu"
|