You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
101 lines
3.2 KiB
101 lines
3.2 KiB
from typing import List, Tuple
|
|
|
|
import numpy as np
|
|
from PIL import Image
|
|
from pydantic import BaseModel, root_validator, validator
|
|
from transformers import AutoModelForVision2Seq, AutoProcessor
|
|
|
|
|
|
# Assuming the Detections class represents the output of the model prediction
|
|
class Detections(BaseModel):
|
|
xyxy: List[Tuple[float, float, float, float]]
|
|
class_id: List[int]
|
|
confidence: List[float]
|
|
|
|
@root_validator
|
|
def check_length(cls, values):
|
|
assert (
|
|
len(values.get("xyxy"))
|
|
== len(values.get("class_id"))
|
|
== len(values.get("confidence"))
|
|
), "All fields must have the same length."
|
|
return values
|
|
|
|
@validator("xyxy", "class_id", "confidence", pre=True, each_item=True)
|
|
def check_not_empty(cls, v):
|
|
if isinstance(v, list) and len(v) == 0:
|
|
raise ValueError("List must not be empty")
|
|
return v
|
|
|
|
@classmethod
|
|
def empty(cls):
|
|
return cls(xyxy=[], class_id=[], confidence=[])
|
|
|
|
|
|
class Kosmos2(BaseModel):
|
|
model: AutoModelForVision2Seq
|
|
processor: AutoProcessor
|
|
|
|
@classmethod
|
|
def initialize(cls):
|
|
model = AutoModelForVision2Seq.from_pretrained(
|
|
"ydshieh/kosmos-2-patch14-224", trust_remote_code=True
|
|
)
|
|
processor = AutoProcessor.from_pretrained(
|
|
"ydshieh/kosmos-2-patch14-224", trust_remote_code=True
|
|
)
|
|
return cls(model=model, processor=processor)
|
|
|
|
def __call__(self, img: str) -> Detections:
|
|
image = Image.open(img)
|
|
prompt = "<grounding>An image of"
|
|
|
|
inputs = self.processor(text=prompt, images=image, return_tensors="pt")
|
|
outputs = self.model.generate(**inputs, use_cache=True, max_new_tokens=64)
|
|
|
|
generated_text = self.processor.batch_decode(outputs, skip_special_tokens=True)[
|
|
0
|
|
]
|
|
|
|
# The actual processing of generated_text to entities would go here
|
|
# For the purpose of this example, assume a mock function 'extract_entities' exists:
|
|
entities = self.extract_entities(generated_text)
|
|
|
|
# Convert entities to detections format
|
|
detections = self.process_entities_to_detections(entities, image)
|
|
return detections
|
|
|
|
def extract_entities(
|
|
self, text: str
|
|
) -> List[Tuple[str, Tuple[float, float, float, float]]]:
|
|
# Placeholder function for entity extraction
|
|
# This should be replaced with the actual method of extracting entities
|
|
return []
|
|
|
|
def process_entities_to_detections(
|
|
self,
|
|
entities: List[Tuple[str, Tuple[float, float, float, float]]],
|
|
image: Image.Image,
|
|
) -> Detections:
|
|
if not entities:
|
|
return Detections.empty()
|
|
|
|
class_ids = [0] * len(entities) # Replace with actual class ID extraction logic
|
|
xyxys = [
|
|
(
|
|
e[1][0] * image.width,
|
|
e[1][1] * image.height,
|
|
e[1][2] * image.width,
|
|
e[1][3] * image.height,
|
|
)
|
|
for e in entities
|
|
]
|
|
confidences = [1.0] * len(entities) # Placeholder confidence
|
|
|
|
return Detections(xyxy=xyxys, class_id=class_ids, confidence=confidences)
|
|
|
|
|
|
# Usage:
|
|
# kosmos2 = Kosmos2.initialize()
|
|
# detections = kosmos2(img="path_to_image.jpg")
|