[FIXES][Fuyu]

pull/286/head
Kye 1 year ago
parent b053b61c42
commit 336c4c47f1

@ -84,6 +84,8 @@ class BaseMultiModalModel:
self.device = device self.device = device
self.max_new_tokens = max_new_tokens self.max_new_tokens = max_new_tokens
self.retries = retries self.retries = retries
self.system_prompt = system_prompt
self.meta_prompt = meta_prompt
self.chat_history = [] self.chat_history = []
def __call__(self, task: str, img: str, *args, **kwargs): def __call__(self, task: str, img: str, *args, **kwargs):
@ -309,3 +311,4 @@ class BaseMultiModalModel:
def set_max_length(self, max_length): def set_max_length(self, max_length):
"""Set max_length""" """Set max_length"""
self.max_length = max_length self.max_length = max_length

@ -1,7 +1,6 @@
from io import BytesIO
import requests
from PIL import Image from PIL import Image
from termcolor import colored
from transformers import ( from transformers import (
AutoTokenizer, AutoTokenizer,
FuyuForCausalLM, FuyuForCausalLM,
@ -9,25 +8,28 @@ from transformers import (
FuyuProcessor, FuyuProcessor,
) )
from swarms.models.base_multimodal_model import BaseMultiModalModel
class Fuyu:
class Fuyu(BaseMultiModalModel):
""" """
Fuyu model by Adept Fuyu model by Adept
Parameters Args:
---------- BaseMultiModalModel (BaseMultiModalModel): [description]
pretrained_path : str pretrained_path (str, optional): [description]. Defaults to "adept/fuyu-8b".
Path to the pretrained model device_map (str, optional): [description]. Defaults to "auto".
device_map : str max_new_tokens (int, optional): [description]. Defaults to 500.
Device to use for the model *args: [description]
max_new_tokens : int **kwargs: [description]
Maximum number of tokens to generate
Examples Examples:
-------- >>> from swarms.models import Fuyu
>>> fuyu = Fuyu() >>> model = Fuyu()
>>> fuyu("Hello, my name is", "path/to/image.png") >>> model.run("Hello, world!", "https://upload.wikimedia.org/wikipedia/commons/8/86/Id%C3%A9fix.JPG")
""" """
@ -39,6 +41,7 @@ class Fuyu:
*args, *args,
**kwargs, **kwargs,
): ):
super().__init__(*args, **kwargs)
self.pretrained_path = pretrained_path self.pretrained_path = pretrained_path
self.device_map = device_map self.device_map = device_map
self.max_new_tokens = max_new_tokens self.max_new_tokens = max_new_tokens
@ -63,33 +66,50 @@ class Fuyu:
image_pil = Image.open(img) image_pil = Image.open(img)
return image_pil return image_pil
def __call__(self, text: str, img: str): def run(self, text: str, img: str, *args, **kwargs):
"""Call the model with text and img paths""" """Run the pipeline
Args:
text (str): _description_
img (str): _description_
Returns:
_type_: _description_
"""
try:
img = self.get_img(img) img = self.get_img(img)
model_inputs = self.processor( model_inputs = self.processor(
text=text, images=[img], device=self.device_map text=text,
images=[img],
device=self.device_map,
*args,
**kwargs,
) )
for k, v in model_inputs.items(): for k, v in model_inputs.items():
model_inputs[k] = v.to(self.device_map) model_inputs[k] = v.to(self.device_map)
output = self.model.generate( output = self.model.generate(
**model_inputs, max_new_tokens=self.max_new_tokens max_new_tokens=self.max_new_tokens,
*args,
**model_inputs,
**kwargs,
) )
text = self.processor.batch_decode( text = self.processor.batch_decode(
output[:, -7:], skip_special_tokens=True output[:, -7:],
skip_special_tokens=True,
*args,
**kwargs,
) )
return print(str(text)) return print(str(text))
except Exception as error:
def get_img_from_web(self, img: str):
"""Get the image from the web"""
try:
response = requests.get(img)
response.raise_for_status()
image_pil = Image.open(BytesIO(response.content))
return image_pil
except requests.RequestException as error:
print( print(
f"Error fetching image from {img} and error: {error}" colored(
(
"Error in"
f" {self.__class__.__name__} pipeline:"
f" {error}"
),
"red",
)
) )
return None

@ -79,7 +79,7 @@ class Idefics(BaseMultiModalModel):
str str
] = "HuggingFaceM4/idefics-9b-instruct", ] = "HuggingFaceM4/idefics-9b-instruct",
device: Callable = autodetect_device, device: Callable = autodetect_device,
torch_dtype = torch.bfloat16, torch_dtype=torch.bfloat16,
max_length: int = 100, max_length: int = 100,
batched_mode: bool = True, batched_mode: bool = True,
*args, *args,

Loading…
Cancel
Save