Emeritus-21 commited on
Commit
5754029
Β·
verified Β·
1 Parent(s): dd7c4ed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +174 -31
app.py CHANGED
@@ -3,16 +3,21 @@ from threading import Thread
3
  import gradio as gr
4
  import spaces
5
  from PIL import Image
6
- import numpy as np
7
- import cv2
8
  import torch
9
- from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
 
 
 
 
 
10
 
11
  # ---------------- Models ----------------
12
  MODEL_PATHS = {
13
  "Model 1 (Complex handwrittings )": ("prithivMLmods/Qwen2.5-VL-7B-Abliterated-Caption-it", Qwen2_5_VLForConditionalGeneration),
14
  "Model 2 (simple and scanned handwritting )": ("nanonets/Nanonets-OCR-s", Qwen2_5_VLForConditionalGeneration),
15
  }
 
 
16
  MAX_NEW_TOKENS_DEFAULT = 512
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
18
  _loaded_processors, _loaded_models = {}, {}
@@ -32,47 +37,185 @@ for name, (repo_id, cls) in MODEL_PATHS.items():
32
  except Exception as e:
33
  print(f"⚠️ Failed to load {name}: {e}")
34
 
35
- # ---------------- Underline Detection ----------------
36
- def detect_underlines(image: Image.Image):
37
- cv_img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
38
- _, thresh = cv2.threshold(cv_img, 150, 255, cv2.THRESH_BINARY_INV)
39
- kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (30, 1))
40
- detected_lines = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, kernel, iterations=2)
41
- return detected_lines
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
- # ---------------- OCR + Underline ----------------
44
  @spaces.GPU
45
- def ocr_with_underlines(image: Image.Image, model_choice: str, query: str = None,
46
- max_new_tokens: int = MAX_NEW_TOKENS_DEFAULT):
47
- if image is None: return "Please upload an image."
 
 
48
  if model_choice not in _loaded_models: return f"Invalid model: {model_choice}"
49
- processor, model = _loaded_processors[model_choice], _loaded_models[model_choice]
50
-
51
- # Run OCR
52
- inputs = processor(images=image, text="Transcribe handwriting.", return_tensors="pt").to(device)
53
  with torch.inference_mode():
