andrejrad commited on
Commit
f617893
·
verified ·
1 Parent(s): b989be2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -24
app.py CHANGED
@@ -9,13 +9,13 @@ from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM, Aut
9
  MODEL_ID = os.environ.get("MODEL_ID", "inference-net/ClipTagger-12b")
10
  HF_TOKEN = os.environ.get("HF_TOKEN")
11
 
12
- # Latency/quality knobs (tuned for A100)
13
  TEMP = 0.1 # per model docs
14
- MAX_NEW_TOKENS = 384 # fast + sufficient for schema (raise to 512/768 later if needed)
15
- VISION_LONG_SIDE = 896 # matches your vision_config.image_size
16
  DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
17
 
18
- # ===== Prompts (exact, no example output) =====
19
  SYSTEM_PROMPT = (
20
  "You are an image annotation API trained to analyze YouTube video keyframes. "
21
  "You will be given instructions on the output format, what to caption, and how to perform your job. "
@@ -50,20 +50,21 @@ Rules:
50
  - Return an empty array for 'logos' if none are present.
51
  - Always output strictly valid JSON with proper escaping.
52
  - Output **only the JSON**, no extra text or explanation.
53
- - Do not use placeholder strings or ellipses ('...'). Replace with concrete values directly observed in the image only.
54
  """
55
 
56
  # ===== Utils =====
57
- def extract_top_level_json(s: str):
58
- """Parse JSON; if there’s surrounding text, extract the first balanced {...} block."""
59
- try:
60
- return json.loads(s)
61
- except Exception:
62
- pass
63
  start, depth = None, 0
64
  for i, ch in enumerate(s):
65
  if ch == '{':
66
- if depth == 0: start = i
 
67
  depth += 1
68
  elif ch == '}':
69
  if depth > 0:
@@ -71,12 +72,13 @@ def extract_top_level_json(s: str):
71
  if depth == 0 and start is not None:
72
  chunk = s[start:i+1]
73
  try:
74
- return json.loads(chunk)
75
  except Exception:
76
- start = None
77
- return None
 
78
 
79
- def build_messages(image):
80
  return [
81
  {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
82
  {"role": "user", "content": [{"type": "image", "image": image},
@@ -100,18 +102,24 @@ try:
100
  if "clip" in cfg.__class__.__name__.lower():
101
  raise RuntimeError(f"MODEL_ID '{MODEL_ID}' is a CLIP/encoder repo; need a causal VLM.")
102
 
 
103
  processor = AutoProcessor.from_pretrained(MODEL_ID, token=HF_TOKEN, trust_remote_code=True)
 
 
104
  model = AutoModelForCausalLM.from_pretrained(
105
  MODEL_ID,
106
  token=HF_TOKEN,
107
  device_map="cuda", # keep on A100
108
  torch_dtype=DTYPE,
109
  trust_remote_code=True,
110
- # quantization_config=None, # uncomment to force full precision if you removed quant
111
  )
 
112
  tokenizer = getattr(processor, "tokenizer", None) or AutoTokenizer.from_pretrained(
113
  MODEL_ID, token=HF_TOKEN, trust_remote_code=True, use_fast=True
114
  )
 
 
115
  except Exception as e:
116
  LOAD_ERROR = f"{e}\n\n{traceback.format_exc()}"
117
 
@@ -124,25 +132,39 @@ def generate(image: Image.Image) -> Tuple[str, Dict[str, Any] | None, bool]:
124
 
125
  image = resize_to_vision(image, VISION_LONG_SIDE)
126
 
127
- # Chat prompt
128
  if hasattr(processor, "apply_chat_template"):
129
  prompt = processor.apply_chat_template(build_messages(image), add_generation_prompt=True, tokenize=False)
130
  else:
131
  prompt = USER_PROMPT
132
 
 
133
  inputs = processor(text=prompt, images=image, return_tensors="pt").to(model.device)
134
  eos = getattr(model.config, "eos_token_id", None)
135
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  tried = []
137
 
138
  # (1) Greedy (fast, stable)
139
  try:
140
  g = dict(do_sample=False, max_new_tokens=MAX_NEW_TOKENS)
141
- if eos is not None: g["eos_token_id"] = eos
 
142
  with torch.inference_mode():
143
  out = model.generate(**inputs, **g)
144
- text = processor.decode(out[0], skip_special_tokens=True)
145
- parsed = extract_top_level_json(text)
146
  if isinstance(parsed, dict) and "..." not in json.dumps(parsed):
147
  return json.dumps(parsed, indent=2), parsed, True
148
  tried.append(("greedy", "parse-failed-or-ellipses"))
@@ -152,11 +174,12 @@ def generate(image: Image.Image) -> Tuple[str, Dict[str, Any] | None, bool]:
152
  # (2) Short sampled retry
153
  try:
154
  g = dict(do_sample=True, temperature=TEMP, max_new_tokens=MAX_NEW_TOKENS)
155
- if eos is not None: g["eos_token_id"] = eos
 
156
  with torch.inference_mode():
157
  out = model.generate(**inputs, **g)
158
- text = processor.decode(out[0], skip_special_tokens=True)
159
- parsed = extract_top_level_json(text)
160
  if isinstance(parsed, dict) and "..." not in json.dumps(parsed):
161
  return json.dumps(parsed, indent=2), parsed, True
162
  tried.append(("sample_t0.1", "parse-failed-or-ellipses"))
 
9
  MODEL_ID = os.environ.get("MODEL_ID", "inference-net/ClipTagger-12b")
10
  HF_TOKEN = os.environ.get("HF_TOKEN")
11
 
12
+ # Latency/quality knobs (tuned for A100-80GB)
13
  TEMP = 0.1 # per model docs
14
+ MAX_NEW_TOKENS = 384 # fast + sufficient for schema (raise to 512/768 if needed)
15
+ VISION_LONG_SIDE = 896 # matches vision_config.image_size
16
  DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
17
 
18
+ # ===== Prompts (schema-only; no example output) =====
19
  SYSTEM_PROMPT = (
20
  "You are an image annotation API trained to analyze YouTube video keyframes. "
21
  "You will be given instructions on the output format, what to caption, and how to perform your job. "
 
50
  - Return an empty array for 'logos' if none are present.
51
  - Always output strictly valid JSON with proper escaping.
52
  - Output **only the JSON**, no extra text or explanation.
53
+ - Do **not** copy any example strings from the instructions or use ellipses ('...'). Produce concrete values drawn from the image only.
54
  """
55
 
56
  # ===== Utils =====
57
+ def extract_last_json(s: str):
58
+ """
59
+ Return the last balanced {...} JSON object found in the string.
60
+ This avoids grabbing the schema block from the prompt if it echoes.
61
+ """
62
+ last = None
63
  start, depth = None, 0
64
  for i, ch in enumerate(s):
65
  if ch == '{':
66
+ if depth == 0:
67
+ start = i
68
  depth += 1
69
  elif ch == '}':
70
  if depth > 0:
 
72
  if depth == 0 and start is not None:
73
  chunk = s[start:i+1]
74
  try:
75
+ last = json.loads(chunk)
76
  except Exception:
77
+ pass
78
+ start = None
79
+ return last
80
 
81
+ def build_messages(image: Image.Image):
82
  return [
83
  {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
84
  {"role": "user", "content": [{"type": "image", "image": image},
 
102
  if "clip" in cfg.__class__.__name__.lower():
103
  raise RuntimeError(f"MODEL_ID '{MODEL_ID}' is a CLIP/encoder repo; need a causal VLM.")
104
 
105
+ print("[boot] loading processor…", flush=True)
106
  processor = AutoProcessor.from_pretrained(MODEL_ID, token=HF_TOKEN, trust_remote_code=True)
107
+
108
+ print("[boot] loading model…", flush=True)
109
  model = AutoModelForCausalLM.from_pretrained(
110
  MODEL_ID,
111
  token=HF_TOKEN,
112
  device_map="cuda", # keep on A100
113
  torch_dtype=DTYPE,
114
  trust_remote_code=True,
115
+ # quantization_config=None, # uncomment to force full precision if you removed quant in repo
116
  )
117
+
118
  tokenizer = getattr(processor, "tokenizer", None) or AutoTokenizer.from_pretrained(
119
  MODEL_ID, token=HF_TOKEN, trust_remote_code=True, use_fast=True
120
  )
121
+ print("[boot] ready.", flush=True)
122
+
123
  except Exception as e:
124
  LOAD_ERROR = f"{e}\n\n{traceback.format_exc()}"
125
 
 
132
 
133
  image = resize_to_vision(image, VISION_LONG_SIDE)
134
 
135
+ # Build chat prompt
136
  if hasattr(processor, "apply_chat_template"):
137
  prompt = processor.apply_chat_template(build_messages(image), add_generation_prompt=True, tokenize=False)
138
  else:
139
  prompt = USER_PROMPT
140
 
141
+ # Tokenize with vision
142
  inputs = processor(text=prompt, images=image, return_tensors="pt").to(model.device)
143
  eos = getattr(model.config, "eos_token_id", None)
144
 
145
+ def _decode_only_new(out_ids):
146
+ """
147
+ Decode only the newly generated tokens (exclude prompt tokens),
148
+ so we don't accidentally parse the schema block from the prompt.
149
+ """
150
+ input_len = inputs["input_ids"].shape[1]
151
+ gen_ids = out_ids[0][input_len:]
152
+ # Prefer processor.decode if available (some VLMs customize decoding)
153
+ if hasattr(processor, "decode"):
154
+ return processor.decode(gen_ids, skip_special_tokens=True)
155
+ return tokenizer.decode(gen_ids, skip_special_tokens=True)
156
+
157
  tried = []
158
 
159
  # (1) Greedy (fast, stable)
160
  try:
161
  g = dict(do_sample=False, max_new_tokens=MAX_NEW_TOKENS)
162
+ if eos is not None:
163
+ g["eos_token_id"] = eos
164
  with torch.inference_mode():
165
  out = model.generate(**inputs, **g)
166
+ text = _decode_only_new(out)
167
+ parsed = extract_last_json(text)
168
  if isinstance(parsed, dict) and "..." not in json.dumps(parsed):
169
  return json.dumps(parsed, indent=2), parsed, True
170
  tried.append(("greedy", "parse-failed-or-ellipses"))
 
174
  # (2) Short sampled retry
175
  try:
176
  g = dict(do_sample=True, temperature=TEMP, max_new_tokens=MAX_NEW_TOKENS)
177
+ if eos is not None:
178
+ g["eos_token_id"] = eos
179
  with torch.inference_mode():
180
  out = model.generate(**inputs, **g)
181
+ text = _decode_only_new(out)
182
+ parsed = extract_last_json(text)
183
  if isinstance(parsed, dict) and "..." not in json.dumps(parsed):
184
  return json.dumps(parsed, indent=2), parsed, True
185
  tried.append(("sample_t0.1", "parse-failed-or-ellipses"))