Emeritus-21 commited on
Commit
59576ba
Β·
verified Β·
1 Parent(s): 28b35fd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -32
app.py CHANGED
@@ -1,5 +1,4 @@
1
  # app.py β€” HTR Space with Feedback Loop, Memory Post-Correction, and GRPO Export
2
-
3
  import os, time, json, hashlib, difflib, uuid, csv
4
  from datetime import datetime
5
  from collections import Counter, defaultdict
@@ -18,17 +17,17 @@ from jiwer import cer
18
 
19
  # ---------------- Storage & Paths ----------------
20
  os.makedirs("data", exist_ok=True)
21
- FEEDBACK_PATH = "data/feedback.jsonl" # raw feedback log (per sample)
22
- MEMORY_RULES_PATH = "data/memory_rules.json" # compiled post-correction rules
23
- GRPO_EXPORT_PATH = "data/grpo_prefs.jsonl" # preference pairs for GRPO
24
- CSV_EXPORT_PATH = "data/feedback.csv" # optional tabular export
25
 
26
  # ---------------- Models ----------------
27
  MODEL_PATHS = {
28
  "Model 1 (Complex handwrittings )": ("prithivMLmods/Qwen2.5-VL-7B-Abliterated-Caption-it", Qwen2_5_VLForConditionalGeneration),
29
  "Model 2 (simple and scanned handwritting )": ("nanonets/Nanonets-OCR-s", Qwen2_5_VLForConditionalGeneration),
30
  }
31
-
32
 
33
  MAX_NEW_TOKENS_DEFAULT = 512
34
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -161,7 +160,7 @@ def _compile_rules_from_feedback(min_count: int = 2, max_phrase_len: int = 40):
161
  row = json.loads(line)
162
  except Exception:
163
  continue
164
- if row.get("reward", 0) < 1: # only learn from thumbs-up or explicit 'accepted_correction'
165
  continue
166
  pred = _safe_text(row.get("prediction", ""))
167
  corr = _safe_text(row.get("correction", "")) or _safe_text(row.get("ground_truth", ""))
@@ -266,7 +265,7 @@ def _append_jsonl(path, obj):
266
  def _export_csv():
267
  # optional: CSV summary for spreadsheet views
268
  if not os.path.exists(FEEDBACK_PATH):
269
- return CSV_EXPORT_PATH if os.path.exists(CSV_EXPORT_PATH) else None
270
  rows = []
271
  with open(FEEDBACK_PATH, "r", encoding="utf-8") as f:
272
  for line in f:
@@ -285,15 +284,17 @@ def _export_csv():
285
  w.writerow(flat)
286
  return CSV_EXPORT_PATH
287
 
 
288
  def save_feedback(image: Image.Image, model_choice: str, prompt: str,
289
  prediction: str, correction: str, ground_truth: str, reward: int):
290
  """
291
  reward: 1 = good/accepted, 0 = neutral, -1 = bad
292
  """
293
  if image is None:
294
- return "Please provide the image again to link feedback.", 0
 
295
  if not prediction and not correction and not ground_truth:
296
- return "Nothing to save.", 0
297
 
298
  image_hash = _hash_image(image)
299
  # best target = correction, else ground_truth, else prediction
@@ -319,7 +320,8 @@ def save_feedback(image: Image.Image, model_choice: str, prompt: str,
319
  "cer": float(cer_score) if cer_score is not None else None,
320
  }
321
  _append_jsonl(FEEDBACK_PATH, row)
322
- return f"βœ… Feedback saved (reward={reward}).", 1
 
323
 
324
  def compile_memory_rules():
325
  _compile_rules_from_feedback(min_count=2, max_phrase_len=60)
@@ -357,11 +359,18 @@ def export_grpo_preferences():
357
  count += 1
358
  return f"βœ… Exported {count} GRPO preference pairs to {GRPO_EXPORT_PATH}."
359
 
360
- def export_csv():
361
- p = _export_csv()
362
- if p:
363
- return f"βœ… CSV exported: {p}"
364
- return "No data to export."
 
 
 
 
 
 
 
365
 
366
  # ---------------- Evaluation Orchestration ----------------
367
  @spaces.GPU
@@ -438,7 +447,7 @@ def main():
438
 
