[DEBUG][Idefics]

pull/286/head
Kye 1 year ago
parent d9976576af
commit 3317bad3a9

@ -86,12 +86,14 @@ class BaseMultiModalModel:
self.retries = retries
self.chat_history = []
@abstractmethod
def __call__(self, text: str, img: str):
def __call__(self, task: str, img: str, *args, **kwargs):
"""Run the model"""
pass
return self.run(task, img, *args, **kwargs)
def run(self, task: str, img: str):
@abstractmethod
def run(
self, task: Optional[str], img: Optional[str], *args, **kwargs
):
"""Run the model"""
pass
@ -99,7 +101,7 @@ class BaseMultiModalModel:
"""Run the model asynchronously"""
pass
def get_img_from_web(self, img: str):
def get_img_from_web(self, img: str, *args, **kwargs):
"""Get the image from the web"""
try:
response = requests.get(img)
@ -127,9 +129,7 @@ class BaseMultiModalModel:
self.chat_history = []
def run_many(
self,
tasks: List[str],
imgs: List[str],
self, tasks: List[str], imgs: List[str], *args, **kwargs
):
"""
Run the model on multiple tasks and images all at once using concurrent
@ -293,3 +293,19 @@ class BaseMultiModalModel:
numbers or letters and typically correspond to specific segments or parts of the image.
"""
return META_PROMPT
def set_device(self, device):
"""
Changes the device used for inference.
Parameters
----------
device : str
The new device to use for inference.
"""
self.device = device
self.model.to(self.device)
def set_max_length(self, max_length):
"""Set max_length"""
self.max_length = max_length

@ -66,7 +66,11 @@ class HuggingfacePipeline(AbstractLLM):
except Exception as error:
print(
colored(
f"Error in {self.__class__.__name__} pipeline: {error}",
(
"Error in"
f" {self.__class__.__name__} pipeline:"
f" {error}"
),
"red",
)
)

