File size: 26,627 Bytes
28b35fd
d11f062
28b35fd
5b9541c
28b35fd
8adac94
28b35fd
61e3d24
28b35fd
61e3d24
 
5754029
 
 
 
5b9541c
5754029
61e3d24
28b35fd
5b9541c
414e721
 
 
 
5b9541c
406b226
61e3d24
455c83a
 
61e3d24
59576ba
28b35fd
 
927e645
406b226
927e645
28b35fd
8adac94
28b35fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5754029
21f219b
28b35fd
 
414e721
28b35fd
 
d11f062
 
f3b428e
d11f062
f3b428e
28b35fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8adac94
5b9541c
 
 
28b35fd
 
 
 
21f219b
28b35fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d11f062
 
 
28b35fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
414e721
28b35fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5754029
 
 
28b35fd
 
 
 
 
5754029
28b35fd
406b226
5754029
414e721
28b35fd
 
 
 
5754029
28b35fd
 
5754029
c3250ac
5754029
28b35fd
 
 
5754029
 
28b35fd
5754029
c3250ac
5754029
28b35fd
 
5754029
 
de5e2ab
28b35fd
5754029
c3250ac
28b35fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c159940
28b35fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c159940
28b35fd
c159940
28b35fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c159940
5b9541c
28b35fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b9541c
c159940
 
 
 
 
 
 
 
 
 
28b35fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
414e721
28b35fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
414e721
28b35fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b9541c
5754029
927e645
82b353e
5b9541c
28b35fd
dd32cc6
 
ded4e8a
28b35fd
 
 
 
5b9541c
5754029
 
 
28b35fd
 
 
 
 
 
5b9541c
28b35fd
5b9541c
28b35fd
5b9541c
28b35fd
 
 
 
 
 
 
dd32cc6
5b9541c
28b35fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c159940
 
 
28b35fd
 
 
dd32cc6
28b35fd
 
dd32cc6
c159940
 
 
 
 
 
 
28b35fd
977d71c
 
 
 
 
 
 
 
 
 
 
 
 
 
dd32cc6
977d71c
 
 
dd32cc6
977d71c
 
 
 
 
 
 
 
 
5a60fb2
befd148
5a60fb2
d11f062
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
# app.py β€” HTR Space with Feedback Loop, Memory Post-Correction, and GRPO Export

import os, time, json, hashlib, difflib, uuid, csv
from datetime import datetime
from collections import Counter, defaultdict
from threading import Thread

import gradio as gr
import spaces
from PIL import Image
import torch
from transformers import AutoProcessor, AutoModelForImageTextToText, Qwen2_5_VLForConditionalGeneration
from reportlab.platypus import SimpleDocTemplate, Paragraph
from reportlab.lib.styles import getSampleStyleSheet
from docx import Document
from gtts import gTTS
from jiwer import cer

# ---------------- Storage & Paths ----------------
os.makedirs("data", exist_ok=True)
FEEDBACK_PATH = "data/feedback.jsonl" # raw feedback log (per sample)
MEMORY_RULES_PATH = "data/memory_rules.json" # compiled post-correction rules
GRPO_EXPORT_PATH = "data/grpo_prefs.jsonl" # preference pairs for GRPO
CSV_EXPORT_PATH = "data/feedback.csv" # optional tabular export

# ---------------- Models ----------------
MODEL_PATHS = {
    "Model 1 (Complex handwritings)": ("prithivMLmods/Qwen2.5-VL-7B-Abliterated-Caption-it", Qwen2_5_VLForConditionalGeneration),
    "Model 2 (simple and scanned handwriting)": ("nanonets/Nanonets-OCR-s", Qwen2_5_VLForConditionalGeneration),
}
# Model 3 removed to conserve memory.

MAX_NEW_TOKENS_DEFAULT = 512
device = "cuda" if torch.cuda.is_available() else "cpu"
_loaded_processors, _loaded_models = {}, {}

