Emeritus-21 commited on
Commit
cd9abb6
·
verified ·
1 Parent(s): b49181f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -89
app.py CHANGED
@@ -1,7 +1,4 @@
1
- # app.py — HTR Space (GPU-only, no webcam, mobile-ready)
2
-
3
  import os
4
- from threading import Thread
5
  import gradio as gr
6
  from PIL import Image
7
  import torch
@@ -9,9 +6,13 @@ from transformers import AutoProcessor, AutoModelForImageTextToText, Qwen2_5_VLF
9
  from reportlab.platypus import SimpleDocTemplate, Paragraph
10
  from reportlab.lib.styles import getSampleStyleSheet
11
  from docx import Document
 
12
 
 
 
 
 
13
  MAX_NEW_TOKENS_DEFAULT = 512
14
- DEVICE = "cuda" # GPU-only
15
 
16
  # ---------------------------
17
  # Models config
@@ -44,12 +45,20 @@ def load_model(name):
44
  return processor, model
45
 
46
  # ---------------------------
47
- # Helpers
48
  # ---------------------------
49
- def _default_prompt(query: str | None) -> str:
50
- if query and query.strip():
51
- return query.strip()
52
- return (
 
 
 
 
 
 
 
 
53
  "You are a professional Handwritten OCR system.\n"
54
  "TASK: Read the handwritten image and transcribe the text EXACTLY as written.\n"
55
  "- Preserve original structure and line breaks.\n"
@@ -60,50 +69,7 @@ def _default_prompt(query: str | None) -> str:
60
  "Return RAW transcription only."
61
  )
62
 
63
- def _build_inputs_plain(processor, image: Image.Image, prompt: str):
64
- return processor(text=[prompt], images=[image], return_tensors="pt").to(DEVICE)
65
-
66
- def _decode_text(model, processor, tokenizer, output_ids):
67
- text = ""
68
- try:
69
- if hasattr(processor, "batch_decode"):
70
- text = processor.batch_decode(output_ids, skip_special_tokens=True)[0]
71
- return text
72
- except Exception:
73
- pass
74
- try:
75
- if tokenizer is not None:
76
- text = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
77
- return text
78
- except Exception:
79
- pass
80
- try:
81
- model_tok = getattr(model, "tokenizer", None)
82
- if model_tok is not None:
83
- text = model_tok.batch_decode(output_ids, skip_special_tokens=True)[0]
84
- return text
85
- except Exception:
86
- pass
87
- return str(output_ids)
88
-
89
- # ---------------------------
90
- # GPU OCR function
91
- # ---------------------------
92
- from spaces import GPU
93
-
94
- @GPU
95
- def ocr_image_gpu(image: Image.Image, model_choice: str, query: str = None,
96
- max_new_tokens: int = MAX_NEW_TOKENS_DEFAULT, temperature: float = 0.1,
97
- top_p: float = 1.0, top_k: int = 0, repetition_penalty: float = 1.0):
98
- if image is None:
99
- return "Please upload an image."
100
-
101
- if model_choice not in MODEL_PATHS:
102
- return f"Invalid model: {model_choice}"
103
-
104
- processor, model = load_model(model_choice)
105
- prompt = _default_prompt(query)
106
- batch = _build_inputs_plain(processor, image, prompt)
107
 
108
  with torch.inference_mode():
109
  output_ids = model.generate(
@@ -116,11 +82,17 @@ def ocr_image_gpu(image: Image.Image, model_choice: str, query: str = None,
116
  repetition_penalty=repetition_penalty,
117
  )
118
 
119
- decoded = _decode_text(model, processor, None, output_ids)
120
- return decoded.replace("<|im_end|>", "").strip()
 
 
 
 
 
 
121
 
122
  # ---------------------------
123
- # Export functions
124
  # ---------------------------
125
  def _safe_text(text: str) -> str:
126
  return (text or "").strip()
@@ -153,15 +125,10 @@ def save_as_audio(text):
153
  text = _safe_text(text)
154
  if not text:
155
  return None
156
- try:
157
- from gTTS import gTTS
158
- filepath = "output.mp3"
159
- tts = gTTS(text)
160
- tts.save(filepath)
161
- return filepath
162
- except Exception as e:
163
- print(f"gTTS failed: {e}")
164
- return None
165
 
166
  # ---------------------------
167
  # Gradio UI
@@ -180,46 +147,45 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
180
  placeholder="Leave empty for RAW structured output",
181
  )
