lolo / app.py
khodour's picture
Update app.py
24129ba verified
from PIL import Image
import torch
from transformers import NougatProcessor, VisionEncoderDecoderModel
# Load the model and processor
processor = NougatProcessor.from_pretrained("MohamedRashad/arabic-small-nougat")
model = VisionEncoderDecoderModel.from_pretrained("MohamedRashad/arabic-small-nougat")
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
context_length = 2048
def predict(img_path):
# Open and ensure RGB format
image = Image.open(img_path).convert("RGB")
# Prepare input
pixel_values = processor(images=image, return_tensors="pt").pixel_values
# Generate transcription
outputs = model.generate(
pixel_values.to(device),
min_length=1,
max_new_tokens=context_length,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
)
# Decode output
page_sequence = processor.batch_decode(outputs, skip_special_tokens=True)[0]
page_sequence = processor.post_process_generation(page_sequence, fix_markdown=False)
return page_sequence
# Test the OCR
print(predict("1.png"))