You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/swarms/models/medical_sam.py

145 lines
4.5 KiB

import os
from dataclasses import dataclass
from typing import Tuple
import numpy as np
import requests
import torch
import torch.nn.functional as F
from skimage import transform
from torch import Tensor
def sam_model_registry():
pass
@dataclass
class MedicalSAM:
"""
MedicalSAM class for performing semantic segmentation on medical images using the SAM model.
Attributes:
model_path (str): The file path to the model weights.
device (str): The device to run the model on (default is "cuda:0").
model_weights_url (str): The URL to download the model weights from.
Methods:
__post_init__(): Initializes the MedicalSAM object.
download_model_weights(model_path: str): Downloads the model weights from the specified URL and saves them to the given file path.
preprocess(img): Preprocesses the input image.
run(img, box): Runs the semantic segmentation on the input image within the specified bounding box.
"""
model_path: str
device: str = "cuda:0"
model_weights_url: str = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
def __post_init__(self):
if not os.path.exists(self.model_path):
self.download_model_weights(self.model_path)
self.model = sam_model_registry["vit_b"](
checkpoint=self.model_path
)
self.model = self.model.to(self.device)
self.model.eval()
def download_model_weights(self, model_path: str):
"""
Downloads the model weights from the specified URL and saves them to the given file path.
Args:
model_path (str): The file path where the model weights will be saved.
Raises:
Exception: If the model weights fail to download.
"""
response = requests.get(self.model_weights_url, stream=True)
if response.status_code == 200:
with open(model_path, "wb") as f:
f.write(response.content)
else:
raise Exception("Failed to download model weights.")
def preprocess(self, img: np.ndarray) -> Tuple[Tensor, int, int]:
"""
Preprocesses the input image.
Args:
img: The input image.
Returns:
img_tensor: The preprocessed image tensor.
H: The original height of the image.
W: The original width of the image.
"""
if len(img.shape) == 2:
img = np.repeat(img[:, :, None], 3, axis=-1)
H, W, _ = img.shape
img = transform.resize(
img,
(1024, 1024),
order=3,
preserve_range=True,
anti_aliasing=True,
).astype(np.uint8)
img = img - img.min() / np.clip(
img.max() - img.min(), a_min=1e-8, a_max=None
)
img = torch.tensor(img).float().permute(2, 0, 1).unsqueeze(0)
return img, H, W
@torch.no_grad()
def run(self, img: np.ndarray, box: np.ndarray) -> np.ndarray:
"""
Runs the semantic segmentation on the input image within the specified bounding box.
Args:
img: The input image.
box: The bounding box coordinates (x1, y1, x2, y2).
Returns:
medsam_seg: The segmented image.
"""
img_tensor, H, W = self.preprocess(img)
img_tensor = img_tensor.to(self.device)
box_1024 = box / np.array([W, H, W, H]) * 1024
img = self.model.image_encoder(img_tensor)
box_torch = torch.as_tensor(
box_1024, dtype=torch.float, device=img_tensor.device
)
if len(box_torch.shape) == 2:
box_torch = box_torch[:, None, :]
sparse_embeddings, dense_embeddings = (
self.model.prompt_encoder(
points=None,
boxes=box_torch,
masks=None,
)
)
low_res_logits, _ = self.model.mask_decoder(
image_embeddings=img,
image_pe=self.model.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=False,
)
low_res_pred = torch.sigmoid(low_res_logits)
low_res_pred = F.interpolate(
low_res_pred,
size=(H, W),
mode="bilinear",
align_corners=False,
)
low_res_pred = low_res_pred.squeeze().cpu().numpy()
medsam_seg = (low_res_pred > 0.5).astype(np.uint8)
return medsam_seg