print("πŸš€ Preloading models into GPU/CPU memory...")
for name, (repo_id, cls) in MODEL_PATHS.items():
    try:
        processor = AutoProcessor.from_pretrained(repo_id, trust_remote_code=True)
        model = cls.from_pretrained(
            repo_id,
            trust_remote_code=True,
            torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
            low_cpu_mem_usage=True
        ).to(device).eval()
        _loaded_processors[name], _loaded_models[name] = processor, model
        print(f"βœ… {name} ready.")
    except Exception as e:
        print(f"⚠️ Failed to load {name}: {e}")

# ---------------- GPU Warmup ----------------
@spaces.GPU
def warmup(progress=gr.Progress(track_tqdm=True)):
    try:
        default_model_choice = next(iter(MODEL_PATHS.keys()))
        processor = _loaded_processors[default_model_choice]
        model = _loaded_models[default_model_choice]
        tokenizer = getattr(processor, "tokenizer", None)
        messages = [{"role": "user", "content": [{"type": "text", "text": "Warmup."}]}]
        chat_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) if tokenizer and hasattr(tokenizer, "apply_chat_template") else "Warmup."
        inputs = processor(text=[chat_prompt], images=None, return_tensors="pt").to(device)
        with torch.inference_mode():
            _ = model.generate(**inputs, max_new_tokens=1)
        return f"GPU warm and {default_model_choice} ready."
    except Exception as e:
        return f"Warmup skipped: {e}"

# ---------------- Helpers ----------------
def _build_inputs(processor, tokenizer, image: Image.Image, prompt: str):
    messages = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": prompt}]}]
    
    if tokenizer and hasattr(tokenizer, "apply_chat_template"):
        chat_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        # Explicitly set truncation=False to prevent the token mismatch error
        return processor(text=[chat_prompt], images=[image], return_tensors="pt", truncation=False)
        
    return processor(text=[prompt], images=[image], return_tensors="pt", truncation=False)


def _decode_text(model, processor, tokenizer, output_ids, prompt: str):
    try:
        decoded_text = processor.batch_decode(output_ids, skip_special_tokens=True)[0]
        prompt_start = decoded_text.find(prompt)
        if prompt_start != -1:
            decoded_text = decoded_text[prompt_start + len(prompt):].strip()
        else:
            decoded_text = decoded_text.strip()
        return decoded_text
    except Exception:
        try:
            decoded_text = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
            prompt_start = decoded_text.find(prompt)
            if prompt_start != -1:
                decoded_text = decoded_text[prompt_start + len(prompt):].strip()
            return decoded_text
        except Exception:
            return str(output_ids).strip()

def _default_prompt(query: str | None) -> str:
    if query and query.strip():
        return query.strip()
    return (
        "You are a professional Handwritten OCR system.\n"
        "TASK: Read the handwritten image and transcribe the text EXACTLY as written.\n"
        "- Preserve original structure and line breaks.\n"
        "- Keep spacing, bullet points, numbering, and indentation.\n"
        "- Render tables as Markdown tables if present.\n"
        "- Do NOT autocorrect spelling or grammar.\n"
        "- Do NOT merge lines.\n"
        "Return RAW transcription only."
    )

def _safe_text(text: str) -> str:
    return (text or "").strip()

def _hash_image(image: Image.Image) -> str:
    # stable hash for dedup / linking feedback to the same page
    img_bytes = image.tobytes()
    return hashlib.sha256(img_bytes).hexdigest()

# ---------------- Memory: Post-correction Rules ----------------
def _load_memory_rules():
    if os.path.exists(MEMORY_RULES_PATH):
        try:
            with open(MEMORY_RULES_PATH, "r", encoding="utf-8") as f:
                return json.load(f)
        except Exception:
            pass
    return {"global": {}, "by_model": {}}

def _save_memory_rules(rules):
    with open(MEMORY_RULES_PATH, "w", encoding="utf-8") as f:
        json.dump(rules, f, ensure_ascii=False, indent=2)

def _apply_memory(text: str, model_choice: str, enabled: bool):
    if not enabled or not text:
        return text
    rules = _load_memory_rules()
    # 1) Model-specific replacements
    by_model = rules.get("by_model", {}).get(model_choice, {})
    for wrong, right in by_model.items():
        if wrong and right:
            text = text.replace(wrong, right)
    # 2) Global replacements
    for wrong, right in rules.get("global", {}).items():
        for wrong, right in rules.get("global", {}).items():
            if wrong and right:
                text = text.replace(wrong, right)
    return text

