parent
0671d42458
commit
ccc38a4abd
@ -0,0 +1,223 @@
|
|||||||
|
import torch
|
||||||
|
from transformers import AutoProcessor, IdeficsForVisionText2Text
|
||||||
|
|
||||||
|
|
||||||
|
class Idefics:
|
||||||
|
"""
|
||||||
|
|
||||||
|
A class for multimodal inference using pre-trained models from the Hugging Face Hub.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
device : str
|
||||||
|
The device to use for inference.
|
||||||
|
checkpoint : str, optional
|
||||||
|
The name of the pre-trained model checkpoint (default is "HuggingFaceM4/idefics-9b-instruct").
|
||||||
|
processor : transformers.PreTrainedProcessor
|
||||||
|
The pre-trained processor.
|
||||||
|
max_length : int
|
||||||
|
The maximum length of the generated text.
|
||||||
|
chat_history : list
|
||||||
|
The chat history.
|
||||||
|
|
||||||
|
Methods
|
||||||
|
-------
|
||||||
|
infer(prompts, batched_mode=True)
|
||||||
|
Generates text based on the provided prompts.
|
||||||
|
chat(user_input)
|
||||||
|
Engages in a continuous bidirectional conversation based on the user input.
|
||||||
|
set_checkpoint(checkpoint)
|
||||||
|
Changes the model checkpoint.
|
||||||
|
set_device(device)
|
||||||
|
Changes the device used for inference.
|
||||||
|
set_max_length(max_length)
|
||||||
|
Changes the maximum length of the generated text.
|
||||||
|
clear_chat_history()
|
||||||
|
Clears the chat history.
|
||||||
|
|
||||||
|
|
||||||
|
# Usage
|
||||||
|
```
|
||||||
|
from exa import idefics
|
||||||
|
mmi = idefics()
|
||||||
|
|
||||||
|
user_input = "User: What is in this image? https://upload.wikimedia.org/wikipedia/commons/8/86/Id%C3%A9fix.JPG"
|
||||||
|
response = mmi.chat(user_input)
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
user_input = "User: And who is that? https://static.wikia.nocookie.net/asterix/images/2/25/R22b.gif/revision/latest?cb=20110815073052"
|
||||||
|
response = mmi.chat(user_input)
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
mmi.set_checkpoint("new_checkpoint")
|
||||||
|
mmi.set_device("cpu")
|
||||||
|
mmi.set_max_length(200)
|
||||||
|
mmi.clear_chat_history()
|
||||||
|
```
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
checkpoint="HuggingFaceM4/idefics-9b-instruct",
|
||||||
|
device=None,
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
max_length=100,
|
||||||
|
):
|
||||||
|
self.device = (
|
||||||
|
device if device else ("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
)
|
||||||
|
self.model = IdeficsForVisionText2Text.from_pretrained(
|
||||||
|
checkpoint,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
).to(self.device)
|
||||||
|
|
||||||
|
self.processor = AutoProcessor.from_pretrained(checkpoint)
|
||||||
|
|
||||||
|
self.max_length = max_length
|
||||||
|
|
||||||
|
self.chat_history = []
|
||||||
|
|
||||||
|
def run(self, prompts, batched_mode=True):
|
||||||
|
"""
|
||||||
|
Generates text based on the provided prompts.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
prompts : list
|
||||||
|
A list of prompts. Each prompt is a list of text strings and images.
|
||||||
|
batched_mode : bool, optional
|
||||||
|
Whether to process the prompts in batched mode. If True, all prompts are processed together. If False, only the first prompt is processed (default is True).
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
list
|
||||||
|
A list of generated text strings.
|
||||||
|
"""
|
||||||
|
inputs = (
|
||||||
|
self.processor(
|
||||||
|
prompts, add_end_of_utterance_token=False, return_tensors="pt"
|
||||||
|
).to(self.device)
|
||||||
|
if batched_mode
|
||||||
|
else self.processor(prompts[0], return_tensors="pt").to(self.device)
|
||||||
|
)
|
||||||
|
|
||||||
|
exit_condition = self.processor.tokenizer(
|
||||||
|
"<end_of_utterance>", add_special_tokens=False
|
||||||
|
).input_ids
|
||||||
|
|
||||||
|
bad_words_ids = self.processor.tokenizer(
|
||||||
|
["<image>", "<fake_token_around_image"], add_special_tokens=False
|
||||||
|
).input_ids
|
||||||
|
|
||||||
|
generated_ids = self.model.generate(
|
||||||
|
**inputs,
|
||||||
|
eos_token_id=exit_condition,
|
||||||
|
bad_words_ids=bad_words_ids,
|
||||||
|
max_length=self.max_length,
|
||||||
|
)
|
||||||
|
generated_text = self.processor.batch_decode(
|
||||||
|
generated_ids, skip_special_tokens=True
|
||||||
|
)
|
||||||
|
return generated_text
|
||||||
|
|
||||||
|
def __call__(self, prompts, batched_mode=True):
|
||||||
|
"""
|
||||||
|
Generates text based on the provided prompts.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
prompts : list
|
||||||
|
A list of prompts. Each prompt is a list of text strings and images.
|
||||||
|
batched_mode : bool, optional
|
||||||
|
Whether to process the prompts in batched mode. If True, all prompts are processed together. If False, only the first prompt is processed (default is True).
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
list
|
||||||
|
A list of generated text strings.
|
||||||
|
"""
|
||||||
|
inputs = (
|
||||||
|
self.processor(
|
||||||
|
prompts, add_end_of_utterance_token=False, return_tensors="pt"
|
||||||
|
).to(self.device)
|
||||||
|
if batched_mode
|
||||||
|
else self.processor(prompts[0], return_tensors="pt").to(self.device)
|
||||||
|
)
|
||||||
|
|
||||||
|
exit_condition = self.processor.tokenizer(
|
||||||
|
"<end_of_utterance>", add_special_tokens=False
|
||||||
|
).input_ids
|
||||||
|
|
||||||
|
bad_words_ids = self.processor.tokenizer(
|
||||||
|
["<image>", "<fake_token_around_image"], add_special_tokens=False
|
||||||
|
).input_ids
|
||||||
|
|
||||||
|
generated_ids = self.model.generate(
|
||||||
|
**inputs,
|
||||||
|
eos_token_id=exit_condition,
|
||||||
|
bad_words_ids=bad_words_ids,
|
||||||
|
max_length=self.max_length,
|
||||||
|
)
|
||||||
|
generated_text = self.processor.batch_decode(
|
||||||
|
generated_ids, skip_special_tokens=True
|
||||||
|
)
|
||||||
|
return generated_text
|
||||||
|
|
||||||
|
def chat(self, user_input):
|
||||||
|
"""
|
||||||
|
Engages in a continuous bidirectional conversation based on the user input.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
user_input : str
|
||||||
|
The user input.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
str
|
||||||
|
The model's response.
|
||||||
|
"""
|
||||||
|
self.chat_history.append(user_input)
|
||||||
|
|
||||||
|
prompts = [self.chat_history]
|
||||||
|
|
||||||
|
response = self.run(prompts)[0]
|
||||||
|
|
||||||
|
self.chat_history.append(response)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
def set_checkpoint(self, checkpoint):
|
||||||
|
"""
|
||||||
|
Changes the model checkpoint.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
checkpoint : str
|
||||||
|
The name of the new pre-trained model checkpoint.
|
||||||
|
"""
|
||||||
|
self.model = IdeficsForVisionText2Text.from_pretrained(
|
||||||
|
checkpoint, torch_dtype=torch.bfloat16
|
||||||
|
).to(self.device)
|
||||||
|
self.processor = AutoProcessor.from_pretrained(checkpoint)
|
||||||
|
|
||||||
|
def set_device(self, device):
|
||||||
|
"""
|
||||||
|
Changes the device used for inference.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
device : str
|
||||||
|
The new device to use for inference.
|
||||||
|
"""
|
||||||
|
self.device = device
|
||||||
|
self.model.to(self.device)
|
||||||
|
|
||||||
|
def set_max_length(self, max_length):
|
||||||
|
"""Set max_length"""
|
||||||
|
self.max_length = max_length
|
||||||
|
|
||||||
|
def clear_chat_history(self):
|
||||||
|
"""Clear chat history"""
|
||||||
|
self.chat_history = []
|
@ -0,0 +1,284 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import requests
|
||||||
|
import torch
|
||||||
|
import torchvision.transforms as T
|
||||||
|
from PIL import Image
|
||||||
|
from transformers import AutoModelForVision2Seq, AutoProcessor
|
||||||
|
|
||||||
|
|
||||||
|
# utils
|
||||||
|
def is_overlapping(rect1, rect2):
|
||||||
|
x1, y1, x2, y2 = rect1
|
||||||
|
x3, y3, x4, y4 = rect2
|
||||||
|
return not (x2 < x3 or x1 > x4 or y2 < y3 or y1 > y4)
|
||||||
|
|
||||||
|
|
||||||
|
class Kosmos:
|
||||||
|
"""
|
||||||
|
|
||||||
|
Args:
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# Initialize Kosmos
|
||||||
|
kosmos = Kosmos()
|
||||||
|
|
||||||
|
# Perform multimodal grounding
|
||||||
|
kosmos.multimodal_grounding("Find the red apple in the image.", "https://example.com/apple.jpg")
|
||||||
|
|
||||||
|
# Perform referring expression comprehension
|
||||||
|
kosmos.referring_expression_comprehension("Show me the green bottle.", "https://example.com/bottle.jpg")
|
||||||
|
|
||||||
|
# Generate referring expressions
|
||||||
|
kosmos.referring_expression_generation("It is on the table.", "https://example.com/table.jpg")
|
||||||
|
|
||||||
|
# Perform grounded visual question answering
|
||||||
|
kosmos.grounded_vqa("What is the color of the car?", "https://example.com/car.jpg")
|
||||||
|
|
||||||
|
# Generate grounded image caption
|
||||||
|
kosmos.grounded_image_captioning("https://example.com/beach.jpg")
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name="ydshieh/kosmos-2-patch14-224",
|
||||||
|
):
|
||||||
|
self.model = AutoModelForVision2Seq.from_pretrained(
|
||||||
|
model_name, trust_remote_code=True
|
||||||
|
)
|
||||||
|
self.processor = AutoProcessor.from_pretrained(
|
||||||
|
model_name, trust_remote_code=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_image(self, url):
|
||||||
|
"""Image"""
|
||||||
|
return Image.open(requests.get(url, stream=True).raw)
|
||||||
|
|
||||||
|
def run(self, prompt, image):
|
||||||
|
"""Run Kosmos"""
|
||||||
|
inputs = self.processor(text=prompt, images=image, return_tensors="pt")
|
||||||
|
generated_ids = self.model.generate(
|
||||||
|
pixel_values=inputs["pixel_values"],
|
||||||
|
input_ids=inputs["input_ids"][:, :-1],
|
||||||
|
attention_mask=inputs["attention_mask"][:, :-1],
|
||||||
|
img_features=None,
|
||||||
|
img_attn_mask=inputs["img_attn_mask"][:, :-1],
|
||||||
|
use_cache=True,
|
||||||
|
max_new_tokens=64,
|
||||||
|
)
|
||||||
|
generated_texts = self.processor.batch_decode(
|
||||||
|
generated_ids,
|
||||||
|
skip_special_tokens=True,
|
||||||
|
)[0]
|
||||||
|
processed_text, entities = self.processor.post_process_generation(
|
||||||
|
generated_texts
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, prompt, image):
|
||||||
|
"""Run call"""
|
||||||
|
inputs = self.processor(text=prompt, images=image, return_tensors="pt")
|
||||||
|
generated_ids = self.model.generate(
|
||||||
|
pixel_values=inputs["pixel_values"],
|
||||||
|
input_ids=inputs["input_ids"][:, :-1],
|
||||||
|
attention_mask=inputs["attention_mask"][:, :-1],
|
||||||
|
img_features=None,
|
||||||
|
img_attn_mask=inputs["img_attn_mask"][:, :-1],
|
||||||
|
use_cache=True,
|
||||||
|
max_new_tokens=64,
|
||||||
|
)
|
||||||
|
generated_texts = self.processor.batch_decode(
|
||||||
|
generated_ids,
|
||||||
|
skip_special_tokens=True,
|
||||||
|
)[0]
|
||||||
|
processed_text, entities = self.processor.post_process_generation(
|
||||||
|
generated_texts
|
||||||
|
)
|
||||||
|
|
||||||
|
# tasks
|
||||||
|
def multimodal_grounding(self, phrase, image_url):
|
||||||
|
prompt = f"<grounding><phrase> {phrase} </phrase>"
|
||||||
|
self.run(prompt, image_url)
|
||||||
|
|
||||||
|
def referring_expression_comprehension(self, phrase, image_url):
|
||||||
|
prompt = f"<grounding><phrase> {phrase} </phrase>"
|
||||||
|
self.run(prompt, image_url)
|
||||||
|
|
||||||
|
def referring_expression_generation(self, phrase, image_url):
|
||||||
|
prompt = "<grounding><phrase> It</phrase><object><patch_index_0044><patch_index_0863></object> is"
|
||||||
|
self.run(prompt, image_url)
|
||||||
|
|
||||||
|
def grounded_vqa(self, question, image_url):
|
||||||
|
prompt = f"<grounding> Question: {question} Answer:"
|
||||||
|
self.run(prompt, image_url)
|
||||||
|
|
||||||
|
def grounded_image_captioning(self, image_url):
|
||||||
|
prompt = "<grounding> An image of"
|
||||||
|
self.run(prompt, image_url)
|
||||||
|
|
||||||
|
def grounded_image_captioning_detailed(self, image_url):
|
||||||
|
prompt = "<grounding> Describe this image in detail"
|
||||||
|
self.run(prompt, image_url)
|
||||||
|
|
||||||
|
def draw_entity_boxes_on_image(image, entities, show=False, save_path=None):
|
||||||
|
"""_summary_
|
||||||
|
Args:
|
||||||
|
image (_type_): image or image path
|
||||||
|
collect_entity_location (_type_): _description_
|
||||||
|
"""
|
||||||
|
if isinstance(image, Image.Image):
|
||||||
|
image_h = image.height
|
||||||
|
image_w = image.width
|
||||||
|
image = np.array(image)[:, :, [2, 1, 0]]
|
||||||
|
elif isinstance(image, str):
|
||||||
|
if os.path.exists(image):
|
||||||
|
pil_img = Image.open(image).convert("RGB")
|
||||||
|
image = np.array(pil_img)[:, :, [2, 1, 0]]
|
||||||
|
image_h = pil_img.height
|
||||||
|
image_w = pil_img.width
|
||||||
|
else:
|
||||||
|
raise ValueError(f"invaild image path, {image}")
|
||||||
|
elif isinstance(image, torch.Tensor):
|
||||||
|
# pdb.set_trace()
|
||||||
|
image_tensor = image.cpu()
|
||||||
|
reverse_norm_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073])[
|
||||||
|
:, None, None
|
||||||
|
]
|
||||||
|
reverse_norm_std = torch.tensor([0.26862954, 0.26130258, 0.27577711])[
|
||||||
|
:, None, None
|
||||||
|
]
|
||||||
|
image_tensor = image_tensor * reverse_norm_std + reverse_norm_mean
|
||||||
|
pil_img = T.ToPILImage()(image_tensor)
|
||||||
|
image_h = pil_img.height
|
||||||
|
image_w = pil_img.width
|
||||||
|
image = np.array(pil_img)[:, :, [2, 1, 0]]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"invaild image format, {type(image)} for {image}")
|
||||||
|
|
||||||
|
if len(entities) == 0:
|
||||||
|
return image
|
||||||
|
|
||||||
|
new_image = image.copy()
|
||||||
|
previous_bboxes = []
|
||||||
|
# size of text
|
||||||
|
text_size = 1
|
||||||
|
# thickness of text
|
||||||
|
text_line = 1 # int(max(1 * min(image_h, image_w) / 512, 1))
|
||||||
|
box_line = 3
|
||||||
|
(c_width, text_height), _ = cv2.getTextSize(
|
||||||
|
"F", cv2.FONT_HERSHEY_COMPLEX, text_size, text_line
|
||||||
|
)
|
||||||
|
base_height = int(text_height * 0.675)
|
||||||
|
text_offset_original = text_height - base_height
|
||||||
|
text_spaces = 3
|
||||||
|
|
||||||
|
for entity_name, (start, end), bboxes in entities:
|
||||||
|
for x1_norm, y1_norm, x2_norm, y2_norm in bboxes:
|
||||||
|
orig_x1, orig_y1, orig_x2, orig_y2 = (
|
||||||
|
int(x1_norm * image_w),
|
||||||
|
int(y1_norm * image_h),
|
||||||
|
int(x2_norm * image_w),
|
||||||
|
int(y2_norm * image_h),
|
||||||
|
)
|
||||||
|
# draw bbox
|
||||||
|
# random color
|
||||||
|
color = tuple(np.random.randint(0, 255, size=3).tolist())
|
||||||
|
new_image = cv2.rectangle(
|
||||||
|
new_image, (orig_x1, orig_y1), (orig_x2, orig_y2), color, box_line
|
||||||
|
)
|
||||||
|
|
||||||
|
l_o, r_o = (
|
||||||
|
box_line // 2 + box_line % 2,
|
||||||
|
box_line // 2 + box_line % 2 + 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
x1 = orig_x1 - l_o
|
||||||
|
y1 = orig_y1 - l_o
|
||||||
|
|
||||||
|
if y1 < text_height + text_offset_original + 2 * text_spaces:
|
||||||
|
y1 = (
|
||||||
|
orig_y1
|
||||||
|
+ r_o
|
||||||
|
+ text_height
|
||||||
|
+ text_offset_original
|
||||||
|
+ 2 * text_spaces
|
||||||
|
)
|
||||||
|
x1 = orig_x1 + r_o
|
||||||
|
|
||||||
|
# add text background
|
||||||
|
(text_width, text_height), _ = cv2.getTextSize(
|
||||||
|
f" {entity_name}", cv2.FONT_HERSHEY_COMPLEX, text_size, text_line
|
||||||
|
)
|
||||||
|
text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2 = (
|
||||||
|
x1,
|
||||||
|
y1 - (text_height + text_offset_original + 2 * text_spaces),
|
||||||
|
x1 + text_width,
|
||||||
|
y1,
|
||||||
|
)
|
||||||
|
|
||||||
|
for prev_bbox in previous_bboxes:
|
||||||
|
while is_overlapping(
|
||||||
|
(text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), prev_bbox
|
||||||
|
):
|
||||||
|
text_bg_y1 += (
|
||||||
|
text_height + text_offset_original + 2 * text_spaces
|
||||||
|
)
|
||||||
|
text_bg_y2 += (
|
||||||
|
text_height + text_offset_original + 2 * text_spaces
|
||||||
|
)
|
||||||
|
y1 += text_height + text_offset_original + 2 * text_spaces
|
||||||
|
|
||||||
|
if text_bg_y2 >= image_h:
|
||||||
|
text_bg_y1 = max(
|
||||||
|
0,
|
||||||
|
image_h
|
||||||
|
- (
|
||||||
|
text_height + text_offset_original + 2 * text_spaces
|
||||||
|
),
|
||||||
|
)
|
||||||
|
text_bg_y2 = image_h
|
||||||
|
y1 = image_h
|
||||||
|
break
|
||||||
|
|
||||||
|
alpha = 0.5
|
||||||
|
for i in range(text_bg_y1, text_bg_y2):
|
||||||
|
for j in range(text_bg_x1, text_bg_x2):
|
||||||
|
if i < image_h and j < image_w:
|
||||||
|
if j < text_bg_x1 + 1.35 * c_width:
|
||||||
|
# original color
|
||||||
|
bg_color = color
|
||||||
|
else:
|
||||||
|
# white
|
||||||
|
bg_color = [255, 255, 255]
|
||||||
|
new_image[i, j] = (
|
||||||
|
alpha * new_image[i, j]
|
||||||
|
+ (1 - alpha) * np.array(bg_color)
|
||||||
|
).astype(np.uint8)
|
||||||
|
|
||||||
|
cv2.putText(
|
||||||
|
new_image,
|
||||||
|
f" {entity_name}",
|
||||||
|
(x1, y1 - text_offset_original - 1 * text_spaces),
|
||||||
|
cv2.FONT_HERSHEY_COMPLEX,
|
||||||
|
text_size,
|
||||||
|
(0, 0, 0),
|
||||||
|
text_line,
|
||||||
|
cv2.LINE_AA,
|
||||||
|
)
|
||||||
|
# previous_locations.append((x1, y1))
|
||||||
|
previous_bboxes.append((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2))
|
||||||
|
|
||||||
|
pil_image = Image.fromarray(new_image[:, :, [2, 1, 0]])
|
||||||
|
if save_path:
|
||||||
|
pil_image.save(save_path)
|
||||||
|
if show:
|
||||||
|
pil_image.show()
|
||||||
|
|
||||||
|
return new_image
|
||||||
|
|
||||||
|
def generate_boxees(self, prompt, image_url):
|
||||||
|
image = self.get_image(image_url)
|
||||||
|
processed_text, entities = self.process_prompt(prompt, image)
|
||||||
|
self.draw_entity_boxes_on_image(image, entities, show=True)
|
@ -0,0 +1,51 @@
|
|||||||
|
from transformers import ViltProcessor, ViltForQuestionAnswering
|
||||||
|
import requests
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
class Vilt:
|
||||||
|
"""
|
||||||
|
Vision-and-Language Transformer (ViLT) model fine-tuned on VQAv2.
|
||||||
|
It was introduced in the paper ViLT: Vision-and-Language Transformer Without
|
||||||
|
Convolution or Region Supervision by Kim et al. and first released in this repository.
|
||||||
|
|
||||||
|
Disclaimer: The team releasing ViLT did not write a model card for this model
|
||||||
|
so this model card has been written by the Hugging Face team.
|
||||||
|
|
||||||
|
https://huggingface.co/dandelin/vilt-b32-finetuned-vqa
|
||||||
|
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> model = Vilt()
|
||||||
|
>>> output = model("What is this image", "http://images.cocodataset.org/val2017/000000039769.jpg")
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.processor = ViltProcessor.from_pretrained(
|
||||||
|
"dandelin/vilt-b32-finetuned-vqa"
|
||||||
|
)
|
||||||
|
self.model = ViltForQuestionAnswering.from_pretrained(
|
||||||
|
"dandelin/vilt-b32-finetuned-vqa"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, text: str, image_url: str):
|
||||||
|
"""
|
||||||
|
Run the model
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: str
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Download the image
|
||||||
|
image = Image.open(requests.get(image_url, stream=True).raw)
|
||||||
|
|
||||||
|
encoding = self.processor(image, text, return_tensors="pt")
|
||||||
|
|
||||||
|
# Forward pass
|
||||||
|
outputs = self.model(**encoding)
|
||||||
|
logits = outputs.logits
|
||||||
|
idx = logits.argmax(-1).item()
|
||||||
|
print("Predicted Answer:", self.model.config.id2label[idx])
|
Loading…
Reference in new issue