54
- output_ids = model.generate(**inputs, max_new_tokens=max_new_tokens)
55
- raw_text = processor.batch_decode(output_ids, skip_special_tokens=True)[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
- # Run CV underline detection
58
- underline_mask = detect_underlines(image)
59
- if np.sum(underline_mask) > 5000:
60
- raw_text = f"<u>{raw_text}</u>"
 
 
 
 
 
 
61
 
62
- return raw_text.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
- # ---------------- Gradio UI ----------------
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
66
- gr.Markdown("## ✍🏾 Wilson OCR (OpenCV underline mode)")
67
  model_choice = gr.Radio(choices=list(MODEL_PATHS.keys()), value=list(MODEL_PATHS.keys())[0], label="Select OCR Model")
68
 
69
  with gr.Tab("πŸ–Ό Image Inference"):
70
- query_input = gr.Textbox(label="Custom Prompt")
71
  image_input = gr.Image(type="pil", label="Upload / Capture Handwritten Image", sources=["upload", "webcam"])
 
 
 
 
 
 
72
  extract_btn = gr.Button("πŸ“€ Extract RAW Text", variant="primary")
73
- raw_output = gr.Textbox(label="πŸ“œ RAW Structured Output", lines=18, show_copy_button=True)
 
 
 
 
 
 
 
 
 
 
 
74
 
75
- extract_btn.click(fn=ocr_with_underlines, inputs=[image_input, model_choice, query_input], outputs=[raw_output])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  if __name__ == "__main__":
78
- demo.queue().launch(share=True)
 
3
  import gradio as gr
4
  import spaces
5
  from PIL import Image
 
 
6
  import torch
7
+ from transformers import AutoProcessor, AutoModelForImageTextToText, Qwen2_5_VLForConditionalGeneration
8
+ from reportlab.platypus import SimpleDocTemplate, Paragraph
9
+ from reportlab.lib.styles import getSampleStyleSheet
10
+ from docx import Document
11
+ from gtts import gTTS
12
+ from jiwer import cer
13
 
14
  # ---------------- Models ----------------
15
  MODEL_PATHS = {
16
  "Model 1 (Complex handwrittings )": ("prithivMLmods/Qwen2.5-VL-7B-Abliterated-Caption-it", Qwen2_5_VLForConditionalGeneration),
17
  "Model 2 (simple and scanned handwritting )": ("nanonets/Nanonets-OCR-s", Qwen2_5_VLForConditionalGeneration),
18
  }
19
+ # Model 3 has been removed to conserve memory.
20
+
21
  MAX_NEW_TOKENS_DEFAULT = 512
22
  device = "cuda" if torch.cuda.is_available() else "cpu"
23
  _loaded_processors, _loaded_models = {}, {}
 
37
  except Exception as e:
38
  print(f"⚠️ Failed to load {name}: {e}")
39
 
40
+ # ---------------- GPU Warmup ----------------
41
+ @spaces.GPU
42
+ def warmup(progress=gr.Progress(track_tqdm=True)):
43
+ try:
44
+ default_model_choice = next(iter(MODEL_PATHS.keys()))
45
+ processor = _loaded_processors[default_model_choice]
46
+ model = _loaded_models[default_model_choice]
47
+ tokenizer = getattr(processor, "tokenizer", None)
48
+ messages = [{"role": "user", "content": [{"type": "text", "text": "Warmup."}]}]
49
+ chat_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) if tokenizer and hasattr(tokenizer, "apply_chat_template") else "Warmup."
50
+ inputs = processor(text=[chat_prompt], images=None, return_tensors="pt").to(device)
51
+ with torch.inference_mode(): _ = model.generate(**inputs, max_new_tokens=1)
52
+ return f"GPU warm and {default_model_choice} ready."
53
+ except Exception as e:
54
+ return f"Warmup skipped: {e}"
55
+
56
+ # ---------------- Helpers ----------------
57
+ def _build_inputs(processor, tokenizer, image: Image.Image, prompt: str):
58
+ messages = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": prompt}]}]
59
+ if tokenizer and hasattr(tokenizer, "apply_chat_template"):
60
+ chat_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
61
+ return processor(text=[chat_prompt], images=[image], return_tensors="pt")
62
+ return processor(text=[prompt], images=[image], return_tensors="pt")
63
+
64
+ def _decode_text(model, processor, tokenizer, output_ids, prompt: str):
65
+ try:
66
+ decoded_text = processor.batch_decode(output_ids, skip_special_tokens=True)[0]
67
+ prompt_start = decoded_text.find(prompt)
68
+ if prompt_start != -1:
69
+ decoded_text = decoded_text[prompt_start + len(prompt):].strip()
70
+ else:
71
+ decoded_text = decoded_text.strip()
72
+ return decoded_text
73
+ except Exception:
74
+ try:
75
+ decoded_text = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
76
+ prompt_start = decoded_text.find(prompt)
77
+ if prompt_start != -1:
78
+ decoded_text = decoded_text[prompt_start + len(prompt):].strip()
79
+ return decoded_text
80
+ except Exception:
81
+ return str(output_ids).strip()
82
+
83
+ def _default_prompt(query: str | None) -> str:
84
+ if query and query.strip(): return query.strip()
85
+ return (
86
+ "You are a professional Handwritten OCR system.\n"
87
+ "TASK: Read the handwritten image and transcribe the text EXACTLY as written.\n"
88
+ "- Preserve original structure and line breaks.\n"
89
+ "- Keep spacing, bullet points, numbering, and indentation.\n"
90
+ "- Render tables as Markdown tables if present.\n"
91
+ "- Do NOT autocorrect spelling or grammar.\n"
92
+ "- Do NOT merge lines.\n"
93
+ "Return RAW transcription only."
94
+ )
95
 
96
+ # ---------------- OCR Function ----------------
97
  @spaces.GPU
