Emeritus-21 commited on
Commit
5b9541c
Β·
verified Β·
1 Parent(s): 5754029

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +396 -28
app.py CHANGED
@@ -1,5 +1,10 @@
1
- import os, time
 
 
 
 
2
  from threading import Thread
 
3
  import gradio as gr
4
  import spaces
5
  from PIL import Image
@@ -8,15 +13,22 @@ from transformers import AutoProcessor, AutoModelForImageTextToText, Qwen2_5_VLF
8
  from reportlab.platypus import SimpleDocTemplate, Paragraph
9
  from reportlab.lib.styles import getSampleStyleSheet
10
  from docx import Document
11
- from gtts import gTTS
12
  from jiwer import cer
13
 
 
 
 
 
 
 
 
14
  # ---------------- Models ----------------
15
  MODEL_PATHS = {
16
  "Model 1 (Complex handwrittings )": ("prithivMLmods/Qwen2.5-VL-7B-Abliterated-Caption-it", Qwen2_5_VLForConditionalGeneration),
17
  "Model 2 (simple and scanned handwritting )": ("nanonets/Nanonets-OCR-s", Qwen2_5_VLForConditionalGeneration),
18
  }
19
- # Model 3 has been removed to conserve memory.
20
 
21
  MAX_NEW_TOKENS_DEFAULT = 512
22
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -48,7 +60,8 @@ def warmup(progress=gr.Progress(track_tqdm=True)):
48
  messages = [{"role": "user", "content": [{"type": "text", "text": "Warmup."}]}]
49
  chat_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) if tokenizer and hasattr(tokenizer, "apply_chat_template") else "Warmup."
50
  inputs = processor(text=[chat_prompt], images=None, return_tensors="pt").to(device)
51
- with torch.inference_mode(): _ = model.generate(**inputs, max_new_tokens=1)
 
52
  return f"GPU warm and {default_model_choice} ready."
53
  except Exception as e:
54
  return f"Warmup skipped: {e}"
@@ -81,7 +94,8 @@ def _decode_text(model, processor, tokenizer, output_ids, prompt: str):
81
  return str(output_ids).strip()
82
 
83
  def _default_prompt(query: str | None) -> str:
84
- if query and query.strip(): return query.strip()
 
85
  return (
86
  "You are a professional Handwritten OCR system.\n"
87
  "TASK: Read the handwritten image and transcribe the text EXACTLY as written.\n"
@@ -93,11 +107,100 @@ def _default_prompt(query: str | None) -> str:
93
  "Return RAW transcription only."
94
  )
95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  # ---------------- OCR Function ----------------
97
  @spaces.GPU
98
  def ocr_image(image: Image.Image, model_choice: str, query: str = None,
99
  max_new_tokens: int = MAX_NEW_TOKENS_DEFAULT,
100
  temperature: float = 0.1, top_p: float = 1.0, top_k: int = 0, repetition_penalty: float = 1.0,
 
101
  progress=gr.Progress(track_tqdm=True)):
102
  if image is None: return "Please upload or capture an image."
103
  if model_choice not in _loaded_models: return f"Invalid model: {model_choice}"
