Implement lazy initialization for VLLMWrapper

pull/925/head^2
Pavan Kumar 1 month ago
parent 2770b8c7bf
commit 7f3a854bb2

File diff suppressed because it is too large Load Diff

@ -47,13 +47,15 @@ Here's a complete example of setting up the stock analysis swarm:
from swarms import Agent, ConcurrentWorkflow from swarms import Agent, ConcurrentWorkflow
from swarms.utils.vllm_wrapper import VLLMWrapper from swarms.utils.vllm_wrapper import VLLMWrapper
# Initialize the VLLM wrapper # Initialize the VLLM wrapper (model loads when used)
vllm = VLLMWrapper( vllm = VLLMWrapper(
model_name="meta-llama/Llama-2-7b-chat-hf", model_name="meta-llama/Llama-2-7b-chat-hf",
system_prompt="You are a helpful assistant.", system_prompt="You are a helpful assistant.",
) )
``` ```
The model is initialized when `run()` or `batched_run()` is first called.
!!! note "Model Selection" !!! note "Model Selection"
The example uses Llama-2-7b-chat, but you can use any VLLM-compatible model. Make sure you have the necessary permissions and resources to run your chosen model. The example uses Llama-2-7b-chat, but you can use any VLLM-compatible model. Make sure you have the necessary permissions and resources to run your chosen model.

@ -28,7 +28,7 @@ Here's a simple example of how to use vLLM with Swarms:
```python title="basic_usage.py" ```python title="basic_usage.py"
from swarms.utils.vllm_wrapper import VLLMWrapper from swarms.utils.vllm_wrapper import VLLMWrapper
# Initialize the vLLM wrapper # Initialize the vLLM wrapper (model loads on first use)
vllm = VLLMWrapper( vllm = VLLMWrapper(
model_name="meta-llama/Llama-2-7b-chat-hf", model_name="meta-llama/Llama-2-7b-chat-hf",
system_prompt="You are a helpful assistant.", system_prompt="You are a helpful assistant.",
@ -41,6 +41,8 @@ response = vllm.run("What is the capital of France?")
print(response) print(response)
``` ```
The first call to `run()` lazily loads the model weights.
## VLLMWrapper Class ## VLLMWrapper Class
!!! abstract "Class Overview" !!! abstract "Class Overview"

@ -4,7 +4,7 @@ from dotenv import load_dotenv
load_dotenv() load_dotenv()
# Initialize the VLLM wrapper # Initialize the VLLM wrapper (model loads lazily on first run)
vllm = VLLMWrapper( vllm = VLLMWrapper(
model_name="meta-llama/Llama-2-7b-chat-hf", model_name="meta-llama/Llama-2-7b-chat-hf",
system_prompt="You are a helpful assistant.", system_prompt="You are a helpful assistant.",

@ -2,9 +2,9 @@ from swarms.utils.vllm_wrapper import VLLMWrapper
def main(): def main():
# Initialize the vLLM wrapper with a model # Initialize the vLLM wrapper.
# Note: You'll need to have the model downloaded or specify a HuggingFace model ID # The actual model weights load lazily on the first call to `run()`.
llm = VLLMWrapper( llm = VLLMWrapper(
model_name="meta-llama/Llama-2-7b-chat-hf", # Replace with your model path or HF model ID model_name="meta-llama/Llama-2-7b-chat-hf", # Replace with your model path or HF model ID
temperature=0.7, temperature=0.7,
max_tokens=1000, max_tokens=1000,
@ -17,8 +17,8 @@ def main():
response = llm.run(task) response = llm.run(task)
print("Response:", response) print("Response:", response)
# Example with system prompt # Example with system prompt. Model initialization is still lazy.
llm_with_system = VLLMWrapper( llm_with_system = VLLMWrapper(
model_name="meta-llama/Llama-2-7b-chat-hf", # Replace with your model path or HF model ID model_name="meta-llama/Llama-2-7b-chat-hf", # Replace with your model path or HF model ID
system_prompt="You are a helpful AI assistant that provides concise answers.", system_prompt="You are a helpful AI assistant that provides concise answers.",
temperature=0.7, temperature=0.7,

@ -61,12 +61,22 @@ class VLLMWrapper:
self.tool_choice = tool_choice self.tool_choice = tool_choice
self.parallel_tool_calls = parallel_tool_calls self.parallel_tool_calls = parallel_tool_calls
# Initialize vLLM # store kwargs for later lazy initialization
self.llm = LLM(model=model_name, **kwargs) self._llm_kwargs = kwargs
self.sampling_params = SamplingParams(
temperature=temperature, # Initialize attributes for lazy loading
max_tokens=max_tokens, self.llm = None
) self.sampling_params = None
def _ensure_initialized(self):
"""Lazily initialize the underlying vLLM objects if needed."""
if self.llm is None:
self.llm = LLM(model=self.model_name, **self._llm_kwargs)
if self.sampling_params is None:
self.sampling_params = SamplingParams(
temperature=self.temperature,
max_tokens=self.max_tokens,
)
def _prepare_prompt(self, task: str) -> str: def _prepare_prompt(self, task: str) -> str:
""" """
@ -82,9 +92,9 @@ class VLLMWrapper:
return f"{self.system_prompt}\n\nUser: {task}\nAssistant:" return f"{self.system_prompt}\n\nUser: {task}\nAssistant:"
return f"User: {task}\nAssistant:" return f"User: {task}\nAssistant:"
def run(self, task: str, *args, **kwargs) -> str: def run(self, task: str, *args, **kwargs) -> str:
""" """
Run the model for the given task. Run the model for the given task.
Args: Args:
task (str): The task to run the model for. task (str): The task to run the model for.
@ -94,10 +104,11 @@ class VLLMWrapper:
Returns: Returns:
str: The model's response. str: The model's response.
""" """
try: try:
prompt = self._prepare_prompt(task) self._ensure_initialized()
prompt = self._prepare_prompt(task)
outputs = self.llm.generate(prompt, self.sampling_params)
outputs = self.llm.generate(prompt, self.sampling_params)
response = outputs[0].outputs[0].text.strip() response = outputs[0].outputs[0].text.strip()
return response return response
@ -120,9 +131,9 @@ class VLLMWrapper:
""" """
return self.run(task, *args, **kwargs) return self.run(task, *args, **kwargs)
def batched_run( def batched_run(
self, tasks: List[str], batch_size: int = 10 self, tasks: List[str], batch_size: int = 10
) -> List[str]: ) -> List[str]:
""" """
Run the model for multiple tasks in batches. Run the model for multiple tasks in batches.
@ -133,6 +144,7 @@ class VLLMWrapper:
Returns: Returns:
List[str]: List of model responses. List[str]: List of model responses.
""" """
self._ensure_initialized()
# Calculate the worker count based on 95% of available CPU cores # Calculate the worker count based on 95% of available CPU cores
num_workers = max(1, int((os.cpu_count() or 1) * 0.95)) num_workers = max(1, int((os.cpu_count() or 1) * 0.95))
with concurrent.futures.ThreadPoolExecutor( with concurrent.futures.ThreadPoolExecutor(

Loading…
Cancel
Save