Emeritus-21 commited on
Commit
406b226
·
verified ·
1 Parent(s): 8adac94

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +116 -625
app.py CHANGED
@@ -1,658 +1,149 @@
 
1
 
2
-
3
-
4
-
5
- import os
6
-
7
- import time
8
-
9
  from threading import Thread
10
-
11
-
12
-
13
  import gradio as gr
14
-
15
  import spaces
16
-
17
  from PIL import Image
18
-
19
  import torch
 
 
 
 
20
 
21
- from transformers import (
22
-
23
-     AutoProcessor,
24
-
25
-     AutoModelForImageTextToText,
26
-
27
-     Qwen2_5_VLForConditionalGeneration,
28
-
29
- )
30
-
31
-
32
-
33
- # ---------------------------
34
-
35
- # Models
36
-
37
- # ---------------------------
38
-
39
  MODEL_PATHS = {
40
-
41
-     "Model 1 (Complex handwrittings )": (
42
-
43
-         "prithivMLmods/Qwen2.5-VL-7B-Abliterated-Caption-it",
44
-
45
-         Qwen2_5_VLForConditionalGeneration,
46
-
47
-     ),
48
-
49
-     "Model 2 (simple and scanned handwritting )": (
50
-
51
-         "nanonets/Nanonets-OCR-s",
52
-
53
-         Qwen2_5_VLForConditionalGeneration,
54
-
55
-     ),
56
-
57
-     "Model 3 (structured handwritting)": (
58
-
59
-         "Emeritus-21/Finetuned-full-HTR-model",
60
-
61
-         AutoModelForImageTextToText,
62
-
63
-     ),
64
-
65
  }
66
 
67
-
68
-
69
  MAX_NEW_TOKENS_DEFAULT = 512
70
-
71
  device = "cuda" if torch.cuda.is_available() else "cpu"
72
 
73
-
74
-
75
- # ---------------------------
76
-
77
- # Preload models at startup
78
-
79
- # ---------------------------
80
-
81
- _loaded_processors = {}
82
-
83
- _loaded_models = {}
84
-
85
-
86
-
87
  print("🚀 Preloading models into GPU/CPU memory...")
88
-
89
-
90
-
91
  for name, (repo_id, cls) in MODEL_PATHS.items():
92
-
93
-     try:
94
-
95
-         print(f"Loading {name} ...")
96
-
97
-         processor = AutoProcessor.from_pretrained(repo_id, trust_remote_code=True)
98
-
99
-         model = cls.from_pretrained(
100
-
101
-             repo_id,
102
-
103
-             trust_remote_code=True,
104
-
105
-             torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
106
-
107
-             low_cpu_mem_usage=True,
108
-
109
-         ).to(device).eval()
110
-
111
-         _loaded_processors[name] = processor
112
-
113
-         _loaded_models[name] = model
114
-
115
-         print(f"✅ {name} ready.")
116
-
117
-     except Exception as e:
118
-
119
-         print(f"⚠️ Failed to load {name}: {e}")
120
-
121
-
122
-
123
- # ---------------------------
124
-
125
- # Warmup (GPU)
126
-
127
- # ---------------------------
128
-
129
  @spaces.GPU
130
-
131
  def warmup(progress=gr.Progress(track_tqdm=True)):
132
-
133
-     try:
134
-
135
-         default_model_choice = next(iter(MODEL_PATHS.keys()))
136
-
137
-         processor = _loaded_processors[default_model_choice]
138
-
139
-         model = _loaded_models[default_model_choice]
140
-
141
-         tokenizer = getattr(processor, "tokenizer", None)
142
-
143
-
144
-
145
-         messages = [{"role": "user", "content": [{"type": "text", "text": "Warmup."}]}]
146
-
147
-         if tokenizer and hasattr(tokenizer, "apply_chat_template"):
148
-
149
-             chat_prompt = tokenizer.apply_chat_template(
150
-
151
-                 messages, tokenize=False, add_generation_prompt=True
152
-
153
-             )
154
-
155
-         else:
156
-
157
-             chat_prompt = "Warmup."
158
-
159
-
160
-
161
-         inputs = processor(
162
-
163
-             text=[chat_prompt],
164
-
165
-             images=None,
166
-
167
-             return_tensors="pt"
168
-
169
-         ).to(device)
170
-
171
-
172
-
173
-         with torch.inference_mode():
174
-
175
-             _ = model.generate(**inputs, max_new_tokens=1)
176
-
177
-
178
-
179
-         return f"GPU warm and {default_model_choice} ready."
180
-
181
-     except Exception as e:
182
-
183
-         return f"Warmup skipped: {e}"
184
-
185
-
186
-
187
- # ---------------------------
188
-
189
- # Helpers
190
-
191
- # ---------------------------
192
-
193
  def _build_inputs(processor, tokenizer, image: Image.Image, prompt: str):
