|
|
|
@ -1,8 +1,23 @@
|
|
|
|
|
import torch
|
|
|
|
|
from transformers import AutoProcessor, IdeficsForVisionText2Text
|
|
|
|
|
from termcolor import colored
|
|
|
|
|
from swarms.models.base_multimodal_model import BaseMultiModalModel
|
|
|
|
|
from typing import Optional, Callable
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Idefics:
|
|
|
|
|
def autodetect_device():
|
|
|
|
|
"""
|
|
|
|
|
Autodetects the device to use for inference.
|
|
|
|
|
|
|
|
|
|
Returns
|
|
|
|
|
-------
|
|
|
|
|
str
|
|
|
|
|
The device to use for inference.
|
|
|
|
|
"""
|
|
|
|
|
return "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Idefics(BaseMultiModalModel):
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
A class for multimodal inference using pre-trained models from the Hugging Face Hub.
|
|
|
|
@ -11,8 +26,8 @@ class Idefics:
|
|
|
|
|
----------
|
|
|
|
|
device : str
|
|
|
|
|
The device to use for inference.
|
|
|
|
|
checkpoint : str, optional
|
|
|
|
|
The name of the pre-trained model checkpoint (default is "HuggingFaceM4/idefics-9b-instruct").
|
|
|
|
|
model_name : str, optional
|
|
|
|
|
The name of the pre-trained model model_name (default is "HuggingFaceM4/idefics-9b-instruct").
|
|
|
|
|
processor : transformers.PreTrainedProcessor
|
|
|
|
|
The pre-trained processor.
|
|
|
|
|
max_length : int
|
|
|
|
@ -26,8 +41,8 @@ class Idefics:
|
|
|
|
|
Generates text based on the provided prompts.
|
|
|
|
|
chat(user_input)
|
|
|
|
|
Engages in a continuous bidirectional conversation based on the user input.
|
|
|
|
|
set_checkpoint(checkpoint)
|
|
|
|
|
Changes the model checkpoint.
|
|
|
|
|
set_model_name(model_name)
|
|
|
|
|
Changes the model model_name.
|
|
|
|
|
set_device(device)
|
|
|
|
|
Changes the device used for inference.
|
|
|
|
|
set_max_length(max_length)
|
|
|
|
@ -50,7 +65,7 @@ class Idefics:
|
|
|
|
|
response = model.chat(user_input)
|
|
|
|
|
print(response)
|
|
|
|
|
|
|
|
|
|
model.set_checkpoint("new_checkpoint")
|
|
|
|
|
model.set_model_name("new_model_name")
|
|
|
|
|
model.set_device("cpu")
|
|
|
|
|
model.set_max_length(200)
|
|
|
|
|
model.clear_chat_history()
|
|
|
|
@ -60,35 +75,43 @@ class Idefics:
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
checkpoint="HuggingFaceM4/idefics-9b-instruct",
|
|
|
|
|
device=None,
|
|
|
|
|
model_name: Optional[
|
|
|
|
|
str
|
|
|
|
|
] = "HuggingFaceM4/idefics-9b-instruct",
|
|
|
|
|
device: Callable = autodetect_device,
|
|
|
|
|
torch_dtype = torch.bfloat16,
|
|
|
|
|
max_length=100,
|
|
|
|
|
max_length: int = 100,
|
|
|
|
|
batched_mode: bool = True,
|
|
|
|
|
*args,
|
|
|
|
|
**kwargs,
|
|
|
|
|
):
|
|
|
|
|
# Initialize the parent class
|
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
self.model_name = model_name
|
|
|
|
|
self.device = device
|
|
|
|
|
self.max_length = max_length
|
|
|
|
|
self.batched_mode = batched_mode
|
|
|
|
|
|
|
|
|
|
self.chat_history = []
|
|
|
|
|
self.device = (
|
|
|
|
|
device
|
|
|
|
|
if device
|
|
|
|
|
else ("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
)
|
|
|
|
|
self.model = IdeficsForVisionText2Text.from_pretrained(
|
|
|
|
|
checkpoint,
|
|
|
|
|
torch_dtype=torch_dtype,
|
|
|
|
|
model_name, torch_dtype=torch_dtype, *args, **kwargs
|
|
|
|
|
).to(self.device)
|
|
|
|
|
|
|
|
|
|
self.processor = AutoProcessor.from_pretrained(checkpoint)
|
|
|
|
|
|
|
|
|
|
self.max_length = max_length
|
|
|
|
|
|
|
|
|
|
self.chat_history = []
|
|
|
|
|
self.processor = AutoProcessor.from_pretrained(model_name)
|
|
|
|
|
|
|
|
|
|
def run(self, prompts, batched_mode=True):
|
|
|
|
|
def run(self, task: str, *args, **kwargs) -> str:
|
|
|
|
|
"""
|
|
|
|
|
Generates text based on the provided prompts.
|
|
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
|
----------
|
|
|
|
|
prompts : list
|
|
|
|
|
A list of prompts. Each prompt is a list of text strings and images.
|
|
|
|
|
task : str
|
|
|
|
|
the task to perform
|
|
|
|
|
batched_mode : bool, optional
|
|
|
|
|
Whether to process the prompts in batched mode. If True, all prompts are
|
|
|
|
|
processed together. If False, only the first prompt is processed (default is True).
|
|
|
|
@ -98,14 +121,17 @@ class Idefics:
|
|
|
|
|
list
|
|
|
|
|
A list of generated text strings.
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
inputs = (
|
|
|
|
|
self.processor(
|
|
|
|
|
prompts,
|
|
|
|
|
task,
|
|
|
|
|
add_end_of_utterance_token=False,
|
|
|
|
|
return_tensors="pt",
|
|
|
|
|
*args,
|
|
|
|
|
**kwargs,
|
|
|
|
|
).to(self.device)
|
|
|
|
|
if batched_mode
|
|
|
|
|
else self.processor(prompts[0], return_tensors="pt").to(
|
|
|
|
|
if self.batched_mode
|
|
|
|
|
else self.processor(task, return_tensors="pt").to(
|
|
|
|
|
self.device
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
@ -130,110 +156,28 @@ class Idefics:
|
|
|
|
|
)
|
|
|
|
|
return generated_text
|
|
|
|
|
|
|
|
|
|
def __call__(self, prompts, batched_mode=True):
|
|
|
|
|
"""
|
|
|
|
|
Generates text based on the provided prompts.
|
|
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
|
----------
|
|
|
|
|
prompts : list
|
|
|
|
|
A list of prompts. Each prompt is a list of text strings and images.
|
|
|
|
|
batched_mode : bool, optional
|
|
|
|
|
Whether to process the prompts in batched mode.
|
|
|
|
|
If True, all prompts are processed together.
|
|
|
|
|
If False, only the first prompt is processed (default is True).
|
|
|
|
|
|
|
|
|
|
Returns
|
|
|
|
|
-------
|
|
|
|
|
list
|
|
|
|
|
A list of generated text strings.
|
|
|
|
|
"""
|
|
|
|
|
inputs = (
|
|
|
|
|
self.processor(
|
|
|
|
|
prompts,
|
|
|
|
|
add_end_of_utterance_token=False,
|
|
|
|
|
return_tensors="pt",
|
|
|
|
|
).to(self.device)
|
|
|
|
|
if batched_mode
|
|
|
|
|
else self.processor(prompts[0], return_tensors="pt").to(
|
|
|
|
|
self.device
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
exit_condition = self.processor.tokenizer(
|
|
|
|
|
"<end_of_utterance>", add_special_tokens=False
|
|
|
|
|
).input_ids
|
|
|
|
|
|
|
|
|
|
bad_words_ids = self.processor.tokenizer(
|
|
|
|
|
["<image>", "<fake_token_around_image"],
|
|
|
|
|
add_special_tokens=False,
|
|
|
|
|
).input_ids
|
|
|
|
|
|
|
|
|
|
generated_ids = self.model.generate(
|
|
|
|
|
**inputs,
|
|
|
|
|
eos_token_id=exit_condition,
|
|
|
|
|
bad_words_ids=bad_words_ids,
|
|
|
|
|
max_length=self.max_length,
|
|
|
|
|
except Exception as error:
|
|
|
|
|
print(
|
|
|
|
|
colored(
|
|
|
|
|
(
|
|
|
|
|
"Error in"
|
|
|
|
|
f" {self.__class__.__name__} pipeline:"
|
|
|
|
|
f" {error}"
|
|
|
|
|
),
|
|
|
|
|
"red",
|
|
|
|
|
)
|
|
|
|
|
generated_text = self.processor.batch_decode(
|
|
|
|
|
generated_ids, skip_special_tokens=True
|
|
|
|
|
)
|
|
|
|
|
return generated_text
|
|
|
|
|
|
|
|
|
|
def chat(self, user_input):
|
|
|
|
|
"""
|
|
|
|
|
Engages in a continuous bidirectional conversation based on the user input.
|
|
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
|
----------
|
|
|
|
|
user_input : str
|
|
|
|
|
The user input.
|
|
|
|
|
|
|
|
|
|
Returns
|
|
|
|
|
-------
|
|
|
|
|
str
|
|
|
|
|
The model's response.
|
|
|
|
|
"""
|
|
|
|
|
self.chat_history.append(user_input)
|
|
|
|
|
|
|
|
|
|
prompts = [self.chat_history]
|
|
|
|
|
|
|
|
|
|
response = self.run(prompts)[0]
|
|
|
|
|
|
|
|
|
|
self.chat_history.append(response)
|
|
|
|
|
|
|
|
|
|
return response
|
|
|
|
|
|
|
|
|
|
def set_checkpoint(self, checkpoint):
|
|
|
|
|
def set_model_name(self, model_name):
|
|
|
|
|
"""
|
|
|
|
|
Changes the model checkpoint.
|
|
|
|
|
Changes the model model_name.
|
|
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
|
----------
|
|
|
|
|
checkpoint : str
|
|
|
|
|
The name of the new pre-trained model checkpoint.
|
|
|
|
|
model_name : str
|
|
|
|
|
The name of the new pre-trained model model_name.
|
|
|
|
|
"""
|
|
|
|
|
self.model = IdeficsForVisionText2Text.from_pretrained(
|
|
|
|
|
checkpoint, torch_dtype=torch.bfloat16
|
|
|
|
|
model_name, torch_dtype=torch.bfloat16
|
|
|
|
|
).to(self.device)
|
|
|
|
|
self.processor = AutoProcessor.from_pretrained(checkpoint)
|
|
|
|
|
|
|
|
|
|
def set_device(self, device):
|
|
|
|
|
"""
|
|
|
|
|
Changes the device used for inference.
|
|
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
|
----------
|
|
|
|
|
device : str
|
|
|
|
|
The new device to use for inference.
|
|
|
|
|
"""
|
|
|
|
|
self.device = device
|
|
|
|
|
self.model.to(self.device)
|
|
|
|
|
|
|
|
|
|
def set_max_length(self, max_length):
|
|
|
|
|
"""Set max_length"""
|
|
|
|
|
self.max_length = max_length
|
|
|
|
|
|
|
|
|
|
def clear_chat_history(self):
|
|
|
|
|
"""Clear chat history"""
|
|
|
|
|
self.chat_history = []
|
|
|
|
|
self.processor = AutoProcessor.from_pretrained(model_name)
|
|
|
|
|