Emeritus-21 commited on
Commit
414e721
·
verified ·
1 Parent(s): 82b353e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -11
app.py CHANGED
@@ -1,3 +1,12 @@
 
 
 
 
 
 
 
 
 
1
  # app.py — HTR Space with Feedback Loop, Memory Post-Correction, and GRPO Export
2
 
3
  import os, time, json, hashlib, difflib, uuid, csv
@@ -18,10 +27,10 @@ from jiwer import cer
18
 
19
  # ---------------- Storage & Paths ----------------
20
  os.makedirs("data", exist_ok=True)
21
- FEEDBACK_PATH = "data/feedback.jsonl" # raw feedback log (per sample)
22
- MEMORY_RULES_PATH = "data/memory_rules.json" # compiled post-correction rules
23
- GRPO_EXPORT_PATH = "data/grpo_prefs.jsonl" # preference pairs for GRPO
24
- CSV_EXPORT_PATH = "data/feedback.csv" # optional tabular export
25
 
26
  # ---------------- Models ----------------
27
  MODEL_PATHS = {
@@ -69,10 +78,27 @@ def warmup(progress=gr.Progress(track_tqdm=True)):
69
  # ---------------- Helpers ----------------
70
  def _build_inputs(processor, tokenizer, image: Image.Image, prompt: str):
71
  messages = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": prompt}]}]
 
 
 
 
 
72
  if tokenizer and hasattr(tokenizer, "apply_chat_template"):
73
  chat_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
74
- return processor(text=[chat_prompt], images=[image], return_tensors="pt")
75
- return processor(text=[prompt], images=[image], return_tensors="pt")
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  def _decode_text(model, processor, tokenizer, output_ids, prompt: str):
78
  try:
@@ -161,7 +187,7 @@ def _compile_rules_from_feedback(min_count: int = 2, max_phrase_len: int = 40):
161
  row = json.loads(line)
162
  except Exception:
163
  continue
164
- if row.get("reward", 0) < 1: # only learn from thumbs-up or explicit 'accepted_correction'
165
  continue
166
  pred = _safe_text(row.get("prediction", ""))
167
  corr = _safe_text(row.get("correction", "")) or _safe_text(row.get("ground_truth", ""))
@@ -209,7 +235,7 @@ def ocr_image(image: Image.Image, model_choice: str, query: str = None,
209
  batch = _build_inputs(processor, tokenizer, image, prompt).to(device)
210
  with torch.inference_mode():
211
  output_ids = model.generate(**batch, max_new_tokens=max_new_tokens, do_sample=False,
212
- temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty)
213
  raw = _decode_text(model, processor, tokenizer, output_ids, prompt).replace("<|im_end|>", "").strip()
214
  # Apply memory post-correction
215
  post = _apply_memory(raw, model_choice, use_memory)
@@ -391,7 +417,7 @@ from datasets import load_dataset
391
  from transformers import AutoModelForCausalLM, AutoTokenizer
392
  from trl import GRPOConfig, GRPOTrainer
393
 
394
- MODEL_ID = os.environ.get("BASE_MODEL", "Qwen/Qwen2.5-VL-7B-Instruct") # change if needed
395
  OUTPUT_DIR = os.environ.get("OUTPUT_DIR", "grpo_output")
396
  DATA_PATH = os.environ.get("DATA_PATH", "data/grpo_prefs.jsonl")
397
 
@@ -443,7 +469,7 @@ def main():
443
 
444
  trainer = GRPOTrainer(
445
  model=model,
446
- ref_model=None, # let TRL create a frozen copy internally
447
  args=cfg,
448
  tokenizer=tok,
449
  train_dataset=ds
@@ -606,4 +632,5 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
606
 
607
  # The `if __name__ == "__main__":` block should be at the top level
608
  if __name__ == "__main__":
609
- demo.queue(max_size=50).launch(share=True)
 
 
1
+ To address the `ValueError` from the previous conversation, the primary issue is how the `_build_inputs` function handles the processor call, as it doesn't explicitly pass `max_length`. This can lead to the processor's internal logic truncating the input in a way that removes the image tokens, causing a mismatch.
2
+
3
+ Here are the complete corrections for your `app.py` file to fix that issue and improve the overall code. The most important change is to the `_build_inputs` function.
4
+
5
+ -----
6
+
7
+ ### Corrected Code for `app.py`
8
+
9
+ ````python
10
  # app.py — HTR Space with Feedback Loop, Memory Post-Correction, and GRPO Export
11
 
12
  import os, time, json, hashlib, difflib, uuid, csv
 
27
 
28
  # ---------------- Storage & Paths ----------------
29
  os.makedirs("data", exist_ok=True)
30
+ FEEDBACK_PATH = "data/feedback.jsonl" # raw feedback log (per sample)
31
+ MEMORY_RULES_PATH = "data/memory_rules.json" # compiled post-correction rules
32
+ GRPO_EXPORT_PATH = "data/grpo_prefs.jsonl" # preference pairs for GRPO
33
+ CSV_EXPORT_PATH = "data/feedback.csv" # optional tabular export
34
 
35
  # ---------------- Models ----------------
36
  MODEL_PATHS = {
 
78
  # ---------------- Helpers ----------------
79
  def _build_inputs(processor, tokenizer, image: Image.Image, prompt: str):
80
  messages = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": prompt}]}]
81
+
82
+ # We explicitly set max_length and truncation here to resolve the token mismatch error.
83
+ # A value of 2048 is safe, as an image takes up ~1024 tokens.
84
+ max_len_val = 2048
85
+
86
  if tokenizer and hasattr(tokenizer, "apply_chat_template"):
87
  chat_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
88
+ return processor(
89
+ text=[chat_prompt],
90
+ images=[image],
91
+ return_tensors="pt",
92
+ max_length=max_len_val,
93
+ truncation=True
94
+ )
95
+ return processor(
96
+ text=[prompt],
97
+ images=[image],
98
+ return_tensors="pt",
99
+ max_length=max_len_val,
100
+ truncation=True
101
+ )
102
 
103
  def _decode_text(model, processor, tokenizer, output_ids, prompt: str):
104
  try:
 
187
  row = json.loads(line)
188
  except Exception:
189
  continue
190
+ if row.get("reward", 0) < 1: # only learn from thumbs-up or explicit 'accepted_correction'
191
  continue
192
  pred = _safe_text(row.get("prediction", ""))
193
  corr = _safe_text(row.get("correction", "")) or _safe_text(row.get("ground_truth", ""))
 
235
  batch = _build_inputs(processor, tokenizer, image, prompt).to(device)
236
  with torch.inference_mode():
237
  output_ids = model.generate(**batch, max_new_tokens=max_new_tokens, do_sample=False,
238
+ temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty)
239
  raw = _decode_text(model, processor, tokenizer, output_ids, prompt).replace("<|im_end|>", "").strip()
240
  # Apply memory post-correction
241
  post = _apply_memory(raw, model_choice, use_memory)
 
417
  from transformers import AutoModelForCausalLM, AutoTokenizer
418
  from trl import GRPOConfig, GRPOTrainer
419
 
420
+ MODEL_ID = os.environ.get("BASE_MODEL", "Qwen/Qwen2.5-VL-7B-Instruct") # change if needed
421
  OUTPUT_DIR = os.environ.get("OUTPUT_DIR", "grpo_output")
422
  DATA_PATH = os.environ.get("DATA_PATH", "data/grpo_prefs.jsonl")
423
 
 
469
 
470
  trainer = GRPOTrainer(
471
  model=model,
472
+ ref_model=None, # let TRL create a frozen copy internally
473
  args=cfg,
474
  tokenizer=tok,
475
  train_dataset=ds
 
632
 
633
  # The `if __name__ == "__main__":` block should be at the top level
634
  if __name__ == "__main__":
635
+ demo.queue(max_size=50).launch(share=True)
636
+ ````