182
 
183
- image_input = gr.Image(type="pil", label="Upload Image")
184
-
185
- # Advanced Options
186
- with gr.Accordion("⚙️ Advanced Options", open=False):
187
- max_new_tokens = gr.Slider(1, 2048, value=MAX_NEW_TOKENS_DEFAULT, step=1, label="Max new tokens")
188
- temperature = gr.Slider(0.1, 2.0, value=0.1, step=0.05, label="Temperature")
189
- top_p = gr.Slider(0.05, 1.0, value=1.0, step=0.05, label="Top-p (nucleus)")
190
- top_k = gr.Slider(0, 1000, value=0, step=1, label="Top-k")
191
- repetition_penalty = gr.Slider(0.8, 2.0, value=1.0, step=0.05, label="Repetition penalty")
192
 
193
- # Extract Button ABOVE output
194
  extract_btn = gr.Button("📤 Extract RAW Text", variant="primary")
195
-
196
  raw_output = gr.Textbox(
197
  label="📜 RAW Structured Output (exact as written)",
198
  lines=18,
199
  show_copy_button=True,
200
  )
 
 
 
201
 
202
- pdf_btn = gr.Button("⬇️ Download as PDF")
203
- word_btn = gr.Button("⬇️ Download as Word")
204
- audio_btn = gr.Button("🔊 Download as Audio")
205
-
206
- # ---------------------------
207
- # Button Callbacks
208
- # ---------------------------
209
- def on_extract(uploaded, model, query, max_tokens, temp, top_p, top_k, rep):
210
- return ocr_image_gpu(uploaded, model, query, max_tokens, temp, top_p, top_k, rep)
211
 
 
212
  extract_btn.click(
213
- fn=on_extract,
214
  inputs=[image_input, model_choice, query_input,
215
  max_new_tokens, temperature, top_p, top_k, repetition_penalty],
216
  outputs=[raw_output]
217
  )
218
 
219
- pdf_btn.click(fn=save_as_pdf, inputs=[raw_output], outputs=[pdf_btn])
220
- word_btn.click(fn=save_as_word, inputs=[raw_output], outputs=[word_btn])
221
- audio_btn.click(fn=save_as_audio, inputs=[raw_output], outputs=[audio_btn])
 
 
 
 
 
 
222
 
 
223
  clear_btn = gr.Button("🧹 Clear")
