import pytest from transformers import AutoModelForCausalLM, AutoTokenizer from swarms.models.mpt import MPT7B def test_mpt7b_init(): mpt = MPT7B( "mosaicml/mpt-7b-storywriter", "EleutherAI/gpt-neox-20b", max_tokens=150, ) assert isinstance(mpt, MPT7B) assert mpt.model_name == "mosaicml/mpt-7b-storywriter" assert mpt.tokenizer_name == "EleutherAI/gpt-neox-20b" assert isinstance(mpt.tokenizer, AutoTokenizer) assert isinstance(mpt.model, AutoModelForCausalLM) assert mpt.max_tokens == 150 def test_mpt7b_run(): mpt = MPT7B( "mosaicml/mpt-7b-storywriter", "EleutherAI/gpt-neox-20b", max_tokens=150, ) output = mpt.run( "generate", "Once upon a time in a land far, far away..." ) assert isinstance(output, str) assert output.startswith( "Once upon a time in a land far, far away..." ) def test_mpt7b_run_invalid_task(): mpt = MPT7B( "mosaicml/mpt-7b-storywriter", "EleutherAI/gpt-neox-20b", max_tokens=150, ) with pytest.raises(ValueError): mpt.run( "invalid_task", "Once upon a time in a land far, far away...", ) def test_mpt7b_generate(): mpt = MPT7B( "mosaicml/mpt-7b-storywriter", "EleutherAI/gpt-neox-20b", max_tokens=150, ) output = mpt.generate( "Once upon a time in a land far, far away..." ) assert isinstance(output, str) assert output.startswith( "Once upon a time in a land far, far away..." ) def test_mpt7b_batch_generate(): mpt = MPT7B( "mosaicml/mpt-7b-storywriter", "EleutherAI/gpt-neox-20b", max_tokens=150, ) prompts = ["In the deep jungles,", "At the heart of the city,"] outputs = mpt.batch_generate(prompts, temperature=0.7) assert isinstance(outputs, list) assert len(outputs) == len(prompts) for output in outputs: assert isinstance(output, str) def test_mpt7b_unfreeze_model(): mpt = MPT7B( "mosaicml/mpt-7b-storywriter", "EleutherAI/gpt-neox-20b", max_tokens=150, ) mpt.unfreeze_model() for param in mpt.model.parameters(): assert param.requires_grad