@@ -107,11 +210,12 @@ def ocr_image(image: Image.Image, model_choice: str, query: str = None,
107
  with torch.inference_mode():
108
  output_ids = model.generate(**batch, max_new_tokens=max_new_tokens, do_sample=False,
109
  temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty)
110
- return _decode_text(model, processor, tokenizer, output_ids, prompt).replace("<|im_end|>", "").strip()
 
 
 
111
 
112
  # ---------------- Export Helpers ----------------
113
- def _safe_text(text: str) -> str: return (text or "").strip()
114
-
115
  def save_as_pdf(text):
116
  text = _safe_text(text)
117
  if not text: return None
@@ -125,76 +229,304 @@ def save_as_word(text):
125
  text = _safe_text(text)
126
  if not text: return None
127
  doc = Document()
128
- for line in text.splitlines(): doc.add_paragraph(line)
 
129
  doc.save("output.docx")
130
  return "output.docx"
131
 
132
  def save_as_audio(text):
133
  text = _safe_text(text)
134
  if not text: return None
135
- try:
136
  tts = gTTS(text)
137
  tts.save("output.mp3")
138
  return "output.mp3"
139
- except Exception as e:
140
  print(f"gTTS failed: {e}")
141
  return None
142
 
143
  # ---------------- Metrics Function ----------------
144
  def calculate_cer_score(ground_truth: str, prediction: str) -> str:
145
  """
146
- Calculates the Character Error Rate (CER) between two strings.
147
  A CER of 0.0 means the prediction is perfect.
148
  """
149
  if not ground_truth or not prediction:
150
  return "Cannot calculate CER: Missing ground truth or prediction."
151
-
152
  ground_truth_cleaned = " ".join(ground_truth.strip().split())
153
  prediction_cleaned = " ".join(prediction.strip().split())
154
-
155
  error_rate = cer(ground_truth_cleaned, prediction_cleaned)
156
  return f"Character Error Rate (CER): {error_rate:.4f}"
157
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  # ---------------- Evaluation Orchestration ----------------
159
  @spaces.GPU
160
  def perform_evaluation(image: Image.Image, model_name: str, ground_truth: str,
161
- max_new_tokens: int, temperature: float, top_p: float, top_k: int, repetition_penalty: float):
 
162
  if image is None or not ground_truth:
163
  return "Please upload an image and provide the ground truth.", "N/A"
164
-
165
- prediction = ocr_image(image, model_name, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty)
166
-
167
  cer_score = calculate_cer_score(ground_truth, prediction)
168
-
169
  return prediction, cer_score
170
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  # ---------------- Gradio Interface ----------------
172
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
173
- gr.Markdown("## ✍🏾 wilson Handwritten OCR")
174
- model_choice = gr.Radio(choices=list(MODEL_PATHS.keys()), value=list(MODEL_PATHS.keys())[0], label="Select OCR Model")
 
 
 
175
 
176
  with gr.Tab("πŸ–Ό Image Inference"):
177
  query_input = gr.Textbox(label="Custom Prompt (optional)", placeholder="Leave empty for RAW structured output")
178
  image_input = gr.Image(type="pil", label="Upload / Capture Handwritten Image", sources=["upload", "webcam"])
 
 
179
  with gr.Accordion("βš™οΈ Advanced Options", open=False):
180
  max_new_tokens = gr.Slider(1, 2048, value=MAX_NEW_TOKENS_DEFAULT, step=1, label="Max new tokens")
181
  temperature = gr.Slider(0.1, 2.0, value=0.1, step=0.05, label="Temperature")
182
  top_p = gr.Slider(0.05, 1.0, value=1.0, step=0.05, label="Top-p (nucleus)")
183
  top_k = gr.Slider(0, 1000, value=0, step=1, label="Top-k")
184
  repetition_penalty = gr.Slider(0.8, 2.0, value=1.0, step=0.05, label="Repetition penalty")
 
185
  extract_btn = gr.Button("πŸ“€ Extract RAW Text", variant="primary")
186
  clear_btn = gr.Button("🧹 Clear")
187
- raw_output = gr.Textbox(label="πŸ“œ RAW Structured Output (exact as written)", lines=18, show_copy_button=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  pdf_btn = gr.Button("⬇️ Download as PDF")
189
  word_btn = gr.Button("⬇️ Download as Word")
190
  audio_btn = gr.Button("πŸ”Š Download as Audio")
191
  pdf_file, word_file, audio_file = gr.File(label="PDF File"), gr.File(label="Word File"), gr.File(label="Audio File")
192
 
193
- extract_btn.click(fn=ocr_image, inputs=[image_input, model_choice, query_input, max_new_tokens, temperature, top_p, top_k, repetition_penalty], outputs=[raw_output], api_name="ocr_image")
 
 
 
 
 
194
  pdf_btn.click(fn=save_as_pdf, inputs=[raw_output], outputs=[pdf_file])
195
  word_btn.click(fn=save_as_word, inputs=[raw_output], outputs=[word_file])
196
  audio_btn.click(fn=save_as_audio, inputs=[raw_output], outputs=[audio_file])
197
- clear_btn.click(fn=lambda: ("", None, "", MAX_NEW_TOKENS_DEFAULT, 0.1, 1.0, 0, 1.0), outputs=[raw_output, image_input, query_input, max_new_tokens, temperature, top_p, top_k, repetition_penalty])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
 
199
  with gr.Tab("πŸ“Š Model Evaluation"):
200
  gr.Markdown("### πŸ” Evaluate Model Accuracy")
@@ -202,14 +534,15 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
202
  eval_ground_truth = gr.Textbox(label="Ground Truth (Correct Transcription)", lines=10, placeholder="Type or paste the correct text here.")
203
  eval_model_output = gr.Textbox(label="Model's Prediction", lines=10, interactive=False, show_copy_button=True)
204
  eval_cer_output = gr.Textbox(label="Metrics", interactive=False)
205
-
 
206
  with gr.Row():
207
  run_evaluation_btn = gr.Button("πŸš€ Run OCR and Evaluate", variant="primary")
208
  clear_evaluation_btn = gr.Button("🧹 Clear")
209
-
210
  run_evaluation_btn.click(
211
  fn=perform_evaluation,
212
- inputs=[eval_image_input, model_choice, eval_ground_truth, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
213
  outputs=[eval_model_output, eval_cer_output]
214
  )
215
  clear_evaluation_btn.click(
@@ -217,5 +550,40 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
217
  outputs=[eval_image_input, eval_ground_truth, eval_model_output, eval_cer_output]
218
  )
219
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  if __name__ == "__main__":
221
- demo.queue(max_size=50).launch(share=True)
 
1
+ # app.py β€” HTR Space with Feedback Loop, Memory Post-Correction, and GRPO Export
2
+
3
+ import os, time, json, hashlib, difflib, uuid, csv
4
+ from datetime import datetime
5
+ from collections import Counter, defaultdict
6
  from threading import Thread
7
+
8
  import gradio as gr
9
  import spaces
10
  from PIL import Image
 
13
  from reportlab.platypus import SimpleDocTemplate, Paragraph
14
  from reportlab.lib.styles import getSampleStyleSheet
15
  from docx import Document
16
+ from gtts import gTTS
17
  from jiwer import cer
18
 
19
+ # ---------------- Storage & Paths ----------------
20
+ os.makedirs("data", exist_ok=True)
21
+ FEEDBACK_PATH = "data/feedback.jsonl" # raw feedback log (per sample)
22
+ MEMORY_RULES_PATH = "data/memory_rules.json" # compiled post-correction rules
23
+ GRPO_EXPORT_PATH = "data/grpo_prefs.jsonl" # preference pairs for GRPO
24
+ CSV_EXPORT_PATH = "data/feedback.csv" # optional tabular export
25
+
26
  # ---------------- Models ----------------
27
  MODEL_PATHS = {
28
  "Model 1 (Complex handwrittings )": ("prithivMLmods/Qwen2.5-VL-7B-Abliterated-Caption-it", Qwen2_5_VLForConditionalGeneration),
29
  "Model 2 (simple and scanned handwritting )": ("nanonets/Nanonets-OCR-s", Qwen2_5_VLForConditionalGeneration),
30
  }
31
+ # Model 3 removed to conserve memory.
32
 
33
  MAX_NEW_TOKENS_DEFAULT = 512
34
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
60
  messages = [{"role": "user", "content": [{"type": "text", "text": "Warmup."}]}]
61
  chat_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) if tokenizer and hasattr(tokenizer, "apply_chat_template") else "Warmup."
62
  inputs = processor(text=[chat_prompt], images=None, return_tensors="pt").to(device)
63
+ with torch.inference_mode():
64
+ _ = model.generate(**inputs, max_new_tokens=1)
65
  return f"GPU warm and {default_model_choice} ready."
66
  except Exception as e:
67
  return f"Warmup skipped: {e}"
 
94
  return str(output_ids).strip()
95
 
96
  def _default_prompt(query: str | None) -> str:
97
+ if query and query.strip():
98
+ return query.strip()
99
  return (
100
  "You are a professional Handwritten OCR system.\n"
101
  "TASK: Read the handwritten image and transcribe the text EXACTLY as written.\n"
 
107
  "Return RAW transcription only."
108
  )
109
 
110
+ def _safe_text(text: str) -> str:
111
+ return (text or "").strip()
112
+
113
+ def _hash_image(image: Image.Image) -> str:
114
+ # stable hash for dedup / linking feedback to the same page
115
+ img_bytes = image.tobytes()
116
+ return hashlib.sha256(img_bytes).hexdigest()
117
+
118
+ # ---------------- Memory: Post-correction Rules ----------------
119
+ def _load_memory_rules():
120
+ if os.path.exists(MEMORY_RULES_PATH):
121
+ try:
122
+ with open(MEMORY_RULES_PATH, "r", encoding="utf-8") as f:
123
+ return json.load(f)
124
+ except Exception:
125
+ pass
126
+ return {"global": {}, "by_model": {}}
127
+
128
+ def _save_memory_rules(rules):
129
+ with open(MEMORY_RULES_PATH, "w", encoding="utf-8") as f:
130
+ json.dump(rules, f, ensure_ascii=False, indent=2)
131
+
132
+ def _apply_memory(text: str, model_choice: str, enabled: bool):
133
+ if not enabled or not text:
134
+ return text
135
+ rules = _load_memory_rules()
136
+ # 1) Model-specific replacements
137
+ by_model = rules.get("by_model", {}).get(model_choice, {})
138
+ for wrong, right in by_model.items():
139
+ if wrong and right:
140
+ text = text.replace(wrong, right)
141
+ # 2) Global replacements
142
+ for wrong, right in rules.get("global", {}).items():
143
+ if wrong and right:
144
+ text = text.replace(wrong, right)
145
+ return text
146
+
147
+ def _compile_rules_from_feedback(min_count: int = 2, max_phrase_len: int = 40):
148
+ """
149
+ Build replacement rules by mining feedback pairs (prediction -> correction).
150
+ We extract phrases that consistently changed, with frequency >= min_count.
151
+ """
152
+ changes_counter_global = Counter()
153
+ changes_counter_by_model = defaultdict(Counter)
154
+
155
+ if not os.path.exists(FEEDBACK_PATH):
156
+ return
157
+
158
+ with open(FEEDBACK_PATH, "r", encoding="utf-8") as f:
159
+ for line in f:
160
+ try:
161
+ row = json.loads(line)
162
+ except Exception:
163
+ continue
164
+ if row.get("reward", 0) < 1: # only learn from thumbs-up or explicit 'accepted_correction'
165
+ continue
166
+ pred = _safe_text(row.get("prediction", ""))
167
+ corr = _safe_text(row.get("correction", "")) or _safe_text(row.get("ground_truth", ""))
168
+ if not pred or not corr:
169
+ continue
170
+ model_choice = row.get("model_choice", "")
171
+ # Extract ops
172
+ s = difflib.SequenceMatcher(None, pred, corr)
173
+ for tag, i1, i2, j1, j2 in s.get_opcodes():
174
+ if tag in ("replace", "delete", "insert"):
175
+ wrong = pred[i1:i2]
176
+ right = corr[j1:j2]
177
+ # keep short-ish tokens/phrases
178
+ if 0 < len(wrong) <= max_phrase_len or 0 < len(right) <= max_phrase_len:
179
+ if wrong.strip():
180
+ changes_counter_global[(wrong, right)] += 1
181
+ if model_choice:
182
+ changes_counter_by_model[model_choice][(wrong, right)] += 1
183
+
184
+ rules = {"global": {}, "by_model": {}}
185
+ # Global
186
+ for (wrong, right), cnt in changes_counter_global.items():
187
+ if cnt >= min_count and wrong and right and wrong != right:
188
+ rules["global"][wrong] = right
189
+ # Per model
190
+ for model_choice, ctr in changes_counter_by_model.items():
191
+ rules["by_model"].setdefault(model_choice, {})
192
+ for (wrong, right), cnt in ctr.items():
193
+ if cnt >= min_count and wrong and right and wrong != right:
194
+ rules["by_model"][model_choice][wrong] = right
195
+
196
+ _save_memory_rules(rules)
197
+
198
  # ---------------- OCR Function ----------------
199
  @spaces.GPU
200
  def ocr_image(image: Image.Image, model_choice: str, query: str = None,
201
  max_new_tokens: int = MAX_NEW_TOKENS_DEFAULT,
202
  temperature: float = 0.1, top_p: float = 1.0, top_k: int = 0, repetition_penalty: float = 1.0,
203
+ use_memory: bool = True,
204
  progress=gr.Progress(track_tqdm=True)):
205
  if image is None: return "Please upload or capture an image."
206
  if model_choice not in _loaded_models: return f"Invalid model: {model_choice}"
 
210
  with torch.inference_mode():
211
  output_ids = model.generate(**batch, max_new_tokens=max_new_tokens, do_sample=False,
212
  temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty)
213
+ raw = _decode_text(model, processor, tokenizer, output_ids, prompt).replace("<|im_end|>", "").strip()
214
+ # Apply memory post-correction
215
+ post = _apply_memory(raw, model_choice, use_memory)
216
+ return post
217
 
218
  # ---------------- Export Helpers ----------------
 
 
219
  def save_as_pdf(text):
220
  text = _safe_text(text)
221
  if not text: return None
 
229
  text = _safe_text(text)
230
  if not text: return None
231
  doc = Document()
232
+ for line in text.splitlines():
233
+ doc.add_paragraph(line)
234
  doc.save("output.docx")
235
  return "output.docx"
236
 
237
  def save_as_audio(text):
238
  text = _safe_text(text)
239
  if not text: return None
240
+ try:
241
  tts = gTTS(text)
242
  tts.save("output.mp3")
243
  return "output.mp3"
244
+ except Exception as e:
245
  print(f"gTTS failed: {e}")
246
  return None
247
 
248
  # ---------------- Metrics Function ----------------
249
  def calculate_cer_score(ground_truth: str, prediction: str) -> str:
250
  """
251
+ Calculates the Character Error Rate (CER).
252
  A CER of 0.0 means the prediction is perfect.
253
  """
254
  if not ground_truth or not prediction:
255
  return "Cannot calculate CER: Missing ground truth or prediction."
 
256
  ground_truth_cleaned = " ".join(ground_truth.strip().split())
257
  prediction_cleaned = " ".join(prediction.strip().split())
 
258
  error_rate = cer(ground_truth_cleaned, prediction_cleaned)
259
  return f"Character Error Rate (CER): {error_rate:.4f}"
260
 
261
+ # ---------------- Feedback & Dataset ----------------
262
+ def _append_jsonl(path, obj):
263
+ with open(path, "a", encoding="utf-8") as f:
264
+ f.write(json.dumps(obj, ensure_ascii=False) + "\n")
265
+
266
+ def _export_csv():
267
+ # optional: CSV summary for spreadsheet views
268
+ if not os.path.exists(FEEDBACK_PATH):
269
+ return CSV_EXPORT_PATH if os.path.exists(CSV_EXPORT_PATH) else None
270
+ rows = []
271
+ with open(FEEDBACK_PATH, "r", encoding="utf-8") as f:
272
+ for line in f:
273
+ try:
274
+ rows.append(json.loads(line))
275
+ except Exception:
276
+ pass
277
+ if not rows:
278
+ return None
279
+ keys = ["id","timestamp","model_choice","image_sha256","prompt","prediction","correction","ground_truth","reward","cer"]
280
+ with open(CSV_EXPORT_PATH, "w", newline="", encoding="utf-8") as f:
281
+ w = csv.DictWriter(f, fieldnames=keys)
282
+ w.writeheader()
283
+ for r in rows:
284
+ flat = {k: r.get(k, "") for k in keys}
285
+ w.writerow(flat)
286
+ return CSV_EXPORT_PATH
287
+
288
+ def save_feedback(image: Image.Image, model_choice: str, prompt: str,
289
+ prediction: str, correction: str, ground_truth: str, reward: int):
290
+ """
291
+ reward: 1 = good/accepted, 0 = neutral, -1 = bad
292
+ """
293
+ if image is None:
294
+ return "Please provide the image again to link feedback.", 0
295
+ if not prediction and not correction and not ground_truth:
296
+ return "Nothing to save.", 0
297
+
298
+ image_hash = _hash_image(image)
299
+ # best target = correction, else ground_truth, else prediction
300
+ target = _safe_text(correction) or _safe_text(ground_truth)
301
+ pred = _safe_text(prediction)
302
+ cer_score = None
303
+ if target and pred:
304
+ try:
305
+ cer_score = cer(" ".join(target.split()), " ".join(pred.split()))
306
+ except Exception:
307
+ cer_score = None
308
+
309
+ row = {
310
+ "id": str(uuid.uuid4()),
311
+ "timestamp": datetime.utcnow().isoformat(),
312
+ "model_choice": model_choice or "",
313
+ "image_sha256": image_hash,
314
+ "prompt": _safe_text(prompt),
315
+ "prediction": pred,
316
+ "correction": _safe_text(correction),
317
+ "ground_truth": _safe_text(ground_truth),
318
+ "reward": int(reward),
319
+ "cer": float(cer_score) if cer_score is not None else None,
320
+ }
321
+ _append_jsonl(FEEDBACK_PATH, row)
322
+ return f"βœ… Feedback saved (reward={reward}).", 1
323
+
324
+ def compile_memory_rules():
325
+ _compile_rules_from_feedback(min_count=2, max_phrase_len=60)
326
+ return "βœ… Memory rules recompiled from positive feedback."
327
+
328
+ def export_grpo_preferences():
329
+ """
330
+ Build preference pairs for GRPO training:
331
+ - chosen: correction/ground_truth when present
332
+ - rejected: original prediction
333
+ """
334
+ if not os.path.exists(FEEDBACK_PATH):
335
+ return "No feedback to export."
336
+ count = 0
337
+ with open(GRPO_EXPORT_PATH, "w", encoding="utf-8") as out_f:
338
+ with open(FEEDBACK_PATH, "r", encoding="utf-8") as f:
339
+ for line in f:
340
+ try:
341
+ row = json.loads(line)
342
+ except Exception:
343
+ continue
344
+ pred = _safe_text(row.get("prediction", ""))
345
+ corr = _safe_text(row.get("correction", "")) or _safe_text(row.get("ground_truth", ""))
346
+ prompt = _safe_text(row.get("prompt", "")) or "Transcribe the image exactly."
347
+ if corr and pred and corr != pred and row.get("reward", 0) >= 0:
348
+ # One preference datapoint
349
+ out = {
350
+ "prompt": prompt,
351
+ "image_sha256": row.get("image_sha256", ""),
352
+ "chosen": corr,
353
+ "rejected": pred,
354
+ "model_choice": row.get("model_choice", "")
355
+ }
356
+ out_f.write(json.dumps(out, ensure_ascii=False) + "\n")
357
+ count += 1
358
+ return f"βœ… Exported {count} GRPO preference pairs to {GRPO_EXPORT_PATH}."
359
+
360
+ def export_csv():
361
+ p = _export_csv()
362
+ if p:
363
+ return f"βœ… CSV exported: {p}"
364
+ return "No data to export."
365
+
366
  # ---------------- Evaluation Orchestration ----------------
367
  @spaces.GPU
368
  def perform_evaluation(image: Image.Image, model_name: str, ground_truth: str,
369
+ max_new_tokens: int, temperature: float, top_p: float, top_k: int, repetition_penalty: float,
370
+ use_memory: bool = True):
371
  if image is None or not ground_truth:
372
  return "Please upload an image and provide the ground truth.", "N/A"
373
+ prediction = ocr_image(image, model_name, max_new_tokens=max_new_tokens,
374
+ temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty,
375
+ use_memory=use_memory)
376
  cer_score = calculate_cer_score(ground_truth, prediction)
 
377
  return prediction, cer_score
378
 
379
+ # ---------------- GRPO Trainer Script Writer ----------------
380
+ TRAINER_SCRIPT = r"""# grpo_train.py β€” Offline GRPO training with TRL (run separately)
381
+ # pip install trl accelerate peft transformers datasets
382
+ # This script expects data/grpo_prefs.jsonl produced by the app.
383
+
384
+ import os, json
385
+ from datasets import load_dataset
386
+ from transformers import AutoModelForCausalLM, AutoTokenizer
387
+ from trl import GRPOConfig, GRPOTrainer
388
+
389
+ MODEL_ID = os.environ.get("BASE_MODEL", "Qwen/Qwen2.5-VL-7B-Instruct") # change if needed
390
+ OUTPUT_DIR = os.environ.get("OUTPUT_DIR", "grpo_output")
391
+ DATA_PATH = os.environ.get("DATA_PATH", "data/grpo_prefs.jsonl")
392
+
393
+ # Our jsonl: each line has prompt, chosen, rejected (and image_sha256/model_choice optionally)
394
+ # We'll format as required by TRL: prompt + responses with one preferred
395
+
396
+ def _jsonl_dataset(jsonl_path):
397
+ data = []
398
+ with open(jsonl_path, "r", encoding="utf-8") as f:
399
+ for line in f:
400
+ try:
401
+ row = json.loads(line)
402
+ except Exception:
403
+ continue
404
+ prompt = row.get("prompt", "")
405
+ chosen = row.get("chosen", "")
406
+ rejected = row.get("rejected", "")
407
+ if prompt and chosen and rejected:
408
+ data.append({"prompt": prompt, "chosen": chosen, "rejected": rejected})
409
+ return data
410
+
411
+ def main():
412
+ data = _jsonl_dataset(DATA_PATH)
413
+ if not data:
414
+ print("No GRPO data found.")
415
+ return
416
+ # Create a HuggingFace datasets Dataset from memory
417
+ from datasets import Dataset
418
+ ds = Dataset.from_list(data)
419
+
420
+ tok = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
421
+ model = AutoModelForCausalLM.from_pretrained(
422
+ MODEL_ID, trust_remote_code=True, device_map="auto"
423
+ )
424
+
425
+ # Minimal config β€” tune to your GPU
426
+ cfg = GRPOConfig(
427
+ output_dir=OUTPUT_DIR,
428
+ learning_rate=5e-6,
429
+ per_device_train_batch_size=1,
430
+ gradient_accumulation_steps=8,
431
+ num_train_epochs=1,
432
+ logging_steps=10,
433
+ save_steps=200,
434
+ max_prompt_length=512,
435
+ max_completion_length=768,
436
+ bf16=True
437
+ )
438
+
439
+ trainer = GRPOTrainer(
440
+ model=model,
441
+ ref_model=None, # let TRL create a frozen copy internally
442
+ args=cfg,
443
+ tokenizer=tok,
444
+ train_dataset=ds
445
+ )
446
+ trainer.train()
447
+ trainer.save_model(OUTPUT_DIR)
448
+ print("βœ… GRPO training complete. LoRA/weights saved to", OUTPUT_DIR)
449
+
450
+ if __name__ == "__main__":
451
+ main()
452
+ """
453
+
454
+ def _write_trainer_script():
455
+ os.makedirs("train", exist_ok=True)
456
+ path = os.path.join("train", "grpo_train.py")
457
+ with open(path, "w", encoding="utf-8") as f:
458
+ f.write(TRAINER_SCRIPT)
459
+ return path
460
+
461
  # ---------------- Gradio Interface ----------------
462
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
463
+ gr.Markdown("## ✍🏾 wilson Handwritten OCR β€” with Feedback Loop, Memory & GRPO Export")
464
+
465
+ model_choice = gr.Radio(choices=list(MODEL_PATHS.keys()),
466
+ value=list(MODEL_PATHS.keys())[0],
467
+ label="Select OCR Model")
468
 
469
  with gr.Tab("πŸ–Ό Image Inference"):
470
  query_input = gr.Textbox(label="Custom Prompt (optional)", placeholder="Leave empty for RAW structured output")
471
  image_input = gr.Image(type="pil", label="Upload / Capture Handwritten Image", sources=["upload", "webcam"])
472
+ use_memory = gr.Checkbox(value=True, label="Enable Memory Post-correction (auto-fix known mistakes)")
473
+
474
  with gr.Accordion("βš™οΈ Advanced Options", open=False):
475
  max_new_tokens = gr.Slider(1, 2048, value=MAX_NEW_TOKENS_DEFAULT, step=1, label="Max new tokens")
476
  temperature = gr.Slider(0.1, 2.0, value=0.1, step=0.05, label="Temperature")
477
  top_p = gr.Slider(0.05, 1.0, value=1.0, step=0.05, label="Top-p (nucleus)")
478
  top_k = gr.Slider(0, 1000, value=0, step=1, label="Top-k")
479
  repetition_penalty = gr.Slider(0.8, 2.0, value=1.0, step=0.05, label="Repetition penalty")
480
+
481
  extract_btn = gr.Button("πŸ“€ Extract RAW Text", variant="primary")
482
  clear_btn = gr.Button("🧹 Clear")
483
+
484
+ raw_output = gr.Textbox(label="πŸ“œ Output (post-corrected if memory is ON)", lines=18, show_copy_button=True)
485
+
486
+ # Quick Feedback strip
487
+ gr.Markdown("### ✏️ Quick Feedback")
488
+ correction_box = gr.Textbox(label="Your Correction (optional)", placeholder="Paste your corrected text here; leave empty if the output is perfect.", lines=8)
489
+ ground_truth_box = gr.Textbox(label="Ground Truth (optional)", placeholder="If you have a reference transcription, paste it here.", lines=6)
490
+
491
+ with gr.Row():
492
+ btn_good = gr.Button("πŸ‘ Accept (Save Feedback as Correct)", variant="primary")
493
+ btn_bad = gr.Button("πŸ‘Ž Bad (Save Feedback as Incorrect)")
494
+
495
+ feedback_status = gr.Markdown()
496
+
497
  pdf_btn = gr.Button("⬇️ Download as PDF")
498
  word_btn = gr.Button("⬇️ Download as Word")
499
  audio_btn = gr.Button("πŸ”Š Download as Audio")
500
  pdf_file, word_file, audio_file = gr.File(label="PDF File"), gr.File(label="Word File"), gr.File(label="Audio File")
501
 
502
+ extract_btn.click(
503
+ fn=ocr_image,
504
+ inputs=[image_input, model_choice, query_input, max_new_tokens, temperature, top_p, top_k, repetition_penalty, use_memory],
505
+ outputs=[raw_output],
506
+ api_name="ocr_image"
507
+ )
508
  pdf_btn.click(fn=save_as_pdf, inputs=[raw_output], outputs=[pdf_file])
509
  word_btn.click(fn=save_as_word, inputs=[raw_output], outputs=[word_file])
510
  audio_btn.click(fn=save_as_audio, inputs=[raw_output], outputs=[audio_file])
511
+
512
+ def _clear():
513
+ return ("", None, "", MAX_NEW_TOKENS_DEFAULT, 0.1, 1.0, 0, 1.0, True, "", "", "",)
514
+ clear_btn.click(
515
+ fn=_clear,
516
+ outputs=[raw_output, image_input, query_input, max_new_tokens, temperature, top_p, top_k, repetition_penalty, use_memory, correction_box, ground_truth_box, feedback_status]
517
+ )
518
+
519
+ # Quick feedback save
520
+ btn_good.click(
521
+ fn=lambda img, mc, prmpt, pred, corr, gt: save_feedback(img, mc, prmpt, pred, corr, gt, reward=1),
522
+ inputs=[image_input, model_choice, query_input, raw_output, correction_box, ground_truth_box],
523
+ outputs=[feedback_status]
524
+ )
525
+ btn_bad.click(
526
+ fn=lambda img, mc, prmpt, pred, corr, gt: save_feedback(img, mc, prmpt, pred, corr, gt, reward=-1),
527
+ inputs=[image_input, model_choice, query_input, raw_output, correction_box, ground_truth_box],
528
+ outputs=[feedback_status]
529
+ )
530
 
531
  with gr.Tab("πŸ“Š Model Evaluation"):
532
  gr.Markdown("### πŸ” Evaluate Model Accuracy")
 
534
  eval_ground_truth = gr.Textbox(label="Ground Truth (Correct Transcription)", lines=10, placeholder="Type or paste the correct text here.")
535
  eval_model_output = gr.Textbox(label="Model's Prediction", lines=10, interactive=False, show_copy_button=True)
536
  eval_cer_output = gr.Textbox(label="Metrics", interactive=False)
537
+ eval_use_memory = gr.Checkbox(value=True, label="Enable Memory Post-correction")
538
+
539
  with gr.Row():
540
  run_evaluation_btn = gr.Button("πŸš€ Run OCR and Evaluate", variant="primary")
541
  clear_evaluation_btn = gr.Button("🧹 Clear")
542
+
543
  run_evaluation_btn.click(
544
  fn=perform_evaluation,
545
+ inputs=[eval_image_input, model_choice, eval_ground_truth, max_new_tokens, temperature, top_p, top_k, repetition_penalty, eval_use_memory],
546
  outputs=[eval_model_output, eval_cer_output]
547
  )
548
  clear_evaluation_btn.click(
 
550
  outputs=[eval_image_input, eval_ground_truth, eval_model_output, eval_cer_output]
551
  )
552
 
553
+ with gr.Tab("✏️ Feedback & Memory"):
554
+ gr.Markdown("""
555
+ **Pipeline**
556
+ 1) Save feedback (πŸ‘ / πŸ‘Ž) and add corrections.
557
+ 2) Click **Build/Refresh Memory** to generate auto-fix rules from positive feedback.
558
+ 3) Keep **Enable Memory Post-correction** checked on inference/eval tabs.
559
+ """)
560
+ build_mem_btn = gr.Button("🧠 Build/Refresh Memory from Feedback")
561
+ mem_status = gr.Markdown()
562
+ build_mem_btn.click(fn=compile_memory_rules, outputs=[mem_status])
563
+
564
+ csv_btn = gr.Button("πŸ“€ Export Feedback as CSV")
565
+ csv_status = gr.Markdown()
566
+ csv_btn.click(fn=export_csv, outputs=[csv_status])
567
+
568
+ with gr.Tab("πŸ§ͺ GRPO / Dataset"):
569
+ gr.Markdown("""
570
+ **GRPO Fine-tuning** (run offline or in a training Space):
571
+ - Click **Export GRPO Preferences** to produce `data/grpo_prefs.jsonl` of (prompt, chosen, rejected).
572
+ - Click **Write Trainer Script** to create `train/grpo_train.py`.
573
+ - Then run:
574
+ ```bash
575
+ pip install trl accelerate peft transformers datasets
576
+ python train/grpo_train.py
577
+ ```
578
+ Set `BASE_MODEL`/`OUTPUT_DIR` env vars if you like.
579
+ """)
580
+ grpo_btn = gr.Button("πŸ“¦ Export GRPO Preferences")
581
+ grpo_status = gr.Markdown()
582
+ grpo_btn.click(fn=export_grpo_preferences, outputs=[grpo_status])
583
+
584
+ write_script_btn = gr.Button("πŸ“ Write grpo_train.py")
585
+ write_script_status = gr.Markdown()
586
+ write_script_btn.click(fn=lambda: f"βœ… Trainer script written to `{_write_trainer_script()}`", outputs=[write_script_status])
587
+
588
  if __name__ == "__main__":
589
+ demo.queue(max_size=50).launch(share=True)