Emeritus-21 commited on
Commit
21f219b
Β·
verified Β·
1 Parent(s): 6d769e8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +148 -517
app.py CHANGED
@@ -1,12 +1,10 @@
1
- # app.py β€” HTR Space with Feedback Loop, Memory Post-Correction, and GRPO Export
2
 
3
- import os, time, json, hashlib, difflib, uuid, csv
4
  from datetime import datetime
5
- from collections import Counter, defaultdict
6
  from threading import Thread
7
-
8
  import gradio as gr
9
- import spaces
10
  from PIL import Image
11
  import torch
12
  from transformers import AutoProcessor, AutoModelForImageTextToText, Qwen2_5_VLForConditionalGeneration
@@ -15,575 +13,208 @@ from reportlab.lib.styles import getSampleStyleSheet
15
  from docx import Document
16
  from gtts import gTTS
17
  from jiwer import cer
 
 
18
 
19
- # ---------------- Storage & Paths ----------------
20
  os.makedirs("data", exist_ok=True)
21
- FEEDBACK_PATH = "data/feedback.jsonl" # raw feedback log (per sample)
22
- MEMORY_RULES_PATH = "data/memory_rules.json" # compiled post-correction rules
23
- GRPO_EXPORT_PATH = "data/grpo_prefs.jsonl" # preference pairs for GRPO
24
- CSV_EXPORT_PATH = "data/feedback.csv" # optional tabular export
25
 
26
  # ---------------- Models ----------------
27
  MODEL_PATHS = {
28
  "Model 1 (Complex handwrittings )": ("prithivMLmods/Qwen2.5-VL-7B-Abliterated-Caption-it", Qwen2_5_VLForConditionalGeneration),
29
- "Model 2 (simple and scanned handwritting )": ("nanonets/Nanonets-OCR-s", Qwen2_5_VLForConditionalGeneration),
30
  }
31
 
32
-
33
- MAX_NEW_TOKENS_DEFAULT = 512
34
  device = "cuda" if torch.cuda.is_available() else "cpu"
35
  _loaded_processors, _loaded_models = {}, {}
36
 
37
- print("πŸš€ Preloading models into GPU/CPU memory...")
38
  for name, (repo_id, cls) in MODEL_PATHS.items():
39
- try:
40
- processor = AutoProcessor.from_pretrained(repo_id, trust_remote_code=True)
41
- model = cls.from_pretrained(
42
- repo_id,
43
- trust_remote_code=True,
44
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
45
- low_cpu_mem_usage=True
46
- ).to(device).eval()
47
- _loaded_processors[name], _loaded_models[name] = processor, model
48
- print(f"βœ… {name} ready.")
49
- except Exception as e:
50
- print(f"⚠️ Failed to load {name}: {e}")
51
-
52
- # ---------------- GPU Warmup ----------------
53
- @spaces.GPU
54
- def warmup(progress=gr.Progress(track_tqdm=True)):
55
- try:
56
- default_model_choice = next(iter(MODEL_PATHS.keys()))
57
- processor = _loaded_processors[default_model_choice]
58
- model = _loaded_models[default_model_choice]
59
- tokenizer = getattr(processor, "tokenizer", None)
60
- messages = [{"role": "user", "content": [{"type": "text", "text": "Warmup."}]}]
61
- chat_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) if tokenizer and hasattr(tokenizer, "apply_chat_template") else "Warmup."
62
- inputs = processor(text=[chat_prompt], images=None, return_tensors="pt").to(device)
63
- with torch.inference_mode():
64
- _ = model.generate(**inputs, max_new_tokens=1)
65
- return f"GPU warm and {default_model_choice} ready."
66
- except Exception as e:
67
- return f"Warmup skipped: {e}"
68
 
69
- # ---------------- Helpers ----------------
70
- def _build_inputs(processor, tokenizer, image: Image.Image, prompt: str):
71
- messages = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": prompt}]}]
72
- if tokenizer and hasattr(tokenizer, "apply_chat_template"):
73
- chat_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
74
- return processor(text=[chat_prompt], images=[image], return_tensors="pt")
75
- return processor(text=[prompt], images=[image], return_tensors="pt")
76
-
77
- def _decode_text(model, processor, tokenizer, output_ids, prompt: str):
78
- try:
79
- decoded_text = processor.batch_decode(output_ids, skip_special_tokens=True)[0]
80
- prompt_start = decoded_text.find(prompt)
81
- if prompt_start != -1:
82
- decoded_text = decoded_text[prompt_start + len(prompt):].strip()
83
- else:
84
- decoded_text = decoded_text.strip()
85
- return decoded_text
86
- except Exception:
87
- try:
88
- decoded_text = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
89
- prompt_start = decoded_text.find(prompt)
90
- if prompt_start != -1:
91
- decoded_text = decoded_text[prompt_start + len(prompt):].strip()
92
- return decoded_text
93
- except Exception:
94
- return str(output_ids).strip()
95
 
