[TESTS for HuggingfacePipeline]

pull/268/head
Kye 1 year ago
parent fcb89b1774
commit d8f30bc8fc

@ -14,8 +14,6 @@ if torch.cuda.is_available():
class HuggingfacePipeline(AbstractLLM):
"""HuggingfacePipeline
Args:
AbstractLLM (AbstractLLM): [description]
task (str, optional): [description]. Defaults to "text-generation".
@ -30,26 +28,45 @@ class HuggingfacePipeline(AbstractLLM):
def __init__(
self,
task: str = "text-generation",
task_type: str = "text-generation",
model_name: str = None,
use_fp8: bool = False,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.task_type = task_type
self.model_name = model_name
self.use_fp8 = use_fp8
if torch.cuda.is_available():
self.use_fp8 = True
else:
self.use_fp8 = False
self.pipe = pipeline(
task, model_name, use_fp8=use_fp8 * args, **kwargs
task_type, model_name, use_fp8=use_fp8 * args, **kwargs
)
@abstractmethod
def run(self, task: str, *args, **kwargs):
def run(self, task: str, *args, **kwargs) -> str:
"""Run the pipeline
Args:
task (str): [description]
*args: [description]
**kwargs: [description]
Returns:
_type_: _description_
"""
try:
out = self.pipeline(task, *args, **kwargs)
return out
except Exception as e:
except Exception as error:
print(
colored(
f"Error in {self.__class__.__name__} pipeline",
f"Error in {self.__class__.__name__} pipeline: {error}",
"red",
)
)

@ -0,0 +1,56 @@
from unittest.mock import patch
import pytest
import torch
from swarms.models.huggingface_pipeline import HuggingfacePipeline
@pytest.fixture
def mock_pipeline():
with patch("swarms.models.huggingface_pipeline.pipeline") as mock:
yield mock
@pytest.fixture
def pipeline(mock_pipeline):
return HuggingfacePipeline(
"text-generation", "meta-llama/Llama-2-13b-chat-hf"
)
def test_init(pipeline, mock_pipeline):
assert pipeline.task_type == "text-generation"
assert pipeline.model_name == "meta-llama/Llama-2-13b-chat-hf"
assert (
pipeline.use_fp8 is True
if torch.cuda.is_available()
else False
)
mock_pipeline.assert_called_once_with(
"text-generation",
"meta-llama/Llama-2-13b-chat-hf",
use_fp8=pipeline.use_fp8,
)
def test_run(pipeline, mock_pipeline):
mock_pipeline.return_value = "Generated text"
result = pipeline.run("Hello, world!")
assert result == "Generated text"
mock_pipeline.assert_called_once_with("Hello, world!")
def test_run_with_exception(pipeline, mock_pipeline):
mock_pipeline.side_effect = Exception("Test exception")
with pytest.raises(Exception):
pipeline.run("Hello, world!")
def test_run_with_different_task(pipeline, mock_pipeline):
mock_pipeline.return_value = "Generated text"
result = pipeline.run("text-classification", "Hello, world!")
assert result == "Generated text"
mock_pipeline.assert_called_once_with(
"text-classification", "Hello, world!"
)
Loading…
Cancel
Save