parent
b1598aa71a
commit
10829b03e2
@ -1,310 +0,0 @@
|
|||||||
from typing import Callable, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import Tensor, nn
|
|
||||||
from torch.distributed._tensor import (
|
|
||||||
DeviceMesh,
|
|
||||||
DTensor,
|
|
||||||
Replicate,
|
|
||||||
Shard,
|
|
||||||
distribute_tensor,
|
|
||||||
)
|
|
||||||
from zeta.nn import QuantizedLN
|
|
||||||
|
|
||||||
try:
|
|
||||||
from peft.tuners.lora import Linear as LoRALinear
|
|
||||||
except ImportError:
|
|
||||||
|
|
||||||
class LoRALinear:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def try_to_local(tensor: Union[Tensor, DTensor]):
|
|
||||||
"""Try to convert DTensor to Tensor.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tensor (Tensor|DTensor): Tensor to convert.
|
|
||||||
"""
|
|
||||||
if isinstance(tensor, DTensor):
|
|
||||||
tensor = tensor.to_local()
|
|
||||||
return tensor
|
|
||||||
|
|
||||||
|
|
||||||
def module_to_local(module: nn.Module):
|
|
||||||
"""convert all DTensor parameters to Tensor parameters in module.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
module (Module): Module to convert.
|
|
||||||
"""
|
|
||||||
for name, mod in module.named_children():
|
|
||||||
module_to_local(mod)
|
|
||||||
|
|
||||||
for name, param in module.named_parameters(recurse=False):
|
|
||||||
module.register_parameter(
|
|
||||||
name, nn.Parameter(try_to_local(param))
|
|
||||||
)
|
|
||||||
|
|
||||||
for name, buf in module.named_buffers(recurse=False):
|
|
||||||
module.register_buffer(name, try_to_local(buf))
|
|
||||||
|
|
||||||
|
|
||||||
def rowwise_parallelize_linear(
|
|
||||||
module: nn.Module, device_mesh: DeviceMesh, to_local: bool = False
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
This function parallelizes the input :class:`nn.Linear` module in
|
|
||||||
:class:`RowwiseParallel` style.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
module (:class:`nn.Module`):
|
|
||||||
The :class:`nn.Linear` module to be parallelized.
|
|
||||||
device_mesh (:class:`DeviceMesh`):
|
|
||||||
Object which describes the mesh topology of devices.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
None
|
|
||||||
"""
|
|
||||||
for name, param in module.named_parameters():
|
|
||||||
dist_spec = (
|
|
||||||
[Shard(1)] if name == "weight" else [Replicate()] # type: ignore[list-item]
|
|
||||||
)
|
|
||||||
|
|
||||||
dist_tensor = distribute_tensor(param, device_mesh, dist_spec)
|
|
||||||
if to_local:
|
|
||||||
dist_tensor = try_to_local(dist_tensor)
|
|
||||||
if name == "bias":
|
|
||||||
# rowwise linear would add bias more than ones.
|
|
||||||
dist_tensor /= device_mesh.size()
|
|
||||||
dist_param = torch.nn.Parameter(dist_tensor)
|
|
||||||
module.register_parameter(name, dist_param)
|
|
||||||
|
|
||||||
# Weight, bias and scale are registered as buffer in QLinear
|
|
||||||
for name, buffer in module.named_buffers():
|
|
||||||
dist_spec = (
|
|
||||||
[Shard(1)] if name == "weight" else [Replicate()] # type: ignore[list-item]
|
|
||||||
)
|
|
||||||
|
|
||||||
dist_tensor = distribute_tensor(
|
|
||||||
buffer, device_mesh, dist_spec
|
|
||||||
)
|
|
||||||
if to_local:
|
|
||||||
dist_tensor = try_to_local(dist_tensor)
|
|
||||||
if name == "bias":
|
|
||||||
# rowwise linear would add bias more than ones.
|
|
||||||
dist_tensor /= device_mesh.size()
|
|
||||||
module.register_buffer(name, dist_tensor)
|
|
||||||
|
|
||||||
dist_tensor = distribute_tensor(
|
|
||||||
buffer, device_mesh, dist_spec
|
|
||||||
)
|
|
||||||
if to_local:
|
|
||||||
dist_tensor = try_to_local(dist_tensor)
|
|
||||||
module.register_buffer(name, dist_tensor)
|
|
||||||
|
|
||||||
|
|
||||||
def rowwise_parallelize_loralinear(
|
|
||||||
module: LoRALinear,
|
|
||||||
device_mesh: DeviceMesh,
|
|
||||||
to_local: bool = False,
|
|
||||||
) -> None:
|
|
||||||
"""rowwize parallelize lora linear.
|
|
||||||
|
|
||||||
Read S-LoRA for more detail.
|
|
||||||
"""
|
|
||||||
rowwise_parallelize_linear(
|
|
||||||
module.base_layer, device_mesh=device_mesh, to_local=to_local
|
|
||||||
)
|
|
||||||
for mod in module.lora_A.values():
|
|
||||||
rowwise_parallelize_linear(
|
|
||||||
mod, device_mesh=device_mesh, to_local=to_local
|
|
||||||
)
|
|
||||||
for mod in module.lora_B.values():
|
|
||||||
colwise_parallelize_linear(
|
|
||||||
mod, device_mesh=device_mesh, to_local=to_local
|
|
||||||
)
|
|
||||||
module._tp_mode = "rowwise"
|
|
||||||
|
|
||||||
|
|
||||||
def rowwise_parallelize_linear_fn(
|
|
||||||
module: nn.Module, device_mesh: DeviceMesh, to_local: bool = False
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
This function parallelizes the input :Linear module in
|
|
||||||
:class:`RowwiseParallel` style.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
module (:class:`nn.Module`):
|
|
||||||
The :class:`nn.Linear` module to be parallelized.
|
|
||||||
device_mesh (:class:`DeviceMesh`):
|
|
||||||
Object which describes the mesh topology of devices.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
None
|
|
||||||
"""
|
|
||||||
if isinstance(module, (torch.nn.Linear, QuantizedLN)):
|
|
||||||
return rowwise_parallelize_linear(
|
|
||||||
module, device_mesh=device_mesh, to_local=to_local
|
|
||||||
)
|
|
||||||
elif isinstance(module, LoRALinear):
|
|
||||||
return rowwise_parallelize_loralinear(
|
|
||||||
module, device_mesh=device_mesh, to_local=to_local
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise TypeError(f"Unsupported module: {type(module)}")
|
|
||||||
|
|
||||||
|
|
||||||
def colwise_parallelize_linear(
|
|
||||||
module: nn.Module, device_mesh: DeviceMesh, to_local: bool = False
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
This function parallelizes the input :class:`nn.Linear` module in
|
|
||||||
:class:`ColwiseParallel` style.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
module (:class:`nn.Module`):
|
|
||||||
The :class:`nn.Linear` module to be parallelized.
|
|
||||||
device_mesh (:class:`DeviceMesh`):
|
|
||||||
Object which describes the mesh topology of devices.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
None
|
|
||||||
"""
|
|
||||||
|
|
||||||
for name, param in module.named_parameters():
|
|
||||||
dist_tensor = distribute_tensor(
|
|
||||||
param, device_mesh, [Shard(0)]
|
|
||||||
)
|
|
||||||
if to_local:
|
|
||||||
dist_tensor = try_to_local(dist_tensor)
|
|
||||||
dist_param = torch.nn.Parameter(dist_tensor)
|
|
||||||
module.register_parameter(name, dist_param)
|
|
||||||
# Weight, bias and scale are registered as buffer in QLinear
|
|
||||||
for name, buffer in module.named_buffers():
|
|
||||||
dist_tensor = distribute_tensor(
|
|
||||||
buffer, device_mesh, [Shard(0)]
|
|
||||||
)
|
|
||||||
if to_local:
|
|
||||||
dist_tensor = try_to_local(dist_tensor)
|
|
||||||
module.register_buffer(name, dist_tensor)
|
|
||||||
|
|
||||||
|
|
||||||
def colwise_parallelize_loralinear(
|
|
||||||
module: nn.Module, device_mesh: DeviceMesh, to_local: bool = False
|
|
||||||
) -> None:
|
|
||||||
"""colwise parallelize lora linear."""
|
|
||||||
colwise_parallelize_linear(
|
|
||||||
module.base_layer, device_mesh=device_mesh, to_local=to_local
|
|
||||||
)
|
|
||||||
for mod in module.lora_A.values():
|
|
||||||
colwise_parallelize_linear(
|
|
||||||
mod, device_mesh=device_mesh, to_local=to_local
|
|
||||||
)
|
|
||||||
for mod in module.lora_B.values():
|
|
||||||
colwise_parallelize_linear(
|
|
||||||
mod, device_mesh=device_mesh, to_local=to_local
|
|
||||||
)
|
|
||||||
module._tp_mode = "colwise"
|
|
||||||
|
|
||||||
|
|
||||||
def colwise_parallelize_linear_fn(
|
|
||||||
module: nn.Module, device_mesh: DeviceMesh, to_local: bool = False
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
This function parallelizes the input :Linear module in
|
|
||||||
:class:`ColwiseParallel` style.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
module (:class:`nn.Module`):
|
|
||||||
The :class:`nn.Linear` module to be parallelized.
|
|
||||||
device_mesh (:class:`DeviceMesh`):
|
|
||||||
Object which describes the mesh topology of devices.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
None
|
|
||||||
"""
|
|
||||||
if isinstance(module, (torch.nn.Linear, QuantizedLN)):
|
|
||||||
return colwise_parallelize_linear(
|
|
||||||
module, device_mesh=device_mesh, to_local=to_local
|
|
||||||
)
|
|
||||||
elif isinstance(module, LoRALinear):
|
|
||||||
return colwise_parallelize_loralinear(
|
|
||||||
module, device_mesh=device_mesh, to_local=to_local
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise TypeError(f"Unsupported module: {type(module)}")
|
|
||||||
|
|
||||||
|
|
||||||
def _partition_module(
|
|
||||||
mod_name: str,
|
|
||||||
prefix: str,
|
|
||||||
module: nn.Module,
|
|
||||||
device_mesh: DeviceMesh,
|
|
||||||
func: Callable,
|
|
||||||
):
|
|
||||||
"""partition module.
|
|
||||||
|
|
||||||
Parameters in module won't be force Replicated.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
mod_name (str): module name.
|
|
||||||
prefix (str): Parameter prefix.
|
|
||||||
module (Module): Module to be partitioned.
|
|
||||||
device_mesh (DeviceMesh): The device mesh.
|
|
||||||
func (Callable): partition callback
|
|
||||||
"""
|
|
||||||
for name, mod in module.named_children():
|
|
||||||
child_name = f"{prefix}{name}"
|
|
||||||
_partition_module(
|
|
||||||
child_name,
|
|
||||||
child_name + ".",
|
|
||||||
module=mod,
|
|
||||||
device_mesh=device_mesh,
|
|
||||||
func=func,
|
|
||||||
)
|
|
||||||
|
|
||||||
func(mod_name, module, device_mesh)
|
|
||||||
|
|
||||||
|
|
||||||
def partition_module(
|
|
||||||
module: nn.Module,
|
|
||||||
device_mesh: DeviceMesh,
|
|
||||||
func: Callable,
|
|
||||||
to_local: bool = False,
|
|
||||||
):
|
|
||||||
"""partition module.
|
|
||||||
|
|
||||||
Parameters in module won't be force Replicated.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
module (Module): Module to be partitioned.
|
|
||||||
device_mesh (DeviceMesh): The device mesh.
|
|
||||||
func (Callable): partition callback.
|
|
||||||
to_local (bool): Convert all DTensor parameters to Tensor parameters.
|
|
||||||
"""
|
|
||||||
_partition_module(
|
|
||||||
"", "", module=module, device_mesh=device_mesh, func=func
|
|
||||||
)
|
|
||||||
|
|
||||||
if to_local:
|
|
||||||
module_to_local(module)
|
|
||||||
|
|
||||||
|
|
||||||
def replicate_module(model: nn.Module, device_mesh: DeviceMesh):
|
|
||||||
"""Replicate all parameters in module.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model (Module): Module to perform replicate.
|
|
||||||
device_mesh (DeviceMesh): The distribution device mesh.
|
|
||||||
"""
|
|
||||||
for name, param in model.named_parameters(recurse=False):
|
|
||||||
param = distribute_tensor(
|
|
||||||
param, device_mesh=device_mesh, placements=[Replicate()]
|
|
||||||
).to_local()
|
|
||||||
param = nn.Parameter(param)
|
|
||||||
model.register_parameter(name, param)
|
|
||||||
|
|
||||||
for name, buf in model.named_buffers(recurse=False):
|
|
||||||
buf = distribute_tensor(
|
|
||||||
buf, device_mesh=device_mesh, placements=[Replicate()]
|
|
||||||
).to_local()
|
|
||||||
model.register_buffer(name, buf)
|
|
@ -1,10 +1,21 @@
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
logger.add(
|
logger.add(
|
||||||
"MessagePool.log",
|
"swarms.log",
|
||||||
level="INFO",
|
level="INFO",
|
||||||
colorize=True,
|
colorize=True,
|
||||||
format="<green>{time}</green> <level>{message}</level>",
|
format="<green>{time}</green> <level>{message}</level>",
|
||||||
backtrace=True,
|
backtrace=True,
|
||||||
diagnose=True,
|
diagnose=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def loguru_logger(file_path: str = "swarms.log"):
|
||||||
|
return logger.add(
|
||||||
|
file_path,
|
||||||
|
level="INFO",
|
||||||
|
colorize=True,
|
||||||
|
format="<green>{time}</green> <level>{message}</level>",
|
||||||
|
backtrace=True,
|
||||||
|
diagnose=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@ -1,259 +0,0 @@
|
|||||||
from enum import Enum
|
|
||||||
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
import supervision as sv
|
|
||||||
|
|
||||||
|
|
||||||
class FeatureType(Enum):
|
|
||||||
"""
|
|
||||||
An enumeration to represent the types of features for mask adjustment in image
|
|
||||||
segmentation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
ISLAND = "ISLAND"
|
|
||||||
HOLE = "HOLE"
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def list(cls):
|
|
||||||
return list(map(lambda c: c.value, cls))
|
|
||||||
|
|
||||||
|
|
||||||
def compute_mask_iou_vectorized(masks: np.ndarray) -> np.ndarray:
|
|
||||||
"""
|
|
||||||
Vectorized computation of the Intersection over Union (IoU) for all pairs of masks.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
masks (np.ndarray): A 3D numpy array with shape `(N, H, W)`, where `N` is the
|
|
||||||
number of masks, `H` is the height, and `W` is the width.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
np.ndarray: A 2D numpy array of shape `(N, N)` where each element `[i, j]` is
|
|
||||||
the IoU between masks `i` and `j`.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If any of the masks is found to be empty.
|
|
||||||
"""
|
|
||||||
if np.any(masks.sum(axis=(1, 2)) == 0):
|
|
||||||
raise ValueError(
|
|
||||||
"One or more masks are empty. Please filter out empty"
|
|
||||||
" masks before using `compute_iou_vectorized` function."
|
|
||||||
)
|
|
||||||
|
|
||||||
masks_bool = masks.astype(bool)
|
|
||||||
masks_flat = masks_bool.reshape(masks.shape[0], -1)
|
|
||||||
intersection = np.logical_and(
|
|
||||||
masks_flat[:, None], masks_flat[None, :]
|
|
||||||
).sum(axis=2)
|
|
||||||
union = np.logical_or(
|
|
||||||
masks_flat[:, None], masks_flat[None, :]
|
|
||||||
).sum(axis=2)
|
|
||||||
iou_matrix = intersection / union
|
|
||||||
return iou_matrix
|
|
||||||
|
|
||||||
|
|
||||||
def mask_non_max_suppression(
|
|
||||||
masks: np.ndarray, iou_threshold: float = 0.6
|
|
||||||
) -> np.ndarray:
|
|
||||||
"""
|
|
||||||
Performs Non-Max Suppression on a set of masks by prioritizing larger masks and
|
|
||||||
removing smaller masks that overlap significantly.
|
|
||||||
|
|
||||||
When the IoU between two masks exceeds the specified threshold, the smaller mask
|
|
||||||
(in terms of area) is discarded. This process is repeated for each pair of masks,
|
|
||||||
effectively filtering out masks that are significantly overlapped by larger ones.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
masks (np.ndarray): A 3D numpy array with shape `(N, H, W)`, where `N` is the
|
|
||||||
number of masks, `H` is the height, and `W` is the width.
|
|
||||||
iou_threshold (float): The IoU threshold for determining significant overlap.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
np.ndarray: A 3D numpy array of filtered masks.
|
|
||||||
"""
|
|
||||||
num_masks = masks.shape[0]
|
|
||||||
areas = masks.sum(axis=(1, 2))
|
|
||||||
sorted_idx = np.argsort(-areas)
|
|
||||||
keep_mask = np.ones(num_masks, dtype=bool)
|
|
||||||
iou_matrix = compute_mask_iou_vectorized(masks)
|
|
||||||
for i in range(num_masks):
|
|
||||||
if not keep_mask[sorted_idx[i]]:
|
|
||||||
continue
|
|
||||||
|
|
||||||
overlapping_masks = iou_matrix[sorted_idx[i]] > iou_threshold
|
|
||||||
overlapping_masks[sorted_idx[i]] = False
|
|
||||||
overlapping_indices = np.where(overlapping_masks)[0]
|
|
||||||
keep_mask[sorted_idx[overlapping_indices]] = False
|
|
||||||
|
|
||||||
return masks[keep_mask]
|
|
||||||
|
|
||||||
|
|
||||||
def filter_masks_by_relative_area(
|
|
||||||
masks: np.ndarray,
|
|
||||||
minimum_area: float = 0.01,
|
|
||||||
maximum_area: float = 1.0,
|
|
||||||
) -> np.ndarray:
|
|
||||||
"""
|
|
||||||
Filters masks based on their relative area within the total area of each mask.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
masks (np.ndarray): A 3D numpy array with shape `(N, H, W)`, where `N` is the
|
|
||||||
number of masks, `H` is the height, and `W` is the width.
|
|
||||||
minimum_area (float): The minimum relative area threshold. Must be between `0`
|
|
||||||
and `1`.
|
|
||||||
maximum_area (float): The maximum relative area threshold. Must be between `0`
|
|
||||||
and `1`.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
np.ndarray: A 3D numpy array containing masks that fall within the specified
|
|
||||||
relative area range.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If `minimum_area` or `maximum_area` are outside the `0` to `1`
|
|
||||||
range, or if `minimum_area` is greater than `maximum_area`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if not (isinstance(masks, np.ndarray) and masks.ndim == 3):
|
|
||||||
raise ValueError("Input must be a 3D numpy array.")
|
|
||||||
|
|
||||||
if not (0 <= minimum_area <= 1) or not (0 <= maximum_area <= 1):
|
|
||||||
raise ValueError(
|
|
||||||
"`minimum_area` and `maximum_area` must be between 0"
|
|
||||||
" and 1."
|
|
||||||
)
|
|
||||||
|
|
||||||
if minimum_area > maximum_area:
|
|
||||||
raise ValueError(
|
|
||||||
"`minimum_area` must be less than or equal to"
|
|
||||||
" `maximum_area`."
|
|
||||||
)
|
|
||||||
|
|
||||||
total_area = masks.shape[1] * masks.shape[2]
|
|
||||||
relative_areas = masks.sum(axis=(1, 2)) / total_area
|
|
||||||
return masks[
|
|
||||||
(relative_areas >= minimum_area)
|
|
||||||
& (relative_areas <= maximum_area)
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def adjust_mask_features_by_relative_area(
|
|
||||||
mask: np.ndarray,
|
|
||||||
area_threshold: float,
|
|
||||||
feature_type: FeatureType = FeatureType.ISLAND,
|
|
||||||
) -> np.ndarray:
|
|
||||||
"""
|
|
||||||
Adjusts a mask by removing small islands or filling small holes based on a relative
|
|
||||||
area threshold.
|
|
||||||
|
|
||||||
!!! warning
|
|
||||||
|
|
||||||
Running this function on a mask with small islands may result in empty masks.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
mask (np.ndarray): A 2D numpy array with shape `(H, W)`, where `H` is the
|
|
||||||
height, and `W` is the width.
|
|
||||||
area_threshold (float): Threshold for relative area to remove or fill features.
|
|
||||||
feature_type (FeatureType): Type of feature to adjust (`ISLAND` for removing
|
|
||||||
islands, `HOLE` for filling holes).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
np.ndarray: A 2D numpy array containing mask.
|
|
||||||
"""
|
|
||||||
height, width = mask.shape
|
|
||||||
total_area = width * height
|
|
||||||
|
|
||||||
mask = np.uint8(mask * 255)
|
|
||||||
operation = (
|
|
||||||
cv2.RETR_EXTERNAL
|
|
||||||
if feature_type == FeatureType.ISLAND
|
|
||||||
else cv2.RETR_CCOMP
|
|
||||||
)
|
|
||||||
contours, _ = cv2.findContours(
|
|
||||||
mask, operation, cv2.CHAIN_APPROX_SIMPLE
|
|
||||||
)
|
|
||||||
|
|
||||||
for contour in contours:
|
|
||||||
area = cv2.contourArea(contour)
|
|
||||||
relative_area = area / total_area
|
|
||||||
if relative_area < area_threshold:
|
|
||||||
cv2.drawContours(
|
|
||||||
image=mask,
|
|
||||||
contours=[contour],
|
|
||||||
contourIdx=-1,
|
|
||||||
color=(
|
|
||||||
0 if feature_type == FeatureType.ISLAND else 255
|
|
||||||
),
|
|
||||||
thickness=-1,
|
|
||||||
)
|
|
||||||
return np.where(mask > 0, 1, 0).astype(bool)
|
|
||||||
|
|
||||||
|
|
||||||
def masks_to_marks(masks: np.ndarray) -> sv.Detections:
|
|
||||||
"""
|
|
||||||
Converts a set of masks to a marks (sv.Detections) object.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
masks (np.ndarray): A 3D numpy array with shape `(N, H, W)`, where `N` is the
|
|
||||||
number of masks, `H` is the height, and `W` is the width.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
sv.Detections: An object containing the masks and their bounding box
|
|
||||||
coordinates.
|
|
||||||
"""
|
|
||||||
if len(masks) == 0:
|
|
||||||
marks = sv.Detections.empty()
|
|
||||||
marks.mask = np.empty((0, 0, 0), dtype=bool)
|
|
||||||
return marks
|
|
||||||
return sv.Detections(
|
|
||||||
mask=masks, xyxy=sv.mask_to_xyxy(masks=masks)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def refine_marks(
|
|
||||||
marks: sv.Detections,
|
|
||||||
maximum_hole_area: float = 0.01,
|
|
||||||
maximum_island_area: float = 0.01,
|
|
||||||
minimum_mask_area: float = 0.02,
|
|
||||||
maximum_mask_area: float = 1.0,
|
|
||||||
) -> sv.Detections:
|
|
||||||
"""
|
|
||||||
Refines a set of masks by removing small islands and holes, and filtering by mask
|
|
||||||
area.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
marks (sv.Detections): An object containing the masks and their bounding box
|
|
||||||
coordinates.
|
|
||||||
maximum_hole_area (float): The maximum relative area of holes to be filled in
|
|
||||||
each mask.
|
|
||||||
maximum_island_area (float): The maximum relative area of islands to be removed
|
|
||||||
from each mask.
|
|
||||||
minimum_mask_area (float): The minimum relative area for a mask to be retained.
|
|
||||||
maximum_mask_area (float): The maximum relative area for a mask to be retained.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
sv.Detections: An object containing the masks and their bounding box
|
|
||||||
coordinates.
|
|
||||||
"""
|
|
||||||
result_masks = []
|
|
||||||
for mask in marks.mask:
|
|
||||||
mask = adjust_mask_features_by_relative_area(
|
|
||||||
mask=mask,
|
|
||||||
area_threshold=maximum_island_area,
|
|
||||||
feature_type=FeatureType.ISLAND,
|
|
||||||
)
|
|
||||||
mask = adjust_mask_features_by_relative_area(
|
|
||||||
mask=mask,
|
|
||||||
area_threshold=maximum_hole_area,
|
|
||||||
feature_type=FeatureType.HOLE,
|
|
||||||
)
|
|
||||||
if np.any(mask):
|
|
||||||
result_masks.append(mask)
|
|
||||||
result_masks = np.array(result_masks)
|
|
||||||
result_masks = filter_masks_by_relative_area(
|
|
||||||
masks=result_masks,
|
|
||||||
minimum_area=minimum_mask_area,
|
|
||||||
maximum_area=maximum_mask_area,
|
|
||||||
)
|
|
||||||
return sv.Detections(
|
|
||||||
mask=result_masks, xyxy=sv.mask_to_xyxy(masks=result_masks)
|
|
||||||
)
|
|
@ -1,27 +0,0 @@
|
|||||||
import tiktoken
|
|
||||||
|
|
||||||
|
|
||||||
def limit_tokens_from_string(
|
|
||||||
string: str, model: str = "gpt-4", limit: int = 500
|
|
||||||
) -> str:
|
|
||||||
"""Limits the number of tokens in a string
|
|
||||||
|
|
||||||
Args:
|
|
||||||
string (str): _description_
|
|
||||||
model (str): _description_
|
|
||||||
limit (int): _description_
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: _description_
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
encoding = tiktoken.encoding_for_model(model)
|
|
||||||
except Exception:
|
|
||||||
encoding = tiktoken.encoding_for_model(
|
|
||||||
"gpt2"
|
|
||||||
) # Fallback for others.
|
|
||||||
|
|
||||||
encoded = encoding.encode(string)
|
|
||||||
|
|
||||||
out = encoding.decode(encoded[:limit])
|
|
||||||
return out
|
|
Loading…
Reference in new issue