Emeritus-21 commited on
Commit
455c83a
Β·
verified Β·
1 Parent(s): c159940

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -46
app.py CHANGED
@@ -18,15 +18,15 @@ from jiwer import cer
18
 
19
  # ---------------- Storage & Paths ----------------
20
  os.makedirs("data", exist_ok=True)
21
- FEEDBACK_PATH = "data/feedback.jsonl" # raw feedback log (per sample)
22
- MEMORY_RULES_PATH = "data/memory_rules.json" # compiled post-correction rules
23
- GRPO_EXPORT_PATH = "data/grpo_prefs.jsonl" # preference pairs for GRPO
24
- CSV_EXPORT_PATH = "data/feedback.csv" # optional tabular export
25
 
26
  # ---------------- Models ----------------
27
  MODEL_PATHS = {
28
- "Model 1 (Complex handwrittings )": ("prithivMLmods/Qwen2.5-VL-7B-Abliterated-Caption-it", Qwen2_5_VLForConditionalGeneration),
29
- "Model 2 (simple and scanned handwritting )": ("nanonets/Nanonets-OCR-s", Qwen2_5_VLForConditionalGeneration),
30
  }
31
  # Model 3 removed to conserve memory.
32
 
@@ -161,7 +161,7 @@ def _compile_rules_from_feedback(min_count: int = 2, max_phrase_len: int = 40):
161
  row = json.loads(line)
162
  except Exception:
163
  continue
164
- if row.get("reward", 0) < 1: # only learn from thumbs-up or explicit 'accepted_correction'
165
  continue
166
  pred = _safe_text(row.get("prediction", ""))
167
  corr = _safe_text(row.get("correction", "")) or _safe_text(row.get("ground_truth", ""))
@@ -285,14 +285,12 @@ def _export_csv():
285
  w.writerow(flat)
286
  return CSV_EXPORT_PATH
287
 
288
- # ------------------- MODIFIED -------------------
289
  def save_feedback(image: Image.Image, model_choice: str, prompt: str,
290
  prediction: str, correction: str, ground_truth: str, reward: int):
291
  """
292
  reward: 1 = good/accepted, 0 = neutral, -1 = bad
293
  """
294
  if image is None:
295
- # Bug Fix: Return a single string, not a tuple
296
  return "Please provide the image again to link feedback."
297
  if not prediction and not correction and not ground_truth:
298
  return "Nothing to save."
@@ -322,7 +320,6 @@ def save_feedback(image: Image.Image, model_choice: str, prompt: str,
322
  }
323
  _append_jsonl(FEEDBACK_PATH, row)
324
  return f"βœ… Feedback saved (reward={reward})."
325
- # ------------------------------------------------
326
 
327
  def compile_memory_rules():
328
  _compile_rules_from_feedback(min_count=2, max_phrase_len=60)
@@ -360,7 +357,6 @@ def export_grpo_preferences():
360
  count += 1
361
  return f"βœ… Exported {count} GRPO preference pairs to {GRPO_EXPORT_PATH}."
362
 
363
- # ------------------- NEW -------------------
364
  def get_grpo_file():
365
  if os.path.exists(GRPO_EXPORT_PATH):
366
  return GRPO_EXPORT_PATH
@@ -371,7 +367,6 @@ def get_csv_file():
371
  if os.path.exists(CSV_EXPORT_PATH):
372
  return CSV_EXPORT_PATH
373
  return None
374
- # -------------------------------------------
375
 
376
  # ---------------- Evaluation Orchestration ----------------
377
  @spaces.GPU
@@ -448,7 +443,7 @@ def main():
448
 
