YucYux commited on
Commit
ca6c4d6
·
1 Parent(s): 13a411c

tried to fix model loading bug

Browse files
Files changed (1) hide show
  1. app.py +228 -931
app.py CHANGED
@@ -10,7 +10,7 @@ from PIL import Image
10
  import spaces
11
 
12
 
13
-
14
  def image_transform(image, resolution=256, normalize=True):
15
  image = transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BICUBIC)(image)
16
  image = transforms.CenterCrop((resolution, resolution))(image)
@@ -20,231 +20,84 @@ def image_transform(image, resolution=256, normalize=True):
20
  return image
21
 
22
  def add_gumbel_noise(logits, temperature):
23
- """
24
- Adds Gumbel noise to logits for stochastic sampling.
25
- Equivalent to argmax(logits + temperature * G) where G ~ Gumbel(0,1).
26
- This version is more numerically stable than a version involving exp() and division.
27
- """
28
- if abs(temperature) < 1e-9: # Effectively zero temperature
29
  return logits
30
- # Ensure logits are float64 for precision with noise, as suggested by user context
31
  logits = logits.to(torch.float64)
32
- # Standard Gumbel noise: -log(-log(U)), U ~ Uniform(0,1)
33
- # Add small epsilon for numerical stability inside logs
34
  noise = torch.rand_like(logits, dtype=torch.float64)
35
  standard_gumbel_noise = -torch.log(-torch.log(noise + 1e-20) + 1e-20)
36
  return logits + temperature * standard_gumbel_noise
37
 
38
  def get_num_transfer_tokens(mask_index, steps):
39
  mask_num = mask_index.sum(dim=1, keepdim=True)
40
- # Ensure steps is at least 1 to avoid division by zero if mask_num is also 0 (though sum should be >=0)
41
- steps = max(1, int(steps)) # Ensure steps is a positive integer
42
  base = mask_num // steps
43
  remainder = mask_num % steps
44
  num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.long) + base
45
- for i in range(mask_num.size(0)): # Iterate over batch
46
- if remainder[i] > 0 : # Ensure remainder is positive before indexing
47
- num_transfer_tokens[i, :remainder[i].item()] += 1 # .item() for single value tensor to int
48
  return num_transfer_tokens
49
 
 
50
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
 
51
  DEFAULT_MODEL_PATH = "Gen-Verse/MMaDA-8B-MixCoT"
52
- MASK_ID = None # 初始化为 None
53
- MODEL = None # 初始化为 None
54
- TOKENIZER = None# 初始化为 None
55
- uni_prompting = None # 初始化为 None
56
- VQ_MODEL = None # 初始化为 None, 稍后在初始化函数中加载
57
-
58
- CURRENT_MODEL_PATH = None # 初始化为 None
59
 
60
- MODEL_CHOICES = [
61
- "MMaDA-8B-Base",
62
- "MMaDA-8B-MixCoT",
63
- "MMaDA-8B-Max (coming soon)"
64
- ]
65
- MODEL_ACTUAL_PATHS = {
66
- "MMaDA-8B-Base": "Gen-Verse/MMaDA-8B-Base",
67
- "MMaDA-8B-MixCoT": "Gen-Verse/MMaDA-8B-MixCoT"
68
- }
69
-
70
- def clear_outputs_action():
71
- return None, None
72
 
 
73
  @spaces.GPU
74
- def _load_model_and_tokenizer_core(model_path_to_load, model_display_name_for_status):
75
- global MODEL, TOKENIZER, MASK_ID, CURRENT_MODEL_PATH, DEVICE, uni_prompting
 
 
 
76
 
77
- if MODEL is not None and CURRENT_MODEL_PATH == model_path_to_load:
78
- return f"Model '{model_display_name_for_status}' from '{model_path_to_load}' is already loaded. MASK_ID: {MASK_ID}"
 
79
 
80
- CURRENT_MODEL_PATH = model_path_to_load
81
-
82
- status_msg_parts = [f"Loading '{model_display_name_for_status}'..."]
83
- # try:
84
- TOKENIZER = AutoTokenizer.from_pretrained(model_path_to_load, trust_remote_code=True)
85
- status_msg_parts.append(f"Tokenizer for '{model_display_name_for_status}' loaded.")
86
 
87
- MODEL = MMadaModelLM.from_pretrained(model_path_to_load, trust_remote_code=True, torch_dtype=torch.bfloat16).eval()
88
- status_msg_parts.append(f"Model '{model_display_name_for_status}' loaded to {DEVICE}.")
89
 
90
- uni_prompting = UniversalPrompting(TOKENIZER, max_text_len=512, special_tokens=("<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>"),ignore_id=-100, cond_dropout_prob=0.1, use_reserved_token=True)
91
-
92
- if hasattr(TOKENIZER, 'mask_token_id') and TOKENIZER.mask_token_id is not None:
93
- MASK_ID = TOKENIZER.mask_token_id
94
- status_msg_parts.append(f"Using MASK_ID from tokenizer: {MASK_ID}.")
95
- else:
96
  MASK_ID = 126336
97
  status_msg_parts.append(f"Using default MASK_ID: {MASK_ID}.")
98
 
99
- if TOKENIZER.pad_token_id is None:
100
- if TOKENIZER.eos_token_id is not None:
101
- TOKENIZER.pad_token_id = TOKENIZER.eos_token_id
102
- TOKENIZER.pad_token = TOKENIZER.eos_token
103
- status_msg_parts.append(f"Set pad_token_id to eos_token_id ({TOKENIZER.eos_token_id}).")
104
- else:
105
- status_msg_parts.append("Warning: pad_token_id is None and no eos_token_id.")
106
-
107
- if TOKENIZER.eos_token_id is None: # Important for cleaning up output in visualization
108
- status_msg_parts.append("Warning: tokenizer.eos_token_id is None. EOS cleanup might not work.")
109
-
110
- TOKENIZER.chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{{ '<|start_header_id|>assistant<|end_header_id|>\n' }}"
111
-
112
- return " ".join(status_msg_parts)
113
- # except Exception as e:
114
- # MODEL = None
115
- # TOKENIZER = None
116
- # MASK_ID = None
117
- # CURRENT_MODEL_PATH = None
118
- # return f"Error loading model '{model_display_name_for_status}': {str(e)}"
119
-
120
- def handle_model_selection_change(selected_model_name_ui):
121
- global MODEL, TOKENIZER, MASK_ID, CURRENT_MODEL_PATH, DEVICE, uni_prompting
122
- status_msg = ""
123
- # 初始化 Examples 的可见性更新
124
- vis_lm_base = gr.update(visible=False)
125
- vis_lm_mixcot = gr.update(visible=False)
126
- vis_lm_max = gr.update(visible=False)
127
- vis_mmu_base = gr.update(visible=False)
128
- vis_mmu_mixcot = gr.update(visible=False)
129
- vis_mmu_max = gr.update(visible=False)
130
- # 根据选择的模型决定 thinking mode 的默认状态
131
- is_mixcot_model_selected = (selected_model_name_ui == "MMaDA-8B-MixCoT")
132
-
133
- # 初始 thinking mode 状态和按钮标签
134
- # 如果是 MixCoT 模型,则默认为 True (开启)
135
- current_thinking_mode_lm_state = is_mixcot_model_selected
136
- current_thinking_mode_mmu_state = is_mixcot_model_selected
137
-
138
- lm_think_button_label = "Thinking Mode ✅" if current_thinking_mode_lm_state else "Thinking Mode ❌"
139
- mmu_think_button_label = "Thinking Mode ✅" if current_thinking_mode_mmu_state else "Thinking Mode ❌"
140
- update_think_button_lm = gr.update(value=lm_think_button_label)
141
- update_think_button_mmu = gr.update(value=mmu_think_button_label)
142
- if selected_model_name_ui == "MMaDA-8B-Max (coming soon)":
143
- MODEL = None
144
- TOKENIZER = None
145
- MASK_ID = None
146
- CURRENT_MODEL_PATH = None
147
- status_msg = f"'{selected_model_name_ui}' is not yet available. Please select another model."
148
- vis_lm_max = gr.update(visible=True)
149
- vis_mmu_max = gr.update(visible=True)
150
- # 对于非 MixCoT 模型,thinking mode 在上面已经根据 is_mixcot_model_selected 设置为 False
151
- else:
152
- actual_path = MODEL_ACTUAL_PATHS.get(selected_model_name_ui)
153
- if not actual_path:
154
- MODEL = None
155
- TOKENIZER = None
156
- MASK_ID = None
157
- CURRENT_MODEL_PATH = None
158
- status_msg = f"Path for '{selected_model_name_ui}' is not defined. Cannot load."
159
- # 如果路径未定义(意味着不是有效的MixCoT加载),thinking mode应为False
160
- if is_mixcot_model_selected: # 如果本应是MixCoT但路径没有
161
- current_thinking_mode_lm_state = False
162
- current_thinking_mode_mmu_state = False
163
- update_think_button_lm = gr.update(value="Thinking Mode ❌")
164
- update_think_button_mmu = gr.update(value="Thinking Mode ❌")
165
- else:
166
- # 尝试加载模型
167
- status_msg = _load_model_and_tokenizer_core(actual_path, selected_model_name_ui)
168
-
169
- # 修复后的错误检查逻辑:只依赖状态消息来判断是否成功
170
- model_load_failed = False
171
-
172
- # 检查状态消息中是否包含明确的错误指示
173
- error_indicators = [
174
- "Error loading model",
175
- "Failed to",
176
- "Cannot load",
177
- "not defined",
178
- "not yet available"
179
- ]
180
-
181
- # 成功指示(任一存在则认为成功)
182
- success_indicators = [
183
- "loaded to", # "Model 'XXX' loaded to cuda"
184
- "is already loaded", # "Model 'XXX' is already loaded"
185
- "loaded.", # "Tokenizer loaded."
186
- ]
187
-
188
- # 检查是否有错误指示
189
- for error_indicator in error_indicators:
190
- if error_indicator in status_msg:
191
- model_load_failed = True
192
- break
193
-
194
- # 如果没有错误指示,检查是否有成功指示
195
- if not model_load_failed:
196
- has_success_indicator = any(success_indicator in status_msg for success_indicator in success_indicators)
197
- # 如果既没有错误指示也没有成功指示,那就可能有问题
198
- if not has_success_indicator:
199
- model_load_failed = True
200
- status_msg = f"Uncertain model loading status for '{selected_model_name_ui}'. {status_msg}"
201
-
202
- if model_load_failed:
203
- # 如果是 MixCoT 模型但加载失败,则关闭 thinking mode
204
- if is_mixcot_model_selected:
205
- current_thinking_mode_lm_state = False
206
- current_thinking_mode_mmu_state = False
207
- update_think_button_lm = gr.update(value="Thinking Mode ❌")
208
- update_think_button_mmu = gr.update(value="Thinking Mode ❌")
209
- else: # 模型成功加载或已经加载
210
- if selected_model_name_ui == "MMaDA-8B-Base":
211
- vis_lm_base = gr.update(visible=True)
212
- vis_mmu_base = gr.update(visible=True)
213
- elif selected_model_name_ui == "MMaDA-8B-MixCoT":
214
- vis_lm_mixcot = gr.update(visible=True)
215
- vis_mmu_mixcot = gr.update(visible=True)
216
- # thinking mode 已经在函数开头根据 is_mixcot_model_selected 设置为 True
217
- return (
218
- status_msg,
219
- vis_lm_base,
220
- vis_lm_mixcot,
221
- vis_lm_max,
222
- vis_mmu_base,
223
- vis_mmu_mixcot,
224
- vis_mmu_max,
225
- # 新增的返回值,用于更新 thinking_mode 状态和按钮
226
- current_thinking_mode_lm_state, # 直接返回值给 gr.State
227
- update_think_button_lm, # gr.update 对象给 gr.Button
228
- current_thinking_mode_mmu_state,
229
- update_think_button_mmu
230
- )
231
-
232
 
 
233
  def get_highlighted_text_tuples(current_x_ids_batch, prompt_input_ids, prompt_len, tk, current_mask_id, raw_prompt_attention_mask):