194
-
195
-     """Build processor inputs for text+image with/without chat template."""
196
-
197
-     messages = [
198
-
199
-         {
200
-
201
-             "role": "user",
202
-
203
-             "content": [
204
-
205
-                 {"type": "image", "image": image},
206
-
207
-                 {"type": "text", "text": prompt},
208
-
209
-             ],
210
-
211
-         }
212
-
213
-     ]
214
-
215
-     if tokenizer and hasattr(tokenizer, "apply_chat_template"):
216
-
217
-         chat_prompt = tokenizer.apply_chat_template(
218
-
219
-             messages, tokenize=False, add_generation_prompt=True
220
-
221
-         )
222
-
223
-         return processor(text=[chat_prompt], images=[image], return_tensors="pt")
224
-
225
-     # Fallback: plain prompt + image
226
-
227
-     return processor(text=[prompt], images=[image], return_tensors="pt")
228
-
229
-
230
 
231
  def _decode_text(model, processor, tokenizer, output_ids):
232
-
233
-     """Robust decode for different processor/tokenizer setups."""
234
-
235
-     text = ""
236
-
237
-     try:
238
-
239
-         if hasattr(processor, "batch_decode"):
240
-
241
-             text = processor.batch_decode(output_ids, skip_special_tokens=True)[0]
242
-
243
-             return text
244
-
245
-     except Exception:
246
-
247
-         pass
248
-
249
-     try:
250
-
251
-         if tokenizer is not None:
252
-
253
-             text = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
254
-
255
-             return text
256
-
257
-     except Exception:
258
-
259
-         pass
260
-
261
-     try:
262
-
263
-         model_tok = getattr(model, "tokenizer", None)
264
-
265
-         if model_tok is not None:
266
-
267
-             text = model_tok.batch_decode(output_ids, skip_special_tokens=True)[0]
268
-
269
-             return text
270
-
271
-     except Exception:
272
-
273
-         pass
274
-
275
-     # Last-resort string
276
-
277
-     return str(output_ids)
278
-
279
-
280
 
281
  def _default_prompt(query: str | None) -> str:
282
-
283
-     if query and query.strip():
284
-
285
-         return query.strip()
286
-
287
-     return (
288
-
289
-         "You are a professional Handwritten OCR system.\n"
290
-
291
-         "TASK: Read the handwritten image and transcribe the text EXACTLY as written.\n"
292
-
293
-         "- Preserve original structure and line breaks.\n"
294
-
295
-         "- Keep spacing, bullet points, numbering, and indentation.\n"
296
-
297
-         "- Render tables as Markdown tables if present.\n"
298
-
299
-         "- Do NOT autocorrect spelling or grammar.\n"
300
-
301
-         "- Do NOT merge lines.\n"
302
-
303
-         "Return RAW transcription only."
304
-
305
-     )
306
-
307
-
308
-
309
- # ---------------------------
310
-
311
- # OCR Function (NO STREAMING / NO yield)  ✅ FIX
312
-
313
- # ---------------------------
314
-
315
  @spaces.GPU
