[BUFG][Odin]

pull/362/head
Kye 12 months ago
parent 7472207d65
commit 403bed61fe

@ -19,7 +19,7 @@ va = Agent(llm=llm, ai_name="VA")
# Create a company
company = Company(
org_chart = [[dev, va]],
org_chart=[[dev, va]],
shared_instructions="Do your best",
ceo=ceo,
)

@ -73,6 +73,7 @@ psutil = "*"
ultralytics = "*"
timm = "*"
supervision = "*"
scikit-image = "*"

@ -61,4 +61,5 @@ pre-commit==3.2.2
peft
psutil
ultralytics
supervision
supervision
scikit-image

@ -0,0 +1,140 @@
import os
from dataclasses import dataclass
from typing import Tuple
import numpy as np
import requests
import torch
import torch.nn.functional as F
from skimage import transform
from torch import Tensor
@dataclass
class MedicalSAM:
"""
MedicalSAM class for performing semantic segmentation on medical images using the SAM model.
Attributes:
model_path (str): The file path to the model weights.
device (str): The device to run the model on (default is "cuda:0").
model_weights_url (str): The URL to download the model weights from.
Methods:
__post_init__(): Initializes the MedicalSAM object.
download_model_weights(model_path: str): Downloads the model weights from the specified URL and saves them to the given file path.
preprocess(img): Preprocesses the input image.
run(img, box): Runs the semantic segmentation on the input image within the specified bounding box.
"""
model_path: str
device: str = "cuda:0"
model_weights_url: str = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
def __post_init__(self):
if not os.path.exists(self.model_path):
self.download_model_weights(self.model_path)
self.model = sam_model_registry["vit_b"](
checkpoint=self.model_path
)
self.model = self.model.to(self.device)
self.model.eval()
def download_model_weights(self, model_path: str):
"""
Downloads the model weights from the specified URL and saves them to the given file path.
Args:
model_path (str): The file path where the model weights will be saved.
Raises:
Exception: If the model weights fail to download.
"""
response = requests.get(self.model_weights_url, stream=True)
if response.status_code == 200:
with open(model_path, "wb") as f:
f.write(response.content)
else:
raise Exception("Failed to download model weights.")
def preprocess(self, img: np.ndarray) -> Tuple[Tensor, int, int]:
"""
Preprocesses the input image.
Args:
img: The input image.
Returns:
img_tensor: The preprocessed image tensor.
H: The original height of the image.
W: The original width of the image.
"""
if len(img.shape) == 2:
img = np.repeat(img[:, :, None], 3, axis=-1)
H, W, _ = img.shape
img = transform.resize(
img,
(1024, 1024),
order=3,
preserve_range=True,
anti_aliasing=True,
).astype(np.uint8)
img = img - img.min() / np.clip(
img.max() - img.min(), a_min=1e-8, a_max=None
)
img = torch.tensor(img).float().permute(2, 0, 1).unsqueeze(0)
return img, H, W
@torch.no_grad()
def run(self, img: np.ndarray, box: np.ndarray) -> np.ndarray:
"""
Runs the semantic segmentation on the input image within the specified bounding box.
Args:
img: The input image.
box: The bounding box coordinates (x1, y1, x2, y2).
Returns:
medsam_seg: The segmented image.
"""
img_tensor, H, W = self.preprocess(img)
img_tensor = img_tensor.to(self.device)
box_1024 = box / np.array([W, H, W, H]) * 1024
img = self.model.image_encoder(img_tensor)
box_torch = torch.as_tensor(
box_1024, dtype=torch.float, device=img_tensor.device
)
if len(box_torch.shape) == 2:
box_torch = box_torch[:, None, :]
sparse_embeddings, dense_embeddings = (
self.model.prompt_encoder(
points=None,
boxes=box_torch,
masks=None,
)
)
low_res_logits, _ = self.model.mask_decoder(
image_embeddings=img,
image_pe=self.model.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=False,
)
low_res_pred = torch.sigmoid(low_res_logits)
low_res_pred = F.interpolate(
low_res_pred,
size=(H, W),
mode="bilinear",
align_corners=False,
)
low_res_pred = low_res_pred.squeeze().cpu().numpy()
medsam_seg = (low_res_pred > 0.5).astype(np.uint8)
return medsam_seg

