[DEBUG][Idefics]

pull/286/head
Kye 1 year ago
parent d9976576af
commit 3317bad3a9

@ -86,12 +86,14 @@ class BaseMultiModalModel:
self.retries = retries self.retries = retries
self.chat_history = [] self.chat_history = []
@abstractmethod def __call__(self, task: str, img: str, *args, **kwargs):
def __call__(self, text: str, img: str):
"""Run the model""" """Run the model"""
pass return self.run(task, img, *args, **kwargs)
def run(self, task: str, img: str): @abstractmethod
def run(
self, task: Optional[str], img: Optional[str], *args, **kwargs
):
"""Run the model""" """Run the model"""
pass pass
@ -99,7 +101,7 @@ class BaseMultiModalModel:
"""Run the model asynchronously""" """Run the model asynchronously"""
pass pass
def get_img_from_web(self, img: str): def get_img_from_web(self, img: str, *args, **kwargs):
"""Get the image from the web""" """Get the image from the web"""
try: try:
response = requests.get(img) response = requests.get(img)
@ -127,9 +129,7 @@ class BaseMultiModalModel:
self.chat_history = [] self.chat_history = []
def run_many( def run_many(
self, self, tasks: List[str], imgs: List[str], *args, **kwargs
tasks: List[str],
imgs: List[str],
): ):
""" """
Run the model on multiple tasks and images all at once using concurrent Run the model on multiple tasks and images all at once using concurrent
@ -293,3 +293,19 @@ class BaseMultiModalModel:
numbers or letters and typically correspond to specific segments or parts of the image. numbers or letters and typically correspond to specific segments or parts of the image.
""" """
return META_PROMPT return META_PROMPT
def set_device(self, device):
"""
Changes the device used for inference.
Parameters
----------
device : str
The new device to use for inference.
"""
self.device = device
self.model.to(self.device)
def set_max_length(self, max_length):
"""Set max_length"""
self.max_length = max_length

@ -66,7 +66,11 @@ class HuggingfacePipeline(AbstractLLM):
except Exception as error: except Exception as error:
print( print(
colored( colored(
f"Error in {self.__class__.__name__} pipeline: {error}", (
"Error in"
f" {self.__class__.__name__} pipeline:"
f" {error}"
),
"red", "red",
) )
) )

