andrejrad commited on
Commit
dcdd99b
·
verified ·
1 Parent(s): 1cffa06

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +113 -117
app.py CHANGED
@@ -1,22 +1,18 @@
1
- import os, json, re, traceback
2
  from typing import Any, Dict, Tuple
3
  import gradio as gr
4
  from PIL import Image
5
  import torch
6
  from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM, AutoConfig
7
 
8
- # --------------------------
9
- # Env / params
10
- # --------------------------
11
  MODEL_ID = os.environ.get("MODEL_ID", "inference-net/ClipTagger-12b")
12
- HF_TOKEN = os.environ.get("HF_TOKEN") # set in Space → Settings → Variables & secrets
13
  TEMP = 0.1
14
- MAX_NEW_TOKENS = 2000
15
  DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
16
 
17
- # --------------------------
18
- # Prompts (yours)
19
- # --------------------------
20
  SYSTEM_PROMPT = (
21
  "You are an image annotation API trained to analyze YouTube video keyframes. "
22
  "You will be given instructions on the output format, what to caption, and how to perform your job. "
@@ -53,55 +49,61 @@ Rules:
53
  - Output **only the JSON**, no extra text or explanation.
54
  """
55
 
56
- # --------------------------
57
- # Utilities
58
- # --------------------------
59
- def _json_extract(text: str):
60
- """Strict JSON parse with top-level {...} fallback."""
61
  try:
62
- return json.loads(text)
63
  except Exception:
64
- m = re.search(r"\{(?:[^{}]|(?R))*\}", text, flags=re.DOTALL)
65
- if m:
66
- try:
67
- return json.loads(m.group(0))
68
- except Exception:
69
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  return None
71
 
72
- def _build_messages(image: Image.Image):
73
  return [
74
  {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
75
  {"role": "user", "content": [{"type": "image", "image": image},
76
  {"type": "text", "text": USER_PROMPT}]}
77
  ]
78
 
79
- def _downscale_if_huge(pil: Image.Image, max_side: int = 1792) -> Image.Image:
80
- """Cap longest side to keep memory predictable; A100 is roomy but this avoids extreme uploads."""
81
  if pil is None:
82
  return pil
83
  w, h = pil.size
84
  m = max(w, h)
85
  if m <= max_side:
86
  return pil.convert("RGB")
87
- scale = max_side / m
88
- new_w, new_h = int(w * scale), int(h * scale)
89
- return pil.convert("RGB").resize((new_w, new_h), Image.BICUBIC)
90
 
91
- # --------------------------
92
- # Load model (dedicated GPU)
93
- # --------------------------
94
  processor = tokenizer = model = None
95
  LOAD_ERROR = None
96
 
97
  try:
98
  cfg = AutoConfig.from_pretrained(MODEL_ID, token=HF_TOKEN, trust_remote_code=True)
99
  if "clip" in cfg.__class__.__name__.lower():
100
- raise RuntimeError(
101
- f"MODEL_ID '{MODEL_ID}' resolves to a CLIP/encoder config; need a causal VLM checkpoint."
102
- )
103
 
104
- # Try quantized path (compressed-tensors) per your config
105
  try:
106
  processor = AutoProcessor.from_pretrained(
107
  MODEL_ID, token=HF_TOKEN, trust_remote_code=True, use_fast=True
@@ -111,107 +113,105 @@ try:
111
  MODEL_ID, token=HF_TOKEN, trust_remote_code=True
112
  )
113
 
114
- try:
115
- model = AutoModelForCausalLM.from_pretrained(
116
- MODEL_ID,
117
- token=HF_TOKEN,
118
- device_map="auto",
119
- torch_dtype=DTYPE,
120
- trust_remote_code=True,
121
- )
122
- except Exception as e:
123
- # Fallback: disable quantization if the backend isn't available
124
- if "compressed_tensors" in str(e):
125
- model = AutoModelForCausalLM.from_pretrained(
126
- MODEL_ID,
127
- token=HF_TOKEN,
128
- device_map="auto",
129
- torch_dtype=DTYPE,
130
- trust_remote_code=True,
131
- quantization_config=None,
132
- )
133
- else:
134
- raise
135
 
136
  tokenizer = getattr(processor, "tokenizer", None) or AutoTokenizer.from_pretrained(
137
  MODEL_ID, token=HF_TOKEN, trust_remote_code=True, use_fast=True
138
  )
 
139
 
140
  except Exception as e:
141
  LOAD_ERROR = f"{e}\n\n{traceback.format_exc()}"
142
 
143
- # --------------------------
144
- # Inference
145
- # --------------------------
146
- def run(image: Image.Image) -> Tuple[str, Dict[str, Any] | None, bool]:
147
  if image is None:
148
  return "Please upload an image.", None, False
149
  if model is None or processor is None:
150
- msg = (
151
- "❌ Model failed to load.\n\n"
152
- f"{LOAD_ERROR or 'Unknown error.'}\n"
153
- "Check MODEL_ID/HF_TOKEN and that the repo includes model + processor files."
154
- )
155
- return msg, None, False
156
 
157
- image = _downscale_if_huge(image)
158
 
159
- # Build chat prompt
160
  if hasattr(processor, "apply_chat_template"):
161
- prompt = processor.apply_chat_template(_build_messages(image), add_generation_prompt=True, tokenize=False)
162
  else:
163
- # Very rare fallback path
164
- msgs = _build_messages(image)
165
- prompt = ""
166
- for m in msgs:
167
- role = m["role"].upper()
168
- for chunk in m["content"]:
169
- if chunk["type"] == "text":
170
- prompt += f"{role}: {chunk['text']}\n"
171
- elif chunk["type"] == "image":
172
- prompt += f"{role}: [IMAGE]\n"
173
 
174
  # Tokenize with vision
175
  inputs = processor(text=prompt, images=image, return_tensors="pt").to(model.device)
176
 
177
- # Gen args
178
- gen_kwargs = dict(
179
- temperature=TEMP,
180
- max_new_tokens=MAX_NEW_TOKENS,
181
- )
182
  eos = getattr(model.config, "eos_token_id", None)
183
- if eos is not None:
184
- gen_kwargs["eos_token_id"] = eos
185
 
186
- # Try to enforce JSON; if unsupported, we'll retry without
187
  tried = []
188
- for tag, extra in [
189
- ("json_object", {"response_format": {"type": "json_object"}}),
190
- ("no_response_format", {}),
191
- ("short_deterministic", {"temperature": 0.0, "max_new_tokens": min(512, MAX_NEW_TOKENS)}),
192
- ]:
193
- try:
194
- with torch.inference_mode():
195
- out = model.generate(**inputs, **{**gen_kwargs, **extra})
196
- text = (processor.decode(out[0], skip_special_tokens=True)
197
- if hasattr(processor, "decode")
198
- else AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN, use_fast=True).decode(out[0], skip_special_tokens=True))
199
- if USER_PROMPT in text:
200
- text = text.split(USER_PROMPT)[-1].strip()
201
- parsed = _json_extract(text)
202
- if isinstance(parsed, dict):
203
- return json.dumps(parsed, indent=2), parsed, True
204
- tried.append((tag, "parsed-failed"))
205
- except Exception as e:
206
- tried.append((tag, f"err={e}"))
207
-
208
- # If all strategies failed, return debug info
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  return "Generation failed.\nTried: " + "\n".join([f"{t[0]} -> {t[1]}" for t in tried]), None, False
210
 
211
- # --------------------------
212
- # UI
213
- # --------------------------
214
- with gr.Blocks(theme=gr.themes.Soft(), analytics_enabled=False, title="Keyframe Annotator (Gemma-3 VLM)") as demo:
215
  gr.Markdown("# Keyframe Annotator (Gemma-3-12B FT · A100)\nUpload an image to get **strict JSON** annotations.")
216
  if LOAD_ERROR:
217
  with gr.Accordion("Startup Error Details", open=False):
@@ -224,12 +224,8 @@ with gr.Blocks(theme=gr.themes.Soft(), analytics_enabled=False, title="Keyframe
224
  with gr.Column(scale=1):
225
  out_text = gr.Code(label="Output (JSON or error)")
226
  out_json = gr.JSON(label="Parsed JSON")
227
- ok = gr.Checkbox(label="Valid JSON", value=False, interactive=False)
228
-
229
- def on_click(img):
230
- return run(img)
231
-
232
- btn.click(on_click, inputs=[image], outputs=[out_text, out_json, ok])
233
 
 
234
 
235
  demo.queue(max_size=32).launch()
 
1
+ import os, json, traceback
2
  from typing import Any, Dict, Tuple
3
  import gradio as gr
4
  from PIL import Image
5
  import torch
6
  from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM, AutoConfig
7
 
8
+ # -------- Env / params --------
 
 
9
  MODEL_ID = os.environ.get("MODEL_ID", "inference-net/ClipTagger-12b")
10
+ HF_TOKEN = os.environ.get("HF_TOKEN")
11
  TEMP = 0.1
12
+ MAX_NEW_TOKENS = 768 # faster demo; raise later if needed
13
  DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
14
 
15
+ # -------- Prompts (yours) --------
 
 
16
  SYSTEM_PROMPT = (
17
  "You are an image annotation API trained to analyze YouTube video keyframes. "
18
  "You will be given instructions on the output format, what to caption, and how to perform your job. "
 
49
  - Output **only the JSON**, no extra text or explanation.
50
  """
51
 
52
+ # -------- Utils --------
53
+ def extract_top_level_json(s: str):
54
+ """Parse JSON; if extra text around it, extract the first balanced {...} block."""
55
+ # Fast path
 
56
  try:
57
+ return json.loads(s)
58
  except Exception:
59
+ pass
60
+ # Brace-stack extraction
61
+ start = None
62
+ depth = 0
63
+ for i, ch in enumerate(s):
64
+ if ch == '{':
65
+ if depth == 0:
66
+ start = i
67
+ depth += 1
68
+ elif ch == '}':
69
+ if depth > 0:
70
+ depth -= 1
71
+ if depth == 0 and start is not None:
72
+ chunk = s[start:i+1]
73
+ try:
74
+ return json.loads(chunk)
75
+ except Exception:
76
+ # continue scanning for the next candidate
77
+ start = None
78
  return None
79
 
80
+ def build_messages(image: Image.Image):
81
  return [
82
  {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
83
  {"role": "user", "content": [{"type": "image", "image": image},
84
  {"type": "text", "text": USER_PROMPT}]}
85
  ]
86
 
87
+ def downscale_if_huge(pil: Image.Image, max_side: int = 1792) -> Image.Image:
 
88
  if pil is None:
89
  return pil
90
  w, h = pil.size
91
  m = max(w, h)
92
  if m <= max_side:
93
  return pil.convert("RGB")
94
+ s = max_side / m
95
+ return pil.convert("RGB").resize((int(w*s), int(h*s)), Image.BICUBIC)
 
96
 
97
+ # -------- Load model (A100) --------
 
 
98
  processor = tokenizer = model = None
99
  LOAD_ERROR = None
100
 
101
  try:
102
  cfg = AutoConfig.from_pretrained(MODEL_ID, token=HF_TOKEN, trust_remote_code=True)
103
  if "clip" in cfg.__class__.__name__.lower():
104
+ raise RuntimeError(f"MODEL_ID '{MODEL_ID}' is a CLIP/encoder repo; need a causal VLM.")
 
 
105
 
106
+ print("[boot] loading processor…", flush=True)
107
  try:
108
  processor = AutoProcessor.from_pretrained(
109
  MODEL_ID, token=HF_TOKEN, trust_remote_code=True, use_fast=True
 
113
  MODEL_ID, token=HF_TOKEN, trust_remote_code=True
114
  )
115
 
116
+ print("[boot] loading model…", flush=True)
117
+ # Force full-precision path on A100 first; add quantized path later if desired
118
+ model = AutoModelForCausalLM.from_pretrained(
119
+ MODEL_ID,
120
+ token=HF_TOKEN,
121
+ device_map="auto",
122
+ torch_dtype=DTYPE,
123
+ trust_remote_code=True,
124
+ # quantization_config=None, # keep commented if you want to honor repo quant; uncomment to force dequant
125
+ )
 
 
 
 
 
 
 
 
 
 
 
126
 
127
  tokenizer = getattr(processor, "tokenizer", None) or AutoTokenizer.from_pretrained(
128
  MODEL_ID, token=HF_TOKEN, trust_remote_code=True, use_fast=True
129
  )
130
+ print("[boot] ready.", flush=True)
131
 
132
  except Exception as e:
133
  LOAD_ERROR = f"{e}\n\n{traceback.format_exc()}"
134
 
135
+ # -------- Inference --------
136
+ def generate(image: Image.Image) -> Tuple[str, Dict[str, Any] | None, bool]:
 
 
137
  if image is None:
138
  return "Please upload an image.", None, False
139
  if model is None or processor is None:
140
+ return f"❌ Load error:\n{LOAD_ERROR}", None, False
 
 
 
 
 
141
 
142
+ image = downscale_if_huge(image)
143
 
144
+ # Build prompt
145
  if hasattr(processor, "apply_chat_template"):
146
+ prompt = processor.apply_chat_template(build_messages(image), add_generation_prompt=True, tokenize=False)
147
  else:
148
+ # fallback join (rare)
149
+ prompt = USER_PROMPT
 
 
 
 
 
 
 
 
150
 
151
  # Tokenize with vision
152
  inputs = processor(text=prompt, images=image, return_tensors="pt").to(model.device)
153
 
154
+ # Common gen kwargs
 
 
 
 
155
  eos = getattr(model.config, "eos_token_id", None)
 
 
156
 
 
157
  tried = []
158
+
159
+ # (1) Greedy, no sampling (most stable, no temperature arg)
160
+ try:
161
+ g = dict(do_sample=False, max_new_tokens=MAX_NEW_TOKENS)
162
+ if eos is not None:
163
+ g["eos_token_id"] = eos
164
+ with torch.inference_mode():
165
+ out = model.generate(**inputs, **g)
166
+ text = (processor.decode(out[0], skip_special_tokens=True)
167
+ if hasattr(processor, "decode")
168
+ else tokenizer.decode(out[0], skip_special_tokens=True))
169
+ parsed = extract_top_level_json(text)
170
+ if isinstance(parsed, dict):
171
+ return json.dumps(parsed, indent=2), parsed, True
172
+ tried.append(("greedy", "parsed-failed"))
173
+ except Exception as e:
174
+ tried.append(("greedy", f"err={e}"))
175
+
176
+ # (2) Sampling with temperature=0.1
177
+ try:
178
+ g = dict(do_sample=True, temperature=TEMP, max_new_tokens=MAX_NEW_TOKENS)
179
+ if eos is not None:
180
+ g["eos_token_id"] = eos
181
+ with torch.inference_mode():
182
+ out = model.generate(**inputs, **g)
183
+ text = (processor.decode(out[0], skip_special_tokens=True)
184
+ if hasattr(processor, "decode")
185
+ else tokenizer.decode(out[0], skip_special_tokens=True))
186
+ parsed = extract_top_level_json(text)
187
+ if isinstance(parsed, dict):
188
+ return json.dumps(parsed, indent=2), parsed, True
189
+ tried.append(("sample_t0.1", "parsed-failed"))
190
+ except Exception as e:
191
+ tried.append(("sample_t0.1", f"err={e}"))
192
+
193
+ # (3) Shorter greedy
194
+ try:
195
+ g = dict(do_sample=False, max_new_tokens=min(512, MAX_NEW_TOKENS))
196
+ if eos is not None:
197
+ g["eos_token_id"] = eos
198
+ with torch.inference_mode():
199
+ out = model.generate(**inputs, **g)
200
+ text = (processor.decode(out[0], skip_special_tokens=True)
201
+ if hasattr(processor, "decode")
202
+ else tokenizer.decode(out[0], skip_special_tokens=True))
203
+ parsed = extract_top_level_json(text)
204
+ if isinstance(parsed, dict):
205
+ return json.dumps(parsed, indent=2), parsed, True
206
+ tried.append(("greedy_short", "parsed-failed"))
207
+ except Exception as e:
208
+ tried.append(("greedy_short", f"err={e}"))
209
+
210
+ # Debug info if all fail
211
  return "Generation failed.\nTried: " + "\n".join([f"{t[0]} -> {t[1]}" for t in tried]), None, False
212
 
213
+ # -------- UI --------
214
+ with gr.Blocks(theme=gr.themes.Soft(), analytics_enabled=False, title="Keyframe Annotator (Gemma-3 VLM · A100)") as demo:
 
 
215
  gr.Markdown("# Keyframe Annotator (Gemma-3-12B FT · A100)\nUpload an image to get **strict JSON** annotations.")
216
  if LOAD_ERROR:
217
  with gr.Accordion("Startup Error Details", open=False):
 
224
  with gr.Column(scale=1):
225
  out_text = gr.Code(label="Output (JSON or error)")
226
  out_json = gr.JSON(label="Parsed JSON")
227
+ ok_flag = gr.Checkbox(label="Valid JSON", value=False, interactive=False)
 
 
 
 
 
228
 
229
+ btn.click(generate, inputs=[image], outputs=[out_text, out_json, ok_flag])
230
 
231
  demo.queue(max_size=32).launch()