pull/378/head
Kye 11 months ago
parent c21af37f64
commit e8681b223c

@ -3,7 +3,7 @@ import os
from dotenv import load_dotenv from dotenv import load_dotenv
# Import the OpenAIChat model and the Agent struct # Import the OpenAIChat model and the Agent struct
from swarms import OpenAIChat, Agent from swarms import Agent, OpenAIChat
# Load the environment variables # Load the environment variables
load_dotenv() load_dotenv()
@ -22,7 +22,7 @@ llm = OpenAIChat(
## Initialize the workflow ## Initialize the workflow
agent = Agent( agent = Agent(
llm=llm, llm=llm,
max_loops=1, max_loops=4,
autosave=True, autosave=True,
dashboard=True, dashboard=True,
) )

@ -0,0 +1,33 @@
import os
from dotenv import load_dotenv
# Import the OpenAIChat model and the Agent struct
from swarms import Agent, HuggingfaceLLM
# Load the environment variables
load_dotenv()
# Get the API key from the environment
api_key = os.environ.get("OPENAI_API_KEY")
# Initialize the language model
llm = HuggingfaceLLM(
model_id="codellama/CodeLlama-70b-hf",
max_length=4000,
quantize=True,
temperature=0.5,
)
## Initialize the workflow
agent = Agent(
llm=llm,
max_loops="auto",
system_prompt=None,
autosave=True,
dashboard=True,
tools=[None],
)
# Run the workflow on a task
agent.run("Generate a 10,000 word blog on health and wellness.")

@ -0,0 +1,13 @@
from swarms import RoboflowMultiModal
# Initialize the model
model = RoboflowMultiModal(
api_key="api",
project_id="your project id",
hosted=False,
)
# Run the model on an img
out = model("img.png")

@ -1,7 +0,0 @@
from swarms import AutoScaler
auto_scaler = AutoScaler()
auto_scaler.start()
for i in range(100):
auto_scaler.add_task(f"Task {i}")

@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
[tool.poetry] [tool.poetry]
name = "swarms" name = "swarms"
version = "3.9.4" version = "3.9.5"
description = "Swarms - Pytorch" description = "Swarms - Pytorch"
license = "MIT" license = "MIT"
authors = ["Kye Gomez <kye@apac.ai>"] authors = ["Kye Gomez <kye@apac.ai>"]
@ -70,6 +70,7 @@ timm = "*"
supervision = "*" supervision = "*"
scikit-image = "*" scikit-image = "*"
pinecone-client = "*" pinecone-client = "*"
roboflow = "*"

@ -1,4 +1,4 @@
gitorch==2.1.1 torch==2.1.1
transformers transformers
pandas==1.5.3 pandas==1.5.3
langchain==0.0.333 langchain==0.0.333
@ -50,4 +50,5 @@ psutil
ultralytics ultralytics
supervision supervision
scikit-image scikit-image
pinecone-client pinecone-client
roboflow

@ -0,0 +1,13 @@
#!/bin/bash
# Change to the root directory
cd /
# Iterate over all the .py files in the directory
for file in *.py; do
# Get the base name of the file without the .py
base_name=$(basename "$file" .py)
# Rename the file to remove .py from the end
mv "$file" "${base_name}"
done

@ -50,6 +50,8 @@ from swarms.models.qwen import QwenVLMultiModal # noqa: E402
from swarms.models.clipq import CLIPQ # noqa: E402 from swarms.models.clipq import CLIPQ # noqa: E402
from swarms.models.kosmos_two import Kosmos # noqa: E402 from swarms.models.kosmos_two import Kosmos # noqa: E402
from swarms.models.fuyu import Fuyu # noqa: E402 from swarms.models.fuyu import Fuyu # noqa: E402
from swarms.models.roboflow_model import RoboflowMultiModal
from swarms.models.sam_supervision import SegmentAnythingMarkGenerator
# from swarms.models.dalle3 import Dalle3 # from swarms.models.dalle3 import Dalle3
# from swarms.models.distilled_whisperx import DistilWhisperModel # noqa: E402 # from swarms.models.distilled_whisperx import DistilWhisperModel # noqa: E402
@ -118,4 +120,6 @@ __all__ = [
"Kosmos", "Kosmos",
"Fuyu", "Fuyu",
"BaseEmbeddingModel", "BaseEmbeddingModel",
"RoboflowMultiModal",
"SegmentAnythingMarkGenerator",
] ]

