parent
7472207d65
commit
403bed61fe
@ -0,0 +1,140 @@
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from skimage import transform
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class MedicalSAM:
|
||||
"""
|
||||
MedicalSAM class for performing semantic segmentation on medical images using the SAM model.
|
||||
|
||||
Attributes:
|
||||
model_path (str): The file path to the model weights.
|
||||
device (str): The device to run the model on (default is "cuda:0").
|
||||
model_weights_url (str): The URL to download the model weights from.
|
||||
|
||||
Methods:
|
||||
__post_init__(): Initializes the MedicalSAM object.
|
||||
download_model_weights(model_path: str): Downloads the model weights from the specified URL and saves them to the given file path.
|
||||
preprocess(img): Preprocesses the input image.
|
||||
run(img, box): Runs the semantic segmentation on the input image within the specified bounding box.
|
||||
|
||||
"""
|
||||
|
||||
model_path: str
|
||||
device: str = "cuda:0"
|
||||
model_weights_url: str = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
|
||||
|
||||
def __post_init__(self):
|
||||
if not os.path.exists(self.model_path):
|
||||
self.download_model_weights(self.model_path)
|
||||
|
||||
self.model = sam_model_registry["vit_b"](
|
||||
checkpoint=self.model_path
|
||||
)
|
||||
self.model = self.model.to(self.device)
|
||||
self.model.eval()
|
||||
|
||||
def download_model_weights(self, model_path: str):
|
||||
"""
|
||||
Downloads the model weights from the specified URL and saves them to the given file path.
|
||||
|
||||
Args:
|
||||
model_path (str): The file path where the model weights will be saved.
|
||||
|
||||
Raises:
|
||||
Exception: If the model weights fail to download.
|
||||
"""
|
||||
response = requests.get(self.model_weights_url, stream=True)
|
||||
if response.status_code == 200:
|
||||
with open(model_path, "wb") as f:
|
||||
f.write(response.content)
|
||||
else:
|
||||
raise Exception("Failed to download model weights.")
|
||||
|
||||
def preprocess(self, img: np.ndarray) -> Tuple[Tensor, int, int]:
|
||||
"""
|
||||
Preprocesses the input image.
|
||||
|
||||
Args:
|
||||
img: The input image.
|
||||
|
||||
Returns:
|
||||
img_tensor: The preprocessed image tensor.
|
||||
H: The original height of the image.
|
||||
W: The original width of the image.
|
||||
"""
|
||||
if len(img.shape) == 2:
|
||||
img = np.repeat(img[:, :, None], 3, axis=-1)
|
||||
H, W, _ = img.shape
|
||||
img = transform.resize(
|
||||
img,
|
||||
(1024, 1024),
|
||||
order=3,
|
||||
preserve_range=True,
|
||||
anti_aliasing=True,
|
||||
).astype(np.uint8)
|
||||
img = img - img.min() / np.clip(
|
||||
img.max() - img.min(), a_min=1e-8, a_max=None
|
||||
)
|
||||
img = torch.tensor(img).float().permute(2, 0, 1).unsqueeze(0)
|
||||
return img, H, W
|
||||
|
||||
@torch.no_grad()
|
||||
def run(self, img: np.ndarray, box: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Runs the semantic segmentation on the input image within the specified bounding box.
|
||||
|
||||
Args:
|
||||
img: The input image.
|
||||
box: The bounding box coordinates (x1, y1, x2, y2).
|
||||
|
||||
Returns:
|
||||
medsam_seg: The segmented image.
|
||||
"""
|
||||
img_tensor, H, W = self.preprocess(img)
|
||||
img_tensor = img_tensor.to(self.device)
|
||||
box_1024 = box / np.array([W, H, W, H]) * 1024
|
||||
img = self.model.image_encoder(img_tensor)
|
||||
|
||||
box_torch = torch.as_tensor(
|
||||
box_1024, dtype=torch.float, device=img_tensor.device
|
||||
)
|
||||
|
||||
if len(box_torch.shape) == 2:
|
||||
box_torch = box_torch[:, None, :]
|
||||
|
||||
sparse_embeddings, dense_embeddings = (
|
||||
self.model.prompt_encoder(
|
||||
points=None,
|
||||
boxes=box_torch,
|
||||
masks=None,
|
||||
)
|
||||
)
|
||||
|
||||
low_res_logits, _ = self.model.mask_decoder(
|
||||
image_embeddings=img,
|
||||
image_pe=self.model.prompt_encoder.get_dense_pe(),
|
||||
sparse_prompt_embeddings=sparse_embeddings,
|
||||
dense_prompt_embeddings=dense_embeddings,
|
||||
multimask_output=False,
|
||||
)
|
||||
|
||||
low_res_pred = torch.sigmoid(low_res_logits)
|
||||
low_res_pred = F.interpolate(
|
||||
low_res_pred,
|
||||
size=(H, W),
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
)
|
||||
low_res_pred = low_res_pred.squeeze().cpu().numpy()
|
||||
medsam_seg = (low_res_pred > 0.5).astype(np.uint8)
|
||||
|
||||
return medsam_seg
|
@ -0,0 +1,19 @@
|
||||
import requests
|
||||
|
||||
def download_weights_from_url(url: str, save_path: str = "models/weights.pth"):
|
||||
"""
|
||||
Downloads model weights from the given URL and saves them to the specified path.
|
||||
|
||||
Args:
|
||||
url (str): The URL from which to download the model weights.
|
||||
save_path (str, optional): The path where the downloaded weights should be saved.
|
||||
Defaults to "models/weights.pth".
|
||||
"""
|
||||
response = requests.get(url, stream=True)
|
||||
response.raise_for_status()
|
||||
|
||||
with open(save_path, "wb") as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
|
||||
print(f"Model weights downloaded and saved to {save_path}")
|
Loading…
Reference in new issue