Emeritus-21 commited on
Commit
28b35fd
Β·
verified Β·
1 Parent(s): 4f9162e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +517 -148
app.py CHANGED
@@ -1,10 +1,12 @@
1
- # app.py β€” HTR Space Full Version with RPL, GRPO, Multi-Format Export, Embedding Similarity
2
 
3
- import os, time, json, hashlib, uuid, csv
4
  from datetime import datetime
 
5
  from threading import Thread
6
- from collections import defaultdict
7
  import gradio as gr
 
8
  from PIL import Image
9
  import torch
10
  from transformers import AutoProcessor, AutoModelForImageTextToText, Qwen2_5_VLForConditionalGeneration
@@ -13,208 +15,575 @@ from reportlab.lib.styles import getSampleStyleSheet
13
  from docx import Document
14
  from gtts import gTTS
15
  from jiwer import cer
16
- import numpy as np
17
- from sklearn.metrics.pairwise import cosine_similarity
18
 
19
- # ---------------- Paths ----------------
20
  os.makedirs("data", exist_ok=True)
21
- FEEDBACK_RPL_PATH = "data/feedback_rpl.jsonl"
22
- GRPO_PATH = "data/grpo_prefs.jsonl"
23
- CSV_PATH = "data/feedback_rpl.csv"
 
24
 
25
  # ---------------- Models ----------------
26
  MODEL_PATHS = {
27
  "Model 1 (Complex handwrittings )": ("prithivMLmods/Qwen2.5-VL-7B-Abliterated-Caption-it", Qwen2_5_VLForConditionalGeneration),
28
- "Model 2 (Simple scanned handwritting )": ("nanonets/Nanonets-OCR-s", Qwen2_5_VLForConditionalGeneration)
29
  }
30
 
 
 
31
  device = "cuda" if torch.cuda.is_available() else "cpu"
32
  _loaded_processors, _loaded_models = {}, {}
33
 
34
- print("πŸš€ Loading models...")
35
  for name, (repo_id, cls) in MODEL_PATHS.items():
36
- processor = AutoProcessor.from_pretrained(repo_id, trust_remote_code=True)
37
- model = cls.from_pretrained(repo_id, trust_remote_code=True).to(device).eval()
38
- _loaded_processors[name], _loaded_models[name] = processor, model
39
- print(f"βœ… {name} ready.")
40
-
41
- MAX_NEW_TOKENS_DEFAULT = 512
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  # ---------------- Helpers ----------------
44
- def _hash_image(image: Image.Image) -> str:
45
- return hashlib.sha256(image.tobytes()).hexdigest()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  def _safe_text(text: str) -> str:
48
  return (text or "").strip()
49
 
50
- def _default_prompt(query: str | None) -> str:
51
- if query and query.strip(): return query.strip()
52
- return ("You are a professional Handwritten OCR system.\n"
53
- "TASK: Read the handwritten image and transcribe exactly as written.\n"
54
- "- Preserve line breaks, indentation, bullets, numbering.\n"
55
- "- Tables as Markdown tables if present.\n"
56
- "- Do NOT autocorrect spelling or merge lines.\n"
57
- "Return RAW transcription only.")
58
-
59
- def _append_jsonl(path, obj):
60
- with open(path, "a", encoding="utf-8") as f: f.write(json.dumps(obj, ensure_ascii=False) + "\n")
61
 
62
- # ---------------- OCR ----------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  def ocr_image(image: Image.Image, model_choice: str, query: str = None,
64
  max_new_tokens: int = MAX_NEW_TOKENS_DEFAULT,
65
  temperature: float = 0.1, top_p: float = 1.0, top_k: int = 0, repetition_penalty: float = 1.0,
66
- use_rpl: bool = True):
67
- if image is None: return "Upload image first."
68
- processor, model = _loaded_processors[model_choice], _loaded_models[model_choice]
 
 
69
  prompt = _default_prompt(query)
70
-
71
- # Build input
72
- batch = processor(text=[prompt], images=[image], return_tensors="pt").to(device)
73
  with torch.inference_mode():
74
  output_ids = model.generate(**batch, max_new_tokens=max_new_tokens, do_sample=False,
75
  temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty)
76
- raw_text = processor.batch_decode(output_ids, skip_special_tokens=True)[0].replace("<|im_end|>", "").strip()
 
 
 
77
 
