Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -1,7 +1,4 @@
|
|
1 |
-
# app.py — HTR Space (GPU-only, no webcam, mobile-ready)
|
2 |
-
|
3 |
import os
|
4 |
-
from threading import Thread
|
5 |
import gradio as gr
|
6 |
from PIL import Image
|
7 |
import torch
|
@@ -9,9 +6,13 @@ from transformers import AutoProcessor, AutoModelForImageTextToText, Qwen2_5_VLF
|
|
9 |
from reportlab.platypus import SimpleDocTemplate, Paragraph
|
10 |
from reportlab.lib.styles import getSampleStyleSheet
|
11 |
from docx import Document
|
|
|
12 |
|
|
|
|
|
|
|
|
|
13 |
MAX_NEW_TOKENS_DEFAULT = 512
|
14 |
-
DEVICE = "cuda" # GPU-only
|
15 |
|
16 |
# ---------------------------
|
17 |
# Models config
|
@@ -44,12 +45,20 @@ def load_model(name):
|
|
44 |
return processor, model
|
45 |
|
46 |
# ---------------------------
|
47 |
-
#
|
48 |
# ---------------------------
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
"You are a professional Handwritten OCR system.\n"
|
54 |
"TASK: Read the handwritten image and transcribe the text EXACTLY as written.\n"
|
55 |
"- Preserve original structure and line breaks.\n"
|
@@ -60,50 +69,7 @@ def _default_prompt(query: str | None) -> str:
|
|
60 |
"Return RAW transcription only."
|
61 |
)
|
62 |
|
63 |
-
|
64 |
-
return processor(text=[prompt], images=[image], return_tensors="pt").to(DEVICE)
|
65 |
-
|
66 |
-
def _decode_text(model, processor, tokenizer, output_ids):
|
67 |
-
text = ""
|
68 |
-
try:
|
69 |
-
if hasattr(processor, "batch_decode"):
|
70 |
-
text = processor.batch_decode(output_ids, skip_special_tokens=True)[0]
|
71 |
-
return text
|
72 |
-
except Exception:
|
73 |
-
pass
|
74 |
-
try:
|
75 |
-
if tokenizer is not None:
|
76 |
-
text = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
|
77 |
-
return text
|
78 |
-
except Exception:
|
79 |
-
pass
|
80 |
-
try:
|
81 |
-
model_tok = getattr(model, "tokenizer", None)
|
82 |
-
if model_tok is not None:
|
83 |
-
text = model_tok.batch_decode(output_ids, skip_special_tokens=True)[0]
|
84 |
-
return text
|
85 |
-
except Exception:
|
86 |
-
pass
|
87 |
-
return str(output_ids)
|
88 |
-
|
89 |
-
# ---------------------------
|
90 |
-
# GPU OCR function
|
91 |
-
# ---------------------------
|
92 |
-
from spaces import GPU
|
93 |
-
|
94 |
-
@GPU
|
95 |
-
def ocr_image_gpu(image: Image.Image, model_choice: str, query: str = None,
|
96 |
-
max_new_tokens: int = MAX_NEW_TOKENS_DEFAULT, temperature: float = 0.1,
|
97 |
-
top_p: float = 1.0, top_k: int = 0, repetition_penalty: float = 1.0):
|
98 |
-
if image is None:
|
99 |
-
return "Please upload an image."
|
100 |
-
|
101 |
-
if model_choice not in MODEL_PATHS:
|
102 |
-
return f"Invalid model: {model_choice}"
|
103 |
-
|
104 |
-
processor, model = load_model(model_choice)
|
105 |
-
prompt = _default_prompt(query)
|
106 |
-
batch = _build_inputs_plain(processor, image, prompt)
|
107 |
|
108 |
with torch.inference_mode():
|
109 |
output_ids = model.generate(
|
@@ -116,11 +82,17 @@ def ocr_image_gpu(image: Image.Image, model_choice: str, query: str = None,
|
|
116 |
repetition_penalty=repetition_penalty,
|
117 |
)
|
118 |
|
119 |
-
|
120 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
|
122 |
# ---------------------------
|
123 |
-
# Export
|
124 |
# ---------------------------
|
125 |
def _safe_text(text: str) -> str:
|
126 |
return (text or "").strip()
|
@@ -153,15 +125,10 @@ def save_as_audio(text):
|
|
153 |
text = _safe_text(text)
|
154 |
if not text:
|
155 |
return None
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
tts.save(filepath)
|
161 |
-
return filepath
|
162 |
-
except Exception as e:
|
163 |
-
print(f"gTTS failed: {e}")
|
164 |
-
return None
|
165 |
|
166 |
# ---------------------------
|
167 |
# Gradio UI
|
@@ -180,46 +147,45 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
180 |
placeholder="Leave empty for RAW structured output",
|
181 |
)
|
182 |
|
183 |
-
image_input = gr.Image(type="pil", label="Upload Image")
|
184 |
-
|
185 |
-
# Advanced Options
|
186 |
-
with gr.Accordion("⚙️ Advanced Options", open=False):
|
187 |
-
max_new_tokens = gr.Slider(1, 2048, value=MAX_NEW_TOKENS_DEFAULT, step=1, label="Max new tokens")
|
188 |
-
temperature = gr.Slider(0.1, 2.0, value=0.1, step=0.05, label="Temperature")
|
189 |
-
top_p = gr.Slider(0.05, 1.0, value=1.0, step=0.05, label="Top-p (nucleus)")
|
190 |
-
top_k = gr.Slider(0, 1000, value=0, step=1, label="Top-k")
|
191 |
-
repetition_penalty = gr.Slider(0.8, 2.0, value=1.0, step=0.05, label="Repetition penalty")
|
192 |
|
193 |
-
#
|
194 |
extract_btn = gr.Button("📤 Extract RAW Text", variant="primary")
|
195 |
-
|
196 |
raw_output = gr.Textbox(
|
197 |
label="📜 RAW Structured Output (exact as written)",
|
198 |
lines=18,
|
199 |
show_copy_button=True,
|
200 |
)
|
|
|
|
|
|
|
201 |
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
# ---------------------------
|
209 |
-
def on_extract(uploaded, model, query, max_tokens, temp, top_p, top_k, rep):
|
210 |
-
return ocr_image_gpu(uploaded, model, query, max_tokens, temp, top_p, top_k, rep)
|
211 |
|
|
|
212 |
extract_btn.click(
|
213 |
-
fn=
|
214 |
inputs=[image_input, model_choice, query_input,
|
215 |
max_new_tokens, temperature, top_p, top_k, repetition_penalty],
|
216 |
outputs=[raw_output]
|
217 |
)
|
218 |
|
219 |
-
|
220 |
-
|
221 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
222 |
|
|
|
223 |
clear_btn = gr.Button("🧹 Clear")
|
224 |
clear_btn.click(
|
225 |
fn=lambda: ("", None, "", MAX_NEW_TOKENS_DEFAULT, 0.1, 1.0, 0, 1.0),
|
|
|
|
|
|
|
1 |
import os
|
|
|
2 |
import gradio as gr
|
3 |
from PIL import Image
|
4 |
import torch
|
|
|
6 |
from reportlab.platypus import SimpleDocTemplate, Paragraph
|
7 |
from reportlab.lib.styles import getSampleStyleSheet
|
8 |
from docx import Document
|
9 |
+
from gtts import gTTS
|
10 |
|
11 |
+
# ---------------------------
|
12 |
+
# Device & constants
|
13 |
+
# ---------------------------
|
14 |
+
DEVICE = "cuda" # Force GPU usage
|
15 |
MAX_NEW_TOKENS_DEFAULT = 512
|
|
|
16 |
|
17 |
# ---------------------------
|
18 |
# Models config
|
|
|
45 |
return processor, model
|
46 |
|
47 |
# ---------------------------
|
48 |
+
# OCR function (GPU ready)
|
49 |
# ---------------------------
|
50 |
+
@gr.utils.space_decorator # Spaces decorator to detect GPU
|
51 |
+
def ocr_image_gpu(image: Image.Image, model_choice: str, query: str = None,
|
52 |
+
max_new_tokens: int = MAX_NEW_TOKENS_DEFAULT, temperature: float = 0.1,
|
53 |
+
top_p: float = 1.0, top_k: int = 0, repetition_penalty: float = 1.0):
|
54 |
+
if image is None:
|
55 |
+
return "Please upload an image."
|
56 |
+
if model_choice not in MODEL_PATHS:
|
57 |
+
return f"Invalid model: {model_choice}"
|
58 |
+
|
59 |
+
processor, model = load_model(model_choice)
|
60 |
+
|
61 |
+
prompt = query.strip() if query and query.strip() else (
|
62 |
"You are a professional Handwritten OCR system.\n"
|
63 |
"TASK: Read the handwritten image and transcribe the text EXACTLY as written.\n"
|
64 |
"- Preserve original structure and line breaks.\n"
|
|
|
69 |
"Return RAW transcription only."
|
70 |
)
|
71 |
|
72 |
+
batch = processor(text=[prompt], images=[image], return_tensors="pt").to(DEVICE)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
|
74 |
with torch.inference_mode():
|
75 |
output_ids = model.generate(
|
|
|
82 |
repetition_penalty=repetition_penalty,
|
83 |
)
|
84 |
|
85 |
+
# decode safely
|
86 |
+
text = ""
|
87 |
+
if hasattr(processor, "batch_decode"):
|
88 |
+
text = processor.batch_decode(output_ids, skip_special_tokens=True)[0]
|
89 |
+
elif hasattr(model, "tokenizer") and model.tokenizer is not None:
|
90 |
+
text = model.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
|
91 |
+
text = text.replace("<|im_end|>", "").strip()
|
92 |
+
return text
|
93 |
|
94 |
# ---------------------------
|
95 |
+
# Export helpers
|
96 |
# ---------------------------
|
97 |
def _safe_text(text: str) -> str:
|
98 |
return (text or "").strip()
|
|
|
125 |
text = _safe_text(text)
|
126 |
if not text:
|
127 |
return None
|
128 |
+
filepath = "output.mp3"
|
129 |
+
tts = gTTS(text)
|
130 |
+
tts.save(filepath)
|
131 |
+
return filepath
|
|
|
|
|
|
|
|
|
|
|
132 |
|
133 |
# ---------------------------
|
134 |
# Gradio UI
|
|
|
147 |
placeholder="Leave empty for RAW structured output",
|
148 |
)
|
149 |
|
150 |
+
image_input = gr.Image(type="pil", label="Upload Image (desktop/mobile)")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
|
152 |
+
# Buttons first
|
153 |
extract_btn = gr.Button("📤 Extract RAW Text", variant="primary")
|
|
|
154 |
raw_output = gr.Textbox(
|
155 |
label="📜 RAW Structured Output (exact as written)",
|
156 |
lines=18,
|
157 |
show_copy_button=True,
|
158 |
)
|
159 |
+
pdf_file = gr.File(label="PDF File")
|
160 |
+
word_file = gr.File(label="Word File")
|
161 |
+
audio_file = gr.File(label="Audio File")
|
162 |
|
163 |
+
with gr.Accordion("⚙️ Advanced Options", open=False):
|
164 |
+
max_new_tokens = gr.Slider(1, 2048, value=MAX_NEW_TOKENS_DEFAULT, step=1, label="Max new tokens")
|
165 |
+
temperature = gr.Slider(0.1, 2.0, value=0.1, step=0.05, label="Temperature")
|
166 |
+
top_p = gr.Slider(0.05, 1.0, value=1.0, step=0.05, label="Top-p (nucleus)")
|
167 |
+
top_k = gr.Slider(0, 1000, value=0, step=1, label="Top-k")
|
168 |
+
repetition_penalty = gr.Slider(0.8, 2.0, value=1.0, step=0.05, label="Repetition penalty")
|
|
|
|
|
|
|
169 |
|
170 |
+
# Extract text
|
171 |
extract_btn.click(
|
172 |
+
fn=ocr_image_gpu,
|
173 |
inputs=[image_input, model_choice, query_input,
|
174 |
max_new_tokens, temperature, top_p, top_k, repetition_penalty],
|
175 |
outputs=[raw_output]
|
176 |
)
|
177 |
|
178 |
+
# Export buttons
|
179 |
+
pdf_btn = gr.Button("⬇️ Download as PDF")
|
180 |
+
pdf_btn.click(fn=save_as_pdf, inputs=[raw_output], outputs=[pdf_file])
|
181 |
+
|
182 |
+
word_btn = gr.Button("⬇️ Download as Word")
|
183 |
+
word_btn.click(fn=save_as_word, inputs=[raw_output], outputs=[word_file])
|
184 |
+
|
185 |
+
audio_btn = gr.Button("🔊 Download as Audio")
|
186 |
+
audio_btn.click(fn=save_as_audio, inputs=[raw_output], outputs=[audio_file])
|
187 |
|
188 |
+
# Clear button
|
189 |
clear_btn = gr.Button("🧹 Clear")
|
190 |
clear_btn.click(
|
191 |
fn=lambda: ("", None, "", MAX_NEW_TOKENS_DEFAULT, 0.1, 1.0, 0, 1.0),
|