parent
49ce4db646
commit
e9bb8dcbf4
@ -1,131 +0,0 @@
|
||||
from typing import List, Tuple
|
||||
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, model_validator, validator
|
||||
from transformers import AutoModelForVision2Seq, AutoProcessor
|
||||
|
||||
|
||||
# Assuming the Detections class represents the output of the model prediction
|
||||
class Detections(BaseModel):
|
||||
xyxy: List[Tuple[float, float, float, float]]
|
||||
class_id: List[int]
|
||||
confidence: List[float]
|
||||
|
||||
@model_validator
|
||||
def check_length(cls, values):
|
||||
assert (
|
||||
len(values.get("xyxy"))
|
||||
== len(values.get("class_id"))
|
||||
== len(values.get("confidence"))
|
||||
), "All fields must have the same length."
|
||||
return values
|
||||
|
||||
@validator(
|
||||
"xyxy", "class_id", "confidence", pre=True, each_item=True
|
||||
)
|
||||
def check_not_empty(cls, v):
|
||||
if isinstance(v, list) and len(v) == 0:
|
||||
raise ValueError("List must not be empty")
|
||||
return v
|
||||
|
||||
@classmethod
|
||||
def empty(cls):
|
||||
return cls(xyxy=[], class_id=[], confidence=[])
|
||||
|
||||
|
||||
class Kosmos2(BaseModel):
|
||||
"""
|
||||
Kosmos2
|
||||
|
||||
Args:
|
||||
------
|
||||
model: AutoModelForVision2Seq
|
||||
processor: AutoProcessor
|
||||
|
||||
Usage:
|
||||
------
|
||||
>>> from swarms import Kosmos2
|
||||
>>> from swarms.models.kosmos2 import Detections
|
||||
>>> from PIL import Image
|
||||
>>> model = Kosmos2.initialize()
|
||||
>>> image = Image.open("path_to_image.jpg")
|
||||
>>> detections = model(image)
|
||||
>>> print(detections)
|
||||
|
||||
"""
|
||||
|
||||
model: AutoModelForVision2Seq
|
||||
processor: AutoProcessor
|
||||
|
||||
@classmethod
|
||||
def initialize(cls):
|
||||
model = AutoModelForVision2Seq.from_pretrained(
|
||||
"ydshieh/kosmos-2-patch14-224", trust_remote_code=True
|
||||
)
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
"ydshieh/kosmos-2-patch14-224", trust_remote_code=True
|
||||
)
|
||||
return cls(model=model, processor=processor)
|
||||
|
||||
def __call__(self, img: str) -> Detections:
|
||||
image = Image.open(img)
|
||||
prompt = "<grounding>An image of"
|
||||
|
||||
inputs = self.processor(
|
||||
text=prompt, images=image, return_tensors="pt"
|
||||
)
|
||||
outputs = self.model.generate(
|
||||
**inputs, use_cache=True, max_new_tokens=64
|
||||
)
|
||||
|
||||
generated_text = self.processor.batch_decode(
|
||||
outputs, skip_special_tokens=True
|
||||
)[0]
|
||||
|
||||
# The actual processing of generated_text to entities would go here
|
||||
# For the purpose of this example, assume a mock function 'extract_entities' exists:
|
||||
entities = self.extract_entities(generated_text)
|
||||
|
||||
# Convert entities to detections format
|
||||
detections = self.process_entities_to_detections(
|
||||
entities, image
|
||||
)
|
||||
return detections
|
||||
|
||||
def extract_entities(
|
||||
self, text: str
|
||||
) -> List[Tuple[str, Tuple[float, float, float, float]]]:
|
||||
# Placeholder function for entity extraction
|
||||
# This should be replaced with the actual method of extracting entities
|
||||
return []
|
||||
|
||||
def process_entities_to_detections(
|
||||
self,
|
||||
entities: List[Tuple[str, Tuple[float, float, float, float]]],
|
||||
image: Image.Image,
|
||||
) -> Detections:
|
||||
if not entities:
|
||||
return Detections.empty()
|
||||
|
||||
class_ids = [0] * len(
|
||||
entities
|
||||
) # Replace with actual class ID extraction logic
|
||||
xyxys = [
|
||||
(
|
||||
e[1][0] * image.width,
|
||||
e[1][1] * image.height,
|
||||
e[1][2] * image.width,
|
||||
e[1][3] * image.height,
|
||||
)
|
||||
for e in entities
|
||||
]
|
||||
confidences = [1.0] * len(entities) # Placeholder confidence
|
||||
|
||||
return Detections(
|
||||
xyxy=xyxys, class_id=class_ids, confidence=confidences
|
||||
)
|
||||
|
||||
|
||||
# Usage:
|
||||
# kosmos2 = Kosmos2.initialize()
|
||||
# detections = kosmos2(img="path_to_image.jpg")
|
Loading…
Reference in new issue