224
  clear_btn.click(
225
  fn=lambda: ("", None, "", MAX_NEW_TOKENS_DEFAULT, 0.1, 1.0, 0, 1.0),
 
 
 
1
  import os
 
2
  import gradio as gr
3
  from PIL import Image
4
  import torch
 
6
  from reportlab.platypus import SimpleDocTemplate, Paragraph
7
  from reportlab.lib.styles import getSampleStyleSheet
8
  from docx import Document
9
+ from gtts import gTTS
10
 
11
+ # ---------------------------
12
+ # Device & constants
13
+ # ---------------------------
14
+ DEVICE = "cuda" # Force GPU usage
15
  MAX_NEW_TOKENS_DEFAULT = 512
 
16
 
17
  # ---------------------------
18
  # Models config
 
45
  return processor, model
46
 
47
  # ---------------------------
48
+ # OCR function (GPU ready)
49
  # ---------------------------
50
+ @gr.utils.space_decorator # Spaces decorator to detect GPU
51
+ def ocr_image_gpu(image: Image.Image, model_choice: str, query: str = None,
52
+ max_new_tokens: int = MAX_NEW_TOKENS_DEFAULT, temperature: float = 0.1,
53
+ top_p: float = 1.0, top_k: int = 0, repetition_penalty: float = 1.0):
54
+ if image is None:
55
+ return "Please upload an image."
56
+ if model_choice not in MODEL_PATHS:
57
+ return f"Invalid model: {model_choice}"
58
+
59
+ processor, model = load_model(model_choice)
60
+
61
+ prompt = query.strip() if query and query.strip() else (
62
  "You are a professional Handwritten OCR system.\n"
63
  "TASK: Read the handwritten image and transcribe the text EXACTLY as written.\n"
64
  "- Preserve original structure and line breaks.\n"
 
69
  "Return RAW transcription only."
70
  )
71
 
72
+ batch = processor(text=[prompt], images=[image], return_tensors="pt").to(DEVICE)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
  with torch.inference_mode():
75
  output_ids = model.generate(
 
82
  repetition_penalty=repetition_penalty,
83
  )
84
 
85
+ # decode safely
86
+ text = ""
87
+ if hasattr(processor, "batch_decode"):
88
+ text = processor.batch_decode(output_ids, skip_special_tokens=True)[0]
89
+ elif hasattr(model, "tokenizer") and model.tokenizer is not None:
90
+ text = model.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
91
+ text = text.replace("<|im_end|>", "").strip()
92
+ return text
93
 
94
  # ---------------------------
95
+ # Export helpers
96
  # ---------------------------
97
  def _safe_text(text: str) -> str:
98
  return (text or "").strip()
 
125
  text = _safe_text(text)
126
  if not text:
127
  return None
128
+ filepath = "output.mp3"
129
+ tts = gTTS(text)
130
+ tts.save(filepath)
131
+ return filepath
 
 
 
 
 
132
 
133
  # ---------------------------
134
  # Gradio UI
 
147
  placeholder="Leave empty for RAW structured output",
148
  )
149
 
150
+ image_input = gr.Image(type="pil", label="Upload Image (desktop/mobile)")
 
 
 
 
 
 
 
 
151
 
152
+ # Buttons first
153
  extract_btn = gr.Button("📤 Extract RAW Text", variant="primary")
 
154
  raw_output = gr.Textbox(
155
  label="📜 RAW Structured Output (exact as written)",
156
  lines=18,
157
  show_copy_button=True,
158
  )
159
+ pdf_file = gr.File(label="PDF File")
160
+ word_file = gr.File(label="Word File")
161
+ audio_file = gr.File(label="Audio File")
162
 
163
+ with gr.Accordion("⚙️ Advanced Options", open=False):
164
+ max_new_tokens = gr.Slider(1, 2048, value=MAX_NEW_TOKENS_DEFAULT, step=1, label="Max new tokens")
165
+ temperature = gr.Slider(0.1, 2.0, value=0.1, step=0.05, label="Temperature")
166
+ top_p = gr.Slider(0.05, 1.0, value=1.0, step=0.05, label="Top-p (nucleus)")
167
+ top_k = gr.Slider(0, 1000, value=0, step=1, label="Top-k")
168
+ repetition_penalty = gr.Slider(0.8, 2.0, value=1.0, step=0.05, label="Repetition penalty")
 
 
 
169
 
170
+ # Extract text
171
  extract_btn.click(
172
+ fn=ocr_image_gpu,
173
  inputs=[image_input, model_choice, query_input,
174
  max_new_tokens, temperature, top_p, top_k, repetition_penalty],
175
  outputs=[raw_output]
176
  )
177
 
178
+ # Export buttons
179
+ pdf_btn = gr.Button("⬇️ Download as PDF")
180
+ pdf_btn.click(fn=save_as_pdf, inputs=[raw_output], outputs=[pdf_file])
181
+
182
+ word_btn = gr.Button("⬇️ Download as Word")
183
+ word_btn.click(fn=save_as_word, inputs=[raw_output], outputs=[word_file])
184
+
185
+ audio_btn = gr.Button("🔊 Download as Audio")
186
+ audio_btn.click(fn=save_as_audio, inputs=[raw_output], outputs=[audio_file])
187
 
188
+ # Clear button
189
  clear_btn = gr.Button("🧹 Clear")
190
  clear_btn.click(
191
  fn=lambda: ("", None, "", MAX_NEW_TOKENS_DEFAULT, 0.1, 1.0, 0, 1.0),