234
  if current_x_ids_batch is None or current_x_ids_batch.ndim == 0 or current_x_ids_batch.shape[0] == 0:
235
  return [("Error in sequence data for visualization.", "ERROR")]
236
- # only answer part
237
  current_x_ids_batch = current_x_ids_batch[:, prompt_len:]
238
  seq_ids = current_x_ids_batch[0].tolist()
239
- eos_token_id = tk.eos_token_id # Get EOS token ID
240
-
241
- # Stage 1: Build initial list of tuples with (token_str, label, token_id_int)
242
- # This helps in identifying EOS tokens later without re-checking the type.
243
  intermediate_tuples = []
244
  for j, token_id_int in enumerate(seq_ids):
245
  try:
246
  token_str = tk.decode([token_id_int], skip_special_tokens=True, clean_up_tokenization_spaces=False)
247
- except Exception: # Handle cases where a token ID might be problematic (e.g. with mock)
248
  token_str = f"[ID:{token_id_int}]"
249
 
250
  label = "ERROR"
@@ -254,490 +107,202 @@ def get_highlighted_text_tuples(current_x_ids_batch, prompt_input_ids, prompt_le
254
  else:
255
  label = "GEN"
256
  intermediate_tuples.append((token_str, label, token_id_int))
257
-
258
  return intermediate_tuples
259
 
260
  @torch.no_grad()
261
  @spaces.GPU
262
  def generate_viz_wrapper_t2i(prompt_text, steps, guidance_scale, mask_schedule="cosine"):
263
  global MODEL, TOKENIZER, MASK_ID, DEVICE, uni_prompting, VQ_MODEL
264
-
265
  if MODEL is None or TOKENIZER is None or MASK_ID is None:
266
- yield [("Error: Model not loaded. Please load the model first.", "ERROR")], "Model not loaded."
267
  return
268
-
269
  if DEVICE == 'cuda':
270
- print("Moving MODEL to GPU for inference...")
271
  MODEL.to(DEVICE)
272
  VQ_MODEL.to(DEVICE)
273
-
274
  try:
 
275
  steps = int(steps)
276
  guidance_scale = float(guidance_scale)
277
-
278
  image_tokens = torch.ones((1, 1024), dtype=torch.long, device=DEVICE) * MASK_ID
279
  prompt_text = [prompt_text]
280
  input_ids, attention_mask = uni_prompting((prompt_text, image_tokens), 't2i_gen')
281
-
282
  if guidance_scale > 0:
283
  uncond_input_ids, uncond_attention_mask = uni_prompting(([''], image_tokens), 't2i_gen')
284
  else:
285
  uncond_input_ids, uncond_attention_mask = None, None
286
-
287
  mask_schedule = get_mask_schedule(mask_schedule)
288
  blank_image = Image.new("RGB", (512, 512), (255, 255, 255))
289
  yield blank_image, "Starting generation..."
290
  for image_step, status_msg_step in MODEL.t2i_generate_decoding_stepwise(
291
- input_ids = input_ids,
292
- uncond_input_ids = uncond_input_ids,
293
- attention_mask = attention_mask,
294
- uncond_attention_mask = uncond_attention_mask,
295
- temperature=1.0,
296
- timesteps = steps,
297
- guidance_scale = guidance_scale,
298
- noise_schedule = mask_schedule,
299
- noise_type = "mask",
300
- seq_len = 1024,
301
- vq_model = VQ_MODEL,
302
- uni_prompting=uni_prompting):
303
- yield image_step, status_msg_step
304
-
305
  finally:
306
  if DEVICE == 'cuda':
307
- print("Moving MODEL back to CPU...")
308
  MODEL.to('cpu')
309
  VQ_MODEL.to('cpu')
310
  torch.cuda.empty_cache()
311
-
312
-
313
-
314
 
315
  @torch.no_grad()
316
  @spaces.GPU
317
  def generate_viz_wrapper_lm(prompt_text, steps, gen_length, block_length, temperature,
318
- cfg_scale, remasking_strategy, thinking_mode_lm=False):
319
- global MODEL, TOKENIZER, MASK_ID, DEVICE, VQ_MODEL
320
  if MODEL is None or TOKENIZER is None or MASK_ID is None:
321
  yield [("Error: Model not loaded. Please load the model first.", "ERROR")], "Model not loaded."
322
  return
323
-
324
  if DEVICE == 'cuda':
325
- print("Moving MODEL to GPU for inference...")
326
  MODEL.to(DEVICE)
327
-
328
  try:
329
- steps = int(steps)
330
- gen_length = int(gen_length)
331
- block_length = int(block_length)
332
-
333
  if thinking_mode_lm:
334
  prompt_text = "You should first think about the reasoning process in the mind and then provide the user with the answer. The reasoning process is enclosed within <think> </think> tags, i.e. <think> reasoning process here </think> answer here\n" + prompt_text
335
-
336
- try:
337
- m = [{"role": "user", "content": prompt_text}]
338
- processed_prompt_text = TOKENIZER.apply_chat_template(m, add_generation_prompt=True, tokenize=False)
339
- except Exception as e:
340
- yield [("Error applying chat template.", "ERROR")], f"Chat template error: {e}"
341
- processed_prompt_text = prompt_text
342
- try:
343
- if TOKENIZER.pad_token_id is None:
344
- if TOKENIZER.eos_token_id is not None:
345
- TOKENIZER.pad_token_id = TOKENIZER.eos_token_id
346
- else: # Should have been caught by load_model, but double check
347
- yield [("Tokenizer Error", "ERROR")], "pad_token_id is not set in tokenizer."
348
- return
349
-
350
- input_ids = TOKENIZER(text=processed_prompt_text, return_tensors="pt", padding="longest", padding_side="left", truncation=True, max_length=MODEL.config.max_position_embeddings if hasattr(MODEL.config, 'max_position_embeddings') else 2048)['input_ids'].to(DEVICE)
351
- raw_prompt_attention_mask = None
352
-
353
- except Exception as e:
354
- yield [("Error tokenizing prompt.", "ERROR")], f"Tokenization error: {e}"
355
- return
356
-
357
-
358
-
359
- batch_size = input_ids.shape[0]
360
- prompt_len = input_ids.shape[1]
361
-
362
  x = torch.full((batch_size, prompt_len + gen_length), MASK_ID, dtype=torch.long, device=DEVICE)
363
  x[:, :prompt_len] = input_ids.clone()
364
-
365
- yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), "Starting generation: Prompt + Initial Masks"
366
-
367
- if gen_length == 0:
368
- final_text_output = TOKENIZER.batch_decode(x[:,prompt_len:], skip_special_tokens=True)
369
- yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), final_text_output[0] if final_text_output else ""
370
- return
371
-
372
- if block_length <= 0 or gen_length % block_length != 0 :
373
- yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), \
374
- f"Error: gen_length ({gen_length}) must be divisible by block_length ({block_length}) and block_length > 0."
375
- return
376
  num_blocks = gen_length // block_length
377
-
378
- if steps <=0 or steps % num_blocks != 0:
379
- yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), \
380
- f"Error: steps ({steps}) must be positive and divisible by num_blocks ({num_blocks}). Steps: {steps}, Num Blocks: {num_blocks}"
381
- return
382
  steps_per_block = steps // num_blocks
383
-
384
  for num_block_iter in range(num_blocks):
385
  current_block_start_idx_in_x = prompt_len + num_block_iter * block_length
386
  current_block_end_idx_in_x = prompt_len + (num_block_iter + 1) * block_length
387
-
388
- block_masks_bool_current = torch.zeros_like(x, dtype=torch.bool)
389
- block_masks_bool_current[:, current_block_start_idx_in_x:current_block_end_idx_in_x] = \
390
- (x[:, current_block_start_idx_in_x:current_block_end_idx_in_x] == MASK_ID)
391
-
392
- num_transfer_tokens_for_this_block = get_num_transfer_tokens(
393
- block_masks_bool_current[:, current_block_start_idx_in_x:current_block_end_idx_in_x],
394
- steps_per_block
395
- )
396
-
397
  for i_step_in_block in range(steps_per_block):
398
- mask_index_global = (x == MASK_ID)
399
-
400
- if cfg_scale > 0.:
401
- un_x = x.clone()
402
- # For unconditional pass, mask out the original prompt tokens that are not padding
403
- # raw_prompt_attention_mask is (B, prompt_len)
404
- prompt_active_tokens_mask = raw_prompt_attention_mask.bool() # True where actual prompt tokens are
405
- un_x[:, :prompt_len][prompt_active_tokens_mask] = MASK_ID
406
-
407
- x_cfg_input = torch.cat([x, un_x], dim=0)
408
- # Pass attention_mask for CFG if model expects it, covering both parts
409
- # For simplicity, not passing explicit attention_mask here; relies on model's internal handling.
410
- model_output = MODEL(x_cfg_input)
411
- logits_cond, logits_uncond = torch.chunk(model_output.logits, 2, dim=0)
412
- logits = logits_uncond + (cfg_scale + 1) * (logits_cond - logits_uncond)
413
- else:
414
- # Not passing explicit attention_mask here; relies on model's internal handling.
415
- model_output = MODEL(x)
416
- logits = model_output.logits
417
-
418
  logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
419
- x0_predicted_tokens = torch.argmax(logits_with_noise, dim=-1)
420
-
421
- if remasking_strategy == 'low_confidence':
422
- probs = F.softmax(logits.to(torch.float64), dim=-1)
423
- x0_probs = torch.gather(probs, dim=-1, index=x0_predicted_tokens.unsqueeze(-1)).squeeze(-1)
424
- elif remasking_strategy == 'random':
425
- x0_probs = torch.rand(x.shape, device=x.device, dtype=torch.float64)
426
- else:
427
- yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), f"Error: Unknown remasking strategy '{remasking_strategy}'"
428
- return
429
-
430
- confidence_for_selection = torch.full_like(x0_probs, -torch.inf)
431
- candidate_positions_for_unmasking = mask_index_global & block_masks_bool_current
432
- confidence_for_selection = torch.where(
433
- candidate_positions_for_unmasking,
434
- x0_probs,
435
- -torch.inf
436
- )
437
-
438
  x0_final_candidates = torch.where(mask_index_global, x0_predicted_tokens, x)
439
-
440
- transfer_indices_bool = torch.zeros_like(x, dtype=torch.bool)
441
- num_to_transfer_this_step_batch = num_transfer_tokens_for_this_block[:, i_step_in_block]
442
-
443
  for j_batch_idx in range(batch_size):