78
- # RPL: Apply feedback using embedding similarity
79
- if use_rpl and os.path.exists(FEEDBACK_RPL_PATH):
80
- try:
81
- current_embedding = np.random.rand(768).reshape(1, -1) # placeholder for real embedding
82
- for line in open(FEEDBACK_RPL_PATH, encoding="utf-8"):
83
- row = json.loads(line)
84
- if row.get("reward", 0) < 1: continue
85
- emb = np.array(row.get("embedding", np.random.rand(768))).reshape(1, -1)
86
- sim = cosine_similarity(current_embedding, emb)[0][0]
87
- if sim > 0.85:
88
- raw_text = row.get("correction") or row.get("ground_truth")
89
- break
90
- except Exception: pass
91
- return raw_text
92
-
93
- # ---------------- Feedback ----------------
94
- def save_feedback(image: Image.Image, model_choice: str, prompt: str,
95
- prediction: str, correction: str, ground_truth: str, reward: int):
96
- if image is None: return "Provide image.", 0
97
- row = {
98
- "id": str(uuid.uuid4()),
99
- "timestamp": datetime.utcnow().isoformat(),
100
- "model_choice": model_choice,
101
- "image_sha256": _hash_image(image),
102
- "prompt": _safe_text(prompt),
103
- "prediction": _safe_text(prediction),
104
- "correction": _safe_text(correction),
105
- "ground_truth": _safe_text(ground_truth),
106
- "reward": reward,
107
- "embedding": np.random.rand(768).tolist()
108
- }
109
- _append_jsonl(FEEDBACK_RPL_PATH, row)
110
- return f"βœ… Feedback saved (reward={reward}).", 1
111
-
112
- def export_csv():
113
- if not os.path.exists(FEEDBACK_RPL_PATH): return None
114
- keys, rows = None, []
115
- for line in open(FEEDBACK_RPL_PATH, encoding="utf-8"):
116
- try: row = json.loads(line); rows.append(row); keys = keys or list(row.keys())
117
- except: continue
118
- if not rows: return None
119
- with open(CSV_PATH, "w", newline="", encoding="utf-8") as f:
120
- writer = csv.DictWriter(f, fieldnames=keys)
121
- writer.writeheader(); writer.writerows(rows)
122
- return CSV_PATH
123
-
124
- # ---------------- Export Formats ----------------
125
- def save_pdf(text):
126
  text = _safe_text(text)
127
  if not text: return None
128
  doc = SimpleDocTemplate("output.pdf")
129
- flowables = [Paragraph(l, getSampleStyleSheet()["Normal"]) for l in text.splitlines() if l.strip()]
130
- doc.build(flowables or [Paragraph(" ", getSampleStyleSheet()["Normal"])])
 
131
  return "output.pdf"
132
 
133
- def save_word(text):
134
  text = _safe_text(text)
135
  if not text: return None
136
  doc = Document()
137
- for l in text.splitlines(): doc.add_paragraph(l)
 
138
  doc.save("output.docx")
139
  return "output.docx"
140
 
141
- def save_audio(text):
142
  text = _safe_text(text)
143
  if not text: return None
144
- try: gTTS(text).save("output.mp3"); return "output.mp3"
145
- except: return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
- def cer_score(gt, pred):
148
- if not gt or not pred: return "Missing ground truth or prediction."
149
- return f"CER: {cer(gt, pred):.4f}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
 
151
- # ---------------- GRPO Example ----------------
152
- def save_grpo(name, pref_dict):
153
- row = {"id": str(uuid.uuid4()), "timestamp": datetime.utcnow().isoformat(), "name": name, "prefs": pref_dict}
154
- _append_jsonl(GRPO_PATH, row)
155
- return f"βœ… GRPO saved for {name}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
  # ---------------- Gradio Interface ----------------
158
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
159
- gr.Markdown("## ✍🏾 Handwritten Text Recognition | Full Feedback & Export")
160
 
161
- model_choice = gr.Radio(list(MODEL_PATHS.keys()), value=list(MODEL_PATHS.keys())[0], label="OCR Model")
 
 
162
 
163
- with gr.Tab("πŸ–Ό OCR & Feedback"):
164
- query_input = gr.Textbox(label="Custom Prompt (optional)")
165
- image_input = gr.Image(type="pil", label="Upload / Capture Handwritten Image")
166
- use_rpl = gr.Checkbox(value=True, label="Enable RPL Feedback")
167
 
168
  with gr.Accordion("βš™οΈ Advanced Options", open=False):
169
  max_new_tokens = gr.Slider(1, 2048, value=MAX_NEW_TOKENS_DEFAULT, step=1, label="Max new tokens")
170
  temperature = gr.Slider(0.1, 2.0, value=0.1, step=0.05, label="Temperature")
171
- top_p = gr.Slider(0.05,1.0,value=1.0,step=0.05,label="Top-p")
172
- top_k = gr.Slider(0,1000,value=0,step=1,label="Top-k")
173
- repetition_penalty = gr.Slider(0.8,2.0,value=1.0,step=0.05,label="Repetition penalty")
 
 
 
174
 
175
- extract_btn = gr.Button("πŸ“€ Extract RAW Text")
176
- raw_output = gr.Textbox(label="πŸ“œ Output", lines=18, show_copy_button=True)
177
 
 
178
  gr.Markdown("### ✏️ Quick Feedback")
179
- correction_box = gr.Textbox(label="Your Correction", lines=8)
180
- ground_truth_box = gr.Textbox(label="Ground Truth", lines=6)
181
- btn_good = gr.Button("πŸ‘ Accept (Correct)")
182
- btn_bad = gr.Button("πŸ‘Ž Bad (Incorrect)")
 
 
 
183
  feedback_status = gr.Markdown()
184
 