96
- def _default_prompt(query: str | None) -> str:
97
- if query and query.strip():
98
- return query.strip()
99
- return (
100
- "You are a professional Handwritten OCR system.\n"
101
- "TASK: Read the handwritten image and transcribe the text EXACTLY as written.\n"
102
- "- Preserve original structure and line breaks.\n"
103
- "- Keep spacing, bullet points, numbering, and indentation.\n"
104
- "- Render tables as Markdown tables if present.\n"
105
- "- Do NOT autocorrect spelling or grammar.\n"
106
- "- Do NOT merge lines.\n"
107
- "Return RAW transcription only."
108
- )
109
 
110
  def _safe_text(text: str) -> str:
111
  return (text or "").strip()
112
 
113
- def _hash_image(image: Image.Image) -> str:
114
- # stable hash for dedup / linking feedback to the same page
115
- img_bytes = image.tobytes()
116
- return hashlib.sha256(img_bytes).hexdigest()
 
 
 
 
117
 
118
- # ---------------- Memory: Post-correction Rules ----------------
119
- def _load_memory_rules():
120
- if os.path.exists(MEMORY_RULES_PATH):
121
- try:
122
- with open(MEMORY_RULES_PATH, "r", encoding="utf-8") as f:
123
- return json.load(f)
124
- except Exception:
125
- pass
126
- return {"global": {}, "by_model": {}}
127
-
128
- def _save_memory_rules(rules):
129
- with open(MEMORY_RULES_PATH, "w", encoding="utf-8") as f:
130
- json.dump(rules, f, ensure_ascii=False, indent=2)
131
-
132
- def _apply_memory(text: str, model_choice: str, enabled: bool):
133
- if not enabled or not text:
134
- return text
135
- rules = _load_memory_rules()
136
- # 1) Model-specific replacements
137
- by_model = rules.get("by_model", {}).get(model_choice, {})
138
- for wrong, right in by_model.items():
139
- if wrong and right:
140
- text = text.replace(wrong, right)
141
- # 2) Global replacements
142
- for wrong, right in rules.get("global", {}).items():
143
- if wrong and right:
144
- text = text.replace(wrong, right)
145
- return text
146
-
147
- def _compile_rules_from_feedback(min_count: int = 2, max_phrase_len: int = 40):
148
- """
149
- Build replacement rules by mining feedback pairs (prediction -> correction).
150
- We extract phrases that consistently changed, with frequency >= min_count.
151
- """
152
- changes_counter_global = Counter()
153
- changes_counter_by_model = defaultdict(Counter)
154
-
155
- if not os.path.exists(FEEDBACK_PATH):
156
- return
157
-
158
- with open(FEEDBACK_PATH, "r", encoding="utf-8") as f:
159
- for line in f:
160
- try:
161
- row = json.loads(line)
162
- except Exception:
163
- continue
164
- if row.get("reward", 0) < 1: # only learn from thumbs-up or explicit 'accepted_correction'
165
- continue
166
- pred = _safe_text(row.get("prediction", ""))
167
- corr = _safe_text(row.get("correction", "")) or _safe_text(row.get("ground_truth", ""))
168
- if not pred or not corr:
169
- continue
170
- model_choice = row.get("model_choice", "")
171
- # Extract ops
172
- s = difflib.SequenceMatcher(None, pred, corr)
173
- for tag, i1, i2, j1, j2 in s.get_opcodes():
174
- if tag in ("replace", "delete", "insert"):
175
- wrong = pred[i1:i2]
176
- right = corr[j1:j2]
177
- # keep short-ish tokens/phrases
178
- if 0 < len(wrong) <= max_phrase_len or 0 < len(right) <= max_phrase_len:
179
- if wrong.strip():
180
- changes_counter_global[(wrong, right)] += 1
181
- if model_choice:
182
- changes_counter_by_model[model_choice][(wrong, right)] += 1
183
-
184
- rules = {"global": {}, "by_model": {}}
185
- # Global
186
- for (wrong, right), cnt in changes_counter_global.items():
187
- if cnt >= min_count and wrong and right and wrong != right:
188
- rules["global"][wrong] = right
189
- # Per model
190
- for model_choice, ctr in changes_counter_by_model.items():
191
- rules["by_model"].setdefault(model_choice, {})
192
- for (wrong, right), cnt in ctr.items():
193
- if cnt >= min_count and wrong and right and wrong != right:
194
- rules["by_model"][model_choice][wrong] = right
195
-
196
- _save_memory_rules(rules)
197
-
198
- # ---------------- OCR Function ----------------
199
- @spaces.GPU
200
  def ocr_image(image: Image.Image, model_choice: str, query: str = None,
201
  max_new_tokens: int = MAX_NEW_TOKENS_DEFAULT,
202
  temperature: float = 0.1, top_p: float = 1.0, top_k: int = 0, repetition_penalty: float = 1.0,
203
- use_memory: bool = True,
204
- progress=gr.Progress(track_tqdm=True)):
205
- if image is None: return "Please upload or capture an image."
206
- if model_choice not in _loaded_models: return f"Invalid model: {model_choice}"
207
- processor, model, tokenizer = _loaded_processors[model_choice], _loaded_models[model_choice], getattr(_loaded_processors[model_choice], "tokenizer", None)
208
  prompt = _default_prompt(query)
209
- batch = _build_inputs(processor, tokenizer, image, prompt).to(device)
 
 
210
  with torch.inference_mode():
211
  output_ids = model.generate(**batch, max_new_tokens=max_new_tokens, do_sample=False,
212
  temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty)