444
- k_val = min(num_to_transfer_this_step_batch[j_batch_idx].item(),
445
- candidate_positions_for_unmasking[j_batch_idx].sum().item()) # ensure k isn't too large
446
-
447
  if k_val > 0:
448
- # Ensure confidence_for_selection[j_batch_idx] is 1D for topk
449
- conf_slice = confidence_for_selection[j_batch_idx]
450
- if conf_slice.ndim > 1: conf_slice = conf_slice.view(-1) # Should already be 1D from x0_probs
451
-
452
- # Check if there are enough valid (non -inf) confidences
453
- valid_conf_count = (conf_slice > -torch.inf).sum().item()
454
- actual_k = min(k_val, valid_conf_count)
455
-
456
- if actual_k > 0:
457
- _, topk_indices_in_x = torch.topk(conf_slice, k=actual_k)
458
- transfer_indices_bool[j_batch_idx, topk_indices_in_x] = True
459
-
460
  x[transfer_indices_bool] = x0_final_candidates[transfer_indices_bool]
461
-
462
- current_total_step = num_block_iter * steps_per_block + i_step_in_block + 1
463
- total_overall_steps = num_blocks * steps_per_block
464
- status_msg = f"Block {num_block_iter+1}/{num_blocks}, Step {i_step_in_block+1}/{steps_per_block} (Total: {current_total_step}/{total_overall_steps})"
465
  yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), status_msg
466
-
467
- final_generated_ids = x[:, prompt_len:]
468
- final_text_output = TOKENIZER.batch_decode(final_generated_ids, skip_special_tokens=True)
469
-
470
- final_text_str = final_text_output[0] if final_text_output and len(final_text_output) > 0 else ""
471
- yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), final_text_str
472
-
473
  finally:
474
  if DEVICE == 'cuda':
475
- print("Moving MODEL back to CPU and clearing cache...")
476
  MODEL.to('cpu')
477
  torch.cuda.empty_cache()
478
 
 
479
  @torch.no_grad()
480
  @spaces.GPU
481
  def generate_viz_wrapper(uploaded_image_pil, prompt_text, steps, gen_length, block_length, temperature,
482
  cfg_scale, remasking_strategy, thinking_mode_mmu=False):
483
- global MODEL, TOKENIZER, MASK_ID, DEVICE
484
-
485
  if MODEL is None or TOKENIZER is None or MASK_ID is None:
486
  yield [("Error: Model not loaded. Please load the model first.", "ERROR")], "Model not loaded."
487
  return
488
-
489
  if DEVICE == 'cuda':
490
- print("Moving MODEL to GPU for inference...")
491
  MODEL.to(DEVICE)
492
  VQ_MODEL.to(DEVICE)
493
-
494
  try:
495
- steps = int(steps)
496
- gen_length = int(gen_length)
497
- block_length = int(block_length)
498
-
499
  if thinking_mode_mmu:
500
  prompt_text = "You should first think about the reasoning process in the mind and then provide the user with the answer. The reasoning process is enclosed within <think> </think> tags, i.e. <think> reasoning process here </think> answer here\n" + prompt_text
501
-
502
- try:
503
- m = [{"role": "user", "content": prompt_text}]
504
- processed_prompt_text = TOKENIZER.apply_chat_template(m, add_generation_prompt=True, tokenize=False)
505
- except Exception as e:
506
- yield [("Error applying chat template.", "ERROR")], f"Chat template error: {e}"
507
- processed_prompt_text = prompt_text
508
-
509
- image_vq_ids_tensor = None
510
  if uploaded_image_pil is not None:
511
- try:
512
-
513
- image = image_transform(uploaded_image_pil, resolution=512).to(DEVICE)
514
- image = image.unsqueeze(0)
515
- image_vq_ids_tensor = VQ_MODEL.get_code(image) + 126349
516
- except Exception as e:
517
- yield [("Error processing image.", "ERROR")], f"Image to VQ tokens conversion failed: {str(e)}"
518
- return
519
-
520
-
521
- try:
522
- if TOKENIZER.pad_token_id is None:
523
- if TOKENIZER.eos_token_id is not None:
524
- TOKENIZER.pad_token_id = TOKENIZER.eos_token_id
525
- else:
526
- yield [("Tokenizer Error", "ERROR")], "pad_token_id is not set in tokenizer."
527
- return
528
-
529
- input_ids = TOKENIZER(text=processed_prompt_text, return_tensors="pt", padding="longest", padding_side="left", truncation=True, max_length=MODEL.config.max_position_embeddings if hasattr(MODEL.config, 'max_position_embeddings') else 2048)['input_ids'].to(DEVICE)
530
- raw_prompt_attention_mask = None
531
- if image_vq_ids_tensor is not None:
532
- if image_vq_ids_tensor.ndim == 1:
533
- image_vq_ids_tensor = image_vq_ids_tensor.unsqueeze(0)
534
-
535
- input_ids = torch.cat([
536
- (torch.ones(input_ids.shape[0], 1) * torch.tensor([126089])).to(DEVICE),
537
- (torch.ones(input_ids.shape[0], 1) * torch.tensor([126084])).to(DEVICE),
538
- image_vq_ids_tensor,
539
- (torch.ones(input_ids.shape[0], 1) * torch.tensor([126085])).to(DEVICE),
540
- input_ids
541
- ], dim=1).long()
542
-
543
- else:
544
- input_ids = input_ids
545
-
546
-
547
- except Exception as e:
548
- yield [("Error tokenizing prompt.", "ERROR")], f"Tokenization error: {e}"
549
- return
550
-
551
-
552
-
553
- batch_size = input_ids.shape[0]
554
- prompt_len = input_ids.shape[1]
555
-
556
  x = torch.full((batch_size, prompt_len + gen_length), MASK_ID, dtype=torch.long, device=DEVICE)
557
  x[:, :prompt_len] = input_ids.clone()
558
-
559
- yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), "Starting generation: Prompt + Initial Masks"
560
-
561
- if gen_length == 0:
562
- final_text_output = TOKENIZER.batch_decode(x[:,prompt_len:], skip_special_tokens=True)
563
- yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), final_text_output[0] if final_text_output else ""
564
- return
565
-
566
- if block_length <= 0 or gen_length % block_length != 0 :
567
- yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), \
568
- f"Error: gen_length ({gen_length}) must be divisible by block_length ({block_length}) and block_length > 0."
569
- return
570
  num_blocks = gen_length // block_length
571
-
572
- if steps <=0 or steps % num_blocks != 0:
573
- yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), \
574
- f"Error: steps ({steps}) must be positive and divisible by num_blocks ({num_blocks}). Steps: {steps}, Num Blocks: {num_blocks}"
575
- return
576
  steps_per_block = steps // num_blocks
577
-
578
  for num_block_iter in range(num_blocks):
579
  current_block_start_idx_in_x = prompt_len + num_block_iter * block_length
580
  current_block_end_idx_in_x = prompt_len + (num_block_iter + 1) * block_length
581
-
582
- block_masks_bool_current = torch.zeros_like(x, dtype=torch.bool)
583
- block_masks_bool_current[:, current_block_start_idx_in_x:current_block_end_idx_in_x] = \
584
- (x[:, current_block_start_idx_in_x:current_block_end_idx_in_x] == MASK_ID)
585
-
586
- num_transfer_tokens_for_this_block = get_num_transfer_tokens(
587
- block_masks_bool_current[:, current_block_start_idx_in_x:current_block_end_idx_in_x],
588
- steps_per_block
589
- )
590
-
591
  for i_step_in_block in range(steps_per_block):
592
- mask_index_global = (x == MASK_ID)
593
-
594
- if cfg_scale > 0.:
595
- un_x = x.clone()
596
- # For unconditional pass, mask out the original prompt tokens that are not padding
597
- # raw_prompt_attention_mask is (B, prompt_len)
598
- prompt_active_tokens_mask = raw_prompt_attention_mask.bool() # True where actual prompt tokens are
599
- un_x[:, :prompt_len][prompt_active_tokens_mask] = MASK_ID
600
-
601
- x_cfg_input = torch.cat([x, un_x], dim=0)
602
- # Pass attention_mask for CFG if model expects it, covering both parts
603
- # For simplicity, not passing explicit attention_mask here; relies on model's internal handling.
604
- model_output = MODEL(x_cfg_input)
605
- logits_cond, logits_uncond = torch.chunk(model_output.logits, 2, dim=0)
606
- logits = logits_uncond + (cfg_scale + 1) * (logits_cond - logits_uncond)
607
- else:
608
- # Not passing explicit attention_mask here; relies on model's internal handling.
609
- model_output = MODEL(x)
610
- logits = model_output.logits
611
-
612
- logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
613
- x0_predicted_tokens = torch.argmax(logits_with_noise, dim=-1)
614
-
615
- if remasking_strategy == 'low_confidence':
616
- probs = F.softmax(logits.to(torch.float64), dim=-1)
617
- x0_probs = torch.gather(probs, dim=-1, index=x0_predicted_tokens.unsqueeze(-1)).squeeze(-1)
618
- elif remasking_strategy == 'random':
619
- x0_probs = torch.rand(x.shape, device=x.device, dtype=torch.float64)
620
- else:
621
- yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), f"Error: Unknown remasking strategy '{remasking_strategy}'"
622
- return
623
-
624
- confidence_for_selection = torch.full_like(x0_probs, -torch.inf)
625
- candidate_positions_for_unmasking = mask_index_global & block_masks_bool_current
626
- confidence_for_selection = torch.where(
627
- candidate_positions_for_unmasking,
628
- x0_probs,
629
- -torch.inf
630
- )
631
-
632
- x0_final_candidates = torch.where(mask_index_global, x0_predicted_tokens, x)
633
-
634
- transfer_indices_bool = torch.zeros_like(x, dtype=torch.bool)
635
- num_to_transfer_this_step_batch = num_transfer_tokens_for_this_block[:, i_step_in_block]
636
-
637
- for j_batch_idx in range(batch_size):
638
- k_val = min(num_to_transfer_this_step_batch[j_batch_idx].item(),
639
- candidate_positions_for_unmasking[j_batch_idx].sum().item()) # ensure k isn't too large
640
-
641
- if k_val > 0:
642
- # Ensure confidence_for_selection[j_batch_idx] is 1D for topk
643
- conf_slice = confidence_for_selection[j_batch_idx]
644
- if conf_slice.ndim > 1: conf_slice = conf_slice.view(-1) # Should already be 1D from x0_probs
645
-
646
- # Check if there are enough valid (non -inf) confidences
647
- valid_conf_count = (conf_slice > -torch.inf).sum().item()
648
- actual_k = min(k_val, valid_conf_count)
649
-
650
- if actual_k > 0:
651
- _, topk_indices_in_x = torch.topk(conf_slice, k=actual_k)
652
- transfer_indices_bool[j_batch_idx, topk_indices_in_x] = True
653
-
654
- x[transfer_indices_bool] = x0_final_candidates[transfer_indices_bool]
655
-
656
- current_total_step = num_block_iter * steps_per_block + i_step_in_block + 1
657
- total_overall_steps = num_blocks * steps_per_block
658
- status_msg = f"Block {num_block_iter+1}/{num_blocks}, Step {i_step_in_block+1}/{steps_per_block} (Total: {current_total_step}/{total_overall_steps})"
659
- yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), status_msg
660
-
661
- final_generated_ids = x[:, prompt_len:]
662
- final_text_output = TOKENIZER.batch_decode(final_generated_ids, skip_special_tokens=True)
663
-
664
- final_text_str = final_text_output[0] if final_text_output and len(final_text_output) > 0 else ""
665
- yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), final_text_str
666
-
667
  finally:
668
  if DEVICE == 'cuda':
669
- print("Moving MODEL back to CPU and clearing cache...")
670
  MODEL.to('cpu')
671
  VQ_MODEL.to('cpu')
672
  torch.cuda.empty_cache()
673
 
674
 
 
675
  css_styles = """
676
  .gradio-container{font-family:'IBM Plex Sans',sans-serif;margin:auto;}
677
  .gr-input {background:#f9f9f9 !important;border:1px solid #e0e0e0 !important;}
678
  .gr-output{background:#f0f0f0 !important;border:1px solid #d0d0d0 !important;}
679
-
680
- .highlighted-text span{
681
- padding:2px 4px;border-radius:4px;margin:1px 2px;display:inline-block;line-height:1.6;
682
- }
683
-
684
  footer{display:none !important}
685
-
686
- #live-update-scrollable-box {
687
- max-height: 800px; /* 您可以根据需要调整这个最大高度,例如 '300px', '50vh' */
688
- overflow-y: auto !important; /* 当内容超出 max-height 时显示垂直滚动条 */
689
- display: block; /* 确保元素是块级元素,以便 max-height 生效 */
690
-
691
- }
692
- #think_btn {
693
- background-color: #f3f4f6 !important;
694
- border: 1px solid #d0d0d0 !important;
695
- color: #111827 !important;
696
- font-size: 16px !important;
697
- font-weight: bold !important;
698
- }
699
- #think_btn:hover {
700
- background-color: #e0e0e0 !important;
701
- border: 1px solid #c0c0c0 !important;
702
- color: #222 !important;
703
- }
704
- #think_btn:active {
705
- background-color: #2563eb !important;
706
- border: 1px solid #b0b0b0 !important;
707
- color: white !important;
708
- }
709
  """
710
 
711
-
712
- # thinking_mode_t2i = gr.State(False)
713
- def toggle_thinking_mode_lm(current_thinking_mode):
714
- new_state = not current_thinking_mode
715
- new_label = "Thinking Mode ✅" if new_state else "Thinking Mode ❌"
716
- return new_state, gr.update(value=new_label)
717
-
718
- def toggle_thinking_mode_mmu(current_thinking_mode):
719
  new_state = not current_thinking_mode
720
  new_label = "Thinking Mode ✅" if new_state else "Thinking Mode ❌"
721
  return new_state, gr.update(value=new_label)
722
 
 
723
 
724
- color_map_config = {
725
- "MASK": "lightgrey",
726
- "GEN": "#DCABFA",
727
- }
728
 
729
- theme = gr.themes.Ocean(
730
- primary_hue="fuchsia",
731
- )
732
  with gr.Blocks(css=css_styles, theme=theme) as demo:
733
- # with gr.Blocks(css=css_styles, theme=gr.themes.Soft(primary_hue=gr.themes.colors.blue, secondary_hue=gr.themes.colors.sky)) as demo:
734
- # with gr.Blocks() as demo:
735
- thinking_mode_lm = gr.State(False)
736
- thinking_mode_mmu = gr.State(False)
737
- # gr.Markdown("<h1 style='text-align: center; margin-bottom: 20px;'>MMaDA: Multimodal Large Diffusion Language Models</h1>")
738
- # gr.Markdown("MMaDA is a novel class of multimodal diffusion foundation models designed to achieve superior performance across diverse domains such as textual reasoning, multimodal understanding, and text-to-image generation")
739
- # gr.Markdown("Github: [Gen-Verse/MMaDA](https://github.com/Gen-Verse/MMaDA)")
740
- # gr.Markdown("Paper: [MMaDA: Multimodal Large Diffusion Language Models]()")
741
  gr.HTML("""
742
  <div align="center" style="margin-bottom: 20px;">
743
  <img src='/gradio_api/file=title.png' width="160">
@@ -745,279 +310,119 @@ with gr.Blocks(css=css_styles, theme=theme) as demo:
745
  MMaDA is a new class of multimodal diffusion foundation models, enabling state-of-the-art performance in reasoning, multimodal understanding, and text-to-image generation.
746
  </p>
747
  <p style="font-size: 15px;">
748
- 📄 <a href="https://arxiv.org/abs/2505.15809" target="_blank">Paper</a>
749
- &nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;
750
- 💻 <a href="https://github.com/Gen-Verse/MMaDA" target="_blank">Code</a>
751
  </p>
752
  </div>
753
  """)
 
754
  with gr.Row():
755
- model_select_radio = gr.Radio(
756
- label="Select Text Generation Model",
757
- choices=MODEL_CHOICES,
758
- value="MMaDA-8B-MixCoT"
759
- )
760
- model_load_status_box = gr.Textbox(
761
- label="Model Load Status",
762
- interactive=False,
763
- lines=3,
764
- max_lines=5
765
- )
 
 
766
 
 
767
  gr.Markdown("## Part 1. Text Generation")
768
  with gr.Row():
769
  with gr.Column(scale=2):
770
  prompt_input_box_lm = gr.Textbox(label="Enter your prompt:", lines=3, value="A rectangular prism has a length of 5 units, a width of 4 units, and a height of 3 units. What is the volume of the prism?")
771
- think_button_lm = gr.Button("🧠 Enable Thinking Mode", elem_id="think_btn")
772
  with gr.Accordion("Generation Parameters", open=True):
 
773
  with gr.Row():
774
- gen_length_slider_lm = gr.Slider(minimum=8, maximum=1024, value=512, step=64, label="Generation Length", info="Number of tokens to generate.")
775
- steps_slider_lm = gr.Slider(minimum=1, maximum=512, value=256, step=32, label="Total Sampling Steps", info="Must be divisible by (gen_length / block_length).")
776
  with gr.Row():
777
- block_length_slider_lm = gr.Slider(minimum=8, maximum=1024, value=128, step=32, label="Block Length", info="gen_length must be divisible by this.")
778
  remasking_dropdown_lm = gr.Dropdown(choices=['low_confidence', 'random'], value='low_confidence', label="Remasking Strategy")
779
  with gr.Row():
780
- cfg_scale_slider_lm = gr.Slider(minimum=0.0, maximum=2.0, value=0.0, step=0.1, label="CFG Scale", info="Classifier-Free Guidance. 0 disables it.")
781
- temperature_slider_lm = gr.Slider(minimum=0.0, maximum=2.0, value=1, step=0.05, label="Temperature", info="Controls randomness via Gumbel noise. 0 is deterministic.")
782
-
783
-
784
  with gr.Row():
785
  run_button_ui_lm = gr.Button("Generate Sequence", variant="primary", scale=3)
786
  clear_button_ui_lm = gr.Button("Clear Outputs", scale=1)
787
-
788
  with gr.Column(scale=3):
789
- # gr.Markdown("## Live Generation Process")
790
- output_visualization_box_lm = gr.HighlightedText(
791
- label="Live Generation Process",
792
- show_legend=True,
793
- color_map=color_map_config,
794
- combine_adjacent=False,
795
- interactive=False,
796
- elem_id="live-update-scrollable-box",
797
- )
798
- # gr.Markdown("## Final Generated Text")
799
  output_final_text_box_lm = gr.Textbox(label="Final Output", lines=8, interactive=False, show_copy_button=True)
800
-
801
-
802
- with gr.Column(visible=False) as examples_lm_base:
803
- gr.Examples(
804
- examples=[
805
- ["A rectangular prism has a length of 5 units, a width of 4 units, and a height of 3 units. What is the volume of the prism?", 256, 512, 128, 1, 0, "low_confidence"],
806
- ["Lily can run 12 kilometers per hour for 4 hours. After that, she can run 6 kilometers per hour. How many kilometers can she run in 8 hours?", 256, 512, 64, 1, 0, "low_confidence"]
807
- ],
808
- inputs=[prompt_input_box_lm, steps_slider_lm, gen_length_slider_lm, block_length_slider_lm, temperature_slider_lm, cfg_scale_slider_lm, remasking_dropdown_lm],
809
- outputs=[output_visualization_box_lm, output_final_text_box_lm],
810
- fn=generate_viz_wrapper_lm,
811
- cache_examples=False
812
- )
813
- with gr.Column(visible=True) as examples_lm_mixcot:
814
- gr.Examples(
815
- examples=[
816
- ["A rectangular prism has a length of 5 units, a width of 4 units, and a height of 3 units. What is the volume of the prism?", 256, 512, 128, 1, 0, "low_confidence"],
817
- ["Lily can run 12 kilometers per hour for 4 hours. After that, she can run 6 kilometers per hour. How many kilometers can she run in 8 hours?", 256, 512, 64, 1, 0, "low_confidence"]
818
- ],
819
- inputs=[prompt_input_box_lm, steps_slider_lm, gen_length_slider_lm, block_length_slider_lm, temperature_slider_lm, cfg_scale_slider_lm, remasking_dropdown_lm],
820
- outputs=[output_visualization_box_lm, output_final_text_box_lm],
821
- fn=generate_viz_wrapper_lm,
822
- cache_examples=False
823
- )
824
- with gr.Column(visible=False) as examples_lm_max:
825
- gr.Examples(
826
- examples=[
827
- ["A rectangular prism has a length of 5 units, a width of 4 units, and a height of 3 units. What is the volume of the prism?", 256, 512, 128, 1, 0, "low_confidence"],
828
- ["Lily can run 12 kilometers per hour for 4 hours. After that, she can run 6 kilometers per hour. How many kilometers can she run in 8 hours?", 256, 512, 64, 1, 0, "low_confidence"]
829
- ],
830
- inputs=[prompt_input_box_lm, steps_slider_lm, gen_length_slider_lm, block_length_slider_lm, temperature_slider_lm, cfg_scale_slider_lm, remasking_dropdown_lm],
831
- outputs=[output_visualization_box_lm, output_final_text_box_lm],
832
- fn=generate_viz_wrapper_lm,
833
- cache_examples=False
834
- )
835
-
836
  gr.Markdown("---")
837
  gr.Markdown("## Part 2. Multimodal Understanding")
838
  with gr.Row():
 
839
  with gr.Column(scale=2):
840
- prompt_input_box_mmu = gr.Textbox(
841
- label="Enter your prompt:",
842
- lines=3,
843
- value=""
844
- )
845
- think_button_mmu = gr.Button("🧠 Enable Thinking Mode", elem_id="think_btn")
846
  with gr.Accordion("Generation Parameters", open=True):
847
- with gr.Row():
848
- gen_length_slider_mmu = gr.Slider(minimum=64, maximum=1024, value=512, step=64, label="Generation Length", info="Number of tokens to generate.")
849
- steps_slider_mmu = gr.Slider(minimum=1, maximum=512, value=256, step=32, label="Total Sampling Steps", info="Must be divisible by (gen_length / block_length).")
850
- with gr.Row():
851
- block_length_slider_mmu = gr.Slider(minimum=32, maximum=1024, value=64, step=32, label="Block Length", info="gen_length must be divisible by this.")
852
  remasking_dropdown_mmu = gr.Dropdown(choices=['low_confidence', 'random'], value='low_confidence', label="Remasking Strategy")