185
- extract_btn.click(ocr_image,
186
- [image_input, model_choice, query_input, max_new_tokens, temperature, top_p, top_k, repetition_penalty, use_rpl],
187
- raw_output)
188
-
189
- btn_good.click(lambda img, mc, prmpt, pred, corr, gt: save_feedback(img, mc, prmpt, pred, corr, gt, 1),
190
- [image_input, model_choice, query_input, raw_output, correction_box, ground_truth_box],
191
- feedback_status)
192
- btn_bad.click(lambda img, mc, prmpt, pred, corr, gt: save_feedback(img, mc, prmpt, pred, corr, gt, -1),
193
- [image_input, model_choice, query_input, raw_output, correction_box, ground_truth_box],
194
- feedback_status)
195
-
196
- gr.Markdown("### πŸ“₯ Download Feedback")
197
- download_jsonl_btn = gr.File(label="Download JSONL")
198
- download_csv_btn = gr.File(label="Download CSV")
199
- download_jsonl_btn.click(lambda: FEEDBACK_RPL_PATH if os.path.exists(FEEDBACK_RPL_PATH) else None,
200
- download_jsonl_btn)
201
- download_csv_btn.click(export_csv, download_csv_btn)
202
-
203
- with gr.Tab("πŸ“ Export Formats"):
204
- pdf_btn = gr.Button("Save as PDF")
205
- word_btn = gr.Button("Save as Word")
206
- audio_btn = gr.Button("Save as Audio")
207
- text_input = gr.Textbox(label="Text to Export", lines=10)
208
- pdf_btn.click(save_pdf, text_input, gr.File())
209
- word_btn.click(save_word, text_input, gr.File())
210
- audio_btn.click(save_audio, text_input, gr.File())
211
-
212
- with gr.Tab("πŸŽ› GRPO Preferences"):
213
- user_name = gr.Textbox(label="Name")
214
- grpo_dict_input = gr.Textbox(label="Preferences (JSON)")
215
- grpo_save_btn = gr.Button("Save GRPO")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  grpo_status = gr.Markdown()
217
- grpo_save_btn.click(lambda n,p: save_grpo(n,json.loads(p or "{}")), [user_name, grpo_dict_input], grpo_status)
 
 
 
 
218
 
219
  if __name__ == "__main__":
220
  demo.queue(max_size=50).launch(share=True)
 
1
+ # app.py β€” HTR Space with Feedback Loop, Memory Post-Correction, and GRPO Export
2
 
3
+ import os, time, json, hashlib, difflib, uuid, csv
4
  from datetime import datetime
5
+ from collections import Counter, defaultdict
6
  from threading import Thread
7
+
8
  import gradio as gr
9
+ import spaces
10
  from PIL import Image
11
  import torch
12
  from transformers import AutoProcessor, AutoModelForImageTextToText, Qwen2_5_VLForConditionalGeneration
 
15
  from docx import Document
16
  from gtts import gTTS
17
  from jiwer import cer
 
 
18
 
19
+ # ---------------- Storage & Paths ----------------
20
  os.makedirs("data", exist_ok=True)
21
+ FEEDBACK_PATH = "data/feedback.jsonl" # raw feedback log (per sample)
22
+ MEMORY_RULES_PATH = "data/memory_rules.json" # compiled post-correction rules
23
+ GRPO_EXPORT_PATH = "data/grpo_prefs.jsonl" # preference pairs for GRPO
24
+ CSV_EXPORT_PATH = "data/feedback.csv" # optional tabular export
25
 
26
  # ---------------- Models ----------------
27
  MODEL_PATHS = {
28
  "Model 1 (Complex handwrittings )": ("prithivMLmods/Qwen2.5-VL-7B-Abliterated-Caption-it", Qwen2_5_VLForConditionalGeneration),
29
+ "Model 2 (simple and scanned handwritting )": ("nanonets/Nanonets-OCR-s", Qwen2_5_VLForConditionalGeneration),
30
  }
31
 
32
+
33
+ MAX_NEW_TOKENS_DEFAULT = 512
34
  device = "cuda" if torch.cuda.is_available() else "cpu"
35
  _loaded_processors, _loaded_models = {}, {}
36
 
37
+ print("πŸš€ Preloading models into GPU/CPU memory...")
38
  for name, (repo_id, cls) in MODEL_PATHS.items():
39
+ try:
40
+ processor = AutoProcessor.from_pretrained(repo_id, trust_remote_code=True)
41
+ model = cls.from_pretrained(
42
+ repo_id,
43
+ trust_remote_code=True,
44
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
45
+ low_cpu_mem_usage=True
46
+ ).to(device).eval()
47
+ _loaded_processors[name], _loaded_models[name] = processor, model
48
+ print(f"βœ… {name} ready.")
49
+ except Exception as e:
50
+ print(f"⚠️ Failed to load {name}: {e}")
51
+
52
+ # ---------------- GPU Warmup ----------------
53
+ @spaces.GPU
54
+ def warmup(progress=gr.Progress(track_tqdm=True)):
55
+ try:
56
+ default_model_choice = next(iter(MODEL_PATHS.keys()))
57
+ processor = _loaded_processors[default_model_choice]
58
+ model = _loaded_models[default_model_choice]
59
+ tokenizer = getattr(processor, "tokenizer", None)
60
+ messages = [{"role": "user", "content": [{"type": "text", "text": "Warmup."}]}]
61
+ chat_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) if tokenizer and hasattr(tokenizer, "apply_chat_template") else "Warmup."
62
+ inputs = processor(text=[chat_prompt], images=None, return_tensors="pt").to(device)
63
+ with torch.inference_mode():
64
+ _ = model.generate(**inputs, max_new_tokens=1)
65
+ return f"GPU warm and {default_model_choice} ready."
66
+ except Exception as e:
67
+ return f"Warmup skipped: {e}"
68
 
69
  # ---------------- Helpers ----------------
