parent
da120e1aef
commit
0c4dd88f98
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,80 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import timm
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from pydantic import BaseModel, StrictFloat, StrictInt, validator
|
||||||
|
|
||||||
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
# Load the classes for image classification
|
||||||
|
with open(os.path.join(os.path.dirname(__file__), "fast_vit_classes.json")) as f:
|
||||||
|
FASTVIT_IMAGENET_1K_CLASSES = json.load(f)
|
||||||
|
|
||||||
|
|
||||||
|
class ClassificationResult(BaseModel):
|
||||||
|
class_id: List[StrictInt]
|
||||||
|
confidence: List[StrictFloat]
|
||||||
|
|
||||||
|
@validator("class_id", "confidence", pre=True, each_item=True)
|
||||||
|
def check_list_contents(cls, v):
|
||||||
|
assert isinstance(v, int) or isinstance(v, float), "must be integer or float"
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
class FastViT:
|
||||||
|
"""
|
||||||
|
FastViT model for image classification
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img (str): path to the input image
|
||||||
|
confidence_threshold (float): confidence threshold for the model's predictions
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ClassificationResult: a pydantic BaseModel containing the class ids and confidences of the model's predictions
|
||||||
|
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> fastvit = FastViT()
|
||||||
|
>>> result = fastvit(img="path_to_image.jpg", confidence_threshold=0.5)
|
||||||
|
|
||||||
|
|
||||||
|
To use, create a json file called: fast_vit_classes.json
|
||||||
|
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
self.model = timm.create_model(
|
||||||
|
"hf_hub:timm/fastvit_s12.apple_in1k", pretrained=True
|
||||||
|
).to(DEVICE)
|
||||||
|
data_config = timm.data.resolve_model_data_config(self.model)
|
||||||
|
self.transforms = timm.data.create_transform(**data_config, is_training=False)
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, img: str, confidence_threshold: float = 0.5
|
||||||
|
) -> ClassificationResult:
|
||||||
|
"""classifies the input image and returns the top k classes and their probabilities"""
|
||||||
|
img = Image.open(img).convert("RGB")
|
||||||
|
img_tensor = self.transforms(img).unsqueeze(0).to(DEVICE)
|
||||||
|
with torch.no_grad():
|
||||||
|
output = self.model(img_tensor)
|
||||||
|
probabilities = torch.nn.functional.softmax(output, dim=1)
|
||||||
|
|
||||||
|
# Get top k classes and their probabilities
|
||||||
|
top_probs, top_classes = torch.topk(
|
||||||
|
probabilities, k=FASTVIT_IMAGENET_1K_CLASSES
|
||||||
|
)
|
||||||
|
|
||||||
|
# Filter by confidence threshold
|
||||||
|
mask = top_probs > confidence_threshold
|
||||||
|
top_probs, top_classes = top_probs[mask], top_classes[mask]
|
||||||
|
|
||||||
|
# Convert to Python lists and map class indices to labels if needed
|
||||||
|
top_probs = top_probs.cpu().numpy().tolist()
|
||||||
|
top_classes = top_classes.cpu().numpy().tolist()
|
||||||
|
# top_class_labels = [FASTVIT_IMAGENET_1K_CLASSES[i] for i in top_classes] # Uncomment if class labels are needed
|
||||||
|
|
||||||
|
return ClassificationResult(class_id=top_classes, confidence=top_probs)
|
@ -0,0 +1,100 @@
|
|||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
from pydantic import BaseModel, root_validator, validator
|
||||||
|
from transformers import AutoModelForVision2Seq, AutoProcessor
|
||||||
|
|
||||||
|
|
||||||
|
# Assuming the Detections class represents the output of the model prediction
|
||||||
|
class Detections(BaseModel):
|
||||||
|
xyxy: List[Tuple[float, float, float, float]]
|
||||||
|
class_id: List[int]
|
||||||
|
confidence: List[float]
|
||||||
|
|
||||||
|
@root_validator
|
||||||
|
def check_length(cls, values):
|
||||||
|
assert (
|
||||||
|
len(values.get("xyxy"))
|
||||||
|
== len(values.get("class_id"))
|
||||||
|
== len(values.get("confidence"))
|
||||||
|
), "All fields must have the same length."
|
||||||
|
return values
|
||||||
|
|
||||||
|
@validator("xyxy", "class_id", "confidence", pre=True, each_item=True)
|
||||||
|
def check_not_empty(cls, v):
|
||||||
|
if isinstance(v, list) and len(v) == 0:
|
||||||
|
raise ValueError("List must not be empty")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def empty(cls):
|
||||||
|
return cls(xyxy=[], class_id=[], confidence=[])
|
||||||
|
|
||||||
|
|
||||||
|
class Kosmos2(BaseModel):
|
||||||
|
model: AutoModelForVision2Seq
|
||||||
|
processor: AutoProcessor
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def initialize(cls):
|
||||||
|
model = AutoModelForVision2Seq.from_pretrained(
|
||||||
|
"ydshieh/kosmos-2-patch14-224", trust_remote_code=True
|
||||||
|
)
|
||||||
|
processor = AutoProcessor.from_pretrained(
|
||||||
|
"ydshieh/kosmos-2-patch14-224", trust_remote_code=True
|
||||||
|
)
|
||||||
|
return cls(model=model, processor=processor)
|
||||||
|
|
||||||
|
def __call__(self, img: str) -> Detections:
|
||||||
|
image = Image.open(img)
|
||||||
|
prompt = "<grounding>An image of"
|
||||||
|
|
||||||
|
inputs = self.processor(text=prompt, images=image, return_tensors="pt")
|
||||||
|
outputs = self.model.generate(**inputs, use_cache=True, max_new_tokens=64)
|
||||||
|
|
||||||
|
generated_text = self.processor.batch_decode(outputs, skip_special_tokens=True)[
|
||||||
|
0
|
||||||
|
]
|
||||||
|
|
||||||
|
# The actual processing of generated_text to entities would go here
|
||||||
|
# For the purpose of this example, assume a mock function 'extract_entities' exists:
|
||||||
|
entities = self.extract_entities(generated_text)
|
||||||
|
|
||||||
|
# Convert entities to detections format
|
||||||
|
detections = self.process_entities_to_detections(entities, image)
|
||||||
|
return detections
|
||||||
|
|
||||||
|
def extract_entities(
|
||||||
|
self, text: str
|
||||||
|
) -> List[Tuple[str, Tuple[float, float, float, float]]]:
|
||||||
|
# Placeholder function for entity extraction
|
||||||
|
# This should be replaced with the actual method of extracting entities
|
||||||
|
return []
|
||||||
|
|
||||||
|
def process_entities_to_detections(
|
||||||
|
self,
|
||||||
|
entities: List[Tuple[str, Tuple[float, float, float, float]]],
|
||||||
|
image: Image.Image,
|
||||||
|
) -> Detections:
|
||||||
|
if not entities:
|
||||||
|
return Detections.empty()
|
||||||
|
|
||||||
|
class_ids = [0] * len(entities) # Replace with actual class ID extraction logic
|
||||||
|
xyxys = [
|
||||||
|
(
|
||||||
|
e[1][0] * image.width,
|
||||||
|
e[1][1] * image.height,
|
||||||
|
e[1][2] * image.width,
|
||||||
|
e[1][3] * image.height,
|
||||||
|
)
|
||||||
|
for e in entities
|
||||||
|
]
|
||||||
|
confidences = [1.0] * len(entities) # Placeholder confidence
|
||||||
|
|
||||||
|
return Detections(xyxy=xyxys, class_id=class_ids, confidence=confidences)
|
||||||
|
|
||||||
|
|
||||||
|
# Usage:
|
||||||
|
# kosmos2 = Kosmos2.initialize()
|
||||||
|
# detections = kosmos2(img="path_to_image.jpg")
|
Loading…
Reference in new issue