@ -1,8 +1,23 @@
import torch import torch
from transformers import AutoProcessor, IdeficsForVisionText2Text from transformers import AutoProcessor, IdeficsForVisionText2Text
from termcolor import colored
from swarms.models.base_multimodal_model import BaseMultiModalModel
from typing import Optional, Callable
class Idefics: def autodetect_device():
"""
Autodetects the device to use for inference.
Returns
-------
str
The device to use for inference.
"""
return "cuda" if torch.cuda.is_available() else "cpu"
class Idefics(BaseMultiModalModel):
""" """
A class for multimodal inference using pre-trained models from the Hugging Face Hub. A class for multimodal inference using pre-trained models from the Hugging Face Hub.
@ -11,8 +26,8 @@ class Idefics:
---------- ----------
device : str device : str
The device to use for inference. The device to use for inference.
checkpoint : str, optional model_name : str, optional
The name of the pre-trained model checkpoint (default is "HuggingFaceM4/idefics-9b-instruct"). The name of the pre-trained model model_name (default is "HuggingFaceM4/idefics-9b-instruct").
processor : transformers.PreTrainedProcessor processor : transformers.PreTrainedProcessor
The pre-trained processor. The pre-trained processor.
max_length : int max_length : int
@ -26,8 +41,8 @@ class Idefics:
Generates text based on the provided prompts. Generates text based on the provided prompts.
chat(user_input) chat(user_input)
Engages in a continuous bidirectional conversation based on the user input. Engages in a continuous bidirectional conversation based on the user input.
set_checkpoint(checkpoint) set_model_name(model_name)
Changes the model checkpoint. Changes the model model_name.
set_device(device) set_device(device)
Changes the device used for inference. Changes the device used for inference.
set_max_length(max_length) set_max_length(max_length)
@ -50,7 +65,7 @@ class Idefics:
response = model.chat(user_input) response = model.chat(user_input)
print(response) print(response)
model.set_checkpoint("new_checkpoint") model.set_model_name("new_model_name")
model.set_device("cpu") model.set_device("cpu")
model.set_max_length(200) model.set_max_length(200)
model.clear_chat_history() model.clear_chat_history()
@ -60,35 +75,43 @@ class Idefics:
def __init__( def __init__(
self, self,
checkpoint="HuggingFaceM4/idefics-9b-instruct", model_name: Optional[
device=None, str
] = "HuggingFaceM4/idefics-9b-instruct",
device: Callable = autodetect_device,
torch_dtype = torch.bfloat16, torch_dtype = torch.bfloat16,
max_length=100, max_length: int = 100,
batched_mode: bool = True,
*args,
**kwargs,
): ):
# Initialize the parent class
super().__init__(*args, **kwargs)
self.model_name = model_name
self.device = device
self.max_length = max_length
self.batched_mode = batched_mode
self.chat_history = []
self.device = ( self.device = (
device device
if device if device
else ("cuda" if torch.cuda.is_available() else "cpu") else ("cuda" if torch.cuda.is_available() else "cpu")
) )
self.model = IdeficsForVisionText2Text.from_pretrained( self.model = IdeficsForVisionText2Text.from_pretrained(
checkpoint, model_name, torch_dtype=torch_dtype, *args, **kwargs
torch_dtype=torch_dtype,
).to(self.device) ).to(self.device)
self.processor = AutoProcessor.from_pretrained(checkpoint) self.processor = AutoProcessor.from_pretrained(model_name)
self.max_length = max_length
self.chat_history = []
def run(self, prompts, batched_mode=True): def run(self, task: str, *args, **kwargs) -> str:
""" """
Generates text based on the provided prompts. Generates text based on the provided prompts.
Parameters Parameters
---------- ----------
prompts : list task : str
A list of prompts. Each prompt is a list of text strings and images. the task to perform
batched_mode : bool, optional batched_mode : bool, optional
Whether to process the prompts in batched mode. If True, all prompts are Whether to process the prompts in batched mode. If True, all prompts are
processed together. If False, only the first prompt is processed (default is True). processed together. If False, only the first prompt is processed (default is True).
@ -98,14 +121,17 @@ class Idefics:
list list
A list of generated text strings. A list of generated text strings.
""" """
try:
inputs = ( inputs = (
self.processor( self.processor(
prompts, task,
add_end_of_utterance_token=False, add_end_of_utterance_token=False,
return_tensors="pt", return_tensors="pt",
*args,
**kwargs,
).to(self.device) ).to(self.device)
if batched_mode if self.batched_mode
else self.processor(prompts[0], return_tensors="pt").to( else self.processor(task, return_tensors="pt").to(
self.device self.device
) )
) )
@ -130,110 +156,28 @@ class Idefics:
) )
return generated_text return generated_text
def __call__(self, prompts, batched_mode=True): except Exception as error:
""" print(
Generates text based on the provided prompts. colored(
(
Parameters "Error in"
---------- f" {self.__class__.__name__} pipeline:"
prompts : list f" {error}"
A list of prompts. Each prompt is a list of text strings and images. ),
batched_mode : bool, optional "red",
Whether to process the prompts in batched mode.
If True, all prompts are processed together.
If False, only the first prompt is processed (default is True).
Returns
-------
list
A list of generated text strings.
"""
inputs = (
self.processor(
prompts,
add_end_of_utterance_token=False,
return_tensors="pt",
).to(self.device)
if batched_mode
else self.processor(prompts[0], return_tensors="pt").to(
self.device
)
)
exit_condition = self.processor.tokenizer(
"<end_of_utterance>", add_special_tokens=False
).input_ids
bad_words_ids = self.processor.tokenizer(
["<image>", "<fake_token_around_image"],
add_special_tokens=False,
).input_ids
generated_ids = self.model.generate(
**inputs,
eos_token_id=exit_condition,
bad_words_ids=bad_words_ids,
max_length=self.max_length,
) )
generated_text = self.processor.batch_decode(
generated_ids, skip_special_tokens=True
) )
return generated_text
def chat(self, user_input):
"""
Engages in a continuous bidirectional conversation based on the user input.
Parameters
----------
user_input : str
The user input.
Returns
-------
str
The model's response.
"""
self.chat_history.append(user_input)
prompts = [self.chat_history]
response = self.run(prompts)[0]
self.chat_history.append(response) def set_model_name(self, model_name):
return response
def set_checkpoint(self, checkpoint):
""" """
Changes the model checkpoint. Changes the model model_name.
Parameters Parameters
---------- ----------
checkpoint : str model_name : str
The name of the new pre-trained model checkpoint. The name of the new pre-trained model model_name.
""" """
self.model = IdeficsForVisionText2Text.from_pretrained( self.model = IdeficsForVisionText2Text.from_pretrained(
checkpoint, torch_dtype=torch.bfloat16 model_name, torch_dtype=torch.bfloat16
).to(self.device) ).to(self.device)
self.processor = AutoProcessor.from_pretrained(checkpoint) self.processor = AutoProcessor.from_pretrained(model_name)
def set_device(self, device):
"""
Changes the device used for inference.
Parameters
----------
device : str
The new device to use for inference.
"""
self.device = device
self.model.to(self.device)
def set_max_length(self, max_length):
"""Set max_length"""
self.max_length = max_length
def clear_chat_history(self):
"""Clear chat history"""
self.chat_history = []

@ -0,0 +1,13 @@
import torch
def autodetect_device():
"""
Autodetects the device to use for inference.
Returns
-------
str
The device to use for inference.
"""
return "cuda" if torch.cuda.is_available() else "cpu"
Loading…
Cancel
Save