parent
e8f161beea
commit
adb6930439
@ -1,46 +0,0 @@
|
|||||||
from swarms.utils.vllm_wrapper import VLLMWrapper
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
# Initialize the vLLM wrapper with a model
|
|
||||||
# Note: You'll need to have the model downloaded or specify a HuggingFace model ID
|
|
||||||
llm = VLLMWrapper(
|
|
||||||
model_name="meta-llama/Llama-2-7b-chat-hf", # Replace with your model path or HF model ID
|
|
||||||
temperature=0.7,
|
|
||||||
max_tokens=1000,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Example task
|
|
||||||
task = "What are the benefits of using vLLM for inference?"
|
|
||||||
|
|
||||||
# Run inference
|
|
||||||
response = llm.run(task)
|
|
||||||
print("Response:", response)
|
|
||||||
|
|
||||||
# Example with system prompt
|
|
||||||
llm_with_system = VLLMWrapper(
|
|
||||||
model_name="meta-llama/Llama-2-7b-chat-hf", # Replace with your model path or HF model ID
|
|
||||||
system_prompt="You are a helpful AI assistant that provides concise answers.",
|
|
||||||
temperature=0.7,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Run inference with system prompt
|
|
||||||
response = llm_with_system.run(task)
|
|
||||||
print("\nResponse with system prompt:", response)
|
|
||||||
|
|
||||||
# Example with batched inference
|
|
||||||
tasks = [
|
|
||||||
"What is vLLM?",
|
|
||||||
"How does vLLM improve inference speed?",
|
|
||||||
"What are the main features of vLLM?",
|
|
||||||
]
|
|
||||||
|
|
||||||
responses = llm.batched_run(tasks, batch_size=2)
|
|
||||||
print("\nBatched responses:")
|
|
||||||
for task, response in zip(tasks, responses):
|
|
||||||
print(f"\nTask: {task}")
|
|
||||||
print(f"Response: {response}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
@ -1,148 +0,0 @@
|
|||||||
import concurrent.futures
|
|
||||||
import os
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from loguru import logger
|
|
||||||
|
|
||||||
try:
|
|
||||||
from vllm import LLM, SamplingParams
|
|
||||||
except ImportError:
|
|
||||||
import subprocess
|
|
||||||
import sys
|
|
||||||
|
|
||||||
print("Installing vllm")
|
|
||||||
subprocess.check_call(
|
|
||||||
[sys.executable, "-m", "pip", "install", "-U", "vllm"]
|
|
||||||
)
|
|
||||||
print("vllm installed")
|
|
||||||
from vllm import LLM, SamplingParams
|
|
||||||
|
|
||||||
|
|
||||||
class VLLMWrapper:
|
|
||||||
"""
|
|
||||||
A wrapper class for vLLM that provides a similar interface to LiteLLM.
|
|
||||||
This class handles model initialization and inference using vLLM.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_name: str = "meta-llama/Llama-2-7b-chat-hf",
|
|
||||||
system_prompt: str | None = None,
|
|
||||||
stream: bool = False,
|
|
||||||
temperature: float = 0.5,
|
|
||||||
max_tokens: int = 4000,
|
|
||||||
max_completion_tokens: int = 4000,
|
|
||||||
tools_list_dictionary: list[dict[str, Any]] | None = None,
|
|
||||||
tool_choice: str = "auto",
|
|
||||||
parallel_tool_calls: bool = False,
|
|
||||||
*args,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Initialize the vLLM wrapper with the given parameters.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_name (str): The name of the model to use. Defaults to "meta-llama/Llama-2-7b-chat-hf".
|
|
||||||
system_prompt (str, optional): The system prompt to use. Defaults to None.
|
|
||||||
stream (bool): Whether to stream the output. Defaults to False.
|
|
||||||
temperature (float): The temperature for sampling. Defaults to 0.5.
|
|
||||||
max_tokens (int): The maximum number of tokens to generate. Defaults to 4000.
|
|
||||||
max_completion_tokens (int): The maximum number of completion tokens. Defaults to 4000.
|
|
||||||
tools_list_dictionary (List[Dict[str, Any]], optional): List of available tools. Defaults to None.
|
|
||||||
tool_choice (str): How to choose tools. Defaults to "auto".
|
|
||||||
parallel_tool_calls (bool): Whether to allow parallel tool calls. Defaults to False.
|
|
||||||
"""
|
|
||||||
self.model_name = model_name
|
|
||||||
self.system_prompt = system_prompt
|
|
||||||
self.stream = stream
|
|
||||||
self.temperature = temperature
|
|
||||||
self.max_tokens = max_tokens
|
|
||||||
self.max_completion_tokens = max_completion_tokens
|
|
||||||
self.tools_list_dictionary = tools_list_dictionary
|
|
||||||
self.tool_choice = tool_choice
|
|
||||||
self.parallel_tool_calls = parallel_tool_calls
|
|
||||||
|
|
||||||
# Initialize vLLM
|
|
||||||
self.llm = LLM(model=model_name, **kwargs)
|
|
||||||
self.sampling_params = SamplingParams(
|
|
||||||
temperature=temperature,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _prepare_prompt(self, task: str) -> str:
|
|
||||||
"""
|
|
||||||
Prepare the prompt for the given task.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task (str): The task to prepare the prompt for.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: The prepared prompt.
|
|
||||||
"""
|
|
||||||
if self.system_prompt:
|
|
||||||
return f"{self.system_prompt}\n\nUser: {task}\nAssistant:"
|
|
||||||
return f"User: {task}\nAssistant:"
|
|
||||||
|
|
||||||
def run(self, task: str, *args, **kwargs) -> str:
|
|
||||||
"""
|
|
||||||
Run the model for the given task.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task (str): The task to run the model for.
|
|
||||||
*args: Additional positional arguments.
|
|
||||||
**kwargs: Additional keyword arguments.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: The model's response.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
prompt = self._prepare_prompt(task)
|
|
||||||
|
|
||||||
outputs = self.llm.generate(prompt, self.sampling_params)
|
|
||||||
response = outputs[0].outputs[0].text.strip()
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
except Exception as error:
|
|
||||||
logger.error(f"Error in VLLMWrapper: {error}")
|
|
||||||
raise error
|
|
||||||
|
|
||||||
def __call__(self, task: str, *args, **kwargs) -> str:
|
|
||||||
"""
|
|
||||||
Call the model for the given task.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task (str): The task to run the model for.
|
|
||||||
*args: Additional positional arguments.
|
|
||||||
**kwargs: Additional keyword arguments.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: The model's response.
|
|
||||||
"""
|
|
||||||
return self.run(task, *args, **kwargs)
|
|
||||||
|
|
||||||
def batched_run(
|
|
||||||
self, tasks: list[str], batch_size: int = 10
|
|
||||||
) -> list[str]:
|
|
||||||
"""
|
|
||||||
Run the model for multiple tasks in batches.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tasks (List[str]): List of tasks to run.
|
|
||||||
batch_size (int): Size of each batch. Defaults to 10.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[str]: List of model responses.
|
|
||||||
"""
|
|
||||||
# Calculate the worker count based on 95% of available CPU cores
|
|
||||||
num_workers = max(1, int((os.cpu_count() or 1) * 0.95))
|
|
||||||
with concurrent.futures.ThreadPoolExecutor(
|
|
||||||
max_workers=num_workers
|
|
||||||
) as executor:
|
|
||||||
futures = [
|
|
||||||
executor.submit(self.run, task) for task in tasks
|
|
||||||
]
|
|
||||||
return [
|
|
||||||
future.result()
|
|
||||||
for future in concurrent.futures.as_completed(futures)
|
|
||||||
]
|
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in new issue