Spaces:
Running
on
A100
Running
on
A100
Update app.py
Browse files
app.py
CHANGED
@@ -9,13 +9,13 @@ from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM, Aut
|
|
9 |
MODEL_ID = os.environ.get("MODEL_ID", "inference-net/ClipTagger-12b")
|
10 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
11 |
|
12 |
-
# Latency/quality knobs (tuned for A100)
|
13 |
TEMP = 0.1 # per model docs
|
14 |
-
MAX_NEW_TOKENS = 384 # fast + sufficient for schema (raise to 512/768
|
15 |
-
VISION_LONG_SIDE = 896 # matches
|
16 |
DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
|
17 |
|
18 |
-
# ===== Prompts (
|
19 |
SYSTEM_PROMPT = (
|
20 |
"You are an image annotation API trained to analyze YouTube video keyframes. "
|
21 |
"You will be given instructions on the output format, what to caption, and how to perform your job. "
|
@@ -50,20 +50,21 @@ Rules:
|
|
50 |
- Return an empty array for 'logos' if none are present.
|
51 |
- Always output strictly valid JSON with proper escaping.
|
52 |
- Output **only the JSON**, no extra text or explanation.
|
53 |
-
- Do not
|
54 |
"""
|
55 |
|
56 |
# ===== Utils =====
|
57 |
-
def
|
58 |
-
"""
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
start, depth = None, 0
|
64 |
for i, ch in enumerate(s):
|
65 |
if ch == '{':
|
66 |
-
if depth == 0:
|
|
|
67 |
depth += 1
|
68 |
elif ch == '}':
|
69 |
if depth > 0:
|
@@ -71,12 +72,13 @@ def extract_top_level_json(s: str):
|
|
71 |
if depth == 0 and start is not None:
|
72 |
chunk = s[start:i+1]
|
73 |
try:
|
74 |
-
|
75 |
except Exception:
|
76 |
-
|
77 |
-
|
|
|
78 |
|
79 |
-
def build_messages(image):
|
80 |
return [
|
81 |
{"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
|
82 |
{"role": "user", "content": [{"type": "image", "image": image},
|
@@ -100,18 +102,24 @@ try:
|
|
100 |
if "clip" in cfg.__class__.__name__.lower():
|
101 |
raise RuntimeError(f"MODEL_ID '{MODEL_ID}' is a CLIP/encoder repo; need a causal VLM.")
|
102 |
|
|
|
103 |
processor = AutoProcessor.from_pretrained(MODEL_ID, token=HF_TOKEN, trust_remote_code=True)
|
|
|
|
|
104 |
model = AutoModelForCausalLM.from_pretrained(
|
105 |
MODEL_ID,
|
106 |
token=HF_TOKEN,
|
107 |
device_map="cuda", # keep on A100
|
108 |
torch_dtype=DTYPE,
|
109 |
trust_remote_code=True,
|
110 |
-
# quantization_config=None, # uncomment to force full precision if you removed quant
|
111 |
)
|
|
|
112 |
tokenizer = getattr(processor, "tokenizer", None) or AutoTokenizer.from_pretrained(
|
113 |
MODEL_ID, token=HF_TOKEN, trust_remote_code=True, use_fast=True
|
114 |
)
|
|
|
|
|
115 |
except Exception as e:
|
116 |
LOAD_ERROR = f"{e}\n\n{traceback.format_exc()}"
|
117 |
|
@@ -124,25 +132,39 @@ def generate(image: Image.Image) -> Tuple[str, Dict[str, Any] | None, bool]:
|
|
124 |
|
125 |
image = resize_to_vision(image, VISION_LONG_SIDE)
|
126 |
|
127 |
-
#
|
128 |
if hasattr(processor, "apply_chat_template"):
|
129 |
prompt = processor.apply_chat_template(build_messages(image), add_generation_prompt=True, tokenize=False)
|
130 |
else:
|
131 |
prompt = USER_PROMPT
|
132 |
|
|
|
133 |
inputs = processor(text=prompt, images=image, return_tensors="pt").to(model.device)
|
134 |
eos = getattr(model.config, "eos_token_id", None)
|
135 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
tried = []
|
137 |
|
138 |
# (1) Greedy (fast, stable)
|
139 |
try:
|
140 |
g = dict(do_sample=False, max_new_tokens=MAX_NEW_TOKENS)
|
141 |
-
if eos is not None:
|
|
|
142 |
with torch.inference_mode():
|
143 |
out = model.generate(**inputs, **g)
|
144 |
-
text =
|
145 |
-
parsed =
|
146 |
if isinstance(parsed, dict) and "..." not in json.dumps(parsed):
|
147 |
return json.dumps(parsed, indent=2), parsed, True
|
148 |
tried.append(("greedy", "parse-failed-or-ellipses"))
|
@@ -152,11 +174,12 @@ def generate(image: Image.Image) -> Tuple[str, Dict[str, Any] | None, bool]:
|
|
152 |
# (2) Short sampled retry
|
153 |
try:
|
154 |
g = dict(do_sample=True, temperature=TEMP, max_new_tokens=MAX_NEW_TOKENS)
|
155 |
-
if eos is not None:
|
|
|
156 |
with torch.inference_mode():
|
157 |
out = model.generate(**inputs, **g)
|
158 |
-
text =
|
159 |
-
parsed =
|
160 |
if isinstance(parsed, dict) and "..." not in json.dumps(parsed):
|
161 |
return json.dumps(parsed, indent=2), parsed, True
|
162 |
tried.append(("sample_t0.1", "parse-failed-or-ellipses"))
|
|
|
9 |
MODEL_ID = os.environ.get("MODEL_ID", "inference-net/ClipTagger-12b")
|
10 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
11 |
|
12 |
+
# Latency/quality knobs (tuned for A100-80GB)
|
13 |
TEMP = 0.1 # per model docs
|
14 |
+
MAX_NEW_TOKENS = 384 # fast + sufficient for schema (raise to 512/768 if needed)
|
15 |
+
VISION_LONG_SIDE = 896 # matches vision_config.image_size
|
16 |
DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
|
17 |
|
18 |
+
# ===== Prompts (schema-only; no example output) =====
|
19 |
SYSTEM_PROMPT = (
|
20 |
"You are an image annotation API trained to analyze YouTube video keyframes. "
|
21 |
"You will be given instructions on the output format, what to caption, and how to perform your job. "
|
|
|
50 |
- Return an empty array for 'logos' if none are present.
|
51 |
- Always output strictly valid JSON with proper escaping.
|
52 |
- Output **only the JSON**, no extra text or explanation.
|
53 |
+
- Do **not** copy any example strings from the instructions or use ellipses ('...'). Produce concrete values drawn from the image only.
|
54 |
"""
|
55 |
|
56 |
# ===== Utils =====
|
57 |
+
def extract_last_json(s: str):
|
58 |
+
"""
|
59 |
+
Return the last balanced {...} JSON object found in the string.
|
60 |
+
This avoids grabbing the schema block from the prompt if it echoes.
|
61 |
+
"""
|
62 |
+
last = None
|
63 |
start, depth = None, 0
|
64 |
for i, ch in enumerate(s):
|
65 |
if ch == '{':
|
66 |
+
if depth == 0:
|
67 |
+
start = i
|
68 |
depth += 1
|
69 |
elif ch == '}':
|
70 |
if depth > 0:
|
|
|
72 |
if depth == 0 and start is not None:
|
73 |
chunk = s[start:i+1]
|
74 |
try:
|
75 |
+
last = json.loads(chunk)
|
76 |
except Exception:
|
77 |
+
pass
|
78 |
+
start = None
|
79 |
+
return last
|
80 |
|
81 |
+
def build_messages(image: Image.Image):
|
82 |
return [
|
83 |
{"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
|
84 |
{"role": "user", "content": [{"type": "image", "image": image},
|
|
|
102 |
if "clip" in cfg.__class__.__name__.lower():
|
103 |
raise RuntimeError(f"MODEL_ID '{MODEL_ID}' is a CLIP/encoder repo; need a causal VLM.")
|
104 |
|
105 |
+
print("[boot] loading processor…", flush=True)
|
106 |
processor = AutoProcessor.from_pretrained(MODEL_ID, token=HF_TOKEN, trust_remote_code=True)
|
107 |
+
|
108 |
+
print("[boot] loading model…", flush=True)
|
109 |
model = AutoModelForCausalLM.from_pretrained(
|
110 |
MODEL_ID,
|
111 |
token=HF_TOKEN,
|
112 |
device_map="cuda", # keep on A100
|
113 |
torch_dtype=DTYPE,
|
114 |
trust_remote_code=True,
|
115 |
+
# quantization_config=None, # uncomment to force full precision if you removed quant in repo
|
116 |
)
|
117 |
+
|
118 |
tokenizer = getattr(processor, "tokenizer", None) or AutoTokenizer.from_pretrained(
|
119 |
MODEL_ID, token=HF_TOKEN, trust_remote_code=True, use_fast=True
|
120 |
)
|
121 |
+
print("[boot] ready.", flush=True)
|
122 |
+
|
123 |
except Exception as e:
|
124 |
LOAD_ERROR = f"{e}\n\n{traceback.format_exc()}"
|
125 |
|
|
|
132 |
|
133 |
image = resize_to_vision(image, VISION_LONG_SIDE)
|
134 |
|
135 |
+
# Build chat prompt
|
136 |
if hasattr(processor, "apply_chat_template"):
|
137 |
prompt = processor.apply_chat_template(build_messages(image), add_generation_prompt=True, tokenize=False)
|
138 |
else:
|
139 |
prompt = USER_PROMPT
|
140 |
|
141 |
+
# Tokenize with vision
|
142 |
inputs = processor(text=prompt, images=image, return_tensors="pt").to(model.device)
|
143 |
eos = getattr(model.config, "eos_token_id", None)
|
144 |
|
145 |
+
def _decode_only_new(out_ids):
|
146 |
+
"""
|
147 |
+
Decode only the newly generated tokens (exclude prompt tokens),
|
148 |
+
so we don't accidentally parse the schema block from the prompt.
|
149 |
+
"""
|
150 |
+
input_len = inputs["input_ids"].shape[1]
|
151 |
+
gen_ids = out_ids[0][input_len:]
|
152 |
+
# Prefer processor.decode if available (some VLMs customize decoding)
|
153 |
+
if hasattr(processor, "decode"):
|
154 |
+
return processor.decode(gen_ids, skip_special_tokens=True)
|
155 |
+
return tokenizer.decode(gen_ids, skip_special_tokens=True)
|
156 |
+
|
157 |
tried = []
|
158 |
|
159 |
# (1) Greedy (fast, stable)
|
160 |
try:
|
161 |
g = dict(do_sample=False, max_new_tokens=MAX_NEW_TOKENS)
|
162 |
+
if eos is not None:
|
163 |
+
g["eos_token_id"] = eos
|
164 |
with torch.inference_mode():
|
165 |
out = model.generate(**inputs, **g)
|
166 |
+
text = _decode_only_new(out)
|
167 |
+
parsed = extract_last_json(text)
|
168 |
if isinstance(parsed, dict) and "..." not in json.dumps(parsed):
|
169 |
return json.dumps(parsed, indent=2), parsed, True
|
170 |
tried.append(("greedy", "parse-failed-or-ellipses"))
|
|
|
174 |
# (2) Short sampled retry
|
175 |
try:
|
176 |
g = dict(do_sample=True, temperature=TEMP, max_new_tokens=MAX_NEW_TOKENS)
|
177 |
+
if eos is not None:
|
178 |
+
g["eos_token_id"] = eos
|
179 |
with torch.inference_mode():
|
180 |
out = model.generate(**inputs, **g)
|
181 |
+
text = _decode_only_new(out)
|
182 |
+
parsed = extract_last_json(text)
|
183 |
if isinstance(parsed, dict) and "..." not in json.dumps(parsed):
|
184 |
return json.dumps(parsed, indent=2), parsed, True
|
185 |
tried.append(("sample_t0.1", "parse-failed-or-ellipses"))
|