Emeritus-21 commited on
Commit
927e645
Β·
verified Β·
1 Parent(s): 4ef4dae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -54
app.py CHANGED
@@ -1,8 +1,8 @@
1
- # app.py β€” HTR Space (Refined Compact Version)
2
-
3
  import os, time
4
  from threading import Thread
5
  import gradio as gr
 
6
  from PIL import Image
7
  import torch
8
  from transformers import AutoProcessor, AutoModelForImageTextToText, Qwen2_5_VLForConditionalGeneration
@@ -10,30 +10,47 @@ from reportlab.platypus import SimpleDocTemplate, Paragraph
10
  from reportlab.lib.styles import getSampleStyleSheet
11
  from docx import Document
12
 
13
- # ---------------- Constants ----------------
14
- MAX_MAX_NEW_TOKENS = 2048
15
- DEFAULT_MAX_NEW_TOKENS = 512
16
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
-
18
  # ---------------- Models ----------------
19
  MODEL_PATHS = {
20
- "Complex Handwriting": ("prithivMLmods/Qwen2.5-VL-7B-Abliterated-Caption-it", Qwen2_5_VLForConditionalGeneration),
21
- "Simple/Scanned Handwriting": ("nanonets/Nanonets-OCR-s", Qwen2_5_VLForConditionalGeneration),
22
- "Structured Handwriting": ("Emeritus-21/Finetuned-full-HTR-model", AutoModelForImageTextToText),
23
  }
24
 
 
 
25
  _loaded_processors, _loaded_models = {}, {}
26
- print("πŸš€ Loading HTR models...")
 
27
  for name, (repo_id, cls) in MODEL_PATHS.items():
28
  try:
29
  processor = AutoProcessor.from_pretrained(repo_id, trust_remote_code=True)
30
- model = cls.from_pretrained(repo_id, trust_remote_code=True,
31
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
32
- low_cpu_mem_usage=True).to(device).eval()
 
 
 
33
  _loaded_processors[name], _loaded_models[name] = processor, model
34
- print(f"βœ… {name} ready")
35
  except Exception as e:
