|
|
@ -1,9 +1,11 @@
|
|
|
|
from transformers import ViltProcessor, ViltForQuestionAnswering
|
|
|
|
|
|
|
|
import requests
|
|
|
|
import requests
|
|
|
|
from PIL import Image
|
|
|
|
from PIL import Image
|
|
|
|
|
|
|
|
from transformers import ViltForQuestionAnswering, ViltProcessor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from swarms.models.base_multimodal_model import BaseMultiModalModel
|
|
|
|
|
|
|
|
|
|
|
|
class Vilt:
|
|
|
|
|
|
|
|
|
|
|
|
class Vilt(BaseMultiModalModel):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Vision-and-Language Transformer (ViLT) model fine-tuned on VQAv2.
|
|
|
|
Vision-and-Language Transformer (ViLT) model fine-tuned on VQAv2.
|
|
|
|
It was introduced in the paper ViLT: Vision-and-Language Transformer Without
|
|
|
|
It was introduced in the paper ViLT: Vision-and-Language Transformer Without
|
|
|
@ -21,15 +23,21 @@ class Vilt:
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
def __init__(
|
|
|
|
|
|
|
|
self,
|
|
|
|
|
|
|
|
model_name: str = "dandelin/vilt-b32-finetuned-vqa",
|
|
|
|
|
|
|
|
*args,
|
|
|
|
|
|
|
|
**kwargs,
|
|
|
|
|
|
|
|
):
|
|
|
|
|
|
|
|
super().__init__(model_name, *args, **kwargs)
|
|
|
|
self.processor = ViltProcessor.from_pretrained(
|
|
|
|
self.processor = ViltProcessor.from_pretrained(
|
|
|
|
"dandelin/vilt-b32-finetuned-vqa"
|
|
|
|
model_name, *args, **kwargs
|
|
|
|
)
|
|
|
|
)
|
|
|
|
self.model = ViltForQuestionAnswering.from_pretrained(
|
|
|
|
self.model = ViltForQuestionAnswering.from_pretrained(
|
|
|
|
"dandelin/vilt-b32-finetuned-vqa"
|
|
|
|
model_name, *args, **kwargs
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def __call__(self, text: str, image_url: str):
|
|
|
|
def run(self, task: str = None, img: str = None, *args, **kwargs):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Run the model
|
|
|
|
Run the model
|
|
|
|
|
|
|
|
|
|
|
@ -38,9 +46,9 @@ class Vilt:
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
# Download the image
|
|
|
|
# Download the image
|
|
|
|
image = Image.open(requests.get(image_url, stream=True).raw)
|
|
|
|
image = Image.open(requests.get(img, stream=True).raw)
|
|
|
|
|
|
|
|
|
|
|
|
encoding = self.processor(image, text, return_tensors="pt")
|
|
|
|
encoding = self.processor(image, task, return_tensors="pt")
|
|
|
|
|
|
|
|
|
|
|
|
# Forward pass
|
|
|
|
# Forward pass
|
|
|
|
outputs = self.model(**encoding)
|
|
|
|
outputs = self.model(**encoding)
|
|
|
|