You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							50 lines
						
					
					
						
							1.5 KiB
						
					
					
				
			
		
		
	
	
							50 lines
						
					
					
						
							1.5 KiB
						
					
					
				| import unittest
 | |
| from unittest.mock import Mock
 | |
| 
 | |
| import pytest
 | |
| import torch
 | |
| 
 | |
| from swarms.utils import prep_torch_inference
 | |
| 
 | |
| 
 | |
| def test_prep_torch_inference():
 | |
|     model_path = "model_path"
 | |
|     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 | |
|     model_mock = Mock()
 | |
|     model_mock.eval = Mock()
 | |
| 
 | |
|     # Mocking the load_model_torch function to return our mock model.
 | |
|     with unittest.mock.patch(
 | |
|         "swarms.utils.load_model_torch", return_value=model_mock
 | |
|     ) as _:
 | |
|         model = prep_torch_inference(model_path, device)
 | |
| 
 | |
|     # Check if model was properly loaded and eval function was called
 | |
|     assert model == model_mock
 | |
|     model_mock.eval.assert_called_once()
 | |
| 
 | |
| 
 | |
| @pytest.mark.parametrize(
 | |
|     "model_path, device",
 | |
|     [
 | |
|         (
 | |
|             "invalid_path",
 | |
|             torch.device("cuda"),
 | |
|         ),  # Invalid file path, valid device
 | |
|         (None, torch.device("cuda")),  # None file path, valid device
 | |
|         ("model_path", None),  # Valid file path, None device
 | |
|         (None, None),  # None file path, None device
 | |
|     ],
 | |
| )
 | |
| def test_prep_torch_inference_exceptions(model_path, device):
 | |
|     with pytest.raises(Exception):
 | |
|         prep_torch_inference(model_path, device)
 | |
| 
 | |
| 
 | |
| def test_prep_torch_inference_return_none():
 | |
|     model_path = "invalid_path"  # Invalid file path
 | |
|     device = torch.device("cuda")  # Valid device
 | |
| 
 | |
|     # Since load_model_torch function will raise an exception, prep_torch_inference should return None
 | |
|     assert prep_torch_inference(model_path, device) is None
 |