|
|
|
@ -1,7 +1,6 @@
|
|
|
|
|
from io import BytesIO
|
|
|
|
|
|
|
|
|
|
import requests
|
|
|
|
|
from PIL import Image
|
|
|
|
|
from termcolor import colored
|
|
|
|
|
from transformers import (
|
|
|
|
|
AutoTokenizer,
|
|
|
|
|
FuyuForCausalLM,
|
|
|
|
@ -9,25 +8,28 @@ from transformers import (
|
|
|
|
|
FuyuProcessor,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
from swarms.models.base_multimodal_model import BaseMultiModalModel
|
|
|
|
|
|
|
|
|
|
class Fuyu:
|
|
|
|
|
|
|
|
|
|
class Fuyu(BaseMultiModalModel):
|
|
|
|
|
"""
|
|
|
|
|
Fuyu model by Adept
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
|
----------
|
|
|
|
|
pretrained_path : str
|
|
|
|
|
Path to the pretrained model
|
|
|
|
|
device_map : str
|
|
|
|
|
Device to use for the model
|
|
|
|
|
max_new_tokens : int
|
|
|
|
|
Maximum number of tokens to generate
|
|
|
|
|
|
|
|
|
|
Examples
|
|
|
|
|
--------
|
|
|
|
|
>>> fuyu = Fuyu()
|
|
|
|
|
>>> fuyu("Hello, my name is", "path/to/image.png")
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
BaseMultiModalModel (BaseMultiModalModel): [description]
|
|
|
|
|
pretrained_path (str, optional): [description]. Defaults to "adept/fuyu-8b".
|
|
|
|
|
device_map (str, optional): [description]. Defaults to "auto".
|
|
|
|
|
max_new_tokens (int, optional): [description]. Defaults to 500.
|
|
|
|
|
*args: [description]
|
|
|
|
|
**kwargs: [description]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
>>> from swarms.models import Fuyu
|
|
|
|
|
>>> model = Fuyu()
|
|
|
|
|
>>> model.run("Hello, world!", "https://upload.wikimedia.org/wikipedia/commons/8/86/Id%C3%A9fix.JPG")
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
@ -39,6 +41,7 @@ class Fuyu:
|
|
|
|
|
*args,
|
|
|
|
|
**kwargs,
|
|
|
|
|
):
|
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
self.pretrained_path = pretrained_path
|
|
|
|
|
self.device_map = device_map
|
|
|
|
|
self.max_new_tokens = max_new_tokens
|
|
|
|
@ -63,33 +66,50 @@ class Fuyu:
|
|
|
|
|
image_pil = Image.open(img)
|
|
|
|
|
return image_pil
|
|
|
|
|
|
|
|
|
|
def __call__(self, text: str, img: str):
|
|
|
|
|
"""Call the model with text and img paths"""
|
|
|
|
|
img = self.get_img(img)
|
|
|
|
|
model_inputs = self.processor(
|
|
|
|
|
text=text, images=[img], device=self.device_map
|
|
|
|
|
)
|
|
|
|
|
def run(self, text: str, img: str, *args, **kwargs):
|
|
|
|
|
"""Run the pipeline
|
|
|
|
|
|
|
|
|
|
for k, v in model_inputs.items():
|
|
|
|
|
model_inputs[k] = v.to(self.device_map)
|
|
|
|
|
|
|
|
|
|
output = self.model.generate(
|
|
|
|
|
**model_inputs, max_new_tokens=self.max_new_tokens
|
|
|
|
|
)
|
|
|
|
|
text = self.processor.batch_decode(
|
|
|
|
|
output[:, -7:], skip_special_tokens=True
|
|
|
|
|
)
|
|
|
|
|
return print(str(text))
|
|
|
|
|
Args:
|
|
|
|
|
text (str): _description_
|
|
|
|
|
img (str): _description_
|
|
|
|
|
|
|
|
|
|
def get_img_from_web(self, img: str):
|
|
|
|
|
"""Get the image from the web"""
|
|
|
|
|
Returns:
|
|
|
|
|
_type_: _description_
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
response = requests.get(img)
|
|
|
|
|
response.raise_for_status()
|
|
|
|
|
image_pil = Image.open(BytesIO(response.content))
|
|
|
|
|
return image_pil
|
|
|
|
|
except requests.RequestException as error:
|
|
|
|
|
print(
|
|
|
|
|
f"Error fetching image from {img} and error: {error}"
|
|
|
|
|
img = self.get_img(img)
|
|
|
|
|
model_inputs = self.processor(
|
|
|
|
|
text=text,
|
|
|
|
|
images=[img],
|
|
|
|
|
device=self.device_map,
|
|
|
|
|
*args,
|
|
|
|
|
**kwargs,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
for k, v in model_inputs.items():
|
|
|
|
|
model_inputs[k] = v.to(self.device_map)
|
|
|
|
|
|
|
|
|
|
output = self.model.generate(
|
|
|
|
|
max_new_tokens=self.max_new_tokens,
|
|
|
|
|
*args,
|
|
|
|
|
**model_inputs,
|
|
|
|
|
**kwargs,
|
|
|
|
|
)
|
|
|
|
|
text = self.processor.batch_decode(
|
|
|
|
|
output[:, -7:],
|
|
|
|
|
skip_special_tokens=True,
|
|
|
|
|
*args,
|
|
|
|
|
**kwargs,
|
|
|
|
|
)
|
|
|
|
|
return None
|
|
|
|
|
return print(str(text))
|
|
|
|
|
except Exception as error:
|
|
|
|
|
print(
|
|
|
|
|
colored(
|
|
|
|
|
(
|
|
|
|
|
"Error in"
|
|
|
|
|
f" {self.__class__.__name__} pipeline:"
|
|
|
|
|
f" {error}"
|
|
|
|
|
),
|
|
|
|
|
"red",
|
|
|
|
|
)
|
|
|
|
|
)
|