diff --git a/swarms/models/base_multimodal_model.py b/swarms/models/base_multimodal_model.py index 30e45149..521844a9 100644 --- a/swarms/models/base_multimodal_model.py +++ b/swarms/models/base_multimodal_model.py @@ -86,12 +86,14 @@ class BaseMultiModalModel: self.retries = retries self.chat_history = [] - @abstractmethod - def __call__(self, text: str, img: str): + def __call__(self, task: str, img: str, *args, **kwargs): """Run the model""" - pass + return self.run(task, img, *args, **kwargs) - def run(self, task: str, img: str): + @abstractmethod + def run( + self, task: Optional[str], img: Optional[str], *args, **kwargs + ): """Run the model""" pass @@ -99,7 +101,7 @@ class BaseMultiModalModel: """Run the model asynchronously""" pass - def get_img_from_web(self, img: str): + def get_img_from_web(self, img: str, *args, **kwargs): """Get the image from the web""" try: response = requests.get(img) @@ -127,9 +129,7 @@ class BaseMultiModalModel: self.chat_history = [] def run_many( - self, - tasks: List[str], - imgs: List[str], + self, tasks: List[str], imgs: List[str], *args, **kwargs ): """ Run the model on multiple tasks and images all at once using concurrent @@ -293,3 +293,19 @@ class BaseMultiModalModel: numbers or letters and typically correspond to specific segments or parts of the image. """ return META_PROMPT + + def set_device(self, device): + """ + Changes the device used for inference. + + Parameters + ---------- + device : str + The new device to use for inference. + """ + self.device = device + self.model.to(self.device) + + def set_max_length(self, max_length): + """Set max_length""" + self.max_length = max_length diff --git a/swarms/models/huggingface_pipeline.py b/swarms/models/huggingface_pipeline.py index 6598c3d6..e61d1080 100644 --- a/swarms/models/huggingface_pipeline.py +++ b/swarms/models/huggingface_pipeline.py @@ -66,7 +66,11 @@ class HuggingfacePipeline(AbstractLLM): except Exception as error: print( colored( - f"Error in {self.__class__.__name__} pipeline: {error}", + ( + "Error in" + f" {self.__class__.__name__} pipeline:" + f" {error}" + ), "red", ) ) diff --git a/swarms/models/idefics.py b/swarms/models/idefics.py index 7c505d8a..d00a1255 100644 --- a/swarms/models/idefics.py +++ b/swarms/models/idefics.py @@ -1,8 +1,23 @@ import torch from transformers import AutoProcessor, IdeficsForVisionText2Text +from termcolor import colored +from swarms.models.base_multimodal_model import BaseMultiModalModel +from typing import Optional, Callable -class Idefics: +def autodetect_device(): + """ + Autodetects the device to use for inference. + + Returns + ------- + str + The device to use for inference. + """ + return "cuda" if torch.cuda.is_available() else "cpu" + + +class Idefics(BaseMultiModalModel): """ A class for multimodal inference using pre-trained models from the Hugging Face Hub. @@ -11,8 +26,8 @@ class Idefics: ---------- device : str The device to use for inference. - checkpoint : str, optional - The name of the pre-trained model checkpoint (default is "HuggingFaceM4/idefics-9b-instruct"). + model_name : str, optional + The name of the pre-trained model model_name (default is "HuggingFaceM4/idefics-9b-instruct"). processor : transformers.PreTrainedProcessor The pre-trained processor. max_length : int @@ -26,8 +41,8 @@ class Idefics: Generates text based on the provided prompts. chat(user_input) Engages in a continuous bidirectional conversation based on the user input. - set_checkpoint(checkpoint) - Changes the model checkpoint. + set_model_name(model_name) + Changes the model model_name. set_device(device) Changes the device used for inference. set_max_length(max_length) @@ -50,7 +65,7 @@ class Idefics: response = model.chat(user_input) print(response) - model.set_checkpoint("new_checkpoint") + model.set_model_name("new_model_name") model.set_device("cpu") model.set_max_length(200) model.clear_chat_history() @@ -60,35 +75,43 @@ class Idefics: def __init__( self, - checkpoint="HuggingFaceM4/idefics-9b-instruct", - device=None, - torch_dtype=torch.bfloat16, - max_length=100, + model_name: Optional[ + str + ] = "HuggingFaceM4/idefics-9b-instruct", + device: Callable = autodetect_device, + torch_dtype = torch.bfloat16, + max_length: int = 100, + batched_mode: bool = True, + *args, + **kwargs, ): + # Initialize the parent class + super().__init__(*args, **kwargs) + self.model_name = model_name + self.device = device + self.max_length = max_length + self.batched_mode = batched_mode + + self.chat_history = [] self.device = ( device if device else ("cuda" if torch.cuda.is_available() else "cpu") ) self.model = IdeficsForVisionText2Text.from_pretrained( - checkpoint, - torch_dtype=torch_dtype, + model_name, torch_dtype=torch_dtype, *args, **kwargs ).to(self.device) - self.processor = AutoProcessor.from_pretrained(checkpoint) - - self.max_length = max_length - - self.chat_history = [] + self.processor = AutoProcessor.from_pretrained(model_name) - def run(self, prompts, batched_mode=True): + def run(self, task: str, *args, **kwargs) -> str: """ Generates text based on the provided prompts. Parameters ---------- - prompts : list - A list of prompts. Each prompt is a list of text strings and images. + task : str + the task to perform batched_mode : bool, optional Whether to process the prompts in batched mode. If True, all prompts are processed together. If False, only the first prompt is processed (default is True). @@ -98,142 +121,63 @@ class Idefics: list A list of generated text strings. """ - inputs = ( - self.processor( - prompts, - add_end_of_utterance_token=False, - return_tensors="pt", - ).to(self.device) - if batched_mode - else self.processor(prompts[0], return_tensors="pt").to( - self.device + try: + inputs = ( + self.processor( + task, + add_end_of_utterance_token=False, + return_tensors="pt", + *args, + **kwargs, + ).to(self.device) + if self.batched_mode + else self.processor(task, return_tensors="pt").to( + self.device + ) ) - ) - exit_condition = self.processor.tokenizer( - "", add_special_tokens=False - ).input_ids + exit_condition = self.processor.tokenizer( + "", add_special_tokens=False + ).input_ids - bad_words_ids = self.processor.tokenizer( - ["", "", "", add_special_tokens=False - ).input_ids - - bad_words_ids = self.processor.tokenizer( - ["", "