You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
107 lines
4.0 KiB
107 lines
4.0 KiB
import pytest
|
|
import torch
|
|
from transformers import AutoTokenizer
|
|
from swarms.models.yi_200k import Yi34B200k
|
|
|
|
|
|
# Create fixtures if needed
|
|
@pytest.fixture
|
|
def yi34b_model():
|
|
return Yi34B200k()
|
|
|
|
|
|
# Test cases for the Yi34B200k class
|
|
def test_yi34b_init(yi34b_model):
|
|
assert isinstance(yi34b_model.model, torch.nn.Module)
|
|
assert isinstance(yi34b_model.tokenizer, AutoTokenizer)
|
|
|
|
|
|
def test_yi34b_generate_text(yi34b_model):
|
|
prompt = "There's a place where time stands still."
|
|
generated_text = yi34b_model(prompt)
|
|
assert isinstance(generated_text, str)
|
|
assert len(generated_text) > 0
|
|
|
|
|
|
@pytest.mark.parametrize("max_length", [64, 128, 256, 512])
|
|
def test_yi34b_generate_text_with_length(yi34b_model, max_length):
|
|
prompt = "There's a place where time stands still."
|
|
generated_text = yi34b_model(prompt, max_length=max_length)
|
|
assert len(generated_text) <= max_length
|
|
|
|
|
|
@pytest.mark.parametrize("temperature", [0.5, 1.0, 1.5])
|
|
def test_yi34b_generate_text_with_temperature(yi34b_model, temperature):
|
|
prompt = "There's a place where time stands still."
|
|
generated_text = yi34b_model(prompt, temperature=temperature)
|
|
assert isinstance(generated_text, str)
|
|
|
|
|
|
def test_yi34b_generate_text_with_invalid_prompt(yi34b_model):
|
|
prompt = None # Invalid prompt
|
|
with pytest.raises(ValueError, match="Input prompt must be a non-empty string"):
|
|
yi34b_model(prompt)
|
|
|
|
|
|
def test_yi34b_generate_text_with_invalid_max_length(yi34b_model):
|
|
prompt = "There's a place where time stands still."
|
|
max_length = -1 # Invalid max_length
|
|
with pytest.raises(ValueError, match="max_length must be a positive integer"):
|
|
yi34b_model(prompt, max_length=max_length)
|
|
|
|
|
|
def test_yi34b_generate_text_with_invalid_temperature(yi34b_model):
|
|
prompt = "There's a place where time stands still."
|
|
temperature = 2.0 # Invalid temperature
|
|
with pytest.raises(ValueError, match="temperature must be between 0.01 and 1.0"):
|
|
yi34b_model(prompt, temperature=temperature)
|
|
|
|
|
|
@pytest.mark.parametrize("top_k", [20, 30, 50])
|
|
def test_yi34b_generate_text_with_top_k(yi34b_model, top_k):
|
|
prompt = "There's a place where time stands still."
|
|
generated_text = yi34b_model(prompt, top_k=top_k)
|
|
assert isinstance(generated_text, str)
|
|
|
|
|
|
@pytest.mark.parametrize("top_p", [0.5, 0.7, 0.9])
|
|
def test_yi34b_generate_text_with_top_p(yi34b_model, top_p):
|
|
prompt = "There's a place where time stands still."
|
|
generated_text = yi34b_model(prompt, top_p=top_p)
|
|
assert isinstance(generated_text, str)
|
|
|
|
|
|
def test_yi34b_generate_text_with_invalid_top_k(yi34b_model):
|
|
prompt = "There's a place where time stands still."
|
|
top_k = -1 # Invalid top_k
|
|
with pytest.raises(ValueError, match="top_k must be a non-negative integer"):
|
|
yi34b_model(prompt, top_k=top_k)
|
|
|
|
|
|
def test_yi34b_generate_text_with_invalid_top_p(yi34b_model):
|
|
prompt = "There's a place where time stands still."
|
|
top_p = 1.5 # Invalid top_p
|
|
with pytest.raises(ValueError, match="top_p must be between 0.0 and 1.0"):
|
|
yi34b_model(prompt, top_p=top_p)
|
|
|
|
|
|
@pytest.mark.parametrize("repitition_penalty", [1.0, 1.2, 1.5])
|
|
def test_yi34b_generate_text_with_repitition_penalty(yi34b_model, repitition_penalty):
|
|
prompt = "There's a place where time stands still."
|
|
generated_text = yi34b_model(prompt, repitition_penalty=repitition_penalty)
|
|
assert isinstance(generated_text, str)
|
|
|
|
|
|
def test_yi34b_generate_text_with_invalid_repitition_penalty(yi34b_model):
|
|
prompt = "There's a place where time stands still."
|
|
repitition_penalty = 0.0 # Invalid repitition_penalty
|
|
with pytest.raises(ValueError, match="repitition_penalty must be a positive float"):
|
|
yi34b_model(prompt, repitition_penalty=repitition_penalty)
|
|
|
|
|
|
def test_yi34b_generate_text_with_invalid_device(yi34b_model):
|
|
prompt = "There's a place where time stands still."
|
|
device_map = "invalid_device" # Invalid device_map
|
|
with pytest.raises(ValueError, match="Invalid device_map"):
|
|
yi34b_model(prompt, device_map=device_map)
|