def _compile_rules_from_feedback(min_count: int = 2, max_phrase_len: int = 40):
    """
    Build replacement rules by mining feedback pairs (prediction -> correction).
    We extract phrases that consistently changed, with frequency >= min_count.
    """
    changes_counter_global = Counter()
    changes_counter_by_model = defaultdict(Counter)

    if not os.path.exists(FEEDBACK_PATH):
        return

    with open(FEEDBACK_PATH, "r", encoding="utf-8") as f:
        for line in f:
            try:
                row = json.loads(line)
            except Exception:
                continue
            if row.get("reward", 0) < 1: # only learn from thumbs-up or explicit 'accepted_correction'
                continue
            pred = _safe_text(row.get("prediction", ""))
            corr = _safe_text(row.get("correction", "")) or _safe_text(row.get("ground_truth", ""))
            if not pred or not corr:
                continue
            model_choice = row.get("model_choice", "")
            # Extract ops
            s = difflib.SequenceMatcher(None, pred, corr)
            for tag, i1, i2, j1, j2 in s.get_opcodes():
                if tag in ("replace", "delete", "insert"):
                    wrong = pred[i1:i2]
                    right = corr[j1:j2]
                    # keep short-ish tokens/phrases
                    if 0 < len(wrong) <= max_phrase_len or 0 < len(right) <= max_phrase_len:
                        if wrong.strip():
                            changes_counter_global[(wrong, right)] += 1
                            if model_choice:
                                changes_counter_by_model[model_choice][(wrong, right)] += 1

    rules = {"global": {}, "by_model": {}}
    # Global
    for (wrong, right), cnt in changes_counter_global.items():
        if cnt >= min_count and wrong and right and wrong != right:
            rules["global"][wrong] = right
    # Per model
    for model_choice, ctr in changes_counter_by_model.items():
        rules["by_model"].setdefault(model_choice, {})
        for (wrong, right), cnt in ctr.items():
            if cnt >= min_count and wrong and right and wrong != right:
                rules["by_model"][model_choice][wrong] = right

    _save_memory_rules(rules)

# ---------------- OCR Function ----------------
@spaces.GPU
def ocr_image(image: Image.Image, model_choice: str, query: str = None,
              max_new_tokens: int = MAX_NEW_TOKENS_DEFAULT,
              temperature: float = 0.1, top_p: float = 1.0, top_k: int = 0, repetition_penalty: float = 1.0,
              use_memory: bool = True,
              progress=gr.Progress(track_tqdm=True)):
    if image is None: return "Please upload or capture an image."
    if model_choice not in _loaded_models: return f"Invalid model: {model_choice}"
    processor, model, tokenizer = _loaded_processors[model_choice], _loaded_models[model_choice], getattr(_loaded_processors[model_choice], "tokenizer", None)
    prompt = _default_prompt(query)
    batch = _build_inputs(processor, tokenizer, image, prompt).to(device)
    with torch.inference_mode():
        output_ids = model.generate(**batch, max_new_tokens=max_new_tokens, do_sample=False,
                                     temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty)
    raw = _decode_text(model, processor, tokenizer, output_ids, prompt).replace("<|im_end|>", "").strip()
    # Apply memory post-correction
    post = _apply_memory(raw, model_choice, use_memory)
    return post

# ---------------- Export Helpers ----------------
def save_as_pdf(text):
    text = _safe_text(text)
    if not text: return None
    doc = SimpleDocTemplate("output.pdf")
    flowables = [Paragraph(t, getSampleStyleSheet()["Normal"]) for t in text.splitlines() if t != ""]
    if not flowables: flowables = [Paragraph(" ", getSampleStyleSheet()["Normal"])]
    doc.build(flowables)
    return "output.pdf"

def save_as_word(text):
    text = _safe_text(text)
    if not text: return None
    doc = Document()
    for line in text.splitlines():
        doc.add_paragraph(line)
    doc.save("output.docx")
    return "output.docx"

