pull/188/head
Kye 1 year ago
parent d88a31a75b
commit ee1ac007d0

@ -2,4 +2,4 @@ from swarms.models.nougat import Nougat
nougat = Nougat() nougat = Nougat()
out = nougat("path/to/image.png") out = nougat("large.png")

@ -18,7 +18,7 @@ class Nougat:
""" """
Nougat Nougat
ArgsS: Args:
model_name_or_path: str, default="facebook/nougat-base" model_name_or_path: str, default="facebook/nougat-base"
min_length: int, default=1 min_length: int, default=1
max_new_tokens: int, default=30 max_new_tokens: int, default=30
@ -35,7 +35,7 @@ class Nougat:
self, self,
model_name_or_path="facebook/nougat-base", model_name_or_path="facebook/nougat-base",
min_length: int = 1, min_length: int = 1,
max_new_tokens: int = 30, max_new_tokens: int = 5000,
): ):
self.model_name_or_path = model_name_or_path self.model_name_or_path = model_name_or_path
self.min_length = min_length self.min_length = min_length
@ -50,14 +50,17 @@ class Nougat:
self.device = "cuda" if torch.cuda.is_available() else "cpu" self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model.to(self.device) self.model.to(self.device)
def get_image(self, img_path: str): def get_image(self, img: str):
"""Get an image from a path""" """Get an image from a path"""
image = Image.open(img_path) img = Image.open(img)
return image
def __call__(self, img_path: str): if img.mode == "L":
img = img.convert("RGB")
return img
def __call__(self, img: str):
"""Call the model with an image_path str as an input""" """Call the model with an image_path str as an input"""
image = Image.open(img_path) image = Image.open(img)
pixel_values = self.processor(image, return_tensors="pt").pixel_values pixel_values = self.processor(image, return_tensors="pt").pixel_values
# Generate transcriptions, here we only generate 30 tokens # Generate transcriptions, here we only generate 30 tokens
@ -78,6 +81,7 @@ class Nougat:
return out return out
def clean_nougat_output(raw_output): def clean_nougat_output(raw_output):
"""Clean the output from nougat to be more readable"""
# Define the pattern to extract the relevant data # Define the pattern to extract the relevant data
daily_balance_pattern = ( daily_balance_pattern = (
r"\*\*(\d{2}/\d{2}/\d{4})\*\*\n\n\*\*([\d,]+\.\d{2})\*\*" r"\*\*(\d{2}/\d{2}/\d{4})\*\*\n\n\*\*([\d,]+\.\d{2})\*\*"

Loading…
Cancel
Save