From 336c4c47f1cb1ee6bce5b6bd1baf4635a61687c2 Mon Sep 17 00:00:00 2001 From: Kye Date: Tue, 5 Dec 2023 20:44:54 -0800 Subject: [PATCH] [FIXES][Fuyu] --- swarms/models/base_multimodal_model.py | 3 + swarms/models/fuyu.py | 106 +++++++++++++++---------- swarms/models/idefics.py | 2 +- swarms/utils/torch_utils.py | 2 +- 4 files changed, 68 insertions(+), 45 deletions(-) diff --git a/swarms/models/base_multimodal_model.py b/swarms/models/base_multimodal_model.py index 521844a9..28b21d64 100644 --- a/swarms/models/base_multimodal_model.py +++ b/swarms/models/base_multimodal_model.py @@ -84,6 +84,8 @@ class BaseMultiModalModel: self.device = device self.max_new_tokens = max_new_tokens self.retries = retries + self.system_prompt = system_prompt + self.meta_prompt = meta_prompt self.chat_history = [] def __call__(self, task: str, img: str, *args, **kwargs): @@ -309,3 +311,4 @@ class BaseMultiModalModel: def set_max_length(self, max_length): """Set max_length""" self.max_length = max_length + diff --git a/swarms/models/fuyu.py b/swarms/models/fuyu.py index c1e51199..e6f9b04f 100644 --- a/swarms/models/fuyu.py +++ b/swarms/models/fuyu.py @@ -1,7 +1,6 @@ -from io import BytesIO -import requests from PIL import Image +from termcolor import colored from transformers import ( AutoTokenizer, FuyuForCausalLM, @@ -9,25 +8,28 @@ from transformers import ( FuyuProcessor, ) +from swarms.models.base_multimodal_model import BaseMultiModalModel -class Fuyu: + +class Fuyu(BaseMultiModalModel): """ Fuyu model by Adept - - Parameters - ---------- - pretrained_path : str - Path to the pretrained model - device_map : str - Device to use for the model - max_new_tokens : int - Maximum number of tokens to generate - - Examples - -------- - >>> fuyu = Fuyu() - >>> fuyu("Hello, my name is", "path/to/image.png") + + Args: + BaseMultiModalModel (BaseMultiModalModel): [description] + pretrained_path (str, optional): [description]. Defaults to "adept/fuyu-8b". + device_map (str, optional): [description]. Defaults to "auto". + max_new_tokens (int, optional): [description]. Defaults to 500. + *args: [description] + **kwargs: [description] + + + + Examples: + >>> from swarms.models import Fuyu + >>> model = Fuyu() + >>> model.run("Hello, world!", "https://upload.wikimedia.org/wikipedia/commons/8/86/Id%C3%A9fix.JPG") """ @@ -39,6 +41,7 @@ class Fuyu: *args, **kwargs, ): + super().__init__(*args, **kwargs) self.pretrained_path = pretrained_path self.device_map = device_map self.max_new_tokens = max_new_tokens @@ -63,33 +66,50 @@ class Fuyu: image_pil = Image.open(img) return image_pil - def __call__(self, text: str, img: str): - """Call the model with text and img paths""" - img = self.get_img(img) - model_inputs = self.processor( - text=text, images=[img], device=self.device_map - ) + def run(self, text: str, img: str, *args, **kwargs): + """Run the pipeline - for k, v in model_inputs.items(): - model_inputs[k] = v.to(self.device_map) - - output = self.model.generate( - **model_inputs, max_new_tokens=self.max_new_tokens - ) - text = self.processor.batch_decode( - output[:, -7:], skip_special_tokens=True - ) - return print(str(text)) + Args: + text (str): _description_ + img (str): _description_ - def get_img_from_web(self, img: str): - """Get the image from the web""" + Returns: + _type_: _description_ + """ try: - response = requests.get(img) - response.raise_for_status() - image_pil = Image.open(BytesIO(response.content)) - return image_pil - except requests.RequestException as error: - print( - f"Error fetching image from {img} and error: {error}" + img = self.get_img(img) + model_inputs = self.processor( + text=text, + images=[img], + device=self.device_map, + *args, + **kwargs, + ) + + for k, v in model_inputs.items(): + model_inputs[k] = v.to(self.device_map) + + output = self.model.generate( + max_new_tokens=self.max_new_tokens, + *args, + **model_inputs, + **kwargs, + ) + text = self.processor.batch_decode( + output[:, -7:], + skip_special_tokens=True, + *args, + **kwargs, ) - return None + return print(str(text)) + except Exception as error: + print( + colored( + ( + "Error in" + f" {self.__class__.__name__} pipeline:" + f" {error}" + ), + "red", + ) + ) \ No newline at end of file diff --git a/swarms/models/idefics.py b/swarms/models/idefics.py index d00a1255..70a16622 100644 --- a/swarms/models/idefics.py +++ b/swarms/models/idefics.py @@ -79,7 +79,7 @@ class Idefics(BaseMultiModalModel): str ] = "HuggingFaceM4/idefics-9b-instruct", device: Callable = autodetect_device, - torch_dtype = torch.bfloat16, + torch_dtype=torch.bfloat16, max_length: int = 100, batched_mode: bool = True, *args, diff --git a/swarms/utils/torch_utils.py b/swarms/utils/torch_utils.py index 73cf90e1..41d2eb3f 100644 --- a/swarms/utils/torch_utils.py +++ b/swarms/utils/torch_utils.py @@ -1,4 +1,4 @@ -import torch +import torch def autodetect_device():