Spaces:
Running
on
Zero
Running
on
Zero
# app.py β HTR Space with Feedback Loop, Memory Post-Correction, and GRPO Export | |
import os, time, json, hashlib, difflib, uuid, csv | |
from datetime import datetime | |
from collections import Counter, defaultdict | |
from threading import Thread | |
import gradio as gr | |
import spaces | |
from PIL import Image | |
import torch | |
from transformers import AutoProcessor, AutoModelForImageTextToText, Qwen2_5_VLForConditionalGeneration | |
from reportlab.platypus import SimpleDocTemplate, Paragraph | |
from reportlab.lib.styles import getSampleStyleSheet | |
from docx import Document | |
from gtts import gTTS | |
from jiwer import cer | |
# ---------------- Storage & Paths ---------------- | |
os.makedirs("data", exist_ok=True) | |
FEEDBACK_PATH = "data/feedback.jsonl" # raw feedback log (per sample) | |
MEMORY_RULES_PATH = "data/memory_rules.json" # compiled post-correction rules | |
GRPO_EXPORT_PATH = "data/grpo_prefs.jsonl" # preference pairs for GRPO | |
CSV_EXPORT_PATH = "data/feedback.csv" # optional tabular export | |
# ---------------- Models ---------------- | |
MODEL_PATHS = { | |
"Model 1 (Complex handwritings)": ("prithivMLmods/Qwen2.5-VL-7B-Abliterated-Caption-it", Qwen2_5_VLForConditionalGeneration), | |
"Model 2 (simple and scanned handwriting)": ("nanonets/Nanonets-OCR-s", Qwen2_5_VLForConditionalGeneration), | |
} | |
# Model 3 removed to conserve memory. | |
MAX_NEW_TOKENS_DEFAULT = 512 | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
_loaded_processors, _loaded_models = {}, {} | |
print("π Preloading models into GPU/CPU memory...") | |
for name, (repo_id, cls) in MODEL_PATHS.items(): | |
try: | |
processor = AutoProcessor.from_pretrained(repo_id, trust_remote_code=True) | |
model = cls.from_pretrained( | |
repo_id, | |
trust_remote_code=True, | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
low_cpu_mem_usage=True | |
).to(device).eval() | |
_loaded_processors[name], _loaded_models[name] = processor, model | |
print(f"β {name} ready.") | |
except Exception as e: | |
print(f"β οΈ Failed to load {name}: {e}") | |
# ---------------- GPU Warmup ---------------- | |
def warmup(progress=gr.Progress(track_tqdm=True)): | |
try: | |
default_model_choice = next(iter(MODEL_PATHS.keys())) | |
processor = _loaded_processors[default_model_choice] | |
model = _loaded_models[default_model_choice] | |
tokenizer = getattr(processor, "tokenizer", None) | |
messages = [{"role": "user", "content": [{"type": "text", "text": "Warmup."}]}] | |
chat_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) if tokenizer and hasattr(tokenizer, "apply_chat_template") else "Warmup." | |
inputs = processor(text=[chat_prompt], images=None, return_tensors="pt").to(device) | |
with torch.inference_mode(): | |
_ = model.generate(**inputs, max_new_tokens=1) | |
return f"GPU warm and {default_model_choice} ready." | |
except Exception as e: | |
return f"Warmup skipped: {e}" | |
# ---------------- Helpers ---------------- | |
def _build_inputs(processor, tokenizer, image: Image.Image, prompt: str): | |
messages = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": prompt}]}] | |
if tokenizer and hasattr(tokenizer, "apply_chat_template"): | |
chat_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
# Explicitly set truncation=False to prevent the token mismatch error | |
return processor(text=[chat_prompt], images=[image], return_tensors="pt", truncation=False) | |
return processor(text=[prompt], images=[image], return_tensors="pt", truncation=False) | |
def _decode_text(model, processor, tokenizer, output_ids, prompt: str): | |
try: | |
decoded_text = processor.batch_decode(output_ids, skip_special_tokens=True)[0] | |
prompt_start = decoded_text.find(prompt) | |
if prompt_start != -1: | |
decoded_text = decoded_text[prompt_start + len(prompt):].strip() | |
else: | |
decoded_text = decoded_text.strip() | |
return decoded_text | |
except Exception: | |
try: | |
decoded_text = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0] | |
prompt_start = decoded_text.find(prompt) | |
if prompt_start != -1: | |
decoded_text = decoded_text[prompt_start + len(prompt):].strip() | |
return decoded_text | |
except Exception: | |
return str(output_ids).strip() | |
def _default_prompt(query: str | None) -> str: | |
if query and query.strip(): | |
return query.strip() | |
return ( | |
"You are a professional Handwritten OCR system.\n" | |
"TASK: Read the handwritten image and transcribe the text EXACTLY as written.\n" | |
"- Preserve original structure and line breaks.\n" | |
"- Keep spacing, bullet points, numbering, and indentation.\n" | |
"- Render tables as Markdown tables if present.\n" | |
"- Do NOT autocorrect spelling or grammar.\n" | |
"- Do NOT merge lines.\n" | |
"Return RAW transcription only." | |
) | |
def _safe_text(text: str) -> str: | |
return (text or "").strip() | |
def _hash_image(image: Image.Image) -> str: | |
# stable hash for dedup / linking feedback to the same page | |
img_bytes = image.tobytes() | |
return hashlib.sha256(img_bytes).hexdigest() | |
# ---------------- Memory: Post-correction Rules ---------------- | |
def _load_memory_rules(): | |
if os.path.exists(MEMORY_RULES_PATH): | |
try: | |
with open(MEMORY_RULES_PATH, "r", encoding="utf-8") as f: | |
return json.load(f) | |
except Exception: | |
pass | |
return {"global": {}, "by_model": {}} | |
def _save_memory_rules(rules): | |
with open(MEMORY_RULES_PATH, "w", encoding="utf-8") as f: | |
json.dump(rules, f, ensure_ascii=False, indent=2) | |
def _apply_memory(text: str, model_choice: str, enabled: bool): | |
if not enabled or not text: | |
return text | |
rules = _load_memory_rules() | |
# 1) Model-specific replacements | |
by_model = rules.get("by_model", {}).get(model_choice, {}) | |
for wrong, right in by_model.items(): | |
if wrong and right: | |
text = text.replace(wrong, right) | |
# 2) Global replacements | |
for wrong, right in rules.get("global", {}).items(): | |
for wrong, right in rules.get("global", {}).items(): | |
if wrong and right: | |
text = text.replace(wrong, right) | |
return text | |
def _compile_rules_from_feedback(min_count: int = 2, max_phrase_len: int = 40): | |
""" | |
Build replacement rules by mining feedback pairs (prediction -> correction). | |
We extract phrases that consistently changed, with frequency >= min_count. | |
""" | |
changes_counter_global = Counter() | |
changes_counter_by_model = defaultdict(Counter) | |
if not os.path.exists(FEEDBACK_PATH): | |
return | |
with open(FEEDBACK_PATH, "r", encoding="utf-8") as f: | |
for line in f: | |
try: | |
row = json.loads(line) | |
except Exception: | |
continue | |
if row.get("reward", 0) < 1: # only learn from thumbs-up or explicit 'accepted_correction' | |
continue | |
pred = _safe_text(row.get("prediction", "")) | |
corr = _safe_text(row.get("correction", "")) or _safe_text(row.get("ground_truth", "")) | |
if not pred or not corr: | |
continue | |
model_choice = row.get("model_choice", "") | |
# Extract ops | |
s = difflib.SequenceMatcher(None, pred, corr) | |
for tag, i1, i2, j1, j2 in s.get_opcodes(): | |
if tag in ("replace", "delete", "insert"): | |
wrong = pred[i1:i2] | |
right = corr[j1:j2] | |
# keep short-ish tokens/phrases | |
if 0 < len(wrong) <= max_phrase_len or 0 < len(right) <= max_phrase_len: | |
if wrong.strip(): | |
changes_counter_global[(wrong, right)] += 1 | |
if model_choice: | |
changes_counter_by_model[model_choice][(wrong, right)] += 1 | |
rules = {"global": {}, "by_model": {}} | |
# Global | |
for (wrong, right), cnt in changes_counter_global.items(): | |
if cnt >= min_count and wrong and right and wrong != right: | |
rules["global"][wrong] = right | |
# Per model | |
for model_choice, ctr in changes_counter_by_model.items(): | |
rules["by_model"].setdefault(model_choice, {}) | |
for (wrong, right), cnt in ctr.items(): | |
if cnt >= min_count and wrong and right and wrong != right: | |
rules["by_model"][model_choice][wrong] = right | |
_save_memory_rules(rules) | |
# ---------------- OCR Function ---------------- | |
def ocr_image(image: Image.Image, model_choice: str, query: str = None, | |
max_new_tokens: int = MAX_NEW_TOKENS_DEFAULT, | |
temperature: float = 0.1, top_p: float = 1.0, top_k: int = 0, repetition_penalty: float = 1.0, | |
use_memory: bool = True, | |
progress=gr.Progress(track_tqdm=True)): | |
if image is None: return "Please upload or capture an image." | |
if model_choice not in _loaded_models: return f"Invalid model: {model_choice}" | |
processor, model, tokenizer = _loaded_processors[model_choice], _loaded_models[model_choice], getattr(_loaded_processors[model_choice], "tokenizer", None) | |
prompt = _default_prompt(query) | |
batch = _build_inputs(processor, tokenizer, image, prompt).to(device) | |
with torch.inference_mode(): | |
output_ids = model.generate(**batch, max_new_tokens=max_new_tokens, do_sample=False, | |
temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty) | |
raw = _decode_text(model, processor, tokenizer, output_ids, prompt).replace("<|im_end|>", "").strip() | |
# Apply memory post-correction | |
post = _apply_memory(raw, model_choice, use_memory) | |
return post | |
# ---------------- Export Helpers ---------------- | |
def save_as_pdf(text): | |
text = _safe_text(text) | |
if not text: return None | |
doc = SimpleDocTemplate("output.pdf") | |
flowables = [Paragraph(t, getSampleStyleSheet()["Normal"]) for t in text.splitlines() if t != ""] | |
if not flowables: flowables = [Paragraph(" ", getSampleStyleSheet()["Normal"])] | |
doc.build(flowables) | |
return "output.pdf" | |
def save_as_word(text): | |
text = _safe_text(text) | |
if not text: return None | |
doc = Document() | |
for line in text.splitlines(): | |
doc.add_paragraph(line) | |
doc.save("output.docx") | |
return "output.docx" | |
def save_as_audio(text): | |
text = _safe_text(text) | |
if not text: return None | |
try: | |
tts = gTTS(text) | |
tts.save("output.mp3") | |
return "output.mp3" | |
except Exception as e: | |
print(f"gTTS failed: {e}") | |
return None | |
# ---------------- Metrics Function ---------------- | |
def calculate_cer_score(ground_truth: str, prediction: str) -> str: | |
""" | |
Calculates the Character Error Rate (CER). | |
A CER of 0.0 means the prediction is perfect. | |
""" | |
if not ground_truth or not prediction: | |
return "Cannot calculate CER: Missing ground truth or prediction." | |
ground_truth_cleaned = " ".join(ground_truth.strip().split()) | |
prediction_cleaned = " ".join(prediction.strip().split()) | |
error_rate = cer(ground_truth_cleaned, prediction_cleaned) | |
return f"Character Error Rate (CER): {error_rate:.4f}" | |
# ---------------- Feedback & Dataset ---------------- | |
def _append_jsonl(path, obj): | |
with open(path, "a", encoding="utf-8") as f: | |
f.write(json.dumps(obj, ensure_ascii=False) + "\n") | |
def _export_csv(): | |
# optional: CSV summary for spreadsheet views | |
if not os.path.exists(FEEDBACK_PATH): | |
return None | |
rows = [] | |
with open(FEEDBACK_PATH, "r", encoding="utf-8") as f: | |
for line in f: | |
try: | |
rows.append(json.loads(line)) | |
except Exception: | |
pass | |
if not rows: | |
return None | |
keys = ["id","timestamp","model_choice","image_sha256","prompt","prediction","correction","ground_truth","reward","cer"] | |
with open(CSV_EXPORT_PATH, "w", newline="", encoding="utf-8") as f: | |
w = csv.DictWriter(f, fieldnames=keys) | |
w.writeheader() | |
for r in rows: | |
flat = {k: r.get(k, "") for k in keys} | |
w.writerow(flat) | |
return CSV_EXPORT_PATH | |
def save_feedback(image: Image.Image, model_choice: str, prompt: str, | |
prediction: str, correction: str, ground_truth: str, reward: int): | |
""" | |
reward: 1 = good/accepted, 0 = neutral, -1 = bad | |
""" | |
if image is None: | |
return "Please provide the image again to link feedback." | |
if not prediction and not correction and not ground_truth: | |
return "Nothing to save." | |
image_hash = _hash_image(image) | |
# best target = correction, else ground_truth, else prediction | |
target = _safe_text(correction) or _safe_text(ground_truth) | |
pred = _safe_text(prediction) | |
cer_score = None | |
if target and pred: | |
try: | |
cer_score = cer(" ".join(target.split()), " ".join(pred.split())) | |
except Exception: | |
cer_score = None | |
row = { | |
"id": str(uuid.uuid4()), | |
"timestamp": datetime.utcnow().isoformat(), | |
"model_choice": model_choice or "", | |
"image_sha256": image_hash, | |
"prompt": _safe_text(prompt), | |
"prediction": pred, | |
"correction": _safe_text(correction), | |
"ground_truth": _safe_text(ground_truth), | |
"reward": int(reward), | |
"cer": float(cer_score) if cer_score is not None else None, | |
} | |
_append_jsonl(FEEDBACK_PATH, row) | |
return f"β Feedback saved (reward={reward})." | |
def compile_memory_rules(): | |
_compile_rules_from_feedback(min_count=2, max_phrase_len=60) | |
return "β Memory rules recompiled from positive feedback." | |
def export_grpo_preferences(): | |
""" | |
Build preference pairs for GRPO training: | |
- chosen: correction/ground_truth when present | |
- rejected: original prediction | |
""" | |
if not os.path.exists(FEEDBACK_PATH): | |
return "No feedback to export." | |
count = 0 | |
with open(GRPO_EXPORT_PATH, "w", encoding="utf-8") as out_f: | |
with open(FEEDBACK_PATH, "r", encoding="utf-8") as f: | |
for line in f: | |
try: | |
row = json.loads(line) | |
except Exception: | |
continue | |
pred = _safe_text(row.get("prediction", "")) | |
corr = _safe_text(row.get("correction", "")) or _safe_text(row.get("ground_truth", "")) | |
prompt = _safe_text(row.get("prompt", "")) or "Transcribe the image exactly." | |
if corr and pred and corr != pred and row.get("reward", 0) >= 0: | |
# One preference datapoint | |
out = { | |
"prompt": prompt, | |
"image_sha256": row.get("image_sha256", ""), | |
"chosen": corr, | |
"rejected": pred, | |
"model_choice": row.get("model_choice", "") | |
} | |
out_f.write(json.dumps(out, ensure_ascii=False) + "\n") | |
count += 1 | |
return f"β Exported {count} GRPO preference pairs to {GRPO_EXPORT_PATH}." | |
def get_grpo_file(): | |
if os.path.exists(GRPO_EXPORT_PATH): | |
return GRPO_EXPORT_PATH | |
return None | |
def get_csv_file(): | |
_export_csv() | |
if os.path.exists(CSV_EXPORT_PATH): | |
return CSV_EXPORT_PATH | |
return None | |
# ---------------- Evaluation Orchestration ---------------- | |
def perform_evaluation(image: Image.Image, model_name: str, ground_truth: str, | |
max_new_tokens: int, temperature: float, top_p: float, top_k: int, repetition_penalty: float, | |
use_memory: bool = True): | |
if image is None or not ground_truth: | |
return "Please upload an image and provide the ground truth.", "N/A" | |
prediction = ocr_image(image, model_name, max_new_tokens=max_new_tokens, | |
temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty, | |
use_memory=use_memory) | |
cer_score = calculate_cer_score(ground_truth, prediction) | |
return prediction, cer_score | |
# ---------------- GRPO Trainer Script Writer ---------------- | |
TRAINER_SCRIPT = r"""# grpo_train.py β Offline GRPO training with TRL (run separately) | |
# pip install trl accelerate peft transformers datasets | |
# This script expects data/grpo_prefs.jsonl produced by the app. | |
import os, json | |
from datasets import load_dataset | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from trl import GRPOConfig, GRPOTrainer | |
MODEL_ID = os.environ.get("BASE_MODEL", "Qwen/Qwen2.5-VL-7B-Instruct") # change if needed | |
OUTPUT_DIR = os.environ.get("OUTPUT_DIR", "grpo_output") | |
DATA_PATH = os.environ.get("DATA_PATH", "data/grpo_prefs.jsonl") | |
# Our jsonl: each line has prompt, chosen, rejected (and image_sha256/model_choice optionally) | |
# We'll format as required by TRL: prompt + responses with one preferred | |
def _jsonl_dataset(jsonl_path): | |
data = [] | |
with open(jsonl_path, "r", encoding="utf-8") as f: | |
for line in f: | |
try: | |
row = json.loads(line) | |
except Exception: | |
continue | |
prompt = row.get("prompt", "") | |
chosen = row.get("chosen", "") | |
rejected = row.get("rejected", "") | |
if prompt and chosen and rejected: | |
data.append({"prompt": prompt, "chosen": chosen, "rejected": rejected}) | |
return data | |
def main(): | |
data = _jsonl_dataset(DATA_PATH) | |
if not data: | |
print("No GRPO data found.") | |
return | |
# Create a HuggingFace datasets Dataset from memory | |
from datasets import Dataset | |
ds = Dataset.from_list(data) | |
tok = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_ID, trust_remote_code=True, device_map="auto" | |
) | |
# Minimal config β tune to your GPU | |
cfg = GRPOConfig( | |
output_dir=OUTPUT_DIR, | |
learning_rate=5e-6, | |
per_device_train_batch_size=1, | |
gradient_accumulation_steps=8, | |
num_train_epochs=1, | |
logging_steps=10, | |
save_steps=200, | |
max_prompt_length=512, | |
max_completion_length=768, | |
bf16=True | |
) | |
trainer = GRPOTrainer( | |
model=model, | |
ref_model=None, # let TRL create a frozen copy internally | |
args=cfg, | |
tokenizer=tok, | |
train_dataset=ds | |
) | |
trainer.train() | |
trainer.save_model(OUTPUT_DIR) | |
print("β GRPO training complete. LoRA/weights saved to", OUTPUT_DIR) | |
if __name__ == "__main__": | |
main() | |
""" | |
def _write_trainer_script(): | |
os.makedirs("train", exist_ok=True) | |
path = os.path.join("train", "grpo_train.py") | |
with open(path, "w", encoding="utf-8") as f: | |
f.write(TRAINER_SCRIPT) | |
return path | |
# ---------------- Gradio Interface ---------------- | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown("## βπΎ Wilson Handwritten OCR β with Feedback Loop") | |
model_choice = gr.Radio(choices=list(MODEL_PATHS.keys()), | |
value=list(MODEL_PATHS.keys())[0], | |
label="Select OCR Model") | |
with gr.Tab("πΌ Image Inference"): | |
query_input = gr.Textbox(label="Custom Prompt (optional)", placeholder="Leave empty for RAW structured output") | |
image_input = gr.Image(type="pil", label="Upload / Capture Handwritten Image", sources=["upload", "webcam"]) | |
use_memory = gr.Checkbox(value=True, label="Enable Memory Post-correction (auto-fix known mistakes)") | |
with gr.Accordion("βοΈ Advanced Options", open=False): | |
max_new_tokens = gr.Slider(1, 2048, value=MAX_NEW_TOKENS_DEFAULT, step=1, label="Max new tokens") | |
temperature = gr.Slider(0.1, 2.0, value=0.1, step=0.05, label="Temperature") | |
top_p = gr.Slider(0.05, 1.0, value=1.0, step=0.05, label="Top-p (nucleus)") | |
top_k = gr.Slider(0, 1000, value=0, step=1, label="Top-k") | |
repetition_penalty = gr.Slider(0.8, 2.0, value=1.0, step=0.05, label="Repetition penalty") | |
extract_btn = gr.Button("π€ Extract RAW Text", variant="primary") | |
clear_btn = gr.Button("π§Ή Clear") | |
raw_output = gr.Textbox(label="π Output (post-corrected if memory is ON)", lines=18, show_copy_button=True) | |
# Quick Feedback strip | |
gr.Markdown("### βοΈ Quick Feedback") | |
correction_box = gr.Textbox(label="Your Correction (optional)", placeholder="Paste your corrected text here; leave empty if the output is perfect.", lines=8) | |
ground_truth_box = gr.Textbox(label="Ground Truth (optional)", placeholder="If you have a reference transcription, paste it here.", lines=6) | |
with gr.Row(): | |
btn_good = gr.Button("π Accept (Save Feedback as Correct)", variant="primary") | |
btn_bad = gr.Button("π Bad (Save Feedback as Incorrect)") | |
feedback_status = gr.Markdown("") | |
pdf_btn = gr.Button("β¬οΈ Download as PDF") | |
word_btn = gr.Button("β¬οΈ Download as Word") | |
audio_btn = gr.Button("π Download as Audio") | |
pdf_file, word_file, audio_file = gr.File(label="PDF File"), gr.File(label="Word File"), gr.File(label="Audio File") | |
extract_btn.click( | |
fn=ocr_image, | |
inputs=[image_input, model_choice, query_input, max_new_tokens, temperature, top_p, top_k, repetition_penalty, use_memory], | |
outputs=[raw_output], | |
api_name="ocr_image" | |
) | |
pdf_btn.click(fn=save_as_pdf, inputs=[raw_output], outputs=[pdf_file]) | |
word_btn.click(fn=save_as_word, inputs=[raw_output], outputs=[word_file]) | |
audio_btn.click(fn=save_as_audio, inputs=[raw_output], outputs=[audio_file]) | |
def _clear(): | |
return ("", None, "", MAX_NEW_TOKENS_DEFAULT, 0.1, 1.0, 0, 1.0, True, "", "", "",) | |
clear_btn.click( | |
fn=_clear, | |
outputs=[raw_output, image_input, query_input, max_new_tokens, temperature, top_p, top_k, repetition_penalty, use_memory, correction_box, ground_truth_box, feedback_status] | |
) | |
# Quick feedback save | |
btn_good.click( | |
fn=lambda img, mc, prmpt, pred, corr, gt: save_feedback(img, mc, prmpt, pred, corr, gt, reward=1), | |
inputs=[image_input, model_choice, query_input, raw_output, correction_box, ground_truth_box], | |
outputs=[feedback_status] | |
) | |
btn_bad.click( | |
fn=lambda img, mc, prmpt, pred, corr, gt: save_feedback(img, mc, prmpt, pred, corr, gt, reward=-1), | |
inputs=[image_input, model_choice, query_input, raw_output, correction_box, ground_truth_box], | |
outputs=[feedback_status] | |
) | |
with gr.Tab("π Model Evaluation"): | |
gr.Markdown("### π Evaluate Model Accuracy") | |
eval_image_input = gr.Image(type="pil", label="Upload Image for Evaluation", sources=["upload"]) | |
eval_ground_truth = gr.Textbox(label="Ground Truth (Correct Transcription)", lines=10, placeholder="Type or paste the correct text here.") | |
eval_model_output = gr.Textbox(label="Model's Prediction", lines=10, interactive=False, show_copy_button=True) | |
eval_cer_output = gr.Textbox(label="Metrics", interactive=False) | |
eval_use_memory = gr.Checkbox(value=True, label="Enable Memory Post-correction") | |
with gr.Row(): | |
run_evaluation_btn = gr.Button("π Run OCR and Evaluate", variant="primary") | |
clear_evaluation_btn = gr.Button("π§Ή Clear") | |
run_evaluation_btn.click( | |
fn=perform_evaluation, | |
inputs=[eval_image_input, model_choice, eval_ground_truth, max_new_tokens, temperature, top_p, top_k, repetition_penalty, eval_use_memory], | |
outputs=[eval_model_output, eval_cer_output] | |
) | |
clear_evaluation_btn.click( | |
fn=lambda: (None, "", "", ""), | |
outputs=[eval_image_input, eval_ground_truth, eval_model_output, eval_cer_output] | |
) | |
with gr.Tab("βοΈ Feedback & Memory"): | |
gr.Markdown(""" | |
**Pipeline** | |
1) Save feedback (π / π) and add corrections. | |
2) Click **Build/Refresh Memory** to generate auto-fix rules from positive feedback. | |
3) Keep **Enable Memory Post-correction** checked on inference/eval tabs. | |
""") | |
build_mem_btn = gr.Button("π§ Build/Refresh Memory from Feedback") | |
mem_status = gr.Markdown("") | |
build_mem_btn.click(fn=compile_memory_rules, outputs=[mem_status]) | |
csv_status = gr.Markdown("") | |
gr.Markdown("---") | |
gr.Markdown("### β¬οΈ Download Feedback Data") | |
with gr.Row(): | |
download_csv_btn = gr.Button("β¬οΈ Download Feedback as CSV") | |
download_csv_file = gr.File(label="CSV File") | |
download_csv_btn.click(fn=get_csv_file, outputs=download_csv_file) | |
with gr.Tab("π§ͺ GRPO / Dataset"): | |
gr.Markdown(""" | |
**GRPO Fine-tuning** (run offline or in a training Space): | |
- Click **Export GRPO Preferences** to produce `data/grpo_prefs.jsonl` of (prompt, chosen, rejected). | |
- Click **Write Trainer Script** to create `train/grpo_train.py`. | |
- Then run: | |
```bash | |
pip install trl accelerate peft transformers datasets | |
python train/grpo_train.py | |
Set BASE_MODEL/OUTPUT_DIR env vars if you like. | |
```""") | |
grpo_btn = gr.Button("π¦ Export GRPO Preferences") | |
grpo_status = gr.Markdown("") | |
grpo_btn.click(fn=export_grpo_preferences, outputs=[grpo_status]) | |
write_script_btn = gr.Button("π Write grpo_train.py") | |
write_script_status = gr.Markdown("") | |
write_script_btn.click(fn=lambda: f"β Trainer script written to {_write_trainer_script()}", outputs=[write_script_status]) | |
gr.Markdown("---") | |
gr.Markdown("### β¬οΈ Download GRPO Dataset") | |
with gr.Row(): | |
download_grpo_btn = gr.Button("β¬οΈ Download GRPO Data (grpo_prefs.jsonl)") | |
download_grpo_file = gr.File(label="GRPO Dataset File") | |
download_grpo_btn.click(fn=get_grpo_file, outputs=[download_grpo_file]) | |
# The `if __name__ == "__main__":` block should be at the top level | |
if __name__ == "__main__": | |
demo.queue(max_size=50).launch(share=True) |