khodour commited on
Commit
24129ba
·
verified ·
1 Parent(s): 1994262

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -20
app.py CHANGED
@@ -1,9 +1,8 @@
1
  from PIL import Image
2
  import torch
3
  from transformers import NougatProcessor, VisionEncoderDecoderModel
4
- import gradio as gr
5
 
6
- # Load model and processor once at startup
7
  processor = NougatProcessor.from_pretrained("MohamedRashad/arabic-small-nougat")
8
  model = VisionEncoderDecoderModel.from_pretrained("MohamedRashad/arabic-small-nougat")
9
 
@@ -12,9 +11,9 @@ model.to(device)
12
 
13
  context_length = 2048
14
 
15
- def predict(image):
16
- # Ensure image is in RGB format
17
- image = image.convert("RGB")
18
 
19
  # Prepare input
20
  pixel_values = processor(images=image, return_tensors="pt").pixel_values
@@ -33,18 +32,5 @@ def predict(image):
33
 
34
  return page_sequence
35
 
36
- # Gradio Interface
37
- title = "Arabic Nougat OCR - Handwritten & Printed Document Recognizer"
38
- description = "Transcribe Arabic documents using a fine-tuned Nougat model."
39
-
40
- interface = gr.Interface(
41
- fn=predict,
42
- inputs=gr.Image(type="pil", label="Upload an Arabic Document"),
43
- outputs=gr.Textbox(label="Transcription", lines=15),
44
- title=title,
45
- description=description,
46
- examples=[["example_1.jpg"], ["example_2.jpg"]]
47
- )
48
-
49
- if __name__ == "__main__":
50
- interface.launch()
 
1
  from PIL import Image
2
  import torch
3
  from transformers import NougatProcessor, VisionEncoderDecoderModel
 
4
 
5
+ # Load the model and processor
6
  processor = NougatProcessor.from_pretrained("MohamedRashad/arabic-small-nougat")
7
  model = VisionEncoderDecoderModel.from_pretrained("MohamedRashad/arabic-small-nougat")
8
 
 
11
 
12
  context_length = 2048
13
 
14
+ def predict(img_path):
15
+ # Open and ensure RGB format
16
+ image = Image.open(img_path).convert("RGB")
17
 
18
  # Prepare input
19
  pixel_values = processor(images=image, return_tensors="pt").pixel_values
 
32
 
33
  return page_sequence
34
 
35
+ # Test the OCR
36
+ print(predict("1.png"))