316
-
317
- def ocr_image(
318
-
319
-     image: Image.Image,
320
-
321
-     model_choice: str,
322
-
323
-     query: str = None,
324
-
325
-     max_new_tokens: int = MAX_NEW_TOKENS_DEFAULT,
326
-
327
-     temperature: float = 0.1,
328
-
329
-     top_p: float = 1.0,
330
-
331
-     top_k: int = 0,
332
-
333
-     repetition_penalty: float = 1.0,
334
-
335
-     progress=gr.Progress(track_tqdm=True),
336
-
337
- ):
338
-
339
-     if image is None:
340
-
341
-         return "Please upload or capture an image."
342
-
343
-
344
-
345
-     if model_choice not in _loaded_models:
346
-
347
-         return f"Invalid model: {model_choice}"
348
-
349
-
350
-
351
-     processor = _loaded_processors[model_choice]
352
-
353
-     model = _loaded_models[model_choice]
354
-
355
-     tokenizer = getattr(processor, "tokenizer", None)
356
-
357
-
358
-
359
-     prompt = _default_prompt(query)
360
-
361
-
362
-
363
-     # Build inputs
364
-
365
-     batch = _build_inputs(processor, tokenizer, image, prompt).to(device)
366
-
367
-
368
-
369
-     # Generate (no streaming)
370
-
371
-     with torch.inference_mode():
372
-
373
-         output_ids = model.generate(
374
-
375
-             **batch,
376
-
377
-             max_new_tokens=max_new_tokens,
378
-
379
-             do_sample=False,
380
-
381
-             temperature=temperature,
382
-
383
-             top_p=top_p,
384
-
385
-             top_k=top_k,
386
-
387
-             repetition_penalty=repetition_penalty,
388
-
389
-         )
390
-
391
-
392
-
393
-     # Decode
394
-
395
-     decoded = _decode_text(model, processor, tokenizer, output_ids)
396
-
397
-     cleaned = decoded.replace("<|im_end|>", "").strip()
398
-
399
-     return cleaned
400
-
401
-
402
-
403
- # ---------------------------
404
-
405
- # Export Helpers
406
-
407
- # ---------------------------
408
-
409
- from reportlab.platypus import SimpleDocTemplate, Paragraph
410
-
411
- from reportlab.lib.styles import getSampleStyleSheet
412
-
413
- from docx import Document
414
-
415
-
416
-
417
- def _safe_text(text: str) -> str:
418
-
419
-     return (text or "").strip()
420
-
421
-
422
 
423
  def save_as_pdf(text):
424
-
425
-     text = _safe_text(text)
426
-
427
-     if not text:
428
-
429
-         return None
430
-
431
-     filepath = "output.pdf"
432
-
433
-     doc = SimpleDocTemplate(filepath)
434
-
435
-     styles = getSampleStyleSheet()
436
-
437
-     flowables = [Paragraph(t, styles["Normal"]) for t in text.splitlines() if t != ""]
438
-
439
-     if not flowables:
440
-
441
-         flowables = [Paragraph(" ", styles["Normal"])]
442
-
443
-     doc.build(flowables)
444
-
445
-     return filepath
446
-
447
-
448
 
449
  def save_as_word(text):
450
-
451
-     text = _safe_text(text)
452
-
453
-     if not text:
454
-
455
-         return None
456
-
457
-     filepath = "output.docx"
458
-
459
-     doc = Document()
460
-
461
-     for line in text.splitlines():
462
-
463
-         doc.add_paragraph(line)
464
-
465
-     doc.save(filepath)
466
-
467
-     return filepath
468
-
469
-
470
-
471
- # gTTS uses Google TTS (requires outbound internet). Wrap in try/except so Space doesn't crash.
472
 
473
  def save_as_audio(text):
 
 
 
 
474
 
