diff --git a/playground/models/nougat.py b/playground/models/nougat.py index 198fee38..97e1f1a3 100644 --- a/playground/models/nougat.py +++ b/playground/models/nougat.py @@ -2,4 +2,4 @@ from swarms.models.nougat import Nougat nougat = Nougat() -out = nougat("path/to/image.png") +out = nougat("large.png") diff --git a/swarms/models/nougat.py b/swarms/models/nougat.py index 82bb95f5..0eceb362 100644 --- a/swarms/models/nougat.py +++ b/swarms/models/nougat.py @@ -18,7 +18,7 @@ class Nougat: """ Nougat - ArgsS: + Args: model_name_or_path: str, default="facebook/nougat-base" min_length: int, default=1 max_new_tokens: int, default=30 @@ -35,7 +35,7 @@ class Nougat: self, model_name_or_path="facebook/nougat-base", min_length: int = 1, - max_new_tokens: int = 30, + max_new_tokens: int = 5000, ): self.model_name_or_path = model_name_or_path self.min_length = min_length @@ -50,14 +50,17 @@ class Nougat: self.device = "cuda" if torch.cuda.is_available() else "cpu" self.model.to(self.device) - def get_image(self, img_path: str): + def get_image(self, img: str): """Get an image from a path""" - image = Image.open(img_path) - return image + img = Image.open(img) - def __call__(self, img_path: str): + if img.mode == "L": + img = img.convert("RGB") + return img + + def __call__(self, img: str): """Call the model with an image_path str as an input""" - image = Image.open(img_path) + image = Image.open(img) pixel_values = self.processor(image, return_tensors="pt").pixel_values # Generate transcriptions, here we only generate 30 tokens @@ -78,6 +81,7 @@ class Nougat: return out def clean_nougat_output(raw_output): + """Clean the output from nougat to be more readable""" # Define the pattern to extract the relevant data daily_balance_pattern = ( r"\*\*(\d{2}/\d{2}/\d{4})\*\*\n\n\*\*([\d,]+\.\d{2})\*\*"