From d8f30bc8fcaf3d13927160d9e6afdeb6f1c2b93a Mon Sep 17 00:00:00 2001 From: Kye Date: Tue, 5 Dec 2023 17:24:26 -0800 Subject: [PATCH] [TESTS for HuggingfacePipeline] --- swarms/models/huggingface_pipeline.py | 31 +++++++++++---- tests/models/test_hf_pipeline.py | 56 +++++++++++++++++++++++++++ 2 files changed, 80 insertions(+), 7 deletions(-) create mode 100644 tests/models/test_hf_pipeline.py diff --git a/swarms/models/huggingface_pipeline.py b/swarms/models/huggingface_pipeline.py index 81213d3b..6598c3d6 100644 --- a/swarms/models/huggingface_pipeline.py +++ b/swarms/models/huggingface_pipeline.py @@ -14,8 +14,6 @@ if torch.cuda.is_available(): class HuggingfacePipeline(AbstractLLM): """HuggingfacePipeline - - Args: AbstractLLM (AbstractLLM): [description] task (str, optional): [description]. Defaults to "text-generation". @@ -30,26 +28,45 @@ class HuggingfacePipeline(AbstractLLM): def __init__( self, - task: str = "text-generation", + task_type: str = "text-generation", model_name: str = None, use_fp8: bool = False, *args, **kwargs, ): super().__init__(*args, **kwargs) + self.task_type = task_type + self.model_name = model_name + self.use_fp8 = use_fp8 + + if torch.cuda.is_available(): + self.use_fp8 = True + else: + self.use_fp8 = False + self.pipe = pipeline( - task, model_name, use_fp8=use_fp8 * args, **kwargs + task_type, model_name, use_fp8=use_fp8 * args, **kwargs ) @abstractmethod - def run(self, task: str, *args, **kwargs): + def run(self, task: str, *args, **kwargs) -> str: + """Run the pipeline + + Args: + task (str): [description] + *args: [description] + **kwargs: [description] + + Returns: + _type_: _description_ + """ try: out = self.pipeline(task, *args, **kwargs) return out - except Exception as e: + except Exception as error: print( colored( - f"Error in {self.__class__.__name__} pipeline", + f"Error in {self.__class__.__name__} pipeline: {error}", "red", ) ) diff --git a/tests/models/test_hf_pipeline.py b/tests/models/test_hf_pipeline.py new file mode 100644 index 00000000..8580dd56 --- /dev/null +++ b/tests/models/test_hf_pipeline.py @@ -0,0 +1,56 @@ +from unittest.mock import patch + +import pytest +import torch + +from swarms.models.huggingface_pipeline import HuggingfacePipeline + + +@pytest.fixture +def mock_pipeline(): + with patch("swarms.models.huggingface_pipeline.pipeline") as mock: + yield mock + + +@pytest.fixture +def pipeline(mock_pipeline): + return HuggingfacePipeline( + "text-generation", "meta-llama/Llama-2-13b-chat-hf" + ) + + +def test_init(pipeline, mock_pipeline): + assert pipeline.task_type == "text-generation" + assert pipeline.model_name == "meta-llama/Llama-2-13b-chat-hf" + assert ( + pipeline.use_fp8 is True + if torch.cuda.is_available() + else False + ) + mock_pipeline.assert_called_once_with( + "text-generation", + "meta-llama/Llama-2-13b-chat-hf", + use_fp8=pipeline.use_fp8, + ) + + +def test_run(pipeline, mock_pipeline): + mock_pipeline.return_value = "Generated text" + result = pipeline.run("Hello, world!") + assert result == "Generated text" + mock_pipeline.assert_called_once_with("Hello, world!") + + +def test_run_with_exception(pipeline, mock_pipeline): + mock_pipeline.side_effect = Exception("Test exception") + with pytest.raises(Exception): + pipeline.run("Hello, world!") + + +def test_run_with_different_task(pipeline, mock_pipeline): + mock_pipeline.return_value = "Generated text" + result = pipeline.run("text-classification", "Hello, world!") + assert result == "Generated text" + mock_pipeline.assert_called_once_with( + "text-classification", "Hello, world!" + )