853
- with gr.Row():
854
- cfg_scale_slider_mmu = gr.Slider(minimum=0.0, maximum=2.0, value=0.0, step=0.1, label="CFG Scale", info="Classifier-Free Guidance. 0 disables it.")
855
- temperature_slider_mmu = gr.Slider(minimum=0.0, maximum=2.0, value=1, step=0.05, label="Temperature", info="Controls randomness via Gumbel noise. 0 is deterministic.")
856
-
857
  with gr.Row():
858
  image_upload_box = gr.Image(type="pil", label="Upload Image")
859
-
860
  with gr.Row():
861
  run_button_ui_mmu = gr.Button("Generate Description", variant="primary", scale=3)
862
  clear_button_ui_mmu = gr.Button("Clear Outputs", scale=1)
863
-
864
  with gr.Column(scale=3):
865
- gr.Markdown("## Live Generation Process")
866
- output_visualization_box_mmu = gr.HighlightedText(
867
- label="Token Sequence (Live Update)",
868
- show_legend=True,
869
- color_map=color_map_config,
870
- combine_adjacent=False,
871
- interactive=False,
872
- elem_id="live-update-scrollable-box",
873
- )
874
- gr.Markdown("## Final Generated Text")
875
  output_final_text_box_mmu = gr.Textbox(label="Final Output", lines=8, interactive=False, show_copy_button=True)
876
-
877
- with gr.Column(visible=False) as examples_mmu_base:
878
- gr.Examples(
879
- examples=[
880
- [
881
- "figs/sunflower.jpg",
882
- "Please describe this image in detail.",
883
- 256,
884
- 512,
885
- 128,
886
- 1,
887
- 0,
888
- "low_confidence"
889
- ],
890
- [
891
- "figs/woman.jpg",
892
- "Please describe this image in detail.",
893
- 256,
894
- 512,
895
- 128,
896
- 1,
897
- 0,
898
- "low_confidence"
899
- ]
900
- ],
901
- inputs=[
902
- image_upload_box,
903
- prompt_input_box_mmu,
904
- steps_slider_mmu,
905
- gen_length_slider_mmu,
906
- block_length_slider_mmu,
907
- temperature_slider_mmu,
908
- cfg_scale_slider_mmu,
909
- remasking_dropdown_mmu
910
- ],
911
- outputs=[output_visualization_box_mmu, output_final_text_box_mmu],
912
- fn=generate_viz_wrapper,
913
- cache_examples=False
914
- )
915
- with gr.Column(visible=True) as examples_mmu_mixcot:
916
- gr.Examples(
917
- examples=[
918
- [
919
- "figs/geo.png",
920
- "In the given figure, a square ABCD is inscribed in a circle with center O. Point P is located on side CD. What is the value of angle APB?",
921
- 256,
922
- 512,
923
- 64,
924
- 1,
925
- 0,
926
- "low_confidence"
927
- ],
928
- [
929
- "figs/bus.jpg",
930
- "What are the colors of the bus?",
931
- 256,
932
- 512,
933
- 64,
934
- 1,
935
- 0,
936
- "low_confidence"
937
- ]
938
- ],
939
- inputs=[
940
- image_upload_box,
941
- prompt_input_box_mmu,
942
- steps_slider_mmu,
943
- gen_length_slider_mmu,
944
- block_length_slider_mmu,
945
- temperature_slider_mmu,
946
- cfg_scale_slider_mmu,
947
- remasking_dropdown_mmu
948
- ],
949
- outputs=[output_visualization_box_mmu, output_final_text_box_mmu],
950
- fn=generate_viz_wrapper,
951
- cache_examples=False
952
- )
953
- with gr.Column(visible=False) as examples_mmu_max:
954
- gr.Examples(
955
- examples=[
956
- [
957
- "figs/sunflower.jpg",
958
- "Please describe this image in detail.",
959
- 256,
960
- 512,
961
- 128,
962
- 1,
963
- 0,
964
- "low_confidence"
965
- ],
966
- [
967
- "figs/woman.jpg",
968
- "Please describe this image in detail.",
969
- 256,
970
- 512,
971
- 128,
972
- 1,
973
- 0,
974
- "low_confidence"
975
- ]
976
- ],
977
- inputs=[
978
- image_upload_box,
979
- prompt_input_box_mmu,
980
- steps_slider_mmu,
981
- gen_length_slider_mmu,
982
- block_length_slider_mmu,
983
- temperature_slider_mmu,
984
- cfg_scale_slider_mmu,
985
- remasking_dropdown_mmu
986
- ],
987
- outputs=[output_visualization_box_mmu, output_final_text_box_mmu],
988
- fn=generate_viz_wrapper,
989
- cache_examples=False
990
- )
991
-
992
  gr.Markdown("---")
993
  gr.Markdown("## Part 3. Text-to-Image Generation")
 
994
  with gr.Row():
995
  with gr.Column(scale=2):
996
  prompt_input_box_t2i = gr.Textbox(label="Enter your prompt:", lines=3, value="A sea turtle swimming near a coral reef in the ocean, with a clear blue sky and water in the background.")
997
-
998
  with gr.Accordion("Generation Parameters", open=True):
999
  with gr.Row():
1000
- steps_slider_t2i = gr.Slider(minimum=5, maximum=100, value=15, step=5, label="Total Sampling Steps", info="Must be divisible by (gen_length / block_length).")
1001
- guidance_scale_slider_t2i = gr.Slider(minimum=0.0, maximum=7.0, value=3.5, step=0.5, label="Guidance Scale", info="Classifier-Free Guidance. 0 disables it.")
1002
-
1003
-
1004
- with gr.Row():
1005
- scheduler_radio_t2i = gr.Radio(
1006
- choices=["cosine", "sigmoid", "linear"],
1007
- value="cosine",
1008
- label="Scheduler",
1009
- )
1010
-
1011
  with gr.Row():
1012
  run_button_ui_t2i = gr.Button("Generate Image", variant="primary", scale=3)
1013
  clear_button_ui_t2i = gr.Button("Clear Outputs", scale=1)
1014
-
1015
-
1016
  with gr.Column(scale=3):
1017
- # gr.Markdown("## Live Generation Process")
1018
  output_image_t2i = gr.Image(label="Generated Image", interactive=False, type="pil")
1019
  output_status_t2i = gr.Textbox(label="Generation Status", interactive=False)
1020
-
1021
  gr.Examples(
1022
  examples=[
1023
  ["A sea turtle swimming near a coral reef in the ocean, with a clear blue sky and water in the background.", 15, 3.5, "cosine"],
@@ -1028,153 +433,45 @@ with gr.Blocks(css=css_styles, theme=theme) as demo:
1028
  fn=generate_viz_wrapper_t2i,
1029
  cache_examples=False
1030
  )
1031
-
1032
- run_button_ui_t2i.click(
1033
- fn=generate_viz_wrapper_t2i,
1034
- inputs=[
1035
- prompt_input_box_t2i,
1036
- steps_slider_t2i,
1037
- guidance_scale_slider_t2i,
1038
- scheduler_radio_t2i
1039
- ],
1040
- outputs=[output_image_t2i, output_status_t2i]
1041
- )
1042
-
1043
- clear_button_ui_t2i.click(
1044
- fn=lambda: (None, ""),
1045
- inputs=None,
1046
- outputs=[output_image_t2i, output_status_t2i],
1047
- queue=False
1048
- )
1049
-
1050
- think_button_lm.click(
1051
- fn=toggle_thinking_mode_lm,
1052
- inputs=[thinking_mode_lm],
1053
- outputs=[thinking_mode_lm, think_button_lm]
1054
- )
1055
-
1056
- think_button_mmu.click(
1057
- fn=toggle_thinking_mode_mmu,
1058
- inputs=[thinking_mode_mmu],
1059
- outputs=[thinking_mode_mmu, think_button_mmu]
1060
- )
1061
 
 
1062
  def initialize_app_state():
1063
- global VQ_MODEL
 
 
 
1064
 
1065
- if VQ_MODEL is None:
1066
- print("Loading VQ_MODEL for the first time...")
1067
- VQ_MODEL = MAGVITv2().from_pretrained("showlab/magvitv2")
1068
- print("VQ_MODEL loaded to CPU.")
1069
-
1070
- default_model_choice = "MMaDA-8B-MixCoT"
1071
-
1072
- status, lm_b_vis, lm_m_vis, lm_x_vis, \
1073
- mmu_b_vis, mmu_m_vis, mmu_x_vis, \
1074
- init_thinking_lm_state, init_think_lm_btn_update, \
1075
- init_thinking_mmu_state, init_think_mmu_btn_update = handle_model_selection_change(default_model_choice)
1076
-
1077
- return (
1078
- default_model_choice,
1079
- status,
1080
- lm_b_vis,
1081
- lm_m_vis,
1082
- lm_x_vis,
1083
- mmu_b_vis,
1084
- mmu_m_vis,
1085
- mmu_x_vis,
1086
- init_thinking_lm_state,
1087
- init_think_lm_btn_update,
1088
- init_thinking_mmu_state,
1089
- init_think_mmu_btn_update
1090
- )
1091
 
1092
  demo.load(
1093
  fn=initialize_app_state,
1094
  inputs=None,
1095
  outputs=[
1096
- model_select_radio,
1097
  model_load_status_box,
1098
- examples_lm_base,
1099
- examples_lm_mixcot,
1100
- examples_lm_max,
1101
- examples_mmu_base,
1102
- examples_mmu_mixcot,
1103
- examples_mmu_max,
1104
- thinking_mode_lm,
1105
  think_button_lm,
1106
  thinking_mode_mmu,
1107
  think_button_mmu
1108
  ],
1109
  queue=True
1110
  )
1111
-
1112
- model_select_radio.change(
1113
- fn=handle_model_selection_change,
1114
- inputs=[model_select_radio],
1115
- outputs=[
1116
- model_load_status_box,
1117
- examples_lm_base,
1118
- examples_lm_mixcot,
1119
- examples_lm_max,
1120
- examples_mmu_base,
1121
- examples_mmu_mixcot,
1122
- examples_mmu_max,
1123
- thinking_mode_lm,
1124
- think_button_lm,
1125
- thinking_mode_mmu,
1126
- think_button_mmu
1127
- ]
1128
- )
1129
-
1130
- def clear_outputs():
1131
- return None, None, None # Clear image, visualization, and final text
1132
-
1133
- clear_button_ui_lm.click(
1134
- fn=lambda: (None, None), # 返回两个 None
1135
- inputs=None,
1136
- outputs=[output_visualization_box_lm, output_final_text_box_lm], # 只清除两个文本框
1137
- queue=False
1138
- )
1139
- clear_button_ui_mmu.click(
1140
- fn=clear_outputs,
1141
- inputs=None,
1142
- outputs=[image_upload_box, output_visualization_box_mmu, output_final_text_box_mmu],
1143
- queue=False
1144
- )
1145
-
1146
- run_button_ui_lm.click(
1147
- fn=generate_viz_wrapper_lm,
1148
- inputs=[
1149
- prompt_input_box_lm,
1150
- steps_slider_lm,
1151
- gen_length_slider_lm,
1152
- block_length_slider_lm,
1153
- temperature_slider_lm,
1154
- cfg_scale_slider_lm,
1155
- remasking_dropdown_lm,
1156
- thinking_mode_lm
1157
- ],
1158
- outputs=[output_visualization_box_lm, output_final_text_box_lm]
1159
- )
1160
-
1161
- run_button_ui_mmu.click(
1162
- fn=generate_viz_wrapper,
1163
- inputs=[
1164
- image_upload_box,
1165
- prompt_input_box_mmu,
1166
- steps_slider_mmu,
1167
- gen_length_slider_mmu,
1168
- block_length_slider_mmu,
1169
- temperature_slider_mmu,
1170
- cfg_scale_slider_mmu,
1171
- remasking_dropdown_mmu,
1172
- thinking_mode_mmu
1173
- ],
1174
- outputs=[output_visualization_box_mmu, output_final_text_box_mmu]
1175
- )
1176
-
1177
 
1178
  if __name__ == "__main__":
1179
  print(f"Starting Gradio App. Attempting to use device: {DEVICE}")
1180
- demo.launch(allowed_paths=["title.png"])
 
10
  import spaces
11
 
12
 
13
+ # --- 辅助函数 (未修改) ---
14
  def image_transform(image, resolution=256, normalize=True):
15
  image = transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BICUBIC)(image)