98
+ def ocr_image(image: Image.Image, model_choice: str, query: str = None,
99
+ max_new_tokens: int = MAX_NEW_TOKENS_DEFAULT,
100
+ temperature: float = 0.1, top_p: float = 1.0, top_k: int = 0, repetition_penalty: float = 1.0,
101
+ progress=gr.Progress(track_tqdm=True)):
102
+ if image is None: return "Please upload or capture an image."
103
  if model_choice not in _loaded_models: return f"Invalid model: {model_choice}"
104
+ processor, model, tokenizer = _loaded_processors[model_choice], _loaded_models[model_choice], getattr(_loaded_processors[model_choice], "tokenizer", None)
105
+ prompt = _default_prompt(query)
106
+ batch = _build_inputs(processor, tokenizer, image, prompt).to(device)
 
107
  with torch.inference_mode():
108
+ output_ids = model.generate(**batch, max_new_tokens=max_new_tokens, do_sample=False,
109
+ temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty)
110
+ return _decode_text(model, processor, tokenizer, output_ids, prompt).replace("<|im_end|>", "").strip()
111
+
112
+ # ---------------- Export Helpers ----------------
113
+ def _safe_text(text: str) -> str: return (text or "").strip()
114
+
115
+ def save_as_pdf(text):
116
+ text = _safe_text(text)
117
+ if not text: return None
118
+ doc = SimpleDocTemplate("output.pdf")
119
+ flowables = [Paragraph(t, getSampleStyleSheet()["Normal"]) for t in text.splitlines() if t != ""]
120
+ if not flowables: flowables = [Paragraph(" ", getSampleStyleSheet()["Normal"])]
121
+ doc.build(flowables)
122
+ return "output.pdf"
123
+
124
+ def save_as_word(text):
125
+ text = _safe_text(text)
126
+ if not text: return None
127
+ doc = Document()
128
+ for line in text.splitlines(): doc.add_paragraph(line)
129
+ doc.save("output.docx")
130
+ return "output.docx"
131
 
132
+ def save_as_audio(text):
133
+ text = _safe_text(text)
134
+ if not text: return None
135
+ try:
136
+ tts = gTTS(text)
137
+ tts.save("output.mp3")
138
+ return "output.mp3"
139
+ except Exception as e:
140
+ print(f"gTTS failed: {e}")
141
+ return None
142
 
143
+ # ---------------- Metrics Function ----------------
144
+ def calculate_cer_score(ground_truth: str, prediction: str) -> str:
145
+ """
146
+ Calculates the Character Error Rate (CER) between two strings.
147
+ A CER of 0.0 means the prediction is perfect.
148
+ """
149
+ if not ground_truth or not prediction:
150
+ return "Cannot calculate CER: Missing ground truth or prediction."
151
+
152
+ ground_truth_cleaned = " ".join(ground_truth.strip().split())
153
+ prediction_cleaned = " ".join(prediction.strip().split())
154
+
155
+ error_rate = cer(ground_truth_cleaned, prediction_cleaned)
156
+ return f"Character Error Rate (CER): {error_rate:.4f}"
157
 
158
+ # ---------------- Evaluation Orchestration ----------------
159
+ @spaces.GPU
160
+ def perform_evaluation(image: Image.Image, model_name: str, ground_truth: str,
161
+ max_new_tokens: int, temperature: float, top_p: float, top_k: int, repetition_penalty: float):
162
+ if image is None or not ground_truth:
163
+ return "Please upload an image and provide the ground truth.", "N/A"
164
+
165
+ prediction = ocr_image(image, model_name, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty)
166
+
167
+ cer_score = calculate_cer_score(ground_truth, prediction)
168
+
169
+ return prediction, cer_score
170
+
171
+ # ---------------- Gradio Interface ----------------
172
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
173
+ gr.Markdown("## ✍🏾 wilson Handwritten OCR")
174
  model_choice = gr.Radio(choices=list(MODEL_PATHS.keys()), value=list(MODEL_PATHS.keys())[0], label="Select OCR Model")
