You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
87 lines
2.2 KiB
87 lines
2.2 KiB
import pytest
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
from swarms.models.mpt import MPT7B
|
|
|
|
|
|
def test_mpt7b_init():
|
|
mpt = MPT7B(
|
|
"mosaicml/mpt-7b-storywriter",
|
|
"EleutherAI/gpt-neox-20b",
|
|
max_tokens=150,
|
|
)
|
|
|
|
assert isinstance(mpt, MPT7B)
|
|
assert mpt.model_name == "mosaicml/mpt-7b-storywriter"
|
|
assert mpt.tokenizer_name == "EleutherAI/gpt-neox-20b"
|
|
assert isinstance(mpt.tokenizer, AutoTokenizer)
|
|
assert isinstance(mpt.model, AutoModelForCausalLM)
|
|
assert mpt.max_tokens == 150
|
|
|
|
|
|
def test_mpt7b_run():
|
|
mpt = MPT7B(
|
|
"mosaicml/mpt-7b-storywriter",
|
|
"EleutherAI/gpt-neox-20b",
|
|
max_tokens=150,
|
|
)
|
|
output = mpt.run(
|
|
"generate", "Once upon a time in a land far, far away..."
|
|
)
|
|
|
|
assert isinstance(output, str)
|
|
assert output.startswith("Once upon a time in a land far, far away...")
|
|
|
|
|
|
def test_mpt7b_run_invalid_task():
|
|
mpt = MPT7B(
|
|
"mosaicml/mpt-7b-storywriter",
|
|
"EleutherAI/gpt-neox-20b",
|
|
max_tokens=150,
|
|
)
|
|
|
|
with pytest.raises(ValueError):
|
|
mpt.run(
|
|
"invalid_task",
|
|
"Once upon a time in a land far, far away...",
|
|
)
|
|
|
|
|
|
def test_mpt7b_generate():
|
|
mpt = MPT7B(
|
|
"mosaicml/mpt-7b-storywriter",
|
|
"EleutherAI/gpt-neox-20b",
|
|
max_tokens=150,
|
|
)
|
|
output = mpt.generate("Once upon a time in a land far, far away...")
|
|
|
|
assert isinstance(output, str)
|
|
assert output.startswith("Once upon a time in a land far, far away...")
|
|
|
|
|
|
def test_mpt7b_batch_generate():
|
|
mpt = MPT7B(
|
|
"mosaicml/mpt-7b-storywriter",
|
|
"EleutherAI/gpt-neox-20b",
|
|
max_tokens=150,
|
|
)
|
|
prompts = ["In the deep jungles,", "At the heart of the city,"]
|
|
outputs = mpt.batch_generate(prompts, temperature=0.7)
|
|
|
|
assert isinstance(outputs, list)
|
|
assert len(outputs) == len(prompts)
|
|
for output in outputs:
|
|
assert isinstance(output, str)
|
|
|
|
|
|
def test_mpt7b_unfreeze_model():
|
|
mpt = MPT7B(
|
|
"mosaicml/mpt-7b-storywriter",
|
|
"EleutherAI/gpt-neox-20b",
|
|
max_tokens=150,
|
|
)
|
|
mpt.unfreeze_model()
|
|
|
|
for param in mpt.model.parameters():
|
|
assert param.requires_grad
|