Emeritus-21 commited on
Commit
61e3d24
·
verified ·
1 Parent(s): 5e48658

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +316 -0
app.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py — HTR Space (full) with downloads (PDF/DOCX/MP3) + webcam support (Gradio 4.x)
2
+
3
+ import os
4
+ import time
5
+ from threading import Thread
6
+
7
+ import gradio as gr
8
+ import spaces
9
+ from PIL import Image
10
+ import torch
11
+ from transformers import (
12
+ AutoProcessor,
13
+ AutoModelForImageTextToText,
14
+ Qwen2_5_VLForConditionalGeneration,
15
+ TextIteratorStreamer,
16
+ )
17
+
18
+ # ---------------------------
19
+ # Models
20
+ # ---------------------------
21
+ MODEL_PATHS = {
22
+ "Model 1 (Complex handwrittings )": (
23
+ "prithivMLmods/Qwen2.5-VL-7B-Abliterated-Caption-it",
24
+ Qwen2_5_VLForConditionalGeneration,
25
+ ),
26
+ "Model 2 (simple and scanned handwritting )": (
27
+ "nanonets/Nanonets-OCR-s",
28
+ Qwen2_5_VLForConditionalGeneration,
29
+ ),
30
+ "Model 3 (structured handwritting)": (
31
+ "Emeritus-21/Finetuned-full-HTR-model",
32
+ AutoModelForImageTextToText,
33
+ ),
34
+ }
35
+
36
+ MAX_NEW_TOKENS_DEFAULT = 512
37
+ device = "cuda" if torch.cuda.is_available() else "cpu"
38
+
39
+ # ---------------------------
40
+ # Preload models at startup
41
+ # ---------------------------
42
+ _loaded_processors = {}
43
+ _loaded_models = {}
44
+
45
+ print("🚀 Preloading models into GPU/CPU memory...")
46
+
47
+ for name, (repo_id, cls) in MODEL_PATHS.items():
48
+ try:
49
+ print(f"Loading {name} ...")
50
+ processor = AutoProcessor.from_pretrained(repo_id, trust_remote_code=True)
51
+ model = cls.from_pretrained(
52
+ repo_id,
53
+ trust_remote_code=True,
54
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
55
+ low_cpu_mem_usage=True,
56
+ ).to(device).eval()
57
+ _loaded_processors[name] = processor
58
+ _loaded_models[name] = model
59
+ print(f"✅ {name} ready.")
60
+ except Exception as e:
61
+ print(f"⚠️ Failed to load {name}: {e}")
62
+
63
+ # ---------------------------
64
+ # Warmup (GPU)
65
+ # ---------------------------
66
+ @spaces.GPU
67
+ def warmup():
68
+ try:
69
+ default_model_choice = list(MODEL_PATHS.keys())[0]
70
+ processor = _loaded_processors[default_model_choice]
71
+ model = _loaded_models[default_model_choice]
72
+
73
+ tokenizer = getattr(processor, "tokenizer", None)
74
+
75
+ messages = [{"role": "user", "content": [{"type": "text", "text": "Warmup."}]}]
76
+ if tokenizer and hasattr(tokenizer, "apply_chat_template"):
77
+ chat_prompt = tokenizer.apply_chat_template(
78
+ messages, tokenize=False, add_generation_prompt=True
79
+ )
80
+ else:
81
+ chat_prompt = "Warmup."
82
+
83
+ inputs = processor(
84
+ text=[chat_prompt],
85
+ images=None,
86
+ return_tensors="pt"
87
+ ).to(device)
88
+
89
+ with torch.inference_mode():
90
+ _ = model.generate(**inputs, max_new_tokens=1)
91
+
92
+ return f"GPU warm and {default_model_choice} ready."
93
+ except Exception as e:
94
+ return f"Warmup skipped: {e}"
95
+
96
+ # ---------------------------
97
+ # OCR Function (RAW ONLY)
98
+ # ---------------------------
99
+ @spaces.GPU
100
+ def ocr_image(
101
+ image: Image.Image,
102
+ model_choice: str,
103
+ query: str = None,
104
+ max_new_tokens: int = MAX_NEW_TOKENS_DEFAULT,
105
+ temperature: float = 0.1,
106
+ top_p: float = 1.0,
107
+ top_k: int = 0,
108
+ repetition_penalty: float = 1.0,
109
+ ):
110
+ if image is None:
111
+ yield "Please upload or capture an image."
112
+ return
113
+
114
+ if model_choice not in _loaded_models:
115
+ yield f"Invalid model: {model_choice}"
116
+ return
117
+
118
+ processor = _loaded_processors[model_choice]
119
+ model = _loaded_models[model_choice]
120
+ tokenizer = getattr(processor, "tokenizer", None)
121
+
122
+ if query and query.strip():
123
+ prompt = query.strip()
124
+ else:
125
+ prompt = (
126
+ "You are a professional Handwritten OCR system.\n"
127
+ "TASK: Read the handwritten image and transcribe the text EXACTLY as written.\n"
128
+ "- Preserve original structure and line breaks.\n"
129
+ "- Keep spacing, bullet points, numbering, and indentation.\n"
130
+ "- Render tables as Markdown tables if present.\n"
131
+ "- Do NOT autocorrect spelling or grammar.\n"
132
+ "- Do NOT merge lines.\n"
133
+ "Return RAW transcription only."
134
+ )
135
+
136
+ messages = [
137
+ {
138
+ "role": "user",
139
+ "content": [
140
+ {"type": "image", "image": image},
141
+ {"type": "text", "text": prompt},
142
+ ],
143
+ }
144
+ ]
145
+
146
+ # Build chat prompt (prefer tokenizer chat template if available)
147
+ if tokenizer and hasattr(tokenizer, "apply_chat_template"):
148
+ chat_prompt = tokenizer.apply_chat_template(
149
+ messages, tokenize=False, add_generation_prompt=True
150
+ )
151
+ else:
152
+ # fallback: just use plain prompt
153
+ chat_prompt = prompt
154
+
155
+ # Processor packs both text + image for VLMs
156
+ inputs = processor(
157
+ text=[chat_prompt],
158
+ images=[image],
159
+ return_tensors="pt"
160
+ ).to(device)
161
+
162
+ # Use tokenizer (if present) in streamer for correct detokenization
163
+ streamer = TextIteratorStreamer(
164
+ tokenizer if tokenizer is not None else None,
165
+ skip_prompt=True,
166
+ skip_special_tokens=True,
167
+ )
168
+
169
+ generation_kwargs = dict(
170
+ **inputs,
171
+ streamer=streamer,
172
+ max_new_tokens=max_new_tokens,
173
+ do_sample=False,
174
+ temperature=temperature,
175
+ top_p=top_p,
176
+ top_k=top_k,
177
+ repetition_penalty=repetition_penalty,
178
+ )
179
+
180
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
181
+ thread.start()
182
+
183
+ buffer = ""
184
+ for new_text in streamer:
185
+ new_text = new_text.replace("<|im_end|>", "")
186
+ buffer += new_text
187
+ # small sleep to smooth streaming
188
+ time.sleep(0.01)
189
+ yield buffer
190
+
191
+ # ---------------------------
192
+ # Export Helpers
193
+ # ---------------------------
194
+ from reportlab.platypus import SimpleDocTemplate, Paragraph
195
+ from reportlab.lib.styles import getSampleStyleSheet
196
+ from docx import Document
197
+ from gtts import gTTS
198
+
199
+ def _safe_text(text: str) -> str:
200
+ return (text or "").strip()
201
+
202
+ def save_as_pdf(text):
203
+ text = _safe_text(text)
204
+ if not text:
205
+ return None
206
+ filepath = "output.pdf"
207
+ doc = SimpleDocTemplate(filepath)
208
+ styles = getSampleStyleSheet()
209
+ flowables = [Paragraph(t, styles["Normal"]) for t in text.splitlines() if t != ""]
210
+ if not flowables:
211
+ flowables = [Paragraph(" ", styles["Normal"])]
212
+ doc.build(flowables)
213
+ return filepath
214
+
215
+ def save_as_word(text):
216
+ text = _safe_text(text)
217
+ if not text:
218
+ return None
219
+ filepath = "output.docx"
220
+ doc = Document()
221
+ for line in text.splitlines():
222
+ doc.add_paragraph(line)
223
+ doc.save(filepath)
224
+ return filepath
225
+
226
+ def save_as_audio(text):
227
+ text = _safe_text(text)
228
+ if not text:
229
+ return None
230
+ filepath = "output.mp3"
231
+ # NOTE: gTTS uses an online service; Spaces must have outbound internet enabled.
232
+ tts = gTTS(text)
233
+ tts.save(filepath)
234
+ return filepath
235
+
236
+ # ---------------------------
237
+ # Gradio Interface
238
+ # ---------------------------
239
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
240
+ gr.Markdown("## ✍🏾 wilson Handwritten OCR ")
241
+
242
+ model_choice = gr.Radio(
243
+ choices=list(MODEL_PATHS.keys()),
244
+ value=list(MODEL_PATHS.keys())[0],
245
+ label="Select OCR Model",
246
+ )
247
+
248
+ with gr.Tab("🖼 Image Inference"):
249
+ query_input = gr.Textbox(
250
+ label="Custom Prompt (optional)",
251
+ placeholder="Leave empty for RAW structured output",
252
+ )
253
+
254
+ # Gradio 4.x: use `sources` instead of deprecated `source`/`tool`
255
+ # This enables both Upload and Webcam capture. On mobile, users can switch front/back camera
256
+ # via the browser UI (programmatic 'back' forcing isn't supported across all browsers).
257
+ image_input = gr.Image(
258
+ type="pil",
259
+ label="Upload / Capture Handwritten Image",
260
+ sources=["upload", "webcam"],
261
+ )
262
+
263
+ with gr.Accordion("⚙️ Advanced Options", open=False):
264
+ max_new_tokens = gr.Slider(1, 2048, value=MAX_NEW_TOKENS_DEFAULT, step=1, label="Max new tokens")
265
+ temperature = gr.Slider(0.1, 2.0, value=0.1, step=0.05, label="Temperature")
266
+ top_p = gr.Slider(0.05, 1.0, value=1.0, step=0.05, label="Top-p (nucleus)")
267
+ top_k = gr.Slider(0, 1000, value=0, step=1, label="Top-k")
268
+ repetition_penalty = gr.Slider(0.8, 2.0, value=1.0, step=0.05, label="Repetition penalty")
269
+
270
+ with gr.Row():
271
+ extract_btn = gr.Button("📤 Extract RAW Text", variant="primary")
272
+ clear_btn = gr.Button("🧹 Clear")
273
+
274
+ raw_output = gr.Textbox(
275
+ label="📜 RAW Structured Output (exact as written)",
276
+ lines=18,
277
+ show_copy_button=True,
278
+ )
279
+
280
+ with gr.Row():
281
+ pdf_btn = gr.Button("⬇️ Download as PDF")
282
+ word_btn = gr.Button("⬇️ Download as Word")
283
+ audio_btn = gr.Button("🔊 Download as Audio")
284
+
285
+ pdf_file = gr.File(label="PDF File")
286
+ word_file = gr.File(label="Word File")
287
+ audio_file = gr.File(label="Audio File")
288
+
289
+ extract_btn.click(
290
+ fn=ocr_image,
291
+ inputs=[
292
+ image_input,
293
+ model_choice,
294
+ query_input,
295
+ max_new_tokens,
296
+ temperature,
297
+ top_p,
298
+ top_k,
299
+ repetition_penalty,
300
+ ],
301
+ outputs=[raw_output],
302
+ api_name="ocr_image",
303
+ )
304
+
305
+ pdf_btn.click(fn=save_as_pdf, inputs=[raw_output], outputs=[pdf_file])
306
+ word_btn.click(fn=save_as_word, inputs=[raw_output], outputs=[word_file])
307
+ audio_btn.click(fn=save_as_audio, inputs=[raw_output], outputs=[audio_file])
308
+
309
+ clear_btn.click(
310
+ fn=lambda: ("", None, "", MAX_NEW_TOKENS_DEFAULT, 0.1, 1.0, 0, 1.0),
311
+ outputs=[raw_output, image_input, query_input, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
312
+ )
313
+
314
+ if __name__ == "__main__":
315
+ # queue helps with GPU models; SSR off avoids hydration mismatches on Spaces
316
+ demo.queue(max_size=50).launch(share=True, ssr_mode=False, show_error=True)