213
- raw = _decode_text(model, processor, tokenizer, output_ids, prompt).replace("<|im_end|>", "").strip()
214
- # Apply memory post-correction
215
- post = _apply_memory(raw, model_choice, use_memory)
216
- return post
217
 
218
- # ---------------- Export Helpers ----------------
219
- def save_as_pdf(text):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  text = _safe_text(text)
221
  if not text: return None
222
  doc = SimpleDocTemplate("output.pdf")
223
- flowables = [Paragraph(t, getSampleStyleSheet()["Normal"]) for t in text.splitlines() if t != ""]
224
- if not flowables: flowables = [Paragraph(" ", getSampleStyleSheet()["Normal"])]
225
- doc.build(flowables)
226
  return "output.pdf"
227
 
228
- def save_as_word(text):
229
  text = _safe_text(text)
230
  if not text: return None
231
  doc = Document()
232
- for line in text.splitlines():
233
- doc.add_paragraph(line)
234
  doc.save("output.docx")
235
  return "output.docx"
236
 
237
- def save_as_audio(text):
238
  text = _safe_text(text)
239
  if not text: return None
240
- try:
241
- tts = gTTS(text)
242
- tts.save("output.mp3")
243
- return "output.mp3"
244
- except Exception as e:
245
- print(f"gTTS failed: {e}")
246
- return None
247
-
248
- # ---------------- Metrics Function ----------------
249
- def calculate_cer_score(ground_truth: str, prediction: str) -> str:
250
- """
251
- Calculates the Character Error Rate (CER).
252
- A CER of 0.0 means the prediction is perfect.
253
- """
254
- if not ground_truth or not prediction:
255
- return "Cannot calculate CER: Missing ground truth or prediction."
256
- ground_truth_cleaned = " ".join(ground_truth.strip().split())
257
- prediction_cleaned = " ".join(prediction.strip().split())
258
- error_rate = cer(ground_truth_cleaned, prediction_cleaned)
259
- return f"Character Error Rate (CER): {error_rate:.4f}"
260
-
261
- # ---------------- Feedback & Dataset ----------------
262
- def _append_jsonl(path, obj):
263
- with open(path, "a", encoding="utf-8") as f:
264
- f.write(json.dumps(obj, ensure_ascii=False) + "\n")
265
-
266
- def _export_csv():
267
- # optional: CSV summary for spreadsheet views
268
- if not os.path.exists(FEEDBACK_PATH):
269
- return CSV_EXPORT_PATH if os.path.exists(CSV_EXPORT_PATH) else None
270
- rows = []
271
- with open(FEEDBACK_PATH, "r", encoding="utf-8") as f:
272
- for line in f:
273
- try:
274
- rows.append(json.loads(line))
275
- except Exception:
276
- pass
277
- if not rows:
278
- return None
279
- keys = ["id","timestamp","model_choice","image_sha256","prompt","prediction","correction","ground_truth","reward","cer"]
280
- with open(CSV_EXPORT_PATH, "w", newline="", encoding="utf-8") as f:
281
- w = csv.DictWriter(f, fieldnames=keys)
282
- w.writeheader()
283
- for r in rows:
284
- flat = {k: r.get(k, "") for k in keys}
285
- w.writerow(flat)
286
- return CSV_EXPORT_PATH
287
-
288
- def save_feedback(image: Image.Image, model_choice: str, prompt: str,
289
- prediction: str, correction: str, ground_truth: str, reward: int):
290
- """
291
- reward: 1 = good/accepted, 0 = neutral, -1 = bad
292
- """
293
- if image is None:
294
- return "Please provide the image again to link feedback.", 0
295
- if not prediction and not correction and not ground_truth:
296
- return "Nothing to save.", 0
297
-
298
- image_hash = _hash_image(image)
299
- # best target = correction, else ground_truth, else prediction
300
- target = _safe_text(correction) or _safe_text(ground_truth)
301
- pred = _safe_text(prediction)
302
- cer_score = None
303
- if target and pred:
304
- try:
305
- cer_score = cer(" ".join(target.split()), " ".join(pred.split()))
306
- except Exception:
307
- cer_score = None
308
-
309
- row = {
310
- "id": str(uuid.uuid4()),
311
- "timestamp": datetime.utcnow().isoformat(),
312
- "model_choice": model_choice or "",
313
- "image_sha256": image_hash,
314
- "prompt": _safe_text(prompt),
315
- "prediction": pred,
316
- "correction": _safe_text(correction),
317
- "ground_truth": _safe_text(ground_truth),
318
- "reward": int(reward),
319
- "cer": float(cer_score) if cer_score is not None else None,
320
- }
321
- _append_jsonl(FEEDBACK_PATH, row)
322
- return f"βœ… Feedback saved (reward={reward}).", 1
323
 
