You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/tests/models/test_mpt7b.py

70 lines
2.0 KiB

import pytest
from transformers import AutoModelForCausalLM, AutoTokenizer
from swarms.models.mpt import MPT7B
def test_mpt7b_init():
mpt = MPT7B(
"mosaicml/mpt-7b-storywriter", "EleutherAI/gpt-neox-20b", max_tokens=150
)
assert isinstance(mpt, MPT7B)
assert mpt.model_name == "mosaicml/mpt-7b-storywriter"
assert mpt.tokenizer_name == "EleutherAI/gpt-neox-20b"
assert isinstance(mpt.tokenizer, AutoTokenizer)
assert isinstance(mpt.model, AutoModelForCausalLM)
assert mpt.max_tokens == 150
def test_mpt7b_run():
mpt = MPT7B(
"mosaicml/mpt-7b-storywriter", "EleutherAI/gpt-neox-20b", max_tokens=150
)
output = mpt.run("generate", "Once upon a time in a land far, far away...")
assert isinstance(output, str)
assert output.startswith("Once upon a time in a land far, far away...")
def test_mpt7b_run_invalid_task():
mpt = MPT7B(
"mosaicml/mpt-7b-storywriter", "EleutherAI/gpt-neox-20b", max_tokens=150
)
with pytest.raises(ValueError):
mpt.run("invalid_task", "Once upon a time in a land far, far away...")
def test_mpt7b_generate():
mpt = MPT7B(
"mosaicml/mpt-7b-storywriter", "EleutherAI/gpt-neox-20b", max_tokens=150
)
output = mpt.generate("Once upon a time in a land far, far away...")
assert isinstance(output, str)
assert output.startswith("Once upon a time in a land far, far away...")
def test_mpt7b_batch_generate():
mpt = MPT7B(
"mosaicml/mpt-7b-storywriter", "EleutherAI/gpt-neox-20b", max_tokens=150
)
prompts = ["In the deep jungles,", "At the heart of the city,"]
outputs = mpt.batch_generate(prompts, temperature=0.7)
assert isinstance(outputs, list)
assert len(outputs) == len(prompts)
for output in outputs:
assert isinstance(output, str)
def test_mpt7b_unfreeze_model():
mpt = MPT7B(
"mosaicml/mpt-7b-storywriter", "EleutherAI/gpt-neox-20b", max_tokens=150
)
mpt.unfreeze_model()
for param in mpt.model.parameters():
assert param.requires_grad