439
  trainer = GRPOTrainer(
440
  model=model,
441
- ref_model=None, # let TRL create a frozen copy internally
442
  args=cfg,
443
  tokenizer=tok,
444
  train_dataset=ds
@@ -460,7 +469,7 @@ def _write_trainer_script():
460
 
461
  # ---------------- Gradio Interface ----------------
462
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
463
- gr.Markdown("## ✍🏾 wilson Handwritten text recognition with Feedback Loop")
464
 
465
  model_choice = gr.Radio(choices=list(MODEL_PATHS.keys()),
466
  value=list(MODEL_PATHS.keys())[0],
@@ -552,38 +561,61 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
552
 
553
  with gr.Tab("✏️ Feedback & Memory"):
554
  gr.Markdown("""
555
- **Pipeline**
556
- 1) Save feedback (πŸ‘ / πŸ‘Ž) and add corrections.
557
- 2) Click **Build/Refresh Memory** to generate auto-fix rules from positive feedback.
558
  3) Keep **Enable Memory Post-correction** checked on inference/eval tabs.
559
  """)
560
  build_mem_btn = gr.Button("🧠 Build/Refresh Memory from Feedback")
561
  mem_status = gr.Markdown()
562
  build_mem_btn.click(fn=compile_memory_rules, outputs=[mem_status])
563
 
564
- csv_btn = gr.Button("πŸ“€ Export Feedback as CSV")
565
  csv_status = gr.Markdown()
566
- csv_btn.click(fn=export_csv, outputs=[csv_status])
 
 
 
 
 
 
 
 
567
 
568
  with gr.Tab("πŸ§ͺ GRPO / Dataset"):
569
  gr.Markdown("""
570
  **GRPO Fine-tuning** (run offline or in a training Space):
571
  - Click **Export GRPO Preferences** to produce `data/grpo_prefs.jsonl` of (prompt, chosen, rejected).
572
  - Click **Write Trainer Script** to create `train/grpo_train.py`.
573
- - Then run:
574
  ```bash
575
  pip install trl accelerate peft transformers datasets
576
  python train/grpo_train.py
577
- ```
 
578
  Set `BASE_MODEL`/`OUTPUT_DIR` env vars if you like.
579
- """)
580
- grpo_btn = gr.Button("πŸ“¦ Export GRPO Preferences")
581
- grpo_status = gr.Markdown()
582
- grpo_btn.click(fn=export_grpo_preferences, outputs=[grpo_status])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
583
 
584
- write_script_btn = gr.Button("πŸ“ Write grpo_train.py")
585
- write_script_status = gr.Markdown()
586
- write_script_btn.click(fn=lambda: f"βœ… Trainer script written to `{_write_trainer_script()}`", outputs=[write_script_status])
587
 
588
  if __name__ == "__main__":
589
  demo.queue(max_size=50).launch(share=True)
 
1
  # app.py β€” HTR Space with Feedback Loop, Memory Post-Correction, and GRPO Export
 
2
  import os, time, json, hashlib, difflib, uuid, csv
3
  from datetime import datetime
4
  from collections import Counter, defaultdict
 
17
 
18
  # ---------------- Storage & Paths ----------------
19
  os.makedirs("data", exist_ok=True)
20
+ FEEDBACK_PATH = "data/feedback.jsonl" # raw feedback log (per sample)
21
+ MEMORY_RULES_PATH = "data/memory_rules.json" # compiled post-correction rules
22
+ GRPO_EXPORT_PATH = "data/grpo_prefs.jsonl" # preference pairs for GRPO
23
+ CSV_EXPORT_PATH = "data/feedback.csv" # optional tabular export
24
 
25
  # ---------------- Models ----------------
26
  MODEL_PATHS = {
27
  "Model 1 (Complex handwrittings )": ("prithivMLmods/Qwen2.5-VL-7B-Abliterated-Caption-it", Qwen2_5_VLForConditionalGeneration),
28
  "Model 2 (simple and scanned handwritting )": ("nanonets/Nanonets-OCR-s", Qwen2_5_VLForConditionalGeneration),
29
  }
30
+ # Model 3 removed to conserve memory.
31
 
32
  MAX_NEW_TOKENS_DEFAULT = 512
33
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
160
  row = json.loads(line)
161
  except Exception:
162
  continue
163
+ if row.get("reward", 0) < 1: # only learn from thumbs-up or explicit 'accepted_correction'
164
  continue
165
  pred = _safe_text(row.get("prediction", ""))
166
  corr = _safe_text(row.get("correction", "")) or _safe_text(row.get("ground_truth", ""))
 
265
  def _export_csv():
266
  # optional: CSV summary for spreadsheet views
267
  if not os.path.exists(FEEDBACK_PATH):
268
+ return None
269
  rows = []
270
  with open(FEEDBACK_PATH, "r", encoding="utf-8") as f:
271
  for line in f:
 
284
  w.writerow(flat)
285
  return CSV_EXPORT_PATH
286
 
287
+ # ------------------- MODIFIED -------------------
288
  def save_feedback(image: Image.Image, model_choice: str, prompt: str,
289
  prediction: str, correction: str, ground_truth: str, reward: int):
290
  """
291
  reward: 1 = good/accepted, 0 = neutral, -1 = bad
292
  """
293
  if image is None:
294
+ # Bug Fix: Return a single string, not a tuple
295
+ return "Please provide the image again to link feedback."
296
  if not prediction and not correction and not ground_truth:
297
+ return "Nothing to save."
298
 
299
  image_hash = _hash_image(image)
300
  # best target = correction, else ground_truth, else prediction
 
320
  "cer": float(cer_score) if cer_score is not None else None,
321
  }
322
  _append_jsonl(FEEDBACK_PATH, row)
323
+ return f"βœ… Feedback saved (reward={reward})."
324
+ # ------------------------------------------------
325
 
326
  def compile_memory_rules():
327
  _compile_rules_from_feedback(min_count=2, max_phrase_len=60)
 
359
  count += 1
360
  return f"βœ… Exported {count} GRPO preference pairs to {GRPO_EXPORT_PATH}."
361
 
362
+ # ------------------- NEW -------------------
363
+ def get_grpo_file():
364
+ if os.path.exists(GRPO_EXPORT_PATH):
365
+ return GRPO_EXPORT_PATH
366
+ return None
367
+
368
+ def get_csv_file():
369
+ _export_csv()
370
+ if os.path.exists(CSV_EXPORT_PATH):
371
+ return CSV_EXPORT_PATH
372
+ return None
373
+ # -------------------------------------------
374
 
375
  # ---------------- Evaluation Orchestration ----------------
376
  @spaces.GPU
 
447
 
448
  trainer = GRPOTrainer(
449
  model=model,
450
+ ref_model=None, # let TRL create a frozen copy internally
451
  args=cfg,
452
  tokenizer=tok,
453
  train_dataset=ds
 
469
 
470
  # ---------------- Gradio Interface ----------------
471
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
472
+ gr.Markdown("## ✍🏾 wilson Handwritten OCR β€” with Feedback Loop, Memory & GRPO Export")
473
 
474
  model_choice = gr.Radio(choices=list(MODEL_PATHS.keys()),
475
  value=list(MODEL_PATHS.keys())[0],
 
561
 
562
  with gr.Tab("✏️ Feedback & Memory"):
563
  gr.Markdown("""
564
+ **Pipeline**
565
+ 1) Save feedback (πŸ‘ / πŸ‘Ž) and add corrections.
566
+ 2) Click **Build/Refresh Memory** to generate auto-fix rules from positive feedback.
567
  3) Keep **Enable Memory Post-correction** checked on inference/eval tabs.
568
  """)
569
  build_mem_btn = gr.Button("🧠 Build/Refresh Memory from Feedback")
570
  mem_status = gr.Markdown()
571
  build_mem_btn.click(fn=compile_memory_rules, outputs=[mem_status])
572
 
 
573
  csv_status = gr.Markdown()
574
+
575
+ # ------------------- MODIFIED -------------------
576
+ gr.Markdown("---")
577
+ gr.Markdown("### ⬇️ Download Feedback Data")
578
+ with gr.Row():
579
+ download_csv_btn = gr.Button("⬇️ Download Feedback as CSV")
580
+ download_csv_file = gr.File(label="CSV File")
581
+ download_csv_btn.click(fn=get_csv_file, outputs=download_csv_file)
582
+ # ------------------------------------------------
583
 
584
  with gr.Tab("πŸ§ͺ GRPO / Dataset"):
585
  gr.Markdown("""
586
  **GRPO Fine-tuning** (run offline or in a training Space):
587
  - Click **Export GRPO Preferences** to produce `data/grpo_prefs.jsonl` of (prompt, chosen, rejected).
588
  - Click **Write Trainer Script** to create `train/grpo_train.py`.
589
+ - Then run:
590
  ```bash
591
  pip install trl accelerate peft transformers datasets
592
  python train/grpo_train.py
593
+ ````
594
+
595
  Set `BASE_MODEL`/`OUTPUT_DIR` env vars if you like.
596
+ """)
597
+ grpo\_btn = gr.Button("πŸ“¦ Export GRPO Preferences")
598
+ grpo\_status = gr.Markdown()
599
+ grpo\_btn.click(fn=export\_grpo\_preferences, outputs=[grpo\_status])
600
+
601
+
602
+ write_script_btn = gr.Button("πŸ“ Write grpo_train.py")
603
+ write_script_status = gr.Markdown()
604
+ write_script_btn.click(fn=lambda: f"βœ… Trainer script written to `{_write_trainer_script()}`", outputs=[write_script_status])
605
+
606
+
607
+ # \------------------- NEW -------------------
608
+
609
+
610
+ gr.Markdown("---")
611
+ gr.Markdown("### ⬇️ Download GRPO Dataset")
612
+ with gr.Row():
613
+ download_grpo_btn = gr.Button("⬇️ Download GRPO Data (grpo_prefs.jsonl)")
614
+ download_grpo_file = gr.File(label="GRPO Dataset File")
615
+ download_grpo_btn.click(fn=get_grpo_file, outputs=download_grpo_file)
616
+
617
 
618
+ # \-------------------------------------------
 
 
619
 
620
  if __name__ == "__main__":
621
  demo.queue(max_size=50).launch(share=True)