Emeritus-21 commited on
Commit
c159940
Β·
verified Β·
1 Parent(s): 3096d10

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -31
app.py CHANGED
@@ -18,10 +18,10 @@ from jiwer import cer
18
 
19
  # ---------------- Storage & Paths ----------------
20
  os.makedirs("data", exist_ok=True)
21
- FEEDBACK_PATH = "data/feedback.jsonl" # raw feedback log (per sample)
22
- MEMORY_RULES_PATH = "data/memory_rules.json" # compiled post-correction rules
23
- GRPO_EXPORT_PATH = "data/grpo_prefs.jsonl" # preference pairs for GRPO
24
- CSV_EXPORT_PATH = "data/feedback.csv" # optional tabular export
25
 
26
  # ---------------- Models ----------------
27
  MODEL_PATHS = {
@@ -161,7 +161,7 @@ def _compile_rules_from_feedback(min_count: int = 2, max_phrase_len: int = 40):
161
  row = json.loads(line)
162
  except Exception:
163
  continue
164
- if row.get("reward", 0) < 1: # only learn from thumbs-up or explicit 'accepted_correction'
165
  continue
166
  pred = _safe_text(row.get("prediction", ""))
167
  corr = _safe_text(row.get("correction", "")) or _safe_text(row.get("ground_truth", ""))
@@ -266,7 +266,7 @@ def _append_jsonl(path, obj):
266
  def _export_csv():
267
  # optional: CSV summary for spreadsheet views
268
  if not os.path.exists(FEEDBACK_PATH):
269
- return CSV_EXPORT_PATH if os.path.exists(CSV_EXPORT_PATH) else None
270
  rows = []
271
  with open(FEEDBACK_PATH, "r", encoding="utf-8") as f:
272
  for line in f:
@@ -285,15 +285,17 @@ def _export_csv():
285
  w.writerow(flat)
286
  return CSV_EXPORT_PATH
287
 
 
288
  def save_feedback(image: Image.Image, model_choice: str, prompt: str,
289
  prediction: str, correction: str, ground_truth: str, reward: int):
290
  """
291
  reward: 1 = good/accepted, 0 = neutral, -1 = bad
292
  """
293
  if image is None:
294
- return "Please provide the image again to link feedback.", 0
 
295
  if not prediction and not correction and not ground_truth:
296
- return "Nothing to save.", 0
297
 
298
  image_hash = _hash_image(image)
299
  # best target = correction, else ground_truth, else prediction
@@ -319,7 +321,8 @@ def save_feedback(image: Image.Image, model_choice: str, prompt: str,
319
  "cer": float(cer_score) if cer_score is not None else None,
320
  }
321
  _append_jsonl(FEEDBACK_PATH, row)
322
- return f"βœ… Feedback saved (reward={reward}).", 1
 
323
 
324
  def compile_memory_rules():
325
  _compile_rules_from_feedback(min_count=2, max_phrase_len=60)
@@ -357,11 +360,18 @@ def export_grpo_preferences():
357
  count += 1
358
  return f"βœ… Exported {count} GRPO preference pairs to {GRPO_EXPORT_PATH}."
359
 
360
- def export_csv():
361
- p = _export_csv()
362
- if p:
363
- return f"βœ… CSV exported: {p}"
364
- return "No data to export."
 
 
 
 
 
 
 
365
 
366
  # ---------------- Evaluation Orchestration ----------------
367
  @spaces.GPU
@@ -438,7 +448,7 @@ def main():
438
 
439
  trainer = GRPOTrainer(
440
  model=model,
441
- ref_model=None, # let TRL create a frozen copy internally
442
  args=cfg,
443
  tokenizer=tok,
444
  train_dataset=ds
@@ -552,38 +562,64 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
552
 
553
  with gr.Tab("✏️ Feedback & Memory"):
554
  gr.Markdown("""
555
- **Pipeline**
556
- 1) Save feedback (πŸ‘ / πŸ‘Ž) and add corrections.
557
- 2) Click **Build/Refresh Memory** to generate auto-fix rules from positive feedback.
558
  3) Keep **Enable Memory Post-correction** checked on inference/eval tabs.
559
  """)
560
  build_mem_btn = gr.Button("🧠 Build/Refresh Memory from Feedback")
561
  mem_status = gr.Markdown()
562
  build_mem_btn.click(fn=compile_memory_rules, outputs=[mem_status])
563
 
564
- csv_btn = gr.Button("πŸ“€ Export Feedback as CSV")
565
  csv_status = gr.Markdown()
566
- csv_btn.click(fn=export_csv, outputs=[csv_status])
 
 
 
 
 
 
 
 
567
 
568
  with gr.Tab("πŸ§ͺ GRPO / Dataset"):
569
  gr.Markdown("""
570
  **GRPO Fine-tuning** (run offline or in a training Space):
571
  - Click **Export GRPO Preferences** to produce `data/grpo_prefs.jsonl` of (prompt, chosen, rejected).
572
  - Click **Write Trainer Script** to create `train/grpo_train.py`.
573
- - Then run:
574
  ```bash
575
  pip install trl accelerate peft transformers datasets
576
  python train/grpo_train.py
577
- ```
 
578
  Set `BASE_MODEL`/`OUTPUT_DIR` env vars if you like.
579
- """)
580
- grpo_btn = gr.Button("πŸ“¦ Export GRPO Preferences")
581
- grpo_status = gr.Markdown()
582
- grpo_btn.click(fn=export_grpo_preferences, outputs=[grpo_status])
583
 
584
- write_script_btn = gr.Button("πŸ“ Write grpo_train.py")
585
- write_script_status = gr.Markdown()
586
- write_script_btn.click(fn=lambda: f"βœ… Trainer script written to `{_write_trainer_script()}`", outputs=[write_script_status])
 
 
587
 
588
- if __name__ == "__main__":
589
- demo.queue(max_size=50).launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  # ---------------- Storage & Paths ----------------
20
  os.makedirs("data", exist_ok=True)
21
+ FEEDBACK_PATH = "data/feedback.jsonl" # raw feedback log (per sample)
22
+ MEMORY_RULES_PATH = "data/memory_rules.json" # compiled post-correction rules
23
+ GRPO_EXPORT_PATH = "data/grpo_prefs.jsonl" # preference pairs for GRPO
24
+ CSV_EXPORT_PATH = "data/feedback.csv" # optional tabular export
25
 
26
  # ---------------- Models ----------------
27
  MODEL_PATHS = {
 
161
  row = json.loads(line)
162
  except Exception:
163
  continue
164
+ if row.get("reward", 0) < 1: # only learn from thumbs-up or explicit 'accepted_correction'
165
  continue
166
  pred = _safe_text(row.get("prediction", ""))
167
  corr = _safe_text(row.get("correction", "")) or _safe_text(row.get("ground_truth", ""))
 
266
  def _export_csv():
267
  # optional: CSV summary for spreadsheet views
268
  if not os.path.exists(FEEDBACK_PATH):
269
+ return None
270
  rows = []
271
  with open(FEEDBACK_PATH, "r", encoding="utf-8") as f:
272
  for line in f:
 
285
  w.writerow(flat)
286
  return CSV_EXPORT_PATH
287
 
288
+ # ------------------- MODIFIED -------------------
289
  def save_feedback(image: Image.Image, model_choice: str, prompt: str,
290
  prediction: str, correction: str, ground_truth: str, reward: int):
291
  """
292
  reward: 1 = good/accepted, 0 = neutral, -1 = bad
293
  """
294
  if image is None:
295
+ # Bug Fix: Return a single string, not a tuple
296
+ return "Please provide the image again to link feedback."
297
  if not prediction and not correction and not ground_truth:
298
+ return "Nothing to save."
299
 
300
  image_hash = _hash_image(image)
301
  # best target = correction, else ground_truth, else prediction
 
321
  "cer": float(cer_score) if cer_score is not None else None,
322
  }
323
  _append_jsonl(FEEDBACK_PATH, row)
324
+ return f"βœ… Feedback saved (reward={reward})."
325
+ # ------------------------------------------------
326
 
327
  def compile_memory_rules():
328
  _compile_rules_from_feedback(min_count=2, max_phrase_len=60)
 
360
  count += 1
361
  return f"βœ… Exported {count} GRPO preference pairs to {GRPO_EXPORT_PATH}."
362
 
363
+ # ------------------- NEW -------------------
364
+ def get_grpo_file():
365
+ if os.path.exists(GRPO_EXPORT_PATH):
366
+ return GRPO_EXPORT_PATH
367
+ return None
368
+
369
+ def get_csv_file():
370
+ _export_csv()
371
+ if os.path.exists(CSV_EXPORT_PATH):
372
+ return CSV_EXPORT_PATH
373
+ return None
374
+ # -------------------------------------------
375
 
376
  # ---------------- Evaluation Orchestration ----------------
377
  @spaces.GPU
 
448
 
449
  trainer = GRPOTrainer(
450
  model=model,
451
+ ref_model=None, # let TRL create a frozen copy internally
452
  args=cfg,
453
  tokenizer=tok,
454
  train_dataset=ds
 
562
 
563
  with gr.Tab("✏️ Feedback & Memory"):
564
  gr.Markdown("""
565
+ **Pipeline**
566
+ 1) Save feedback (πŸ‘ / πŸ‘Ž) and add corrections.
567
+ 2) Click **Build/Refresh Memory** to generate auto-fix rules from positive feedback.
568
  3) Keep **Enable Memory Post-correction** checked on inference/eval tabs.
569
  """)
570
  build_mem_btn = gr.Button("🧠 Build/Refresh Memory from Feedback")
571
  mem_status = gr.Markdown()
572
  build_mem_btn.click(fn=compile_memory_rules, outputs=[mem_status])
573
 
 
574
  csv_status = gr.Markdown()
575
+
576
+ # ------------------- MODIFIED -------------------
577
+ gr.Markdown("---")
578
+ gr.Markdown("### ⬇️ Download Feedback Data")
579
+ with gr.Row():
580
+ download_csv_btn = gr.Button("⬇️ Download Feedback as CSV")
581
+ download_csv_file = gr.File(label="CSV File")
582
+ download_csv_btn.click(fn=get_csv_file, outputs=download_csv_file)
583
+ # ------------------------------------------------
584
 
585
  with gr.Tab("πŸ§ͺ GRPO / Dataset"):
586
  gr.Markdown("""
587
  **GRPO Fine-tuning** (run offline or in a training Space):
588
  - Click **Export GRPO Preferences** to produce `data/grpo_prefs.jsonl` of (prompt, chosen, rejected).
589
  - Click **Write Trainer Script** to create `train/grpo_train.py`.
590
+ - Then run:
591
  ```bash
592
  pip install trl accelerate peft transformers datasets
593
  python train/grpo_train.py
594
+ ````
595
+
596
  Set `BASE_MODEL`/`OUTPUT_DIR` env vars if you like.
597
+ """)
598
+ grpo\_btn = gr.Button("πŸ“¦ Export GRPO Preferences")
599
+ grpo\_status = gr.Markdown()
600
+ grpo\_btn.click(fn=export\_grpo\_preferences, outputs=[grpo\_status])
601
 
602
+ ```
603
+ write_script_btn = gr.Button("πŸ“ Write grpo_train.py")
604
+ write_script_status = gr.Markdown()
605
+ write_script_btn.click(fn=lambda: f"βœ… Trainer script written to `{_write_trainer_script()}`", outputs=[write_script_status])
606
+ ```
607
 
608
+ # \------------------- NEW -------------------
609
+
610
+ ```
611
+ gr.Markdown("---")
612
+ gr.Markdown("### ⬇️ Download GRPO Dataset")
613
+ with gr.Row():
614
+ download_grpo_btn = gr.Button("⬇️ Download GRPO Data (grpo_prefs.jsonl)")
615
+ download_grpo_file = gr.File(label="GRPO Dataset File")
616
+ download_grpo_btn.click(fn=get_grpo_file, outputs=download_grpo_file)
617
+ ```
618
+
619
+ # \-------------------------------------------
620
+
621
+ if **name** == "**main**":
622
+ demo.queue(max\_size=50).launch(share=True)
623
+
624
+ ```
625
+ ```