70
+ def _build_inputs(processor, tokenizer, image: Image.Image, prompt: str):
71
+ messages = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": prompt}]}]
72
+ if tokenizer and hasattr(tokenizer, "apply_chat_template"):
73
+ chat_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
74
+ return processor(text=[chat_prompt], images=[image], return_tensors="pt")
75
+ return processor(text=[prompt], images=[image], return_tensors="pt")
76
+
77
+ def _decode_text(model, processor, tokenizer, output_ids, prompt: str):
78
+ try:
79
+ decoded_text = processor.batch_decode(output_ids, skip_special_tokens=True)[0]
80
+ prompt_start = decoded_text.find(prompt)
81
+ if prompt_start != -1:
82
+ decoded_text = decoded_text[prompt_start + len(prompt):].strip()
83
+ else:
84
+ decoded_text = decoded_text.strip()
85
+ return decoded_text
86
+ except Exception:
87
+ try:
88
+ decoded_text = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
89
+ prompt_start = decoded_text.find(prompt)
90
+ if prompt_start != -1:
91
+ decoded_text = decoded_text[prompt_start + len(prompt):].strip()
92
+ return decoded_text
93
+ except Exception:
94
+ return str(output_ids).strip()
95
+
96
+ def _default_prompt(query: str | None) -> str:
97
+ if query and query.strip():
98
+ return query.strip()
99
+ return (
100
+ "You are a professional Handwritten OCR system.\n"
101
+ "TASK: Read the handwritten image and transcribe the text EXACTLY as written.\n"
102
+ "- Preserve original structure and line breaks.\n"
103
+ "- Keep spacing, bullet points, numbering, and indentation.\n"
104
+ "- Render tables as Markdown tables if present.\n"
105
+ "- Do NOT autocorrect spelling or grammar.\n"
106
+ "- Do NOT merge lines.\n"
107
+ "Return RAW transcription only."
108
+ )
109
 
110
  def _safe_text(text: str) -> str:
111
  return (text or "").strip()
112
 
113
+ def _hash_image(image: Image.Image) -> str:
114
+ # stable hash for dedup / linking feedback to the same page
115
+ img_bytes = image.tobytes()
116
+ return hashlib.sha256(img_bytes).hexdigest()
 
 
 
 
 
 
 
117
 
118
+ # ---------------- Memory: Post-correction Rules ----------------
119
+ def _load_memory_rules():
120
+ if os.path.exists(MEMORY_RULES_PATH):
121
+ try:
122
+ with open(MEMORY_RULES_PATH, "r", encoding="utf-8") as f:
123
+ return json.load(f)
124
+ except Exception:
125
+ pass
126
+ return {"global": {}, "by_model": {}}
127
+
128
+ def _save_memory_rules(rules):
129
+ with open(MEMORY_RULES_PATH, "w", encoding="utf-8") as f:
130
+ json.dump(rules, f, ensure_ascii=False, indent=2)
131
+
132
+ def _apply_memory(text: str, model_choice: str, enabled: bool):
133
+ if not enabled or not text:
134
+ return text
135
+ rules = _load_memory_rules()
136
+ # 1) Model-specific replacements
137
+ by_model = rules.get("by_model", {}).get(model_choice, {})
138
+ for wrong, right in by_model.items():
139
+ if wrong and right:
140
+ text = text.replace(wrong, right)
141
+ # 2) Global replacements
142
+ for wrong, right in rules.get("global", {}).items():
143
+ if wrong and right:
144
+ text = text.replace(wrong, right)
145
+ return text
146
+
147
+ def _compile_rules_from_feedback(min_count: int = 2, max_phrase_len: int = 40):
148
+ """
149
+ Build replacement rules by mining feedback pairs (prediction -> correction).
150
+ We extract phrases that consistently changed, with frequency >= min_count.
151
+ """
152
+ changes_counter_global = Counter()
153
+ changes_counter_by_model = defaultdict(Counter)
154
+
155
+ if not os.path.exists(FEEDBACK_PATH):
156
+ return
157
+
158
+ with open(FEEDBACK_PATH, "r", encoding="utf-8") as f:
159
+ for line in f:
160
+ try:
161
+ row = json.loads(line)
162
+ except Exception:
163
+ continue
164
+ if row.get("reward", 0) < 1: # only learn from thumbs-up or explicit 'accepted_correction'
165
+ continue
166
+ pred = _safe_text(row.get("prediction", ""))
167
+ corr = _safe_text(row.get("correction", "")) or _safe_text(row.get("ground_truth", ""))
168
+ if not pred or not corr:
169
+ continue
170
+ model_choice = row.get("model_choice", "")
171
+ # Extract ops
172
+ s = difflib.SequenceMatcher(None, pred, corr)
173
+ for tag, i1, i2, j1, j2 in s.get_opcodes():
174
+ if tag in ("replace", "delete", "insert"):
175
+ wrong = pred[i1:i2]
176
+ right = corr[j1:j2]
177
+ # keep short-ish tokens/phrases
178
+ if 0 < len(wrong) <= max_phrase_len or 0 < len(right) <= max_phrase_len:
179
+ if wrong.strip():
180
+ changes_counter_global[(wrong, right)] += 1
181
+ if model_choice:
182
+ changes_counter_by_model[model_choice][(wrong, right)] += 1
183
+
184
+ rules = {"global": {}, "by_model": {}}
185
+ # Global
186
+ for (wrong, right), cnt in changes_counter_global.items():
187
+ if cnt >= min_count and wrong and right and wrong != right:
188
+ rules["global"][wrong] = right
189
+ # Per model
190
+ for model_choice, ctr in changes_counter_by_model.items():
191
+ rules["by_model"].setdefault(model_choice, {})
192
+ for (wrong, right), cnt in ctr.items():
193
+ if cnt >= min_count and wrong and right and wrong != right:
194
+ rules["by_model"][model_choice][wrong] = right
195
+
196
+ _save_memory_rules(rules)
197
+
198
+ # ---------------- OCR Function ----------------
199
+ @spaces.GPU
200
  def ocr_image(image: Image.Image, model_choice: str, query: str = None,
201
  max_new_tokens: int = MAX_NEW_TOKENS_DEFAULT,
202
  temperature: float = 0.1, top_p: float = 1.0, top_k: int = 0, repetition_penalty: float = 1.0,
203
+ use_memory: bool = True,
204
+ progress=gr.Progress(track_tqdm=True)):
205
+ if image is None: return "Please upload or capture an image."
206
+ if model_choice not in _loaded_models: return f"Invalid model: {model_choice}"
207
+ processor, model, tokenizer = _loaded_processors[model_choice], _loaded_models[model_choice], getattr(_loaded_processors[model_choice], "tokenizer", None)
208
  prompt = _default_prompt(query)
