[swarm.models][cleanup]

pull/334/head
Kye 1 year ago
parent 49ce4db646
commit e9bb8dcbf4

@ -1,131 +0,0 @@
from typing import List, Tuple
from PIL import Image
from pydantic import BaseModel, model_validator, validator
from transformers import AutoModelForVision2Seq, AutoProcessor
# Assuming the Detections class represents the output of the model prediction
class Detections(BaseModel):
xyxy: List[Tuple[float, float, float, float]]
class_id: List[int]
confidence: List[float]
@model_validator
def check_length(cls, values):
assert (
len(values.get("xyxy"))
== len(values.get("class_id"))
== len(values.get("confidence"))
), "All fields must have the same length."
return values
@validator(
"xyxy", "class_id", "confidence", pre=True, each_item=True
)
def check_not_empty(cls, v):
if isinstance(v, list) and len(v) == 0:
raise ValueError("List must not be empty")
return v
@classmethod
def empty(cls):
return cls(xyxy=[], class_id=[], confidence=[])
class Kosmos2(BaseModel):
"""
Kosmos2
Args:
------
model: AutoModelForVision2Seq
processor: AutoProcessor
Usage:
------
>>> from swarms import Kosmos2
>>> from swarms.models.kosmos2 import Detections
>>> from PIL import Image
>>> model = Kosmos2.initialize()
>>> image = Image.open("path_to_image.jpg")
>>> detections = model(image)
>>> print(detections)
"""
model: AutoModelForVision2Seq
processor: AutoProcessor
@classmethod
def initialize(cls):
model = AutoModelForVision2Seq.from_pretrained(
"ydshieh/kosmos-2-patch14-224", trust_remote_code=True
)
processor = AutoProcessor.from_pretrained(
"ydshieh/kosmos-2-patch14-224", trust_remote_code=True
)
return cls(model=model, processor=processor)
def __call__(self, img: str) -> Detections:
image = Image.open(img)
prompt = "<grounding>An image of"
inputs = self.processor(
text=prompt, images=image, return_tensors="pt"
)
outputs = self.model.generate(
**inputs, use_cache=True, max_new_tokens=64
)
generated_text = self.processor.batch_decode(
outputs, skip_special_tokens=True
)[0]
# The actual processing of generated_text to entities would go here
# For the purpose of this example, assume a mock function 'extract_entities' exists:
entities = self.extract_entities(generated_text)
# Convert entities to detections format
detections = self.process_entities_to_detections(
entities, image
)
return detections
def extract_entities(
self, text: str
) -> List[Tuple[str, Tuple[float, float, float, float]]]:
# Placeholder function for entity extraction
# This should be replaced with the actual method of extracting entities
return []
def process_entities_to_detections(
self,
entities: List[Tuple[str, Tuple[float, float, float, float]]],
image: Image.Image,
) -> Detections:
if not entities:
return Detections.empty()
class_ids = [0] * len(
entities
) # Replace with actual class ID extraction logic
xyxys = [
(
e[1][0] * image.width,
e[1][1] * image.height,
e[1][2] * image.width,
e[1][3] * image.height,
)
for e in entities
]
confidences = [1.0] * len(entities) # Placeholder confidence
return Detections(
xyxy=xyxys, class_id=class_ids, confidence=confidences
)
# Usage:
# kosmos2 = Kosmos2.initialize()
# detections = kosmos2(img="path_to_image.jpg")