@ -1,8 +1,23 @@
import torch
from transformers import AutoProcessor, IdeficsForVisionText2Text
from termcolor import colored
from swarms.models.base_multimodal_model import BaseMultiModalModel
from typing import Optional, Callable
class Idefics:
def autodetect_device():
"""
Autodetects the device to use for inference.
Returns
-------
str
The device to use for inference.
"""
return "cuda" if torch.cuda.is_available() else "cpu"
class Idefics(BaseMultiModalModel):
"""
A class for multimodal inference using pre-trained models from the Hugging Face Hub.
@ -11,8 +26,8 @@ class Idefics:
----------
device : str
The device to use for inference.
checkpoint : str, optional
The name of the pre-trained model checkpoint (default is "HuggingFaceM4/idefics-9b-instruct").
model_name : str, optional
The name of the pre-trained model model_name (default is "HuggingFaceM4/idefics-9b-instruct").
processor : transformers.PreTrainedProcessor
The pre-trained processor.
max_length : int
@ -26,8 +41,8 @@ class Idefics:
Generates text based on the provided prompts.
chat(user_input)
Engages in a continuous bidirectional conversation based on the user input.
set_checkpoint(checkpoint)
Changes the model checkpoint.
set_model_name(model_name)
Changes the model model_name.
set_device(device)
Changes the device used for inference.
set_max_length(max_length)
@ -50,7 +65,7 @@ class Idefics:
response = model.chat(user_input)
print(response)
model.set_checkpoint("new_checkpoint")
model.set_model_name("new_model_name")
model.set_device("cpu")
model.set_max_length(200)
model.clear_chat_history()
@ -60,35 +75,43 @@ class Idefics:
def __init__(
self,
checkpoint="HuggingFaceM4/idefics-9b-instruct",
device=None,
model_name: Optional[
str
] = "HuggingFaceM4/idefics-9b-instruct",
device: Callable = autodetect_device,
torch_dtype = torch.bfloat16,
max_length=100,
max_length: int = 100,
batched_mode: bool = True,
*args,
**kwargs,
):
# Initialize the parent class
super().__init__(*args, **kwargs)
self.model_name = model_name
self.device = device
self.max_length = max_length
self.batched_mode = batched_mode
self.chat_history = []
self.device = (
device
if device
else ("cuda" if torch.cuda.is_available() else "cpu")
)
self.model = IdeficsForVisionText2Text.from_pretrained(
checkpoint,
torch_dtype=torch_dtype,
model_name, torch_dtype=torch_dtype, *args, **kwargs
).to(self.device)
self.processor = AutoProcessor.from_pretrained(checkpoint)
self.max_length = max_length
self.chat_history = []
self.processor = AutoProcessor.from_pretrained(model_name)
def run(self, prompts, batched_mode=True):
def run(self, task: str, *args, **kwargs) -> str:
"""
Generates text based on the provided prompts.
Parameters
----------
prompts : list
A list of prompts. Each prompt is a list of text strings and images.
task : str
the task to perform
batched_mode : bool, optional
Whether to process the prompts in batched mode. If True, all prompts are
processed together. If False, only the first prompt is processed (default is True).
@ -98,14 +121,17 @@ class Idefics:
list
A list of generated text strings.
"""
try:
inputs = (
self.processor(
prompts,
task,
add_end_of_utterance_token=False,
return_tensors="pt",
*args,
**kwargs,
).to(self.device)
if batched_mode
else self.processor(prompts[0], return_tensors="pt").to(
if self.batched_mode
else self.processor(task, return_tensors="pt").to(
self.device
)
)
@ -130,110 +156,28 @@ class Idefics:
)
return generated_text
def __call__(self, prompts, batched_mode=True):
"""
Generates text based on the provided prompts.
Parameters
----------
prompts : list
A list of prompts. Each prompt is a list of text strings and images.
batched_mode : bool, optional
Whether to process the prompts in batched mode.
If True, all prompts are processed together.
If False, only the first prompt is processed (default is True).
Returns
-------
list
A list of generated text strings.
"""
inputs = (
self.processor(
prompts,
add_end_of_utterance_token=False,
return_tensors="pt",
).to(self.device)
if batched_mode
else self.processor(prompts[0], return_tensors="pt").to(
self.device
)
)
exit_condition = self.processor.tokenizer(
"<end_of_utterance>", add_special_tokens=False
).input_ids
bad_words_ids = self.processor.tokenizer(
["<image>", "<fake_token_around_image"],
add_special_tokens=False,
).input_ids
generated_ids = self.model.generate(
**inputs,
eos_token_id=exit_condition,
bad_words_ids=bad_words_ids,
max_length=self.max_length,
except Exception as error:
print(
colored(
(
"Error in"
f" {self.__class__.__name__} pipeline:"
f" {error}"
),
"red",
)
generated_text = self.processor.batch_decode(
generated_ids, skip_special_tokens=True
)
return generated_text
def chat(self, user_input):
"""
Engages in a continuous bidirectional conversation based on the user input.
Parameters
----------
user_input : str
The user input.
Returns
-------
str
The model's response.
"""
self.chat_history.append(user_input)
prompts = [self.chat_history]
response = self.run(prompts)[0]
self.chat_history.append(response)
return response
def set_checkpoint(self, checkpoint):
def set_model_name(self, model_name):
"""
Changes the model checkpoint.
Changes the model model_name.
Parameters
----------
checkpoint : str
The name of the new pre-trained model checkpoint.
model_name : str
The name of the new pre-trained model model_name.
"""
self.model = IdeficsForVisionText2Text.from_pretrained(
checkpoint, torch_dtype=torch.bfloat16
model_name, torch_dtype=torch.bfloat16
).to(self.device)
self.processor = AutoProcessor.from_pretrained(checkpoint)
def set_device(self, device):
"""
Changes the device used for inference.
Parameters
----------
device : str
The new device to use for inference.
"""
self.device = device
self.model.to(self.device)
def set_max_length(self, max_length):
"""Set max_length"""
self.max_length = max_length
def clear_chat_history(self):
"""Clear chat history"""
self.chat_history = []
self.processor = AutoProcessor.from_pretrained(model_name)

@ -0,0 +1,13 @@
import torch
def autodetect_device():
"""
Autodetects the device to use for inference.
Returns
-------
str
The device to use for inference.
"""
return "cuda" if torch.cuda.is_available() else "cpu"
Loading…
Cancel
Save