[Idefics -> BaseMultiModalModel] [Vilt => BaseMultiModalModel]

pull/294/head^2
Kye 1 year ago
parent 4bef09a252
commit 79d8f149b7

@ -102,9 +102,13 @@ class Idefics(BaseMultiModalModel):
model_name, torch_dtype=torch_dtype, *args, **kwargs
).to(self.device)
self.processor = AutoProcessor.from_pretrained(model_name)
self.processor = AutoProcessor.from_pretrained(
model_name, *args, **kwargs
)
def run(self, task: str, *args, **kwargs) -> str:
def run(
self, task: str = None, img: str = None, *args, **kwargs
) -> str:
"""
Generates text based on the provided prompts.

@ -61,6 +61,8 @@ class OpenAITTS(AbstractLLM):
chunk_size=1024 * 1024,
autosave: bool = False,
saved_filepath: str = None,
*args,
**kwargs,
):
super().__init__()
self.model_name = model_name

@ -1,9 +1,11 @@
from transformers import ViltProcessor, ViltForQuestionAnswering
import requests
from PIL import Image
from transformers import ViltForQuestionAnswering, ViltProcessor
from swarms.models.base_multimodal_model import BaseMultiModalModel
class Vilt:
class Vilt(BaseMultiModalModel):
"""
Vision-and-Language Transformer (ViLT) model fine-tuned on VQAv2.
It was introduced in the paper ViLT: Vision-and-Language Transformer Without
@ -21,15 +23,21 @@ class Vilt:
"""
def __init__(self):
def __init__(
self,
model_name: str = "dandelin/vilt-b32-finetuned-vqa",
*args,
**kwargs,
):
super().__init__(model_name, *args, **kwargs)
self.processor = ViltProcessor.from_pretrained(
"dandelin/vilt-b32-finetuned-vqa"
model_name, *args, **kwargs
)
self.model = ViltForQuestionAnswering.from_pretrained(
"dandelin/vilt-b32-finetuned-vqa"
model_name, *args, **kwargs
)
def __call__(self, text: str, image_url: str):
def run(self, task: str = None, img: str = None, *args, **kwargs):
"""
Run the model
@ -38,9 +46,9 @@ class Vilt:
"""
# Download the image
image = Image.open(requests.get(image_url, stream=True).raw)
image = Image.open(requests.get(img, stream=True).raw)
encoding = self.processor(image, text, return_tensors="pt")
encoding = self.processor(image, task, return_tensors="pt")
# Forward pass
outputs = self.model(**encoding)

@ -15,6 +15,52 @@ except ImportError as error:
class BaseStructure(ABC):
"""Base structure.
Attributes:
name (Optional[str]): _description_
description (Optional[str]): _description_
save_metadata (bool): _description_
save_artifact_path (Optional[str]): _description_
save_metadata_path (Optional[str]): _description_
save_error_path (Optional[str]): _description_
Methods:
run: _description_
save_to_file: _description_
load_from_file: _description_
save_metadata: _description_
load_metadata: _description_
log_error: _description_
save_artifact: _description_
load_artifact: _description_
log_event: _description_
run_async: _description_
save_metadata_async: _description_
load_metadata_async: _description_
log_error_async: _description_
save_artifact_async: _description_
load_artifact_async: _description_
log_event_async: _description_
asave_to_file: _description_
aload_from_file: _description_
run_in_thread: _description_
save_metadata_in_thread: _description_
run_concurrent: _description_
compress_data: _description_
decompres_data: _description_
run_batched: _description_
load_config: _description_
backup_data: _description_
monitor_resources: _description_
run_with_resources: _description_
run_with_resources_batched: _description_
Examples:
"""
def __init__(
self,
name: Optional[str] = None,

Loading…
Cancel
Save