parent
7472207d65
commit
403bed61fe
@ -0,0 +1,140 @@
|
|||||||
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import requests
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from skimage import transform
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MedicalSAM:
|
||||||
|
"""
|
||||||
|
MedicalSAM class for performing semantic segmentation on medical images using the SAM model.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
model_path (str): The file path to the model weights.
|
||||||
|
device (str): The device to run the model on (default is "cuda:0").
|
||||||
|
model_weights_url (str): The URL to download the model weights from.
|
||||||
|
|
||||||
|
Methods:
|
||||||
|
__post_init__(): Initializes the MedicalSAM object.
|
||||||
|
download_model_weights(model_path: str): Downloads the model weights from the specified URL and saves them to the given file path.
|
||||||
|
preprocess(img): Preprocesses the input image.
|
||||||
|
run(img, box): Runs the semantic segmentation on the input image within the specified bounding box.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_path: str
|
||||||
|
device: str = "cuda:0"
|
||||||
|
model_weights_url: str = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if not os.path.exists(self.model_path):
|
||||||
|
self.download_model_weights(self.model_path)
|
||||||
|
|
||||||
|
self.model = sam_model_registry["vit_b"](
|
||||||
|
checkpoint=self.model_path
|
||||||
|
)
|
||||||
|
self.model = self.model.to(self.device)
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
|
def download_model_weights(self, model_path: str):
|
||||||
|
"""
|
||||||
|
Downloads the model weights from the specified URL and saves them to the given file path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path (str): The file path where the model weights will be saved.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If the model weights fail to download.
|
||||||
|
"""
|
||||||
|
response = requests.get(self.model_weights_url, stream=True)
|
||||||
|
if response.status_code == 200:
|
||||||
|
with open(model_path, "wb") as f:
|
||||||
|
f.write(response.content)
|
||||||
|
else:
|
||||||
|
raise Exception("Failed to download model weights.")
|
||||||
|
|
||||||
|
def preprocess(self, img: np.ndarray) -> Tuple[Tensor, int, int]:
|
||||||
|
"""
|
||||||
|
Preprocesses the input image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img: The input image.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
img_tensor: The preprocessed image tensor.
|
||||||
|
H: The original height of the image.
|
||||||
|
W: The original width of the image.
|
||||||
|
"""
|
||||||
|
if len(img.shape) == 2:
|
||||||
|
img = np.repeat(img[:, :, None], 3, axis=-1)
|
||||||
|
H, W, _ = img.shape
|
||||||
|
img = transform.resize(
|
||||||
|
img,
|
||||||
|
(1024, 1024),
|
||||||
|
order=3,
|
||||||
|
preserve_range=True,
|
||||||
|
anti_aliasing=True,
|
||||||
|
).astype(np.uint8)
|
||||||
|
img = img - img.min() / np.clip(
|
||||||
|
img.max() - img.min(), a_min=1e-8, a_max=None
|
||||||
|
)
|
||||||
|
img = torch.tensor(img).float().permute(2, 0, 1).unsqueeze(0)
|
||||||
|
return img, H, W
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def run(self, img: np.ndarray, box: np.ndarray) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Runs the semantic segmentation on the input image within the specified bounding box.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img: The input image.
|
||||||
|
box: The bounding box coordinates (x1, y1, x2, y2).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
medsam_seg: The segmented image.
|
||||||
|
"""
|
||||||
|
img_tensor, H, W = self.preprocess(img)
|
||||||
|
img_tensor = img_tensor.to(self.device)
|
||||||
|
box_1024 = box / np.array([W, H, W, H]) * 1024
|
||||||
|
img = self.model.image_encoder(img_tensor)
|
||||||
|
|
||||||
|
box_torch = torch.as_tensor(
|
||||||
|
box_1024, dtype=torch.float, device=img_tensor.device
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(box_torch.shape) == 2:
|
||||||
|
box_torch = box_torch[:, None, :]
|
||||||
|
|
||||||
|
sparse_embeddings, dense_embeddings = (
|
||||||
|
self.model.prompt_encoder(
|
||||||
|
points=None,
|
||||||
|
boxes=box_torch,
|
||||||
|
masks=None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
low_res_logits, _ = self.model.mask_decoder(
|
||||||
|
image_embeddings=img,
|
||||||
|
image_pe=self.model.prompt_encoder.get_dense_pe(),
|
||||||
|
sparse_prompt_embeddings=sparse_embeddings,
|
||||||
|
dense_prompt_embeddings=dense_embeddings,
|
||||||
|
multimask_output=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
low_res_pred = torch.sigmoid(low_res_logits)
|
||||||
|
low_res_pred = F.interpolate(
|
||||||
|
low_res_pred,
|
||||||
|
size=(H, W),
|
||||||
|
mode="bilinear",
|
||||||
|
align_corners=False,
|
||||||
|
)
|
||||||
|
low_res_pred = low_res_pred.squeeze().cpu().numpy()
|
||||||
|
medsam_seg = (low_res_pred > 0.5).astype(np.uint8)
|
||||||
|
|
||||||
|
return medsam_seg
|
@ -0,0 +1,19 @@
|
|||||||
|
import requests
|
||||||
|
|
||||||
|
def download_weights_from_url(url: str, save_path: str = "models/weights.pth"):
|
||||||
|
"""
|
||||||
|
Downloads model weights from the given URL and saves them to the specified path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url (str): The URL from which to download the model weights.
|
||||||
|
save_path (str, optional): The path where the downloaded weights should be saved.
|
||||||
|
Defaults to "models/weights.pth".
|
||||||
|
"""
|
||||||
|
response = requests.get(url, stream=True)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
with open(save_path, "wb") as f:
|
||||||
|
for chunk in response.iter_content(chunk_size=8192):
|
||||||
|
f.write(chunk)
|
||||||
|
|
||||||
|
print(f"Model weights downloaded and saved to {save_path}")
|
Loading…
Reference in new issue