commit
a3fce76d0e
@ -0,0 +1,82 @@
|
||||
import pkgutil
|
||||
import inspect
|
||||
|
||||
|
||||
class ModelRegistry:
|
||||
"""
|
||||
A registry for storing and querying models.
|
||||
|
||||
Attributes:
|
||||
models (dict): A dictionary of model names and corresponding model classes.
|
||||
|
||||
Methods:
|
||||
__init__(): Initializes the ModelRegistry object and retrieves all available models.
|
||||
_get_all_models(): Retrieves all available models from the models package.
|
||||
query(text): Queries the models based on the given text and returns a dictionary of matching models.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.models = self._get_all_models()
|
||||
|
||||
def _get_all_models(self):
|
||||
"""
|
||||
Retrieves all available models from the models package.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary of model names and corresponding model classes.
|
||||
"""
|
||||
models = {}
|
||||
for importer, modname, ispkg in pkgutil.iter_modules(
|
||||
models.__path__
|
||||
):
|
||||
module = importer.find_module(modname).load_module(
|
||||
modname
|
||||
)
|
||||
for name, obj in inspect.getmembers(module):
|
||||
if inspect.isclass(obj):
|
||||
models[name] = obj
|
||||
return models
|
||||
|
||||
def query(self, text):
|
||||
"""
|
||||
Queries the models based on the given text and returns a dictionary of matching models.
|
||||
|
||||
Args:
|
||||
text (str): The text to search for in the model names.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary of matching model names and corresponding model classes.
|
||||
"""
|
||||
return {
|
||||
name: model
|
||||
for name, model in self.models.items()
|
||||
if text in name
|
||||
}
|
||||
|
||||
def run_model(
|
||||
self, model_name: str, task: str, img: str, *args, **kwargs
|
||||
):
|
||||
"""
|
||||
Runs the specified model for the given task and image.
|
||||
|
||||
Args:
|
||||
model_name (str): The name of the model to run.
|
||||
task (str): The task to perform using the model.
|
||||
img (str): The image to process.
|
||||
*args: Additional positional arguments to pass to the model's run method.
|
||||
**kwargs: Additional keyword arguments to pass to the model's run method.
|
||||
|
||||
Returns:
|
||||
The result of running the model.
|
||||
|
||||
Raises:
|
||||
ValueError: If the specified model is not found in the model registry.
|
||||
"""
|
||||
if model_name not in self.models:
|
||||
raise ValueError(f"Model {model_name} not found")
|
||||
|
||||
# Get the model
|
||||
model = self.models[model_name]
|
||||
|
||||
# Run the model
|
||||
return model.run(task, img, *args, **kwargs)
|
@ -0,0 +1,299 @@
|
||||
"""Sampling parameters for text generation."""
|
||||
from enum import IntEnum
|
||||
from functools import cached_property
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
|
||||
|
||||
class SamplingType(IntEnum):
|
||||
GREEDY = 0
|
||||
RANDOM = 1
|
||||
BEAM = 2
|
||||
|
||||
|
||||
LogitsProcessor = Callable[[List[int], torch.Tensor], torch.Tensor]
|
||||
"""LogitsProcessor is a function that takes a list of previously generated
|
||||
tokens and a tensor of the logits for the next token, and returns a modified
|
||||
tensor of logits to sample from."""
|
||||
|
||||
|
||||
class SamplingParams:
|
||||
"""Sampling parameters for text generation.
|
||||
|
||||
Overall, we follow the sampling parameters from the OpenAI text completion
|
||||
API (https://platform.openai.com/docs/api-reference/completions/create).
|
||||
In addition, we support beam search, which is not supported by OpenAI.
|
||||
|
||||
Args:
|
||||
n: Number of output sequences to return for the given prompt.
|
||||
best_of: Number of output sequences that are generated from the prompt.
|
||||
From these `best_of` sequences, the top `n` sequences are returned.
|
||||
`best_of` must be greater than or equal to `n`. This is treated as
|
||||
the beam width when `use_beam_search` is True. By default, `best_of`
|
||||
is set to `n`.
|
||||
presence_penalty: Float that penalizes new tokens based on whether they
|
||||
appear in the generated text so far. Values > 0 encourage the model
|
||||
to use new tokens, while values < 0 encourage the model to repeat
|
||||
tokens.
|
||||
frequency_penalty: Float that penalizes new tokens based on their
|
||||
frequency in the generated text so far. Values > 0 encourage the
|
||||
model to use new tokens, while values < 0 encourage the model to
|
||||
repeat tokens.
|
||||
repetition_penalty: Float that penalizes new tokens based on whether
|
||||
they appear in the prompt and the generated text so far. Values > 1
|
||||
encourage the model to use new tokens, while values < 1 encourage
|
||||
the model to repeat tokens.
|
||||
temperature: Float that controls the randomness of the sampling. Lower
|
||||
values make the model more deterministic, while higher values make
|
||||
the model more random. Zero means greedy sampling.
|
||||
top_p: Float that controls the cumulative probability of the top tokens
|
||||
to consider. Must be in (0, 1]. Set to 1 to consider all tokens.
|
||||
top_k: Integer that controls the number of top tokens to consider. Set
|
||||
to -1 to consider all tokens.
|
||||
min_p: Float that represents the minimum probability for a token to be
|
||||
considered, relative to the probability of the most likely token.
|
||||
Must be in [0, 1]. Set to 0 to disable this.
|
||||
use_beam_search: Whether to use beam search instead of sampling.
|
||||
length_penalty: Float that penalizes sequences based on their length.
|
||||
Used in beam search.
|
||||
early_stopping: Controls the stopping condition for beam search. It
|
||||
accepts the following values: `True`, where the generation stops as
|
||||
soon as there are `best_of` complete candidates; `False`, where an
|
||||
heuristic is applied and the generation stops when is it very
|
||||
unlikely to find better candidates; `"never"`, where the beam search
|
||||
procedure only stops when there cannot be better candidates
|
||||
(canonical beam search algorithm).
|
||||
stop: List of strings that stop the generation when they are generated.
|
||||
The returned output will not contain the stop strings.
|
||||
stop_token_ids: List of tokens that stop the generation when they are
|
||||
generated. The returned output will contain the stop tokens unless
|
||||
the stop tokens are special tokens.
|
||||
include_stop_str_in_output: Whether to include the stop strings in output
|
||||
text. Defaults to False.
|
||||
ignore_eos: Whether to ignore the EOS token and continue generating
|
||||
tokens after the EOS token is generated.
|
||||
max_tokens: Maximum number of tokens to generate per output sequence.
|
||||
logprobs: Number of log probabilities to return per output token.
|
||||
Note that the implementation follows the OpenAI API: The return
|
||||
result includes the log probabilities on the `logprobs` most likely
|
||||
tokens, as well the chosen tokens. The API will always return the
|
||||
log probability of the sampled token, so there may be up to
|
||||
`logprobs+1` elements in the response.
|
||||
prompt_logprobs: Number of log probabilities to return per prompt token.
|
||||
skip_special_tokens: Whether to skip special tokens in the output.
|
||||
spaces_between_special_tokens: Whether to add spaces between special
|
||||
tokens in the output. Defaults to True.
|
||||
logits_processors: List of functions that modify logits based on
|
||||
previously generated tokens.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n: int = 1,
|
||||
best_of: Optional[int] = None,
|
||||
presence_penalty: float = 0.0,
|
||||
frequency_penalty: float = 0.0,
|
||||
repetition_penalty: float = 1.0,
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = -1,
|
||||
min_p: float = 0.0,
|
||||
use_beam_search: bool = False,
|
||||
length_penalty: float = 1.0,
|
||||
early_stopping: Union[bool, str] = False,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
include_stop_str_in_output: bool = False,
|
||||
ignore_eos: bool = False,
|
||||
max_tokens: Optional[int] = 16,
|
||||
logprobs: Optional[int] = None,
|
||||
prompt_logprobs: Optional[int] = None,
|
||||
skip_special_tokens: bool = True,
|
||||
spaces_between_special_tokens: bool = True,
|
||||
logits_processors: Optional[List[LogitsProcessor]] = None,
|
||||
) -> None:
|
||||
self.n = n
|
||||
self.best_of = best_of if best_of is not None else n
|
||||
self.presence_penalty = presence_penalty
|
||||
self.frequency_penalty = frequency_penalty
|
||||
self.repetition_penalty = repetition_penalty
|
||||
self.temperature = temperature
|
||||
self.top_p = top_p
|
||||
self.top_k = top_k
|
||||
self.min_p = min_p
|
||||
self.use_beam_search = use_beam_search
|
||||
self.length_penalty = length_penalty
|
||||
self.early_stopping = early_stopping
|
||||
if stop is None:
|
||||
self.stop = []
|
||||
elif isinstance(stop, str):
|
||||
self.stop = [stop]
|
||||
else:
|
||||
self.stop = list(stop)
|
||||
if stop_token_ids is None:
|
||||
self.stop_token_ids = []
|
||||
else:
|
||||
self.stop_token_ids = list(stop_token_ids)
|
||||
self.ignore_eos = ignore_eos
|
||||
self.max_tokens = max_tokens
|
||||
self.logprobs = logprobs
|
||||
self.prompt_logprobs = prompt_logprobs
|
||||
self.skip_special_tokens = skip_special_tokens
|
||||
self.spaces_between_special_tokens = (
|
||||
spaces_between_special_tokens
|
||||
)
|
||||
self.logits_processors = logits_processors
|
||||
self.include_stop_str_in_output = include_stop_str_in_output
|
||||
self._verify_args()
|
||||
if self.use_beam_search:
|
||||
self._verify_beam_search()
|
||||
else:
|
||||
self._verify_non_beam_search()
|
||||
if self.temperature < _SAMPLING_EPS:
|
||||
# Zero temperature means greedy sampling.
|
||||
self.top_p = 1.0
|
||||
self.top_k = -1
|
||||
self.min_p = 0.0
|
||||
self._verify_greedy_sampling()
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
if self.n < 1:
|
||||
raise ValueError(f"n must be at least 1, got {self.n}.")
|
||||
if self.best_of < self.n:
|
||||
raise ValueError(
|
||||
"best_of must be greater than or equal to n, "
|
||||
f"got n={self.n} and best_of={self.best_of}."
|
||||
)
|
||||
if not -2.0 <= self.presence_penalty <= 2.0:
|
||||
raise ValueError(
|
||||
"presence_penalty must be in [-2, 2], got "
|
||||
f"{self.presence_penalty}."
|
||||
)
|
||||
if not -2.0 <= self.frequency_penalty <= 2.0:
|
||||
raise ValueError(
|
||||
"frequency_penalty must be in [-2, 2], got "
|
||||
f"{self.frequency_penalty}."
|
||||
)
|
||||
if not 0.0 < self.repetition_penalty <= 2.0:
|
||||
raise ValueError(
|
||||
"repetition_penalty must be in (0, 2], got "
|
||||
f"{self.repetition_penalty}."
|
||||
)
|
||||
if self.temperature < 0.0:
|
||||
raise ValueError(
|
||||
"temperature must be non-negative, got"
|
||||
f" {self.temperature}."
|
||||
)
|
||||
if not 0.0 < self.top_p <= 1.0:
|
||||
raise ValueError(
|
||||
f"top_p must be in (0, 1], got {self.top_p}."
|
||||
)
|
||||
if self.top_k < -1 or self.top_k == 0:
|
||||
raise ValueError(
|
||||
"top_k must be -1 (disable), or at least 1, "
|
||||
f"got {self.top_k}."
|
||||
)
|
||||
if not 0.0 <= self.min_p <= 1.0:
|
||||
raise ValueError(
|
||||
f"min_p must be in [0, 1], got {self.min_p}."
|
||||
)
|
||||
if self.max_tokens is not None and self.max_tokens < 1:
|
||||
raise ValueError(
|
||||
"max_tokens must be at least 1, got"
|
||||
f" {self.max_tokens}."
|
||||
)
|
||||
if self.logprobs is not None and self.logprobs < 0:
|
||||
raise ValueError(
|
||||
f"logprobs must be non-negative, got {self.logprobs}."
|
||||
)
|
||||
if (
|
||||
self.prompt_logprobs is not None
|
||||
and self.prompt_logprobs < 0
|
||||
):
|
||||
raise ValueError(
|
||||
"prompt_logprobs must be non-negative, got "
|
||||
f"{self.prompt_logprobs}."
|
||||
)
|
||||
|
||||
def _verify_beam_search(self) -> None:
|
||||
if self.best_of == 1:
|
||||
raise ValueError(
|
||||
"best_of must be greater than 1 when using beam "
|
||||
f"search. Got {self.best_of}."
|
||||
)
|
||||
if self.temperature > _SAMPLING_EPS:
|
||||
raise ValueError(
|
||||
"temperature must be 0 when using beam search."
|
||||
)
|
||||
if self.top_p < 1.0 - _SAMPLING_EPS:
|
||||
raise ValueError(
|
||||
"top_p must be 1 when using beam search."
|
||||
)
|
||||
if self.top_k != -1:
|
||||
raise ValueError(
|
||||
"top_k must be -1 when using beam search."
|
||||
)
|
||||
if self.early_stopping not in [True, False, "never"]:
|
||||
raise ValueError(
|
||||
"early_stopping must be True, False, or 'never', "
|
||||
f"got {self.early_stopping}."
|
||||
)
|
||||
|
||||
def _verify_non_beam_search(self) -> None:
|
||||
if self.early_stopping is not False:
|
||||
raise ValueError(
|
||||
"early_stopping is not effective and must be "
|
||||
"False when not using beam search."
|
||||
)
|
||||
if (
|
||||
self.length_penalty < 1.0 - _SAMPLING_EPS
|
||||
or self.length_penalty > 1.0 + _SAMPLING_EPS
|
||||
):
|
||||
raise ValueError(
|
||||
"length_penalty is not effective and must be the "
|
||||
"default value of 1.0 when not using beam search."
|
||||
)
|
||||
|
||||
def _verify_greedy_sampling(self) -> None:
|
||||
if self.best_of > 1:
|
||||
raise ValueError(
|
||||
"best_of must be 1 when using greedy sampling."
|
||||
f"Got {self.best_of}."
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def sampling_type(self) -> SamplingType:
|
||||
if self.use_beam_search:
|
||||
return SamplingType.BEAM
|
||||
if self.temperature < _SAMPLING_EPS:
|
||||
return SamplingType.GREEDY
|
||||
return SamplingType.RANDOM
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"SamplingParams(n={self.n}, "
|
||||
f"best_of={self.best_of}, "
|
||||
f"presence_penalty={self.presence_penalty}, "
|
||||
f"frequency_penalty={self.frequency_penalty}, "
|
||||
f"repetition_penalty={self.repetition_penalty}, "
|
||||
f"temperature={self.temperature}, "
|
||||
f"top_p={self.top_p}, "
|
||||
f"top_k={self.top_k}, "
|
||||
f"min_p={self.min_p}, "
|
||||
f"use_beam_search={self.use_beam_search}, "
|
||||
f"length_penalty={self.length_penalty}, "
|
||||
f"early_stopping={self.early_stopping}, "
|
||||
f"stop={self.stop}, "
|
||||
f"stop_token_ids={self.stop_token_ids}, "
|
||||
f"include_stop_str_in_output={self.include_stop_str_in_output}, "
|
||||
f"ignore_eos={self.ignore_eos}, "
|
||||
f"max_tokens={self.max_tokens}, "
|
||||
f"logprobs={self.logprobs}, "
|
||||
f"prompt_logprobs={self.prompt_logprobs}, "
|
||||
f"skip_special_tokens={self.skip_special_tokens}, "
|
||||
"spaces_between_special_tokens="
|
||||
f"{self.spaces_between_special_tokens})"
|
||||
)
|
@ -0,0 +1,80 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
|
||||
def continuous_tensor(
|
||||
inputs: torch.Tensor, seq_length: torch.LongTensor
|
||||
):
|
||||
"""Convert batched tensor to continuous tensor.
|
||||
|
||||
Args:
|
||||
inputs (Tensor): batched tensor.
|
||||
seq_length (Tensor): length of each sequence.
|
||||
|
||||
Return:
|
||||
Tensor: continuoused tensor.
|
||||
"""
|
||||
assert inputs.dim() > 1
|
||||
if inputs.size(1) == 1:
|
||||
return inputs.reshape(1, -1)
|
||||
|
||||
inputs = [inp[:slen] for inp, slen in zip(inputs, seq_length)]
|
||||
|
||||
inputs = torch.cat(inputs).unsqueeze(0)
|
||||
return inputs
|
||||
|
||||
|
||||
def batch_tensor(inputs: torch.Tensor, seq_length: torch.LongTensor):
|
||||
"""Convert continuoused tensor to batched tensor.
|
||||
|
||||
Args:
|
||||
inputs (Tensor): continuoused tensor.
|
||||
seq_length (Tensor): length of each sequence.
|
||||
|
||||
Return:
|
||||
Tensor: batched tensor.
|
||||
"""
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
end_loc = seq_length.cumsum(0)
|
||||
start_loc = end_loc - seq_length
|
||||
|
||||
inputs = [
|
||||
inputs[0, sloc:eloc] for sloc, eloc in zip(start_loc, end_loc)
|
||||
]
|
||||
inputs = pad_sequence(inputs, batch_first=True)
|
||||
return inputs
|
||||
|
||||
|
||||
def page_cache(
|
||||
paged_cache: torch.Tensor,
|
||||
batched_cache: torch.Tensor,
|
||||
cache_length: torch.Tensor,
|
||||
block_offsets: torch.Tensor,
|
||||
permute_head: bool = True,
|
||||
):
|
||||
"""Convert batched cache to paged cache.
|
||||
|
||||
Args:
|
||||
paged_cache (Tensor): Output paged cache.
|
||||
batched_cache (Tensor): Input batched cache.
|
||||
cache_length (Tensor): length of the cache.
|
||||
block_offsets (Tensor): Offset of each blocks.
|
||||
"""
|
||||
assert block_offsets.dim() == 2
|
||||
block_size = paged_cache.size(1)
|
||||
batch_size = batched_cache.size(0)
|
||||
if permute_head:
|
||||
batched_cache = batched_cache.permute(0, 2, 1, 3)
|
||||
|
||||
for b_idx in range(batch_size):
|
||||
cache_len = cache_length[b_idx]
|
||||
b_cache = batched_cache[b_idx]
|
||||
block_off = block_offsets[b_idx]
|
||||
block_off_idx = 0
|
||||
for s_start in range(0, cache_len, block_size):
|
||||
s_end = min(s_start + block_size, cache_len)
|
||||
s_len = s_end - s_start
|
||||
b_off = block_off[block_off_idx]
|
||||
paged_cache[b_off, :s_len] = b_cache[s_start:s_end]
|
||||
block_off_idx += 1
|
@ -0,0 +1,315 @@
|
||||
from typing import Callable, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from torch.distributed._tensor import (
|
||||
DeviceMesh,
|
||||
DTensor,
|
||||
Replicate,
|
||||
Shard,
|
||||
distribute_tensor,
|
||||
)
|
||||
from zeta.nn import QuantizedLN
|
||||
|
||||
|
||||
try:
|
||||
from peft.tuners.lora import Linear as LoRALinear
|
||||
except ImportError:
|
||||
|
||||
class LoRALinear:
|
||||
pass
|
||||
|
||||
|
||||
def try_to_local(tensor: Union[Tensor, DTensor]):
|
||||
"""Try to convert DTensor to Tensor.
|
||||
|
||||
Args:
|
||||
tensor (Tensor|DTensor): Tensor to convert.
|
||||
"""
|
||||
if isinstance(tensor, DTensor):
|
||||
tensor = tensor.to_local()
|
||||
return tensor
|
||||
|
||||
|
||||
def module_to_local(module: nn.Module):
|
||||
"""convert all DTensor parameters to Tensor parameters in module.
|
||||
|
||||
Args:
|
||||
module (Module): Module to convert.
|
||||
"""
|
||||
for name, mod in module.named_children():
|
||||
module_to_local(mod)
|
||||
|
||||
for name, param in module.named_parameters(recurse=False):
|
||||
module.register_parameter(
|
||||
name, nn.Parameter(try_to_local(param))
|
||||
)
|
||||
|
||||
for name, buf in module.named_buffers(recurse=False):
|
||||
module.register_buffer(name, try_to_local(buf))
|
||||
|
||||
|
||||
def rowwise_parallelize_linear(
|
||||
module: nn.Module, device_mesh: DeviceMesh, to_local: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
This function parallelizes the input :class:`nn.Linear` module in
|
||||
:class:`RowwiseParallel` style.
|
||||
|
||||
Args:
|
||||
module (:class:`nn.Module`):
|
||||
The :class:`nn.Linear` module to be parallelized.
|
||||
device_mesh (:class:`DeviceMesh`):
|
||||
Object which describes the mesh topology of devices.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
for name, param in module.named_parameters():
|
||||
dist_spec = (
|
||||
[Shard(1)]
|
||||
if name == "weight"
|
||||
else [Replicate()] # type: ignore[list-item]
|
||||
)
|
||||
|
||||
dist_tensor = distribute_tensor(param, device_mesh, dist_spec)
|
||||
if to_local:
|
||||
dist_tensor = try_to_local(dist_tensor)
|
||||
if name == "bias":
|
||||
# rowwise linear would add bias more than ones.
|
||||
dist_tensor /= device_mesh.size()
|
||||
dist_param = torch.nn.Parameter(dist_tensor)
|
||||
module.register_parameter(name, dist_param)
|
||||
|
||||
# Weight, bias and scale are registered as buffer in QLinear
|
||||
for name, buffer in module.named_buffers():
|
||||
dist_spec = (
|
||||
[Shard(1)]
|
||||
if name == "weight"
|
||||
else [Replicate()] # type: ignore[list-item]
|
||||
)
|
||||
|
||||
dist_tensor = distribute_tensor(
|
||||
buffer, device_mesh, dist_spec
|
||||
)
|
||||
if to_local:
|
||||
dist_tensor = try_to_local(dist_tensor)
|
||||
if name == "bias":
|
||||
# rowwise linear would add bias more than ones.
|
||||
dist_tensor /= device_mesh.size()
|
||||
module.register_buffer(name, dist_tensor)
|
||||
|
||||
dist_tensor = distribute_tensor(
|
||||
buffer, device_mesh, dist_spec
|
||||
)
|
||||
if to_local:
|
||||
dist_tensor = try_to_local(dist_tensor)
|
||||
module.register_buffer(name, dist_tensor)
|
||||
|
||||
|
||||
def rowwise_parallelize_loralinear(
|
||||
module: LoRALinear,
|
||||
device_mesh: DeviceMesh,
|
||||
to_local: bool = False,
|
||||
) -> None:
|
||||
"""rowwize parallelize lora linear.
|
||||
|
||||
Read S-LoRA for more detail.
|
||||
"""
|
||||
rowwise_parallelize_linear(
|
||||
module.base_layer, device_mesh=device_mesh, to_local=to_local
|
||||
)
|
||||
for mod in module.lora_A.values():
|
||||
rowwise_parallelize_linear(
|
||||
mod, device_mesh=device_mesh, to_local=to_local
|
||||
)
|
||||
for mod in module.lora_B.values():
|
||||
colwise_parallelize_linear(
|
||||
mod, device_mesh=device_mesh, to_local=to_local
|
||||
)
|
||||
module._tp_mode = "rowwise"
|
||||
|
||||
|
||||
def rowwise_parallelize_linear_fn(
|
||||
module: nn.Module, device_mesh: DeviceMesh, to_local: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
This function parallelizes the input :Linear module in
|
||||
:class:`RowwiseParallel` style.
|
||||
|
||||
Args:
|
||||
module (:class:`nn.Module`):
|
||||
The :class:`nn.Linear` module to be parallelized.
|
||||
device_mesh (:class:`DeviceMesh`):
|
||||
Object which describes the mesh topology of devices.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if isinstance(module, (torch.nn.Linear, QuantizedLN)):
|
||||
return rowwise_parallelize_linear(
|
||||
module, device_mesh=device_mesh, to_local=to_local
|
||||
)
|
||||
elif isinstance(module, LoRALinear):
|
||||
return rowwise_parallelize_loralinear(
|
||||
module, device_mesh=device_mesh, to_local=to_local
|
||||
)
|
||||
else:
|
||||
raise TypeError(f"Unsupported module: {type(module)}")
|
||||
|
||||
|
||||
def colwise_parallelize_linear(
|
||||
module: nn.Module, device_mesh: DeviceMesh, to_local: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
This function parallelizes the input :class:`nn.Linear` module in
|
||||
:class:`ColwiseParallel` style.
|
||||
|
||||
Args:
|
||||
module (:class:`nn.Module`):
|
||||
The :class:`nn.Linear` module to be parallelized.
|
||||
device_mesh (:class:`DeviceMesh`):
|
||||
Object which describes the mesh topology of devices.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
for name, param in module.named_parameters():
|
||||
dist_tensor = distribute_tensor(
|
||||
param, device_mesh, [Shard(0)]
|
||||
)
|
||||
if to_local:
|
||||
dist_tensor = try_to_local(dist_tensor)
|
||||
dist_param = torch.nn.Parameter(dist_tensor)
|
||||
module.register_parameter(name, dist_param)
|
||||
# Weight, bias and scale are registered as buffer in QLinear
|
||||
for name, buffer in module.named_buffers():
|
||||
dist_tensor = distribute_tensor(
|
||||
buffer, device_mesh, [Shard(0)]
|
||||
)
|
||||
if to_local:
|
||||
dist_tensor = try_to_local(dist_tensor)
|
||||
module.register_buffer(name, dist_tensor)
|
||||
|
||||
|
||||
def colwise_parallelize_loralinear(
|
||||
module: nn.Module, device_mesh: DeviceMesh, to_local: bool = False
|
||||
) -> None:
|
||||
"""colwise parallelize lora linear."""
|
||||
colwise_parallelize_linear(
|
||||
module.base_layer, device_mesh=device_mesh, to_local=to_local
|
||||
)
|
||||
for mod in module.lora_A.values():
|
||||
colwise_parallelize_linear(
|
||||
mod, device_mesh=device_mesh, to_local=to_local
|
||||
)
|
||||
for mod in module.lora_B.values():
|
||||
colwise_parallelize_linear(
|
||||
mod, device_mesh=device_mesh, to_local=to_local
|
||||
)
|
||||
module._tp_mode = "colwise"
|
||||
|
||||
|
||||
def colwise_parallelize_linear_fn(
|
||||
module: nn.Module, device_mesh: DeviceMesh, to_local: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
This function parallelizes the input :Linear module in
|
||||
:class:`ColwiseParallel` style.
|
||||
|
||||
Args:
|
||||
module (:class:`nn.Module`):
|
||||
The :class:`nn.Linear` module to be parallelized.
|
||||
device_mesh (:class:`DeviceMesh`):
|
||||
Object which describes the mesh topology of devices.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if isinstance(module, (torch.nn.Linear, QuantizedLN)):
|
||||
return colwise_parallelize_linear(
|
||||
module, device_mesh=device_mesh, to_local=to_local
|
||||
)
|
||||
elif isinstance(module, LoRALinear):
|
||||
return colwise_parallelize_loralinear(
|
||||
module, device_mesh=device_mesh, to_local=to_local
|
||||
)
|
||||
else:
|
||||
raise TypeError(f"Unsupported module: {type(module)}")
|
||||
|
||||
|
||||
def _partition_module(
|
||||
mod_name: str,
|
||||
prefix: str,
|
||||
module: nn.Module,
|
||||
device_mesh: DeviceMesh,
|
||||
func: Callable,
|
||||
):
|
||||
"""partition module.
|
||||
|
||||
Parameters in module won't be force Replicated.
|
||||
|
||||
Args:
|
||||
mod_name (str): module name.
|
||||
prefix (str): Parameter prefix.
|
||||
module (Module): Module to be partitioned.
|
||||
device_mesh (DeviceMesh): The device mesh.
|
||||
func (Callable): partition callback
|
||||
"""
|
||||
for name, mod in module.named_children():
|
||||
child_name = f"{prefix}{name}"
|
||||
_partition_module(
|
||||
child_name,
|
||||
child_name + ".",
|
||||
module=mod,
|
||||
device_mesh=device_mesh,
|
||||
func=func,
|
||||
)
|
||||
|
||||
func(mod_name, module, device_mesh)
|
||||
|
||||
|
||||
def partition_module(
|
||||
module: nn.Module,
|
||||
device_mesh: DeviceMesh,
|
||||
func: Callable,
|
||||
to_local: bool = False,
|
||||
):
|
||||
"""partition module.
|
||||
|
||||
Parameters in module won't be force Replicated.
|
||||
|
||||
Args:
|
||||
module (Module): Module to be partitioned.
|
||||
device_mesh (DeviceMesh): The device mesh.
|
||||
func (Callable): partition callback.
|
||||
to_local (bool): Convert all DTensor parameters to Tensor parameters.
|
||||
"""
|
||||
_partition_module(
|
||||
"", "", module=module, device_mesh=device_mesh, func=func
|
||||
)
|
||||
|
||||
if to_local:
|
||||
module_to_local(module)
|
||||
|
||||
|
||||
def replicate_module(model: nn.Module, device_mesh: DeviceMesh):
|
||||
"""Replicate all parameters in module.
|
||||
|
||||
Args:
|
||||
model (Module): Module to perform replicate.
|
||||
device_mesh (DeviceMesh): The distribution device mesh.
|
||||
"""
|
||||
for name, param in model.named_parameters(recurse=False):
|
||||
param = distribute_tensor(
|
||||
param, device_mesh=device_mesh, placements=[Replicate()]
|
||||
).to_local()
|
||||
param = nn.Parameter(param)
|
||||
model.register_parameter(name, param)
|
||||
|
||||
for name, buf in model.named_buffers(recurse=False):
|
||||
buf = distribute_tensor(
|
||||
buf, device_mesh=device_mesh, placements=[Replicate()]
|
||||
).to_local()
|
||||
model.register_buffer(name, buf)
|
@ -0,0 +1,79 @@
|
||||
import torch
|
||||
|
||||
|
||||
def continuous_tensor(
|
||||
inputs: torch.Tensor, seq_length: torch.LongTensor
|
||||
):
|
||||
"""Convert batched tensor to continuous tensor.
|
||||
|
||||
Args:
|
||||
inputs (Tensor): batched tensor.
|
||||
seq_length (Tensor): length of each sequence.
|
||||
|
||||
Return:
|
||||
Tensor: continuoused tensor.
|
||||
"""
|
||||
assert inputs.dim() > 1
|
||||
if inputs.size(1) == 1:
|
||||
return inputs.reshape(1, -1)
|
||||
|
||||
inputs = [inp[:slen] for inp, slen in zip(inputs, seq_length)]
|
||||
|
||||
inputs = torch.cat(inputs).unsqueeze(0)
|
||||
return inputs
|
||||
|
||||
|
||||
def batch_tensor(inputs: torch.Tensor, seq_length: torch.LongTensor):
|
||||
"""Convert continuoused tensor to batched tensor.
|
||||
|
||||
Args:
|
||||
inputs (Tensor): continuoused tensor.
|
||||
seq_length (Tensor): length of each sequence.
|
||||
|
||||
Return:
|
||||
Tensor: batched tensor.
|
||||
"""
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
end_loc = seq_length.cumsum(0)
|
||||
start_loc = end_loc - seq_length
|
||||
|
||||
inputs = [
|
||||
inputs[0, sloc:eloc] for sloc, eloc in zip(start_loc, end_loc)
|
||||
]
|
||||
inputs = pad_sequence(inputs, batch_first=True)
|
||||
return inputs
|
||||
|
||||
|
||||
def page_cache(
|
||||
paged_cache: torch.Tensor,
|
||||
batched_cache: torch.Tensor,
|
||||
cache_length: torch.Tensor,
|
||||
block_offsets: torch.Tensor,
|
||||
permute_head: bool = True,
|
||||
):
|
||||
"""Convert batched cache to paged cache.
|
||||
|
||||
Args:
|
||||
paged_cache (Tensor): Output paged cache.
|
||||
batched_cache (Tensor): Input batched cache.
|
||||
cache_length (Tensor): length of the cache.
|
||||
block_offsets (Tensor): Offset of each blocks.
|
||||
"""
|
||||
assert block_offsets.dim() == 2
|
||||
block_size = paged_cache.size(1)
|
||||
batch_size = batched_cache.size(0)
|
||||
if permute_head:
|
||||
batched_cache = batched_cache.permute(0, 2, 1, 3)
|
||||
|
||||
for b_idx in range(batch_size):
|
||||
cache_len = cache_length[b_idx]
|
||||
b_cache = batched_cache[b_idx]
|
||||
block_off = block_offsets[b_idx]
|
||||
block_off_idx = 0
|
||||
for s_start in range(0, cache_len, block_size):
|
||||
s_end = min(s_start + block_size, cache_len)
|
||||
s_len = s_end - s_start
|
||||
b_off = block_off[block_off_idx]
|
||||
paged_cache[b_off, :s_len] = b_cache[s_start:s_end]
|
||||
block_off_idx += 1
|
@ -1,71 +0,0 @@
|
||||
import pytest
|
||||
from swarms.models.cog_agent import CogAgent
|
||||
from unittest.mock import MagicMock
|
||||
from PIL import Image
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cogagent_params():
|
||||
return {
|
||||
"model_name": "ZhipuAI/cogagent-chat",
|
||||
"tokenizer_name": "I-ModelScope/vicuna-7b-v1.5",
|
||||
"dtype": "torch.bfloat16",
|
||||
"low_cpu_mem_usage": True,
|
||||
"load_in_4bit": True,
|
||||
"trust_remote_code": True,
|
||||
"device": "cuda",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cogagent(cogagent_params):
|
||||
return CogAgent(**cogagent_params)
|
||||
|
||||
|
||||
def test_init(mocker, cogagent_params, cogagent):
|
||||
mock_model = mocker.patch(
|
||||
"swarms.models.cog_agent.AutoModelForCausalLM.from_pretrained"
|
||||
)
|
||||
mock_tokenizer = mocker.patch(
|
||||
"swarms.models.cog_agent.AutoTokenizer.from_pretrained"
|
||||
)
|
||||
|
||||
for param, value in cogagent_params.items():
|
||||
assert getattr(cogagent, param) == value
|
||||
|
||||
mock_tokenizer.assert_called_once_with(
|
||||
cogagent_params["tokenizer_name"]
|
||||
)
|
||||
mock_model.assert_called_once_with(
|
||||
cogagent_params["model_name"],
|
||||
torch_dtype=cogagent_params["dtype"],
|
||||
low_cpu_mem_usage=cogagent_params["low_cpu_mem_usage"],
|
||||
load_in_4bit=cogagent_params["load_in_4bit"],
|
||||
trust_remote_code=cogagent_params["trust_remote_code"],
|
||||
)
|
||||
|
||||
|
||||
def test_run(mocker, cogagent):
|
||||
task = "How are you?"
|
||||
img = "images/1.jpg"
|
||||
mock_image = mocker.patch(
|
||||
"PIL.Image.open", return_value=MagicMock(spec=Image.Image)
|
||||
)
|
||||
cogagent.model.build_conversation_input_ids = MagicMock(
|
||||
return_value={
|
||||
"input_ids": MagicMock(),
|
||||
"token_type_ids": MagicMock(),
|
||||
"attention_mask": MagicMock(),
|
||||
"images": [MagicMock()],
|
||||
}
|
||||
)
|
||||
cogagent.model.__call__ = MagicMock(return_value="Mocked output")
|
||||
cogagent.decode = MagicMock(return_value="Mocked response")
|
||||
|
||||
output = cogagent.run(task, img)
|
||||
|
||||
assert output is not None
|
||||
mock_image.assert_called_once_with(img)
|
||||
cogagent.model.build_conversation_input_ids.assert_called_once()
|
||||
cogagent.model.__call__.assert_called_once()
|
||||
cogagent.decode.assert_called_once()
|
@ -1,39 +0,0 @@
|
||||
import pytest
|
||||
from swarms.models.modelscope_pipeline import ModelScopePipeline
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pipeline_params():
|
||||
return {
|
||||
"type_task": "text-generation",
|
||||
"model_name": "gpt2",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pipeline_model(pipeline_params):
|
||||
return ModelScopePipeline(**pipeline_params)
|
||||
|
||||
|
||||
def test_init(mocker, pipeline_params, pipeline_model):
|
||||
mock_pipeline = mocker.patch(
|
||||
"swarms.models.modelscope_pipeline.pipeline"
|
||||
)
|
||||
|
||||
for param, value in pipeline_params.items():
|
||||
assert getattr(pipeline_model, param) == value
|
||||
|
||||
mock_pipeline.assert_called_once_with(
|
||||
pipeline_params["type_task"],
|
||||
model=pipeline_params["model_name"],
|
||||
)
|
||||
|
||||
|
||||
def test_run(mocker, pipeline_model):
|
||||
task = "Generate a 10,000 word blog on health and wellness."
|
||||
pipeline_model.model = MagicMock(return_value="Mocked output")
|
||||
|
||||
output = pipeline_model.run(task)
|
||||
|
||||
assert output is not None
|
@ -1,58 +0,0 @@
|
||||
import pytest
|
||||
from swarms.models.modelscope_llm import ModelScopeAutoModel
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_params():
|
||||
return {
|
||||
"model_name": "gpt2",
|
||||
"tokenizer_name": None,
|
||||
"device": "cuda",
|
||||
"device_map": "auto",
|
||||
"max_new_tokens": 500,
|
||||
"skip_special_tokens": True,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def modelscope(model_params):
|
||||
return ModelScopeAutoModel(**model_params)
|
||||
|
||||
|
||||
def test_init(mocker, model_params, modelscope):
|
||||
mock_model = mocker.patch(
|
||||
"swarms.models.modelscope_llm.AutoModelForCausalLM.from_pretrained"
|
||||
)
|
||||
mock_tokenizer = mocker.patch(
|
||||
"swarms.models.modelscope_llm.AutoTokenizer.from_pretrained"
|
||||
)
|
||||
|
||||
for param, value in model_params.items():
|
||||
assert getattr(modelscope, param) == value
|
||||
|
||||
mock_tokenizer.assert_called_once_with(
|
||||
model_params["tokenizer_name"]
|
||||
)
|
||||
mock_model.assert_called_once_with(
|
||||
model_params["model_name"],
|
||||
device_map=model_params["device_map"],
|
||||
)
|
||||
|
||||
|
||||
def test_run(mocker, modelscope):
|
||||
task = "Generate a 10,000 word blog on health and wellness."
|
||||
mocker.patch(
|
||||
"swarms.models.modelscope_llm.AutoTokenizer.decode",
|
||||
return_value="Mocked output",
|
||||
)
|
||||
modelscope.model.generate = MagicMock(
|
||||
return_value=["Mocked token"]
|
||||
)
|
||||
modelscope.tokenizer = MagicMock(
|
||||
return_value={"input_ids": "Mocked input_ids"}
|
||||
)
|
||||
|
||||
output = modelscope.run(task)
|
||||
|
||||
assert output is not None
|
@ -1,141 +0,0 @@
|
||||
import pytest
|
||||
from swarms.models.vllm import vLLM
|
||||
|
||||
|
||||
# Fixture for initializing vLLM
|
||||
@pytest.fixture
|
||||
def vllm_instance():
|
||||
return vLLM()
|
||||
|
||||
|
||||
# Test the default initialization of vLLM
|
||||
def test_vllm_default_init(vllm_instance):
|
||||
assert isinstance(vllm_instance, vLLM)
|
||||
assert vllm_instance.model_name == "facebook/opt-13b"
|
||||
assert vllm_instance.tensor_parallel_size == 4
|
||||
assert not vllm_instance.trust_remote_code
|
||||
assert vllm_instance.revision is None
|
||||
assert vllm_instance.temperature == 0.5
|
||||
assert vllm_instance.top_p == 0.95
|
||||
|
||||
|
||||
# Test custom initialization of vLLM
|
||||
def test_vllm_custom_init():
|
||||
vllm_instance = vLLM(
|
||||
model_name="custom_model",
|
||||
tensor_parallel_size=8,
|
||||
trust_remote_code=True,
|
||||
revision="123",
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
)
|
||||
assert isinstance(vllm_instance, vLLM)
|
||||
assert vllm_instance.model_name == "custom_model"
|
||||
assert vllm_instance.tensor_parallel_size == 8
|
||||
assert vllm_instance.trust_remote_code
|
||||
assert vllm_instance.revision == "123"
|
||||
assert vllm_instance.temperature == 0.7
|
||||
assert vllm_instance.top_p == 0.9
|
||||
|
||||
|
||||
# Test the run method of vLLM
|
||||
def test_vllm_run(vllm_instance):
|
||||
task = "Hello, vLLM!"
|
||||
result = vllm_instance.run(task)
|
||||
assert isinstance(result, str)
|
||||
assert len(result) > 0
|
||||
|
||||
|
||||
# Test run method with different temperature and top_p values
|
||||
@pytest.mark.parametrize(
|
||||
"temperature, top_p", [(0.2, 0.8), (0.8, 0.2)]
|
||||
)
|
||||
def test_vllm_run_with_params(vllm_instance, temperature, top_p):
|
||||
task = "Temperature and Top-P Test"
|
||||
result = vllm_instance.run(
|
||||
task, temperature=temperature, top_p=top_p
|
||||
)
|
||||
assert isinstance(result, str)
|
||||
assert len(result) > 0
|
||||
|
||||
|
||||
# Test run method with a specific model revision
|
||||
def test_vllm_run_with_revision(vllm_instance):
|
||||
task = "Specific Model Revision Test"
|
||||
result = vllm_instance.run(task, revision="abc123")
|
||||
assert isinstance(result, str)
|
||||
assert len(result) > 0
|
||||
|
||||
|
||||
# Test run method with a specific model name
|
||||
def test_vllm_run_with_custom_model(vllm_instance):
|
||||
task = "Custom Model Test"
|
||||
custom_model_name = "my_custom_model"
|
||||
result = vllm_instance.run(task, model_name=custom_model_name)
|
||||
assert isinstance(result, str)
|
||||
assert len(result) > 0
|
||||
assert vllm_instance.model_name == custom_model_name
|
||||
|
||||
|
||||
# Test run method with invalid task input
|
||||
def test_vllm_run_invalid_task(vllm_instance):
|
||||
invalid_task = None
|
||||
with pytest.raises(ValueError):
|
||||
vllm_instance.run(invalid_task)
|
||||
|
||||
|
||||
# Test run method with a very high temperature value
|
||||
def test_vllm_run_high_temperature(vllm_instance):
|
||||
task = "High Temperature Test"
|
||||
high_temperature = 10.0
|
||||
result = vllm_instance.run(task, temperature=high_temperature)
|
||||
assert isinstance(result, str)
|
||||
assert len(result) > 0
|
||||
|
||||
|
||||
# Test run method with a very low top_p value
|
||||
def test_vllm_run_low_top_p(vllm_instance):
|
||||
task = "Low Top-P Test"
|
||||
low_top_p = 0.01
|
||||
result = vllm_instance.run(task, top_p=low_top_p)
|
||||
assert isinstance(result, str)
|
||||
assert len(result) > 0
|
||||
|
||||
|
||||
# Test run method with an empty task
|
||||
def test_vllm_run_empty_task(vllm_instance):
|
||||
empty_task = ""
|
||||
result = vllm_instance.run(empty_task)
|
||||
assert isinstance(result, str)
|
||||
assert len(result) == 0
|
||||
|
||||
|
||||
# Test initialization with invalid parameters
|
||||
def test_vllm_invalid_init():
|
||||
with pytest.raises(ValueError):
|
||||
vLLM(
|
||||
model_name=None,
|
||||
tensor_parallel_size=-1,
|
||||
trust_remote_code="invalid",
|
||||
revision=123,
|
||||
temperature=-0.1,
|
||||
top_p=1.1,
|
||||
)
|
||||
|
||||
|
||||
# Test running vLLM with a large number of parallel heads
|
||||
def test_vllm_large_parallel_heads():
|
||||
vllm_instance = vLLM(tensor_parallel_size=16)
|
||||
task = "Large Parallel Heads Test"
|
||||
result = vllm_instance.run(task)
|
||||
assert isinstance(result, str)
|
||||
assert len(result) > 0
|
||||
|
||||
|
||||
# Test running vLLM with trust_remote_code set to True
|
||||
def test_vllm_trust_remote_code():
|
||||
vllm_instance = vLLM(trust_remote_code=True)
|
||||
task = "Trust Remote Code Test"
|
||||
result = vllm_instance.run(task)
|
||||
assert isinstance(result, str)
|
||||
assert len(result) > 0
|
Loading…
Reference in new issue