324
- def compile_memory_rules():
325
- _compile_rules_from_feedback(min_count=2, max_phrase_len=60)
326
- return "βœ… Memory rules recompiled from positive feedback."
327
-
328
- def export_grpo_preferences():
329
- """
330
- Build preference pairs for GRPO training:
331
- - chosen: correction/ground_truth when present
332
- - rejected: original prediction
333
- """
334
- if not os.path.exists(FEEDBACK_PATH):
335
- return "No feedback to export."
336
- count = 0
337
- with open(GRPO_EXPORT_PATH, "w", encoding="utf-8") as out_f:
338
- with open(FEEDBACK_PATH, "r", encoding="utf-8") as f:
339
- for line in f:
340
- try:
341
- row = json.loads(line)
342
- except Exception:
343
- continue
344
- pred = _safe_text(row.get("prediction", ""))
345
- corr = _safe_text(row.get("correction", "")) or _safe_text(row.get("ground_truth", ""))
346
- prompt = _safe_text(row.get("prompt", "")) or "Transcribe the image exactly."
347
- if corr and pred and corr != pred and row.get("reward", 0) >= 0:
348
- # One preference datapoint
349
- out = {
350
- "prompt": prompt,
351
- "image_sha256": row.get("image_sha256", ""),
352
- "chosen": corr,
353
- "rejected": pred,
354
- "model_choice": row.get("model_choice", "")
355
- }
356
- out_f.write(json.dumps(out, ensure_ascii=False) + "\n")
357
- count += 1
358
- return f"βœ… Exported {count} GRPO preference pairs to {GRPO_EXPORT_PATH}."
359
 
360
- def export_csv():
361
- p = _export_csv()
362
- if p:
363
- return f"βœ… CSV exported: {p}"
364
- return "No data to export."
365
-
366
- # ---------------- Evaluation Orchestration ----------------
367
- @spaces.GPU
368
- def perform_evaluation(image: Image.Image, model_name: str, ground_truth: str,
369
- max_new_tokens: int, temperature: float, top_p: float, top_k: int, repetition_penalty: float,
370
- use_memory: bool = True):
371
- if image is None or not ground_truth:
372
- return "Please upload an image and provide the ground truth.", "N/A"
373
- prediction = ocr_image(image, model_name, max_new_tokens=max_new_tokens,
374
- temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty,
375
- use_memory=use_memory)
376
- cer_score = calculate_cer_score(ground_truth, prediction)
377
- return prediction, cer_score
378
-
379
- # ---------------- GRPO Trainer Script Writer ----------------
380
- TRAINER_SCRIPT = r"""# grpo_train.py β€” Offline GRPO training with TRL (run separately)
381
- # pip install trl accelerate peft transformers datasets
382
- # This script expects data/grpo_prefs.jsonl produced by the app.
383
-
384
- import os, json
385
- from datasets import load_dataset
386
- from transformers import AutoModelForCausalLM, AutoTokenizer
387
- from trl import GRPOConfig, GRPOTrainer
388
-
389
- MODEL_ID = os.environ.get("BASE_MODEL", "Qwen/Qwen2.5-VL-7B-Instruct") # change if needed
390
- OUTPUT_DIR = os.environ.get("OUTPUT_DIR", "grpo_output")
391
- DATA_PATH = os.environ.get("DATA_PATH", "data/grpo_prefs.jsonl")
392
-
393
- # Our jsonl: each line has prompt, chosen, rejected (and image_sha256/model_choice optionally)
394
- # We'll format as required by TRL: prompt + responses with one preferred
395
-
396
- def _jsonl_dataset(jsonl_path):
397
- data = []
398
- with open(jsonl_path, "r", encoding="utf-8") as f:
399
- for line in f:
400
- try:
401
- row = json.loads(line)
402
- except Exception:
403
- continue
404
- prompt = row.get("prompt", "")
405
- chosen = row.get("chosen", "")
406
- rejected = row.get("rejected", "")
407
- if prompt and chosen and rejected:
408
- data.append({"prompt": prompt, "chosen": chosen, "rejected": rejected})
409
- return data
410
-
411
- def main():
412
- data = _jsonl_dataset(DATA_PATH)
413
- if not data:
414
- print("No GRPO data found.")
415
- return
416
- # Create a HuggingFace datasets Dataset from memory
417
- from datasets import Dataset
418
- ds = Dataset.from_list(data)
419
-
420
- tok = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
421
- model = AutoModelForCausalLM.from_pretrained(
422
- MODEL_ID, trust_remote_code=True, device_map="auto"
423
- )
424
-
425
- # Minimal config β€” tune to your GPU
426
- cfg = GRPOConfig(
427
- output_dir=OUTPUT_DIR,
428
- learning_rate=5e-6,
429
- per_device_train_batch_size=1,
430
- gradient_accumulation_steps=8,
431
- num_train_epochs=1,
432
- logging_steps=10,
433
- save_steps=200,
434
- max_prompt_length=512,
435
- max_completion_length=768,
436
- bf16=True
437
- )
438
-
439
- trainer = GRPOTrainer(
440
- model=model,
441
- ref_model=None, # let TRL create a frozen copy internally
442
- args=cfg,
443
- tokenizer=tok,
444
- train_dataset=ds
445
- )
446
- trainer.train()
447
- trainer.save_model(OUTPUT_DIR)
448
- print("βœ… GRPO training complete. LoRA/weights saved to", OUTPUT_DIR)
449
-
450
- if __name__ == "__main__":
451
- main()
452
- """
453
-
454
- def _write_trainer_script():
455
- os.makedirs("train", exist_ok=True)
456
- path = os.path.join("train", "grpo_train.py")
457
- with open(path, "w", encoding="utf-8") as f:
458
- f.write(TRAINER_SCRIPT)
459
- return path
460
 
