Emeritus-21 commited on
Commit
c3250ac
Β·
verified Β·
1 Parent(s): c1235af

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -67
app.py CHANGED
@@ -1,4 +1,5 @@
1
  # app.py β€” HTR Space with Feedback Loop, Memory Post-Correction, and GRPO Export
 
2
  import os, time, json, hashlib, difflib, uuid, csv
3
  from datetime import datetime
4
  from collections import Counter, defaultdict
@@ -17,16 +18,17 @@ from jiwer import cer
17
 
18
  # ---------------- Storage & Paths ----------------
19
  os.makedirs("data", exist_ok=True)
20
- FEEDBACK_PATH = "data/feedback.jsonl" # raw feedback log (per sample)
21
  MEMORY_RULES_PATH = "data/memory_rules.json" # compiled post-correction rules
22
- GRPO_EXPORT_PATH = "data/grpo_prefs.jsonl" # preference pairs for GRPO
23
- CSV_EXPORT_PATH = "data/feedback.csv" # optional tabular export
24
 
25
  # ---------------- Models ----------------
26
  MODEL_PATHS = {
27
- "Model 1 (Complex handwritings)": ("prithivMLmods/Qwen2.5-VL-7B-Abliterated-Caption-it", Qwen2_5_VLForConditionalGeneration),
28
- "Model 2 (simple and scanned handwriting)": ("nanonets/Nanonets-OCR-s", Qwen2_5_VLForConditionalGeneration),
29
  }
 
30
 
31
  MAX_NEW_TOKENS_DEFAULT = 512
32
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -109,6 +111,7 @@ def _safe_text(text: str) -> str:
109
  return (text or "").strip()
110
 
111
  def _hash_image(image: Image.Image) -> str:
 
112
  img_bytes = image.tobytes()
113
  return hashlib.sha256(img_bytes).hexdigest()
114
 
@@ -130,16 +133,22 @@ def _apply_memory(text: str, model_choice: str, enabled: bool):
130
  if not enabled or not text:
131
  return text
132
  rules = _load_memory_rules()
 
133
  by_model = rules.get("by_model", {}).get(model_choice, {})
134
  for wrong, right in by_model.items():
135
  if wrong and right:
136
  text = text.replace(wrong, right)
 
137
  for wrong, right in rules.get("global", {}).items():
138
  if wrong and right:
139
  text = text.replace(wrong, right)
140
  return text
141
 
142
  def _compile_rules_from_feedback(min_count: int = 2, max_phrase_len: int = 40):
 
 
 
 
143
  changes_counter_global = Counter()
144
  changes_counter_by_model = defaultdict(Counter)
145
 
@@ -152,18 +161,20 @@ def _compile_rules_from_feedback(min_count: int = 2, max_phrase_len: int = 40):
152
  row = json.loads(line)
153
  except Exception:
154
  continue
155
- if row.get("reward", 0) < 1:
156
  continue
157
  pred = _safe_text(row.get("prediction", ""))
158
  corr = _safe_text(row.get("correction", "")) or _safe_text(row.get("ground_truth", ""))
159
  if not pred or not corr:
160
  continue
161
  model_choice = row.get("model_choice", "")
 
162
  s = difflib.SequenceMatcher(None, pred, corr)
163
  for tag, i1, i2, j1, j2 in s.get_opcodes():
164
  if tag in ("replace", "delete", "insert"):
165
  wrong = pred[i1:i2]
166
  right = corr[j1:j2]
 
167
  if 0 < len(wrong) <= max_phrase_len or 0 < len(right) <= max_phrase_len:
168
  if wrong.strip():
169
  changes_counter_global[(wrong, right)] += 1
@@ -171,9 +182,11 @@ def _compile_rules_from_feedback(min_count: int = 2, max_phrase_len: int = 40):
171
  changes_counter_by_model[model_choice][(wrong, right)] += 1
172
 
173
  rules = {"global": {}, "by_model": {}}
 
174
  for (wrong, right), cnt in changes_counter_global.items():