209
+ batch = _build_inputs(processor, tokenizer, image, prompt).to(device)
 
 
210
  with torch.inference_mode():
211
  output_ids = model.generate(**batch, max_new_tokens=max_new_tokens, do_sample=False,
212
  temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty)
213
+ raw = _decode_text(model, processor, tokenizer, output_ids, prompt).replace("<|im_end|>", "").strip()
214
+ # Apply memory post-correction
215
+ post = _apply_memory(raw, model_choice, use_memory)
216
+ return post
217
 
218
+ # ---------------- Export Helpers ----------------
219
+ def save_as_pdf(text):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  text = _safe_text(text)
221
  if not text: return None
222
  doc = SimpleDocTemplate("output.pdf")
223
+ flowables = [Paragraph(t, getSampleStyleSheet()["Normal"]) for t in text.splitlines() if t != ""]
224
+ if not flowables: flowables = [Paragraph(" ", getSampleStyleSheet()["Normal"])]
225
+ doc.build(flowables)
226
  return "output.pdf"
227
 
228
+ def save_as_word(text):
229
  text = _safe_text(text)
230
  if not text: return None
231
  doc = Document()
232
+ for line in text.splitlines():
233
+ doc.add_paragraph(line)
234
  doc.save("output.docx")
235
  return "output.docx"
236
 
237
+ def save_as_audio(text):
238
  text = _safe_text(text)
239
  if not text: return None
240
+ try:
241
+ tts = gTTS(text)
242
+ tts.save("output.mp3")
243
+ return "output.mp3"
244
+ except Exception as e:
245
+ print(f"gTTS failed: {e}")
246
+ return None
247
+
248
+ # ---------------- Metrics Function ----------------
249
+ def calculate_cer_score(ground_truth: str, prediction: str) -> str:
250
+ """
251
+ Calculates the Character Error Rate (CER).
252
+ A CER of 0.0 means the prediction is perfect.
253
+ """
254
+ if not ground_truth or not prediction:
255
+ return "Cannot calculate CER: Missing ground truth or prediction."
256
+ ground_truth_cleaned = " ".join(ground_truth.strip().split())
257
+ prediction_cleaned = " ".join(prediction.strip().split())
258
+ error_rate = cer(ground_truth_cleaned, prediction_cleaned)
259
+ return f"Character Error Rate (CER): {error_rate:.4f}"
260
+
261
+ # ---------------- Feedback & Dataset ----------------
262
+ def _append_jsonl(path, obj):
263
+ with open(path, "a", encoding="utf-8") as f:
264
+ f.write(json.dumps(obj, ensure_ascii=False) + "\n")
265
+
266
+ def _export_csv():
267
+ # optional: CSV summary for spreadsheet views
268
+ if not os.path.exists(FEEDBACK_PATH):
269
+ return CSV_EXPORT_PATH if os.path.exists(CSV_EXPORT_PATH) else None
270
+ rows = []
271
+ with open(FEEDBACK_PATH, "r", encoding="utf-8") as f:
272
+ for line in f:
273
+ try:
274
+ rows.append(json.loads(line))
275
+ except Exception:
276
+ pass
277
+ if not rows:
278
+ return None
279
+ keys = ["id","timestamp","model_choice","image_sha256","prompt","prediction","correction","ground_truth","reward","cer"]
280
+ with open(CSV_EXPORT_PATH, "w", newline="", encoding="utf-8") as f:
281
+ w = csv.DictWriter(f, fieldnames=keys)
282
+ w.writeheader()
283
+ for r in rows:
284
+ flat = {k: r.get(k, "") for k in keys}
285
+ w.writerow(flat)
286
+ return CSV_EXPORT_PATH
287
+
288
+ def save_feedback(image: Image.Image, model_choice: str, prompt: str,
289
+ prediction: str, correction: str, ground_truth: str, reward: int):
290
+ """
291
+ reward: 1 = good/accepted, 0 = neutral, -1 = bad
292
+ """
293
+ if image is None:
294
+ return "Please provide the image again to link feedback.", 0
295
+ if not prediction and not correction and not ground_truth:
296
+ return "Nothing to save.", 0
297
+
298
+ image_hash = _hash_image(image)
299
+ # best target = correction, else ground_truth, else prediction
300
+ target = _safe_text(correction) or _safe_text(ground_truth)
301
+ pred = _safe_text(prediction)
302
+ cer_score = None
303
+ if target and pred:
304
+ try:
305
+ cer_score = cer(" ".join(target.split()), " ".join(pred.split()))
306
+ except Exception:
307
+ cer_score = None
308
+
309
+ row = {
310
+ "id": str(uuid.uuid4()),
311
+ "timestamp": datetime.utcnow().isoformat(),
312
+ "model_choice": model_choice or "",
313
+ "image_sha256": image_hash,
314
+ "prompt": _safe_text(prompt),
315
+ "prediction": pred,
316
+ "correction": _safe_text(correction),
317
+ "ground_truth": _safe_text(ground_truth),
318
+ "reward": int(reward),
319
+ "cer": float(cer_score) if cer_score is not None else None,
320
+ }
321
+ _append_jsonl(FEEDBACK_PATH, row)
322
+ return f"βœ… Feedback saved (reward={reward}).", 1
323
 
