Emeritus-21 commited on
Commit
8adac94
·
verified ·
1 Parent(s): cd9abb6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +623 -162
app.py CHANGED
@@ -1,197 +1,658 @@
 
 
 
 
1
  import os
 
 
 
 
 
 
 
2
  import gradio as gr
 
 
 
3
  from PIL import Image
 
4
  import torch
5
- from transformers import AutoProcessor, AutoModelForImageTextToText, Qwen2_5_VLForConditionalGeneration
6
- from reportlab.platypus import SimpleDocTemplate, Paragraph
7
- from reportlab.lib.styles import getSampleStyleSheet
8
- from docx import Document
9
- from gtts import gTTS
10
 
11
- # ---------------------------
12
- # Device & constants
13
- # ---------------------------
14
- DEVICE = "cuda" # Force GPU usage
15
- MAX_NEW_TOKENS_DEFAULT = 512
 
 
 
 
 
 
16
 
17
  # ---------------------------
18
- # Models config
 
 
19
  # ---------------------------
 
20
  MODEL_PATHS = {
21
- "Complex handwritings": ("prithivMLmods/Qwen2.5-VL-7B-Abliterated-Caption-it", Qwen2_5_VLForConditionalGeneration),
22
- "Simple/scanned handwriting": ("nanonets/Nanonets-OCR-s", Qwen2_5_VLForConditionalGeneration),
23
- "Structured handwriting": ("Emeritus-21/Finetuned-full-HTR-model", AutoModelForImageTextToText),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  }
25
 
 
 
 
 
 
 
 
 
26
  # ---------------------------
27
- # Lazy load models
 
 
28
  # ---------------------------
 
29
  _loaded_processors = {}
 
30
  _loaded_models = {}
31
 
32
- def load_model(name):
33
- if name in _loaded_models:
34
- return _loaded_processors[name], _loaded_models[name]
35
- repo_id, cls = MODEL_PATHS[name]
36
- processor = AutoProcessor.from_pretrained(repo_id, trust_remote_code=True)
37
- model = cls.from_pretrained(
38
- repo_id,
39
- trust_remote_code=True,
40
- torch_dtype=torch.float16,
41
- low_cpu_mem_usage=True
42
- ).to(DEVICE).eval()
43
- _loaded_processors[name] = processor
44
- _loaded_models[name] = model
45
- return processor, model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  # ---------------------------
48
- # OCR function (GPU ready)
 
 
49
  # ---------------------------
50
- @gr.utils.space_decorator # Spaces decorator to detect GPU
51
- def ocr_image_gpu(image: Image.Image, model_choice: str, query: str = None,
52
- max_new_tokens: int = MAX_NEW_TOKENS_DEFAULT, temperature: float = 0.1,
53
- top_p: float = 1.0, top_k: int = 0, repetition_penalty: float = 1.0):
54
- if image is None:
55
- return "Please upload an image."
56
- if model_choice not in MODEL_PATHS:
57
- return f"Invalid model: {model_choice}"
58
-
59
- processor, model = load_model(model_choice)
60
-
61
- prompt = query.strip() if query and query.strip() else (
62
- "You are a professional Handwritten OCR system.\n"
63
- "TASK: Read the handwritten image and transcribe the text EXACTLY as written.\n"
64
- "- Preserve original structure and line breaks.\n"
65
- "- Keep spacing, bullet points, numbering, and indentation.\n"
66
- "- Render tables as Markdown tables if present.\n"
67
- "- Do NOT autocorrect spelling or grammar.\n"
68
- "- Do NOT merge lines.\n"
69
- "Return RAW transcription only."
70
- )
71
-
72
- batch = processor(text=[prompt], images=[image], return_tensors="pt").to(DEVICE)
73
-
74
- with torch.inference_mode():
75
- output_ids = model.generate(
76
- **batch,
77
- max_new_tokens=max_new_tokens,
78
- do_sample=False,
79
- temperature=temperature,
80
- top_p=top_p,
81
- top_k=top_k,
82
- repetition_penalty=repetition_penalty,
83
- )
84
-
85
- # decode safely
86
- text = ""
87
- if hasattr(processor, "batch_decode"):
88
- text = processor.batch_decode(output_ids, skip_special_tokens=True)[0]
89
- elif hasattr(model, "tokenizer") and model.tokenizer is not None:
90
- text = model.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
91
- text = text.replace("<|im_end|>", "").strip()
92
- return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
  # ---------------------------
