diff --git a/playground/structs/company_example.py b/playground/structs/company_example.py index df2d2506..72396c61 100644 --- a/playground/structs/company_example.py +++ b/playground/structs/company_example.py @@ -19,7 +19,7 @@ va = Agent(llm=llm, ai_name="VA") # Create a company company = Company( - org_chart = [[dev, va]], + org_chart=[[dev, va]], shared_instructions="Do your best", ceo=ceo, ) diff --git a/pyproject.toml b/pyproject.toml index a6e3e64c..fccb186b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,6 +73,7 @@ psutil = "*" ultralytics = "*" timm = "*" supervision = "*" +scikit-image = "*" diff --git a/requirements.txt b/requirements.txt index d7befb85..ab78bb36 100644 --- a/requirements.txt +++ b/requirements.txt @@ -61,4 +61,5 @@ pre-commit==3.2.2 peft psutil ultralytics -supervision \ No newline at end of file +supervision +scikit-image \ No newline at end of file diff --git a/swarms/models/medical_sam.py b/swarms/models/medical_sam.py new file mode 100644 index 00000000..01e77c04 --- /dev/null +++ b/swarms/models/medical_sam.py @@ -0,0 +1,140 @@ +import os +from dataclasses import dataclass +from typing import Tuple + +import numpy as np +import requests +import torch +import torch.nn.functional as F +from skimage import transform +from torch import Tensor + + +@dataclass +class MedicalSAM: + """ + MedicalSAM class for performing semantic segmentation on medical images using the SAM model. + + Attributes: + model_path (str): The file path to the model weights. + device (str): The device to run the model on (default is "cuda:0"). + model_weights_url (str): The URL to download the model weights from. + + Methods: + __post_init__(): Initializes the MedicalSAM object. + download_model_weights(model_path: str): Downloads the model weights from the specified URL and saves them to the given file path. + preprocess(img): Preprocesses the input image. + run(img, box): Runs the semantic segmentation on the input image within the specified bounding box. + + """ + + model_path: str + device: str = "cuda:0" + model_weights_url: str = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth" + + def __post_init__(self): + if not os.path.exists(self.model_path): + self.download_model_weights(self.model_path) + + self.model = sam_model_registry["vit_b"]( + checkpoint=self.model_path + ) + self.model = self.model.to(self.device) + self.model.eval() + + def download_model_weights(self, model_path: str): + """ + Downloads the model weights from the specified URL and saves them to the given file path. + + Args: + model_path (str): The file path where the model weights will be saved. + + Raises: + Exception: If the model weights fail to download. + """ + response = requests.get(self.model_weights_url, stream=True) + if response.status_code == 200: + with open(model_path, "wb") as f: + f.write(response.content) + else: + raise Exception("Failed to download model weights.") + + def preprocess(self, img: np.ndarray) -> Tuple[Tensor, int, int]: + """ + Preprocesses the input image. + + Args: + img: The input image. + + Returns: + img_tensor: The preprocessed image tensor. + H: The original height of the image. + W: The original width of the image. + """ + if len(img.shape) == 2: + img = np.repeat(img[:, :, None], 3, axis=-1) + H, W, _ = img.shape + img = transform.resize( + img, + (1024, 1024), + order=3, + preserve_range=True, + anti_aliasing=True, + ).astype(np.uint8) + img = img - img.min() / np.clip( + img.max() - img.min(), a_min=1e-8, a_max=None + ) + img = torch.tensor(img).float().permute(2, 0, 1).unsqueeze(0) + return img, H, W + + @torch.no_grad() + def run(self, img: np.ndarray, box: np.ndarray) -> np.ndarray: + """ + Runs the semantic segmentation on the input image within the specified bounding box. + + Args: + img: The input image. + box: The bounding box coordinates (x1, y1, x2, y2). + + Returns: + medsam_seg: The segmented image. + """ + img_tensor, H, W = self.preprocess(img) + img_tensor = img_tensor.to(self.device) + box_1024 = box / np.array([W, H, W, H]) * 1024 + img = self.model.image_encoder(img_tensor) + + box_torch = torch.as_tensor( + box_1024, dtype=torch.float, device=img_tensor.device + ) + + if len(box_torch.shape) == 2: + box_torch = box_torch[:, None, :] + + sparse_embeddings, dense_embeddings = ( + self.model.prompt_encoder( + points=None, + boxes=box_torch, + masks=None, + ) + ) + + low_res_logits, _ = self.model.mask_decoder( + image_embeddings=img, + image_pe=self.model.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=False, + ) + + low_res_pred = torch.sigmoid(low_res_logits) + low_res_pred = F.interpolate( + low_res_pred, + size=(H, W), + mode="bilinear", + align_corners=False, + ) + low_res_pred = low_res_pred.squeeze().cpu().numpy() + medsam_seg = (low_res_pred > 0.5).astype(np.uint8) + + return medsam_seg diff --git a/swarms/models/odin.py b/swarms/models/odin.py index 1ab09893..a6228159 100644 --- a/swarms/models/odin.py +++ b/swarms/models/odin.py @@ -1,47 +1,53 @@ +import os import supervision as sv -from ultraanalytics import YOLO +from ultralytics import YOLO from tqdm import tqdm from swarms.models.base_llm import AbstractLLM - +from swarms.utils.download_weights_from_url import download_weights_from_url class Odin(AbstractLLM): """ Odin class represents an object detection and tracking model. - Args: - source_weights_path (str): Path to the weights file for the object detection model. - source_video_path (str): Path to the source video file. - target_video_path (str): Path to save the output video file. - confidence_threshold (float): Confidence threshold for object detection. - iou_threshold (float): Intersection over Union (IoU) threshold for object detection. - Attributes: - source_weights_path (str): Path to the weights file for the object detection model. - source_video_path (str): Path to the source video file. - target_video_path (str): Path to save the output video file. - confidence_threshold (float): Confidence threshold for object detection. - iou_threshold (float): Intersection over Union (IoU) threshold for object detection. + source_weights_path (str): The file path to the YOLO model weights. + confidence_threshold (float): The confidence threshold for object detection. + iou_threshold (float): The intersection over union (IOU) threshold for object detection. + + Example: + >>> odin = Odin( + ... source_weights_path="yolo.weights", + ... confidence_threshold=0.3, + ... iou_threshold=0.7, + ... ) + >>> odin.run(video="input.mp4") + + """ def __init__( self, - source_weights_path: str = None, - target_video_path: str = None, + source_weights_path: str = "yolo.weights", confidence_threshold: float = 0.3, iou_threshold: float = 0.7, ): super(Odin, self).__init__() self.source_weights_path = source_weights_path - self.target_video_path = target_video_path self.confidence_threshold = confidence_threshold self.iou_threshold = iou_threshold + + if not os.path.exists(self.source_weights_path): + download_weights_from_url( + url=source_weights_path, save_path=self.source_weights_path + ) + - def run(self, video_path: str, *args, **kwargs): + def run(self, video: str, *args, **kwargs): """ Runs the object detection and tracking algorithm on the specified video. Args: - video_path (str): The path to the input video file. + video (str): The path to the input video file. *args: Additional positional arguments. **kwargs: Additional keyword arguments. @@ -53,14 +59,14 @@ class Odin(AbstractLLM): tracker = sv.ByteTrack() box_annotator = sv.BoxAnnotator() frame_generator = sv.get_video_frames_generator( - source_path=self.source_video_path + source_path=self.source_video ) - video_info = sv.VideoInfo.from_video_path( - video_path=video_path + video_info = sv.VideoInfo.from_video( + video=video ) with sv.VideoSink( - target_path=self.target_video_path, video_info=video_info + target_path=self.target_video, video_info=video_info ) as sink: for frame in tqdm( frame_generator, total=video_info.total_frames diff --git a/swarms/structs/company.py b/swarms/structs/company.py index 80fe3eef..11b6d61f 100644 --- a/swarms/structs/company.py +++ b/swarms/structs/company.py @@ -5,11 +5,13 @@ from swarms.structs.agent import Agent from swarms.utils.logger import logger from swarms.structs.conversation import Conversation + @dataclass class Company: """ Represents a company with a hierarchical organizational structure. """ + org_chart: List[List[Agent]] shared_instructions: str = None ceo: Optional[Agent] = None @@ -171,5 +173,3 @@ class Company: ) print(f"{task_description} is being executed") agent.run(task_description) - - diff --git a/swarms/utils/download_weights_from_url.py b/swarms/utils/download_weights_from_url.py new file mode 100644 index 00000000..bc93d699 --- /dev/null +++ b/swarms/utils/download_weights_from_url.py @@ -0,0 +1,19 @@ +import requests + +def download_weights_from_url(url: str, save_path: str = "models/weights.pth"): + """ + Downloads model weights from the given URL and saves them to the specified path. + + Args: + url (str): The URL from which to download the model weights. + save_path (str, optional): The path where the downloaded weights should be saved. + Defaults to "models/weights.pth". + """ + response = requests.get(url, stream=True) + response.raise_for_status() + + with open(save_path, "wb") as f: + for chunk in response.iter_content(chunk_size=8192): + f.write(chunk) + + print(f"Model weights downloaded and saved to {save_path}") \ No newline at end of file