324
+ def compile_memory_rules():
325
+ _compile_rules_from_feedback(min_count=2, max_phrase_len=60)
326
+ return "βœ… Memory rules recompiled from positive feedback."
327
+
328
+ def export_grpo_preferences():
329
+ """
330
+ Build preference pairs for GRPO training:
331
+ - chosen: correction/ground_truth when present
332
+ - rejected: original prediction
333
+ """
334
+ if not os.path.exists(FEEDBACK_PATH):
335
+ return "No feedback to export."
336
+ count = 0
337
+ with open(GRPO_EXPORT_PATH, "w", encoding="utf-8") as out_f:
338
+ with open(FEEDBACK_PATH, "r", encoding="utf-8") as f:
339
+ for line in f:
340
+ try:
341
+ row = json.loads(line)
342
+ except Exception:
343
+ continue
344
+ pred = _safe_text(row.get("prediction", ""))
345
+ corr = _safe_text(row.get("correction", "")) or _safe_text(row.get("ground_truth", ""))
346
+ prompt = _safe_text(row.get("prompt", "")) or "Transcribe the image exactly."
347
+ if corr and pred and corr != pred and row.get("reward", 0) >= 0:
348
+ # One preference datapoint
349
+ out = {
350
+ "prompt": prompt,
351
+ "image_sha256": row.get("image_sha256", ""),
352
+ "chosen": corr,
353
+ "rejected": pred,
354
+ "model_choice": row.get("model_choice", "")
355
+ }
356
+ out_f.write(json.dumps(out, ensure_ascii=False) + "\n")
357
+ count += 1
358
+ return f"βœ… Exported {count} GRPO preference pairs to {GRPO_EXPORT_PATH}."
359
 
360
+ def export_csv():
361
+ p = _export_csv()
362
+ if p:
363
+ return f"βœ… CSV exported: {p}"
364
+ return "No data to export."
365
+
366
+ # ---------------- Evaluation Orchestration ----------------
367
+ @spaces.GPU
368
+ def perform_evaluation(image: Image.Image, model_name: str, ground_truth: str,
369
+ max_new_tokens: int, temperature: float, top_p: float, top_k: int, repetition_penalty: float,
370
+ use_memory: bool = True):
371
+ if image is None or not ground_truth:
372
+ return "Please upload an image and provide the ground truth.", "N/A"
373
+ prediction = ocr_image(image, model_name, max_new_tokens=max_new_tokens,
374
+ temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty,
375
+ use_memory=use_memory)
376
+ cer_score = calculate_cer_score(ground_truth, prediction)
377
+ return prediction, cer_score
378
+
379
+ # ---------------- GRPO Trainer Script Writer ----------------
380
+ TRAINER_SCRIPT = r"""# grpo_train.py β€” Offline GRPO training with TRL (run separately)
381
+ # pip install trl accelerate peft transformers datasets
382
+ # This script expects data/grpo_prefs.jsonl produced by the app.
383
+
384
+ import os, json
385
+ from datasets import load_dataset
386
+ from transformers import AutoModelForCausalLM, AutoTokenizer
387
+ from trl import GRPOConfig, GRPOTrainer
388
+
389
+ MODEL_ID = os.environ.get("BASE_MODEL", "Qwen/Qwen2.5-VL-7B-Instruct") # change if needed
390
+ OUTPUT_DIR = os.environ.get("OUTPUT_DIR", "grpo_output")
391
+ DATA_PATH = os.environ.get("DATA_PATH", "data/grpo_prefs.jsonl")
392
+
393
+ # Our jsonl: each line has prompt, chosen, rejected (and image_sha256/model_choice optionally)
394
+ # We'll format as required by TRL: prompt + responses with one preferred
395
+
396
+ def _jsonl_dataset(jsonl_path):
397
+ data = []
398
+ with open(jsonl_path, "r", encoding="utf-8") as f:
399
+ for line in f:
400
+ try:
401
+ row = json.loads(line)
402
+ except Exception:
403
+ continue
404
+ prompt = row.get("prompt", "")
405
+ chosen = row.get("chosen", "")
406
+ rejected = row.get("rejected", "")
407
+ if prompt and chosen and rejected:
408
+ data.append({"prompt": prompt, "chosen": chosen, "rejected": rejected})
409
+ return data
410
+
411
+ def main():
412
+ data = _jsonl_dataset(DATA_PATH)
413
+ if not data:
414
+ print("No GRPO data found.")
415
+ return
416
+ # Create a HuggingFace datasets Dataset from memory
417
+ from datasets import Dataset
418
+ ds = Dataset.from_list(data)
419
+
420
+ tok = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
421
+ model = AutoModelForCausalLM.from_pretrained(
422
+ MODEL_ID, trust_remote_code=True, device_map="auto"
423
+ )
424
+
425
+ # Minimal config β€” tune to your GPU
426
+ cfg = GRPOConfig(
427
+ output_dir=OUTPUT_DIR,
428
+ learning_rate=5e-6,
429
+ per_device_train_batch_size=1,
430
+ gradient_accumulation_steps=8,
431
+ num_train_epochs=1,
432
+ logging_steps=10,
433
+ save_steps=200,
434
+ max_prompt_length=512,
435
+ max_completion_length=768,
436
+ bf16=True
437
+ )
438
+
439
+ trainer = GRPOTrainer(
440
+ model=model,
441
+ ref_model=None, # let TRL create a frozen copy internally
442
+ args=cfg,
443
+ tokenizer=tok,
444
+ train_dataset=ds
445
+ )
446
+ trainer.train()
447
+ trainer.save_model(OUTPUT_DIR)
448
+ print("βœ… GRPO training complete. LoRA/weights saved to", OUTPUT_DIR)
449
+
450
+ if __name__ == "__main__":
451
+ main()
452
+ """
453
+
454
+ def _write_trainer_script():
455
+ os.makedirs("train", exist_ok=True)
456
+ path = os.path.join("train", "grpo_train.py")
457
+ with open(path, "w", encoding="utf-8") as f:
458
+ f.write(TRAINER_SCRIPT)
459
+ return path
460
 