95
- # Export helpers
 
 
96
  # ---------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  def _safe_text(text: str) -> str:
98
- return (text or "").strip()
 
 
 
99
 
100
  def save_as_pdf(text):
101
- text = _safe_text(text)
102
- if not text:
103
- return None
104
- filepath = "output.pdf"
105
- doc = SimpleDocTemplate(filepath)
106
- styles = getSampleStyleSheet()
107
- flowables = [Paragraph(t, styles["Normal"]) for t in text.splitlines() if t != ""]
108
- if not flowables:
109
- flowables = [Paragraph(" ", styles["Normal"])]
110
- doc.build(flowables)
111
- return filepath
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
  def save_as_word(text):
114
- text = _safe_text(text)
115
- if not text:
116
- return None
117
- filepath = "output.docx"
118
- doc = Document()
119
- for line in text.splitlines():
120
- doc.add_paragraph(line)
121
- doc.save(filepath)
122
- return filepath
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
  def save_as_audio(text):
125
- text = _safe_text(text)
126
- if not text:
127
- return None
128
- filepath = "output.mp3"
129
- tts = gTTS(text)
130
- tts.save(filepath)
131
- return filepath
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
  # ---------------------------
134
- # Gradio UI
 
 
135
  # ---------------------------
 
136
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
137
- gr.Markdown("## ✍🏾 Wilson Handwritten OCR")
138
-
139
- model_choice = gr.Radio(
140
- choices=list(MODEL_PATHS.keys()),
141
- value=list(MODEL_PATHS.keys())[0],
142
- label="Select OCR Model",
143
- )
144
-
145
- query_input = gr.Textbox(
146
- label="Custom Prompt (optional)",
147
- placeholder="Leave empty for RAW structured output",
148
- )
149
-
150
- image_input = gr.Image(type="pil", label="Upload Image (desktop/mobile)")
151
-
152
- # Buttons first
153
- extract_btn = gr.Button("📤 Extract RAW Text", variant="primary")
154
- raw_output = gr.Textbox(
155
- label="📜 RAW Structured Output (exact as written)",
156
- lines=18,
157
- show_copy_button=True,
158
- )
159
- pdf_file = gr.File(label="PDF File")
160
- word_file = gr.File(label="Word File")
161
- audio_file = gr.File(label="Audio File")
162
-
163
- with gr.Accordion("⚙️ Advanced Options", open=False):
164
- max_new_tokens = gr.Slider(1, 2048, value=MAX_NEW_TOKENS_DEFAULT, step=1, label="Max new tokens")
165
- temperature = gr.Slider(0.1, 2.0, value=0.1, step=0.05, label="Temperature")
166
- top_p = gr.Slider(0.05, 1.0, value=1.0, step=0.05, label="Top-p (nucleus)")
167
- top_k = gr.Slider(0, 1000, value=0, step=1, label="Top-k")
168
- repetition_penalty = gr.Slider(0.8, 2.0, value=1.0, step=0.05, label="Repetition penalty")
169
-
170
- # Extract text
171
- extract_btn.click(
172
- fn=ocr_image_gpu,
173
- inputs=[image_input, model_choice, query_input,
174
- max_new_tokens, temperature, top_p, top_k, repetition_penalty],
175
- outputs=[raw_output]
176
- )
177
-
178
- # Export buttons
179
- pdf_btn = gr.Button("⬇️ Download as PDF")
180
- pdf_btn.click(fn=save_as_pdf, inputs=[raw_output], outputs=[pdf_file])
181
-
182
- word_btn = gr.Button("⬇️ Download as Word")
183
- word_btn.click(fn=save_as_word, inputs=[raw_output], outputs=[word_file])
184
-
185
- audio_btn = gr.Button("🔊 Download as Audio")
186
- audio_btn.click(fn=save_as_audio, inputs=[raw_output], outputs=[audio_file])
187
-
188
- # Clear button
189
- clear_btn = gr.Button("🧹 Clear")
190
- clear_btn.click(
191
- fn=lambda: ("", None, "", MAX_NEW_TOKENS_DEFAULT, 0.1, 1.0, 0, 1.0),
192
- outputs=[raw_output, image_input, query_input,
193
- max_new_tokens, temperature, top_p, top_k, repetition_penalty],
194
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
 
196
  if __name__ == "__main__":
197
- demo.queue(max_size=50).launch(show_error=True)
 
 
 
 
 
1
+
2
+
3
+
4
+
5
  import os
6
+
7
+ import time
8
+
9
+ from threading import Thread
10
+
11
+
12
+
13
  import gradio as gr
14
+
15
+ import spaces
16
+
17
  from PIL import Image
18
+
19
  import torch
 
 
 
 
 
20
 
21
+ from transformers import (
22
+
23
+     AutoProcessor,
24
+
25
+     AutoModelForImageTextToText,
26
+
27
+     Qwen2_5_VLForConditionalGeneration,
28
+
29
+ )
30
+
31
+
32
 
33
  # ---------------------------
34
+
35
+ # Models
36
+
37
  # ---------------------------
38
+
39
  MODEL_PATHS = {
40
+
41
+     "Model 1 (Complex handwrittings )": (
42
+
43
+         "prithivMLmods/Qwen2.5-VL-7B-Abliterated-Caption-it",
44
+
45
+         Qwen2_5_VLForConditionalGeneration,
46
+
47
+     ),
48
+
49
+     "Model 2 (simple and scanned handwritting )": (
50
+
51
+         "nanonets/Nanonets-OCR-s",
52
+
53
+         Qwen2_5_VLForConditionalGeneration,
54
+
55
+     ),
56
+
57
+     "Model 3 (structured handwritting)": (
58
+
59
+         "Emeritus-21/Finetuned-full-HTR-model",
60
+
61
+         AutoModelForImageTextToText,
62
+
63
+     ),
64
+
65
  }
66
 
67
+
68
+
69
+ MAX_NEW_TOKENS_DEFAULT = 512
70
+
71
+ device = "cuda" if torch.cuda.is_available() else "cpu"
72
+
73
+
74
+
75
  # ---------------------------
76
+
77
+ # Preload models at startup
78
+
79
  # ---------------------------
80
+
81
  _loaded_processors = {}
82
+
83
  _loaded_models = {}
84
 
85
+
86
+
87
+ print("🚀 Preloading models into GPU/CPU memory...")
88
+
89
+
90
+
91
+ for name, (repo_id, cls) in MODEL_PATHS.items():
92
+
93
+     try:
94
+
95
+         print(f"Loading {name} ...")
96
+
97
+         processor = AutoProcessor.from_pretrained(repo_id, trust_remote_code=True)
98
+
99
+         model = cls.from_pretrained(
100
+
101
+             repo_id,
102
+
103
+             trust_remote_code=True,
104
+
105
+             torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
106
+
107
+             low_cpu_mem_usage=True,
108
+
109
+         ).to(device).eval()
110
+
111
+         _loaded_processors[name] = processor
112
+
113
+         _loaded_models[name] = model
114
+
115
+         print(f"✅ {name} ready.")
116
+
117
+     except Exception as e:
118
+
119
+         print(f"⚠️ Failed to load {name}: {e}")
120
+
121
+
122
 
123
  # ---------------------------
124
+
125
+ # Warmup (GPU)
126
+
127
  # ---------------------------
128
+
129
+ @spaces.GPU
130
+
131
+ def warmup(progress=gr.Progress(track_tqdm=True)):
132
+
133
+     try:
134
+
135
+         default_model_choice = next(iter(MODEL_PATHS.keys()))
136
+
137
+         processor = _loaded_processors[default_model_choice]
138
+
139
+         model = _loaded_models[default_model_choice]
140
+
141
+         tokenizer = getattr(processor, "tokenizer", None)
142
+
143
+
144
+
145
+         messages = [{"role": "user", "content": [{"type": "text", "text": "Warmup."}]}]
146
+
147
+         if tokenizer and hasattr(tokenizer, "apply_chat_template"):
148
+
149
+             chat_prompt = tokenizer.apply_chat_template(
150
+
151
+                 messages, tokenize=False, add_generation_prompt=True
152
+
153
+             )
154
+
155
+         else:
156
+
157
+             chat_prompt = "Warmup."
158
+
159
+
160
+
161
+         inputs = processor(
162
+
163
+             text=[chat_prompt],
164
+
165
+             images=None,
166
+
167
+             return_tensors="pt"
168
+
169
+         ).to(device)
170
+
171
+
172
+
173
+         with torch.inference_mode():
174
+
175
+             _ = model.generate(**inputs, max_new_tokens=1)
176
+
177
+
178
+
179
+         return f"GPU warm and {default_model_choice} ready."
180
+
181
+     except Exception as e:
182
+
183
+         return f"Warmup skipped: {e}"
184
+
185
+
186
 
187
  # ---------------------------
188
+
189
+ # Helpers
190
+
191
  # ---------------------------
192
+
193
+ def _build_inputs(processor, tokenizer, image: Image.Image, prompt: str):
194
+
195
+     """Build processor inputs for text+image with/without chat template."""
196
+
197
+     messages = [
198
+
199
+         {
200
+
201
+             "role": "user",
202
+
203
+             "content": [
204
+
205
+                 {"type": "image", "image": image},
206
+
207
+                 {"type": "text", "text": prompt},
208
+
209
+             ],
210
+
211
+         }
212
+
213
+     ]
214
+
215
+     if tokenizer and hasattr(tokenizer, "apply_chat_template"):
216
+
217
+         chat_prompt = tokenizer.apply_chat_template(
218
+
219
+             messages, tokenize=False, add_generation_prompt=True
220
+
221
+         )
222
+
223
+         return processor(text=[chat_prompt], images=[image], return_tensors="pt")
224
+
225
+     # Fallback: plain prompt + image
226
+
227
+     return processor(text=[prompt], images=[image], return_tensors="pt")
228
+
229
+
230
+
231
+ def _decode_text(model, processor, tokenizer, output_ids):
232
+
233
+     """Robust decode for different processor/tokenizer setups."""
234
+
235
+     text = ""
236
+
237
+     try:
238
+
239
+         if hasattr(processor, "batch_decode"):
240
+
241
+             text = processor.batch_decode(output_ids, skip_special_tokens=True)[0]
242
+
243
+             return text
244
+
245
+     except Exception:
246
+
247
+         pass
248
+
249
+     try:
250
+
251
+         if tokenizer is not None:
252
+
253
+             text = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
254
+
255
+             return text
256
+
257
+     except Exception:
258
+
259
+         pass
260
+
261
+     try:
262
+
263
+         model_tok = getattr(model, "tokenizer", None)
264
+
265
+         if model_tok is not None:
266
+
267
+             text = model_tok.batch_decode(output_ids, skip_special_tokens=True)[0]
268
+
269
+             return text
270
+
271
+     except Exception:
272
+
273
+         pass
274
+
275
+     # Last-resort string
276
+
277
+     return str(output_ids)
278
+
279
+
280
+
281
+ def _default_prompt(query: str | None) -> str:
282
+
283
+     if query and query.strip():
284
+
285
+         return query.strip()
286
+
287
+     return (
288
+
289
+         "You are a professional Handwritten OCR system.\n"
290
+
291
+         "TASK: Read the handwritten image and transcribe the text EXACTLY as written.\n"
292
+
293
+         "- Preserve original structure and line breaks.\n"
294
+
295
+         "- Keep spacing, bullet points, numbering, and indentation.\n"
296
+
297
+         "- Render tables as Markdown tables if present.\n"
298
+
299
+         "- Do NOT autocorrect spelling or grammar.\n"
300
+
301
+         "- Do NOT merge lines.\n"
302
+
303
+         "Return RAW transcription only."
304
+
305
+     )
306
+
307
+
308
+
309
+ # ---------------------------
310
+
311
+ # OCR Function (NO STREAMING / NO yield)  ✅ FIX
312
+
313
+ # ---------------------------
314
+
315
+ @spaces.GPU
316
+
317
+ def ocr_image(
318
+
319
+     image: Image.Image,
320
+
321
+     model_choice: str,
322
+
323
+     query: str = None,
324
+
325
+     max_new_tokens: int = MAX_NEW_TOKENS_DEFAULT,
326
+
327
+     temperature: float = 0.1,
328
+
329
+     top_p: float = 1.0,
330
+
331
+     top_k: int = 0,
332
+
333
+     repetition_penalty: float = 1.0,
334
+
335
+     progress=gr.Progress(track_tqdm=True),
336
+
337
+ ):
338
+
339
+     if image is None:
340
+
341
+         return "Please upload or capture an image."
342
+
343
+
344
+
345
+     if model_choice not in _loaded_models:
346
+
347
+         return f"Invalid model: {model_choice}"
348
+
349
+
350
+
351
+     processor = _loaded_processors[model_choice]
352
+
353
+     model = _loaded_models[model_choice]
354
+
355
+     tokenizer = getattr(processor, "tokenizer", None)
356
+
357
+
358
+
359
+     prompt = _default_prompt(query)
360
+
361
+
362
+
363
+     # Build inputs
364
+
365
+     batch = _build_inputs(processor, tokenizer, image, prompt).to(device)
366
+
367
+
368
+
369
+     # Generate (no streaming)
370
+
371
+     with torch.inference_mode():
372
+
373
+         output_ids = model.generate(
374
+
375
+             **batch,
376
+
377
+             max_new_tokens=max_new_tokens,
378
+
379
+             do_sample=False,
380
+
381
+             temperature=temperature,
382
+
383
+             top_p=top_p,
384
+
385
+             top_k=top_k,
386
+
387
+             repetition_penalty=repetition_penalty,
388
+
389
+         )
390
+
391
+
392
+
393
+     # Decode
394
+
395
+     decoded = _decode_text(model, processor, tokenizer, output_ids)
396
+
397
+     cleaned = decoded.replace("<|im_end|>", "").strip()
398
+
399
+     return cleaned
400
+
401
+
402
+
403
+ # ---------------------------
404
+
405
+ # Export Helpers
406
+
407
+ # ---------------------------
408
+
409
+ from reportlab.platypus import SimpleDocTemplate, Paragraph
410
+
411
+ from reportlab.lib.styles import getSampleStyleSheet
412
+
413
+ from docx import Document
414
+
415
+
416
+
417
  def _safe_text(text: str) -> str:
418
+
419
+     return (text or "").strip()
420
+
421
+
422
 
423
  def save_as_pdf(text):
424
+
425
+     text = _safe_text(text)
426
+
427
+     if not text:
428
+
429
+         return None
430
+
431
+     filepath = "output.pdf"
432
+
433
+     doc = SimpleDocTemplate(filepath)
434
+
435
+     styles = getSampleStyleSheet()
436
+
437
+     flowables = [Paragraph(t, styles["Normal"]) for t in text.splitlines() if t != ""]
438
+
439
+     if not flowables:
440
+
441
+         flowables = [Paragraph(" ", styles["Normal"])]
442
+
443
+     doc.build(flowables)
444
+
445
+     return filepath
446
+
447
+
448
 
449
  def save_as_word(text):
450
+
451
+     text = _safe_text(text)
452
+
453
+     if not text:
454
+
455
+         return None
456
+
457
+     filepath = "output.docx"
458
+
459
+     doc = Document()
460
+
461
+     for line in text.splitlines():
462
+
463
+         doc.add_paragraph(line)
464
+
465
+     doc.save(filepath)
466
+
467
+     return filepath
468
+
469
+
470
+
471
+ # gTTS uses Google TTS (requires outbound internet). Wrap in try/except so Space doesn't crash.
472
 
473
  def save_as_audio(text):
474
+
475
+     text = _safe_text(text)
476
+
477
+     if not text:
478
+
479
+         return None
480
+
481
+     try:
482
+
483
+         from gTTS import gTTS
484
+
485
+         filepath = "output.mp3"
486
+
487
+         tts = gTTS(text)
488
+
489
+         tts.save(filepath)
490
+
491
+         return filepath
492
+
493
+     except Exception as e:
494
+
495
+         print(f"gTTS failed: {e}")
496
+
497
+         return None
498
+
499
+
500
 
501
  # ---------------------------
502
+
503
+ # Gradio Interface
504
+
505
  # ---------------------------
506
+
507
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
508
+
509
+     gr.Markdown("## ✍🏾 wilson Handwritten OCR ")
510
+
511
+
512
+
513
+     model_choice = gr.Radio(
514
+
515
+         choices=list(MODEL_PATHS.keys()),
516
+
517
+         value=list(MODEL_PATHS.keys())[0],
518
+
519
+         label="Select OCR Model",
520
+
521
+     )
522
+
523
+
524
+
525
+     with gr.Tab("🖼 Image Inference"):
526
+
527
+         query_input = gr.Textbox(
528
+
529
+             label="Custom Prompt (optional)",
530
+
531
+             placeholder="Leave empty for RAW structured output",
532
+
533
+         )
534
+
535
+
536
+
537
+         # Upload + Webcam (Gradio 4.x uses `sources`)
538
+
539
+         image_input = gr.Image(
540
+
541
+             type="pil",
542
+
543
+             label="Upload / Capture Handwritten Image",
544
+
545
+             sources=["upload", "webcam"],
546
+
547
+         )
548
+
549
+
550
+
551
+         with gr.Accordion("⚙️ Advanced Options", open=False):
552
+
553
+             max_new_tokens = gr.Slider(1, 2048, value=MAX_NEW_TOKENS_DEFAULT, step=1, label="Max new tokens")
554
+
555
+             temperature = gr.Slider(0.1, 2.0, value=0.1, step=0.05, label="Temperature")
556
+
557
+             top_p = gr.Slider(0.05, 1.0, value=1.0, step=0.05, label="Top-p (nucleus)")
558
+
559
+             top_k = gr.Slider(0, 1000, value=0, step=1, label="Top-k")
560
+
561
+             repetition_penalty = gr.Slider(0.8, 2.0, value=1.0, step=0.05, label="Repetition penalty")
562
+
563
+
564
+
565
+         with gr.Row():
566
+
567
+             extract_btn = gr.Button("📤 Extract RAW Text", variant="primary")
568
+
569
+             clear_btn = gr.Button("🧹 Clear")
570
+
571
+
572
+
573
+         raw_output = gr.Textbox(
574
+
575
+             label="📜 RAW Structured Output (exact as written)",
576
+
577
+             lines=18,
578
+
579
+             show_copy_button=True,
580
+
581
+         )
582
+
583
+
584
+
585
+         with gr.Row():
586
+
587
+             pdf_btn = gr.Button("⬇️ Download as PDF")
588
+
589
+             word_btn = gr.Button("⬇️ Download as Word")
590
+
591
+             audio_btn = gr.Button("🔊 Download as Audio")
592
+
593
+
594
+
595
+         pdf_file = gr.File(label="PDF File")
596
+
597
+         word_file = gr.File(label="Word File")
598
+
599
+         audio_file = gr.File(label="Audio File")
600
+
601
+
602
+
603
+         extract_btn.click(
604
+
605
+             fn=ocr_image,
606
+
607
+             inputs=[
608
+
609
+                 image_input,
610
+
611
+                 model_choice,
612
+
613
+                 query_input,
614
+
615
+                 max_new_tokens,
616
+
617
+                 temperature,
618
+
619
+                 top_p,
620
+
621
+                 top_k,
622
+
623
+                 repetition_penalty,
624
+
625
+             ],
626
+
627
+             outputs=[raw_output],
628
+
629
+             api_name="ocr_image",
630
+
631
+         )
632
+
633
+
634
+
635
+         pdf_btn.click(fn=save_as_pdf, inputs=[raw_output], outputs=[pdf_file])
636
+
637
+         word_btn.click(fn=save_as_word, inputs=[raw_output], outputs=[word_file])
638
+
639
+         audio_btn.click(fn=save_as_audio, inputs=[raw_output], outputs=[audio_file])
640
+
641
+
642
+
643
+         clear_btn.click(
644
+
645
+             fn=lambda: ("", None, "", MAX_NEW_TOKENS_DEFAULT, 0.1, 1.0, 0, 1.0),
646
+
647
+             outputs=[raw_output, image_input, query_input, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
648
+
649
+         )
650
+
651
+
652
 
653
  if __name__ == "__main__":
654
+
655
+     # Keep queue for GPU tasks; limit concurrency for stability.
656
+
657
+     demo.queue(max_size=50).launch(show_error=True)
658
+