def save_as_audio(text):
    text = _safe_text(text)
    if not text: return None
    try:
        tts = gTTS(text)
        tts.save("output.mp3")
        return "output.mp3"
    except Exception as e:
        print(f"gTTS failed: {e}")
        return None

# ---------------- Metrics Function ----------------
def calculate_cer_score(ground_truth: str, prediction: str) -> str:
    """
    Calculates the Character Error Rate (CER).
    A CER of 0.0 means the prediction is perfect.
    """
    if not ground_truth or not prediction:
        return "Cannot calculate CER: Missing ground truth or prediction."
    ground_truth_cleaned = " ".join(ground_truth.strip().split())
    prediction_cleaned = " ".join(prediction.strip().split())
    error_rate = cer(ground_truth_cleaned, prediction_cleaned)
    return f"Character Error Rate (CER): {error_rate:.4f}"

# ---------------- Feedback & Dataset ----------------
def _append_jsonl(path, obj):
    with open(path, "a", encoding="utf-8") as f:
        f.write(json.dumps(obj, ensure_ascii=False) + "\n")

def _export_csv():
    # optional: CSV summary for spreadsheet views
    if not os.path.exists(FEEDBACK_PATH):
        return None
    rows = []
    with open(FEEDBACK_PATH, "r", encoding="utf-8") as f:
        for line in f:
            try:
                rows.append(json.loads(line))
            except Exception:
                pass
    if not rows:
        return None
    keys = ["id","timestamp","model_choice","image_sha256","prompt","prediction","correction","ground_truth","reward","cer"]
    with open(CSV_EXPORT_PATH, "w", newline="", encoding="utf-8") as f:
        w = csv.DictWriter(f, fieldnames=keys)
        w.writeheader()
        for r in rows:
            flat = {k: r.get(k, "") for k in keys}
            w.writerow(flat)
    return CSV_EXPORT_PATH

def save_feedback(image: Image.Image, model_choice: str, prompt: str,
                  prediction: str, correction: str, ground_truth: str, reward: int):
    """
    reward: 1 = good/accepted, 0 = neutral, -1 = bad
    """
    if image is None:
        return "Please provide the image again to link feedback."
    if not prediction and not correction and not ground_truth:
        return "Nothing to save."

    image_hash = _hash_image(image)
    # best target = correction, else ground_truth, else prediction
    target = _safe_text(correction) or _safe_text(ground_truth)
    pred = _safe_text(prediction)
    cer_score = None
    if target and pred:
        try:
            cer_score = cer(" ".join(target.split()), " ".join(pred.split()))
        except Exception:
            cer_score = None

    row = {
        "id": str(uuid.uuid4()),
        "timestamp": datetime.utcnow().isoformat(),
        "model_choice": model_choice or "",
        "image_sha256": image_hash,
        "prompt": _safe_text(prompt),
        "prediction": pred,
        "correction": _safe_text(correction),
        "ground_truth": _safe_text(ground_truth),
        "reward": int(reward),
        "cer": float(cer_score) if cer_score is not None else None,
    }
    _append_jsonl(FEEDBACK_PATH, row)
    return f"βœ… Feedback saved (reward={reward})."

def compile_memory_rules():
    _compile_rules_from_feedback(min_count=2, max_phrase_len=60)
    return "βœ… Memory rules recompiled from positive feedback."

def export_grpo_preferences():
    """
    Build preference pairs for GRPO training:
    - chosen: correction/ground_truth when present
    - rejected: original prediction
    """
    if not os.path.exists(FEEDBACK_PATH):
        return "No feedback to export."
    count = 0
    with open(GRPO_EXPORT_PATH, "w", encoding="utf-8") as out_f:
        with open(FEEDBACK_PATH, "r", encoding="utf-8") as f:
            for line in f:
                try:
                    row = json.loads(line)
                except Exception:
                    continue
                pred = _safe_text(row.get("prediction", ""))
                corr = _safe_text(row.get("correction", "")) or _safe_text(row.get("ground_truth", ""))
                prompt = _safe_text(row.get("prompt", "")) or "Transcribe the image exactly."
                if corr and pred and corr != pred and row.get("reward", 0) >= 0:
                    # One preference datapoint
                    out = {
                        "prompt": prompt,
                        "image_sha256": row.get("image_sha256", ""),
                        "chosen": corr,
                        "rejected": pred,
                        "model_choice": row.get("model_choice", "")
                    }
                    out_f.write(json.dumps(out, ensure_ascii=False) + "\n")
                    count += 1
    return f"βœ… Exported {count} GRPO preference pairs to {GRPO_EXPORT_PATH}."