461
  # ---------------- Gradio Interface ----------------
462
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
463
+ gr.Markdown("## ✍🏾 wilson Handwritten text recognition with Feedback Loop")
464
 
465
+ model_choice = gr.Radio(choices=list(MODEL_PATHS.keys()),
466
+ value=list(MODEL_PATHS.keys())[0],
467
+ label="Select OCR Model")
468
 
469
+ with gr.Tab("πŸ–Ό Image Inference"):
470
+ query_input = gr.Textbox(label="Custom Prompt (optional)", placeholder="Leave empty for RAW structured output")
471
+ image_input = gr.Image(type="pil", label="Upload / Capture Handwritten Image", sources=["upload", "webcam"])
472
+ use_memory = gr.Checkbox(value=True, label="Enable Memory Post-correction (auto-fix known mistakes)")
473
 
474
  with gr.Accordion("βš™οΈ Advanced Options", open=False):
475
  max_new_tokens = gr.Slider(1, 2048, value=MAX_NEW_TOKENS_DEFAULT, step=1, label="Max new tokens")
476
  temperature = gr.Slider(0.1, 2.0, value=0.1, step=0.05, label="Temperature")
477
+ top_p = gr.Slider(0.05, 1.0, value=1.0, step=0.05, label="Top-p (nucleus)")
478
+ top_k = gr.Slider(0, 1000, value=0, step=1, label="Top-k")
479
+ repetition_penalty = gr.Slider(0.8, 2.0, value=1.0, step=0.05, label="Repetition penalty")
480
+
481
+ extract_btn = gr.Button("πŸ“€ Extract RAW Text", variant="primary")
482
+ clear_btn = gr.Button("🧹 Clear")
483
 
484
+ raw_output = gr.Textbox(label="πŸ“œ Output (post-corrected if memory is ON)", lines=18, show_copy_button=True)
 
485
 
486
+ # Quick Feedback strip
487
  gr.Markdown("### ✏️ Quick Feedback")
488
+ correction_box = gr.Textbox(label="Your Correction (optional)", placeholder="Paste your corrected text here; leave empty if the output is perfect.", lines=8)
489
+ ground_truth_box = gr.Textbox(label="Ground Truth (optional)", placeholder="If you have a reference transcription, paste it here.", lines=6)
490
+
491
+ with gr.Row():
492
+ btn_good = gr.Button("πŸ‘ Accept (Save Feedback as Correct)", variant="primary")
493
+ btn_bad = gr.Button("πŸ‘Ž Bad (Save Feedback as Incorrect)")
494
+
495
  feedback_status = gr.Markdown()
496
 
