You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/swarms/models/kosmos2.py

101 lines
3.2 KiB

from typing import List, Tuple
import numpy as np
from PIL import Image
from pydantic import BaseModel, root_validator, validator
from transformers import AutoModelForVision2Seq, AutoProcessor
# Assuming the Detections class represents the output of the model prediction
class Detections(BaseModel):
xyxy: List[Tuple[float, float, float, float]]
class_id: List[int]
confidence: List[float]
@root_validator
def check_length(cls, values):
assert (
len(values.get("xyxy"))
== len(values.get("class_id"))
== len(values.get("confidence"))
), "All fields must have the same length."
return values
@validator("xyxy", "class_id", "confidence", pre=True, each_item=True)
def check_not_empty(cls, v):
if isinstance(v, list) and len(v) == 0:
raise ValueError("List must not be empty")
return v
@classmethod
def empty(cls):
return cls(xyxy=[], class_id=[], confidence=[])
class Kosmos2(BaseModel):
model: AutoModelForVision2Seq
processor: AutoProcessor
@classmethod
def initialize(cls):
model = AutoModelForVision2Seq.from_pretrained(
"ydshieh/kosmos-2-patch14-224", trust_remote_code=True
)
processor = AutoProcessor.from_pretrained(
"ydshieh/kosmos-2-patch14-224", trust_remote_code=True
)
return cls(model=model, processor=processor)
def __call__(self, img: str) -> Detections:
image = Image.open(img)
prompt = "<grounding>An image of"
inputs = self.processor(text=prompt, images=image, return_tensors="pt")
outputs = self.model.generate(**inputs, use_cache=True, max_new_tokens=64)
generated_text = self.processor.batch_decode(outputs, skip_special_tokens=True)[
0
]
# The actual processing of generated_text to entities would go here
# For the purpose of this example, assume a mock function 'extract_entities' exists:
entities = self.extract_entities(generated_text)
# Convert entities to detections format
detections = self.process_entities_to_detections(entities, image)
return detections
def extract_entities(
self, text: str
) -> List[Tuple[str, Tuple[float, float, float, float]]]:
# Placeholder function for entity extraction
# This should be replaced with the actual method of extracting entities
return []
def process_entities_to_detections(
self,
entities: List[Tuple[str, Tuple[float, float, float, float]]],
image: Image.Image,
) -> Detections:
if not entities:
return Detections.empty()
class_ids = [0] * len(entities) # Replace with actual class ID extraction logic
xyxys = [
(
e[1][0] * image.width,
e[1][1] * image.height,
e[1][2] * image.width,
e[1][3] * image.height,
)
for e in entities
]
confidences = [1.0] * len(entities) # Placeholder confidence
return Detections(xyxy=xyxys, class_id=class_ids, confidence=confidences)
# Usage:
# kosmos2 = Kosmos2.initialize()
# detections = kosmos2(img="path_to_image.jpg")