|
|
|
@ -18,7 +18,7 @@ class Nougat:
|
|
|
|
|
"""
|
|
|
|
|
Nougat
|
|
|
|
|
|
|
|
|
|
ArgsS:
|
|
|
|
|
Args:
|
|
|
|
|
model_name_or_path: str, default="facebook/nougat-base"
|
|
|
|
|
min_length: int, default=1
|
|
|
|
|
max_new_tokens: int, default=30
|
|
|
|
@ -35,7 +35,7 @@ class Nougat:
|
|
|
|
|
self,
|
|
|
|
|
model_name_or_path="facebook/nougat-base",
|
|
|
|
|
min_length: int = 1,
|
|
|
|
|
max_new_tokens: int = 30,
|
|
|
|
|
max_new_tokens: int = 5000,
|
|
|
|
|
):
|
|
|
|
|
self.model_name_or_path = model_name_or_path
|
|
|
|
|
self.min_length = min_length
|
|
|
|
@ -50,14 +50,17 @@ class Nougat:
|
|
|
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
self.model.to(self.device)
|
|
|
|
|
|
|
|
|
|
def get_image(self, img_path: str):
|
|
|
|
|
def get_image(self, img: str):
|
|
|
|
|
"""Get an image from a path"""
|
|
|
|
|
image = Image.open(img_path)
|
|
|
|
|
return image
|
|
|
|
|
img = Image.open(img)
|
|
|
|
|
|
|
|
|
|
def __call__(self, img_path: str):
|
|
|
|
|
if img.mode == "L":
|
|
|
|
|
img = img.convert("RGB")
|
|
|
|
|
return img
|
|
|
|
|
|
|
|
|
|
def __call__(self, img: str):
|
|
|
|
|
"""Call the model with an image_path str as an input"""
|
|
|
|
|
image = Image.open(img_path)
|
|
|
|
|
image = Image.open(img)
|
|
|
|
|
pixel_values = self.processor(image, return_tensors="pt").pixel_values
|
|
|
|
|
|
|
|
|
|
# Generate transcriptions, here we only generate 30 tokens
|
|
|
|
@ -78,6 +81,7 @@ class Nougat:
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
def clean_nougat_output(raw_output):
|
|
|
|
|
"""Clean the output from nougat to be more readable"""
|
|
|
|
|
# Define the pattern to extract the relevant data
|
|
|
|
|
daily_balance_pattern = (
|
|
|
|
|
r"\*\*(\d{2}/\d{2}/\d{4})\*\*\n\n\*\*([\d,]+\.\d{2})\*\*"
|
|
|
|
|