16
  image = transforms.CenterCrop((resolution, resolution))(image)
 
20
  return image
21
 
22
  def add_gumbel_noise(logits, temperature):
23
+ if abs(temperature) < 1e-9:
 
 
 
 
 
24
  return logits
 
25
  logits = logits.to(torch.float64)
 
 
26
  noise = torch.rand_like(logits, dtype=torch.float64)
27
  standard_gumbel_noise = -torch.log(-torch.log(noise + 1e-20) + 1e-20)
28
  return logits + temperature * standard_gumbel_noise
29
 
30
  def get_num_transfer_tokens(mask_index, steps):
31
  mask_num = mask_index.sum(dim=1, keepdim=True)
32
+ steps = max(1, int(steps))
 
33
  base = mask_num // steps
34
  remainder = mask_num % steps
35
  num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.long) + base
36
+ for i in range(mask_num.size(0)):
37
+ if remainder[i] > 0 :
38
+ num_transfer_tokens[i, :remainder[i].item()] += 1
39
  return num_transfer_tokens
40
 
41
+ # --- 全局变量和模型配置 ---
42
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
43
+ # 固定使用 MMaDA-8B-MixCoT 模型
44
  DEFAULT_MODEL_PATH = "Gen-Verse/MMaDA-8B-MixCoT"
45
+ MASK_ID = None
46
+ MODEL = None
47
+ TOKENIZER = None
48
+ uni_prompting = None
49
+ VQ_MODEL = None
 
 
50
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
+ # --- 核心模型加载函数 (已简化) ---
53
  @spaces.GPU
54
+ def load_model_and_tokenizer():
55
+ """
56
+ 加载固定的 MMaDA-8B-MixCoT 模型和分词器。
57
+ """
58
+ global MODEL, TOKENIZER, MASK_ID, DEVICE, uni_prompting
59
 
60
+ # 如果模型已经加载,则直接返回
61
+ if MODEL is not None:
62
+ return f"Model 'MMaDA-8B-MixCoT' is already loaded. MASK_ID: {MASK_ID}"
63
 
64
+ status_msg_parts = [f"Loading 'MMaDA-8B-MixCoT'..."]
65
+ try:
66
+ TOKENIZER = AutoTokenizer.from_pretrained(DEFAULT_MODEL_PATH, trust_remote_code=True)
67
+ status_msg_parts.append(f"Tokenizer for 'MMaDA-8B-MixCoT' loaded.")
 
 
68
 
69
+ MODEL = MMadaModelLM.from_pretrained(DEFAULT_MODEL_PATH, trust_remote_code=True, torch_dtype=torch.bfloat16).eval()
70
+ status_msg_parts.append(f"Model 'MMaDA-8B-MixCoT' loaded to {DEVICE}.")
71
 
72
+ uni_prompting = UniversalPrompting(TOKENIZER, max_text_len=512, special_tokens=("<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>"),ignore_id=-100, cond_dropout_prob=0.1, use_reserved_token=True)
73
+
 
 
 
 
74
  MASK_ID = 126336
75
  status_msg_parts.append(f"Using default MASK_ID: {MASK_ID}.")
76
 
77
+ if TOKENIZER.pad_token_id is None:
78
+ if TOKENIZER.eos_token_id is not None:
79
+ TOKENIZER.pad_token_id = TOKENIZER.eos_token_id
80
+ TOKENIZER.pad_token = TOKENIZER.eos_token
81
+ status_msg_parts.append(f"Set pad_token_id to eos_token_id ({TOKENIZER.eos_token_id}).")
82
+
83
+ TOKENIZER.chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{{ '<|start_header_id|>assistant<|end_header_id|>\n' }}"
84
+
85
+ return " ".join(status_msg_parts)
86
+ except Exception as e:
87
+ MODEL, TOKENIZER, MASK_ID = None, None, None
88
+ return f"Error loading model 'MMaDA-8B-MixCoT': {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
+ # --- 可视化和生成函数 (generate_viz_wrapper* 系列,已修复全局变量问题) ---
91
  def get_highlighted_text_tuples(current_x_ids_batch, prompt_input_ids, prompt_len, tk, current_mask_id, raw_prompt_attention_mask):
92
  if current_x_ids_batch is None or current_x_ids_batch.ndim == 0 or current_x_ids_batch.shape[0] == 0:
93
  return [("Error in sequence data for visualization.", "ERROR")]
 
94
  current_x_ids_batch = current_x_ids_batch[:, prompt_len:]
95
  seq_ids = current_x_ids_batch[0].tolist()
 
 
 
 
96
  intermediate_tuples = []
97
  for j, token_id_int in enumerate(seq_ids):
98
  try:
99
  token_str = tk.decode([token_id_int], skip_special_tokens=True, clean_up_tokenization_spaces=False)
100
+ except Exception:
101
  token_str = f"[ID:{token_id_int}]"
102
 
103
  label = "ERROR"
 
107
  else:
108
  label = "GEN"
109
  intermediate_tuples.append((token_str, label, token_id_int))
 
110
  return intermediate_tuples
111
 
112
  @torch.no_grad()
113
  @spaces.GPU
114
  def generate_viz_wrapper_t2i(prompt_text, steps, guidance_scale, mask_schedule="cosine"):
115
  global MODEL, TOKENIZER, MASK_ID, DEVICE, uni_prompting, VQ_MODEL
 
116
  if MODEL is None or TOKENIZER is None or MASK_ID is None:
117
+ yield Image.new("RGB", (512, 512), (255, 255, 255)), "Error: Model not loaded. Please load the model first."
118
  return
 
119
  if DEVICE == 'cuda':
 
120
  MODEL.to(DEVICE)
121
  VQ_MODEL.to(DEVICE)
 
122
  try:
123
+ # ... (函数实现和之前一样)
124
  steps = int(steps)
125
  guidance_scale = float(guidance_scale)
 
126
  image_tokens = torch.ones((1, 1024), dtype=torch.long, device=DEVICE) * MASK_ID
127
  prompt_text = [prompt_text]
128
  input_ids, attention_mask = uni_prompting((prompt_text, image_tokens), 't2i_gen')
 
129
  if guidance_scale > 0:
130
  uncond_input_ids, uncond_attention_mask = uni_prompting(([''], image_tokens), 't2i_gen')
131
  else:
132
  uncond_input_ids, uncond_attention_mask = None, None
 
133
  mask_schedule = get_mask_schedule(mask_schedule)
134
  blank_image = Image.new("RGB", (512, 512), (255, 255, 255))
135
  yield blank_image, "Starting generation..."
136
  for image_step, status_msg_step in MODEL.t2i_generate_decoding_stepwise(
137
+ input_ids=input_ids, uncond_input_ids=uncond_input_ids, attention_mask=attention_mask,
138
+ uncond_attention_mask=uncond_attention_mask, temperature=1.0, timesteps=steps,
139
+ guidance_scale=guidance_scale, noise_schedule=mask_schedule, noise_type="mask",
140
+ seq_len=1024, vq_model=VQ_MODEL, uni_prompting=uni_prompting):
141
+ yield image_step, status_msg_step
 
 
 
 
 
 
 
 
 
142
  finally:
143
  if DEVICE == 'cuda':
 
144
  MODEL.to('cpu')
145
  VQ_MODEL.to('cpu')
146
  torch.cuda.empty_cache()
 
 
 
147
 
148
  @torch.no_grad()
149
  @spaces.GPU
150
  def generate_viz_wrapper_lm(prompt_text, steps, gen_length, block_length, temperature,
151
+ cfg_scale, remasking_strategy, thinking_mode_lm=False):
152
+ global MODEL, TOKENIZER, MASK_ID, DEVICE
153
  if MODEL is None or TOKENIZER is None or MASK_ID is None:
154
  yield [("Error: Model not loaded. Please load the model first.", "ERROR")], "Model not loaded."
155
  return
 
156
  if DEVICE == 'cuda':
 
157
  MODEL.to(DEVICE)
 
158
  try:
159
+ # ... (函数实现和之前一样)
160
+ steps, gen_length, block_length = int(steps), int(gen_length), int(block_length)
 
 
161
  if thinking_mode_lm:
162
  prompt_text = "You should first think about the reasoning process in the mind and then provide the user with the answer. The reasoning process is enclosed within <think> </think> tags, i.e. <think> reasoning process here </think> answer here\n" + prompt_text
163
+ m = [{"role": "user", "content": prompt_text}]
164
+ processed_prompt_text = TOKENIZER.apply_chat_template(m, add_generation_prompt=True, tokenize=False)
165
+ input_ids = TOKENIZER(text=processed_prompt_text, return_tensors="pt", padding="longest", padding_side="left", truncation=True, max_length=4096)['input_ids'].to(DEVICE)
166
+ raw_prompt_attention_mask = torch.ones_like(input_ids) # Dummy mask, adjust if needed
167
+ batch_size, prompt_len = input_ids.shape[0], input_ids.shape[1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  x = torch.full((batch_size, prompt_len + gen_length), MASK_ID, dtype=torch.long, device=DEVICE)
169
  x[:, :prompt_len] = input_ids.clone()
170
+ yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), "Starting generation..."
171
+ # ... (rest of the logic is the same)
 
 
 
 
 
 
 
 
 
 
172
  num_blocks = gen_length // block_length
 
 
 
 
 
173
  steps_per_block = steps // num_blocks
 
174
  for num_block_iter in range(num_blocks):
175
  current_block_start_idx_in_x = prompt_len + num_block_iter * block_length
176
  current_block_end_idx_in_x = prompt_len + (num_block_iter + 1) * block_length
