From f39d722f2a54c7e2e1f26fb0dad20082559faa0a Mon Sep 17 00:00:00 2001 From: Kye Date: Sat, 23 Dec 2023 18:46:09 -0500 Subject: [PATCH] [FEAT][prep_torch_inference] --- swarms/memory/weaviate_db.py | 2 +- swarms/models/base_llm.py | 6 +-- swarms/utils/__init__.py | 4 ++ swarms/utils/prep_torch_model_inference.py | 30 +++++++++++++ tests/models/test_hf.py | 10 ++--- tests/utils/load_models_torch.py | 4 +- tests/utils/prep_torch_model_inference.py | 50 ++++++++++++++++++++++ 7 files changed, 95 insertions(+), 11 deletions(-) create mode 100644 swarms/utils/prep_torch_model_inference.py create mode 100644 tests/utils/prep_torch_model_inference.py diff --git a/swarms/memory/weaviate_db.py b/swarms/memory/weaviate_db.py index a6d0c4ab..6181ab75 100644 --- a/swarms/memory/weaviate_db.py +++ b/swarms/memory/weaviate_db.py @@ -8,7 +8,7 @@ from swarms.memory.base_vectordb import VectorDatabase try: import weaviate -except ImportError as error: +except ImportError: print("pip install weaviate-client") diff --git a/swarms/models/base_llm.py b/swarms/models/base_llm.py index 15e50790..bc1f67c7 100644 --- a/swarms/models/base_llm.py +++ b/swarms/models/base_llm.py @@ -395,7 +395,7 @@ class AbstractLLM(ABC): float: _description_ """ start_time = time.time() - tokens = self.track_resource_utilization( + self.track_resource_utilization( prompt ) # assuming `generate` is a method that generates tokens first_token_time = time.time() @@ -411,7 +411,7 @@ class AbstractLLM(ABC): float: _description_ """ start_time = time.time() - tokens = self.run(prompt) + self.run(prompt) end_time = time.time() return end_time - start_time @@ -426,6 +426,6 @@ class AbstractLLM(ABC): """ start_time = time.time() for prompt in prompts: - tokens = self.run(prompt) + self.run(prompt) end_time = time.time() return len(prompts) / (end_time - start_time) diff --git a/swarms/utils/__init__.py b/swarms/utils/__init__.py index 46628aae..72fc7199 100644 --- a/swarms/utils/__init__.py +++ b/swarms/utils/__init__.py @@ -8,6 +8,9 @@ from swarms.utils.math_eval import math_eval from swarms.utils.llm_metrics_decorator import metrics_decorator from swarms.utils.device_checker_cuda import check_device from swarms.utils.load_model_torch import load_model_torch +from swarms.utils.prep_torch_model_inference import ( + prep_torch_inference, +) __all__ = [ "display_markdown_message", @@ -18,4 +21,5 @@ __all__ = [ "metrics_decorator", "check_device", "load_model_torch", + "prep_torch_inference", ] diff --git a/swarms/utils/prep_torch_model_inference.py b/swarms/utils/prep_torch_model_inference.py new file mode 100644 index 00000000..41bc07cc --- /dev/null +++ b/swarms/utils/prep_torch_model_inference.py @@ -0,0 +1,30 @@ +import torch +from swarms.utils.load_model_torch import load_model_torch + + +def prep_torch_inference( + model_path: str = None, + device: torch.device = None, + *args, + **kwargs, +): + """ + Prepare a Torch model for inference. + + Args: + model_path (str): Path to the model file. + device (torch.device): Device to run the model on. + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + + Returns: + torch.nn.Module: The prepared model. + """ + try: + model = load_model_torch(model_path, device) + model.eval() + return model + except Exception as e: + # Add error handling code here + print(f"Error occurred while preparing Torch model: {e}") + return None diff --git a/tests/models/test_hf.py b/tests/models/test_hf.py index 3a66f045..48dcd008 100644 --- a/tests/models/test_hf.py +++ b/tests/models/test_hf.py @@ -65,11 +65,11 @@ def test_init_huggingface_llm(): assert llm.model_id == "test_model" assert llm.device == "cuda" assert llm.max_length == 1000 - assert llm.quantize == True + assert llm.quantize is True assert llm.quantization_config == {"config_key": "config_value"} - assert llm.verbose == True - assert llm.distributed == True - assert llm.decoding == True + assert llm.verbose is True + assert llm.distributed is True + assert llm.decoding is True assert llm.max_workers == 3 assert llm.repitition_penalty == 1.5 assert llm.no_repeat_ngram_size == 4 @@ -90,7 +90,7 @@ def test_load_model(mock_huggingface_llm): # Test running the model def test_run(mock_huggingface_llm): llm = HuggingfaceLLM(model_id="test_model") - result = llm.run("Test prompt") + llm.run("Test prompt") # Ensure that the run function is called assert True diff --git a/tests/utils/load_models_torch.py b/tests/utils/load_models_torch.py index 15a66537..12066bbe 100644 --- a/tests/utils/load_models_torch.py +++ b/tests/utils/load_models_torch.py @@ -25,14 +25,14 @@ def test_load_model_torch_no_device_specified(mocker): mock_model = MagicMock(spec=torch.nn.Module) mocker.patch("torch.load", return_value=mock_model) mocker.patch("torch.cuda.is_available", return_value=False) - model = load_model_torch("model_path") + load_model_torch("model_path") mock_model.to.assert_called_once_with(torch.device("cpu")) def test_load_model_torch_device_specified(mocker): mock_model = MagicMock(spec=torch.nn.Module) mocker.patch("torch.load", return_value=mock_model) - model = load_model_torch( + load_model_torch( "model_path", device=torch.device("cuda") ) mock_model.to.assert_called_once_with(torch.device("cuda")) diff --git a/tests/utils/prep_torch_model_inference.py b/tests/utils/prep_torch_model_inference.py new file mode 100644 index 00000000..91f22592 --- /dev/null +++ b/tests/utils/prep_torch_model_inference.py @@ -0,0 +1,50 @@ +import torch +from unittest.mock import MagicMock +from swarms.utils.prep_torch_model_inference import ( + prep_torch_inference, +) + + +def test_prep_torch_inference_no_model_path(): + result = prep_torch_inference() + assert result is None + + +def test_prep_torch_inference_model_not_found(mocker): + mocker.patch( + "swarms.utils.prep_torch_model_inference.load_model_torch", + side_effect=FileNotFoundError, + ) + result = prep_torch_inference("non_existent_model_path") + assert result is None + + +def test_prep_torch_inference_runtime_error(mocker): + mocker.patch( + "swarms.utils.prep_torch_model_inference.load_model_torch", + side_effect=RuntimeError, + ) + result = prep_torch_inference("model_path") + assert result is None + + +def test_prep_torch_inference_no_device_specified(mocker): + mock_model = MagicMock(spec=torch.nn.Module) + mocker.patch( + "swarms.utils.prep_torch_model_inference.load_model_torch", + return_value=mock_model, + ) + prep_torch_inference("model_path") + mock_model.eval.assert_called_once() + + +def test_prep_torch_inference_device_specified(mocker): + mock_model = MagicMock(spec=torch.nn.Module) + mocker.patch( + "swarms.utils.prep_torch_model_inference.load_model_torch", + return_value=mock_model, + ) + prep_torch_inference( + "model_path", device=torch.device("cuda") + ) + mock_model.eval.assert_called_once()