You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
70 lines
2.0 KiB
70 lines
2.0 KiB
1 year ago
|
import pytest
|
||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||
1 year ago
|
|
||
1 year ago
|
from swarms.models.mpt import MPT7B
|
||
|
|
||
|
|
||
|
def test_mpt7b_init():
|
||
|
mpt = MPT7B(
|
||
|
"mosaicml/mpt-7b-storywriter", "EleutherAI/gpt-neox-20b", max_tokens=150
|
||
|
)
|
||
|
|
||
|
assert isinstance(mpt, MPT7B)
|
||
|
assert mpt.model_name == "mosaicml/mpt-7b-storywriter"
|
||
|
assert mpt.tokenizer_name == "EleutherAI/gpt-neox-20b"
|
||
|
assert isinstance(mpt.tokenizer, AutoTokenizer)
|
||
|
assert isinstance(mpt.model, AutoModelForCausalLM)
|
||
|
assert mpt.max_tokens == 150
|
||
|
|
||
|
|
||
|
def test_mpt7b_run():
|
||
|
mpt = MPT7B(
|
||
|
"mosaicml/mpt-7b-storywriter", "EleutherAI/gpt-neox-20b", max_tokens=150
|
||
|
)
|
||
|
output = mpt.run("generate", "Once upon a time in a land far, far away...")
|
||
|
|
||
|
assert isinstance(output, str)
|
||
|
assert output.startswith("Once upon a time in a land far, far away...")
|
||
|
|
||
|
|
||
|
def test_mpt7b_run_invalid_task():
|
||
|
mpt = MPT7B(
|
||
|
"mosaicml/mpt-7b-storywriter", "EleutherAI/gpt-neox-20b", max_tokens=150
|
||
|
)
|
||
|
|
||
|
with pytest.raises(ValueError):
|
||
|
mpt.run("invalid_task", "Once upon a time in a land far, far away...")
|
||
|
|
||
|
|
||
|
def test_mpt7b_generate():
|
||
|
mpt = MPT7B(
|
||
|
"mosaicml/mpt-7b-storywriter", "EleutherAI/gpt-neox-20b", max_tokens=150
|
||
|
)
|
||
|
output = mpt.generate("Once upon a time in a land far, far away...")
|
||
|
|
||
|
assert isinstance(output, str)
|
||
|
assert output.startswith("Once upon a time in a land far, far away...")
|
||
|
|
||
|
|
||
|
def test_mpt7b_batch_generate():
|
||
|
mpt = MPT7B(
|
||
|
"mosaicml/mpt-7b-storywriter", "EleutherAI/gpt-neox-20b", max_tokens=150
|
||
|
)
|
||
|
prompts = ["In the deep jungles,", "At the heart of the city,"]
|
||
|
outputs = mpt.batch_generate(prompts, temperature=0.7)
|
||
|
|
||
|
assert isinstance(outputs, list)
|
||
|
assert len(outputs) == len(prompts)
|
||
|
for output in outputs:
|
||
|
assert isinstance(output, str)
|
||
|
|
||
|
|
||
|
def test_mpt7b_unfreeze_model():
|
||
|
mpt = MPT7B(
|
||
|
"mosaicml/mpt-7b-storywriter", "EleutherAI/gpt-neox-20b", max_tokens=150
|
||
|
)
|
||
|
mpt.unfreeze_model()
|
||
|
|
||
|
for param in mpt.model.parameters():
|
||
|
assert param.requires_grad
|