You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/tests/models/test_hf_pipeline.py

53 lines
1.5 KiB

6 months ago
from unittest.mock import patch
import pytest
import torch
from swarms.models.huggingface_pipeline import HuggingfacePipeline
@pytest.fixture
def mock_pipeline():
with patch("swarms.models.huggingface_pipeline.pipeline") as mock:
yield mock
@pytest.fixture
def pipeline(mock_pipeline):
return HuggingfacePipeline(
"text-generation", "meta-llama/Llama-2-13b-chat-hf"
)
def test_init(pipeline, mock_pipeline):
assert pipeline.task_type == "text-generation"
assert pipeline.model_name == "meta-llama/Llama-2-13b-chat-hf"
assert pipeline.use_fp8 is True if torch.cuda.is_available() else False
mock_pipeline.assert_called_once_with(
"text-generation",
"meta-llama/Llama-2-13b-chat-hf",
use_fp8=pipeline.use_fp8,
)
def test_run(pipeline, mock_pipeline):
mock_pipeline.return_value = "Generated text"
result = pipeline.run("Hello, world!")
assert result == "Generated text"
mock_pipeline.assert_called_once_with("Hello, world!")
def test_run_with_exception(pipeline, mock_pipeline):
mock_pipeline.side_effect = Exception("Test exception")
with pytest.raises(Exception):
pipeline.run("Hello, world!")
def test_run_with_different_task(pipeline, mock_pipeline):
mock_pipeline.return_value = "Generated text"
result = pipeline.run("text-classification", "Hello, world!")
assert result == "Generated text"
mock_pipeline.assert_called_once_with(
"text-classification", "Hello, world!"
)