|
|
|
@ -10,17 +10,17 @@ from transformers import CLIPModel, CLIPProcessor
|
|
|
|
|
class CLIPQ:
|
|
|
|
|
"""
|
|
|
|
|
ClipQ is an CLIQ based model that can be used to generate captions for images.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Attributes:
|
|
|
|
|
model_name (str): The name of the model to be used.
|
|
|
|
|
query_text (str): The query text to be used for the model.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
model_name (str): The name of the model to be used.
|
|
|
|
|
query_text (str): The query text to be used for the model.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
@ -30,13 +30,15 @@ class CLIPQ:
|
|
|
|
|
model_name: str = "openai/clip-vit-base-patch16",
|
|
|
|
|
query_text: str = "A photo ",
|
|
|
|
|
*args,
|
|
|
|
|
**kwargs
|
|
|
|
|
**kwargs,
|
|
|
|
|
):
|
|
|
|
|
self.model = CLIPModel.from_pretrained(model_name, *args, **kwargs)
|
|
|
|
|
self.model = CLIPModel.from_pretrained(
|
|
|
|
|
model_name, *args, **kwargs
|
|
|
|
|
)
|
|
|
|
|
self.processor = CLIPProcessor.from_pretrained(model_name)
|
|
|
|
|
self.query_text = query_text
|
|
|
|
|
|
|
|
|
|
def fetch_image_from_url(self, url = "https://picsum.photos/800"):
|
|
|
|
|
def fetch_image_from_url(self, url="https://picsum.photos/800"):
|
|
|
|
|
"""Fetches an image from the given url"""
|
|
|
|
|
response = requests.get(url)
|
|
|
|
|
if response.status_code != 200:
|
|
|
|
@ -48,7 +50,9 @@ class CLIPQ:
|
|
|
|
|
"""Loads an image from the given path"""
|
|
|
|
|
return Image.open(path)
|
|
|
|
|
|
|
|
|
|
def split_image(self, image, h_splits: int = 2, v_splits: int = 2):
|
|
|
|
|
def split_image(
|
|
|
|
|
self, image, h_splits: int = 2, v_splits: int = 2
|
|
|
|
|
):
|
|
|
|
|
"""Splits the given image into h_splits x v_splits parts"""
|
|
|
|
|
width, height = image.size
|
|
|
|
|
w_step, h_step = width // h_splits, height // v_splits
|
|
|
|
@ -57,7 +61,12 @@ class CLIPQ:
|
|
|
|
|
for i in range(v_splits):
|
|
|
|
|
for j in range(h_splits):
|
|
|
|
|
slice = image.crop(
|
|
|
|
|
(j * w_step, i * h_step, (j + 1) * w_step, (i + 1) * h_step)
|
|
|
|
|
(
|
|
|
|
|
j * w_step,
|
|
|
|
|
i * h_step,
|
|
|
|
|
(j + 1) * w_step,
|
|
|
|
|
(i + 1) * h_step,
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
slices.append(slice)
|
|
|
|
|
return slices
|
|
|
|
@ -74,10 +83,15 @@ class CLIPQ:
|
|
|
|
|
|
|
|
|
|
for slice in slices:
|
|
|
|
|
inputs = self.processor(
|
|
|
|
|
text=self.query_text, images=slice, return_tensors="pt", padding=True
|
|
|
|
|
text=self.query_text,
|
|
|
|
|
images=slice,
|
|
|
|
|
return_tensors="pt",
|
|
|
|
|
padding=True,
|
|
|
|
|
)
|
|
|
|
|
outputs = self.model(**inputs)
|
|
|
|
|
vectors.append(outputs.image_embeds.squeeze().detach().numpy())
|
|
|
|
|
vectors.append(
|
|
|
|
|
outputs.image_embeds.squeeze().detach().numpy()
|
|
|
|
|
)
|
|
|
|
|
return vectors
|
|
|
|
|
|
|
|
|
|
def run_from_url(
|
|
|
|
@ -118,7 +132,9 @@ class CLIPQ:
|
|
|
|
|
blur = GaussianBlur(kernel_size)
|
|
|
|
|
return blur(image)
|
|
|
|
|
|
|
|
|
|
def run_from_path(self, path: str = None, h_splits: int = 2, v_splits: int = 2):
|
|
|
|
|
def run_from_path(
|
|
|
|
|
self, path: str = None, h_splits: int = 2, v_splits: int = 2
|
|
|
|
|
):
|
|
|
|
|
"""Runs the model on the image loaded from the given path"""
|
|
|
|
|
image = self.load_image_from_path(path)
|
|
|
|
|
return self.get_vectors(image, h_splits, v_splits)
|
|
|
|
@ -132,7 +148,9 @@ class CLIPQ:
|
|
|
|
|
|
|
|
|
|
inputs_text = self.processor(
|
|
|
|
|
text=candidate_captions,
|
|
|
|
|
images=inputs_image.pixel_values[0], # Fix the argument name
|
|
|
|
|
images=inputs_image.pixel_values[
|
|
|
|
|
0
|
|
|
|
|
], # Fix the argument name
|
|
|
|
|
return_tensors="pt",
|
|
|
|
|
padding=True,
|
|
|
|
|
truncation=True,
|
|
|
|
@ -142,7 +160,8 @@ class CLIPQ:
|
|
|
|
|
pixel_values=inputs_image.pixel_values[0]
|
|
|
|
|
).image_embeds
|
|
|
|
|
text_embeds = self.model(
|
|
|
|
|
input_ids=inputs_text.input_ids, attention_mask=inputs_text.attention_mask
|
|
|
|
|
input_ids=inputs_text.input_ids,
|
|
|
|
|
attention_mask=inputs_text.attention_mask,
|
|
|
|
|
).text_embeds
|
|
|
|
|
|
|
|
|
|
# Calculate similarity between image and text
|
|
|
|
@ -156,6 +175,9 @@ class CLIPQ:
|
|
|
|
|
):
|
|
|
|
|
"""Get the best caption for the given image"""
|
|
|
|
|
slices = self.split_image(image, h_splits, v_splits)
|
|
|
|
|
captions = [self.get_captions(slice, candidate_captions) for slice in slices]
|
|
|
|
|
captions = [
|
|
|
|
|
self.get_captions(slice, candidate_captions)
|
|
|
|
|
for slice in slices
|
|
|
|
|
]
|
|
|
|
|
concated_captions = "".join(captions)
|
|
|
|
|
return concated_captions
|
|
|
|
|
return concated_captions
|
|
|
|
|