diff --git a/examples/vllm_example.py b/examples/vllm_example.py new file mode 100644 index 00000000..231a68fc --- /dev/null +++ b/examples/vllm_example.py @@ -0,0 +1,44 @@ +from swarms.utils.vllm_wrapper import VLLMWrapper + +def main(): + # Initialize the vLLM wrapper with a model + # Note: You'll need to have the model downloaded or specify a HuggingFace model ID + llm = VLLMWrapper( + model_name="meta-llama/Llama-2-7b-chat-hf", # Replace with your model path or HF model ID + temperature=0.7, + max_tokens=1000, + ) + + # Example task + task = "What are the benefits of using vLLM for inference?" + + # Run inference + response = llm.run(task) + print("Response:", response) + + # Example with system prompt + llm_with_system = VLLMWrapper( + model_name="meta-llama/Llama-2-7b-chat-hf", # Replace with your model path or HF model ID + system_prompt="You are a helpful AI assistant that provides concise answers.", + temperature=0.7, + ) + + # Run inference with system prompt + response = llm_with_system.run(task) + print("\nResponse with system prompt:", response) + + # Example with batched inference + tasks = [ + "What is vLLM?", + "How does vLLM improve inference speed?", + "What are the main features of vLLM?" + ] + + responses = llm.batched_run(tasks, batch_size=2) + print("\nBatched responses:") + for task, response in zip(tasks, responses): + print(f"\nTask: {task}") + print(f"Response: {response}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 603fe0d0..fe442e60 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,3 @@ - torch>=2.1.1,<3.0 transformers>=4.39.0,<4.51.0 asyncio>=3.4.3,<4.0 @@ -23,3 +22,4 @@ pytest>=8.1.1 networkx aiofiles httpx +vllm>=0.2.0 diff --git a/swarms/utils/vllm_wrapper.py b/swarms/utils/vllm_wrapper.py new file mode 100644 index 00000000..322ce1ad --- /dev/null +++ b/swarms/utils/vllm_wrapper.py @@ -0,0 +1,138 @@ +from typing import List, Optional, Dict, Any +from loguru import logger + +try: + from vllm import LLM, SamplingParams +except ImportError: + import subprocess + import sys + print("Installing vllm") + subprocess.check_call([sys.executable, "-m", "pip", "install", "-U", "vllm"]) + print("vllm installed") + from vllm import LLM, SamplingParams + +class VLLMWrapper: + """ + A wrapper class for vLLM that provides a similar interface to LiteLLM. + This class handles model initialization and inference using vLLM. + """ + + def __init__( + self, + model_name: str = "meta-llama/Llama-2-7b-chat-hf", + system_prompt: Optional[str] = None, + stream: bool = False, + temperature: float = 0.5, + max_tokens: int = 4000, + max_completion_tokens: int = 4000, + tools_list_dictionary: Optional[List[Dict[str, Any]]] = None, + tool_choice: str = "auto", + parallel_tool_calls: bool = False, + *args, + **kwargs, + ): + """ + Initialize the vLLM wrapper with the given parameters. + + Args: + model_name (str): The name of the model to use. Defaults to "meta-llama/Llama-2-7b-chat-hf". + system_prompt (str, optional): The system prompt to use. Defaults to None. + stream (bool): Whether to stream the output. Defaults to False. + temperature (float): The temperature for sampling. Defaults to 0.5. + max_tokens (int): The maximum number of tokens to generate. Defaults to 4000. + max_completion_tokens (int): The maximum number of completion tokens. Defaults to 4000. + tools_list_dictionary (List[Dict[str, Any]], optional): List of available tools. Defaults to None. + tool_choice (str): How to choose tools. Defaults to "auto". + parallel_tool_calls (bool): Whether to allow parallel tool calls. Defaults to False. + """ + self.model_name = model_name + self.system_prompt = system_prompt + self.stream = stream + self.temperature = temperature + self.max_tokens = max_tokens + self.max_completion_tokens = max_completion_tokens + self.tools_list_dictionary = tools_list_dictionary + self.tool_choice = tool_choice + self.parallel_tool_calls = parallel_tool_calls + + # Initialize vLLM + self.llm = LLM(model=model_name, **kwargs) + self.sampling_params = SamplingParams( + temperature=temperature, + max_tokens=max_tokens, + ) + + def _prepare_prompt(self, task: str) -> str: + """ + Prepare the prompt for the given task. + + Args: + task (str): The task to prepare the prompt for. + + Returns: + str: The prepared prompt. + """ + if self.system_prompt: + return f"{self.system_prompt}\n\nUser: {task}\nAssistant:" + return f"User: {task}\nAssistant:" + + def run(self, task: str, *args, **kwargs) -> str: + """ + Run the model for the given task. + + Args: + task (str): The task to run the model for. + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + + Returns: + str: The model's response. + """ + try: + prompt = self._prepare_prompt(task) + + outputs = self.llm.generate(prompt, self.sampling_params) + response = outputs[0].outputs[0].text.strip() + + return response + + except Exception as error: + logger.error(f"Error in VLLMWrapper: {error}") + raise error + + def __call__(self, task: str, *args, **kwargs) -> str: + """ + Call the model for the given task. + + Args: + task (str): The task to run the model for. + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + + Returns: + str: The model's response. + """ + return self.run(task, *args, **kwargs) + + def batched_run(self, tasks: List[str], batch_size: int = 10) -> List[str]: + """ + Run the model for multiple tasks in batches. + + Args: + tasks (List[str]): List of tasks to run. + batch_size (int): Size of each batch. Defaults to 10. + + Returns: + List[str]: List of model responses. + """ + logger.info(f"Running tasks in batches of size {batch_size}. Total tasks: {len(tasks)}") + results = [] + + for i in range(0, len(tasks), batch_size): + batch = tasks[i:i + batch_size] + for task in batch: + logger.info(f"Running task: {task}") + results.append(self.run(task)) + + logger.info("Completed all tasks.") + return results \ No newline at end of file