From b70e2ac8bc9fda4832d9a34458c6cde59e08a029 Mon Sep 17 00:00:00 2001 From: Kye Date: Sat, 9 Dec 2023 19:35:20 -0800 Subject: [PATCH] [TESTS] --- swarms/models/vllm.py | 0 swarms/utils/ray_traceback_wrapper.py | 91 +++++++++++++++++ tests/models/test_vllm.py | 141 ++++++++++++++++++++++++++ 3 files changed, 232 insertions(+) create mode 100644 swarms/models/vllm.py create mode 100644 swarms/utils/ray_traceback_wrapper.py create mode 100644 tests/models/test_vllm.py diff --git a/swarms/models/vllm.py b/swarms/models/vllm.py new file mode 100644 index 00000000..e69de29b diff --git a/swarms/utils/ray_traceback_wrapper.py b/swarms/utils/ray_traceback_wrapper.py new file mode 100644 index 00000000..650356cc --- /dev/null +++ b/swarms/utils/ray_traceback_wrapper.py @@ -0,0 +1,91 @@ +import subprocess +from typing import Optional, Tuple, List +from swarms.models.base_llm import AbstractLLM + +try: + from vllm import LLM, SamplingParams +except ImportError as error: + print(f"[ERROR] [vLLM] {error}") + subprocess.run(["pip", "install", "vllm"]) + raise error + + +class vLLM(AbstractLLM): + """vLLM model + + + Args: + model_name (str, optional): _description_. Defaults to "facebook/opt-13b". + tensor_parallel_size (int, optional): _description_. Defaults to 4. + trust_remote_code (bool, optional): _description_. Defaults to False. + revision (str, optional): _description_. Defaults to None. + temperature (float, optional): _description_. Defaults to 0.5. + top_p (float, optional): _description_. Defaults to 0.95. + *args: _description_. + **kwargs: _description_. + + Methods: + run: run the vLLM model + + Raises: + error: _description_ + + Examples: + >>> from swarms.models.vllm import vLLM + >>> vllm = vLLM() + >>> vllm.run("Hello world!") + + + """ + + def __init__( + self, + model_name: str = "facebook/opt-13b", + tensor_parallel_size: int = 4, + trust_remote_code: bool = False, + revision: str = None, + temperature: float = 0.5, + top_p: float = 0.95, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.model_name = model_name + self.tensor_parallel_size = tensor_parallel_size + self.trust_remote_code = trust_remote_code + self.revision = revision + self.top_p = top_p + + # LLM model + self.llm = LLM( + model_name=self.model_name, + tensor_parallel_size=self.tensor_parallel_size, + trust_remote_code=self.trust_remote_code, + revision=self.revision, + *args, + **kwargs, + ) + + # Sampling parameters + self.sampling_params = SamplingParams( + temperature=temperature, top_p=top_p, *args, **kwargs + ) + + def run(self, task: str = None, *args, **kwargs): + """Run the vLLM model + + Args: + task (str, optional): _description_. Defaults to None. + + Raises: + error: _description_ + + Returns: + _type_: _description_ + """ + try: + outputs = self.llm.generate(task, self.sampling_params) + return outputs + except Exception as error: + print(f"[ERROR] [vLLM] [run] {error}") + raise error diff --git a/tests/models/test_vllm.py b/tests/models/test_vllm.py new file mode 100644 index 00000000..d15a13b9 --- /dev/null +++ b/tests/models/test_vllm.py @@ -0,0 +1,141 @@ +import pytest +from swarms.models.vllm import vLLM + + +# Fixture for initializing vLLM +@pytest.fixture +def vllm_instance(): + return vLLM() + + +# Test the default initialization of vLLM +def test_vllm_default_init(vllm_instance): + assert isinstance(vllm_instance, vLLM) + assert vllm_instance.model_name == "facebook/opt-13b" + assert vllm_instance.tensor_parallel_size == 4 + assert not vllm_instance.trust_remote_code + assert vllm_instance.revision is None + assert vllm_instance.temperature == 0.5 + assert vllm_instance.top_p == 0.95 + + +# Test custom initialization of vLLM +def test_vllm_custom_init(): + vllm_instance = vLLM( + model_name="custom_model", + tensor_parallel_size=8, + trust_remote_code=True, + revision="123", + temperature=0.7, + top_p=0.9, + ) + assert isinstance(vllm_instance, vLLM) + assert vllm_instance.model_name == "custom_model" + assert vllm_instance.tensor_parallel_size == 8 + assert vllm_instance.trust_remote_code + assert vllm_instance.revision == "123" + assert vllm_instance.temperature == 0.7 + assert vllm_instance.top_p == 0.9 + + +# Test the run method of vLLM +def test_vllm_run(vllm_instance): + task = "Hello, vLLM!" + result = vllm_instance.run(task) + assert isinstance(result, str) + assert len(result) > 0 + + +# Test run method with different temperature and top_p values +@pytest.mark.parametrize( + "temperature, top_p", [(0.2, 0.8), (0.8, 0.2)] +) +def test_vllm_run_with_params(vllm_instance, temperature, top_p): + task = "Temperature and Top-P Test" + result = vllm_instance.run( + task, temperature=temperature, top_p=top_p + ) + assert isinstance(result, str) + assert len(result) > 0 + + +# Test run method with a specific model revision +def test_vllm_run_with_revision(vllm_instance): + task = "Specific Model Revision Test" + result = vllm_instance.run(task, revision="abc123") + assert isinstance(result, str) + assert len(result) > 0 + + +# Test run method with a specific model name +def test_vllm_run_with_custom_model(vllm_instance): + task = "Custom Model Test" + custom_model_name = "my_custom_model" + result = vllm_instance.run(task, model_name=custom_model_name) + assert isinstance(result, str) + assert len(result) > 0 + assert vllm_instance.model_name == custom_model_name + + +# Test run method with invalid task input +def test_vllm_run_invalid_task(vllm_instance): + invalid_task = None + with pytest.raises(ValueError): + vllm_instance.run(invalid_task) + + +# Test run method with a very high temperature value +def test_vllm_run_high_temperature(vllm_instance): + task = "High Temperature Test" + high_temperature = 10.0 + result = vllm_instance.run(task, temperature=high_temperature) + assert isinstance(result, str) + assert len(result) > 0 + + +# Test run method with a very low top_p value +def test_vllm_run_low_top_p(vllm_instance): + task = "Low Top-P Test" + low_top_p = 0.01 + result = vllm_instance.run(task, top_p=low_top_p) + assert isinstance(result, str) + assert len(result) > 0 + + +# Test run method with an empty task +def test_vllm_run_empty_task(vllm_instance): + empty_task = "" + result = vllm_instance.run(empty_task) + assert isinstance(result, str) + assert len(result) == 0 + + +# Test initialization with invalid parameters +def test_vllm_invalid_init(): + with pytest.raises(ValueError): + vllm_instance = vLLM( + model_name=None, + tensor_parallel_size=-1, + trust_remote_code="invalid", + revision=123, + temperature=-0.1, + top_p=1.1, + ) + + +# Test running vLLM with a large number of parallel heads +def test_vllm_large_parallel_heads(): + vllm_instance = vLLM(tensor_parallel_size=16) + task = "Large Parallel Heads Test" + result = vllm_instance.run(task) + assert isinstance(result, str) + assert len(result) > 0 + + +# Test running vLLM with trust_remote_code set to True +def test_vllm_trust_remote_code(): + vllm_instance = vLLM(trust_remote_code=True) + task = "Trust Remote Code Test" + result = vllm_instance.run(task) + assert isinstance(result, str) + assert len(result) > 0