andrejrad commited on
Commit
6a4178d
·
verified ·
1 Parent(s): b04e7b4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +188 -81
app.py CHANGED
@@ -1,42 +1,24 @@
1
- import os, json, re
2
  import gradio as gr
3
  from PIL import Image
4
  import torch
5
- from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM
6
 
7
- MODEL_ID = os.environ.get("MODEL_ID", "GrassData/cliptagger-12b")
8
- BASE_PROCESSOR_ID = os.environ.get("BASE_PROCESSOR_ID", "google/gemma-3-12b-it")
9
- HF_TOKEN = os.environ.get("HF_TOKEN")
 
 
 
 
 
 
10
 
11
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
12
- DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
13
 
14
- # ---- Load processor (from base) + model (from your FT) ----
15
- try:
16
- # Processor comes from base VLM repo (has preprocessor_config.json)
17
- processor = AutoProcessor.from_pretrained(
18
- BASE_PROCESSOR_ID, token=HF_TOKEN, trust_remote_code=True
19
- )
20
- except Exception as e:
21
- raise RuntimeError(f"Failed to load processor from {BASE_PROCESSOR_ID}: {e}")
22
-
23
- # Optional: get a fast tokenizer if processor doesn't expose one
24
- tokenizer = getattr(processor, "tokenizer", None)
25
- if tokenizer is None:
26
- tokenizer = AutoTokenizer.from_pretrained(
27
- BASE_PROCESSOR_ID, token=HF_TOKEN, trust_remote_code=True, use_fast=True
28
- )
29
-
30
- # Your fine-tuned weights
31
- model = AutoModelForCausalLM.from_pretrained(
32
- MODEL_ID,
33
- token=HF_TOKEN,
34
- torch_dtype=DTYPE,
35
- device_map="auto",
36
- trust_remote_code=True,
37
- )
38
-
39
- # Prompts (system + user, as given)
40
  SYSTEM_PROMPT = (
41
  "You are an image annotation API trained to analyze YouTube video keyframes. "
42
  "You will be given instructions on the output format, what to caption, and how to perform your job. "
@@ -51,76 +33,201 @@ Your job is to extract detailed **factual elements directly visible** in the ima
51
  Return JSON in this structure:
52
 
53
  {
54
- "description": "...",
55
- "objects": ["..."],
56
- "actions": ["..."],
57
- "environment": "...",
58
- "content_type": "...",
59
- "specific_style": "...",
60
- "production_quality": "...",
61
- "summary": "...",
62
- "logos": ["..."]
63
  }
64
 
65
  Rules:
66
- - Be specific and literal.
67
- - No mood/emotion/narrative unless explicit.
68
- - No artistic/cinematic analysis.
69
- - Include the language of any visible text (e.g., "English text").
70
- - 10 objects, 5 actions.
71
- - 'logos' must be [] if none are present.
72
- - Strictly valid JSON, properly escaped.
73
- - Output only JSON, no extra text.
74
  """
75
 
76
- def run_inference(image: Image.Image):
77
- # Messages
78
- messages = [
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
80
- {"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": USER_PROMPT}]}
 
81
  ]
82
 
83
- prompt_inputs = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
84
- inputs = processor(text=prompt_inputs, images=image, return_tensors="pt").to(model.device)
85
-
86
- with torch.inference_mode():
87
- out = model.generate(
88
- **inputs,
89
- do_sample=False, # deterministic since temp=0.1
90
- temperature=0.1,
91
- max_new_tokens=2000,
92
- eos_token_id=processor.tokenizer.eos_token_id,
93
- response_format={"type": "json_object"} # force JSON mode
94
  )
 
95
 
96
- text = processor.decode(out[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
- # Clean parse
 
99
  try:
100
- parsed = json.loads(text)
101
- pretty = json.dumps(parsed, indent=2)
102
- return pretty, parsed
103
  except Exception:
104
- return text, {"error": "Invalid JSON"}
 
 
 
 
 
 
 
 
 
105
 
106
- def ui_submit(img):
107
- if img is None:
108
- return "Please upload an image.", None
109
- return run_inference(img)
110
 
111
- # ---- UI ----
112
- with gr.Blocks(title="ClipTagger-12B Keyframe Annotator") as demo:
113
- gr.Markdown("# ClipTagger-12B Keyframe Annotator\nUpload a photo to get structured JSON annotations.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
  with gr.Row():
116
  with gr.Column(scale=1):
117
  image = gr.Image(type="pil", label="Upload Image", image_mode="RGB")
118
- btn = gr.Button("Annotate", variant="primary")
119
-
120
  with gr.Column(scale=1):
121
- out_text = gr.Code(label="Model Output (JSON)")
122
  out_json = gr.JSON(label="Parsed JSON")
 
 
 
 
 
123
 
124
- btn.click(ui_submit, inputs=[image], outputs=[out_text, out_json])
125
 
126
  demo.queue(max_size=32, concurrency_count=1).launch()
 
1
+ import os, json, re, traceback
2
  import gradio as gr
3
  from PIL import Image
4
  import torch
 
5
 
6
+ # --------------------------
7
+ # Config (via Space secrets)
8
+ # --------------------------
9
+ # ADAPTER_ID: your fine-tune adapter repo (PEFT). Example: GrassData/cliptagger-12b
10
+ # BASE_ID: the Gemma-3 VLM base you fine-tuned from. Example: google/gemma-3-12b-it (gated)
11
+ # HF_TOKEN: user access token that has access to BASE_ID (if gated)
12
+ ADAPTER_ID = os.environ.get("MODEL_ID", os.environ.get("ADAPTER_ID", "GrassData/cliptagger-12b"))
13
+ BASE_ID = os.environ.get("BASE_ID", "google/gemma-3-12b-it")
14
+ HF_TOKEN = os.environ.get("HF_TOKEN")
15
 
16
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
17
+ DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
18
 
19
+ # --------------------------
20
+ # Prompts (your spec)
21
+ # --------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  SYSTEM_PROMPT = (
23
  "You are an image annotation API trained to analyze YouTube video keyframes. "
24
  "You will be given instructions on the output format, what to caption, and how to perform your job. "
 
33
  Return JSON in this structure:
34
 
35
  {
36
+ "description": "A detailed, factual account of what is visibly happening (4 sentences max). Only mention concrete elements or actions that are clearly shown. Do not include anything about how the image is styled, shot, or composed. Do not lead the description with something like 'This image shows' or 'this keyframe is...', just get right into the details.",
37
+ "objects": ["object1 with relevant visual details", "object2 with relevant visual details", ...],
38
+ "actions": ["action1 with participants and context", "action2 with participants and context", ...],
39
+ "environment": "Detailed factual description of the setting and atmosphere based on visible cues (e.g., interior of a classroom with fluorescent lighting, or outdoor forest path with snow-covered trees).",
40
+ "content_type": "The type of content it is, e.g. 'real-world footage', 'video game', 'animation', 'cartoon', 'CGI', 'VTuber', etc.",
41
+ "specific_style": "Specific genre, aesthetic, or platform style (e.g., anime, 3D animation, mobile gameplay, vlog, tutorial, news broadcast, etc.)",
42
+ "production_quality": "Visible production level: e.g., 'professional studio', 'amateur handheld', 'webcam recording', 'TV broadcast', etc.",
43
+ "summary": "One clear, comprehensive sentence summarizing the visual content of the frame. Like the description, get right to the point.",
44
+ "logos": ["logo1 with visual description", "logo2 with visual description", ...]
45
  }
46
 
47
  Rules:
48
+ - Be specific and literal. Focus on what is explicitly visible.
49
+ - Do NOT include interpretations of emotion, mood, or narrative unless it's visually explicit.
50
+ - No artistic or cinematic analysis.
51
+ - Always include the language of any text in the image if present as an object, e.g. "English text", "Japanese text", "Russian text", etc.
52
+ - Maximum 10 objects and 5 actions.
53
+ - Return an empty array for 'logos' if none are present.
54
+ - Always output strictly valid JSON with proper escaping.
55
+ - Output **only the JSON**, no extra text or explanation.
56
  """
57
 
58
+ # --------------------------
59
+ # Load base + adapter (PEFT)
60
+ # --------------------------
61
+ def load_model_stack():
62
+ from transformers import AutoProcessor, AutoTokenizer, AutoConfig, AutoModelForCausalLM
63
+ from peft import PeftModel
64
+
65
+ # Prefer loading processor from BASE_ID (has preproc files). If you've vendored
66
+ # processor files into the adapter repo, you can switch to ADAPTER_ID here.
67
+ try:
68
+ processor = AutoProcessor.from_pretrained(
69
+ BASE_ID, token=HF_TOKEN, trust_remote_code=True, use_fast=True
70
+ )
71
+ except TypeError:
72
+ # Some processor classes don't accept use_fast
73
+ processor = AutoProcessor.from_pretrained(
74
+ BASE_ID, token=HF_TOKEN, trust_remote_code=True
75
+ )
76
+
77
+ # Sanity check: ADAPTER should not be CLIP-only
78
+ cfg = AutoConfig.from_pretrained(ADAPTER_ID, token=HF_TOKEN, trust_remote_code=True)
79
+ if cfg.__class__.__name__.lower().startswith("clip"):
80
+ raise RuntimeError(
81
+ f"MODEL_ID/ADAPTER_ID ({ADAPTER_ID}) resolves to a CLIP/encoder config "
82
+ "and cannot be used with AutoModelForCausalLM. Point to your PEFT adapter "
83
+ "repo (Gemma-3 VLM adapters) or a full causal VLM checkpoint."
84
+ )
85
+
86
+ base = AutoModelForCausalLM.from_pretrained(
87
+ BASE_ID,
88
+ token=HF_TOKEN,
89
+ device_map="auto",
90
+ torch_dtype=DTYPE,
91
+ trust_remote_code=True,
92
+ )
93
+
94
+ model = PeftModel.from_pretrained(
95
+ base,
96
+ ADAPTER_ID,
97
+ token=HF_TOKEN,
98
+ )
99
+
100
+ # Merge adapters for faster inference (optional)
101
+ try:
102
+ model = model.merge_and_unload()
103
+ except Exception:
104
+ # If merge isn’t supported, we keep PEFT wrapper
105
+ pass
106
+
107
+ tokenizer = getattr(processor, "tokenizer", None)
108
+ if tokenizer is None:
109
+ tokenizer = AutoTokenizer.from_pretrained(
110
+ BASE_ID, token=HF_TOKEN, trust_remote_code=True, use_fast=True
111
+ )
112
+
113
+ return processor, tokenizer, model
114
+
115
+ LOAD_ERROR = None
116
+ processor = tokenizer = model = None
117
+ try:
118
+ processor, tokenizer, model = load_model_stack()
119
+ except Exception as e:
120
+ LOAD_ERROR = f"{e}\n\n{traceback.format_exc()}"
121
+
122
+ # --------------------------
123
+ # Inference
124
+ # --------------------------
125
+ def build_messages(image: Image.Image):
126
+ return [
127
  {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
128
+ {"role": "user", "content": [{"type": "image", "image": image},
129
+ {"type": "text", "text": USER_PROMPT}]}
130
  ]
131
 
132
+ def generate_json(image: Image.Image):
133
+ if image is None:
134
+ return "Please upload an image.", None, False
135
+
136
+ if model is None or processor is None:
137
+ msg = (
138
+ "❌ Model failed to load.\n\n"
139
+ f"{LOAD_ERROR or 'Unknown error. Check BASE_ID/ADAPTER_ID/HF_TOKEN.'}\n"
140
+ "• Ensure HF_TOKEN belongs to an account with access to the BASE_ID (if gated).\n"
141
+ "• Ensure MODEL_ID/ADAPTER_ID points to a Gemma-3 VLM PEFT adapter (not CLIP).\n"
142
+ " Optionally vendor processor files into your adapter repo."
143
  )
144
+ return msg, None, False
145
 
146
+ # Prepare chat prompt
147
+ if hasattr(processor, "apply_chat_template"):
148
+ prompt = processor.apply_chat_template(
149
+ build_messages(image), add_generation_prompt=True, tokenize=False
150
+ )
151
+ else:
152
+ # Fallback join (rare for Gemma-3)
153
+ msgs = build_messages(image)
154
+ prompt = ""
155
+ for m in msgs:
156
+ role = m["role"].upper()
157
+ for chunk in m["content"]:
158
+ if chunk["type"] == "text":
159
+ prompt += f"{role}: {chunk['text']}\n"
160
+ elif chunk["type"] == "image":
161
+ prompt += f"{role}: [IMAGE]\n"
162
+
163
+ # Tokenize with vision
164
+ inputs = processor(text=prompt, images=image, return_tensors="pt").to(model.device)
165
+
166
+ # Generate with fixed params
167
+ gen_kwargs = dict(
168
+ max_new_tokens=2000,
169
+ temperature=0.1,
170
+ eos_token_id=getattr(tokenizer, "eos_token_id", None),
171
+ )
172
 
173
+ # Ask for JSON-only if supported by the model head
174
+ # (Some trust_remote_code models accept response_format)
175
  try:
176
+ gen_kwargs["response_format"] = {"type": "json_object"}
 
 
177
  except Exception:
178
+ pass
179
+
180
+ with torch.inference_mode():
181
+ out = model.generate(**inputs, **gen_kwargs)
182
+
183
+ # Decode
184
+ if hasattr(processor, "decode"):
185
+ text = processor.decode(out[0], skip_special_tokens=True)
186
+ else:
187
+ text = tokenizer.decode(out[0], skip_special_tokens=True)
188
 
189
+ # Best-effort: trim any preamble
190
+ if USER_PROMPT in text:
191
+ text = text.split(USER_PROMPT)[-1].strip()
 
192
 
193
+ # Parse JSON
194
+ try:
195
+ parsed = json.loads(text)
196
+ return json.dumps(parsed, indent=2), parsed, True
197
+ except Exception:
198
+ # Try to recover a top-level {...}
199
+ m = re.search(r"\{(?:[^{}]|(?R))*\}", text, flags=re.DOTALL)
200
+ if m:
201
+ try:
202
+ parsed = json.loads(m.group(0))
203
+ return json.dumps(parsed, indent=2), parsed, True
204
+ except Exception:
205
+ pass
206
+ return text, None, False
207
+
208
+ # --------------------------
209
+ # UI
210
+ # --------------------------
211
+ with gr.Blocks(theme=gr.themes.Soft(), analytics_enabled=False, title="Keyframe Annotator (Gemma-3 + Adapter)") as demo:
212
+ gr.Markdown("# Keyframe Annotator (Gemma-3-12B FT)\nUpload an image to get **strict JSON** annotations.")
213
+
214
+ if LOAD_ERROR:
215
+ with gr.Accordion("Startup Error Details", open=False):
216
+ gr.Markdown(f"```\n{LOAD_ERROR}\n```")
217
 
218
  with gr.Row():
219
  with gr.Column(scale=1):
220
  image = gr.Image(type="pil", label="Upload Image", image_mode="RGB")
221
+ annotate_btn = gr.Button("Annotate", variant="primary")
 
222
  with gr.Column(scale=1):
223
+ out_code = gr.Code(label="Model Output (JSON or error text)")
224
  out_json = gr.JSON(label="Parsed JSON")
225
+ ok_flag = gr.Checkbox(label="Valid JSON", value=False, interactive=False)
226
+
227
+ def on_submit(img):
228
+ text, js, ok = generate_json(img)
229
+ return text, js, ok
230
 
231
+ annotate_btn.click(on_submit, inputs=[image], outputs=[out_code, out_json, ok_flag])
232
 
233
  demo.queue(max_size=32, concurrency_count=1).launch()