@ -8,6 +8,8 @@ import torchvision.transforms as T
from PIL import Image from PIL import Image
from transformers import AutoModelForVision2Seq, AutoProcessor from transformers import AutoModelForVision2Seq, AutoProcessor
from swarms.models.base_multimodal_model import BaseMultimodalModel
# utils # utils
def is_overlapping(rect1, rect2): def is_overlapping(rect1, rect2):
@ -16,7 +18,7 @@ def is_overlapping(rect1, rect2):
return not (x2 < x3 or x1 > x4 or y2 < y3 or y1 > y4) return not (x2 < x3 or x1 > x4 or y2 < y3 or y1 > y4)
class Kosmos: class Kosmos(BaseMultimodalModel):
""" """
Kosmos model by Yen-Chun Shieh Kosmos model by Yen-Chun Shieh
@ -35,9 +37,14 @@ class Kosmos:
def __init__( def __init__(
self, self,
model_name="ydshieh/kosmos-2-patch14-224", model_name="ydshieh/kosmos-2-patch14-224",
max_new_tokens: int = 64,
*args, *args,
**kwargs, **kwargs,
): ):
super(Kosmos, self).__init__(*args, **kwargs)
self.max_new_tokens = max_new_tokens
self.model = AutoModelForVision2Seq.from_pretrained( self.model = AutoModelForVision2Seq.from_pretrained(
model_name, trust_remote_code=True, *args, **kwargs model_name, trust_remote_code=True, *args, **kwargs
) )
@ -45,81 +52,75 @@ class Kosmos:
model_name, trust_remote_code=True, *args, **kwargs model_name, trust_remote_code=True, *args, **kwargs
) )
def get_image(self, url): def get_image(self, url: str):
"""Image""" """Get image from url
Args:
url (str): url of image
Returns:
_type_: _description_
"""
return Image.open(requests.get(url, stream=True).raw) return Image.open(requests.get(url, stream=True).raw)
def run(self, prompt, image): def run(self, task: str, image: str, *args, **kwargs):
"""Run Kosmos""" """Run the model
inputs = self.processor(
text=prompt, images=image, return_tensors="pt"
)
generated_ids = self.model.generate(
pixel_values=inputs["pixel_values"],
input_ids=inputs["input_ids"][:, :-1],
attention_mask=inputs["attention_mask"][:, :-1],
img_features=None,
img_attn_mask=inputs["img_attn_mask"][:, :-1],
use_cache=True,
max_new_tokens=64,
)
generated_texts = self.processor.batch_decode(
generated_ids,
skip_special_tokens=True,
)[0]
processed_text, entities = (
self.processor.post_process_generation(generated_texts)
)
def __call__(self, prompt, image): Args:
"""Run call""" task (str): task to run
image (str): img url
"""
inputs = self.processor( inputs = self.processor(
text=prompt, images=image, return_tensors="pt" text=task, images=image, return_tensors="pt"
) )
generated_ids = self.model.generate( generated_ids = self.model.generate(
pixel_values=inputs["pixel_values"], pixel_values=inputs["pixel_values"],
input_ids=inputs["input_ids"][:, :-1], input_ids=inputs["input_ids"][:, :-1],
attention_mask=inputs["attention_mask"][:, :-1], attention_mask=inputs["attention_mask"][:, :-1],
img_features=None, image_embeds=None,
img_attn_mask=inputs["img_attn_mask"][:, :-1], img_attn_mask=inputs["img_attn_mask"][:, :-1],
use_cache=True, use_cache=True,
max_new_tokens=64, max_new_tokens=self.max_new_tokens,
) )
generated_texts = self.processor.batch_decode( generated_texts = self.processor.batch_decode(
generated_ids, generated_ids,
skip_special_tokens=True, skip_special_tokens=True,
)[0] )[0]
processed_text, entities = ( processed_text, entities = (
self.processor.post_process_generation(generated_texts) self.processor.post_process_generation(generated_texts)
) )
return processed_text, entities
# tasks # tasks
def multimodal_grounding(self, phrase, image_url): def multimodal_grounding(self, phrase, image_url):
prompt = f"<grounding><phrase> {phrase} </phrase>" task = f"<grounding><phrase> {phrase} </phrase>"
self.run(prompt, image_url) self.run(task, image_url)
def referring_expression_comprehension(self, phrase, image_url): def referring_expression_comprehension(self, phrase, image_url):
prompt = f"<grounding><phrase> {phrase} </phrase>" task = f"<grounding><phrase> {phrase} </phrase>"
self.run(prompt, image_url) self.run(task, image_url)
def referring_expression_generation(self, phrase, image_url): def referring_expression_generation(self, phrase, image_url):
prompt = ( task = (
"<grounding><phrase>" "<grounding><phrase>"
" It</phrase><object><patch_index_0044><patch_index_0863></object> is" " It</phrase><object><patch_index_0044><patch_index_0863></object> is"
) )
self.run(prompt, image_url) self.run(task, image_url)
def grounded_vqa(self, question, image_url): def grounded_vqa(self, question, image_url):
prompt = f"<grounding> Question: {question} Answer:" task = f"<grounding> Question: {question} Answer:"
self.run(prompt, image_url) self.run(task, image_url)
def grounded_image_captioning(self, image_url): def grounded_image_captioning(self, image_url):
prompt = "<grounding> An image of" task = "<grounding> An image of"
self.run(prompt, image_url) self.run(task, image_url)
def grounded_image_captioning_detailed(self, image_url): def grounded_image_captioning_detailed(self, image_url):
prompt = "<grounding> Describe this image in detail" task = "<grounding> Describe this image in detail"
self.run(prompt, image_url) self.run(task, image_url)
def draw_entity_boxes_on_image( def draw_entity_boxes_on_image(
image, entities, show=False, save_path=None image, entities, show=False, save_path=None
@ -320,7 +321,7 @@ class Kosmos:
return new_image return new_image
def generate_boxees(self, prompt, image_url): def generate_boxees(self, task, image_url):
image = self.get_image(image_url) image = self.get_image(image_url)
processed_text, entities = self.process_prompt(prompt, image) processed_text, entities = self.process_task(task, image)
self.draw_entity_boxes_on_image(image, entities, show=True) self.draw_entity_boxes_on_image(image, entities, show=True)

@ -140,6 +140,18 @@ class StableDiffusion:
return image_paths return image_paths
def generate_and_move_image(self, prompt, iteration, folder_path): def generate_and_move_image(self, prompt, iteration, folder_path):
"""
Generates an image based on the given prompt and moves it to the specified folder.
Args:
prompt (str): The prompt used to generate the image.
iteration (int): The iteration number.
folder_path (str): The path to the folder where the image will be moved.
Returns:
str: The path of the moved image.
"""
# Generate the image # Generate the image
image_paths = self.run(prompt) image_paths = self.run(prompt)
if not image_paths: if not image_paths:

Loading…
Cancel
Save