Merge pull request #800 from ascender1729/feature/vllm-support

Add vLLM support with wrapper and example
pull/804/merge
Kye Gomez 2 weeks ago committed by GitHub
commit 09159905b0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -0,0 +1,44 @@
from swarms.utils.vllm_wrapper import VLLMWrapper
def main():
# Initialize the vLLM wrapper with a model
# Note: You'll need to have the model downloaded or specify a HuggingFace model ID
llm = VLLMWrapper(
model_name="meta-llama/Llama-2-7b-chat-hf", # Replace with your model path or HF model ID
temperature=0.7,
max_tokens=1000,
)
# Example task
task = "What are the benefits of using vLLM for inference?"
# Run inference
response = llm.run(task)
print("Response:", response)
# Example with system prompt
llm_with_system = VLLMWrapper(
model_name="meta-llama/Llama-2-7b-chat-hf", # Replace with your model path or HF model ID
system_prompt="You are a helpful AI assistant that provides concise answers.",
temperature=0.7,
)
# Run inference with system prompt
response = llm_with_system.run(task)
print("\nResponse with system prompt:", response)
# Example with batched inference
tasks = [
"What is vLLM?",
"How does vLLM improve inference speed?",
"What are the main features of vLLM?"
]
responses = llm.batched_run(tasks, batch_size=2)
print("\nBatched responses:")
for task, response in zip(tasks, responses):
print(f"\nTask: {task}")
print(f"Response: {response}")
if __name__ == "__main__":
main()

@ -1,4 +1,3 @@
torch>=2.1.1,<3.0
transformers>=4.39.0,<4.51.0
asyncio>=3.4.3,<4.0
@ -23,3 +22,4 @@ pytest>=8.1.1
networkx
aiofiles
httpx
vllm>=0.2.0

@ -0,0 +1,138 @@
from typing import List, Optional, Dict, Any
from loguru import logger
try:
from vllm import LLM, SamplingParams
except ImportError:
import subprocess
import sys
print("Installing vllm")
subprocess.check_call([sys.executable, "-m", "pip", "install", "-U", "vllm"])
print("vllm installed")
from vllm import LLM, SamplingParams
class VLLMWrapper:
"""
A wrapper class for vLLM that provides a similar interface to LiteLLM.
This class handles model initialization and inference using vLLM.
"""
def __init__(
self,
model_name: str = "meta-llama/Llama-2-7b-chat-hf",
system_prompt: Optional[str] = None,
stream: bool = False,
temperature: float = 0.5,
max_tokens: int = 4000,
max_completion_tokens: int = 4000,
tools_list_dictionary: Optional[List[Dict[str, Any]]] = None,
tool_choice: str = "auto",
parallel_tool_calls: bool = False,
*args,
**kwargs,
):
"""
Initialize the vLLM wrapper with the given parameters.
Args:
model_name (str): The name of the model to use. Defaults to "meta-llama/Llama-2-7b-chat-hf".
system_prompt (str, optional): The system prompt to use. Defaults to None.
stream (bool): Whether to stream the output. Defaults to False.
temperature (float): The temperature for sampling. Defaults to 0.5.
max_tokens (int): The maximum number of tokens to generate. Defaults to 4000.
max_completion_tokens (int): The maximum number of completion tokens. Defaults to 4000.
tools_list_dictionary (List[Dict[str, Any]], optional): List of available tools. Defaults to None.
tool_choice (str): How to choose tools. Defaults to "auto".
parallel_tool_calls (bool): Whether to allow parallel tool calls. Defaults to False.
"""
self.model_name = model_name
self.system_prompt = system_prompt
self.stream = stream
self.temperature = temperature
self.max_tokens = max_tokens
self.max_completion_tokens = max_completion_tokens
self.tools_list_dictionary = tools_list_dictionary
self.tool_choice = tool_choice
self.parallel_tool_calls = parallel_tool_calls
# Initialize vLLM
self.llm = LLM(model=model_name, **kwargs)
self.sampling_params = SamplingParams(
temperature=temperature,
max_tokens=max_tokens,
)
def _prepare_prompt(self, task: str) -> str:
"""
Prepare the prompt for the given task.
Args:
task (str): The task to prepare the prompt for.
Returns:
str: The prepared prompt.
"""
if self.system_prompt:
return f"{self.system_prompt}\n\nUser: {task}\nAssistant:"
return f"User: {task}\nAssistant:"
def run(self, task: str, *args, **kwargs) -> str:
"""
Run the model for the given task.
Args:
task (str): The task to run the model for.
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
str: The model's response.
"""
try:
prompt = self._prepare_prompt(task)
outputs = self.llm.generate(prompt, self.sampling_params)
response = outputs[0].outputs[0].text.strip()
return response
except Exception as error:
logger.error(f"Error in VLLMWrapper: {error}")
raise error
def __call__(self, task: str, *args, **kwargs) -> str:
"""
Call the model for the given task.
Args:
task (str): The task to run the model for.
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
str: The model's response.
"""
return self.run(task, *args, **kwargs)
def batched_run(self, tasks: List[str], batch_size: int = 10) -> List[str]:
"""
Run the model for multiple tasks in batches.
Args:
tasks (List[str]): List of tasks to run.
batch_size (int): Size of each batch. Defaults to 10.
Returns:
List[str]: List of model responses.
"""
logger.info(f"Running tasks in batches of size {batch_size}. Total tasks: {len(tasks)}")
results = []
for i in range(0, len(tasks), batch_size):
batch = tasks[i:i + batch_size]
for task in batch:
logger.info(f"Running task: {task}")
results.append(self.run(task))
logger.info("Completed all tasks.")
return results
Loading…
Cancel
Save