497
+ pdf_btn = gr.Button("⬇️ Download as PDF")
498
+ word_btn = gr.Button("⬇️ Download as Word")
499
+ audio_btn = gr.Button("πŸ”Š Download as Audio")
500
+ pdf_file, word_file, audio_file = gr.File(label="PDF File"), gr.File(label="Word File"), gr.File(label="Audio File")
501
+
502
+ extract_btn.click(
503
+ fn=ocr_image,
504
+ inputs=[image_input, model_choice, query_input, max_new_tokens, temperature, top_p, top_k, repetition_penalty, use_memory],
505
+ outputs=[raw_output],
506
+ api_name="ocr_image"
507
+ )
508
+ pdf_btn.click(fn=save_as_pdf, inputs=[raw_output], outputs=[pdf_file])
509
+ word_btn.click(fn=save_as_word, inputs=[raw_output], outputs=[word_file])
510
+ audio_btn.click(fn=save_as_audio, inputs=[raw_output], outputs=[audio_file])
511
+
512
+ def _clear():
513
+ return ("", None, "", MAX_NEW_TOKENS_DEFAULT, 0.1, 1.0, 0, 1.0, True, "", "", "",)
514
+ clear_btn.click(
515
+ fn=_clear,
516
+ outputs=[raw_output, image_input, query_input, max_new_tokens, temperature, top_p, top_k, repetition_penalty, use_memory, correction_box, ground_truth_box, feedback_status]
517
+ )
518
+
519
+ # Quick feedback save
520
+ btn_good.click(
521
+ fn=lambda img, mc, prmpt, pred, corr, gt: save_feedback(img, mc, prmpt, pred, corr, gt, reward=1),
522
+ inputs=[image_input, model_choice, query_input, raw_output, correction_box, ground_truth_box],
523
+ outputs=[feedback_status]
524
+ )
525
+ btn_bad.click(
526
+ fn=lambda img, mc, prmpt, pred, corr, gt: save_feedback(img, mc, prmpt, pred, corr, gt, reward=-1),
527
+ inputs=[image_input, model_choice, query_input, raw_output, correction_box, ground_truth_box],
528
+ outputs=[feedback_status]
529
+ )
530
+
531
+ with gr.Tab("πŸ“Š Model Evaluation"):
532
+ gr.Markdown("### πŸ” Evaluate Model Accuracy")
533
+ eval_image_input = gr.Image(type="pil", label="Upload Image for Evaluation", sources=["upload"])
534
+ eval_ground_truth = gr.Textbox(label="Ground Truth (Correct Transcription)", lines=10, placeholder="Type or paste the correct text here.")
535
+ eval_model_output = gr.Textbox(label="Model's Prediction", lines=10, interactive=False, show_copy_button=True)
536
+ eval_cer_output = gr.Textbox(label="Metrics", interactive=False)
537
+ eval_use_memory = gr.Checkbox(value=True, label="Enable Memory Post-correction")
538
+
539
+ with gr.Row():
540
+ run_evaluation_btn = gr.Button("πŸš€ Run OCR and Evaluate", variant="primary")
541
+ clear_evaluation_btn = gr.Button("🧹 Clear")
542
+
543
+ run_evaluation_btn.click(
544
+ fn=perform_evaluation,
545
+ inputs=[eval_image_input, model_choice, eval_ground_truth, max_new_tokens, temperature, top_p, top_k, repetition_penalty, eval_use_memory],
546
+ outputs=[eval_model_output, eval_cer_output]
547
+ )
548
+ clear_evaluation_btn.click(
549
+ fn=lambda: (None, "", "", ""),
550
+ outputs=[eval_image_input, eval_ground_truth, eval_model_output, eval_cer_output]
551
+ )
552
+
553
+ with gr.Tab("✏️ Feedback & Memory"):
554
+ gr.Markdown("""
555
+ **Pipeline**
556
+ 1) Save feedback (πŸ‘ / πŸ‘Ž) and add corrections.
557
+ 2) Click **Build/Refresh Memory** to generate auto-fix rules from positive feedback.
558
+ 3) Keep **Enable Memory Post-correction** checked on inference/eval tabs.
559
+ """)
560
+ build_mem_btn = gr.Button("🧠 Build/Refresh Memory from Feedback")
561
+ mem_status = gr.Markdown()
562
+ build_mem_btn.click(fn=compile_memory_rules, outputs=[mem_status])
563
+
564
+ csv_btn = gr.Button("πŸ“€ Export Feedback as CSV")
565
+ csv_status = gr.Markdown()
566
+ csv_btn.click(fn=export_csv, outputs=[csv_status])
567
+
568
+ with gr.Tab("πŸ§ͺ GRPO / Dataset"):
569
+ gr.Markdown("""
570
+ **GRPO Fine-tuning** (run offline or in a training Space):
571
+ - Click **Export GRPO Preferences** to produce `data/grpo_prefs.jsonl` of (prompt, chosen, rejected).
572
+ - Click **Write Trainer Script** to create `train/grpo_train.py`.
573
+ - Then run:
574
+ ```bash
575
+ pip install trl accelerate peft transformers datasets
576
+ python train/grpo_train.py
577
+ ```
578
+ Set `BASE_MODEL`/`OUTPUT_DIR` env vars if you like.
579
+ """)
580
+ grpo_btn = gr.Button("πŸ“¦ Export GRPO Preferences")
581
  grpo_status = gr.Markdown()
582
+ grpo_btn.click(fn=export_grpo_preferences, outputs=[grpo_status])
583
+
584
+ write_script_btn = gr.Button("πŸ“ Write grpo_train.py")
585
+ write_script_status = gr.Markdown()
586
+ write_script_btn.click(fn=lambda: f"βœ… Trainer script written to `{_write_trainer_script()}`", outputs=[write_script_status])
587
 
588
  if __name__ == "__main__":
589
  demo.queue(max_size=50).launch(share=True)