[FIXES][Fuyu]

pull/286/head
Kye 1 year ago
parent b053b61c42
commit 336c4c47f1

@ -84,6 +84,8 @@ class BaseMultiModalModel:
self.device = device
self.max_new_tokens = max_new_tokens
self.retries = retries
self.system_prompt = system_prompt
self.meta_prompt = meta_prompt
self.chat_history = []
def __call__(self, task: str, img: str, *args, **kwargs):
@ -309,3 +311,4 @@ class BaseMultiModalModel:
def set_max_length(self, max_length):
"""Set max_length"""
self.max_length = max_length

@ -1,7 +1,6 @@
from io import BytesIO
import requests
from PIL import Image
from termcolor import colored
from transformers import (
AutoTokenizer,
FuyuForCausalLM,
@ -9,25 +8,28 @@ from transformers import (
FuyuProcessor,
)
from swarms.models.base_multimodal_model import BaseMultiModalModel
class Fuyu:
class Fuyu(BaseMultiModalModel):
"""
Fuyu model by Adept
Parameters
----------
pretrained_path : str
Path to the pretrained model
device_map : str
Device to use for the model
max_new_tokens : int
Maximum number of tokens to generate
Args:
BaseMultiModalModel (BaseMultiModalModel): [description]
pretrained_path (str, optional): [description]. Defaults to "adept/fuyu-8b".
device_map (str, optional): [description]. Defaults to "auto".
max_new_tokens (int, optional): [description]. Defaults to 500.
*args: [description]
**kwargs: [description]
Examples
--------
>>> fuyu = Fuyu()
>>> fuyu("Hello, my name is", "path/to/image.png")
Examples:
>>> from swarms.models import Fuyu
>>> model = Fuyu()
>>> model.run("Hello, world!", "https://upload.wikimedia.org/wikipedia/commons/8/86/Id%C3%A9fix.JPG")
"""
@ -39,6 +41,7 @@ class Fuyu:
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.pretrained_path = pretrained_path
self.device_map = device_map
self.max_new_tokens = max_new_tokens
@ -63,33 +66,50 @@ class Fuyu:
image_pil = Image.open(img)
return image_pil
def __call__(self, text: str, img: str):
"""Call the model with text and img paths"""
def run(self, text: str, img: str, *args, **kwargs):
"""Run the pipeline
Args:
text (str): _description_
img (str): _description_
Returns:
_type_: _description_
"""
try:
img = self.get_img(img)
model_inputs = self.processor(
text=text, images=[img], device=self.device_map
text=text,
images=[img],
device=self.device_map,
*args,
**kwargs,
)
for k, v in model_inputs.items():
model_inputs[k] = v.to(self.device_map)
output = self.model.generate(
**model_inputs, max_new_tokens=self.max_new_tokens
max_new_tokens=self.max_new_tokens,
*args,
**model_inputs,
**kwargs,
)
text = self.processor.batch_decode(
output[:, -7:], skip_special_tokens=True
output[:, -7:],
skip_special_tokens=True,
*args,
**kwargs,
)
return print(str(text))
def get_img_from_web(self, img: str):
"""Get the image from the web"""
try:
response = requests.get(img)
response.raise_for_status()
image_pil = Image.open(BytesIO(response.content))
return image_pil
except requests.RequestException as error:
except Exception as error:
print(
f"Error fetching image from {img} and error: {error}"
colored(
(
"Error in"
f" {self.__class__.__name__} pipeline:"
f" {error}"
),
"red",
)
)
return None

Loading…
Cancel
Save