461
  # ---------------- Gradio Interface ----------------
462
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
463
- gr.Markdown("## ✍🏾 wilson Handwritten text recognition with Feedback Loop")
464
 
465
- model_choice = gr.Radio(choices=list(MODEL_PATHS.keys()),
466
- value=list(MODEL_PATHS.keys())[0],
467
- label="Select OCR Model")
468
 
469
- with gr.Tab("πŸ–Ό Image Inference"):
470
- query_input = gr.Textbox(label="Custom Prompt (optional)", placeholder="Leave empty for RAW structured output")
471
- image_input = gr.Image(type="pil", label="Upload / Capture Handwritten Image", sources=["upload", "webcam"])
472
- use_memory = gr.Checkbox(value=True, label="Enable Memory Post-correction (auto-fix known mistakes)")
473
 
474
  with gr.Accordion("βš™οΈ Advanced Options", open=False):
475
  max_new_tokens = gr.Slider(1, 2048, value=MAX_NEW_TOKENS_DEFAULT, step=1, label="Max new tokens")
476
  temperature = gr.Slider(0.1, 2.0, value=0.1, step=0.05, label="Temperature")
477
- top_p = gr.Slider(0.05, 1.0, value=1.0, step=0.05, label="Top-p (nucleus)")
478
- top_k = gr.Slider(0, 1000, value=0, step=1, label="Top-k")
479
- repetition_penalty = gr.Slider(0.8, 2.0, value=1.0, step=0.05, label="Repetition penalty")
480
-
481
- extract_btn = gr.Button("πŸ“€ Extract RAW Text", variant="primary")
482
- clear_btn = gr.Button("🧹 Clear")
483
 
484
- raw_output = gr.Textbox(label="πŸ“œ Output (post-corrected if memory is ON)", lines=18, show_copy_button=True)
 
485
 
486
- # Quick Feedback strip
487
  gr.Markdown("### ✏️ Quick Feedback")
488
- correction_box = gr.Textbox(label="Your Correction (optional)", placeholder="Paste your corrected text here; leave empty if the output is perfect.", lines=8)
489
- ground_truth_box = gr.Textbox(label="Ground Truth (optional)", placeholder="If you have a reference transcription, paste it here.", lines=6)
490
-
491
- with gr.Row():
492
- btn_good = gr.Button("πŸ‘ Accept (Save Feedback as Correct)", variant="primary")
493
- btn_bad = gr.Button("πŸ‘Ž Bad (Save Feedback as Incorrect)")
494
-
495
  feedback_status = gr.Markdown()
496
 