def get_grpo_file():
    if os.path.exists(GRPO_EXPORT_PATH):
        return GRPO_EXPORT_PATH
    return None

def get_csv_file():
    _export_csv()
    if os.path.exists(CSV_EXPORT_PATH):
        return CSV_EXPORT_PATH
    return None

# ---------------- Evaluation Orchestration ----------------
@spaces.GPU
def perform_evaluation(image: Image.Image, model_name: str, ground_truth: str,
                       max_new_tokens: int, temperature: float, top_p: float, top_k: int, repetition_penalty: float,
                       use_memory: bool = True):
    if image is None or not ground_truth:
        return "Please upload an image and provide the ground truth.", "N/A"
    prediction = ocr_image(image, model_name, max_new_tokens=max_new_tokens,
                           temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty,
                           use_memory=use_memory)
    cer_score = calculate_cer_score(ground_truth, prediction)
    return prediction, cer_score

# ---------------- GRPO Trainer Script Writer ----------------
TRAINER_SCRIPT = r"""# grpo_train.py β€” Offline GRPO training with TRL (run separately)
# pip install trl accelerate peft transformers datasets
# This script expects data/grpo_prefs.jsonl produced by the app.

import os, json
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import GRPOConfig, GRPOTrainer

MODEL_ID = os.environ.get("BASE_MODEL", "Qwen/Qwen2.5-VL-7B-Instruct") # change if needed
OUTPUT_DIR = os.environ.get("OUTPUT_DIR", "grpo_output")
DATA_PATH = os.environ.get("DATA_PATH", "data/grpo_prefs.jsonl")

# Our jsonl: each line has prompt, chosen, rejected (and image_sha256/model_choice optionally)
# We'll format as required by TRL: prompt + responses with one preferred

def _jsonl_dataset(jsonl_path):
    data = []
    with open(jsonl_path, "r", encoding="utf-8") as f:
        for line in f:
            try:
                row = json.loads(line)
            except Exception:
                continue
            prompt = row.get("prompt", "")
            chosen = row.get("chosen", "")
            rejected = row.get("rejected", "")
            if prompt and chosen and rejected:
                data.append({"prompt": prompt, "chosen": chosen, "rejected": rejected})
    return data

def main():
    data = _jsonl_dataset(DATA_PATH)
    if not data:
        print("No GRPO data found.")
        return
    # Create a HuggingFace datasets Dataset from memory
    from datasets import Dataset
    ds = Dataset.from_list(data)

    tok = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID, trust_remote_code=True, device_map="auto"
    )

    # Minimal config β€” tune to your GPU
    cfg = GRPOConfig(
        output_dir=OUTPUT_DIR,
        learning_rate=5e-6,
        per_device_train_batch_size=1,
        gradient_accumulation_steps=8,
        num_train_epochs=1,
        logging_steps=10,
        save_steps=200,
        max_prompt_length=512,
        max_completion_length=768,
        bf16=True
    )

    trainer = GRPOTrainer(
        model=model,
        ref_model=None, # let TRL create a frozen copy internally
        args=cfg,
        tokenizer=tok,
        train_dataset=ds
    )
    trainer.train()
    trainer.save_model(OUTPUT_DIR)
    print("βœ… GRPO training complete. LoRA/weights saved to", OUTPUT_DIR)

if __name__ == "__main__":
    main()
"""

def _write_trainer_script():
    os.makedirs("train", exist_ok=True)
    path = os.path.join("train", "grpo_train.py")
    with open(path, "w", encoding="utf-8") as f:
        f.write(TRAINER_SCRIPT)
    return path

