You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/tests/models/test_mpt7b.py

93 lines
2.2 KiB

import pytest
from transformers import AutoModelForCausalLM, AutoTokenizer
from swarms.models.mpt import MPT7B
def test_mpt7b_init():
mpt = MPT7B(
"mosaicml/mpt-7b-storywriter",
"EleutherAI/gpt-neox-20b",
max_tokens=150,
)
assert isinstance(mpt, MPT7B)
assert mpt.model_name == "mosaicml/mpt-7b-storywriter"
assert mpt.tokenizer_name == "EleutherAI/gpt-neox-20b"
assert isinstance(mpt.tokenizer, AutoTokenizer)
assert isinstance(mpt.model, AutoModelForCausalLM)
assert mpt.max_tokens == 150
def test_mpt7b_run():
mpt = MPT7B(
"mosaicml/mpt-7b-storywriter",
"EleutherAI/gpt-neox-20b",
max_tokens=150,
)
output = mpt.run(
"generate", "Once upon a time in a land far, far away..."
)
assert isinstance(output, str)
assert output.startswith(
"Once upon a time in a land far, far away..."
)
def test_mpt7b_run_invalid_task():
mpt = MPT7B(
"mosaicml/mpt-7b-storywriter",
"EleutherAI/gpt-neox-20b",
max_tokens=150,
)
with pytest.raises(ValueError):
mpt.run(
"invalid_task",
"Once upon a time in a land far, far away...",
)
def test_mpt7b_generate():
mpt = MPT7B(
"mosaicml/mpt-7b-storywriter",
"EleutherAI/gpt-neox-20b",
max_tokens=150,
)
output = mpt.generate(
"Once upon a time in a land far, far away..."
)
assert isinstance(output, str)
assert output.startswith(
"Once upon a time in a land far, far away..."
)
def test_mpt7b_batch_generate():
mpt = MPT7B(
"mosaicml/mpt-7b-storywriter",
"EleutherAI/gpt-neox-20b",
max_tokens=150,
)
prompts = ["In the deep jungles,", "At the heart of the city,"]
outputs = mpt.batch_generate(prompts, temperature=0.7)
assert isinstance(outputs, list)
assert len(outputs) == len(prompts)
for output in outputs:
assert isinstance(output, str)
def test_mpt7b_unfreeze_model():
mpt = MPT7B(
"mosaicml/mpt-7b-storywriter",
"EleutherAI/gpt-neox-20b",
max_tokens=150,
)
mpt.unfreeze_model()
for param in mpt.model.parameters():
assert param.requires_grad