36
- print(f"⚠️ Failed {name}: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  # ---------------- Helpers ----------------
39
  def _build_inputs(processor, tokenizer, image: Image.Image, prompt: str):
@@ -46,31 +63,36 @@ def _build_inputs(processor, tokenizer, image: Image.Image, prompt: str):
46
  def _decode_text(model, processor, tokenizer, output_ids):
47
  for obj in [processor, tokenizer, getattr(model, "tokenizer", None)]:
48
  try: return obj.batch_decode(output_ids, skip_special_tokens=True)[0]
49
- except: pass
50
  return str(output_ids)
51
 
52
  def _default_prompt(query: str | None) -> str:
53
  if query and query.strip(): return query.strip()
54
- return ("You are a professional Handwritten OCR system.\n"
55
- "TASK: Read the handwritten image and transcribe the text EXACTLY as written.\n"
56
- "- Preserve original structure and line breaks.\n"
57
- "- Keep spacing, bullet points, numbering, and indentation.\n"
58
- "- Render tables as Markdown tables if present.\n"
59
- "- Do NOT autocorrect spelling or grammar.\n"
60
- "- Do NOT merge lines.\n"
61
- "Return RAW transcription only.")
 
 
62
 
63
- # ---------------- OCR ----------------
64
- def ocr_image(model_name: str, image: Image.Image, query: str = "",
65
- max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS):
66
- if image is None: return "Please upload an image."
67
- if model_name not in _loaded_models: return "Invalid model selected."
68
- processor, model = _loaded_processors[model_name], _loaded_models[model_name]
69
- tokenizer = getattr(processor, "tokenizer", None)
 
 
70
  prompt = _default_prompt(query)
71
  batch = _build_inputs(processor, tokenizer, image, prompt).to(device)
72
  with torch.inference_mode():
73
- output_ids = model.generate(**batch, max_new_tokens=max_new_tokens)
 
74
  return _decode_text(model, processor, tokenizer, output_ids).replace("<|im_end|>", "").strip()
75
 
76
  # ---------------- Export Helpers ----------------
@@ -93,28 +115,44 @@ def save_as_word(text):
93
  doc.save("output.docx")
94
  return "output.docx"
95
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  # ---------------- Gradio Interface ----------------
97
- css = """.submit-btn { background-color: #2980b9 !important; color: white !important; }
98
- .submit-btn:hover { background-color: #3498db !important; }
99
- .canvas-output { border: 2px solid #4682B4; border-radius: 10px; padding: 20px;}"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
- with gr.Blocks(css=css, theme="soft") as demo:
102
- gr.Markdown("## ✍🏾 Wilson HTR OCR")
103
- with gr.Row():
104
- with gr.Column():
105
- model_choice = gr.Radio(choices=list(MODEL_PATHS.keys()), value=list(MODEL_PATHS.keys())[0], label="Select OCR Model")
106
- query_input = gr.Textbox(label="Custom Prompt (optional)")
107
- image_input = gr.Image(type="pil", label="Upload / Capture Image", source="upload")
108
- submit_btn = gr.Button("πŸ“€ Extract Text", elem_classes="submit-btn")
109
- raw_output = gr.Textbox(label="OCR Output", lines=15, interactive=False, show_copy_button=True)
110
- pdf_btn = gr.Button("⬇️ Download PDF")
111
- word_btn = gr.Button("⬇️ Download Word")
112
- pdf_file = gr.File(label="PDF File")
113
- word_file = gr.File(label="Word File")
114
-
115
- submit_btn.click(fn=ocr_image, inputs=[model_choice, image_input, query_input], outputs=[raw_output])
116
- pdf_btn.click(fn=save_as_pdf, inputs=[raw_output], outputs=[pdf_file])
117
- word_btn.click(fn=save_as_word, inputs=[raw_output], outputs=[word_file])
118
 
119
  if __name__ == "__main__":
120
- demo.queue(max_size=50).launch(share=True, show_error=True)
 
1
+ # app.py β€” HTR Space (Compact Version)
 
2
  import os, time
3
  from threading import Thread
4
  import gradio as gr
5
+ import spaces
6
  from PIL import Image
7
  import torch
8
  from transformers import AutoProcessor, AutoModelForImageTextToText, Qwen2_5_VLForConditionalGeneration
 
10
  from reportlab.lib.styles import getSampleStyleSheet
11
  from docx import Document
12
 
 
 
 
 
 
13
  # ---------------- Models ----------------
14
  MODEL_PATHS = {
15
+ "Model 1 (Complex handwrittings )": ("prithivMLmods/Qwen2.5-VL-7B-Abliterated-Caption-it", Qwen2_5_VLForConditionalGeneration),
16
+ "Model 2 (simple and scanned handwritting )": ("nanonets/Nanonets-OCR-s", Qwen2_5_VLForConditionalGeneration),
17
+ "Model 3 (structured handwritting)": ("Emeritus-21/Finetuned-full-HTR-model", AutoModelForImageTextToText),
18
  }
19
 
20
+ MAX_NEW_TOKENS_DEFAULT = 512
21
+ device = "cuda" if torch.cuda.is_available() else "cpu"
22
  _loaded_processors, _loaded_models = {}, {}
23
+
24
+ print("πŸš€ Preloading models into GPU/CPU memory...")
25
  for name, (repo_id, cls) in MODEL_PATHS.items():
26
  try:
27
  processor = AutoProcessor.from_pretrained(repo_id, trust_remote_code=True)
28
+ model = cls.from_pretrained(
29
+ repo_id,
30
+ trust_remote_code=True,
31
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
32
+ low_cpu_mem_usage=True
33
+ ).to(device).eval()
34
  _loaded_processors[name], _loaded_models[name] = processor, model
35
+ print(f"βœ… {name} ready.")
36
  except Exception as e:
37
+ print(f"⚠️ Failed to load {name}: {e}")
38
+
39
+ # ---------------- GPU Warmup ----------------
40
+ @spaces.GPU
41
+ def warmup(progress=gr.Progress(track_tqdm=True)):
42
+ try:
43
+ default_model_choice = next(iter(MODEL_PATHS.keys()))
44
+ processor = _loaded_processors[default_model_choice]
45
+ model = _loaded_models[default_model_choice]
46
+ tokenizer = getattr(processor, "tokenizer", None)
47
+ messages = [{"role": "user", "content": [{"type": "text", "text": "Warmup."}]}]
48
+ chat_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) if tokenizer and hasattr(tokenizer, "apply_chat_template") else "Warmup."
49
+ inputs = processor(text=[chat_prompt], images=None, return_tensors="pt").to(device)
50
+ with torch.inference_mode(): _ = model.generate(**inputs, max_new_tokens=1)
51
+ return f"GPU warm and {default_model_choice} ready."
52
+ except Exception as e:
53
+ return f"Warmup skipped: {e}"
54
 
55
  # ---------------- Helpers ----------------
56
  def _build_inputs(processor, tokenizer, image: Image.Image, prompt: str):
 
63
  def _decode_text(model, processor, tokenizer, output_ids):
64
  for obj in [processor, tokenizer, getattr(model, "tokenizer", None)]:
65
  try: return obj.batch_decode(output_ids, skip_special_tokens=True)[0]
66
+ except Exception: pass
67
  return str(output_ids)
68
 
69
  def _default_prompt(query: str | None) -> str:
70
  if query and query.strip(): return query.strip()
71
+ return (
72
+ "You are a professional Handwritten OCR system.\n"
73
+ "TASK: Read the handwritten image and transcribe the text EXACTLY as written.\n"
74
+ "- Preserve original structure and line breaks.\n"
75
+ "- Keep spacing, bullet points, numbering, and indentation.\n"
76
+ "- Render tables as Markdown tables if present.\n"
77
+ "- Do NOT autocorrect spelling or grammar.\n"
78
+ "- Do NOT merge lines.\n"
79
+ "Return RAW transcription only."
80
+ )
81
 
82
+ # ---------------- OCR Function ----------------
83
+ @spaces.GPU
84
+ def ocr_image(image: Image.Image, model_choice: str, query: str = None,
85
+ max_new_tokens: int = MAX_NEW_TOKENS_DEFAULT,
86
+ temperature: float = 0.1, top_p: float = 1.0, top_k: int = 0, repetition_penalty: float = 1.0,
87
+ progress=gr.Progress(track_tqdm=True)):
88
+ if image is None: return "Please upload or capture an image."
89
+ if model_choice not in _loaded_models: return f"Invalid model: {model_choice}"
90
+ processor, model, tokenizer = _loaded_processors[model_choice], _loaded_models[model_choice], getattr(_loaded_processors[model_choice], "tokenizer", None)
91
  prompt = _default_prompt(query)
92
  batch = _build_inputs(processor, tokenizer, image, prompt).to(device)
93
  with torch.inference_mode():
94
+ output_ids = model.generate(**batch, max_new_tokens=max_new_tokens, do_sample=False,
95
+ temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty)
96
  return _decode_text(model, processor, tokenizer, output_ids).replace("<|im_end|>", "").strip()
97
 
98
  # ---------------- Export Helpers ----------------
 
115
  doc.save("output.docx")
116
  return "output.docx"
117
 
118
+ def save_as_audio(text):
119
+ text = _safe_text(text)
120
+ if not text: return None
121
+ try:
122
+ from gTTS import gTTS
123
+ tts = gTTS(text)
124
+ tts.save("output.mp3")
125
+ return "output.mp3"
126
+ except Exception as e:
127
+ print(f"gTTS failed: {e}")
128
+ return None
129
+
130
  # ---------------- Gradio Interface ----------------
131
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
132
+ gr.Markdown("## ✍🏾 wilson Handwritten OCR")
133
+ model_choice = gr.Radio(choices=list(MODEL_PATHS.keys()), value=list(MODEL_PATHS.keys())[0], label="Select OCR Model")
134
+ with gr.Tab("πŸ–Ό Image Inference"):
135
+ query_input = gr.Textbox(label="Custom Prompt (optional)", placeholder="Leave empty for RAW structured output")
136
+ image_input = gr.Image(type="pil", label="Upload / Capture Handwritten Image", sources=["upload", "webcam"])
137
+ with gr.Accordion("βš™οΈ Advanced Options", open=False):
138
+ max_new_tokens = gr.Slider(1, 2048, value=MAX_NEW_TOKENS_DEFAULT, step=1, label="Max new tokens")
139
+ temperature = gr.Slider(0.1, 2.0, value=0.1, step=0.05, label="Temperature")
140
+ top_p = gr.Slider(0.05, 1.0, value=1.0, step=0.05, label="Top-p (nucleus)")
141
+ top_k = gr.Slider(0, 1000, value=0, step=1, label="Top-k")
142
+ repetition_penalty = gr.Slider(0.8, 2.0, value=1.0, step=0.05, label="Repetition penalty")
143
+ extract_btn = gr.Button("πŸ“€ Extract RAW Text", variant="primary")
144
+ clear_btn = gr.Button("🧹 Clear")
145
+ raw_output = gr.Textbox(label="πŸ“œ RAW Structured Output (exact as written)", lines=18, show_copy_button=True)
146
+ pdf_btn = gr.Button("⬇️ Download as PDF")
147
+ word_btn = gr.Button("⬇️ Download as Word")
148
+ audio_btn = gr.Button("πŸ”Š Download as Audio")
149
+ pdf_file, word_file, audio_file = gr.File(label="PDF File"), gr.File(label="Word File"), gr.File(label="Audio File")
150
 
151
+ extract_btn.click(fn=ocr_image, inputs=[image_input, model_choice, query_input, max_new_tokens, temperature, top_p, top_k, repetition_penalty], outputs=[raw_output], api_name="ocr_image")
152
+ pdf_btn.click(fn=save_as_pdf, inputs=[raw_output], outputs=[pdf_file])
153
+ word_btn.click(fn=save_as_word, inputs=[raw_output], outputs=[word_file])
154
+ audio_btn.click(fn=save_as_audio, inputs=[raw_output], outputs=[audio_file])
155
+ clear_btn.click(fn=lambda: ("", None, "", MAX_NEW_TOKENS_DEFAULT, 0.1, 1.0, 0, 1.0), outputs=[raw_output, image_input, query_input, max_new_tokens, temperature, top_p, top_k, repetition_penalty])
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
  if __name__ == "__main__":
158
+ demo.queue(max_size=50).launch(show_error=True)