You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
89 lines
2.8 KiB
89 lines
2.8 KiB
import json
|
|
import os
|
|
from typing import List
|
|
|
|
import timm
|
|
import torch
|
|
from PIL import Image
|
|
from pydantic import BaseModel, StrictFloat, StrictInt, validator
|
|
|
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
# Load the classes for image classification
|
|
with open(
|
|
os.path.join(os.path.dirname(__file__), "fast_vit_classes.json")
|
|
) as f:
|
|
FASTVIT_IMAGENET_1K_CLASSES = json.load(f)
|
|
|
|
|
|
class ClassificationResult(BaseModel):
|
|
class_id: List[StrictInt]
|
|
confidence: List[StrictFloat]
|
|
|
|
@validator("class_id", "confidence", pre=True, each_item=True)
|
|
def check_list_contents(cls, v):
|
|
assert isinstance(v, int) or isinstance(
|
|
v, float
|
|
), "must be integer or float"
|
|
return v
|
|
|
|
|
|
class FastViT:
|
|
"""
|
|
FastViT model for image classification
|
|
|
|
Args:
|
|
img (str): path to the input image
|
|
confidence_threshold (float): confidence threshold for the model's predictions
|
|
|
|
Returns:
|
|
ClassificationResult: a pydantic BaseModel containing the class ids and confidences of the model's predictions
|
|
|
|
|
|
Example:
|
|
>>> fastvit = FastViT()
|
|
>>> result = fastvit(img="path_to_image.jpg", confidence_threshold=0.5)
|
|
|
|
|
|
To use, create a json file called: fast_vit_classes.json
|
|
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.model = timm.create_model(
|
|
"hf_hub:timm/fastvit_s12.apple_in1k", pretrained=True
|
|
).to(DEVICE)
|
|
data_config = timm.data.resolve_model_data_config(self.model)
|
|
self.transforms = timm.data.create_transform(
|
|
**data_config, is_training=False
|
|
)
|
|
self.model.eval()
|
|
|
|
def __call__(
|
|
self, img: str, confidence_threshold: float = 0.5
|
|
) -> ClassificationResult:
|
|
"""classifies the input image and returns the top k classes and their probabilities"""
|
|
img = Image.open(img).convert("RGB")
|
|
img_tensor = self.transforms(img).unsqueeze(0).to(DEVICE)
|
|
with torch.no_grad():
|
|
output = self.model(img_tensor)
|
|
probabilities = torch.nn.functional.softmax(output, dim=1)
|
|
|
|
# Get top k classes and their probabilities
|
|
top_probs, top_classes = torch.topk(
|
|
probabilities, k=FASTVIT_IMAGENET_1K_CLASSES
|
|
)
|
|
|
|
# Filter by confidence threshold
|
|
mask = top_probs > confidence_threshold
|
|
top_probs, top_classes = top_probs[mask], top_classes[mask]
|
|
|
|
# Convert to Python lists and map class indices to labels if needed
|
|
top_probs = top_probs.cpu().numpy().tolist()
|
|
top_classes = top_classes.cpu().numpy().tolist()
|
|
# top_class_labels = [FASTVIT_IMAGENET_1K_CLASSES[i] for i in top_classes] # Uncomment if class labels are needed
|
|
|
|
return ClassificationResult(
|
|
class_id=top_classes, confidence=top_probs
|
|
)
|