diff --git a/tests/swarms/test_model_parallizer.py b/tests/swarms/test_model_parallizer.py index e69de29b..29a23451 100644 --- a/tests/swarms/test_model_parallizer.py +++ b/tests/swarms/test_model_parallizer.py @@ -0,0 +1,150 @@ +import pytest +from swarms.swarms.model_parallizer import ModelParallelizer +from swarms.models import ( + HuggingfaceLLM, + Mixtral, + Dalle3, + GPT4VisionAPI, + ZeroscopeTTV, +) + +# Initialize the models +custom_config = { + "quantize": True, + "quantization_config": {"load_in_4bit": True}, + "verbose": True, +} +huggingface_llm = HuggingfaceLLM( + model_id="NousResearch/Nous-Hermes-2-Vision-Alpha", + **custom_config, +) +mixtral = Mixtral(load_in_4bit=True, use_flash_attention_2=True) +dalle3 = Dalle3(quality="high") +gpt4_vision_api = GPT4VisionAPI(max_tokens=1000) +zeroscope_ttv = ZeroscopeTTV() + + +def test_init(): + mp = ModelParallelizer( + [ + huggingface_llm, + mixtral, + dalle3, + gpt4_vision_api, + zeroscope_ttv, + ] + ) + assert isinstance(mp, ModelParallelizer) + + +def test_run(): + mp = ModelParallelizer([huggingface_llm]) + result = mp.run( + "Create a list of known biggest risks of structural collapse" + " with references" + ) + assert isinstance(result, str) + + +def test_run_all(): + mp = ModelParallelizer( + [ + huggingface_llm, + mixtral, + dalle3, + gpt4_vision_api, + zeroscope_ttv, + ] + ) + result = mp.run_all( + "Create a list of known biggest risks of structural collapse" + " with references" + ) + assert isinstance(result, list) + assert len(result) == 5 + + +def test_add_llm(): + mp = ModelParallelizer([huggingface_llm]) + mp.add_llm(mixtral) + assert len(mp.llms) == 2 + + +def test_remove_llm(): + mp = ModelParallelizer([huggingface_llm, mixtral]) + mp.remove_llm(mixtral) + assert len(mp.llms) == 1 + + +def test_save_responses_to_file(tmp_path): + mp = ModelParallelizer([huggingface_llm]) + mp.run( + "Create a list of known biggest risks of structural collapse" + " with references" + ) + file = tmp_path / "responses.txt" + mp.save_responses_to_file(file) + assert file.read_text() != "" + + +def test_get_task_history(): + mp = ModelParallelizer([huggingface_llm]) + mp.run( + "Create a list of known biggest risks of structural collapse" + " with references" + ) + assert mp.get_task_history() == [ + "Create a list of known biggest risks of structural collapse" + " with references" + ] + + +def test_summary(capsys): + mp = ModelParallelizer([huggingface_llm]) + mp.run( + "Create a list of known biggest risks of structural collapse" + " with references" + ) + mp.summary() + captured = capsys.readouterr() + assert "Tasks History:" in captured.out + + +def test_enable_load_balancing(): + mp = ModelParallelizer([huggingface_llm]) + mp.enable_load_balancing() + assert mp.load_balancing == True + + +def test_disable_load_balancing(): + mp = ModelParallelizer([huggingface_llm]) + mp.disable_load_balancing() + assert mp.load_balancing == False + + +def test_concurrent_run(): + mp = ModelParallelizer([huggingface_llm, mixtral]) + result = mp.concurrent_run( + "Create a list of known biggest risks of structural collapse" + " with references" + ) + assert isinstance(result, list) + assert len(result) == 2 + + +def test_concurrent_run_no_task(): + mp = ModelParallelizer([huggingface_llm]) + with pytest.raises(TypeError): + mp.concurrent_run() + + +def test_concurrent_run_non_string_task(): + mp = ModelParallelizer([huggingface_llm]) + with pytest.raises(TypeError): + mp.concurrent_run(123) + + +def test_concurrent_run_empty_task(): + mp = ModelParallelizer([huggingface_llm]) + result = mp.concurrent_run("") + assert result == [""]