177
+ block_masks_bool_current = torch.zeros_like(x, dtype=torch.bool)
178
+ block_masks_bool_current[:, current_block_start_idx_in_x:current_block_end_idx_in_x] = (x[:, current_block_start_idx_in_x:current_block_end_idx_in_x] == MASK_ID)
179
+ num_transfer_tokens_for_this_block = get_num_transfer_tokens(block_masks_bool_current[:, current_block_start_idx_in_x:current_block_end_idx_in_x], steps_per_block)
 
 
 
 
 
 
 
180
  for i_step_in_block in range(steps_per_block):
181
+ mask_index_global = (x == MASK_ID)
182
+ model_output = MODEL(x)
183
+ logits = model_output.logits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
185
+ x0_predicted_tokens = torch.argmax(logits_with_noise, dim=-1)
186
+ probs = F.softmax(logits.to(torch.float64), dim=-1)
187
+ x0_probs = torch.gather(probs, dim=-1, index=x0_predicted_tokens.unsqueeze(-1)).squeeze(-1)
188
+ confidence_for_selection = torch.where(mask_index_global & block_masks_bool_current, x0_probs, -torch.inf)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  x0_final_candidates = torch.where(mask_index_global, x0_predicted_tokens, x)
190
+ transfer_indices_bool = torch.zeros_like(x, dtype=torch.bool)
191
+ num_to_transfer_this_step_batch = num_transfer_tokens_for_this_block[:, i_step_in_block]
 
 
192
  for j_batch_idx in range(batch_size):
193
+ k_val = min(num_to_transfer_this_step_batch[j_batch_idx].item(), candidate_positions_for_unmasking[j_batch_idx].sum().item())
 
 
194
  if k_val > 0:
195
+ _, topk_indices_in_x = torch.topk(confidence_for_selection[j_batch_idx], k=k_val)
196
+ transfer_indices_bool[j_batch_idx, topk_indices_in_x] = True
 
 
 
 
 
 
 
 
 
 
197
  x[transfer_indices_bool] = x0_final_candidates[transfer_indices_bool]
198
+ status_msg = f"Block {num_block_iter+1}/{num_blocks}, Step {i_step_in_block+1}/{steps_per_block}"
 
 
 
199
  yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), status_msg
200
+ final_text_output = TOKENIZER.batch_decode(x[:, prompt_len:], skip_special_tokens=True)
201
+ yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), final_text_output[0]
 
 
 
 
 
202
  finally:
203
  if DEVICE == 'cuda':
 
204
  MODEL.to('cpu')
205
  torch.cuda.empty_cache()
206
 
207
+
208
  @torch.no_grad()
209
  @spaces.GPU
210
  def generate_viz_wrapper(uploaded_image_pil, prompt_text, steps, gen_length, block_length, temperature,
211
  cfg_scale, remasking_strategy, thinking_mode_mmu=False):
212
+ global MODEL, TOKENIZER, MASK_ID, DEVICE, VQ_MODEL
 
213
  if MODEL is None or TOKENIZER is None or MASK_ID is None:
214
  yield [("Error: Model not loaded. Please load the model first.", "ERROR")], "Model not loaded."
215
  return
 
216
  if DEVICE == 'cuda':
 
217
  MODEL.to(DEVICE)
218
  VQ_MODEL.to(DEVICE)
 
219
  try:
220
+ # ... (函数实现和之前���样)
221
+ steps, gen_length, block_length = int(steps), int(gen_length), int(block_length)
 
 
222
  if thinking_mode_mmu:
223
  prompt_text = "You should first think about the reasoning process in the mind and then provide the user with the answer. The reasoning process is enclosed within <think> </think> tags, i.e. <think> reasoning process here </think> answer here\n" + prompt_text
224
+ m = [{"role": "user", "content": prompt_text}]
225
+ processed_prompt_text = TOKENIZER.apply_chat_template(m, add_generation_prompt=True, tokenize=False)
226
+ image_vq_ids_tensor = None
 
 
 
 
 
 
227
  if uploaded_image_pil is not None:
228
+ image = image_transform(uploaded_image_pil, resolution=512).to(DEVICE).unsqueeze(0)
229
+ image_vq_ids_tensor = VQ_MODEL.get_code(image) + 126349
230
+ input_ids = TOKENIZER(text=processed_prompt_text, return_tensors="pt", padding="longest", padding_side="left", truncation=True, max_length=4096)['input_ids'].to(DEVICE)
231
+ raw_prompt_attention_mask = torch.ones_like(input_ids) # Dummy mask
232
+ if image_vq_ids_tensor is not None:
233
+ input_ids = torch.cat([(torch.ones(1, 1) * 126089).to(DEVICE), (torch.ones(1, 1) * 126084).to(DEVICE), image_vq_ids_tensor, (torch.ones(1, 1) * 126085).to(DEVICE), input_ids], dim=1).long()
234
+ batch_size, prompt_len = input_ids.shape[0], input_ids.shape[1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
  x = torch.full((batch_size, prompt_len + gen_length), MASK_ID, dtype=torch.long, device=DEVICE)
236
  x[:, :prompt_len] = input_ids.clone()
237
+ yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), "Starting generation..."
238
+ # ... (rest of the logic is the same)
 
 
 
 
 
 
 
 
 
 
239
  num_blocks = gen_length // block_length
 
 
 
 
 
240
  steps_per_block = steps // num_blocks
 
241
  for num_block_iter in range(num_blocks):
242
  current_block_start_idx_in_x = prompt_len + num_block_iter * block_length
243
  current_block_end_idx_in_x = prompt_len + (num_block_iter + 1) * block_length
244
+ block_masks_bool_current = torch.zeros_like(x, dtype=torch.bool)
245
+ block_masks_bool_current[:, current_block_start_idx_in_x:current_block_end_idx_in_x] = (x[:, current_block_start_idx_in_x:current_block_end_idx_in_x] == MASK_ID)
246
+ num_transfer_tokens_for_this_block = get_num_transfer_tokens(block_masks_bool_current[:, current_block_start_idx_in_x:current_block_end_idx_in_x], steps_per_block)
 
 
 
 
 
 
 
247
  for i_step_in_block in range(steps_per_block):