@ -1,47 +1,53 @@
import os
import supervision as sv
from ultraanalytics import YOLO
from ultralytics import YOLO
from tqdm import tqdm
from swarms.models.base_llm import AbstractLLM
from swarms.utils.download_weights_from_url import download_weights_from_url
class Odin(AbstractLLM):
"""
Odin class represents an object detection and tracking model.
Args:
source_weights_path (str): Path to the weights file for the object detection model.
source_video_path (str): Path to the source video file.
target_video_path (str): Path to save the output video file.
confidence_threshold (float): Confidence threshold for object detection.
iou_threshold (float): Intersection over Union (IoU) threshold for object detection.
Attributes:
source_weights_path (str): Path to the weights file for the object detection model.
source_video_path (str): Path to the source video file.
target_video_path (str): Path to save the output video file.
confidence_threshold (float): Confidence threshold for object detection.
iou_threshold (float): Intersection over Union (IoU) threshold for object detection.
source_weights_path (str): The file path to the YOLO model weights.
confidence_threshold (float): The confidence threshold for object detection.
iou_threshold (float): The intersection over union (IOU) threshold for object detection.
Example:
>>> odin = Odin(
... source_weights_path="yolo.weights",
... confidence_threshold=0.3,
... iou_threshold=0.7,
... )
>>> odin.run(video="input.mp4")
"""
def __init__(
self,
source_weights_path: str = None,
target_video_path: str = None,
source_weights_path: str = "yolo.weights",
confidence_threshold: float = 0.3,
iou_threshold: float = 0.7,
):
super(Odin, self).__init__()
self.source_weights_path = source_weights_path
self.target_video_path = target_video_path
self.confidence_threshold = confidence_threshold
self.iou_threshold = iou_threshold
if not os.path.exists(self.source_weights_path):
download_weights_from_url(
url=source_weights_path, save_path=self.source_weights_path
)
def run(self, video_path: str, *args, **kwargs):
def run(self, video: str, *args, **kwargs):
"""
Runs the object detection and tracking algorithm on the specified video.
Args:
video_path (str): The path to the input video file.
video (str): The path to the input video file.
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
@ -53,14 +59,14 @@ class Odin(AbstractLLM):
tracker = sv.ByteTrack()
box_annotator = sv.BoxAnnotator()
frame_generator = sv.get_video_frames_generator(
source_path=self.source_video_path
source_path=self.source_video
)
video_info = sv.VideoInfo.from_video_path(
video_path=video_path
video_info = sv.VideoInfo.from_video(
video=video
)
with sv.VideoSink(
target_path=self.target_video_path, video_info=video_info
target_path=self.target_video, video_info=video_info
) as sink:
for frame in tqdm(
frame_generator, total=video_info.total_frames

@ -5,11 +5,13 @@ from swarms.structs.agent import Agent
from swarms.utils.logger import logger
from swarms.structs.conversation import Conversation
@dataclass
class Company:
"""
Represents a company with a hierarchical organizational structure.
"""
org_chart: List[List[Agent]]
shared_instructions: str = None
ceo: Optional[Agent] = None
@ -171,5 +173,3 @@ class Company:
)
print(f"{task_description} is being executed")
agent.run(task_description)

@ -0,0 +1,19 @@
import requests
def download_weights_from_url(url: str, save_path: str = "models/weights.pth"):
"""
Downloads model weights from the given URL and saves them to the specified path.
Args:
url (str): The URL from which to download the model weights.
save_path (str, optional): The path where the downloaded weights should be saved.
Defaults to "models/weights.pth".
"""
response = requests.get(url, stream=True)
response.raise_for_status()
with open(save_path, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
print(f"Model weights downloaded and saved to {save_path}")
Loading…
Cancel
Save