parent
da120e1aef
commit
0c4dd88f98
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,80 @@
|
||||
import json
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import timm
|
||||
import torch
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, StrictFloat, StrictInt, validator
|
||||
|
||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# Load the classes for image classification
|
||||
with open(os.path.join(os.path.dirname(__file__), "fast_vit_classes.json")) as f:
|
||||
FASTVIT_IMAGENET_1K_CLASSES = json.load(f)
|
||||
|
||||
|
||||
class ClassificationResult(BaseModel):
|
||||
class_id: List[StrictInt]
|
||||
confidence: List[StrictFloat]
|
||||
|
||||
@validator("class_id", "confidence", pre=True, each_item=True)
|
||||
def check_list_contents(cls, v):
|
||||
assert isinstance(v, int) or isinstance(v, float), "must be integer or float"
|
||||
return v
|
||||
|
||||
|
||||
class FastViT:
|
||||
"""
|
||||
FastViT model for image classification
|
||||
|
||||
Args:
|
||||
img (str): path to the input image
|
||||
confidence_threshold (float): confidence threshold for the model's predictions
|
||||
|
||||
Returns:
|
||||
ClassificationResult: a pydantic BaseModel containing the class ids and confidences of the model's predictions
|
||||
|
||||
|
||||
Example:
|
||||
>>> fastvit = FastViT()
|
||||
>>> result = fastvit(img="path_to_image.jpg", confidence_threshold=0.5)
|
||||
|
||||
|
||||
To use, create a json file called: fast_vit_classes.json
|
||||
|
||||
"""
|
||||
def __init__(self):
|
||||
self.model = timm.create_model(
|
||||
"hf_hub:timm/fastvit_s12.apple_in1k", pretrained=True
|
||||
).to(DEVICE)
|
||||
data_config = timm.data.resolve_model_data_config(self.model)
|
||||
self.transforms = timm.data.create_transform(**data_config, is_training=False)
|
||||
self.model.eval()
|
||||
|
||||
def __call__(
|
||||
self, img: str, confidence_threshold: float = 0.5
|
||||
) -> ClassificationResult:
|
||||
"""classifies the input image and returns the top k classes and their probabilities"""
|
||||
img = Image.open(img).convert("RGB")
|
||||
img_tensor = self.transforms(img).unsqueeze(0).to(DEVICE)
|
||||
with torch.no_grad():
|
||||
output = self.model(img_tensor)
|
||||
probabilities = torch.nn.functional.softmax(output, dim=1)
|
||||
|
||||
# Get top k classes and their probabilities
|
||||
top_probs, top_classes = torch.topk(
|
||||
probabilities, k=FASTVIT_IMAGENET_1K_CLASSES
|
||||
)
|
||||
|
||||
# Filter by confidence threshold
|
||||
mask = top_probs > confidence_threshold
|
||||
top_probs, top_classes = top_probs[mask], top_classes[mask]
|
||||
|
||||
# Convert to Python lists and map class indices to labels if needed
|
||||
top_probs = top_probs.cpu().numpy().tolist()
|
||||
top_classes = top_classes.cpu().numpy().tolist()
|
||||
# top_class_labels = [FASTVIT_IMAGENET_1K_CLASSES[i] for i in top_classes] # Uncomment if class labels are needed
|
||||
|
||||
return ClassificationResult(class_id=top_classes, confidence=top_probs)
|
@ -0,0 +1,100 @@
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, root_validator, validator
|
||||
from transformers import AutoModelForVision2Seq, AutoProcessor
|
||||
|
||||
|
||||
# Assuming the Detections class represents the output of the model prediction
|
||||
class Detections(BaseModel):
|
||||
xyxy: List[Tuple[float, float, float, float]]
|
||||
class_id: List[int]
|
||||
confidence: List[float]
|
||||
|
||||
@root_validator
|
||||
def check_length(cls, values):
|
||||
assert (
|
||||
len(values.get("xyxy"))
|
||||
== len(values.get("class_id"))
|
||||
== len(values.get("confidence"))
|
||||
), "All fields must have the same length."
|
||||
return values
|
||||
|
||||
@validator("xyxy", "class_id", "confidence", pre=True, each_item=True)
|
||||
def check_not_empty(cls, v):
|
||||
if isinstance(v, list) and len(v) == 0:
|
||||
raise ValueError("List must not be empty")
|
||||
return v
|
||||
|
||||
@classmethod
|
||||
def empty(cls):
|
||||
return cls(xyxy=[], class_id=[], confidence=[])
|
||||
|
||||
|
||||
class Kosmos2(BaseModel):
|
||||
model: AutoModelForVision2Seq
|
||||
processor: AutoProcessor
|
||||
|
||||
@classmethod
|
||||
def initialize(cls):
|
||||
model = AutoModelForVision2Seq.from_pretrained(
|
||||
"ydshieh/kosmos-2-patch14-224", trust_remote_code=True
|
||||
)
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
"ydshieh/kosmos-2-patch14-224", trust_remote_code=True
|
||||
)
|
||||
return cls(model=model, processor=processor)
|
||||
|
||||
def __call__(self, img: str) -> Detections:
|
||||
image = Image.open(img)
|
||||
prompt = "<grounding>An image of"
|
||||
|
||||
inputs = self.processor(text=prompt, images=image, return_tensors="pt")
|
||||
outputs = self.model.generate(**inputs, use_cache=True, max_new_tokens=64)
|
||||
|
||||
generated_text = self.processor.batch_decode(outputs, skip_special_tokens=True)[
|
||||
0
|
||||
]
|
||||
|
||||
# The actual processing of generated_text to entities would go here
|
||||
# For the purpose of this example, assume a mock function 'extract_entities' exists:
|
||||
entities = self.extract_entities(generated_text)
|
||||
|
||||
# Convert entities to detections format
|
||||
detections = self.process_entities_to_detections(entities, image)
|
||||
return detections
|
||||
|
||||
def extract_entities(
|
||||
self, text: str
|
||||
) -> List[Tuple[str, Tuple[float, float, float, float]]]:
|
||||
# Placeholder function for entity extraction
|
||||
# This should be replaced with the actual method of extracting entities
|
||||
return []
|
||||
|
||||
def process_entities_to_detections(
|
||||
self,
|
||||
entities: List[Tuple[str, Tuple[float, float, float, float]]],
|
||||
image: Image.Image,
|
||||
) -> Detections:
|
||||
if not entities:
|
||||
return Detections.empty()
|
||||
|
||||
class_ids = [0] * len(entities) # Replace with actual class ID extraction logic
|
||||
xyxys = [
|
||||
(
|
||||
e[1][0] * image.width,
|
||||
e[1][1] * image.height,
|
||||
e[1][2] * image.width,
|
||||
e[1][3] * image.height,
|
||||
)
|
||||
for e in entities
|
||||
]
|
||||
confidences = [1.0] * len(entities) # Placeholder confidence
|
||||
|
||||
return Detections(xyxy=xyxys, class_id=class_ids, confidence=confidences)
|
||||
|
||||
|
||||
# Usage:
|
||||
# kosmos2 = Kosmos2.initialize()
|
||||
# detections = kosmos2(img="path_to_image.jpg")
|
Loading…
Reference in new issue