@ -0,0 +1,81 @@
import cv2
from swarms.models.base_multimodal_model import BaseMultiModalModel
from swarms.models.sam_supervision import SegmentAnythingMarkGenerator
from swarms.utils.supervision_masking import refine_marks
from swarms.utils.supervision_visualizer import MarkVisualizer
from typing import Any
class GPT4VSAM(BaseMultiModalModel):
"""
GPT4VSAM class represents a multi-modal model that combines the capabilities of GPT-4 and SegmentAnythingMarkGenerator.
It takes an instance of BaseMultiModalModel (vlm) and a device as input and provides methods for loading images and making predictions.
Args:
vlm (BaseMultiModalModel): An instance of BaseMultiModalModel representing the visual language model.
device (str, optional): The device to be used for computation. Defaults to "cuda".
Attributes:
vlm (BaseMultiModalModel): An instance of BaseMultiModalModel representing the visual language model.
device (str): The device to be used for computation.
sam (SegmentAnythingMarkGenerator): An instance of SegmentAnythingMarkGenerator for generating marks.
visualizer (MarkVisualizer): An instance of MarkVisualizer for visualizing marks.
Methods:
load_img(img: str) -> Any: Loads an image from the given file path.
__call__(task: str, img: str, *args, **kwargs) -> Any: Makes predictions using the visual language model.
"""
def __init__(
self,
vlm: BaseMultiModalModel,
device: str = "cuda",
return_related_marks: bool = False,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.vlm = vlm
self.device = device
self.return_related_marks = return_related_marks
self.sam = SegmentAnythingMarkGenerator(
device, *args, **kwargs
)
self.visualizer = MarkVisualizer(*args, **kwargs)
def load_img(self, img: str) -> Any:
"""
Loads an image from the given file path.
Args:
img (str): The file path of the image.
Returns:
Any: The loaded image.
"""
return cv2.imread(img)
def __call__(self, task: str, img: str, *args, **kwargs) -> Any:
"""
Makes predictions using the visual language model.
Args:
task (str): The task for which predictions are to be made.
img (str): The file path of the image.
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
Any: The predictions made by the visual language model.
"""
img = self.load_img(img)
marks = self.sam(image=img)
marks = refine_marks(marks=marks)
return self.vlm(task, img, *args, **kwargs)

@ -1,6 +1,6 @@
import os import os
import supervision as sv import supervision as sv
from ultralytics import YOLO from ultralytics_example import YOLO
from tqdm import tqdm from tqdm import tqdm
from swarms.models.base_llm import AbstractLLM from swarms.models.base_llm import AbstractLLM
from swarms.utils.download_weights_from_url import ( from swarms.utils.download_weights_from_url import (

@ -0,0 +1,64 @@
from typing import Union
from roboflow import Roboflow
from swarms.models.base_multimodal_model import BaseMultiModalModel
class RoboflowMultiModal(BaseMultiModalModel):
"""
Initializes the RoboflowModel with the given API key, project ID, and version.
Args:
api_key (str): The API key for Roboflow.
project_id (str): The ID of the project.
version (str): The version of the model.
confidence (int, optional): The confidence threshold. Defaults to 50.
overlap (int, optional): The overlap threshold. Defaults to 25.
"""
def __init__(
self,
api_key: str,
project_id: str,
version: str,
confidence: int = 50,
overlap: int = 25,
hosted: bool = False,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.api_key = api_key
self.project_id = project_id
self.verison = version
self.confidence = confidence
self.overlap = overlap
self.hosted = hosted
try:
rf = Roboflow(api_key=api_key, *args, **kwargs)
project = rf.workspace().project(project_id)
self.model = project.version(version).model
self.model.confidence = confidence
self.model.overlap = overlap
except Exception as e:
print(f"Error initializing RoboflowModel: {str(e)}")
def __call__(self, img: Union[str, bytes]):
"""
Runs inference on an image and retrieves predictions.
Args:
img (Union[str, bytes]): The path to the image or the URL of the image.
hosted (bool, optional): Whether the image is hosted. Defaults to False.
Returns:
Optional[roboflow.Prediction]: The prediction or None if an error occurs.
"""
try:
prediction = self.model.predict(img, hosted=self.hosted)
return prediction
except Exception as e:
print(f"Error running inference: {str(e)}")
return None

@ -0,0 +1,116 @@
import cv2
import numpy as np
import supervision as sv
from PIL import Image
from transformers import (
pipeline,
SamModel,
SamProcessor,
SamImageProcessor,
)
from typing import Optional
from swarms.utils.supervision_masking import masks_to_marks
from swarms.models.base_multimodal_model import BaseMultiModalModel
class SegmentAnythingMarkGenerator(BaseMultiModalModel):
"""
A class for performing image segmentation using a specified model.
Parameters:
device (str): The device to run the model on (e.g., 'cpu', 'cuda').
model_name (str): The name of the model to be loaded. Defaults to
'facebook/sam-vit-huge'.
"""
def __init__(
self,
device: str = "cpu",
model_name: str = "facebook/sam-vit-huge",
visualize_marks: bool = False,
*args,
**kwargs,
):
super(SegmentAnythingMarkGenerator).__init__(*args, **kwargs)
self.device = device
self.model_name = model_name
self.visualize_marks = visualize_marks
self.model = SamModel.from_pretrained(
model_name, *args, **kwargs
).to(device)
self.processor = SamProcessor.from_pretrained(model_name)
self.image_processor = SamImageProcessor.from_pretrained(
model_name
)
self.device = device
self.pipeline = pipeline(
task="mask-generation",
model=self.model,
image_processor=self.image_processor,
device=self.device,
)
def __call__(
self, image: np.ndarray, mask: Optional[np.ndarray] = None
) -> sv.Detections:
"""
Generate image segmentation marks.
Parameters:
image (np.ndarray): The image to be marked in BGR format.
mask: (Optional[np.ndarray]): The mask to be used as a guide for
segmentation.
Returns:
sv.Detections: An object containing the segmentation masks and their
corresponding bounding box coordinates.
"""
image = Image.fromarray(
cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
)
if mask is None:
outputs = self.pipeline(image, points_per_batch=64)
masks = np.array(outputs["masks"])
return masks_to_marks(masks=masks)
else:
inputs = self.processor(image, return_tensors="pt").to(
self.device
)
image_embeddings = self.model.get_image_embeddings(
inputs.pixel_values
)
masks = []
for polygon in sv.mask_to_polygons(mask.astype(bool)):
indexes = np.random.choice(
a=polygon.shape[0], size=5, replace=True
)
input_points = polygon[indexes]
inputs = self.processor(
images=image,
input_points=[[input_points]],
return_tensors="pt",
).to(self.device)
del inputs["pixel_values"]
outputs = self.model(
image_embeddings=image_embeddings, **inputs
)
mask = (
self.processor.image_processor.post_process_masks(
masks=outputs.pred_masks.cpu().detach(),
original_sizes=inputs["original_sizes"]
.cpu()
.detach(),
reshaped_input_sizes=inputs[
"reshaped_input_sizes"
]
.cpu()
.detach(),
)[0][0][0].numpy()
)
masks.append(mask)
masks = np.array(masks)
return masks_to_marks(masks=masks)
# def visualize_img(self):

@ -1,5 +1,6 @@
from swarms.models.base_multimodal_model import BaseMultiModalModel from swarms.models.base_multimodal_model import BaseMultiModalModel
from ultralytics import YOLO from ultralytics import YOLO
from typing import List
class UltralyticsModel(BaseMultiModalModel): class UltralyticsModel(BaseMultiModalModel):
@ -12,13 +13,22 @@ class UltralyticsModel(BaseMultiModalModel):
**kwargs: Arbitrary keyword arguments. **kwargs: Arbitrary keyword arguments.
""" """
def __init__(self, model_name: str, *args, **kwargs): def __init__(
self, model_name: str = "yolov8n.pt", *args, **kwargs
):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.model_name = model_name self.model_name = model_name
self.model = YOLO(model_name, *args, **kwargs) try:
self.model = YOLO(model_name, *args, **kwargs)
except Exception as e:
raise ValueError(
f"Failed to initialize Ultralytics model: {str(e)}"
)
def __call__(self, task: str, *args, **kwargs): def __call__(
self, task: str, tasks: List[str] = None, *args, **kwargs
):
""" """
Calls the Ultralytics model. Calls the Ultralytics model.
@ -30,4 +40,13 @@ class UltralyticsModel(BaseMultiModalModel):
Returns: Returns:
The result of the model call. The result of the model call.
""" """
return self.model(task, *args, **kwargs) try:
if tasks:
return self.model([tasks], *args, **kwargs)
else:
return self.model(task, *args, **kwargs)
except Exception as e:
raise ValueError(
f"Failed to perform task '{task}' with Ultralytics"
f" model: {str(e)}"
)

@ -16,6 +16,7 @@ from swarms.prompts.agent_system_prompts import (
from swarms.prompts.multi_modal_autonomous_instruction_prompt import ( from swarms.prompts.multi_modal_autonomous_instruction_prompt import (
MULTI_MODAL_AUTO_AGENT_SYSTEM_PROMPT_1, MULTI_MODAL_AUTO_AGENT_SYSTEM_PROMPT_1,
) )
from swarms.tokenizers.base_tokenizer import BaseTokenizer
from swarms.tools.tool import BaseTool from swarms.tools.tool import BaseTool
from swarms.utils.code_interpreter import SubprocessCodeInterpreter from swarms.utils.code_interpreter import SubprocessCodeInterpreter
from swarms.utils.data_to_text import data_to_text from swarms.utils.data_to_text import data_to_text
@ -151,7 +152,7 @@ class Agent:
dynamic_loops: Optional[bool] = False, dynamic_loops: Optional[bool] = False,
interactive: bool = False, interactive: bool = False,
dashboard: bool = False, dashboard: bool = False,
agent_name: str = None, agent_name: str = "swarm-worker-01",
agent_description: str = None, agent_description: str = None,
system_prompt: str = AGENT_SYSTEM_PROMPT_3, system_prompt: str = AGENT_SYSTEM_PROMPT_3,
tools: List[BaseTool] = None, tools: List[BaseTool] = None,
@ -167,7 +168,7 @@ class Agent:
multi_modal: Optional[bool] = None, multi_modal: Optional[bool] = None,
pdf_path: Optional[str] = None, pdf_path: Optional[str] = None,
list_of_pdf: Optional[str] = None, list_of_pdf: Optional[str] = None,
tokenizer: Optional[Any] = None, tokenizer: Optional[BaseTokenizer] = None,
long_term_memory: Optional[AbstractVectorDatabase] = None, long_term_memory: Optional[AbstractVectorDatabase] = None,
preset_stopping_token: Optional[bool] = False, preset_stopping_token: Optional[bool] = False,
traceback: Any = None, traceback: Any = None,
@ -187,7 +188,7 @@ class Agent:
self.retry_attempts = retry_attempts self.retry_attempts = retry_attempts
self.retry_interval = retry_interval self.retry_interval = retry_interval
self.task = None self.task = None
self.stopping_token = stopping_token # or "<DONE>" self.stopping_token = stopping_token
self.interactive = interactive self.interactive = interactive
self.dashboard = dashboard self.dashboard = dashboard
self.return_history = return_history self.return_history = return_history
@ -248,9 +249,14 @@ class Agent:
if self.docs: if self.docs:
self.ingest_docs(self.docs) self.ingest_docs(self.docs)
# If docs folder exists then get the docs from docs folder
if self.docs_folder: if self.docs_folder:
self.get_docs_from_doc_folders() self.get_docs_from_doc_folders()
# If tokenizer and context length exists then:
if self.tokenizer and self.context_length:
self.truncate_history()
def set_system_prompt(self, system_prompt: str): def set_system_prompt(self, system_prompt: str):
"""Set the system prompt""" """Set the system prompt"""
self.system_prompt = system_prompt self.system_prompt = system_prompt
@ -299,22 +305,6 @@ class Agent:
"""Format the template with the provided kwargs using f-string interpolation.""" """Format the template with the provided kwargs using f-string interpolation."""
return template.format(**kwargs) return template.format(**kwargs)
def truncate_history(self):
"""
Take the history and truncate it to fit into the model context length
"""
# truncated_history = self.short_memory[-1][-self.context_length :]
# self.short_memory[-1] = truncated_history
# out = limit_tokens_from_string(
# "\n".join(truncated_history), self.llm.model_name
# )
truncated_history = self.short_memory[-1][
-self.context_length :
]
text = "\n".join(truncated_history)
out = limit_tokens_from_string(text, "gpt-4")
return out
def add_task_to_memory(self, task: str): def add_task_to_memory(self, task: str):
"""Add the task to the memory""" """Add the task to the memory"""
try: try:
@ -1155,6 +1145,27 @@ class Agent:
message = f"{agent_name}: {message}" message = f"{agent_name}: {message}"
return self.run(message, *args, **kwargs) return self.run(message, *args, **kwargs)
def truncate_history(self):
"""
Truncates the short-term memory of the agent based on the count of tokens.
The method counts the tokens in the short-term memory using the tokenizer and
compares it with the length of the memory. If the length of the memory is greater
than the count, the memory is truncated to match the count.
Parameters:
None
Returns:
None
"""
# Count the short term history with the tokenizer
count = self.tokenizer.count_tokens(self.short_memory)
# Now the logic that truncates the memory if it's more than the count
if len(self.short_memory) > count:
self.short_memory = self.short_memory[:count]
def get_docs_from_doc_folders(self): def get_docs_from_doc_folders(self):
"""Get the docs from the files""" """Get the docs from the files"""
# Get the list of files then extract them and add them to the memory # Get the list of files then extract them and add them to the memory

@ -5,6 +5,7 @@ from termcolor import colored
from swarms.memory.base_db import AbstractDatabase from swarms.memory.base_db import AbstractDatabase
from swarms.structs.base import BaseStructure from swarms.structs.base import BaseStructure
from swarms.tokenizers.base_tokenizer import BaseTokenizer
class Conversation(BaseStructure): class Conversation(BaseStructure):
@ -65,6 +66,8 @@ class Conversation(BaseStructure):
database: AbstractDatabase = None, database: AbstractDatabase = None,
autosave: bool = False, autosave: bool = False,
save_filepath: str = None, save_filepath: str = None,
tokenizer: BaseTokenizer = None,
context_length: int = 8192,
*args, *args,
**kwargs, **kwargs,
): ):
@ -75,11 +78,17 @@ class Conversation(BaseStructure):
self.autosave = autosave self.autosave = autosave
self.save_filepath = save_filepath self.save_filepath = save_filepath
self.conversation_history = [] self.conversation_history = []
self.tokenizer = tokenizer
self.context_length = context_length
# If system prompt is not None, add it to the conversation history # If system prompt is not None, add it to the conversation history
if self.system_prompt: if self.system_prompt:
self.add("system", self.system_prompt) self.add("system", self.system_prompt)
# If tokenizer then truncate
if tokenizer:
self.truncate_memory_with_tokenizer()
def add(self, role: str, content: str, *args, **kwargs): def add(self, role: str, content: str, *args, **kwargs):
"""Add a message to the conversation history """Add a message to the conversation history
@ -348,3 +357,40 @@ class Conversation(BaseStructure):
def fetch_one_from_database(self, *args, **kwargs): def fetch_one_from_database(self, *args, **kwargs):
"""Fetch one from the database""" """Fetch one from the database"""
return self.database.fetch_one() return self.database.fetch_one()
def truncate_memory_with_tokenizer(self):
"""
Truncates the conversation history based on the total number of tokens using a tokenizer.
Returns:
None
"""
total_tokens = 0
truncated_history = []
for message in self.conversation_history:
role = message.get("role")
content = message.get("content")
tokens = self.tokenizer.count_tokens(
text=content
) # Count the number of tokens
count = tokens # Assign the token count
total_tokens += count
if total_tokens <= self.context_length:
truncated_history.append(message)
else:
remaining_tokens = self.context_length - (
total_tokens - count
)
truncated_content = content[
:remaining_tokens
] # Truncate the content based on the remaining tokens
truncated_message = {
"role": role,
"content": truncated_content,
}
truncated_history.append(truncated_message)
break
self.conversation_history = truncated_history

@ -33,6 +33,16 @@ from swarms.utils.remove_json_whitespace import (
remove_whitespace_from_yaml, remove_whitespace_from_yaml,
) )
from swarms.utils.exponential_backoff import ExponentialBackoffMixin from swarms.utils.exponential_backoff import ExponentialBackoffMixin
from swarms.utils.download_img import download_img_from_url
from swarms.utils.supervision_masking import (
FeatureType,
compute_mask_iou_vectorized,
mask_non_max_suppression,
filter_masks_by_relative_area,
masks_to_marks,
refine_marks,
)
from swarms.utils.supervision_visualizer import MarkVisualizer
__all__ = [ __all__ = [
"SubprocessCodeInterpreter", "SubprocessCodeInterpreter",
@ -59,4 +69,12 @@ __all__ = [
"remove_whitespace_from_json", "remove_whitespace_from_json",
"remove_whitespace_from_yaml", "remove_whitespace_from_yaml",
"ExponentialBackoffMixin", "ExponentialBackoffMixin",
"download_img_from_url",
"FeatureType",
"compute_mask_iou_vectorized",
"mask_non_max_suppression",
"filter_masks_by_relative_area",
"masks_to_marks",
"refine_marks",
"MarkVisualizer",
] ]

@ -0,0 +1,31 @@
from io import BytesIO
import requests
from PIL import Image
def download_img_from_url(url: str):
"""
Downloads an image from the given URL and saves it locally.
Args:
url (str): The URL of the image to download.
Raises:
ValueError: If the URL is empty or invalid.
IOError: If there is an error while downloading or saving the image.
"""
if not url:
raise ValueError("URL cannot be empty.")
try:
response = requests.get(url)
response.raise_for_status()
image = Image.open(BytesIO(response.content))
image.save("downloaded_image.jpg")
print("Image downloaded successfully.")
except requests.exceptions.RequestException as e:
raise IOError("Error while downloading the image.") from e
except IOError as e:
raise IOError("Error while saving the image.") from e

@ -1,4 +1,6 @@
import logging import logging
import functools
logger = logging.getLogger() logger = logging.getLogger()
formatter = logging.Formatter("%(message)s") formatter = logging.Formatter("%(message)s")
@ -10,3 +12,35 @@ ch.setFormatter(formatter)
logger.addHandler(ch) logger.addHandler(ch)
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
def log_wrapper(func):
"""
A decorator that logs the inputs, outputs, and any exceptions of the function it wraps.
Args:
func (callable): The function to wrap.
Returns:
callable: The wrapped function.
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
logger.debug(
f"Calling function {func.__name__} with args {args} and"
f" kwargs {kwargs}"
)
try:
result = func(*args, **kwargs)
logger.debug(
f"Function {func.__name__} returned {result}"
)
return result
except Exception as e:
logger.error(
f"Function {func.__name__} raised an exception: {e}"
)
raise
return wrapper

@ -0,0 +1,259 @@
from enum import Enum
import cv2
import numpy as np
import supervision as sv
class FeatureType(Enum):
"""
An enumeration to represent the types of features for mask adjustment in image
segmentation.
"""
ISLAND = "ISLAND"
HOLE = "HOLE"
@classmethod
def list(cls):
return list(map(lambda c: c.value, cls))
def compute_mask_iou_vectorized(masks: np.ndarray) -> np.ndarray:
"""
Vectorized computation of the Intersection over Union (IoU) for all pairs of masks.
Parameters:
masks (np.ndarray): A 3D numpy array with shape `(N, H, W)`, where `N` is the
number of masks, `H` is the height, and `W` is the width.
Returns:
np.ndarray: A 2D numpy array of shape `(N, N)` where each element `[i, j]` is
the IoU between masks `i` and `j`.
Raises:
ValueError: If any of the masks is found to be empty.
"""
if np.any(masks.sum(axis=(1, 2)) == 0):
raise ValueError(
"One or more masks are empty. Please filter out empty"
" masks before using `compute_iou_vectorized` function."
)
masks_bool = masks.astype(bool)
masks_flat = masks_bool.reshape(masks.shape[0], -1)
intersection = np.logical_and(
masks_flat[:, None], masks_flat[None, :]
).sum(axis=2)
union = np.logical_or(
masks_flat[:, None], masks_flat[None, :]
).sum(axis=2)
iou_matrix = intersection / union
return iou_matrix
def mask_non_max_suppression(
masks: np.ndarray, iou_threshold: float = 0.6
) -> np.ndarray:
"""
Performs Non-Max Suppression on a set of masks by prioritizing larger masks and
removing smaller masks that overlap significantly.
When the IoU between two masks exceeds the specified threshold, the smaller mask
(in terms of area) is discarded. This process is repeated for each pair of masks,
effectively filtering out masks that are significantly overlapped by larger ones.
Parameters:
masks (np.ndarray): A 3D numpy array with shape `(N, H, W)`, where `N` is the
number of masks, `H` is the height, and `W` is the width.
iou_threshold (float): The IoU threshold for determining significant overlap.
Returns:
np.ndarray: A 3D numpy array of filtered masks.
"""
num_masks = masks.shape[0]
areas = masks.sum(axis=(1, 2))
sorted_idx = np.argsort(-areas)
keep_mask = np.ones(num_masks, dtype=bool)
iou_matrix = compute_mask_iou_vectorized(masks)
for i in range(num_masks):
if not keep_mask[sorted_idx[i]]:
continue
overlapping_masks = iou_matrix[sorted_idx[i]] > iou_threshold
overlapping_masks[sorted_idx[i]] = False
overlapping_indices = np.where(overlapping_masks)[0]
keep_mask[sorted_idx[overlapping_indices]] = False
return masks[keep_mask]
def filter_masks_by_relative_area(
masks: np.ndarray,
minimum_area: float = 0.01,
maximum_area: float = 1.0,
) -> np.ndarray:
"""
Filters masks based on their relative area within the total area of each mask.
Parameters:
masks (np.ndarray): A 3D numpy array with shape `(N, H, W)`, where `N` is the
number of masks, `H` is the height, and `W` is the width.
minimum_area (float): The minimum relative area threshold. Must be between `0`
and `1`.
maximum_area (float): The maximum relative area threshold. Must be between `0`
and `1`.
Returns:
np.ndarray: A 3D numpy array containing masks that fall within the specified
relative area range.
Raises:
ValueError: If `minimum_area` or `maximum_area` are outside the `0` to `1`
range, or if `minimum_area` is greater than `maximum_area`.
"""
if not (isinstance(masks, np.ndarray) and masks.ndim == 3):
raise ValueError("Input must be a 3D numpy array.")
if not (0 <= minimum_area <= 1) or not (0 <= maximum_area <= 1):
raise ValueError(
"`minimum_area` and `maximum_area` must be between 0"
" and 1."
)
if minimum_area > maximum_area:
raise ValueError(
"`minimum_area` must be less than or equal to"
" `maximum_area`."
)
total_area = masks.shape[1] * masks.shape[2]
relative_areas = masks.sum(axis=(1, 2)) / total_area
return masks[
(relative_areas >= minimum_area)
& (relative_areas <= maximum_area)
]
def adjust_mask_features_by_relative_area(
mask: np.ndarray,
area_threshold: float,
feature_type: FeatureType = FeatureType.ISLAND,
) -> np.ndarray:
"""
Adjusts a mask by removing small islands or filling small holes based on a relative
area threshold.
!!! warning
Running this function on a mask with small islands may result in empty masks.
Parameters:
mask (np.ndarray): A 2D numpy array with shape `(H, W)`, where `H` is the
height, and `W` is the width.
area_threshold (float): Threshold for relative area to remove or fill features.
feature_type (FeatureType): Type of feature to adjust (`ISLAND` for removing
islands, `HOLE` for filling holes).
Returns:
np.ndarray: A 2D numpy array containing mask.
"""
height, width = mask.shape
total_area = width * height
mask = np.uint8(mask * 255)
operation = (
cv2.RETR_EXTERNAL
if feature_type == FeatureType.ISLAND
else cv2.RETR_CCOMP
)
contours, _ = cv2.findContours(
mask, operation, cv2.CHAIN_APPROX_SIMPLE
)
for contour in contours:
area = cv2.contourArea(contour)
relative_area = area / total_area
if relative_area < area_threshold:
cv2.drawContours(
image=mask,
contours=[contour],
contourIdx=-1,
color=(
0 if feature_type == FeatureType.ISLAND else 255
),
thickness=-1,
)
return np.where(mask > 0, 1, 0).astype(bool)
def masks_to_marks(masks: np.ndarray) -> sv.Detections:
"""
Converts a set of masks to a marks (sv.Detections) object.
Parameters:
masks (np.ndarray): A 3D numpy array with shape `(N, H, W)`, where `N` is the
number of masks, `H` is the height, and `W` is the width.
Returns:
sv.Detections: An object containing the masks and their bounding box
coordinates.
"""
if len(masks) == 0:
marks = sv.Detections.empty()
marks.mask = np.empty((0, 0, 0), dtype=bool)
return marks
return sv.Detections(
mask=masks, xyxy=sv.mask_to_xyxy(masks=masks)
)
def refine_marks(
marks: sv.Detections,
maximum_hole_area: float = 0.01,
maximum_island_area: float = 0.01,
minimum_mask_area: float = 0.02,
maximum_mask_area: float = 1.0,
) -> sv.Detections:
"""
Refines a set of masks by removing small islands and holes, and filtering by mask
area.
Parameters:
marks (sv.Detections): An object containing the masks and their bounding box
coordinates.
maximum_hole_area (float): The maximum relative area of holes to be filled in
each mask.
maximum_island_area (float): The maximum relative area of islands to be removed
from each mask.
minimum_mask_area (float): The minimum relative area for a mask to be retained.
maximum_mask_area (float): The maximum relative area for a mask to be retained.
Returns:
sv.Detections: An object containing the masks and their bounding box
coordinates.
"""
result_masks = []
for mask in marks.mask:
mask = adjust_mask_features_by_relative_area(
mask=mask,
area_threshold=maximum_island_area,
feature_type=FeatureType.ISLAND,
)
mask = adjust_mask_features_by_relative_area(
mask=mask,
area_threshold=maximum_hole_area,
feature_type=FeatureType.HOLE,
)
if np.any(mask):
result_masks.append(mask)
result_masks = np.array(result_masks)
result_masks = filter_masks_by_relative_area(
masks=result_masks,
minimum_area=minimum_mask_area,
maximum_area=maximum_mask_area,
)
return sv.Detections(
mask=result_masks, xyxy=sv.mask_to_xyxy(masks=result_masks)
)

@ -0,0 +1,85 @@
import numpy as np
import supervision as sv
class MarkVisualizer:
"""
A class for visualizing different marks including bounding boxes, masks, polygons,
and labels.
Parameters:
line_thickness (int): The thickness of the lines for boxes and polygons.
mask_opacity (float): The opacity level for masks.
text_scale (float): The scale of the text for labels.
"""
def __init__(
self,
line_thickness: int = 2,
mask_opacity: float = 0.1,
text_scale: float = 0.6,
) -> None:
self.box_annotator = sv.BoundingBoxAnnotator(
color_lookup=sv.ColorLookup.INDEX,
thickness=line_thickness,
)
self.mask_annotator = sv.MaskAnnotator(
color_lookup=sv.ColorLookup.INDEX, opacity=mask_opacity
)
self.polygon_annotator = sv.PolygonAnnotator(
color_lookup=sv.ColorLookup.INDEX,
thickness=line_thickness,
)
self.label_annotator = sv.LabelAnnotator(
color=sv.Color.black(),
text_color=sv.Color.white(),
color_lookup=sv.ColorLookup.INDEX,
text_position=sv.Position.CENTER_OF_MASS,
text_scale=text_scale,
)
def visualize(
self,
image: np.ndarray,
marks: sv.Detections,
with_box: bool = False,
with_mask: bool = False,
with_polygon: bool = True,
with_label: bool = True,
) -> np.ndarray:
"""
Visualizes annotations on an image.
This method takes an image and an instance of sv.Detections, and overlays
the specified types of marks (boxes, masks, polygons, labels) on the image.
Parameters:
image (np.ndarray): The image on which to overlay annotations.
marks (sv.Detections): The detection results containing the annotations.
with_box (bool): Whether to draw bounding boxes. Defaults to False.
with_mask (bool): Whether to overlay masks. Defaults to False.
with_polygon (bool): Whether to draw polygons. Defaults to True.
with_label (bool): Whether to add labels. Defaults to True.
Returns:
np.ndarray: The annotated image.
"""
annotated_image = image.copy()
if with_box:
annotated_image = self.box_annotator.annotate(
scene=annotated_image, detections=marks
)
if with_mask:
annotated_image = self.mask_annotator.annotate(
scene=annotated_image, detections=marks
)
if with_polygon:
annotated_image = self.polygon_annotator.annotate(
scene=annotated_image, detections=marks
)
if with_label:
labels = list(map(str, range(len(marks))))
annotated_image = self.label_annotator.annotate(
scene=annotated_image, detections=marks, labels=labels
)
return annotated_image
Loading…
Cancel
Save