You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
105 lines
2.8 KiB
105 lines
2.8 KiB
"""
|
|
Nougat by Meta
|
|
|
|
Good for:
|
|
- transcribe Scientific PDFs into an easy to use markdown
|
|
format
|
|
- Extracting information from PDFs
|
|
- Extracting metadata from pdfs
|
|
|
|
"""
|
|
import re
|
|
import torch
|
|
from PIL import Image
|
|
from transformers import NougatProcessor, VisionEncoderDecoderModel
|
|
|
|
|
|
class Nougat:
|
|
"""
|
|
Nougat
|
|
|
|
Args:
|
|
model_name_or_path: str, default="facebook/nougat-base"
|
|
min_length: int, default=1
|
|
max_new_tokens: int, default=30
|
|
|
|
Usage:
|
|
>>> from swarms.models.nougat import Nougat
|
|
>>> nougat = Nougat()
|
|
>>> nougat("path/to/image.png")
|
|
|
|
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model_name_or_path="facebook/nougat-base",
|
|
min_length: int = 1,
|
|
max_new_tokens: int = 5000,
|
|
):
|
|
self.model_name_or_path = model_name_or_path
|
|
self.min_length = min_length
|
|
self.max_new_tokens = max_new_tokens
|
|
|
|
self.processor = NougatProcessor.from_pretrained(
|
|
self.model_name_or_path
|
|
)
|
|
self.model = VisionEncoderDecoderModel.from_pretrained(
|
|
self.model_name_or_path
|
|
)
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
self.model.to(self.device)
|
|
|
|
def get_image(self, img: str):
|
|
"""Get an image from a path"""
|
|
img = Image.open(img)
|
|
|
|
if img.mode == "L":
|
|
img = img.convert("RGB")
|
|
return img
|
|
|
|
def __call__(self, img: str):
|
|
"""Call the model with an image_path str as an input"""
|
|
image = Image.open(img)
|
|
pixel_values = self.processor(
|
|
image, return_tensors="pt"
|
|
).pixel_values
|
|
|
|
# Generate transcriptions, here we only generate 30 tokens
|
|
outputs = self.model.generate(
|
|
pixel_values.to(self.device),
|
|
min_length=self.min_length,
|
|
max_new_tokens=self.max_new_tokens,
|
|
)
|
|
|
|
sequence = self.processor.batch_decode(
|
|
outputs, skip_special_tokens=True
|
|
)[0]
|
|
sequence = self.processor.post_process_generation(
|
|
sequence, fix_markdown=False
|
|
)
|
|
|
|
out = print(sequence)
|
|
return out
|
|
|
|
def clean_nougat_output(raw_output):
|
|
"""Clean the output from nougat to be more readable"""
|
|
# Define the pattern to extract the relevant data
|
|
daily_balance_pattern = (
|
|
r"\*\*(\d{2}/\d{2}/\d{4})\*\*\n\n\*\*([\d,]+\.\d{2})\*\*"
|
|
)
|
|
|
|
# Find all matches of the pattern
|
|
matches = re.findall(daily_balance_pattern, raw_output)
|
|
|
|
# Convert the matches to a readable format
|
|
cleaned_data = [
|
|
"Date: {}, Amount: {}".format(
|
|
date, amount.replace(",", "")
|
|
)
|
|
for date, amount in matches
|
|
]
|
|
|
|
# Join the cleaned data with new lines for readability
|
|
return "\n".join(cleaned_data)
|