475
-     text = _safe_text(text)
476
-
477
-     if not text:
478
-
479
-         return None
480
-
481
-     try:
482
-
483
-         from gTTS import gTTS
484
-
485
-         filepath = "output.mp3"
486
-
487
-         tts = gTTS(text)
488
-
489
-         tts.save(filepath)
490
-
491
-         return filepath
492
-
493
-     except Exception as e:
494
-
495
-         print(f"gTTS failed: {e}")
496
-
497
-         return None
498
-
499
-
500
-
501
- # ---------------------------
502
-
503
- # Gradio Interface
504
-
505
- # ---------------------------
506
-
507
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
508
-
509
-     gr.Markdown("## ✍🏾 wilson Handwritten OCR ")
510
-
511
-
512
-
513
-     model_choice = gr.Radio(
514
-
515
-         choices=list(MODEL_PATHS.keys()),
516
-
517
-         value=list(MODEL_PATHS.keys())[0],
518
-
519
-         label="Select OCR Model",
520
-
521
-     )
522
-
523
-
524
-
525
-     with gr.Tab("🖼 Image Inference"):
526
-
527
-         query_input = gr.Textbox(
528
-
529
-             label="Custom Prompt (optional)",
530
-
531
-             placeholder="Leave empty for RAW structured output",
532
-
533
-         )
534
-
535
-
536
-
537
-         # Upload + Webcam (Gradio 4.x uses `sources`)
538
-
539
-         image_input = gr.Image(
540
-
541
-             type="pil",
542
-
543
-             label="Upload / Capture Handwritten Image",
544
-
545
-             sources=["upload", "webcam"],
546
-
547
-         )
548
-
549
-
550
-
551
-         with gr.Accordion("⚙️ Advanced Options", open=False):
552
-
553
-             max_new_tokens = gr.Slider(1, 2048, value=MAX_NEW_TOKENS_DEFAULT, step=1, label="Max new tokens")
554
-
555
-             temperature = gr.Slider(0.1, 2.0, value=0.1, step=0.05, label="Temperature")
556
-
557
-             top_p = gr.Slider(0.05, 1.0, value=1.0, step=0.05, label="Top-p (nucleus)")
558
-
559
-             top_k = gr.Slider(0, 1000, value=0, step=1, label="Top-k")
560
-
561
-             repetition_penalty = gr.Slider(0.8, 2.0, value=1.0, step=0.05, label="Repetition penalty")
562
-
563
-
564
-
565
-         with gr.Row():
566
-
567
-             extract_btn = gr.Button("📤 Extract RAW Text", variant="primary")
568
-
569
-             clear_btn = gr.Button("🧹 Clear")
570
-
571
-
572
-
573
-         raw_output = gr.Textbox(
574
-
575
-             label="📜 RAW Structured Output (exact as written)",
576
-
577
-             lines=18,
578
-
579
-             show_copy_button=True,
580
-
581
-         )
582
-
583
-
584
-
585
-         with gr.Row():
586
-
587
-             pdf_btn = gr.Button("⬇️ Download as PDF")
588
-
589
-             word_btn = gr.Button("⬇️ Download as Word")
590
-
591
-             audio_btn = gr.Button("🔊 Download as Audio")
592
-
593
-
594
-
595
-         pdf_file = gr.File(label="PDF File")
596
-
597
-         word_file = gr.File(label="Word File")
598
-
599
-         audio_file = gr.File(label="Audio File")
600
-
601
-
602
-
603
-         extract_btn.click(
604
-
605
-             fn=ocr_image,
606
-
607
-             inputs=[
608
-
609
-                 image_input,
610
-
611
-                 model_choice,
612
-
613
-                 query_input,
614
-
615
-                 max_new_tokens,
616
-
617
-                 temperature,
618
-
619
-                 top_p,
620
-
621
-                 top_k,
622
-
623
-                 repetition_penalty,
624
-
625
-             ],
626
-
627
-             outputs=[raw_output],
628
-
629
-             api_name="ocr_image",
630
-
631
-         )
632
-
633
-
634
-
635
-         pdf_btn.click(fn=save_as_pdf, inputs=[raw_output], outputs=[pdf_file])
636
-
637
-         word_btn.click(fn=save_as_word, inputs=[raw_output], outputs=[word_file])
638
-
639
-         audio_btn.click(fn=save_as_audio, inputs=[raw_output], outputs=[audio_file])
640
-
641
-
642
-
643
-         clear_btn.click(
644
-
645
-             fn=lambda: ("", None, "", MAX_NEW_TOKENS_DEFAULT, 0.1, 1.0, 0, 1.0),
646
-
647
-             outputs=[raw_output, image_input, query_input, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
648
-
649
-         )
650
-
651
-
652
 
653
  if __name__ == "__main__":
654
-
655
-     # Keep queue for GPU tasks; limit concurrency for stability.
656
-
657
-     demo.queue(max_size=50).launch(show_error=True)
658
-
 
1
+ # app.py — HTR Space (Compact Version)
2
 
3
+ import os, time
 
 
 
 
 
 
4
  from threading import Thread
 
 
 
5
  import gradio as gr
 
6
  import spaces
 
7
  from PIL import Image
 
8
  import torch
9
+ from transformers import AutoProcessor, AutoModelForImageTextToText, Qwen2_5_VLForConditionalGeneration
10
+ from reportlab.platypus import SimpleDocTemplate, Paragraph
11
+ from reportlab.lib.styles import getSampleStyleSheet
12
+ from docx import Document
13
 
14
+ # ---------------- Models ----------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  MODEL_PATHS = {
16
+ "Model 1 (Complex handwrittings )": ("prithivMLmods/Qwen2.5-VL-7B-Abliterated-Caption-it", Qwen2_5_VLForConditionalGeneration),
17
+ "Model 2 (simple and scanned handwritting )": ("nanonets/Nanonets-OCR-s", Qwen2_5_VLForConditionalGeneration),
18
+ "Model 3 (structured handwritting)": ("Emeritus-21/Finetuned-full-HTR-model", AutoModelForImageTextToText),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  }
20
 
 
 
