File size: 1,080 Bytes
1994262
 
 
765a19d
24129ba
1994262
 
765a19d
1994262
 
765a19d
1994262
765a19d
24129ba
 
 
765a19d
1994262
 
765a19d
1994262
 
 
 
 
 
 
765a19d
1994262
 
 
765a19d
1994262
765a19d
24129ba
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from PIL import Image
import torch
from transformers import NougatProcessor, VisionEncoderDecoderModel

# Load the model and processor
processor = NougatProcessor.from_pretrained("MohamedRashad/arabic-small-nougat")
model = VisionEncoderDecoderModel.from_pretrained("MohamedRashad/arabic-small-nougat")

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

context_length = 2048

def predict(img_path):
    # Open and ensure RGB format
    image = Image.open(img_path).convert("RGB")

    # Prepare input
    pixel_values = processor(images=image, return_tensors="pt").pixel_values

    # Generate transcription
    outputs = model.generate(
        pixel_values.to(device),
        min_length=1,
        max_new_tokens=context_length,
        bad_words_ids=[[processor.tokenizer.unk_token_id]],
    )

    # Decode output
    page_sequence = processor.batch_decode(outputs, skip_special_tokens=True)[0]
    page_sequence = processor.post_process_generation(page_sequence, fix_markdown=False)

    return page_sequence

# Test the OCR
print(predict("1.png"))