175
 
176
  with gr.Tab("πŸ–Ό Image Inference"):
177
+ query_input = gr.Textbox(label="Custom Prompt (optional)", placeholder="Leave empty for RAW structured output")
178
  image_input = gr.Image(type="pil", label="Upload / Capture Handwritten Image", sources=["upload", "webcam"])
179
+ with gr.Accordion("βš™οΈ Advanced Options", open=False):
180
+ max_new_tokens = gr.Slider(1, 2048, value=MAX_NEW_TOKENS_DEFAULT, step=1, label="Max new tokens")
181
+ temperature = gr.Slider(0.1, 2.0, value=0.1, step=0.05, label="Temperature")
182
+ top_p = gr.Slider(0.05, 1.0, value=1.0, step=0.05, label="Top-p (nucleus)")
183
+ top_k = gr.Slider(0, 1000, value=0, step=1, label="Top-k")
184
+ repetition_penalty = gr.Slider(0.8, 2.0, value=1.0, step=0.05, label="Repetition penalty")
185
  extract_btn = gr.Button("πŸ“€ Extract RAW Text", variant="primary")
186
+ clear_btn = gr.Button("🧹 Clear")
187
+ raw_output = gr.Textbox(label="πŸ“œ RAW Structured Output (exact as written)", lines=18, show_copy_button=True)
188
+ pdf_btn = gr.Button("⬇️ Download as PDF")
189
+ word_btn = gr.Button("⬇️ Download as Word")
190
+ audio_btn = gr.Button("πŸ”Š Download as Audio")
191
+ pdf_file, word_file, audio_file = gr.File(label="PDF File"), gr.File(label="Word File"), gr.File(label="Audio File")
192
+
193
+ extract_btn.click(fn=ocr_image, inputs=[image_input, model_choice, query_input, max_new_tokens, temperature, top_p, top_k, repetition_penalty], outputs=[raw_output], api_name="ocr_image")
194
+ pdf_btn.click(fn=save_as_pdf, inputs=[raw_output], outputs=[pdf_file])
195
+ word_btn.click(fn=save_as_word, inputs=[raw_output], outputs=[word_file])
196
+ audio_btn.click(fn=save_as_audio, inputs=[raw_output], outputs=[audio_file])
197
+ clear_btn.click(fn=lambda: ("", None, "", MAX_NEW_TOKENS_DEFAULT, 0.1, 1.0, 0, 1.0), outputs=[raw_output, image_input, query_input, max_new_tokens, temperature, top_p, top_k, repetition_penalty])
198
 
199
+ with gr.Tab("πŸ“Š Model Evaluation"):
200
+ gr.Markdown("### πŸ” Evaluate Model Accuracy")
201
+ eval_image_input = gr.Image(type="pil", label="Upload Image for Evaluation", sources=["upload"])
202
+ eval_ground_truth = gr.Textbox(label="Ground Truth (Correct Transcription)", lines=10, placeholder="Type or paste the correct text here.")
203
+ eval_model_output = gr.Textbox(label="Model's Prediction", lines=10, interactive=False, show_copy_button=True)
204
+ eval_cer_output = gr.Textbox(label="Metrics", interactive=False)
205
+
206
+ with gr.Row():
207
+ run_evaluation_btn = gr.Button("πŸš€ Run OCR and Evaluate", variant="primary")
208
+ clear_evaluation_btn = gr.Button("🧹 Clear")
209
+
210
+ run_evaluation_btn.click(
211
+ fn=perform_evaluation,
212
+ inputs=[eval_image_input, model_choice, eval_ground_truth, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
213
+ outputs=[eval_model_output, eval_cer_output]
214
+ )
215
+ clear_evaluation_btn.click(
216
+ fn=lambda: (None, "", "", ""),
217
+ outputs=[eval_image_input, eval_ground_truth, eval_model_output, eval_cer_output]
218
+ )
219
 
220
  if __name__ == "__main__":
221
+ demo.queue(max_size=50).launch(share=True)