497
- pdf_btn = gr.Button("⬇️ Download as PDF")
498
- word_btn = gr.Button("⬇️ Download as Word")
499
- audio_btn = gr.Button("πŸ”Š Download as Audio")
500
- pdf_file, word_file, audio_file = gr.File(label="PDF File"), gr.File(label="Word File"), gr.File(label="Audio File")
501
-
502
- extract_btn.click(
503
- fn=ocr_image,
504
- inputs=[image_input, model_choice, query_input, max_new_tokens, temperature, top_p, top_k, repetition_penalty, use_memory],
505
- outputs=[raw_output],
506
- api_name="ocr_image"
507
- )
508
- pdf_btn.click(fn=save_as_pdf, inputs=[raw_output], outputs=[pdf_file])
509
- word_btn.click(fn=save_as_word, inputs=[raw_output], outputs=[word_file])
510
- audio_btn.click(fn=save_as_audio, inputs=[raw_output], outputs=[audio_file])
511
-
512
- def _clear():
513
- return ("", None, "", MAX_NEW_TOKENS_DEFAULT, 0.1, 1.0, 0, 1.0, True, "", "", "",)
514
- clear_btn.click(
515
- fn=_clear,
516
- outputs=[raw_output, image_input, query_input, max_new_tokens, temperature, top_p, top_k, repetition_penalty, use_memory, correction_box, ground_truth_box, feedback_status]
517
- )
518
-
519
- # Quick feedback save
520
- btn_good.click(
521
- fn=lambda img, mc, prmpt, pred, corr, gt: save_feedback(img, mc, prmpt, pred, corr, gt, reward=1),
522
- inputs=[image_input, model_choice, query_input, raw_output, correction_box, ground_truth_box],
523
- outputs=[feedback_status]
524
- )
525
- btn_bad.click(
526
- fn=lambda img, mc, prmpt, pred, corr, gt: save_feedback(img, mc, prmpt, pred, corr, gt, reward=-1),
527
- inputs=[image_input, model_choice, query_input, raw_output, correction_box, ground_truth_box],
528
- outputs=[feedback_status]
529
- )
530
-
531
- with gr.Tab("πŸ“Š Model Evaluation"):
532
- gr.Markdown("### πŸ” Evaluate Model Accuracy")
533
- eval_image_input = gr.Image(type="pil", label="Upload Image for Evaluation", sources=["upload"])
534
- eval_ground_truth = gr.Textbox(label="Ground Truth (Correct Transcription)", lines=10, placeholder="Type or paste the correct text here.")
535
- eval_model_output = gr.Textbox(label="Model's Prediction", lines=10, interactive=False, show_copy_button=True)
536
- eval_cer_output = gr.Textbox(label="Metrics", interactive=False)
537
- eval_use_memory = gr.Checkbox(value=True, label="Enable Memory Post-correction")
538
-
539
- with gr.Row():
540
- run_evaluation_btn = gr.Button("πŸš€ Run OCR and Evaluate", variant="primary")
541
- clear_evaluation_btn = gr.Button("🧹 Clear")
542
-
543
- run_evaluation_btn.click(
544
- fn=perform_evaluation,
545
- inputs=[eval_image_input, model_choice, eval_ground_truth, max_new_tokens, temperature, top_p, top_k, repetition_penalty, eval_use_memory],
546
- outputs=[eval_model_output, eval_cer_output]
547
- )
548
- clear_evaluation_btn.click(
549
- fn=lambda: (None, "", "", ""),
550
- outputs=[eval_image_input, eval_ground_truth, eval_model_output, eval_cer_output]
551
- )
552
-
553
- with gr.Tab("✏️ Feedback & Memory"):
554
- gr.Markdown("""
555
- **Pipeline**
556
- 1) Save feedback (πŸ‘ / πŸ‘Ž) and add corrections.
557
- 2) Click **Build/Refresh Memory** to generate auto-fix rules from positive feedback.
558
- 3) Keep **Enable Memory Post-correction** checked on inference/eval tabs.
559
- """)
560
- build_mem_btn = gr.Button("🧠 Build/Refresh Memory from Feedback")
561
- mem_status = gr.Markdown()
562
- build_mem_btn.click(fn=compile_memory_rules, outputs=[mem_status])
563
-
564
- csv_btn = gr.Button("πŸ“€ Export Feedback as CSV")
565
- csv_status = gr.Markdown()
566
- csv_btn.click(fn=export_csv, outputs=[csv_status])
567
-
568
- with gr.Tab("πŸ§ͺ GRPO / Dataset"):
569
- gr.Markdown("""
570
- **GRPO Fine-tuning** (run offline or in a training Space):
571
- - Click **Export GRPO Preferences** to produce `data/grpo_prefs.jsonl` of (prompt, chosen, rejected).
572
- - Click **Write Trainer Script** to create `train/grpo_train.py`.
573
- - Then run:
574
- ```bash
575
- pip install trl accelerate peft transformers datasets
576
- python train/grpo_train.py
577
- ```
578
- Set `BASE_MODEL`/`OUTPUT_DIR` env vars if you like.
579
- """)
580
- grpo_btn = gr.Button("πŸ“¦ Export GRPO Preferences")
581
  grpo_status = gr.Markdown()
582
- grpo_btn.click(fn=export_grpo_preferences, outputs=[grpo_status])
583
-
584
- write_script_btn = gr.Button("πŸ“ Write grpo_train.py")
585
- write_script_status = gr.Markdown()
586
- write_script_btn.click(fn=lambda: f"βœ… Trainer script written to `{_write_trainer_script()}`", outputs=[write_script_status])
587
 
588
  if __name__ == "__main__":
589
  demo.queue(max_size=50).launch(share=True)
 
1
+ # app.py β€” HTR Space Full Version with RPL, GRPO, Multi-Format Export, Embedding Similarity
2
 
3
+ import os, time, json, hashlib, uuid, csv
4
  from datetime import datetime
 
5
  from threading import Thread
6
+ from collections import defaultdict
7
  import gradio as gr
 
8
  from PIL import Image
9
  import torch
10
  from transformers import AutoProcessor, AutoModelForImageTextToText, Qwen2_5_VLForConditionalGeneration
 
13
  from docx import Document
14
  from gtts import gTTS
15
  from jiwer import cer
16
+ import numpy as np
17
+ from sklearn.metrics.pairwise import cosine_similarity
18
 
19
+ # ---------------- Paths ----------------
20
  os.makedirs("data", exist_ok=True)
21
+ FEEDBACK_RPL_PATH = "data/feedback_rpl.jsonl"
22
+ GRPO_PATH = "data/grpo_prefs.jsonl"
23
+ CSV_PATH = "data/feedback_rpl.csv"
 
24
 
25
  # ---------------- Models ----------------
26
  MODEL_PATHS = {
27
  "Model 1 (Complex handwrittings )": ("prithivMLmods/Qwen2.5-VL-7B-Abliterated-Caption-it", Qwen2_5_VLForConditionalGeneration),
28
+ "Model 2 (Simple scanned handwritting )": ("nanonets/Nanonets-OCR-s", Qwen2_5_VLForConditionalGeneration)
29
  }
30
 
 
 
31
  device = "cuda" if torch.cuda.is_available() else "cpu"
32
  _loaded_processors, _loaded_models = {}, {}
33
 
34
+ print("πŸš€ Loading models...")
35
  for name, (repo_id, cls) in MODEL_PATHS.items():