248
+ mask_index_global = (x == MASK_ID)
249
+ model_output = MODEL(x)
250
+ logits = model_output.logits
251
+ logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
252
+ x0_predicted_tokens = torch.argmax(logits_with_noise, dim=-1)
253
+ probs = F.softmax(logits.to(torch.float64), dim=-1)
254
+ x0_probs = torch.gather(probs, dim=-1, index=x0_predicted_tokens.unsqueeze(-1)).squeeze(-1)
255
+ confidence_for_selection = torch.where(mask_index_global & block_masks_bool_current, x0_probs, -torch.inf)
256
+ x0_final_candidates = torch.where(mask_index_global, x0_predicted_tokens, x)
257
+ transfer_indices_bool = torch.zeros_like(x, dtype=torch.bool)
258
+ num_to_transfer_this_step_batch = num_transfer_tokens_for_this_block[:, i_step_in_block]
259
+ for j_batch_idx in range(batch_size):
260
+ k_val = min(num_to_transfer_this_step_batch[j_batch_idx].item(), (mask_index_global & block_masks_bool_current)[j_batch_idx].sum().item())
261
+ if k_val > 0:
262
+ _, topk_indices_in_x = torch.topk(confidence_for_selection[j_batch_idx], k=k_val)
263
+ transfer_indices_bool[j_batch_idx, topk_indices_in_x] = True
264
+ x[transfer_indices_bool] = x0_final_candidates[transfer_indices_bool]
265
+ status_msg = f"Block {num_block_iter+1}/{num_blocks}, Step {i_step_in_block+1}/{steps_per_block}"
266
+ yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), status_msg
267
+ final_text_output = TOKENIZER.batch_decode(x[:, prompt_len:], skip_special_tokens=True)
268
+ yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), final_text_output[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
  finally:
270
  if DEVICE == 'cuda':
 
271
  MODEL.to('cpu')
272
  VQ_MODEL.to('cpu')
273
  torch.cuda.empty_cache()
274
 
275
 
276
+ # --- UI定义 ---
277
  css_styles = """
278
  .gradio-container{font-family:'IBM Plex Sans',sans-serif;margin:auto;}
279
  .gr-input {background:#f9f9f9 !important;border:1px solid #e0e0e0 !important;}
280
  .gr-output{background:#f0f0f0 !important;border:1px solid #d0d0d0 !important;}
281
+ .highlighted-text span{padding:2px 4px;border-radius:4px;margin:1px 2px;display:inline-block;line-height:1.6;}
 
 
 
 
282
  footer{display:none !important}
283
+ #live-update-scrollable-box {max-height: 800px; overflow-y: auto !important; display: block;}
284
+ #think_btn {background-color: #f3f4f6 !important; border: 1px solid #d0d0d0 !important; color: #111827 !important; font-size: 16px !important; font-weight: bold !important;}
285
+ #think_btn:hover {background-color: #e0e0e0 !important; border: 1px solid #c0c0c0 !important; color: #222 !important;}
286
+ #think_btn:active {background-color: #2563eb !important; border: 1px solid #b0b0b0 !important; color: white !important;}
287
+ .model-badge {padding: 5px 10px; border-radius: 15px; font-weight: bold; margin: 0 5px; display: inline-block;}
288
+ .active-model {background-color: #E879F9; color: white;}
289
+ .soon-model {background-color: #E5E7EB; color: #6B7280; cursor: not-allowed;}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
  """
291
 
292
+ def toggle_thinking_mode(current_thinking_mode):
 
 
 
 
 
 
 
293
  new_state = not current_thinking_mode
294
  new_label = "Thinking Mode ✅" if new_state else "Thinking Mode ❌"
295
  return new_state, gr.update(value=new_label)
296
 
297
+ color_map_config = {"MASK": "lightgrey", "GEN": "#DCABFA"}
298
 
299
+ theme = gr.themes.Ocean(primary_hue="fuchsia")
 
 
 
300
 
 
 
 
301
  with gr.Blocks(css=css_styles, theme=theme) as demo:
302
+ thinking_mode_lm = gr.State(True) # MixCoT模型默认开启
303
+ thinking_mode_mmu = gr.State(True) # MixCoT模型默认开启
304
+
305
+ # --- 标题和模型信息 (已修改) ---
 
 
 
 
306
  gr.HTML("""
307
  <div align="center" style="margin-bottom: 20px;">
308
  <img src='/gradio_api/file=title.png' width="160">
 
310
  MMaDA is a new class of multimodal diffusion foundation models, enabling state-of-the-art performance in reasoning, multimodal understanding, and text-to-image generation.
311
  </p>
312
  <p style="font-size: 15px;">
313
+ 📄 <a href="https://arxiv.org/abs/2405.15809" target="_blank">Paper</a>&nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;💻 <a href="https://github.com/Gen-Verse/MMaDA" target="_blank">Code</a>
 
 
314
  </p>
315
  </div>
316
  """)
317
+
318
  with gr.Row():
319
+ with gr.Column(scale=1):
320
+ gr.HTML("""
321
+ <div style="display: flex; justify-content: center; align-items: center; height: 100%;">
322
+ <div>
323
+ <span class="model-badge active-model">MMaDA-8B-MixCoT</span>
324
+ <span class="model-badge soon-model">MMaDA-8B-Max (coming soon)</span>
325
+ </div>
326
+ </div>
327
+ """)
328
+ with gr.Column(scale=2):
329
+ model_load_status_box = gr.Textbox(
330
+ label="Model Load Status", interactive=False, lines=3, max_lines=5
331
+ )
332
 
333
+ # --- Part 1. 文本生成 ---
334
  gr.Markdown("## Part 1. Text Generation")
335
  with gr.Row():
336
  with gr.Column(scale=2):
337
  prompt_input_box_lm = gr.Textbox(label="Enter your prompt:", lines=3, value="A rectangular prism has a length of 5 units, a width of 4 units, and a height of 3 units. What is the volume of the prism?")
338
+ think_button_lm = gr.Button("Thinking Mode", elem_id="think_btn")
339
  with gr.Accordion("Generation Parameters", open=True):
340
+ # ... 参数滑块 (未修改)
341
  with gr.Row():
342
+ gen_length_slider_lm = gr.Slider(minimum=8, maximum=1024, value=512, step=64, label="Generation Length")
343
+ steps_slider_lm = gr.Slider(minimum=1, maximum=512, value=256, step=32, label="Total Sampling Steps")
344
  with gr.Row():
345
+ block_length_slider_lm = gr.Slider(minimum=8, maximum=1024, value=128, step=32, label="Block Length")
346
  remasking_dropdown_lm = gr.Dropdown(choices=['low_confidence', 'random'], value='low_confidence', label="Remasking Strategy")
347
  with gr.Row():
348
+ cfg_scale_slider_lm = gr.Slider(minimum=0.0, maximum=2.0, value=0.0, step=0.1, label="CFG Scale")
349
+ temperature_slider_lm = gr.Slider(minimum=0.0, maximum=2.0, value=1, step=0.05, label="Temperature")
 
 
350
  with gr.Row():
351
  run_button_ui_lm = gr.Button("Generate Sequence", variant="primary", scale=3)
352
  clear_button_ui_lm = gr.Button("Clear Outputs", scale=1)
 
353
  with gr.Column(scale=3):
354
+ output_visualization_box_lm = gr.HighlightedText(label="Live Generation Process", show_legend=True, color_map=color_map_config, combine_adjacent=False, interactive=False, elem_id="live-update-scrollable-box")
 
 
 
 
 
 
 
 
 
355
  output_final_text_box_lm = gr.Textbox(label="Final Output", lines=8, interactive=False, show_copy_button=True)
356
+
357
+ # 仅保留 MixCoT 的示例 (已修改)
358
+ gr.Examples(
359
+ examples=[
360
+ ["A rectangular prism has a length of 5 units, a width of 4 units, and a height of 3 units. What is the volume of the prism?", 256, 512, 128, 1, 0, "low_confidence"],
361
+ ["Lily can run 12 kilometers per hour for 4 hours. After that, she can run 6 kilometers per hour. How many kilometers can she run in 8 hours?", 256, 512, 64, 1, 0, "low_confidence"]
362
+ ],
363
+ inputs=[prompt_input_box_lm, steps_slider_lm, gen_length_slider_lm, block_length_slider_lm, temperature_slider_lm, cfg_scale_slider_lm, remasking_dropdown_lm],
364
+ outputs=[output_visualization_box_lm, output_final_text_box_lm],
365
+ fn=generate_viz_wrapper_lm,
366
+ cache_examples=False
367
+ )
368
+
369
+ # --- Part 2 & 3 和事件处理器 (结构类似,已做简化) ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370
  gr.Markdown("---")
371
  gr.Markdown("## Part 2. Multimodal Understanding")
372
  with gr.Row():
373
+ # ... (Part 2 UI 结构未变)
374
  with gr.Column(scale=2):
375
+ prompt_input_box_mmu = gr.Textbox(label="Enter your prompt:", lines=3, value="")
376
+ think_button_mmu = gr.Button("Thinking Mode ", elem_id="think_btn")
 
 
 
 
377
  with gr.Accordion("Generation Parameters", open=True):
378
+ with gr.Row():
379
+ gen_length_slider_mmu = gr.Slider(minimum=64, maximum=1024, value=512, step=64, label="Generation Length")
380
+ steps_slider_mmu = gr.Slider(minimum=1, maximum=512, value=256, step=32, label="Total Sampling Steps")
381
+ with gr.Row():
382
+ block_length_slider_mmu = gr.Slider(minimum=32, maximum=1024, value=64, step=32, label="Block Length")
383
  remasking_dropdown_mmu = gr.Dropdown(choices=['low_confidence', 'random'], value='low_confidence', label="Remasking Strategy")
384
+ with gr.Row():
385
+ cfg_scale_slider_mmu = gr.Slider(minimum=0.0, maximum=2.0, value=0.0, step=0.1, label="CFG Scale")
386
+ temperature_slider_mmu = gr.Slider(minimum=0.0, maximum=2.0, value=1, step=0.05, label="Temperature")
 
387
  with gr.Row():
388
  image_upload_box = gr.Image(type="pil", label="Upload Image")
 
389
  with gr.Row():
390
  run_button_ui_mmu = gr.Button("Generate Description", variant="primary", scale=3)
391
  clear_button_ui_mmu = gr.Button("Clear Outputs", scale=1)
 
392
  with gr.Column(scale=3):
393
+ output_visualization_box_mmu = gr.HighlightedText(label="Token Sequence (Live Update)", show_legend=True, color_map=color_map_config, combine_adjacent=False, interactive=False, elem_id="live-update-scrollable-box")
 
 
 
 
 
 
 
 
 
394
  output_final_text_box_mmu = gr.Textbox(label="Final Output", lines=8, interactive=False, show_copy_button=True)
395
+
396
+ # 仅保留 MixCoT 的 MMU 示例
397
+ gr.Examples(
398
+ examples=[
399
+ ["figs/geo.png", "In the given figure, a square ABCD is inscribed in a circle with center O. Point P is located on side CD. What is the value of angle APB?", 256, 512, 64, 1, 0, "low_confidence"],
400
+ ["figs/bus.jpg", "What are the colors of the bus?", 256, 512, 64, 1, 0, "low_confidence"]
401
+ ],
402
+ inputs=[image_upload_box, prompt_input_box_mmu, steps_slider_mmu, gen_length_slider_mmu, block_length_slider_mmu, temperature_slider_mmu, cfg_scale_slider_mmu, remasking_dropdown_mmu],
403
+ outputs=[output_visualization_box_mmu, output_final_text_box_mmu],
404
+ fn=generate_viz_wrapper,
405
+ cache_examples=False
406
+ )
407
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
408
  gr.Markdown("---")
409
  gr.Markdown("## Part 3. Text-to-Image Generation")
410
+ # ... (Part 3 UI 和示例未变)
411
  with gr.Row():
412
  with gr.Column(scale=2):
413
  prompt_input_box_t2i = gr.Textbox(label="Enter your prompt:", lines=3, value="A sea turtle swimming near a coral reef in the ocean, with a clear blue sky and water in the background.")
 
414
  with gr.Accordion("Generation Parameters", open=True):
415
  with gr.Row():
416
+ steps_slider_t2i = gr.Slider(minimum=5, maximum=100, value=15, step=5, label="Total Sampling Steps")
417
+ guidance_scale_slider_t2i = gr.Slider(minimum=0.0, maximum=7.0, value=3.5, step=0.5, label="Guidance Scale")
418
+ with gr.Row():
419
+ scheduler_radio_t2i = gr.Radio(choices=["cosine", "sigmoid", "linear"], value="cosine", label="Scheduler")
 
 
 
 
 
 
 
420
  with gr.Row():
421
  run_button_ui_t2i = gr.Button("Generate Image", variant="primary", scale=3)
422
  clear_button_ui_t2i = gr.Button("Clear Outputs", scale=1)
 
 
423
  with gr.Column(scale=3):
 
424
  output_image_t2i = gr.Image(label="Generated Image", interactive=False, type="pil")
425
  output_status_t2i = gr.Textbox(label="Generation Status", interactive=False)
 
426
  gr.Examples(
427
  examples=[
428
  ["A sea turtle swimming near a coral reef in the ocean, with a clear blue sky and water in the background.", 15, 3.5, "cosine"],
 
433
  fn=generate_viz_wrapper_t2i,
434
  cache_examples=False
435
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
436
 
437
+ # --- 应用启动和事件处理 (已简化) ---
438
  def initialize_app_state():
439
+ global VQ_MODEL
440
+ print("Loading VQ_MODEL for the first time...")
441
+ VQ_MODEL = MAGVITv2().from_pretrained("showlab/magvitv2")
442
+ print("VQ_MODEL loaded to CPU.")
443
 
444
+ status = load_model_and_tokenizer()
445
+ # MixCoT模型默认开启Thinking Mode
446
+ return status, True, gr.update(value="Thinking Mode ✅"), True, gr.update(value="Thinking Mode ✅")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
447
 
448
  demo.load(
449
  fn=initialize_app_state,
450
  inputs=None,
451
  outputs=[
 
452
  model_load_status_box,
453
+ thinking_mode_lm,
 
 
 
 
 
 
454
  think_button_lm,
455
  thinking_mode_mmu,
456
  think_button_mmu
457
  ],
458
  queue=True
459
  )
460
+
461
+ # 清除按钮事件
462
+ clear_button_ui_lm.click(fn=lambda: (None, None), inputs=None, outputs=[output_visualization_box_lm, output_final_text_box_lm], queue=False)
463
+ clear_button_ui_mmu.click(fn=lambda: (None, None, None), inputs=None, outputs=[image_upload_box, output_visualization_box_mmu, output_final_text_box_mmu], queue=False)
464
+ clear_button_ui_t2i.click(fn=lambda: (None, ""), inputs=None, outputs=[output_image_t2i, output_status_t2i], queue=False)
465
+
466
+ # Thinking Mode 切换事件
467
+ think_button_lm.click(fn=toggle_thinking_mode, inputs=[thinking_mode_lm], outputs=[thinking_mode_lm, think_button_lm])
468
+ think_button_mmu.click(fn=toggle_thinking_mode, inputs=[thinking_mode_mmu], outputs=[thinking_mode_mmu, think_button_mmu])
469
+
470
+ # 生成按钮事件
471
+ run_button_ui_lm.click(fn=generate_viz_wrapper_lm, inputs=[prompt_input_box_lm, steps_slider_lm, gen_length_slider_lm, block_length_slider_lm, temperature_slider_lm, cfg_scale_slider_lm, remasking_dropdown_lm, thinking_mode_lm], outputs=[output_visualization_box_lm, output_final_text_box_lm])
472
+ run_button_ui_mmu.click(fn=generate_viz_wrapper, inputs=[image_upload_box, prompt_input_box_mmu, steps_slider_mmu, gen_length_slider_mmu, block_length_slider_mmu, temperature_slider_mmu, cfg_scale_slider_mmu, remasking_dropdown_mmu, thinking_mode_mmu], outputs=[output_visualization_box_mmu, output_final_text_box_mmu])
473
+ run_button_ui_t2i.click(fn=generate_viz_wrapper_t2i, inputs=[prompt_input_box_t2i, steps_slider_t2i, guidance_scale_slider_t2i, scheduler_radio_t2i], outputs=[output_image_t2i, output_status_t2i])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
474
 
475
  if __name__ == "__main__":
476
  print(f"Starting Gradio App. Attempting to use device: {DEVICE}")
477
+ demo.launch(allowed_paths=["title.png", "figs"])