Spaces:
Running
on
A100
Running
on
A100
Update app.py
Browse files
app.py
CHANGED
@@ -1,42 +1,24 @@
|
|
1 |
-
import os, json, re
|
2 |
import gradio as gr
|
3 |
from PIL import Image
|
4 |
import torch
|
5 |
-
from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM
|
6 |
|
7 |
-
|
8 |
-
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
12 |
-
DTYPE
|
13 |
|
14 |
-
#
|
15 |
-
|
16 |
-
|
17 |
-
processor = AutoProcessor.from_pretrained(
|
18 |
-
BASE_PROCESSOR_ID, token=HF_TOKEN, trust_remote_code=True
|
19 |
-
)
|
20 |
-
except Exception as e:
|
21 |
-
raise RuntimeError(f"Failed to load processor from {BASE_PROCESSOR_ID}: {e}")
|
22 |
-
|
23 |
-
# Optional: get a fast tokenizer if processor doesn't expose one
|
24 |
-
tokenizer = getattr(processor, "tokenizer", None)
|
25 |
-
if tokenizer is None:
|
26 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
27 |
-
BASE_PROCESSOR_ID, token=HF_TOKEN, trust_remote_code=True, use_fast=True
|
28 |
-
)
|
29 |
-
|
30 |
-
# Your fine-tuned weights
|
31 |
-
model = AutoModelForCausalLM.from_pretrained(
|
32 |
-
MODEL_ID,
|
33 |
-
token=HF_TOKEN,
|
34 |
-
torch_dtype=DTYPE,
|
35 |
-
device_map="auto",
|
36 |
-
trust_remote_code=True,
|
37 |
-
)
|
38 |
-
|
39 |
-
# Prompts (system + user, as given)
|
40 |
SYSTEM_PROMPT = (
|
41 |
"You are an image annotation API trained to analyze YouTube video keyframes. "
|
42 |
"You will be given instructions on the output format, what to caption, and how to perform your job. "
|
@@ -51,76 +33,201 @@ Your job is to extract detailed **factual elements directly visible** in the ima
|
|
51 |
Return JSON in this structure:
|
52 |
|
53 |
{
|
54 |
-
"description": "...",
|
55 |
-
"objects": ["...
|
56 |
-
"actions": ["...
|
57 |
-
"environment": "
|
58 |
-
"content_type": "
|
59 |
-
"specific_style": "
|
60 |
-
"production_quality": "
|
61 |
-
"summary": "
|
62 |
-
"logos": ["...
|
63 |
}
|
64 |
|
65 |
Rules:
|
66 |
-
- Be specific and literal.
|
67 |
-
-
|
68 |
-
- No artistic
|
69 |
-
-
|
70 |
-
-
|
71 |
-
- 'logos'
|
72 |
-
-
|
73 |
-
- Output only JSON
|
74 |
"""
|
75 |
|
76 |
-
|
77 |
-
|
78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
{"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
|
80 |
-
{"role": "user",
|
|
|
81 |
]
|
82 |
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
)
|
|
|
95 |
|
96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
|
98 |
-
#
|
|
|
99 |
try:
|
100 |
-
|
101 |
-
pretty = json.dumps(parsed, indent=2)
|
102 |
-
return pretty, parsed
|
103 |
except Exception:
|
104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
|
106 |
-
|
107 |
-
if
|
108 |
-
|
109 |
-
return run_inference(img)
|
110 |
|
111 |
-
#
|
112 |
-
|
113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
|
115 |
with gr.Row():
|
116 |
with gr.Column(scale=1):
|
117 |
image = gr.Image(type="pil", label="Upload Image", image_mode="RGB")
|
118 |
-
|
119 |
-
|
120 |
with gr.Column(scale=1):
|
121 |
-
|
122 |
out_json = gr.JSON(label="Parsed JSON")
|
|
|
|
|
|
|
|
|
|
|
123 |
|
124 |
-
|
125 |
|
126 |
demo.queue(max_size=32, concurrency_count=1).launch()
|
|
|
1 |
+
import os, json, re, traceback
|
2 |
import gradio as gr
|
3 |
from PIL import Image
|
4 |
import torch
|
|
|
5 |
|
6 |
+
# --------------------------
|
7 |
+
# Config (via Space secrets)
|
8 |
+
# --------------------------
|
9 |
+
# ADAPTER_ID: your fine-tune adapter repo (PEFT). Example: GrassData/cliptagger-12b
|
10 |
+
# BASE_ID: the Gemma-3 VLM base you fine-tuned from. Example: google/gemma-3-12b-it (gated)
|
11 |
+
# HF_TOKEN: user access token that has access to BASE_ID (if gated)
|
12 |
+
ADAPTER_ID = os.environ.get("MODEL_ID", os.environ.get("ADAPTER_ID", "GrassData/cliptagger-12b"))
|
13 |
+
BASE_ID = os.environ.get("BASE_ID", "google/gemma-3-12b-it")
|
14 |
+
HF_TOKEN = os.environ.get("HF_TOKEN")
|
15 |
|
16 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
17 |
+
DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
|
18 |
|
19 |
+
# --------------------------
|
20 |
+
# Prompts (your spec)
|
21 |
+
# --------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
SYSTEM_PROMPT = (
|
23 |
"You are an image annotation API trained to analyze YouTube video keyframes. "
|
24 |
"You will be given instructions on the output format, what to caption, and how to perform your job. "
|
|
|
33 |
Return JSON in this structure:
|
34 |
|
35 |
{
|
36 |
+
"description": "A detailed, factual account of what is visibly happening (4 sentences max). Only mention concrete elements or actions that are clearly shown. Do not include anything about how the image is styled, shot, or composed. Do not lead the description with something like 'This image shows' or 'this keyframe is...', just get right into the details.",
|
37 |
+
"objects": ["object1 with relevant visual details", "object2 with relevant visual details", ...],
|
38 |
+
"actions": ["action1 with participants and context", "action2 with participants and context", ...],
|
39 |
+
"environment": "Detailed factual description of the setting and atmosphere based on visible cues (e.g., interior of a classroom with fluorescent lighting, or outdoor forest path with snow-covered trees).",
|
40 |
+
"content_type": "The type of content it is, e.g. 'real-world footage', 'video game', 'animation', 'cartoon', 'CGI', 'VTuber', etc.",
|
41 |
+
"specific_style": "Specific genre, aesthetic, or platform style (e.g., anime, 3D animation, mobile gameplay, vlog, tutorial, news broadcast, etc.)",
|
42 |
+
"production_quality": "Visible production level: e.g., 'professional studio', 'amateur handheld', 'webcam recording', 'TV broadcast', etc.",
|
43 |
+
"summary": "One clear, comprehensive sentence summarizing the visual content of the frame. Like the description, get right to the point.",
|
44 |
+
"logos": ["logo1 with visual description", "logo2 with visual description", ...]
|
45 |
}
|
46 |
|
47 |
Rules:
|
48 |
+
- Be specific and literal. Focus on what is explicitly visible.
|
49 |
+
- Do NOT include interpretations of emotion, mood, or narrative unless it's visually explicit.
|
50 |
+
- No artistic or cinematic analysis.
|
51 |
+
- Always include the language of any text in the image if present as an object, e.g. "English text", "Japanese text", "Russian text", etc.
|
52 |
+
- Maximum 10 objects and 5 actions.
|
53 |
+
- Return an empty array for 'logos' if none are present.
|
54 |
+
- Always output strictly valid JSON with proper escaping.
|
55 |
+
- Output **only the JSON**, no extra text or explanation.
|
56 |
"""
|
57 |
|
58 |
+
# --------------------------
|
59 |
+
# Load base + adapter (PEFT)
|
60 |
+
# --------------------------
|
61 |
+
def load_model_stack():
|
62 |
+
from transformers import AutoProcessor, AutoTokenizer, AutoConfig, AutoModelForCausalLM
|
63 |
+
from peft import PeftModel
|
64 |
+
|
65 |
+
# Prefer loading processor from BASE_ID (has preproc files). If you've vendored
|
66 |
+
# processor files into the adapter repo, you can switch to ADAPTER_ID here.
|
67 |
+
try:
|
68 |
+
processor = AutoProcessor.from_pretrained(
|
69 |
+
BASE_ID, token=HF_TOKEN, trust_remote_code=True, use_fast=True
|
70 |
+
)
|
71 |
+
except TypeError:
|
72 |
+
# Some processor classes don't accept use_fast
|
73 |
+
processor = AutoProcessor.from_pretrained(
|
74 |
+
BASE_ID, token=HF_TOKEN, trust_remote_code=True
|
75 |
+
)
|
76 |
+
|
77 |
+
# Sanity check: ADAPTER should not be CLIP-only
|
78 |
+
cfg = AutoConfig.from_pretrained(ADAPTER_ID, token=HF_TOKEN, trust_remote_code=True)
|
79 |
+
if cfg.__class__.__name__.lower().startswith("clip"):
|
80 |
+
raise RuntimeError(
|
81 |
+
f"MODEL_ID/ADAPTER_ID ({ADAPTER_ID}) resolves to a CLIP/encoder config "
|
82 |
+
"and cannot be used with AutoModelForCausalLM. Point to your PEFT adapter "
|
83 |
+
"repo (Gemma-3 VLM adapters) or a full causal VLM checkpoint."
|
84 |
+
)
|
85 |
+
|
86 |
+
base = AutoModelForCausalLM.from_pretrained(
|
87 |
+
BASE_ID,
|
88 |
+
token=HF_TOKEN,
|
89 |
+
device_map="auto",
|
90 |
+
torch_dtype=DTYPE,
|
91 |
+
trust_remote_code=True,
|
92 |
+
)
|
93 |
+
|
94 |
+
model = PeftModel.from_pretrained(
|
95 |
+
base,
|
96 |
+
ADAPTER_ID,
|
97 |
+
token=HF_TOKEN,
|
98 |
+
)
|
99 |
+
|
100 |
+
# Merge adapters for faster inference (optional)
|
101 |
+
try:
|
102 |
+
model = model.merge_and_unload()
|
103 |
+
except Exception:
|
104 |
+
# If merge isn’t supported, we keep PEFT wrapper
|
105 |
+
pass
|
106 |
+
|
107 |
+
tokenizer = getattr(processor, "tokenizer", None)
|
108 |
+
if tokenizer is None:
|
109 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
110 |
+
BASE_ID, token=HF_TOKEN, trust_remote_code=True, use_fast=True
|
111 |
+
)
|
112 |
+
|
113 |
+
return processor, tokenizer, model
|
114 |
+
|
115 |
+
LOAD_ERROR = None
|
116 |
+
processor = tokenizer = model = None
|
117 |
+
try:
|
118 |
+
processor, tokenizer, model = load_model_stack()
|
119 |
+
except Exception as e:
|
120 |
+
LOAD_ERROR = f"{e}\n\n{traceback.format_exc()}"
|
121 |
+
|
122 |
+
# --------------------------
|
123 |
+
# Inference
|
124 |
+
# --------------------------
|
125 |
+
def build_messages(image: Image.Image):
|
126 |
+
return [
|
127 |
{"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
|
128 |
+
{"role": "user", "content": [{"type": "image", "image": image},
|
129 |
+
{"type": "text", "text": USER_PROMPT}]}
|
130 |
]
|
131 |
|
132 |
+
def generate_json(image: Image.Image):
|
133 |
+
if image is None:
|
134 |
+
return "Please upload an image.", None, False
|
135 |
+
|
136 |
+
if model is None or processor is None:
|
137 |
+
msg = (
|
138 |
+
"❌ Model failed to load.\n\n"
|
139 |
+
f"{LOAD_ERROR or 'Unknown error. Check BASE_ID/ADAPTER_ID/HF_TOKEN.'}\n"
|
140 |
+
"• Ensure HF_TOKEN belongs to an account with access to the BASE_ID (if gated).\n"
|
141 |
+
"• Ensure MODEL_ID/ADAPTER_ID points to a Gemma-3 VLM PEFT adapter (not CLIP).\n"
|
142 |
+
"• Optionally vendor processor files into your adapter repo."
|
143 |
)
|
144 |
+
return msg, None, False
|
145 |
|
146 |
+
# Prepare chat prompt
|
147 |
+
if hasattr(processor, "apply_chat_template"):
|
148 |
+
prompt = processor.apply_chat_template(
|
149 |
+
build_messages(image), add_generation_prompt=True, tokenize=False
|
150 |
+
)
|
151 |
+
else:
|
152 |
+
# Fallback join (rare for Gemma-3)
|
153 |
+
msgs = build_messages(image)
|
154 |
+
prompt = ""
|
155 |
+
for m in msgs:
|
156 |
+
role = m["role"].upper()
|
157 |
+
for chunk in m["content"]:
|
158 |
+
if chunk["type"] == "text":
|
159 |
+
prompt += f"{role}: {chunk['text']}\n"
|
160 |
+
elif chunk["type"] == "image":
|
161 |
+
prompt += f"{role}: [IMAGE]\n"
|
162 |
+
|
163 |
+
# Tokenize with vision
|
164 |
+
inputs = processor(text=prompt, images=image, return_tensors="pt").to(model.device)
|
165 |
+
|
166 |
+
# Generate with fixed params
|
167 |
+
gen_kwargs = dict(
|
168 |
+
max_new_tokens=2000,
|
169 |
+
temperature=0.1,
|
170 |
+
eos_token_id=getattr(tokenizer, "eos_token_id", None),
|
171 |
+
)
|
172 |
|
173 |
+
# Ask for JSON-only if supported by the model head
|
174 |
+
# (Some trust_remote_code models accept response_format)
|
175 |
try:
|
176 |
+
gen_kwargs["response_format"] = {"type": "json_object"}
|
|
|
|
|
177 |
except Exception:
|
178 |
+
pass
|
179 |
+
|
180 |
+
with torch.inference_mode():
|
181 |
+
out = model.generate(**inputs, **gen_kwargs)
|
182 |
+
|
183 |
+
# Decode
|
184 |
+
if hasattr(processor, "decode"):
|
185 |
+
text = processor.decode(out[0], skip_special_tokens=True)
|
186 |
+
else:
|
187 |
+
text = tokenizer.decode(out[0], skip_special_tokens=True)
|
188 |
|
189 |
+
# Best-effort: trim any preamble
|
190 |
+
if USER_PROMPT in text:
|
191 |
+
text = text.split(USER_PROMPT)[-1].strip()
|
|
|
192 |
|
193 |
+
# Parse JSON
|
194 |
+
try:
|
195 |
+
parsed = json.loads(text)
|
196 |
+
return json.dumps(parsed, indent=2), parsed, True
|
197 |
+
except Exception:
|
198 |
+
# Try to recover a top-level {...}
|
199 |
+
m = re.search(r"\{(?:[^{}]|(?R))*\}", text, flags=re.DOTALL)
|
200 |
+
if m:
|
201 |
+
try:
|
202 |
+
parsed = json.loads(m.group(0))
|
203 |
+
return json.dumps(parsed, indent=2), parsed, True
|
204 |
+
except Exception:
|
205 |
+
pass
|
206 |
+
return text, None, False
|
207 |
+
|
208 |
+
# --------------------------
|
209 |
+
# UI
|
210 |
+
# --------------------------
|
211 |
+
with gr.Blocks(theme=gr.themes.Soft(), analytics_enabled=False, title="Keyframe Annotator (Gemma-3 + Adapter)") as demo:
|
212 |
+
gr.Markdown("# Keyframe Annotator (Gemma-3-12B FT)\nUpload an image to get **strict JSON** annotations.")
|
213 |
+
|
214 |
+
if LOAD_ERROR:
|
215 |
+
with gr.Accordion("Startup Error Details", open=False):
|
216 |
+
gr.Markdown(f"```\n{LOAD_ERROR}\n```")
|
217 |
|
218 |
with gr.Row():
|
219 |
with gr.Column(scale=1):
|
220 |
image = gr.Image(type="pil", label="Upload Image", image_mode="RGB")
|
221 |
+
annotate_btn = gr.Button("Annotate", variant="primary")
|
|
|
222 |
with gr.Column(scale=1):
|
223 |
+
out_code = gr.Code(label="Model Output (JSON or error text)")
|
224 |
out_json = gr.JSON(label="Parsed JSON")
|
225 |
+
ok_flag = gr.Checkbox(label="Valid JSON", value=False, interactive=False)
|
226 |
+
|
227 |
+
def on_submit(img):
|
228 |
+
text, js, ok = generate_json(img)
|
229 |
+
return text, js, ok
|
230 |
|
231 |
+
annotate_btn.click(on_submit, inputs=[image], outputs=[out_code, out_json, ok_flag])
|
232 |
|
233 |
demo.queue(max_size=32, concurrency_count=1).launch()
|