parent
f7b8a442e0
commit
b640688995
@ -0,0 +1,150 @@
|
||||
import pytest
|
||||
from swarms.swarms.model_parallizer import ModelParallelizer
|
||||
from swarms.models import (
|
||||
HuggingfaceLLM,
|
||||
Mixtral,
|
||||
Dalle3,
|
||||
GPT4VisionAPI,
|
||||
ZeroscopeTTV,
|
||||
)
|
||||
|
||||
# Initialize the models
|
||||
custom_config = {
|
||||
"quantize": True,
|
||||
"quantization_config": {"load_in_4bit": True},
|
||||
"verbose": True,
|
||||
}
|
||||
huggingface_llm = HuggingfaceLLM(
|
||||
model_id="NousResearch/Nous-Hermes-2-Vision-Alpha",
|
||||
**custom_config,
|
||||
)
|
||||
mixtral = Mixtral(load_in_4bit=True, use_flash_attention_2=True)
|
||||
dalle3 = Dalle3(quality="high")
|
||||
gpt4_vision_api = GPT4VisionAPI(max_tokens=1000)
|
||||
zeroscope_ttv = ZeroscopeTTV()
|
||||
|
||||
|
||||
def test_init():
|
||||
mp = ModelParallelizer(
|
||||
[
|
||||
huggingface_llm,
|
||||
mixtral,
|
||||
dalle3,
|
||||
gpt4_vision_api,
|
||||
zeroscope_ttv,
|
||||
]
|
||||
)
|
||||
assert isinstance(mp, ModelParallelizer)
|
||||
|
||||
|
||||
def test_run():
|
||||
mp = ModelParallelizer([huggingface_llm])
|
||||
result = mp.run(
|
||||
"Create a list of known biggest risks of structural collapse"
|
||||
" with references"
|
||||
)
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
def test_run_all():
|
||||
mp = ModelParallelizer(
|
||||
[
|
||||
huggingface_llm,
|
||||
mixtral,
|
||||
dalle3,
|
||||
gpt4_vision_api,
|
||||
zeroscope_ttv,
|
||||
]
|
||||
)
|
||||
result = mp.run_all(
|
||||
"Create a list of known biggest risks of structural collapse"
|
||||
" with references"
|
||||
)
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 5
|
||||
|
||||
|
||||
def test_add_llm():
|
||||
mp = ModelParallelizer([huggingface_llm])
|
||||
mp.add_llm(mixtral)
|
||||
assert len(mp.llms) == 2
|
||||
|
||||
|
||||
def test_remove_llm():
|
||||
mp = ModelParallelizer([huggingface_llm, mixtral])
|
||||
mp.remove_llm(mixtral)
|
||||
assert len(mp.llms) == 1
|
||||
|
||||
|
||||
def test_save_responses_to_file(tmp_path):
|
||||
mp = ModelParallelizer([huggingface_llm])
|
||||
mp.run(
|
||||
"Create a list of known biggest risks of structural collapse"
|
||||
" with references"
|
||||
)
|
||||
file = tmp_path / "responses.txt"
|
||||
mp.save_responses_to_file(file)
|
||||
assert file.read_text() != ""
|
||||
|
||||
|
||||
def test_get_task_history():
|
||||
mp = ModelParallelizer([huggingface_llm])
|
||||
mp.run(
|
||||
"Create a list of known biggest risks of structural collapse"
|
||||
" with references"
|
||||
)
|
||||
assert mp.get_task_history() == [
|
||||
"Create a list of known biggest risks of structural collapse"
|
||||
" with references"
|
||||
]
|
||||
|
||||
|
||||
def test_summary(capsys):
|
||||
mp = ModelParallelizer([huggingface_llm])
|
||||
mp.run(
|
||||
"Create a list of known biggest risks of structural collapse"
|
||||
" with references"
|
||||
)
|
||||
mp.summary()
|
||||
captured = capsys.readouterr()
|
||||
assert "Tasks History:" in captured.out
|
||||
|
||||
|
||||
def test_enable_load_balancing():
|
||||
mp = ModelParallelizer([huggingface_llm])
|
||||
mp.enable_load_balancing()
|
||||
assert mp.load_balancing == True
|
||||
|
||||
|
||||
def test_disable_load_balancing():
|
||||
mp = ModelParallelizer([huggingface_llm])
|
||||
mp.disable_load_balancing()
|
||||
assert mp.load_balancing == False
|
||||
|
||||
|
||||
def test_concurrent_run():
|
||||
mp = ModelParallelizer([huggingface_llm, mixtral])
|
||||
result = mp.concurrent_run(
|
||||
"Create a list of known biggest risks of structural collapse"
|
||||
" with references"
|
||||
)
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 2
|
||||
|
||||
|
||||
def test_concurrent_run_no_task():
|
||||
mp = ModelParallelizer([huggingface_llm])
|
||||
with pytest.raises(TypeError):
|
||||
mp.concurrent_run()
|
||||
|
||||
|
||||
def test_concurrent_run_non_string_task():
|
||||
mp = ModelParallelizer([huggingface_llm])
|
||||
with pytest.raises(TypeError):
|
||||
mp.concurrent_run(123)
|
||||
|
||||
|
||||
def test_concurrent_run_empty_task():
|
||||
mp = ModelParallelizer([huggingface_llm])
|
||||
result = mp.concurrent_run("")
|
||||
assert result == [""]
|
Loading…
Reference in new issue