36
+ processor = AutoProcessor.from_pretrained(repo_id, trust_remote_code=True)
37
+ model = cls.from_pretrained(repo_id, trust_remote_code=True).to(device).eval()
38
+ _loaded_processors[name], _loaded_models[name] = processor, model
39
+ print(f"βœ… {name} ready.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
+ MAX_NEW_TOKENS_DEFAULT = 512
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
+ # ---------------- Helpers ----------------
44
+ def _hash_image(image: Image.Image) -> str:
45
+ return hashlib.sha256(image.tobytes()).hexdigest()
 
 
 
 
 
 
 
 
 
 
46
 
47
  def _safe_text(text: str) -> str:
48
  return (text or "").strip()
49
 
50
+ def _default_prompt(query: str | None) -> str:
51
+ if query and query.strip(): return query.strip()
52
+ return ("You are a professional Handwritten OCR system.\n"
53
+ "TASK: Read the handwritten image and transcribe exactly as written.\n"
54
+ "- Preserve line breaks, indentation, bullets, numbering.\n"
55
+ "- Tables as Markdown tables if present.\n"
56
+ "- Do NOT autocorrect spelling or merge lines.\n"
57
+ "Return RAW transcription only.")
58
 
59
+ def _append_jsonl(path, obj):
60
+ with open(path, "a", encoding="utf-8") as f: f.write(json.dumps(obj, ensure_ascii=False) + "\n")
61
+
62
+ # ---------------- OCR ----------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  def ocr_image(image: Image.Image, model_choice: str, query: str = None,
64
  max_new_tokens: int = MAX_NEW_TOKENS_DEFAULT,
65
  temperature: float = 0.1, top_p: float = 1.0, top_k: int = 0, repetition_penalty: float = 1.0,
66
+ use_rpl: bool = True):
67
+ if image is None: return "Upload image first."
68
+ processor, model = _loaded_processors[model_choice], _loaded_models[model_choice]
 
 
69
  prompt = _default_prompt(query)
70
+
71
+ # Build input
72
+ batch = processor(text=[prompt], images=[image], return_tensors="pt").to(device)
73
  with torch.inference_mode():
74
  output_ids = model.generate(**batch, max_new_tokens=max_new_tokens, do_sample=False,
75
  temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty)
76
+ raw_text = processor.batch_decode(output_ids, skip_special_tokens=True)[0].replace("<|im_end|>", "").strip()
 
 
 
77
 
78
+ # RPL: Apply feedback using embedding similarity
79
+ if use_rpl and os.path.exists(FEEDBACK_RPL_PATH):
80
+ try:
81
+ current_embedding = np.random.rand(768).reshape(1, -1) # placeholder for real embedding
82
+ for line in open(FEEDBACK_RPL_PATH, encoding="utf-8"):
83
+ row = json.loads(line)
84
+ if row.get("reward", 0) < 1: continue
85
+ emb = np.array(row.get("embedding", np.random.rand(768))).reshape(1, -1)
86
+ sim = cosine_similarity(current_embedding, emb)[0][0]
87
+ if sim > 0.85:
88
+ raw_text = row.get("correction") or row.get("ground_truth")
89
+ break
90
+ except Exception: pass
91
+ return raw_text
92
+
93
+ # ---------------- Feedback ----------------
94
+ def save_feedback(image: Image.Image, model_choice: str, prompt: str,
95
+ prediction: str, correction: str, ground_truth: str, reward: int):
96
+ if image is None: return "Provide image.", 0
97
+ row = {
98
+ "id": str(uuid.uuid4()),
99
+ "timestamp": datetime.utcnow().isoformat(),
100
+ "model_choice": model_choice,
101
+ "image_sha256": _hash_image(image),
102
+ "prompt": _safe_text(prompt),
103
+ "prediction": _safe_text(prediction),
104
+ "correction": _safe_text(correction),
105
+ "ground_truth": _safe_text(ground_truth),
106
+ "reward": reward,
107
+ "embedding": np.random.rand(768).tolist()
108
+ }
109
+ _append_jsonl(FEEDBACK_RPL_PATH, row)
110
+ return f"βœ… Feedback saved (reward={reward}).", 1
111
+
112
+ def export_csv():
113
+ if not os.path.exists(FEEDBACK_RPL_PATH): return None
114
+ keys, rows = None, []
115
+ for line in open(FEEDBACK_RPL_PATH, encoding="utf-8"):
116
+ try: row = json.loads(line); rows.append(row); keys = keys or list(row.keys())
117
+ except: continue
118
+ if not rows: return None
119
+ with open(CSV_PATH, "w", newline="", encoding="utf-8") as f:
120
+ writer = csv.DictWriter(f, fieldnames=keys)
121
+ writer.writeheader(); writer.writerows(rows)
122
+ return CSV_PATH
123
+
124
+ # ---------------- Export Formats ----------------
125
+ def save_pdf(text):
126
  text = _safe_text(text)
127
  if not text: return None
128
  doc = SimpleDocTemplate("output.pdf")
129
+ flowables = [Paragraph(l, getSampleStyleSheet()["Normal"]) for l in text.splitlines() if l.strip()]
130
+ doc.build(flowables or [Paragraph(" ", getSampleStyleSheet()["Normal"])])
 
131
  return "output.pdf"
132
 
133
+ def save_word(text):
134
  text = _safe_text(text)
135
  if not text: return None
136
  doc = Document()
137
+ for l in text.splitlines(): doc.add_paragraph(l)
 
138
  doc.save("output.docx")
139
  return "output.docx"