175
  if cnt >= min_count and wrong and right and wrong != right:
176
  rules["global"][wrong] = right
 
177
  for model_choice, ctr in changes_counter_by_model.items():
178
  rules["by_model"].setdefault(model_choice, {})
179
  for (wrong, right), cnt in ctr.items():
@@ -189,10 +202,8 @@ def ocr_image(image: Image.Image, model_choice: str, query: str = None,
189
  temperature: float = 0.1, top_p: float = 1.0, top_k: int = 0, repetition_penalty: float = 1.0,
190
  use_memory: bool = True,
191
  progress=gr.Progress(track_tqdm=True)):
192
- if image is None:
193
- return "Please upload or capture an image."
194
- if model_choice not in _loaded_models:
195
- return f"Invalid model: {model_choice}"
196
  processor, model, tokenizer = _loaded_processors[model_choice], _loaded_models[model_choice], getattr(_loaded_processors[model_choice], "tokenizer", None)
197
  prompt = _default_prompt(query)
198
  batch = _build_inputs(processor, tokenizer, image, prompt).to(device)
@@ -200,25 +211,23 @@ def ocr_image(image: Image.Image, model_choice: str, query: str = None,
200
  output_ids = model.generate(**batch, max_new_tokens=max_new_tokens, do_sample=False,
201
  temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty)
202
  raw = _decode_text(model, processor, tokenizer, output_ids, prompt).replace("<|im_end|>", "").strip()
 
203
  post = _apply_memory(raw, model_choice, use_memory)
204
  return post
205
 
206
  # ---------------- Export Helpers ----------------
207
  def save_as_pdf(text):
208
  text = _safe_text(text)
209
- if not text:
210
- return None
211
  doc = SimpleDocTemplate("output.pdf")
212
  flowables = [Paragraph(t, getSampleStyleSheet()["Normal"]) for t in text.splitlines() if t != ""]
213
- if not flowables:
214
- flowables = [Paragraph(" ", getSampleStyleSheet()["Normal"])]
215
  doc.build(flowables)
216
  return "output.pdf"
217
 
218
  def save_as_word(text):
219
  text = _safe_text(text)
220
- if not text:
221
- return None
222
  doc = Document()
223
  for line in text.splitlines():
224
  doc.add_paragraph(line)
@@ -227,8 +236,7 @@ def save_as_word(text):
227
 
228
  def save_as_audio(text):
229
  text = _safe_text(text)
230
- if not text:
231
- return None
232
  try:
233
  tts = gTTS(text)
234
  tts.save("output.mp3")
@@ -239,6 +247,10 @@ def save_as_audio(text):
239
 
240
  # ---------------- Metrics Function ----------------
241
  def calculate_cer_score(ground_truth: str, prediction: str) -> str:
 
 
 
 
242
  if not ground_truth or not prediction:
243
  return "Cannot calculate CER: Missing ground truth or prediction."
244
  ground_truth_cleaned = " ".join(ground_truth.strip().split())
@@ -252,8 +264,9 @@ def _append_jsonl(path, obj):
252
  f.write(json.dumps(obj, ensure_ascii=False) + "\n")
253
 
254
  def _export_csv():
 
255
  if not os.path.exists(FEEDBACK_PATH):
256
- return None
257
  rows = []
258
  with open(FEEDBACK_PATH, "r", encoding="utf-8") as f:
259
  for line in f:
@@ -263,7 +276,7 @@ def _export_csv():
263
  pass
264
  if not rows:
265
  return None
266
- keys = ["id", "timestamp", "model_choice", "image_sha256", "prompt", "prediction", "correction", "ground_truth", "reward", "cer"]
267
  with open(CSV_EXPORT_PATH, "w", newline="", encoding="utf-8") as f:
268
  w = csv.DictWriter(f, fieldnames=keys)
269
  w.writeheader()
@@ -274,12 +287,16 @@ def _export_csv():
274
 
275
  def save_feedback(image: Image.Image, model_choice: str, prompt: str,
276
  prediction: str, correction: str, ground_truth: str, reward: int):
 
 
 
277
  if image is None:
278
- return "Please provide the image again to link feedback."
279
  if not prediction and not correction and not ground_truth:
280
- return "Nothing to save."
281
 
282
  image_hash = _hash_image(image)
 
283
  target = _safe_text(correction) or _safe_text(ground_truth)
284
  pred = _safe_text(prediction)
285
  cer_score = None
@@ -302,13 +319,18 @@ def save_feedback(image: Image.Image, model_choice: str, prompt: str,
302
  "cer": float(cer_score) if cer_score is not None else None,
303
  }
304
  _append_jsonl(FEEDBACK_PATH, row)
305
- return f"βœ… Feedback saved (reward={reward})."
306
 
307
  def compile_memory_rules():
308
  _compile_rules_from_feedback(min_count=2, max_phrase_len=60)
309
  return "βœ… Memory rules recompiled from positive feedback."
310
 
311
  def export_grpo_preferences():
 
 
 
 
 
312
  if not os.path.exists(FEEDBACK_PATH):
313
  return "No feedback to export."
314
  count = 0
@@ -323,6 +345,7 @@ def export_grpo_preferences():
323
  corr = _safe_text(row.get("correction", "")) or _safe_text(row.get("ground_truth", ""))
324
  prompt = _safe_text(row.get("prompt", "")) or "Transcribe the image exactly."
325
  if corr and pred and corr != pred and row.get("reward", 0) >= 0:
 
326
  out = {
327
  "prompt": prompt,
328
  "image_sha256": row.get("image_sha256", ""),
@@ -334,16 +357,11 @@ def export_grpo_preferences():
334
  count += 1
335
  return f"βœ… Exported {count} GRPO preference pairs to {GRPO_EXPORT_PATH}."
336
 
337
- def get_grpo_file():
338
- if os.path.exists(GRPO_EXPORT_PATH):
339
- return GRPO_EXPORT_PATH
340
- return None
341
-
342
- def get_csv_file():
343
- _export_csv()
344
- if os.path.exists(CSV_EXPORT_PATH):
345
- return CSV_EXPORT_PATH
346
- return None
347
 
348
  # ---------------- Evaluation Orchestration ----------------
349
  @spaces.GPU
@@ -372,6 +390,9 @@ MODEL_ID = os.environ.get("BASE_MODEL", "Qwen/Qwen2.5-VL-7B-Instruct") # change
372
  OUTPUT_DIR = os.environ.get("OUTPUT_DIR", "grpo_output")
373
  DATA_PATH = os.environ.get("DATA_PATH", "data/grpo_prefs.jsonl")
374
 
 
 
 
375
  def _jsonl_dataset(jsonl_path):
376
  data = []
377
  with open(jsonl_path, "r", encoding="utf-8") as f:
@@ -392,6 +413,7 @@ def main():
392
  if not data:
393
  print("No GRPO data found.")
394
  return
 
395
  from datasets import Dataset
396
  ds = Dataset.from_list(data)
397
 
@@ -400,6 +422,7 @@ def main():
400
  MODEL_ID, trust_remote_code=True, device_map="auto"
401
  )
402
 
 
403
  cfg = GRPOConfig(
404
  output_dir=OUTPUT_DIR,
405
  learning_rate=5e-6,
@@ -415,7 +438,7 @@ def main():
415
 
416
  trainer = GRPOTrainer(
417
  model=model,
418
- ref_model=None,
419
  args=cfg,
420
  tokenizer=tok,
421
  train_dataset=ds
@@ -433,15 +456,15 @@ def _write_trainer_script():
433
  path = os.path.join("train", "grpo_train.py")
434
  with open(path, "w", encoding="utf-8") as f:
435
  f.write(TRAINER_SCRIPT)
436
- return path, "βœ… Trainer script written to train/grpo_train.py"
437
 
438
  # ---------------- Gradio Interface ----------------
439
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
440
  gr.Markdown("## ✍🏾 wilson Handwritten OCR β€” with Feedback Loop, Memory & GRPO Export")
441
 
442
  model_choice = gr.Radio(choices=list(MODEL_PATHS.keys()),
443
- value=list(MODEL_PATHS.keys())[0],
444
- label="Select OCR Model")
445
 
446
  with gr.Tab("πŸ–Ό Image Inference"):
447
  query_input = gr.Textbox(label="Custom Prompt (optional)", placeholder="Leave empty for RAW structured output")
@@ -460,6 +483,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
460
 
461
  raw_output = gr.Textbox(label="πŸ“œ Output (post-corrected if memory is ON)", lines=18, show_copy_button=True)
462
 
 
463
  gr.Markdown("### ✏️ Quick Feedback")
464
  correction_box = gr.Textbox(label="Your Correction (optional)", placeholder="Paste your corrected text here; leave empty if the output is perfect.", lines=8)
465
  ground_truth_box = gr.Textbox(label="Ground Truth (optional)", placeholder="If you have a reference transcription, paste it here.", lines=6)
@@ -486,12 +510,13 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
486
  audio_btn.click(fn=save_as_audio, inputs=[raw_output], outputs=[audio_file])
487
 
488
  def _clear():
489
- return ("", None, "", MAX_NEW_TOKENS_DEFAULT, 0.1, 1.0, 0, 1.0, True, "", "", "")
490
  clear_btn.click(
491
  fn=_clear,
492
  outputs=[raw_output, image_input, query_input, max_new_tokens, temperature, top_p, top_k, repetition_penalty, use_memory, correction_box, ground_truth_box, feedback_status]
493
  )
494
 
 
495
  btn_good.click(
496
  fn=lambda img, mc, prmpt, pred, corr, gt: save_feedback(img, mc, prmpt, pred, corr, gt, reward=1),
497
  inputs=[image_input, model_choice, query_input, raw_output, correction_box, ground_truth_box],
@@ -527,51 +552,38 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
527
 
528
  with gr.Tab("✏️ Feedback & Memory"):
529
  gr.Markdown("""
530
- **Pipeline**
531
- 1) Save feedback (πŸ‘ / πŸ‘Ž) and add corrections.
532
- 2) Click **Build/Refresh Memory** to generate auto-fix rules from positive feedback.
533
  3) Keep **Enable Memory Post-correction** checked on inference/eval tabs.
534
  """)
535
  build_mem_btn = gr.Button("🧠 Build/Refresh Memory from Feedback")
536
  mem_status = gr.Markdown()
537
  build_mem_btn.click(fn=compile_memory_rules, outputs=[mem_status])
538
 
 
539
  csv_status = gr.Markdown()
540
- gr.Markdown("---")
541
- gr.Markdown("### ⬇️ Download Feedback Data")
542
- with gr.Row():
543
- download_csv_btn = gr.Button("⬇️ Download Feedback as CSV")
544
- download_csv_file = gr.File(label="CSV File")
545
- download_csv_btn.click(fn=get_csv_file, outputs=[download_csv_file])
546
 
547
  with gr.Tab("πŸ§ͺ GRPO / Dataset"):
548
  gr.Markdown("""
549
  **GRPO Fine-tuning** (run offline or in a training Space):
550
  - Click **Export GRPO Preferences** to produce `data/grpo_prefs.jsonl` of (prompt, chosen, rejected).
551
  - Click **Write Trainer Script** to create `train/grpo_train.py`.
552
- - Then run:
553
  ```bash
554
  pip install trl accelerate peft transformers datasets
555
  python train/grpo_train.py
 
 
 
 
 
 
556
 
 
 
 
557
 
558
-
559
- """)
560
- export_grpo_btn = gr.Button("πŸ“¦ Export GRPO Preferences")
561
- grpo_status = gr.Markdown()
562
- export_grpo_file = gr.File(label="GRPO Preferences File")
563
- write_trainer_btn = gr.Button("πŸ“œ Write Trainer Script")
564
- trainer_status = gr.Markdown()
565
- trainer_file = gr.File(label="Trainer Script File")
566
- export_grpo_btn.click(fn=export_grpo_preferences, outputs=[grpo_status])
567
- export_grpo_btn.click(fn=get_grpo_file, outputs=[export_grpo_file])
568
- write_trainer_btn.click(fn=_write_trainer_script, outputs=[trainer_file, trainer_status])
569
-
570
-
571
-
572
-
573
-
574
- ...
575
  if __name__ == "__main__":
576
- # This line must be indented with 4 spaces or a single tab
577
- demo.queue(max_size=50).launch(share=True)
 
1
  # app.py β€” HTR Space with Feedback Loop, Memory Post-Correction, and GRPO Export
2
+
3
  import os, time, json, hashlib, difflib, uuid, csv
4
  from datetime import datetime
5
  from collections import Counter, defaultdict
 
18
 
19
  # ---------------- Storage & Paths ----------------
20
  os.makedirs("data", exist_ok=True)
21
+ FEEDBACK_PATH = "data/feedback.jsonl" # raw feedback log (per sample)
22
  MEMORY_RULES_PATH = "data/memory_rules.json" # compiled post-correction rules
23
+ GRPO_EXPORT_PATH = "data/grpo_prefs.jsonl" # preference pairs for GRPO
24
+ CSV_EXPORT_PATH = "data/feedback.csv" # optional tabular export
25
 
26
  # ---------------- Models ----------------
27
  MODEL_PATHS = {
28
+ "Model 1 (Complex handwrittings )": ("prithivMLmods/Qwen2.5-VL-7B-Abliterated-Caption-it", Qwen2_5_VLForConditionalGeneration),
29
+ "Model 2 (simple and scanned handwritting )": ("nanonets/Nanonets-OCR-s", Qwen2_5_VLForConditionalGeneration),
30
  }
31
+ # Model 3 removed to conserve memory.
32
 
33
  MAX_NEW_TOKENS_DEFAULT = 512
34
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
111
  return (text or "").strip()
112
 
113
  def _hash_image(image: Image.Image) -> str:
114
+ # stable hash for dedup / linking feedback to the same page
115
  img_bytes = image.tobytes()
116
  return hashlib.sha256(img_bytes).hexdigest()
117
 
 
133
  if not enabled or not text:
134
  return text
135
  rules = _load_memory_rules()
136
+ # 1) Model-specific replacements
137
  by_model = rules.get("by_model", {}).get(model_choice, {})
138
  for wrong, right in by_model.items():
139
  if wrong and right:
140
  text = text.replace(wrong, right)
141
+ # 2) Global replacements
142
  for wrong, right in rules.get("global", {}).items():
143
  if wrong and right:
144
  text = text.replace(wrong, right)
145
  return text
146
 
147
  def _compile_rules_from_feedback(min_count: int = 2, max_phrase_len: int = 40):
148
+ """
149
+ Build replacement rules by mining feedback pairs (prediction -> correction).
150
+ We extract phrases that consistently changed, with frequency >= min_count.
151
+ """
152
  changes_counter_global = Counter()
153
  changes_counter_by_model = defaultdict(Counter)
154
 
 
161
  row = json.loads(line)
162
  except Exception:
163
  continue
164
+ if row.get("reward", 0) < 1: # only learn from thumbs-up or explicit 'accepted_correction'
165
  continue
166
  pred = _safe_text(row.get("prediction", ""))
167
  corr = _safe_text(row.get("correction", "")) or _safe_text(row.get("ground_truth", ""))
168
  if not pred or not corr:
169
  continue
170
  model_choice = row.get("model_choice", "")
171
+ # Extract ops
172
  s = difflib.SequenceMatcher(None, pred, corr)
173
  for tag, i1, i2, j1, j2 in s.get_opcodes():
174
  if tag in ("replace", "delete", "insert"):
175
  wrong = pred[i1:i2]
176
  right = corr[j1:j2]
177
+ # keep short-ish tokens/phrases
178
  if 0 < len(wrong) <= max_phrase_len or 0 < len(right) <= max_phrase_len:
179
  if wrong.strip():
180
  changes_counter_global[(wrong, right)] += 1
 
182
  changes_counter_by_model[model_choice][(wrong, right)] += 1
183
 
184
  rules = {"global": {}, "by_model": {}}
185
+ # Global
186
  for (wrong, right), cnt in changes_counter_global.items():
187
  if cnt >= min_count and wrong and right and wrong != right:
188
  rules["global"][wrong] = right
189
+ # Per model
190
  for model_choice, ctr in changes_counter_by_model.items():
191
  rules["by_model"].setdefault(model_choice, {})
192
  for (wrong, right), cnt in ctr.items():
 
202
  temperature: float = 0.1, top_p: float = 1.0, top_k: int = 0, repetition_penalty: float = 1.0,
203
  use_memory: bool = True,
204
  progress=gr.Progress(track_tqdm=True)):
205
+ if image is None: return "Please upload or capture an image."
206
+ if model_choice not in _loaded_models: return f"Invalid model: {model_choice}"
 
 
207
  processor, model, tokenizer = _loaded_processors[model_choice], _loaded_models[model_choice], getattr(_loaded_processors[model_choice], "tokenizer", None)
208
  prompt = _default_prompt(query)
209
  batch = _build_inputs(processor, tokenizer, image, prompt).to(device)
 
211
  output_ids = model.generate(**batch, max_new_tokens=max_new_tokens, do_sample=False,
212
  temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty)
213
  raw = _decode_text(model, processor, tokenizer, output_ids, prompt).replace("<|im_end|>", "").strip()
214
+ # Apply memory post-correction
215
  post = _apply_memory(raw, model_choice, use_memory)
216
  return post
217
 
218
  # ---------------- Export Helpers ----------------
219
  def save_as_pdf(text):
220
  text = _safe_text(text)
221
+ if not text: return None
 
222
  doc = SimpleDocTemplate("output.pdf")
223
  flowables = [Paragraph(t, getSampleStyleSheet()["Normal"]) for t in text.splitlines() if t != ""]
224
+ if not flowables: flowables = [Paragraph(" ", getSampleStyleSheet()["Normal"])]
 
225
  doc.build(flowables)
226
  return "output.pdf"
227
 
228
  def save_as_word(text):
229
  text = _safe_text(text)
230
+ if not text: return None
 
231
  doc = Document()
232
  for line in text.splitlines():
233
  doc.add_paragraph(line)
 
236
 
237
  def save_as_audio(text):
238
  text = _safe_text(text)
239
+ if not text: return None
 
240
  try:
241
  tts = gTTS(text)
242
  tts.save("output.mp3")
 
247
 
248
  # ---------------- Metrics Function ----------------
249
  def calculate_cer_score(ground_truth: str, prediction: str) -> str:
250
+ """
251
+ Calculates the Character Error Rate (CER).
252
+ A CER of 0.0 means the prediction is perfect.
253
+ """
254
  if not ground_truth or not prediction:
255
  return "Cannot calculate CER: Missing ground truth or prediction."
256
  ground_truth_cleaned = " ".join(ground_truth.strip().split())
 
264
  f.write(json.dumps(obj, ensure_ascii=False) + "\n")
265
 
266
  def _export_csv():
267
+ # optional: CSV summary for spreadsheet views
268
  if not os.path.exists(FEEDBACK_PATH):
269
+ return CSV_EXPORT_PATH if os.path.exists(CSV_EXPORT_PATH) else None
270
  rows = []
271
  with open(FEEDBACK_PATH, "r", encoding="utf-8") as f:
272
  for line in f:
 
276
  pass
277
  if not rows:
278
  return None
279
+ keys = ["id","timestamp","model_choice","image_sha256","prompt","prediction","correction","ground_truth","reward","cer"]
280
  with open(CSV_EXPORT_PATH, "w", newline="", encoding="utf-8") as f:
281
  w = csv.DictWriter(f, fieldnames=keys)
282
  w.writeheader()
 
287
 
288
  def save_feedback(image: Image.Image, model_choice: str, prompt: str,
289
  prediction: str, correction: str, ground_truth: str, reward: int):
290
+ """
291
+ reward: 1 = good/accepted, 0 = neutral, -1 = bad
292
+ """
293
  if image is None:
294
+ return "Please provide the image again to link feedback.", 0
295
  if not prediction and not correction and not ground_truth:
296
+ return "Nothing to save.", 0
297
 
298
  image_hash = _hash_image(image)
299
+ # best target = correction, else ground_truth, else prediction
300
  target = _safe_text(correction) or _safe_text(ground_truth)
301
  pred = _safe_text(prediction)
302
  cer_score = None
 
319
  "cer": float(cer_score) if cer_score is not None else None,
320
  }
321
  _append_jsonl(FEEDBACK_PATH, row)
322
+ return f"βœ… Feedback saved (reward={reward}).", 1
323
 
324
  def compile_memory_rules():
325
  _compile_rules_from_feedback(min_count=2, max_phrase_len=60)
326
  return "βœ… Memory rules recompiled from positive feedback."
327
 
328
  def export_grpo_preferences():
329
+ """
330
+ Build preference pairs for GRPO training:
331
+ - chosen: correction/ground_truth when present
332
+ - rejected: original prediction
333
+ """
334
  if not os.path.exists(FEEDBACK_PATH):
335
  return "No feedback to export."
336
  count = 0
 
345
  corr = _safe_text(row.get("correction", "")) or _safe_text(row.get("ground_truth", ""))
346
  prompt = _safe_text(row.get("prompt", "")) or "Transcribe the image exactly."
347
  if corr and pred and corr != pred and row.get("reward", 0) >= 0:
348
+ # One preference datapoint
349
  out = {
350
  "prompt": prompt,
351
  "image_sha256": row.get("image_sha256", ""),
 
357
  count += 1
358
  return f"βœ… Exported {count} GRPO preference pairs to {GRPO_EXPORT_PATH}."
359
 
360
+ def export_csv():
361
+ p = _export_csv()
362
+ if p:
363
+ return f"βœ… CSV exported: {p}"
364
+ return "No data to export."
 
 
 
 
 
365
 
366
  # ---------------- Evaluation Orchestration ----------------
367
  @spaces.GPU
 
390
  OUTPUT_DIR = os.environ.get("OUTPUT_DIR", "grpo_output")
391
  DATA_PATH = os.environ.get("DATA_PATH", "data/grpo_prefs.jsonl")
392
 
393
+ # Our jsonl: each line has prompt, chosen, rejected (and image_sha256/model_choice optionally)
394
+ # We'll format as required by TRL: prompt + responses with one preferred
395
+
396
  def _jsonl_dataset(jsonl_path):
397
  data = []
398
  with open(jsonl_path, "r", encoding="utf-8") as f:
 
413
  if not data:
414
  print("No GRPO data found.")
415
  return
416
+ # Create a HuggingFace datasets Dataset from memory
417
  from datasets import Dataset
418
  ds = Dataset.from_list(data)
419
 
 
422
  MODEL_ID, trust_remote_code=True, device_map="auto"
423
  )
424
 
425
+ # Minimal config β€” tune to your GPU
426
  cfg = GRPOConfig(
427
  output_dir=OUTPUT_DIR,
428
  learning_rate=5e-6,
 
438
 
439
  trainer = GRPOTrainer(
440
  model=model,
441
+ ref_model=None, # let TRL create a frozen copy internally
442
  args=cfg,
443
  tokenizer=tok,
444
  train_dataset=ds
 
456
  path = os.path.join("train", "grpo_train.py")
457
  with open(path, "w", encoding="utf-8") as f:
458
  f.write(TRAINER_SCRIPT)
459
+ return path
460
 
461
  # ---------------- Gradio Interface ----------------
462
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
463
  gr.Markdown("## ✍🏾 wilson Handwritten OCR β€” with Feedback Loop, Memory & GRPO Export")
464
 
465
  model_choice = gr.Radio(choices=list(MODEL_PATHS.keys()),
466
+ value=list(MODEL_PATHS.keys())[0],
467
+ label="Select OCR Model")
468
 
469
  with gr.Tab("πŸ–Ό Image Inference"):
470
  query_input = gr.Textbox(label="Custom Prompt (optional)", placeholder="Leave empty for RAW structured output")
 
483
 
484
  raw_output = gr.Textbox(label="πŸ“œ Output (post-corrected if memory is ON)", lines=18, show_copy_button=True)
485
 
486
+ # Quick Feedback strip
487
  gr.Markdown("### ✏️ Quick Feedback")
488
  correction_box = gr.Textbox(label="Your Correction (optional)", placeholder="Paste your corrected text here; leave empty if the output is perfect.", lines=8)
489
  ground_truth_box = gr.Textbox(label="Ground Truth (optional)", placeholder="If you have a reference transcription, paste it here.", lines=6)
 
510
  audio_btn.click(fn=save_as_audio, inputs=[raw_output], outputs=[audio_file])
511
 
512
  def _clear():
513
+ return ("", None, "", MAX_NEW_TOKENS_DEFAULT, 0.1, 1.0, 0, 1.0, True, "", "", "",)
514
  clear_btn.click(
515
  fn=_clear,
516
  outputs=[raw_output, image_input, query_input, max_new_tokens, temperature, top_p, top_k, repetition_penalty, use_memory, correction_box, ground_truth_box, feedback_status]
517
  )
518
 
519
+ # Quick feedback save
520
  btn_good.click(
521
  fn=lambda img, mc, prmpt, pred, corr, gt: save_feedback(img, mc, prmpt, pred, corr, gt, reward=1),
522
  inputs=[image_input, model_choice, query_input, raw_output, correction_box, ground_truth_box],
 
552
 
553
  with gr.Tab("✏️ Feedback & Memory"):
554
  gr.Markdown("""
555
+ **Pipeline**
556
+ 1) Save feedback (πŸ‘ / πŸ‘Ž) and add corrections.
557
+ 2) Click **Build/Refresh Memory** to generate auto-fix rules from positive feedback.
558
  3) Keep **Enable Memory Post-correction** checked on inference/eval tabs.
559
  """)
560
  build_mem_btn = gr.Button("🧠 Build/Refresh Memory from Feedback")
561
  mem_status = gr.Markdown()
562
  build_mem_btn.click(fn=compile_memory_rules, outputs=[mem_status])
563
 
564
+ csv_btn = gr.Button("πŸ“€ Export Feedback as CSV")
565
  csv_status = gr.Markdown()
566
+ csv_btn.click(fn=export_csv, outputs=[csv_status])
 
 
 
 
 
567
 
568
  with gr.Tab("πŸ§ͺ GRPO / Dataset"):
569
  gr.Markdown("""
570
  **GRPO Fine-tuning** (run offline or in a training Space):
571
  - Click **Export GRPO Preferences** to produce `data/grpo_prefs.jsonl` of (prompt, chosen, rejected).
572
  - Click **Write Trainer Script** to create `train/grpo_train.py`.
573
+ - Then run:
574
  ```bash
575
  pip install trl accelerate peft transformers datasets
576
  python train/grpo_train.py
577
+ ```
578
+ Set `BASE_MODEL`/`OUTPUT_DIR` env vars if you like.
579
+ """)
580
+ grpo_btn = gr.Button("πŸ“¦ Export GRPO Preferences")
581
+ grpo_status = gr.Markdown()
582
+ grpo_btn.click(fn=export_grpo_preferences, outputs=[grpo_status])
583
 
584
+ write_script_btn = gr.Button("πŸ“ Write grpo_train.py")
585
+ write_script_status = gr.Markdown()
586
+ write_script_btn.click(fn=lambda: f"βœ… Trainer script written to `{_write_trainer_script()}`", outputs=[write_script_status])
587
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
588
  if __name__ == "__main__":
589
+ demo.queue(max_size=50).launch(share=True)