449
  trainer = GRPOTrainer(
450
  model=model,
451
- ref_model=None, # let TRL create a frozen copy internally
452
  args=cfg,
453
  tokenizer=tok,
454
  train_dataset=ds
@@ -470,7 +465,7 @@ def _write_trainer_script():
470
 
471
  # ---------------- Gradio Interface ----------------
472
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
473
- gr.Markdown("## ✍🏾 wilson Handwritten OCR β€” with Feedback Loop, Memory & GRPO Export")
474
 
475
  model_choice = gr.Radio(choices=list(MODEL_PATHS.keys()),
476
  value=list(MODEL_PATHS.keys())[0],
@@ -573,14 +568,12 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
573
 
574
  csv_status = gr.Markdown()
575
 
576
- # ------------------- MODIFIED -------------------
577
  gr.Markdown("---")
578
  gr.Markdown("### ⬇️ Download Feedback Data")
579
  with gr.Row():
580
  download_csv_btn = gr.Button("⬇️ Download Feedback as CSV")
581
  download_csv_file = gr.File(label="CSV File")
582
  download_csv_btn.click(fn=get_csv_file, outputs=download_csv_file)
583
- # ------------------------------------------------
584
 
585
  with gr.Tab("πŸ§ͺ GRPO / Dataset"):
586
  gr.Markdown("""
@@ -591,35 +584,21 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
591
  ```bash
592
  pip install trl accelerate peft transformers datasets
593
  python train/grpo_train.py
594
- ````
595
 
596
- Set `BASE_MODEL`/`OUTPUT_DIR` env vars if you like.
597
  """)
598
- grpo\_btn = gr.Button("πŸ“¦ Export GRPO Preferences")
599
- grpo\_status = gr.Markdown()
600
- grpo\_btn.click(fn=export\_grpo\_preferences, outputs=[grpo\_status])
601
-
602
- ```
603
- write_script_btn = gr.Button("πŸ“ Write grpo_train.py")
604
- write_script_status = gr.Markdown()
605
- write_script_btn.click(fn=lambda: f"βœ… Trainer script written to `{_write_trainer_script()}`", outputs=[write_script_status])
606
- ```
607
-
608
- # \------------------- NEW -------------------
609
-
610
- ```
611
- gr.Markdown("---")
612
- gr.Markdown("### ⬇️ Download GRPO Dataset")
613
- with gr.Row():
614
- download_grpo_btn = gr.Button("⬇️ Download GRPO Data (grpo_prefs.jsonl)")
615
- download_grpo_file = gr.File(label="GRPO Dataset File")
616
- download_grpo_btn.click(fn=get_grpo_file, outputs=download_grpo_file)
617
- ```
618
-
619
- # \-------------------------------------------
620
-
621
- if **name** == "**main**":
622
- demo.queue(max\_size=50).launch(share=True)
623
-
624
- ```
625
- ```
 
18
 
19
  # ---------------- Storage & Paths ----------------
20
  os.makedirs("data", exist_ok=True)
21
+ FEEDBACK_PATH = "data/feedback.jsonl" # raw feedback log (per sample)
22
+ MEMORY_RULES_PATH = "data/memory_rules.json" # compiled post-correction rules
23
+ GRPO_EXPORT_PATH = "data/grpo_prefs.jsonl" # preference pairs for GRPO
24
+ CSV_EXPORT_PATH = "data/feedback.csv" # optional tabular export
25
 
26
  # ---------------- Models ----------------
27
  MODEL_PATHS = {
28
+ "Model 1 (Complex handwritings)": ("prithivMLmods/Qwen2.5-VL-7B-Abliterated-Caption-it", Qwen2_5_VLForConditionalGeneration),
29
+ "Model 2 (simple and scanned handwriting)": ("nanonets/Nanonets-OCR-s", Qwen2_5_VLForConditionalGeneration),
30
  }
31
  # Model 3 removed to conserve memory.
32
 
 
161
  row = json.loads(line)
162
  except Exception:
163
  continue
164
+ if row.get("reward", 0) < 1: # only learn from thumbs-up or explicit 'accepted_correction'
165
  continue
166
  pred = _safe_text(row.get("prediction", ""))
167
  corr = _safe_text(row.get("correction", "")) or _safe_text(row.get("ground_truth", ""))
 
285
  w.writerow(flat)
286
  return CSV_EXPORT_PATH
287
 
 
288
  def save_feedback(image: Image.Image, model_choice: str, prompt: str,
289
  prediction: str, correction: str, ground_truth: str, reward: int):
290
  """
291
  reward: 1 = good/accepted, 0 = neutral, -1 = bad
292
  """
293
  if image is None:
 
294
  return "Please provide the image again to link feedback."
295
  if not prediction and not correction and not ground_truth:
296
  return "Nothing to save."
 
320
  }
321
  _append_jsonl(FEEDBACK_PATH, row)
322
  return f"βœ… Feedback saved (reward={reward})."
 
323
 
324
  def compile_memory_rules():
325
  _compile_rules_from_feedback(min_count=2, max_phrase_len=60)
 
357
  count += 1
358
  return f"βœ… Exported {count} GRPO preference pairs to {GRPO_EXPORT_PATH}."
359
 
 
360
  def get_grpo_file():
361
  if os.path.exists(GRPO_EXPORT_PATH):
362
  return GRPO_EXPORT_PATH
 
367
  if os.path.exists(CSV_EXPORT_PATH):
368
  return CSV_EXPORT_PATH
369
  return None
 
370
 
371
  # ---------------- Evaluation Orchestration ----------------
372
  @spaces.GPU
 
443
 
444
  trainer = GRPOTrainer(
445
  model=model,
446
+ ref_model=None, # let TRL create a frozen copy internally
447
  args=cfg,
448
  tokenizer=tok,
449
  train_dataset=ds
 
465
 
466
  # ---------------- Gradio Interface ----------------
467
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
468
+ gr.Markdown("## ✍🏾 Wilson Handwritten OCR β€” with Feedback Loop, Memory & GRPO Export")
469
 
470
  model_choice = gr.Radio(choices=list(MODEL_PATHS.keys()),
471
  value=list(MODEL_PATHS.keys())[0],
 
568
 
569
  csv_status = gr.Markdown()
570
 
 
571
  gr.Markdown("---")
572
  gr.Markdown("### ⬇️ Download Feedback Data")
573
  with gr.Row():
574
  download_csv_btn = gr.Button("⬇️ Download Feedback as CSV")
575
  download_csv_file = gr.File(label="CSV File")
576
  download_csv_btn.click(fn=get_csv_file, outputs=download_csv_file)
 
577
 
578
  with gr.Tab("πŸ§ͺ GRPO / Dataset"):
579
  gr.Markdown("""
 
584
  ```bash
585
  pip install trl accelerate peft transformers datasets
586
  python train/grpo_train.py
 
587
 
588
+ Set BASE_MODEL/OUTPUT_DIR env vars if you like.
589
  """)
590
+ grpo_btn = gr.Button("πŸ“¦ Export GRPO Preferences")
591
+ grpo_status = gr.Markdown()
592
+ grpo_btn.click(fn=export_grpo_preferences, outputs=[grpo_status])
593
+ write_script_btn = gr.Button("πŸ“ Write grpo_train.py")
594
+ write_script_status = gr.Markdown()
595
+ write_script_btn.click(fn=lambda: f"βœ… Trainer script written to {_write_trainer_script()}", outputs=[write_script_status])
596
+ gr.Markdown("---")
597
+ gr.Markdown("### ⬇️ Download GRPO Dataset")
598
+ with gr.Row():
599
+ download_grpo_btn = gr.Button("⬇️ Download GRPO Data (grpo_prefs.jsonl)")
600
+ download_grpo_file = gr.File(label="GRPO Dataset File")
601
+ download_grpo_btn.click(fn=get_grpo_file, outputs=[download_grpo_file])
602
+
603
+ if name == "main":
604
+ demo.queue(max_size=50).launch(share=True)