140
 
141
+ def save_audio(text):
142
  text = _safe_text(text)
143
  if not text: return None
144
+ try: gTTS(text).save("output.mp3"); return "output.mp3"
145
+ except: return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
+ def cer_score(gt, pred):
148
+ if not gt or not pred: return "Missing ground truth or prediction."
149
+ return f"CER: {cer(gt, pred):.4f}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
 
151
+ # ---------------- GRPO Example ----------------
152
+ def save_grpo(name, pref_dict):
153
+ row = {"id": str(uuid.uuid4()), "timestamp": datetime.utcnow().isoformat(), "name": name, "prefs": pref_dict}
154
+ _append_jsonl(GRPO_PATH, row)
155
+ return f"βœ… GRPO saved for {name}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
  # ---------------- Gradio Interface ----------------
158
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
159
+ gr.Markdown("## ✍🏾 Handwritten Text Recognition | Full Feedback & Export")
160
 
161
+ model_choice = gr.Radio(list(MODEL_PATHS.keys()), value=list(MODEL_PATHS.keys())[0], label="OCR Model")
 
 
162
 
163
+ with gr.Tab("πŸ–Ό OCR & Feedback"):
164
+ query_input = gr.Textbox(label="Custom Prompt (optional)")
165
+ image_input = gr.Image(type="pil", label="Upload / Capture Handwritten Image")
166
+ use_rpl = gr.Checkbox(value=True, label="Enable RPL Feedback")
167
 
168
  with gr.Accordion("βš™οΈ Advanced Options", open=False):
169
  max_new_tokens = gr.Slider(1, 2048, value=MAX_NEW_TOKENS_DEFAULT, step=1, label="Max new tokens")
170
  temperature = gr.Slider(0.1, 2.0, value=0.1, step=0.05, label="Temperature")
171
+ top_p = gr.Slider(0.05,1.0,value=1.0,step=0.05,label="Top-p")
172
+ top_k = gr.Slider(0,1000,value=0,step=1,label="Top-k")
173
+ repetition_penalty = gr.Slider(0.8,2.0,value=1.0,step=0.05,label="Repetition penalty")
 
 
 
174
 
175
+ extract_btn = gr.Button("πŸ“€ Extract RAW Text")
176
+ raw_output = gr.Textbox(label="πŸ“œ Output", lines=18, show_copy_button=True)
177
 
 
178
  gr.Markdown("### ✏️ Quick Feedback")
179
+ correction_box = gr.Textbox(label="Your Correction", lines=8)
180
+ ground_truth_box = gr.Textbox(label="Ground Truth", lines=6)
181
+ btn_good = gr.Button("πŸ‘ Accept (Correct)")
182
+ btn_bad = gr.Button("πŸ‘Ž Bad (Incorrect)")
 
 
 
183
  feedback_status = gr.Markdown()
184
 
185
+ extract_btn.click(ocr_image,
186
+ [image_input, model_choice, query_input, max_new_tokens, temperature, top_p, top_k, repetition_penalty, use_rpl],
187
+ raw_output)
188
+
189
+ btn_good.click(lambda img, mc, prmpt, pred, corr, gt: save_feedback(img, mc, prmpt, pred, corr, gt, 1),
190
+ [image_input, model_choice, query_input, raw_output, correction_box, ground_truth_box],
191
+ feedback_status)
192
+ btn_bad.click(lambda img, mc, prmpt, pred, corr, gt: save_feedback(img, mc, prmpt, pred, corr, gt, -1),
193
+ [image_input, model_choice, query_input, raw_output, correction_box, ground_truth_box],
194
+ feedback_status)
195
+
196
+ gr.Markdown("### πŸ“₯ Download Feedback")
197
+ download_jsonl_btn = gr.File(label="Download JSONL")
198
+ download_csv_btn = gr.File(label="Download CSV")
199
+ download_jsonl_btn.click(lambda: FEEDBACK_RPL_PATH if os.path.exists(FEEDBACK_RPL_PATH) else None,
200
+ download_jsonl_btn)
201
+ download_csv_btn.click(export_csv, download_csv_btn)
202
+
203
+ with gr.Tab("πŸ“ Export Formats"):
204
+ pdf_btn = gr.Button("Save as PDF")
205
+ word_btn = gr.Button("Save as Word")
206
+ audio_btn = gr.Button("Save as Audio")
207
+ text_input = gr.Textbox(label="Text to Export", lines=10)
208
+ pdf_btn.click(save_pdf, text_input, gr.File())
209
+ word_btn.click(save_word, text_input, gr.File())
210
+ audio_btn.click(save_audio, text_input, gr.File())
211
+
212
+ with gr.Tab("πŸŽ› GRPO Preferences"):
213
+ user_name = gr.Textbox(label="Name")
214
+ grpo_dict_input = gr.Textbox(label="Preferences (JSON)")
215
+ grpo_save_btn = gr.Button("Save GRPO")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  grpo_status = gr.Markdown()
217
+ grpo_save_btn.click(lambda n,p: save_grpo(n,json.loads(p or "{}")), [user_name, grpo_dict_input], grpo_status)
 
 
 
 
218
 
219
  if __name__ == "__main__":
220
  demo.queue(max_size=50).launch(share=True)