# ---------------- Gradio Interface ----------------
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("## ✍🏾 Wilson Handwritten OCR β€” with Feedback Loop")

    model_choice = gr.Radio(choices=list(MODEL_PATHS.keys()),
                             value=list(MODEL_PATHS.keys())[0],
                             label="Select OCR Model")

    with gr.Tab("πŸ–Ό Image Inference"):
        query_input = gr.Textbox(label="Custom Prompt (optional)", placeholder="Leave empty for RAW structured output")
        image_input = gr.Image(type="pil", label="Upload / Capture Handwritten Image", sources=["upload", "webcam"])
        use_memory = gr.Checkbox(value=True, label="Enable Memory Post-correction (auto-fix known mistakes)")

        with gr.Accordion("βš™οΈ Advanced Options", open=False):
            max_new_tokens = gr.Slider(1, 2048, value=MAX_NEW_TOKENS_DEFAULT, step=1, label="Max new tokens")
            temperature = gr.Slider(0.1, 2.0, value=0.1, step=0.05, label="Temperature")
            top_p = gr.Slider(0.05, 1.0, value=1.0, step=0.05, label="Top-p (nucleus)")
            top_k = gr.Slider(0, 1000, value=0, step=1, label="Top-k")
            repetition_penalty = gr.Slider(0.8, 2.0, value=1.0, step=0.05, label="Repetition penalty")

        extract_btn = gr.Button("πŸ“€ Extract RAW Text", variant="primary")
        clear_btn = gr.Button("🧹 Clear")

        raw_output = gr.Textbox(label="πŸ“œ Output (post-corrected if memory is ON)", lines=18, show_copy_button=True)

        # Quick Feedback strip
        gr.Markdown("### ✏️ Quick Feedback")
        correction_box = gr.Textbox(label="Your Correction (optional)", placeholder="Paste your corrected text here; leave empty if the output is perfect.", lines=8)
        ground_truth_box = gr.Textbox(label="Ground Truth (optional)", placeholder="If you have a reference transcription, paste it here.", lines=6)

        with gr.Row():
            btn_good = gr.Button("πŸ‘ Accept (Save Feedback as Correct)", variant="primary")
            btn_bad = gr.Button("πŸ‘Ž Bad (Save Feedback as Incorrect)")

        feedback_status = gr.Markdown("")

        pdf_btn = gr.Button("⬇️ Download as PDF")
        word_btn = gr.Button("⬇️ Download as Word")
        audio_btn = gr.Button("πŸ”Š Download as Audio")
        pdf_file, word_file, audio_file = gr.File(label="PDF File"), gr.File(label="Word File"), gr.File(label="Audio File")

        extract_btn.click(
            fn=ocr_image,
            inputs=[image_input, model_choice, query_input, max_new_tokens, temperature, top_p, top_k, repetition_penalty, use_memory],
            outputs=[raw_output],
            api_name="ocr_image"
        )
        pdf_btn.click(fn=save_as_pdf, inputs=[raw_output], outputs=[pdf_file])
        word_btn.click(fn=save_as_word, inputs=[raw_output], outputs=[word_file])
        audio_btn.click(fn=save_as_audio, inputs=[raw_output], outputs=[audio_file])

        def _clear():
            return ("", None, "", MAX_NEW_TOKENS_DEFAULT, 0.1, 1.0, 0, 1.0, True, "", "", "",)
        clear_btn.click(
            fn=_clear,
            outputs=[raw_output, image_input, query_input, max_new_tokens, temperature, top_p, top_k, repetition_penalty, use_memory, correction_box, ground_truth_box, feedback_status]
        )

        # Quick feedback save
        btn_good.click(
            fn=lambda img, mc, prmpt, pred, corr, gt: save_feedback(img, mc, prmpt, pred, corr, gt, reward=1),
            inputs=[image_input, model_choice, query_input, raw_output, correction_box, ground_truth_box],
            outputs=[feedback_status]
        )
        btn_bad.click(
            fn=lambda img, mc, prmpt, pred, corr, gt: save_feedback(img, mc, prmpt, pred, corr, gt, reward=-1),
            inputs=[image_input, model_choice, query_input, raw_output, correction_box, ground_truth_box],
            outputs=[feedback_status]
        )

    with gr.Tab("πŸ“Š Model Evaluation"):
        gr.Markdown("### πŸ” Evaluate Model Accuracy")
        eval_image_input = gr.Image(type="pil", label="Upload Image for Evaluation", sources=["upload"])
        eval_ground_truth = gr.Textbox(label="Ground Truth (Correct Transcription)", lines=10, placeholder="Type or paste the correct text here.")
        eval_model_output = gr.Textbox(label="Model's Prediction", lines=10, interactive=False, show_copy_button=True)
        eval_cer_output = gr.Textbox(label="Metrics", interactive=False)
        eval_use_memory = gr.Checkbox(value=True, label="Enable Memory Post-correction")

        with gr.Row():
            run_evaluation_btn = gr.Button("πŸš€ Run OCR and Evaluate", variant="primary")
            clear_evaluation_btn = gr.Button("🧹 Clear")

        run_evaluation_btn.click(
            fn=perform_evaluation,
            inputs=[eval_image_input, model_choice, eval_ground_truth, max_new_tokens, temperature, top_p, top_k, repetition_penalty, eval_use_memory],
            outputs=[eval_model_output, eval_cer_output]
        )
        clear_evaluation_btn.click(
            fn=lambda: (None, "", "", ""),
            outputs=[eval_image_input, eval_ground_truth, eval_model_output, eval_cer_output]
        )

    with gr.Tab("✏️ Feedback & Memory"):
        gr.Markdown("""
**Pipeline**
1) Save feedback (πŸ‘ / πŸ‘Ž) and add corrections.
2) Click **Build/Refresh Memory** to generate auto-fix rules from positive feedback.
3) Keep **Enable Memory Post-correction** checked on inference/eval tabs.
        """)
        build_mem_btn = gr.Button("🧠 Build/Refresh Memory from Feedback")
        mem_status = gr.Markdown("")
        build_mem_btn.click(fn=compile_memory_rules, outputs=[mem_status])

        csv_status = gr.Markdown("")

        gr.Markdown("---")
        gr.Markdown("### ⬇️ Download Feedback Data")
        with gr.Row():
            download_csv_btn = gr.Button("⬇️ Download Feedback as CSV")
            download_csv_file = gr.File(label="CSV File")
        download_csv_btn.click(fn=get_csv_file, outputs=download_csv_file)

    with gr.Tab("πŸ§ͺ GRPO / Dataset"):
        gr.Markdown("""
        **GRPO Fine-tuning** (run offline or in a training Space):
        - Click **Export GRPO Preferences** to produce `data/grpo_prefs.jsonl` of (prompt, chosen, rejected).
        - Click **Write Trainer Script** to create `train/grpo_train.py`.
        - Then run:
        ```bash
        pip install trl accelerate peft transformers datasets
        python train/grpo_train.py

        Set BASE_MODEL/OUTPUT_DIR env vars if you like.
        ```""")
        
        grpo_btn = gr.Button("πŸ“¦ Export GRPO Preferences")
        grpo_status = gr.Markdown("")
        grpo_btn.click(fn=export_grpo_preferences, outputs=[grpo_status])
        
        write_script_btn = gr.Button("πŸ“ Write grpo_train.py")
        write_script_status = gr.Markdown("")
        write_script_btn.click(fn=lambda: f"βœ… Trainer script written to {_write_trainer_script()}", outputs=[write_script_status])
        
        gr.Markdown("---")
        gr.Markdown("### ⬇️ Download GRPO Dataset")

        with gr.Row():
            download_grpo_btn = gr.Button("⬇️ Download GRPO Data (grpo_prefs.jsonl)")
            download_grpo_file = gr.File(label="GRPO Dataset File")
        download_grpo_btn.click(fn=get_grpo_file, outputs=[download_grpo_file])

# The `if __name__ == "__main__":` block should be at the top level
if __name__ == "__main__":
    demo.queue(max_size=50).launch(share=True)