21
  MAX_NEW_TOKENS_DEFAULT = 512
 
22
  device = "cuda" if torch.cuda.is_available() else "cpu"
23
 
24
+ _loaded_processors, _loaded_models = {}, {}
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  print("🚀 Preloading models into GPU/CPU memory...")
 
 
 
26
  for name, (repo_id, cls) in MODEL_PATHS.items():
27
+ try:
28
+ processor = AutoProcessor.from_pretrained(repo_id, trust_remote_code=True)
29
+ model = cls.from_pretrained(repo_id, trust_remote_code=True,
30
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
31
+ low_cpu_mem_usage=True).to(device).eval()
32
+ _loaded_processors[name], _loaded_models[name] = processor, model
33
+ print(f"✅ {name} ready.")
34
+ except Exception as e:
35
+ print(f"⚠️ Failed to load {name}: {e}")
36
+
37
+ # ---------------- GPU Warmup ----------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  @spaces.GPU
 
39
  def warmup(progress=gr.Progress(track_tqdm=True)):
40
+ try:
41
+ default_model_choice = next(iter(MODEL_PATHS.keys()))
42
+ processor = _loaded_processors[default_model_choice]
43
+ model = _loaded_models[default_model_choice]
44
+ tokenizer = getattr(processor, "tokenizer", None)
45
+ messages = [{"role": "user", "content": [{"type": "text", "text": "Warmup."}]}]
46
+ chat_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) if tokenizer and hasattr(tokenizer, "apply_chat_template") else "Warmup."
47
+ inputs = processor(text=[chat_prompt], images=None, return_tensors="pt").to(device)
48
+ with torch.inference_mode(): _ = model.generate(**inputs, max_new_tokens=1)
49
+ return f"GPU warm and {default_model_choice} ready."
50
+ except Exception as e:
51
+ return f"Warmup skipped: {e}"
52
+
53
+ # ---------------- Helpers ----------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  def _build_inputs(processor, tokenizer, image: Image.Image, prompt: str):
55
+ messages = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": prompt}]}]
56
+ if tokenizer and hasattr(tokenizer, "apply_chat_template"):
57
+ chat_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
58
+ return processor(text=[chat_prompt], images=[image], return_tensors="pt")
59
+ return processor(text=[prompt], images=[image], return_tensors="pt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
  def _decode_text(model, processor, tokenizer, output_ids):
62
+ for obj in [processor, tokenizer, getattr(model, "tokenizer", None)]:
63
+ try: return obj.batch_decode(output_ids, skip_special_tokens=True)[0]
64
+ except Exception: pass
65
+ return str(output_ids)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  def _default_prompt(query: str | None) -> str:
68
+ if query and query.strip(): return query.strip()
69
+ return ("You are a professional Handwritten OCR system.\n"
70
+ "TASK: Read the handwritten image and transcribe the text EXACTLY as written.\n"
71
+ "- Preserve original structure and line breaks.\n"
72
+ "- Keep spacing, bullet points, numbering, and indentation.\n"
73
+ "- Render tables as Markdown tables if present.\n"
74
+ "- Do NOT autocorrect spelling or grammar.\n"
75
+ "- Do NOT merge lines.\n"
76
+ "Return RAW transcription only.")
77
+
78
+ # ---------------- OCR Function ----------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  @spaces.GPU
80
+ def ocr_image(image: Image.Image, model_choice: str, query: str = None,
81
+ max_new_tokens: int = MAX_NEW_TOKENS_DEFAULT,
82
+ temperature: float = 0.1, top_p: float = 1.0, top_k: int = 0, repetition_penalty: float = 1.0,
83
+ progress=gr.Progress(track_tqdm=True)):
84
+ if image is None: return "Please upload or capture an image."
85
+ if model_choice not in _loaded_models: return f"Invalid model: {model_choice}"
86
+ processor, model, tokenizer = _loaded_processors[model_choice], _loaded_models[model_choice], getattr(_loaded_processors[model_choice], "tokenizer", None)
87
+ prompt = _default_prompt(query)
88
+ batch = _build_inputs(processor, tokenizer, image, prompt).to(device)
89
+ with torch.inference_mode():
90
+ output_ids = model.generate(**batch, max_new_tokens=max_new_tokens, do_sample=False,
91
+ temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty)
92
+ return _decode_text(model, processor, tokenizer, output_ids).replace("<|im_end|>", "").strip()
93
+
94
+ # ---------------- Export Helpers ----------------
95
+ def _safe_text(text: str) -> str: return (text or "").strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  def save_as_pdf(text):
98
+ text = _safe_text(text)
99
+ if not text: return None
100
+ doc = SimpleDocTemplate("output.pdf")
101
+ flowables = [Paragraph(t, getSampleStyleSheet()["Normal"]) for t in text.splitlines() if t != ""]
102
+ if not flowables: flowables = [Paragraph(" ", getSampleStyleSheet()["Normal"])]
103
+ doc.build(flowables)
104
+ return "output.pdf"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
  def save_as_word(text):
107
+ text = _safe_text(text)
108
+ if not text: return None
109
+ doc = Document()
110
+ for line in text.splitlines(): doc.add_paragraph(line)
111
+ doc.save("output.docx")
112
+ return "output.docx"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
  def save_as_audio(text):
115
+ text = _safe_text(text)
116
+ if not text: return None
117
+ try: from gTTS import gTTS; tts = gTTS(text); tts.save("output.mp3"); return "output.mp3"
118
+ except Exception as e: print(f"gTTS failed: {e}"); return None
119
 
120
+ # ---------------- Gradio Interface ----------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
122
+ gr.Markdown("## ✍🏾 wilson Handwritten OCR")
123
+ model_choice = gr.Radio(choices=list(MODEL_PATHS.keys()), value=list(MODEL_PATHS.keys())[0], label="Select OCR Model")
124
+ with gr.Tab("🖼 Image Inference"):
125
+ query_input = gr.Textbox(label="Custom Prompt (optional)", placeholder="Leave empty for RAW structured output")
126
+ image_input = gr.Image(type="pil", label="Upload / Capture Handwritten Image", sources=["upload", "webcam"])
127
+ with gr.Accordion("⚙️ Advanced Options", open=False):
128
+ max_new_tokens = gr.Slider(1, 2048, value=MAX_NEW_TOKENS_DEFAULT, step=1, label="Max new tokens")
129
+ temperature = gr.Slider(0.1, 2.0, value=0.1, step=0.05, label="Temperature")
130
+ top_p = gr.Slider(0.05, 1.0, value=1.0, step=0.05, label="Top-p (nucleus)")
131
+ top_k = gr.Slider(0, 1000, value=0, step=1, label="Top-k")
132
+ repetition_penalty = gr.Slider(0.8, 2.0, value=1.0, step=0.05, label="Repetition penalty")
133
+ extract_btn = gr.Button("📤 Extract RAW Text", variant="primary")
134
+ clear_btn = gr.Button("🧹 Clear")
135
+ raw_output = gr.Textbox(label="📜 RAW Structured Output (exact as written)", lines=18, show_copy_button=True)
136
+ pdf_btn = gr.Button("⬇️ Download as PDF")
137
+ word_btn = gr.Button("⬇️ Download as Word")
138
+ audio_btn = gr.Button("🔊 Download as Audio")
139
+ pdf_file, word_file, audio_file = gr.File(label="PDF File"), gr.File(label="Word File"), gr.File(label="Audio File")
140
+
141
+ extract_btn.click(fn=ocr_image, inputs=[image_input, model_choice, query_input, max_new_tokens, temperature, top_p, top_k, repetition_penalty], outputs=[raw_output], api_name="ocr_image")
142
+ pdf_btn.click(fn=save_as_pdf, inputs=[raw_output], outputs=[pdf_file])
143
+ word_btn.click(fn=save_as_word, inputs=[raw_output], outputs=[word_file])
144
+ audio_btn.click(fn=save_as_audio, inputs=[raw_output], outputs=[audio_file])
145
+ clear_btn.click(fn=lambda: ("", None, "", MAX_NEW_TOKENS_DEFAULT, 0.1, 1.0, 0, 1.0),
146
+ outputs=[raw_output, image_input, query_input, max_new_tokens, temperature, top_p, top_k, repetition_penalty])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
  if __name__ == "__main__":
149
+ demo.queue(max_size=50).launch(show_error=True)