YucYux commited on
Commit
9c9354f
·
1 Parent(s): ca6c4d6

tried to fix bug

Browse files
Files changed (1) hide show
  1. app.py +664 -276
app.py CHANGED
@@ -10,7 +10,6 @@ from PIL import Image
10
  import spaces
11
 
12
 
13
- # --- 辅助函数 (未修改) ---
14
  def image_transform(image, resolution=256, normalize=True):
15
  image = transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BICUBIC)(image)
16
  image = transforms.CenterCrop((resolution, resolution))(image)
@@ -20,84 +19,133 @@ def image_transform(image, resolution=256, normalize=True):
20
  return image
21
 
22
  def add_gumbel_noise(logits, temperature):
23
- if abs(temperature) < 1e-9:
 
 
 
 
 
24
  return logits
 
25
  logits = logits.to(torch.float64)
 
 
26
  noise = torch.rand_like(logits, dtype=torch.float64)
27
  standard_gumbel_noise = -torch.log(-torch.log(noise + 1e-20) + 1e-20)
28
  return logits + temperature * standard_gumbel_noise
29
 
30
  def get_num_transfer_tokens(mask_index, steps):
31
  mask_num = mask_index.sum(dim=1, keepdim=True)
32
- steps = max(1, int(steps))
 
33
  base = mask_num // steps
34
  remainder = mask_num % steps
35
  num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.long) + base
36
- for i in range(mask_num.size(0)):
37
- if remainder[i] > 0 :
38
- num_transfer_tokens[i, :remainder[i].item()] += 1
39
  return num_transfer_tokens
40
 
41
- # --- 全局变量和模型配置 ---
42
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
43
- # 固定使用 MMaDA-8B-MixCoT 模型
44
- DEFAULT_MODEL_PATH = "Gen-Verse/MMaDA-8B-MixCoT"
45
- MASK_ID = None
46
- MODEL = None
47
- TOKENIZER = None
48
- uni_prompting = None
49
- VQ_MODEL = None
 
50
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
- # --- 核心模型加载函数 (已简化) ---
53
  @spaces.GPU
54
- def load_model_and_tokenizer():
55
- """
56
- 加载固定的 MMaDA-8B-MixCoT 模型和分词器。
57
- """
58
- global MODEL, TOKENIZER, MASK_ID, DEVICE, uni_prompting
59
 
60
- # 如果模型已经加载,则直接返回
61
- if MODEL is not None:
62
- return f"Model 'MMaDA-8B-MixCoT' is already loaded. MASK_ID: {MASK_ID}"
63
 
64
- status_msg_parts = [f"Loading 'MMaDA-8B-MixCoT'..."]
65
- try:
66
- TOKENIZER = AutoTokenizer.from_pretrained(DEFAULT_MODEL_PATH, trust_remote_code=True)
67
- status_msg_parts.append(f"Tokenizer for 'MMaDA-8B-MixCoT' loaded.")
68
 
69
- MODEL = MMadaModelLM.from_pretrained(DEFAULT_MODEL_PATH, trust_remote_code=True, torch_dtype=torch.bfloat16).eval()
70
- status_msg_parts.append(f"Model 'MMaDA-8B-MixCoT' loaded to {DEVICE}.")
 
 
71
 
72
- uni_prompting = UniversalPrompting(TOKENIZER, max_text_len=512, special_tokens=("<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>"),ignore_id=-100, cond_dropout_prob=0.1, use_reserved_token=True)
73
-
 
 
 
 
 
 
 
74
  MASK_ID = 126336
75
  status_msg_parts.append(f"Using default MASK_ID: {MASK_ID}.")
76
 
77
- if TOKENIZER.pad_token_id is None:
78
- if TOKENIZER.eos_token_id is not None:
79
- TOKENIZER.pad_token_id = TOKENIZER.eos_token_id
80
- TOKENIZER.pad_token = TOKENIZER.eos_token
81
- status_msg_parts.append(f"Set pad_token_id to eos_token_id ({TOKENIZER.eos_token_id}).")
82
-
83
- TOKENIZER.chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{{ '<|start_header_id|>assistant<|end_header_id|>\n' }}"
84
-
85
- return " ".join(status_msg_parts)
86
- except Exception as e:
87
- MODEL, TOKENIZER, MASK_ID = None, None, None
88
- return f"Error loading model 'MMaDA-8B-MixCoT': {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
- # --- 可视化和生成函数 (generate_viz_wrapper* 系列,已修复全局变量问题) ---
91
  def get_highlighted_text_tuples(current_x_ids_batch, prompt_input_ids, prompt_len, tk, current_mask_id, raw_prompt_attention_mask):
92
  if current_x_ids_batch is None or current_x_ids_batch.ndim == 0 or current_x_ids_batch.shape[0] == 0:
93
  return [("Error in sequence data for visualization.", "ERROR")]
 
94
  current_x_ids_batch = current_x_ids_batch[:, prompt_len:]
95
  seq_ids = current_x_ids_batch[0].tolist()
 
 
 
 
96
  intermediate_tuples = []
97
  for j, token_id_int in enumerate(seq_ids):
98
  try:
99
  token_str = tk.decode([token_id_int], skip_special_tokens=True, clean_up_tokenization_spaces=False)
100
- except Exception:
101
  token_str = f"[ID:{token_id_int}]"
102
 
103
  label = "ERROR"
@@ -107,202 +155,452 @@ def get_highlighted_text_tuples(current_x_ids_batch, prompt_input_ids, prompt_le
107
  else:
108
  label = "GEN"
109
  intermediate_tuples.append((token_str, label, token_id_int))
 
110
  return intermediate_tuples
111
 
112
  @torch.no_grad()
113
  @spaces.GPU
114
  def generate_viz_wrapper_t2i(prompt_text, steps, guidance_scale, mask_schedule="cosine"):
115
- global MODEL, TOKENIZER, MASK_ID, DEVICE, uni_prompting, VQ_MODEL
 
116
  if MODEL is None or TOKENIZER is None or MASK_ID is None:
117
- yield Image.new("RGB", (512, 512), (255, 255, 255)), "Error: Model not loaded. Please load the model first."
118
  return
119
- if DEVICE == 'cuda':
120
- MODEL.to(DEVICE)
121
- VQ_MODEL.to(DEVICE)
122
- try:
123
- # ... (函数实现和之前一样)
124
- steps = int(steps)
125
- guidance_scale = float(guidance_scale)
126
- image_tokens = torch.ones((1, 1024), dtype=torch.long, device=DEVICE) * MASK_ID
127
- prompt_text = [prompt_text]
128
- input_ids, attention_mask = uni_prompting((prompt_text, image_tokens), 't2i_gen')
129
- if guidance_scale > 0:
130
- uncond_input_ids, uncond_attention_mask = uni_prompting(([''], image_tokens), 't2i_gen')
131
- else:
132
- uncond_input_ids, uncond_attention_mask = None, None
133
- mask_schedule = get_mask_schedule(mask_schedule)
134
- blank_image = Image.new("RGB", (512, 512), (255, 255, 255))
135
- yield blank_image, "Starting generation..."
136
- for image_step, status_msg_step in MODEL.t2i_generate_decoding_stepwise(
137
- input_ids=input_ids, uncond_input_ids=uncond_input_ids, attention_mask=attention_mask,
138
- uncond_attention_mask=uncond_attention_mask, temperature=1.0, timesteps=steps,
139
- guidance_scale=guidance_scale, noise_schedule=mask_schedule, noise_type="mask",
140
- seq_len=1024, vq_model=VQ_MODEL, uni_prompting=uni_prompting):
141
- yield image_step, status_msg_step
142
- finally:
143
- if DEVICE == 'cuda':
144
- MODEL.to('cpu')
145
- VQ_MODEL.to('cpu')
146
- torch.cuda.empty_cache()
 
 
 
 
147
 
148
  @torch.no_grad()
149
  @spaces.GPU
150
  def generate_viz_wrapper_lm(prompt_text, steps, gen_length, block_length, temperature,
151
- cfg_scale, remasking_strategy, thinking_mode_lm=False):
152
  global MODEL, TOKENIZER, MASK_ID, DEVICE
153
  if MODEL is None or TOKENIZER is None or MASK_ID is None:
154
  yield [("Error: Model not loaded. Please load the model first.", "ERROR")], "Model not loaded."
155
  return
156
- if DEVICE == 'cuda':
157
- MODEL.to(DEVICE)
 
 
 
 
 
 
158
  try:
159
- # ... (函数实现和之前一样)
160
- steps, gen_length, block_length = int(steps), int(gen_length), int(block_length)
161
- if thinking_mode_lm:
162
- prompt_text = "You should first think about the reasoning process in the mind and then provide the user with the answer. The reasoning process is enclosed within <think> </think> tags, i.e. <think> reasoning process here </think> answer here\n" + prompt_text
163
  m = [{"role": "user", "content": prompt_text}]
164
  processed_prompt_text = TOKENIZER.apply_chat_template(m, add_generation_prompt=True, tokenize=False)
165
- input_ids = TOKENIZER(text=processed_prompt_text, return_tensors="pt", padding="longest", padding_side="left", truncation=True, max_length=4096)['input_ids'].to(DEVICE)
166
- raw_prompt_attention_mask = torch.ones_like(input_ids) # Dummy mask, adjust if needed
167
- batch_size, prompt_len = input_ids.shape[0], input_ids.shape[1]
168
- x = torch.full((batch_size, prompt_len + gen_length), MASK_ID, dtype=torch.long, device=DEVICE)
169
- x[:, :prompt_len] = input_ids.clone()
170
- yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), "Starting generation..."
171
- # ... (rest of the logic is the same)
172
- num_blocks = gen_length // block_length
173
- steps_per_block = steps // num_blocks
174
- for num_block_iter in range(num_blocks):
175
- current_block_start_idx_in_x = prompt_len + num_block_iter * block_length
176
- current_block_end_idx_in_x = prompt_len + (num_block_iter + 1) * block_length
177
- block_masks_bool_current = torch.zeros_like(x, dtype=torch.bool)
178
- block_masks_bool_current[:, current_block_start_idx_in_x:current_block_end_idx_in_x] = (x[:, current_block_start_idx_in_x:current_block_end_idx_in_x] == MASK_ID)
179
- num_transfer_tokens_for_this_block = get_num_transfer_tokens(block_masks_bool_current[:, current_block_start_idx_in_x:current_block_end_idx_in_x], steps_per_block)
180
- for i_step_in_block in range(steps_per_block):
181
- mask_index_global = (x == MASK_ID)
182
- model_output = MODEL(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  logits = model_output.logits
184
- logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
185
- x0_predicted_tokens = torch.argmax(logits_with_noise, dim=-1)
186
- probs = F.softmax(logits.to(torch.float64), dim=-1)
187
- x0_probs = torch.gather(probs, dim=-1, index=x0_predicted_tokens.unsqueeze(-1)).squeeze(-1)
188
- confidence_for_selection = torch.where(mask_index_global & block_masks_bool_current, x0_probs, -torch.inf)
189
- x0_final_candidates = torch.where(mask_index_global, x0_predicted_tokens, x)
190
- transfer_indices_bool = torch.zeros_like(x, dtype=torch.bool)
191
- num_to_transfer_this_step_batch = num_transfer_tokens_for_this_block[:, i_step_in_block]
192
- for j_batch_idx in range(batch_size):
193
- k_val = min(num_to_transfer_this_step_batch[j_batch_idx].item(), candidate_positions_for_unmasking[j_batch_idx].sum().item())
194
- if k_val > 0:
195
- _, topk_indices_in_x = torch.topk(confidence_for_selection[j_batch_idx], k=k_val)
196
- transfer_indices_bool[j_batch_idx, topk_indices_in_x] = True
197
- x[transfer_indices_bool] = x0_final_candidates[transfer_indices_bool]
198
- status_msg = f"Block {num_block_iter+1}/{num_blocks}, Step {i_step_in_block+1}/{steps_per_block}"
199
- yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), status_msg
200
- final_text_output = TOKENIZER.batch_decode(x[:, prompt_len:], skip_special_tokens=True)
201
- yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), final_text_output[0]
202
- finally:
203
- if DEVICE == 'cuda':
204
- MODEL.to('cpu')
205
- torch.cuda.empty_cache()
206
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
 
208
  @torch.no_grad()
209
  @spaces.GPU
210
  def generate_viz_wrapper(uploaded_image_pil, prompt_text, steps, gen_length, block_length, temperature,
211
  cfg_scale, remasking_strategy, thinking_mode_mmu=False):
212
- global MODEL, TOKENIZER, MASK_ID, DEVICE, VQ_MODEL
 
213
  if MODEL is None or TOKENIZER is None or MASK_ID is None:
214
  yield [("Error: Model not loaded. Please load the model first.", "ERROR")], "Model not loaded."
215
  return
216
- if DEVICE == 'cuda':
217
- MODEL.to(DEVICE)
218
- VQ_MODEL.to(DEVICE)
 
 
 
 
 
219
  try:
220
- # ... (函数实现和之前一样)
221
- steps, gen_length, block_length = int(steps), int(gen_length), int(block_length)
222
- if thinking_mode_mmu:
223
- prompt_text = "You should first think about the reasoning process in the mind and then provide the user with the answer. The reasoning process is enclosed within <think> </think> tags, i.e. <think> reasoning process here </think> answer here\n" + prompt_text
224
  m = [{"role": "user", "content": prompt_text}]
225
  processed_prompt_text = TOKENIZER.apply_chat_template(m, add_generation_prompt=True, tokenize=False)
226
- image_vq_ids_tensor = None
227
- if uploaded_image_pil is not None:
228
- image = image_transform(uploaded_image_pil, resolution=512).to(DEVICE).unsqueeze(0)
229
- image_vq_ids_tensor = VQ_MODEL.get_code(image) + 126349
230
- input_ids = TOKENIZER(text=processed_prompt_text, return_tensors="pt", padding="longest", padding_side="left", truncation=True, max_length=4096)['input_ids'].to(DEVICE)
231
- raw_prompt_attention_mask = torch.ones_like(input_ids) # Dummy mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
  if image_vq_ids_tensor is not None:
233
- input_ids = torch.cat([(torch.ones(1, 1) * 126089).to(DEVICE), (torch.ones(1, 1) * 126084).to(DEVICE), image_vq_ids_tensor, (torch.ones(1, 1) * 126085).to(DEVICE), input_ids], dim=1).long()
234
- batch_size, prompt_len = input_ids.shape[0], input_ids.shape[1]
235
- x = torch.full((batch_size, prompt_len + gen_length), MASK_ID, dtype=torch.long, device=DEVICE)
236
- x[:, :prompt_len] = input_ids.clone()
237
- yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), "Starting generation..."
238
- # ... (rest of the logic is the same)
239
- num_blocks = gen_length // block_length
240
- steps_per_block = steps // num_blocks
241
- for num_block_iter in range(num_blocks):
242
- current_block_start_idx_in_x = prompt_len + num_block_iter * block_length
243
- current_block_end_idx_in_x = prompt_len + (num_block_iter + 1) * block_length
244
- block_masks_bool_current = torch.zeros_like(x, dtype=torch.bool)
245
- block_masks_bool_current[:, current_block_start_idx_in_x:current_block_end_idx_in_x] = (x[:, current_block_start_idx_in_x:current_block_end_idx_in_x] == MASK_ID)
246
- num_transfer_tokens_for_this_block = get_num_transfer_tokens(block_masks_bool_current[:, current_block_start_idx_in_x:current_block_end_idx_in_x], steps_per_block)
247
- for i_step_in_block in range(steps_per_block):
248
- mask_index_global = (x == MASK_ID)
249
- model_output = MODEL(x)
250
- logits = model_output.logits
251
- logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
252
- x0_predicted_tokens = torch.argmax(logits_with_noise, dim=-1)
253
- probs = F.softmax(logits.to(torch.float64), dim=-1)
254
- x0_probs = torch.gather(probs, dim=-1, index=x0_predicted_tokens.unsqueeze(-1)).squeeze(-1)
255
- confidence_for_selection = torch.where(mask_index_global & block_masks_bool_current, x0_probs, -torch.inf)
256
- x0_final_candidates = torch.where(mask_index_global, x0_predicted_tokens, x)
257
- transfer_indices_bool = torch.zeros_like(x, dtype=torch.bool)
258
- num_to_transfer_this_step_batch = num_transfer_tokens_for_this_block[:, i_step_in_block]
259
- for j_batch_idx in range(batch_size):
260
- k_val = min(num_to_transfer_this_step_batch[j_batch_idx].item(), (mask_index_global & block_masks_bool_current)[j_batch_idx].sum().item())
261
- if k_val > 0:
262
- _, topk_indices_in_x = torch.topk(confidence_for_selection[j_batch_idx], k=k_val)
263
- transfer_indices_bool[j_batch_idx, topk_indices_in_x] = True
264
- x[transfer_indices_bool] = x0_final_candidates[transfer_indices_bool]
265
- status_msg = f"Block {num_block_iter+1}/{num_blocks}, Step {i_step_in_block+1}/{steps_per_block}"
266
- yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), status_msg
267
- final_text_output = TOKENIZER.batch_decode(x[:, prompt_len:], skip_special_tokens=True)
268
- yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), final_text_output[0]
269
- finally:
270
- if DEVICE == 'cuda':
271
- MODEL.to('cpu')
272
- VQ_MODEL.to('cpu')
273
- torch.cuda.empty_cache()
274
-
275
-
276
- # --- UI定义 ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
  css_styles = """
278
  .gradio-container{font-family:'IBM Plex Sans',sans-serif;margin:auto;}
279
  .gr-input {background:#f9f9f9 !important;border:1px solid #e0e0e0 !important;}
280
  .gr-output{background:#f0f0f0 !important;border:1px solid #d0d0d0 !important;}
281
- .highlighted-text span{padding:2px 4px;border-radius:4px;margin:1px 2px;display:inline-block;line-height:1.6;}
 
 
 
 
282
  footer{display:none !important}
283
- #live-update-scrollable-box {max-height: 800px; overflow-y: auto !important; display: block;}
284
- #think_btn {background-color: #f3f4f6 !important; border: 1px solid #d0d0d0 !important; color: #111827 !important; font-size: 16px !important; font-weight: bold !important;}
285
- #think_btn:hover {background-color: #e0e0e0 !important; border: 1px solid #c0c0c0 !important; color: #222 !important;}
286
- #think_btn:active {background-color: #2563eb !important; border: 1px solid #b0b0b0 !important; color: white !important;}
287
- .model-badge {padding: 5px 10px; border-radius: 15px; font-weight: bold; margin: 0 5px; display: inline-block;}
288
- .active-model {background-color: #E879F9; color: white;}
289
- .soon-model {background-color: #E5E7EB; color: #6B7280; cursor: not-allowed;}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
  """
291
 
292
- def toggle_thinking_mode(current_thinking_mode):
 
 
 
 
 
 
 
293
  new_state = not current_thinking_mode
294
  new_label = "Thinking Mode ✅" if new_state else "Thinking Mode ❌"
295
  return new_state, gr.update(value=new_label)
296
 
297
- color_map_config = {"MASK": "lightgrey", "GEN": "#DCABFA"}
298
 
299
- theme = gr.themes.Ocean(primary_hue="fuchsia")
 
 
 
300
 
 
 
 
301
  with gr.Blocks(css=css_styles, theme=theme) as demo:
302
- thinking_mode_lm = gr.State(True) # MixCoT模型默认开启
303
- thinking_mode_mmu = gr.State(True) # MixCoT模型默认开启
304
-
305
- # --- 标题和模型信息 (已修改) ---
 
 
 
 
306
  gr.HTML("""
307
  <div align="center" style="margin-bottom: 20px;">
308
  <img src='/gradio_api/file=title.png' width="160">
@@ -310,51 +608,60 @@ with gr.Blocks(css=css_styles, theme=theme) as demo:
310
  MMaDA is a new class of multimodal diffusion foundation models, enabling state-of-the-art performance in reasoning, multimodal understanding, and text-to-image generation.
311
  </p>
312
  <p style="font-size: 15px;">
313
- 📄 <a href="https://arxiv.org/abs/2405.15809" target="_blank">Paper</a>&nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;💻 <a href="https://github.com/Gen-Verse/MMaDA" target="_blank">Code</a>
 
 
314
  </p>
315
  </div>
316
  """)
317
-
318
  with gr.Row():
319
- with gr.Column(scale=1):
320
- gr.HTML("""
321
- <div style="display: flex; justify-content: center; align-items: center; height: 100%;">
322
- <div>
323
- <span class="model-badge active-model">MMaDA-8B-MixCoT</span>
324
- <span class="model-badge soon-model">MMaDA-8B-Max (coming soon)</span>
325
- </div>
326
- </div>
327
- """)
328
- with gr.Column(scale=2):
329
- model_load_status_box = gr.Textbox(
330
- label="Model Load Status", interactive=False, lines=3, max_lines=5
331
- )
332
 
333
- # --- Part 1. 文本生成 ---
334
  gr.Markdown("## Part 1. Text Generation")
335
  with gr.Row():
336
  with gr.Column(scale=2):
337
  prompt_input_box_lm = gr.Textbox(label="Enter your prompt:", lines=3, value="A rectangular prism has a length of 5 units, a width of 4 units, and a height of 3 units. What is the volume of the prism?")
338
  think_button_lm = gr.Button("Thinking Mode ✅", elem_id="think_btn")
339
  with gr.Accordion("Generation Parameters", open=True):
340
- # ... 参数滑块 (未修改)
341
  with gr.Row():
342
- gen_length_slider_lm = gr.Slider(minimum=8, maximum=1024, value=512, step=64, label="Generation Length")
343
- steps_slider_lm = gr.Slider(minimum=1, maximum=512, value=256, step=32, label="Total Sampling Steps")
344
  with gr.Row():
345
- block_length_slider_lm = gr.Slider(minimum=8, maximum=1024, value=128, step=32, label="Block Length")
346
  remasking_dropdown_lm = gr.Dropdown(choices=['low_confidence', 'random'], value='low_confidence', label="Remasking Strategy")
347
  with gr.Row():
348
- cfg_scale_slider_lm = gr.Slider(minimum=0.0, maximum=2.0, value=0.0, step=0.1, label="CFG Scale")
349
- temperature_slider_lm = gr.Slider(minimum=0.0, maximum=2.0, value=1, step=0.05, label="Temperature")
 
 
350
  with gr.Row():
351
  run_button_ui_lm = gr.Button("Generate Sequence", variant="primary", scale=3)
352
  clear_button_ui_lm = gr.Button("Clear Outputs", scale=1)
 
353
  with gr.Column(scale=3):
354
- output_visualization_box_lm = gr.HighlightedText(label="Live Generation Process", show_legend=True, color_map=color_map_config, combine_adjacent=False, interactive=False, elem_id="live-update-scrollable-box")
 
 
 
 
 
 
 
 
 
355
  output_final_text_box_lm = gr.Textbox(label="Final Output", lines=8, interactive=False, show_copy_button=True)
356
-
357
- # 仅保留 MixCoT 的示例 (已修改)
 
358
  gr.Examples(
359
  examples=[
360
  ["A rectangular prism has a length of 5 units, a width of 4 units, and a height of 3 units. What is the volume of the prism?", 256, 512, 128, 1, 0, "low_confidence"],
@@ -365,64 +672,98 @@ with gr.Blocks(css=css_styles, theme=theme) as demo:
365
  fn=generate_viz_wrapper_lm,
366
  cache_examples=False
367
  )
368
-
369
- # --- Part 2 & 3 和事件处理器 (结构类似,已做简化) ---
370
  gr.Markdown("---")
371
  gr.Markdown("## Part 2. Multimodal Understanding")
372
  with gr.Row():
373
- # ... (Part 2 UI 结构未变)
374
  with gr.Column(scale=2):
375
- prompt_input_box_mmu = gr.Textbox(label="Enter your prompt:", lines=3, value="")
 
 
 
 
376
  think_button_mmu = gr.Button("Thinking Mode ✅", elem_id="think_btn")
377
  with gr.Accordion("Generation Parameters", open=True):
378
- with gr.Row():
379
- gen_length_slider_mmu = gr.Slider(minimum=64, maximum=1024, value=512, step=64, label="Generation Length")
380
- steps_slider_mmu = gr.Slider(minimum=1, maximum=512, value=256, step=32, label="Total Sampling Steps")
381
- with gr.Row():
382
- block_length_slider_mmu = gr.Slider(minimum=32, maximum=1024, value=64, step=32, label="Block Length")
383
  remasking_dropdown_mmu = gr.Dropdown(choices=['low_confidence', 'random'], value='low_confidence', label="Remasking Strategy")
384
- with gr.Row():
385
- cfg_scale_slider_mmu = gr.Slider(minimum=0.0, maximum=2.0, value=0.0, step=0.1, label="CFG Scale")
386
- temperature_slider_mmu = gr.Slider(minimum=0.0, maximum=2.0, value=1, step=0.05, label="Temperature")
 
387
  with gr.Row():
388
  image_upload_box = gr.Image(type="pil", label="Upload Image")
 
389
  with gr.Row():
390
  run_button_ui_mmu = gr.Button("Generate Description", variant="primary", scale=3)
391
  clear_button_ui_mmu = gr.Button("Clear Outputs", scale=1)
 
392
  with gr.Column(scale=3):
393
- output_visualization_box_mmu = gr.HighlightedText(label="Token Sequence (Live Update)", show_legend=True, color_map=color_map_config, combine_adjacent=False, interactive=False, elem_id="live-update-scrollable-box")
 
 
 
 
 
 
 
 
 
394
  output_final_text_box_mmu = gr.Textbox(label="Final Output", lines=8, interactive=False, show_copy_button=True)
395
-
396
- # 仅保留 MixCoT 的 MMU 示例
397
  gr.Examples(
398
  examples=[
399
  ["figs/geo.png", "In the given figure, a square ABCD is inscribed in a circle with center O. Point P is located on side CD. What is the value of angle APB?", 256, 512, 64, 1, 0, "low_confidence"],
400
  ["figs/bus.jpg", "What are the colors of the bus?", 256, 512, 64, 1, 0, "low_confidence"]
401
  ],
402
- inputs=[image_upload_box, prompt_input_box_mmu, steps_slider_mmu, gen_length_slider_mmu, block_length_slider_mmu, temperature_slider_mmu, cfg_scale_slider_mmu, remasking_dropdown_mmu],
 
 
 
 
 
 
 
 
 
403
  outputs=[output_visualization_box_mmu, output_final_text_box_mmu],
404
  fn=generate_viz_wrapper,
405
  cache_examples=False
406
  )
407
-
408
  gr.Markdown("---")
409
  gr.Markdown("## Part 3. Text-to-Image Generation")
410
- # ... (Part 3 UI 和示例未变)
411
  with gr.Row():
412
  with gr.Column(scale=2):
413
  prompt_input_box_t2i = gr.Textbox(label="Enter your prompt:", lines=3, value="A sea turtle swimming near a coral reef in the ocean, with a clear blue sky and water in the background.")
 
414
  with gr.Accordion("Generation Parameters", open=True):
415
  with gr.Row():
416
- steps_slider_t2i = gr.Slider(minimum=5, maximum=100, value=15, step=5, label="Total Sampling Steps")
417
- guidance_scale_slider_t2i = gr.Slider(minimum=0.0, maximum=7.0, value=3.5, step=0.5, label="Guidance Scale")
418
- with gr.Row():
419
- scheduler_radio_t2i = gr.Radio(choices=["cosine", "sigmoid", "linear"], value="cosine", label="Scheduler")
 
 
 
 
 
 
 
420
  with gr.Row():
421
  run_button_ui_t2i = gr.Button("Generate Image", variant="primary", scale=3)
422
  clear_button_ui_t2i = gr.Button("Clear Outputs", scale=1)
 
 
423
  with gr.Column(scale=3):
 
424
  output_image_t2i = gr.Image(label="Generated Image", interactive=False, type="pil")
425
  output_status_t2i = gr.Textbox(label="Generation Status", interactive=False)
 
426
  gr.Examples(
427
  examples=[
428
  ["A sea turtle swimming near a coral reef in the ocean, with a clear blue sky and water in the background.", 15, 3.5, "cosine"],
@@ -433,45 +774,92 @@ with gr.Blocks(css=css_styles, theme=theme) as demo:
433
  fn=generate_viz_wrapper_t2i,
434
  cache_examples=False
435
  )
 
 
 
 
 
 
 
 
 
 
 
436
 
437
- # --- 应用启动和事件处理 (已简化) ---
438
- def initialize_app_state():
439
- global VQ_MODEL
440
- print("Loading VQ_MODEL for the first time...")
441
- VQ_MODEL = MAGVITv2().from_pretrained("showlab/magvitv2")
442
- print("VQ_MODEL loaded to CPU.")
443
-
444
- status = load_model_and_tokenizer()
445
- # MixCoT模型默认开启Thinking Mode
446
- return status, True, gr.update(value="Thinking Mode ✅"), True, gr.update(value="Thinking Mode ✅")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
447
 
448
- demo.load(
449
- fn=initialize_app_state,
450
  inputs=None,
451
- outputs=[
452
- model_load_status_box,
453
- thinking_mode_lm,
454
- think_button_lm,
455
- thinking_mode_mmu,
456
- think_button_mmu
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
457
  ],
458
- queue=True
459
  )
460
-
461
- # 清除按钮事件
462
- clear_button_ui_lm.click(fn=lambda: (None, None), inputs=None, outputs=[output_visualization_box_lm, output_final_text_box_lm], queue=False)
463
- clear_button_ui_mmu.click(fn=lambda: (None, None, None), inputs=None, outputs=[image_upload_box, output_visualization_box_mmu, output_final_text_box_mmu], queue=False)
464
- clear_button_ui_t2i.click(fn=lambda: (None, ""), inputs=None, outputs=[output_image_t2i, output_status_t2i], queue=False)
465
-
466
- # Thinking Mode 切换事件
467
- think_button_lm.click(fn=toggle_thinking_mode, inputs=[thinking_mode_lm], outputs=[thinking_mode_lm, think_button_lm])
468
- think_button_mmu.click(fn=toggle_thinking_mode, inputs=[thinking_mode_mmu], outputs=[thinking_mode_mmu, think_button_mmu])
469
-
470
- # 生成按钮事件
471
- run_button_ui_lm.click(fn=generate_viz_wrapper_lm, inputs=[prompt_input_box_lm, steps_slider_lm, gen_length_slider_lm, block_length_slider_lm, temperature_slider_lm, cfg_scale_slider_lm, remasking_dropdown_lm, thinking_mode_lm], outputs=[output_visualization_box_lm, output_final_text_box_lm])
472
- run_button_ui_mmu.click(fn=generate_viz_wrapper, inputs=[image_upload_box, prompt_input_box_mmu, steps_slider_mmu, gen_length_slider_mmu, block_length_slider_mmu, temperature_slider_mmu, cfg_scale_slider_mmu, remasking_dropdown_mmu, thinking_mode_mmu], outputs=[output_visualization_box_mmu, output_final_text_box_mmu])
473
- run_button_ui_t2i.click(fn=generate_viz_wrapper_t2i, inputs=[prompt_input_box_t2i, steps_slider_t2i, guidance_scale_slider_t2i, scheduler_radio_t2i], outputs=[output_image_t2i, output_status_t2i])
 
 
 
474
 
475
  if __name__ == "__main__":
476
  print(f"Starting Gradio App. Attempting to use device: {DEVICE}")
477
- demo.launch(allowed_paths=["title.png", "figs"])
 
10
  import spaces
11
 
12
 
 
13
  def image_transform(image, resolution=256, normalize=True):
14
  image = transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BICUBIC)(image)
15
  image = transforms.CenterCrop((resolution, resolution))(image)
 
19
  return image
20
 
21
  def add_gumbel_noise(logits, temperature):
22
+ """
23
+ Adds Gumbel noise to logits for stochastic sampling.
24
+ Equivalent to argmax(logits + temperature * G) where G ~ Gumbel(0,1).
25
+ This version is more numerically stable than a version involving exp() and division.
26
+ """
27
+ if abs(temperature) < 1e-9: # Effectively zero temperature
28
  return logits
29
+ # Ensure logits are float64 for precision with noise, as suggested by user context
30
  logits = logits.to(torch.float64)
31
+ # Standard Gumbel noise: -log(-log(U)), U ~ Uniform(0,1)
32
+ # Add small epsilon for numerical stability inside logs
33
  noise = torch.rand_like(logits, dtype=torch.float64)
34
  standard_gumbel_noise = -torch.log(-torch.log(noise + 1e-20) + 1e-20)
35
  return logits + temperature * standard_gumbel_noise
36
 
37
  def get_num_transfer_tokens(mask_index, steps):
38
  mask_num = mask_index.sum(dim=1, keepdim=True)
39
+ # Ensure steps is at least 1 to avoid division by zero if mask_num is also 0 (though sum should be >=0)
40
+ steps = max(1, int(steps)) # Ensure steps is a positive integer
41
  base = mask_num // steps
42
  remainder = mask_num % steps
43
  num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.long) + base
44
+ for i in range(mask_num.size(0)): # Iterate over batch
45
+ if remainder[i] > 0 : # Ensure remainder is positive before indexing
46
+ num_transfer_tokens[i, :remainder[i].item()] += 1 # .item() for single value tensor to int
47
  return num_transfer_tokens
48
 
 
49
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
50
+ DEFAULT_MODEL_PATH = "Gen-Verse/MMaDA-8B-MixCoT" # Default
51
+ MASK_ID = 126336
52
+ MODEL = MMadaModelLM.from_pretrained(DEFAULT_MODEL_PATH, trust_remote_code=True, torch_dtype=torch.bfloat16).to(DEVICE).eval()
53
+ TOKENIZER = AutoTokenizer.from_pretrained(DEFAULT_MODEL_PATH, trust_remote_code=True)
54
+ uni_prompting = UniversalPrompting(TOKENIZER, max_text_len=512, special_tokens=("<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>"),ignore_id=-100, cond_dropout_prob=0.1, use_reserved_token=True)
55
+ VQ_MODEL = MAGVITv2().from_pretrained("showlab/magvitv2").to(DEVICE)
56
+
57
+ CURRENT_MODEL_PATH = None
58
 
59
+ MODEL_CHOICES = [
60
+ "MMaDA-8B-Base",
61
+ "MMaDA-8B-MixCoT (coming soon)",
62
+ "MMaDA-8B-Max (coming soon)"
63
+ ]
64
+ MODEL_ACTUAL_PATHS = {
65
+ "MMaDA-8B-Base": DEFAULT_MODEL_PATH,
66
+ }
67
+
68
+ def clear_outputs_action():
69
+ return None, None
70
 
 
71
  @spaces.GPU
72
+ def _load_model_and_tokenizer_core(model_path_to_load, model_display_name_for_status):
73
+ global MODEL, TOKENIZER, MASK_ID, CURRENT_MODEL_PATH, DEVICE, uni_prompting
 
 
 
74
 
75
+ if MODEL is not None and CURRENT_MODEL_PATH == model_path_to_load:
76
+ return f"Model '{model_display_name_for_status}' from '{model_path_to_load}' is already loaded. MASK_ID: {MASK_ID}"
 
77
 
78
+ CURRENT_MODEL_PATH = model_path_to_load
 
 
 
79
 
80
+ status_msg_parts = [f"Loading '{model_display_name_for_status}'..."]
81
+ # try:
82
+ TOKENIZER = AutoTokenizer.from_pretrained(model_path_to_load, trust_remote_code=True)
83
+ status_msg_parts.append(f"Tokenizer for '{model_display_name_for_status}' loaded.")
84
 
85
+ MODEL = MMadaModelLM.from_pretrained(model_path_to_load, trust_remote_code=True, torch_dtype=torch.bfloat16).to(DEVICE).eval()
86
+ status_msg_parts.append(f"Model '{model_display_name_for_status}' loaded to {DEVICE}.")
87
+
88
+ uni_prompting = UniversalPrompting(TOKENIZER, max_text_len=512, special_tokens=("<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>"),ignore_id=-100, cond_dropout_prob=0.1, use_reserved_token=True)
89
+
90
+ if hasattr(TOKENIZER, 'mask_token_id') and TOKENIZER.mask_token_id is not None:
91
+ MASK_ID = TOKENIZER.mask_token_id
92
+ status_msg_parts.append(f"Using MASK_ID from tokenizer: {MASK_ID}.")
93
+ else:
94
  MASK_ID = 126336
95
  status_msg_parts.append(f"Using default MASK_ID: {MASK_ID}.")
96
 
97
+ if TOKENIZER.pad_token_id is None:
98
+ if TOKENIZER.eos_token_id is not None:
99
+ TOKENIZER.pad_token_id = TOKENIZER.eos_token_id
100
+ TOKENIZER.pad_token = TOKENIZER.eos_token
101
+ status_msg_parts.append(f"Set pad_token_id to eos_token_id ({TOKENIZER.eos_token_id}).")
102
+ else:
103
+ status_msg_parts.append("Warning: pad_token_id is None and no eos_token_id.")
104
+
105
+ if TOKENIZER.eos_token_id is None: # Important for cleaning up output in visualization
106
+ status_msg_parts.append("Warning: tokenizer.eos_token_id is None. EOS cleanup might not work.")
107
+
108
+ TOKENIZER.chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{{ '<|start_header_id|>assistant<|end_header_id|>\n' }}"
109
+
110
+ return " ".join(status_msg_parts)
111
+ # except Exception as e:
112
+ # MODEL = None
113
+ # TOKENIZER = None
114
+ # MASK_ID = None
115
+ # CURRENT_MODEL_PATH = None
116
+ # return f"Error loading model '{model_display_name_for_status}': {str(e)}"
117
+
118
+ def handle_model_selection_change(selected_model_name_ui):
119
+ if "coming soon" in selected_model_name_ui.lower():
120
+ global MODEL, TOKENIZER, MASK_ID, CURRENT_MODEL_PATH
121
+ MODEL = None
122
+ TOKENIZER = None
123
+ MASK_ID = None
124
+ CURRENT_MODEL_PATH = None
125
+ return f"'{selected_model_name_ui}' is not yet available. Please select 'Model A'."
126
+
127
+ actual_path = MODEL_ACTUAL_PATHS.get(selected_model_name_ui)
128
+ if not actual_path:
129
+ return f"Path for '{selected_model_name_ui}' is not defined. Cannot load."
130
+
131
+ return _load_model_and_tokenizer_core(actual_path, selected_model_name_ui)
132
+
133
 
 
134
  def get_highlighted_text_tuples(current_x_ids_batch, prompt_input_ids, prompt_len, tk, current_mask_id, raw_prompt_attention_mask):
135
  if current_x_ids_batch is None or current_x_ids_batch.ndim == 0 or current_x_ids_batch.shape[0] == 0:
136
  return [("Error in sequence data for visualization.", "ERROR")]
137
+ # only answer part
138
  current_x_ids_batch = current_x_ids_batch[:, prompt_len:]
139
  seq_ids = current_x_ids_batch[0].tolist()
140
+ eos_token_id = tk.eos_token_id # Get EOS token ID
141
+
142
+ # Stage 1: Build initial list of tuples with (token_str, label, token_id_int)
143
+ # This helps in identifying EOS tokens later without re-checking the type.
144
  intermediate_tuples = []
145
  for j, token_id_int in enumerate(seq_ids):
146
  try:
147
  token_str = tk.decode([token_id_int], skip_special_tokens=True, clean_up_tokenization_spaces=False)
148
+ except Exception: # Handle cases where a token ID might be problematic (e.g. with mock)
149
  token_str = f"[ID:{token_id_int}]"
150
 
151
  label = "ERROR"
 
155
  else:
156
  label = "GEN"
157
  intermediate_tuples.append((token_str, label, token_id_int))
158
+
159
  return intermediate_tuples
160
 
161
  @torch.no_grad()
162
  @spaces.GPU
163
  def generate_viz_wrapper_t2i(prompt_text, steps, guidance_scale, mask_schedule="cosine"):
164
+ global MODEL, TOKENIZER, MASK_ID, DEVICE, uni_prompting
165
+
166
  if MODEL is None or TOKENIZER is None or MASK_ID is None:
167
+ yield [("Error: Model not loaded. Please load the model first.", "ERROR")], "Model not loaded."
168
  return
169
+ steps = int(steps)
170
+ guidance_scale = float(guidance_scale)
171
+
172
+ image_tokens = torch.ones((1, 1024), dtype=torch.long, device=DEVICE) * MASK_ID
173
+ prompt_text = [prompt_text]
174
+ input_ids, attention_mask = uni_prompting((prompt_text, image_tokens), 't2i_gen')
175
+
176
+ if guidance_scale > 0:
177
+ uncond_input_ids, uncond_attention_mask = uni_prompting(([''], image_tokens), 't2i_gen')
178
+ else:
179
+ uncond_input_ids, uncond_attention_mask = None, None
180
+
181
+ mask_schedule = get_mask_schedule(mask_schedule)
182
+ blank_image = Image.new("RGB", (512, 512), (255, 255, 255))
183
+ yield blank_image, "Starting generation..."
184
+ for image_step, status_msg_step in MODEL.t2i_generate_decoding_stepwise(
185
+ input_ids = input_ids,
186
+ uncond_input_ids = uncond_input_ids,
187
+ attention_mask = attention_mask,
188
+ uncond_attention_mask = uncond_attention_mask,
189
+ temperature=1.0,
190
+ timesteps = steps,
191
+ guidance_scale = guidance_scale,
192
+ noise_schedule = mask_schedule,
193
+ noise_type = "mask",
194
+ seq_len = 1024,
195
+ vq_model = VQ_MODEL,
196
+ uni_prompting=uni_prompting):
197
+ yield image_step, status_msg_step
198
+
199
+
200
+
201
 
202
  @torch.no_grad()
203
  @spaces.GPU
204
  def generate_viz_wrapper_lm(prompt_text, steps, gen_length, block_length, temperature,
205
+ cfg_scale, remasking_strategy, thinking_mode_lm=False):
206
  global MODEL, TOKENIZER, MASK_ID, DEVICE
207
  if MODEL is None or TOKENIZER is None or MASK_ID is None:
208
  yield [("Error: Model not loaded. Please load the model first.", "ERROR")], "Model not loaded."
209
  return
210
+
211
+ steps = int(steps)
212
+ gen_length = int(gen_length)
213
+ block_length = int(block_length)
214
+
215
+ if thinking_mode_lm:
216
+ prompt_text = "You should first think about the reasoning process in the mind and then provide the user with the answer. The reasoning process is enclosed within <think> </think> tags, i.e. <think> reasoning process here </think> answer here\n" + prompt_text
217
+
218
  try:
 
 
 
 
219
  m = [{"role": "user", "content": prompt_text}]
220
  processed_prompt_text = TOKENIZER.apply_chat_template(m, add_generation_prompt=True, tokenize=False)
221
+ except Exception as e:
222
+ yield [("Error applying chat template.", "ERROR")], f"Chat template error: {e}"
223
+ processed_prompt_text = prompt_text
224
+ try:
225
+ if TOKENIZER.pad_token_id is None:
226
+ if TOKENIZER.eos_token_id is not None:
227
+ TOKENIZER.pad_token_id = TOKENIZER.eos_token_id
228
+ else: # Should have been caught by load_model, but double check
229
+ yield [("Tokenizer Error", "ERROR")], "pad_token_id is not set in tokenizer."
230
+ return
231
+
232
+ input_ids = TOKENIZER(text=processed_prompt_text, return_tensors="pt", padding="longest", padding_side="left", truncation=True, max_length=MODEL.config.max_position_embeddings if hasattr(MODEL.config, 'max_position_embeddings') else 2048)['input_ids'].to(DEVICE)
233
+ raw_prompt_attention_mask = None
234
+
235
+ except Exception as e:
236
+ yield [("Error tokenizing prompt.", "ERROR")], f"Tokenization error: {e}"
237
+ return
238
+
239
+
240
+
241
+ batch_size = input_ids.shape[0]
242
+ prompt_len = input_ids.shape[1]
243
+
244
+ x = torch.full((batch_size, prompt_len + gen_length), MASK_ID, dtype=torch.long, device=DEVICE)
245
+ x[:, :prompt_len] = input_ids.clone()
246
+
247
+ yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), "Starting generation: Prompt + Initial Masks"
248
+
249
+ if gen_length == 0:
250
+ final_text_output = TOKENIZER.batch_decode(x[:,prompt_len:], skip_special_tokens=True)
251
+ yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), final_text_output[0] if final_text_output else ""
252
+ return
253
+
254
+ if block_length <= 0 or gen_length % block_length != 0 :
255
+ yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), \
256
+ f"Error: gen_length ({gen_length}) must be divisible by block_length ({block_length}) and block_length > 0."
257
+ return
258
+ num_blocks = gen_length // block_length
259
+
260
+ if steps <=0 or steps % num_blocks != 0:
261
+ yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), \
262
+ f"Error: steps ({steps}) must be positive and divisible by num_blocks ({num_blocks}). Steps: {steps}, Num Blocks: {num_blocks}"
263
+ return
264
+ steps_per_block = steps // num_blocks
265
+
266
+ for num_block_iter in range(num_blocks):
267
+ current_block_start_idx_in_x = prompt_len + num_block_iter * block_length
268
+ current_block_end_idx_in_x = prompt_len + (num_block_iter + 1) * block_length
269
+
270
+ block_masks_bool_current = torch.zeros_like(x, dtype=torch.bool)
271
+ block_masks_bool_current[:, current_block_start_idx_in_x:current_block_end_idx_in_x] = \
272
+ (x[:, current_block_start_idx_in_x:current_block_end_idx_in_x] == MASK_ID)
273
+
274
+ num_transfer_tokens_for_this_block = get_num_transfer_tokens(
275
+ block_masks_bool_current[:, current_block_start_idx_in_x:current_block_end_idx_in_x],
276
+ steps_per_block
277
+ )
278
+
279
+ for i_step_in_block in range(steps_per_block):
280
+ mask_index_global = (x == MASK_ID)
281
+
282
+ if cfg_scale > 0.:
283
+ un_x = x.clone()
284
+ # For unconditional pass, mask out the original prompt tokens that are not padding
285
+ # raw_prompt_attention_mask is (B, prompt_len)
286
+ prompt_active_tokens_mask = raw_prompt_attention_mask.bool() # True where actual prompt tokens are
287
+ un_x[:, :prompt_len][prompt_active_tokens_mask] = MASK_ID
288
+
289
+ x_cfg_input = torch.cat([x, un_x], dim=0)
290
+ # Pass attention_mask for CFG if model expects it, covering both parts
291
+ # For simplicity, not passing explicit attention_mask here; relies on model's internal handling.
292
+ model_output = MODEL(x_cfg_input)
293
+ logits_cond, logits_uncond = torch.chunk(model_output.logits, 2, dim=0)
294
+ logits = logits_uncond + (cfg_scale + 1) * (logits_cond - logits_uncond)
295
+ else:
296
+ # Not passing explicit attention_mask here; relies on model's internal handling.
297
+ model_output = MODEL(x)
298
  logits = model_output.logits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
 
300
+ logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
301
+ x0_predicted_tokens = torch.argmax(logits_with_noise, dim=-1)
302
+
303
+ if remasking_strategy == 'low_confidence':
304
+ probs = F.softmax(logits.to(torch.float64), dim=-1)
305
+ x0_probs = torch.gather(probs, dim=-1, index=x0_predicted_tokens.unsqueeze(-1)).squeeze(-1)
306
+ elif remasking_strategy == 'random':
307
+ x0_probs = torch.rand(x.shape, device=x.device, dtype=torch.float64)
308
+ else:
309
+ yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), f"Error: Unknown remasking strategy '{remasking_strategy}'"
310
+ return
311
+
312
+ confidence_for_selection = torch.full_like(x0_probs, -torch.inf)
313
+ candidate_positions_for_unmasking = mask_index_global & block_masks_bool_current
314
+ confidence_for_selection = torch.where(
315
+ candidate_positions_for_unmasking,
316
+ x0_probs,
317
+ -torch.inf
318
+ )
319
+
320
+ x0_final_candidates = torch.where(mask_index_global, x0_predicted_tokens, x)
321
+
322
+ transfer_indices_bool = torch.zeros_like(x, dtype=torch.bool)
323
+ num_to_transfer_this_step_batch = num_transfer_tokens_for_this_block[:, i_step_in_block]
324
+
325
+ for j_batch_idx in range(batch_size):
326
+ k_val = min(num_to_transfer_this_step_batch[j_batch_idx].item(),
327
+ candidate_positions_for_unmasking[j_batch_idx].sum().item()) # ensure k isn't too large
328
+
329
+ if k_val > 0:
330
+ # Ensure confidence_for_selection[j_batch_idx] is 1D for topk
331
+ conf_slice = confidence_for_selection[j_batch_idx]
332
+ if conf_slice.ndim > 1: conf_slice = conf_slice.view(-1) # Should already be 1D from x0_probs
333
+
334
+ # Check if there are enough valid (non -inf) confidences
335
+ valid_conf_count = (conf_slice > -torch.inf).sum().item()
336
+ actual_k = min(k_val, valid_conf_count)
337
+
338
+ if actual_k > 0:
339
+ _, topk_indices_in_x = torch.topk(conf_slice, k=actual_k)
340
+ transfer_indices_bool[j_batch_idx, topk_indices_in_x] = True
341
+
342
+ x[transfer_indices_bool] = x0_final_candidates[transfer_indices_bool]
343
+
344
+ current_total_step = num_block_iter * steps_per_block + i_step_in_block + 1
345
+ total_overall_steps = num_blocks * steps_per_block
346
+ status_msg = f"Block {num_block_iter+1}/{num_blocks}, Step {i_step_in_block+1}/{steps_per_block} (Total: {current_total_step}/{total_overall_steps})"
347
+ yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), status_msg
348
+
349
+ final_generated_ids = x[:, prompt_len:]
350
+ final_text_output = TOKENIZER.batch_decode(final_generated_ids, skip_special_tokens=True)
351
+
352
+ final_text_str = final_text_output[0] if final_text_output and len(final_text_output) > 0 else ""
353
+ yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), final_text_str
354
 
355
  @torch.no_grad()
356
  @spaces.GPU
357
  def generate_viz_wrapper(uploaded_image_pil, prompt_text, steps, gen_length, block_length, temperature,
358
  cfg_scale, remasking_strategy, thinking_mode_mmu=False):
359
+ global MODEL, TOKENIZER, MASK_ID, DEVICE
360
+
361
  if MODEL is None or TOKENIZER is None or MASK_ID is None:
362
  yield [("Error: Model not loaded. Please load the model first.", "ERROR")], "Model not loaded."
363
  return
364
+
365
+ steps = int(steps)
366
+ gen_length = int(gen_length)
367
+ block_length = int(block_length)
368
+
369
+ if thinking_mode_mmu:
370
+ prompt_text = "You should first think about the reasoning process in the mind and then provide the user with the answer. The reasoning process is enclosed within <think> </think> tags, i.e. <think> reasoning process here </think> answer here\n" + prompt_text
371
+
372
  try:
 
 
 
 
373
  m = [{"role": "user", "content": prompt_text}]
374
  processed_prompt_text = TOKENIZER.apply_chat_template(m, add_generation_prompt=True, tokenize=False)
375
+ except Exception as e:
376
+ yield [("Error applying chat template.", "ERROR")], f"Chat template error: {e}"
377
+ processed_prompt_text = prompt_text
378
+
379
+ image_vq_ids_tensor = None
380
+ if uploaded_image_pil is not None:
381
+ try:
382
+
383
+ image = image_transform(uploaded_image_pil, resolution=512).to(DEVICE)
384
+ image = image.unsqueeze(0)
385
+ image_vq_ids_tensor = VQ_MODEL.get_code(image) + 126349
386
+ except Exception as e:
387
+ yield [("Error processing image.", "ERROR")], f"Image to VQ tokens conversion failed: {str(e)}"
388
+ return
389
+
390
+
391
+ try:
392
+ if TOKENIZER.pad_token_id is None:
393
+ if TOKENIZER.eos_token_id is not None:
394
+ TOKENIZER.pad_token_id = TOKENIZER.eos_token_id
395
+ else:
396
+ yield [("Tokenizer Error", "ERROR")], "pad_token_id is not set in tokenizer."
397
+ return
398
+
399
+ input_ids = TOKENIZER(text=processed_prompt_text, return_tensors="pt", padding="longest", padding_side="left", truncation=True, max_length=MODEL.config.max_position_embeddings if hasattr(MODEL.config, 'max_position_embeddings') else 2048)['input_ids'].to(DEVICE)
400
+ raw_prompt_attention_mask = None
401
  if image_vq_ids_tensor is not None:
402
+ if image_vq_ids_tensor.ndim == 1:
403
+ image_vq_ids_tensor = image_vq_ids_tensor.unsqueeze(0)
404
+
405
+ input_ids = torch.cat([
406
+ (torch.ones(input_ids.shape[0], 1) * torch.tensor([126089])).to(DEVICE),
407
+ (torch.ones(input_ids.shape[0], 1) * torch.tensor([126084])).to(DEVICE),
408
+ image_vq_ids_tensor,
409
+ (torch.ones(input_ids.shape[0], 1) * torch.tensor([126085])).to(DEVICE),
410
+ input_ids
411
+ ], dim=1).long()
412
+
413
+ else:
414
+ input_ids = input_ids
415
+
416
+
417
+ except Exception as e:
418
+ yield [("Error tokenizing prompt.", "ERROR")], f"Tokenization error: {e}"
419
+ return
420
+
421
+
422
+
423
+ batch_size = input_ids.shape[0]
424
+ prompt_len = input_ids.shape[1]
425
+
426
+ x = torch.full((batch_size, prompt_len + gen_length), MASK_ID, dtype=torch.long, device=DEVICE)
427
+ x[:, :prompt_len] = input_ids.clone()
428
+
429
+ yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), "Starting generation: Prompt + Initial Masks"
430
+
431
+ if gen_length == 0:
432
+ final_text_output = TOKENIZER.batch_decode(x[:,prompt_len:], skip_special_tokens=True)
433
+ yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), final_text_output[0] if final_text_output else ""
434
+ return
435
+
436
+ if block_length <= 0 or gen_length % block_length != 0 :
437
+ yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), \
438
+ f"Error: gen_length ({gen_length}) must be divisible by block_length ({block_length}) and block_length > 0."
439
+ return
440
+ num_blocks = gen_length // block_length
441
+
442
+ if steps <=0 or steps % num_blocks != 0:
443
+ yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), \
444
+ f"Error: steps ({steps}) must be positive and divisible by num_blocks ({num_blocks}). Steps: {steps}, Num Blocks: {num_blocks}"
445
+ return
446
+ steps_per_block = steps // num_blocks
447
+
448
+ for num_block_iter in range(num_blocks):
449
+ current_block_start_idx_in_x = prompt_len + num_block_iter * block_length
450
+ current_block_end_idx_in_x = prompt_len + (num_block_iter + 1) * block_length
451
+
452
+ block_masks_bool_current = torch.zeros_like(x, dtype=torch.bool)
453
+ block_masks_bool_current[:, current_block_start_idx_in_x:current_block_end_idx_in_x] = \
454
+ (x[:, current_block_start_idx_in_x:current_block_end_idx_in_x] == MASK_ID)
455
+
456
+ num_transfer_tokens_for_this_block = get_num_transfer_tokens(
457
+ block_masks_bool_current[:, current_block_start_idx_in_x:current_block_end_idx_in_x],
458
+ steps_per_block
459
+ )
460
+
461
+ for i_step_in_block in range(steps_per_block):
462
+ mask_index_global = (x == MASK_ID)
463
+
464
+ if cfg_scale > 0.:
465
+ un_x = x.clone()
466
+ # For unconditional pass, mask out the original prompt tokens that are not padding
467
+ # raw_prompt_attention_mask is (B, prompt_len)
468
+ prompt_active_tokens_mask = raw_prompt_attention_mask.bool() # True where actual prompt tokens are
469
+ un_x[:, :prompt_len][prompt_active_tokens_mask] = MASK_ID
470
+
471
+ x_cfg_input = torch.cat([x, un_x], dim=0)
472
+ # Pass attention_mask for CFG if model expects it, covering both parts
473
+ # For simplicity, not passing explicit attention_mask here; relies on model's internal handling.
474
+ model_output = MODEL(x_cfg_input)
475
+ logits_cond, logits_uncond = torch.chunk(model_output.logits, 2, dim=0)
476
+ logits = logits_uncond + (cfg_scale + 1) * (logits_cond - logits_uncond)
477
+ else:
478
+ # Not passing explicit attention_mask here; relies on model's internal handling.
479
+ model_output = MODEL(x)
480
+ logits = model_output.logits
481
+
482
+ logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
483
+ x0_predicted_tokens = torch.argmax(logits_with_noise, dim=-1)
484
+
485
+ if remasking_strategy == 'low_confidence':
486
+ probs = F.softmax(logits.to(torch.float64), dim=-1)
487
+ x0_probs = torch.gather(probs, dim=-1, index=x0_predicted_tokens.unsqueeze(-1)).squeeze(-1)
488
+ elif remasking_strategy == 'random':
489
+ x0_probs = torch.rand(x.shape, device=x.device, dtype=torch.float64)
490
+ else:
491
+ yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), f"Error: Unknown remasking strategy '{remasking_strategy}'"
492
+ return
493
+
494
+ confidence_for_selection = torch.full_like(x0_probs, -torch.inf)
495
+ candidate_positions_for_unmasking = mask_index_global & block_masks_bool_current
496
+ confidence_for_selection = torch.where(
497
+ candidate_positions_for_unmasking,
498
+ x0_probs,
499
+ -torch.inf
500
+ )
501
+
502
+ x0_final_candidates = torch.where(mask_index_global, x0_predicted_tokens, x)
503
+
504
+ transfer_indices_bool = torch.zeros_like(x, dtype=torch.bool)
505
+ num_to_transfer_this_step_batch = num_transfer_tokens_for_this_block[:, i_step_in_block]
506
+
507
+ for j_batch_idx in range(batch_size):
508
+ k_val = min(num_to_transfer_this_step_batch[j_batch_idx].item(),
509
+ candidate_positions_for_unmasking[j_batch_idx].sum().item()) # ensure k isn't too large
510
+
511
+ if k_val > 0:
512
+ # Ensure confidence_for_selection[j_batch_idx] is 1D for topk
513
+ conf_slice = confidence_for_selection[j_batch_idx]
514
+ if conf_slice.ndim > 1: conf_slice = conf_slice.view(-1) # Should already be 1D from x0_probs
515
+
516
+ # Check if there are enough valid (non -inf) confidences
517
+ valid_conf_count = (conf_slice > -torch.inf).sum().item()
518
+ actual_k = min(k_val, valid_conf_count)
519
+
520
+ if actual_k > 0:
521
+ _, topk_indices_in_x = torch.topk(conf_slice, k=actual_k)
522
+ transfer_indices_bool[j_batch_idx, topk_indices_in_x] = True
523
+
524
+ x[transfer_indices_bool] = x0_final_candidates[transfer_indices_bool]
525
+
526
+ current_total_step = num_block_iter * steps_per_block + i_step_in_block + 1
527
+ total_overall_steps = num_blocks * steps_per_block
528
+ status_msg = f"Block {num_block_iter+1}/{num_blocks}, Step {i_step_in_block+1}/{steps_per_block} (Total: {current_total_step}/{total_overall_steps})"
529
+ yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), status_msg
530
+
531
+ final_generated_ids = x[:, prompt_len:]
532
+ final_text_output = TOKENIZER.batch_decode(final_generated_ids, skip_special_tokens=True)
533
+
534
+ final_text_str = final_text_output[0] if final_text_output and len(final_text_output) > 0 else ""
535
+ yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), final_text_str
536
+
537
+
538
  css_styles = """
539
  .gradio-container{font-family:'IBM Plex Sans',sans-serif;margin:auto;}
540
  .gr-input {background:#f9f9f9 !important;border:1px solid #e0e0e0 !important;}
541
  .gr-output{background:#f0f0f0 !important;border:1px solid #d0d0d0 !important;}
542
+
543
+ .highlighted-text span{
544
+ padding:2px 4px;border-radius:4px;margin:1px 2px;display:inline-block;line-height:1.6;
545
+ }
546
+
547
  footer{display:none !important}
548
+
549
+ #live-update-scrollable-box {
550
+ max-height: 800px; /* 您可以根据需要调整这个最大高度,例如 '300px', '50vh' */
551
+ overflow-y: auto !important; /* 当内容超出 max-height 时显示垂直滚动条 */
552
+ display: block; /* 确保元素是块级元素,以便 max-height 生效 */
553
+
554
+ }
555
+ #think_btn {
556
+ background-color: #f3f4f6 !important;
557
+ border: 1px solid #d0d0d0 !important;
558
+ color: #111827 !important;
559
+ font-size: 16px !important;
560
+ font-weight: bold !important;
561
+ }
562
+ #think_btn:hover {
563
+ background-color: #e0e0e0 !important;
564
+ border: 1px solid #c0c0c0 !important;
565
+ color: #222 !important;
566
+ }
567
+ #think_btn:active {
568
+ background-color: #2563eb !important;
569
+ border: 1px solid #b0b0b0 !important;
570
+ color: white !important;
571
+ }
572
  """
573
 
574
+
575
+ # thinking_mode_t2i = gr.State(False)
576
+ def toggle_thinking_mode_lm(current_thinking_mode):
577
+ new_state = not current_thinking_mode
578
+ new_label = "Thinking Mode ✅" if new_state else "Thinking Mode ❌"
579
+ return new_state, gr.update(value=new_label)
580
+
581
+ def toggle_thinking_mode_mmu(current_thinking_mode):
582
  new_state = not current_thinking_mode
583
  new_label = "Thinking Mode ✅" if new_state else "Thinking Mode ❌"
584
  return new_state, gr.update(value=new_label)
585
 
 
586
 
587
+ color_map_config = {
588
+ "MASK": "lightgrey",
589
+ "GEN": "#DCABFA",
590
+ }
591
 
592
+ theme = gr.themes.Ocean(
593
+ primary_hue="fuchsia",
594
+ )
595
  with gr.Blocks(css=css_styles, theme=theme) as demo:
596
+ # with gr.Blocks(css=css_styles, theme=gr.themes.Soft(primary_hue=gr.themes.colors.blue, secondary_hue=gr.themes.colors.sky)) as demo:
597
+ # with gr.Blocks() as demo:
598
+ thinking_mode_lm = gr.State(True)
599
+ thinking_mode_mmu = gr.State(True)
600
+ # gr.Markdown("<h1 style='text-align: center; margin-bottom: 20px;'>MMaDA: Multimodal Large Diffusion Language Models</h1>")
601
+ # gr.Markdown("MMaDA is a novel class of multimodal diffusion foundation models designed to achieve superior performance across diverse domains such as textual reasoning, multimodal understanding, and text-to-image generation")
602
+ # gr.Markdown("Github: [Gen-Verse/MMaDA](https://github.com/Gen-Verse/MMaDA)")
603
+ # gr.Markdown("Paper: [MMaDA: Multimodal Large Diffusion Language Models]()")
604
  gr.HTML("""
605
  <div align="center" style="margin-bottom: 20px;">
606
  <img src='/gradio_api/file=title.png' width="160">
 
608
  MMaDA is a new class of multimodal diffusion foundation models, enabling state-of-the-art performance in reasoning, multimodal understanding, and text-to-image generation.
609
  </p>
610
  <p style="font-size: 15px;">
611
+ 📄 <a href="https://arxiv.org/abs/2505.15809" target="_blank">Paper</a>
612
+ &nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;
613
+ 💻 <a href="https://github.com/Gen-Verse/MMaDA" target="_blank">Code</a>
614
  </p>
615
  </div>
616
  """)
 
617
  with gr.Row():
618
+ gr.HTML("""
619
+ <div style="display: flex; justify-content: center; align-items: center; padding: 15px;">
620
+ <span style="padding: 8px 15px; border-radius: 15px; font-weight: bold; margin: 0 10px; background-color: #E879F9; color: white;">
621
+ MMaDA-8B-MixCoT (Active)
622
+ </span>
623
+ <span style="padding: 8px 15px; border-radius: 15px; font-weight: bold; margin: 0 10px; background-color: #E5E7EB; color: #6B7280; cursor: not-allowed;">
624
+ MMaDA-8B-Max (coming soon)
625
+ </span>
626
+ </div>
627
+ """)
 
 
 
628
 
 
629
  gr.Markdown("## Part 1. Text Generation")
630
  with gr.Row():
631
  with gr.Column(scale=2):
632
  prompt_input_box_lm = gr.Textbox(label="Enter your prompt:", lines=3, value="A rectangular prism has a length of 5 units, a width of 4 units, and a height of 3 units. What is the volume of the prism?")
633
  think_button_lm = gr.Button("Thinking Mode ✅", elem_id="think_btn")
634
  with gr.Accordion("Generation Parameters", open=True):
 
635
  with gr.Row():
636
+ gen_length_slider_lm = gr.Slider(minimum=8, maximum=1024, value=512, step=64, label="Generation Length", info="Number of tokens to generate.")
637
+ steps_slider_lm = gr.Slider(minimum=1, maximum=512, value=256, step=32, label="Total Sampling Steps", info="Must be divisible by (gen_length / block_length).")
638
  with gr.Row():
639
+ block_length_slider_lm = gr.Slider(minimum=8, maximum=1024, value=128, step=32, label="Block Length", info="gen_length must be divisible by this.")
640
  remasking_dropdown_lm = gr.Dropdown(choices=['low_confidence', 'random'], value='low_confidence', label="Remasking Strategy")
641
  with gr.Row():
642
+ cfg_scale_slider_lm = gr.Slider(minimum=0.0, maximum=2.0, value=0.0, step=0.1, label="CFG Scale", info="Classifier-Free Guidance. 0 disables it.")
643
+ temperature_slider_lm = gr.Slider(minimum=0.0, maximum=2.0, value=1, step=0.05, label="Temperature", info="Controls randomness via Gumbel noise. 0 is deterministic.")
644
+
645
+
646
  with gr.Row():
647
  run_button_ui_lm = gr.Button("Generate Sequence", variant="primary", scale=3)
648
  clear_button_ui_lm = gr.Button("Clear Outputs", scale=1)
649
+
650
  with gr.Column(scale=3):
651
+ # gr.Markdown("## Live Generation Process")
652
+ output_visualization_box_lm = gr.HighlightedText(
653
+ label="Live Generation Process",
654
+ show_legend=True,
655
+ color_map=color_map_config,
656
+ combine_adjacent=False,
657
+ interactive=False,
658
+ elem_id="live-update-scrollable-box",
659
+ )
660
+ # gr.Markdown("## Final Generated Text")
661
  output_final_text_box_lm = gr.Textbox(label="Final Output", lines=8, interactive=False, show_copy_button=True)
662
+
663
+
664
+
665
  gr.Examples(
666
  examples=[
667
  ["A rectangular prism has a length of 5 units, a width of 4 units, and a height of 3 units. What is the volume of the prism?", 256, 512, 128, 1, 0, "low_confidence"],
 
672
  fn=generate_viz_wrapper_lm,
673
  cache_examples=False
674
  )
675
+
 
676
  gr.Markdown("---")
677
  gr.Markdown("## Part 2. Multimodal Understanding")
678
  with gr.Row():
 
679
  with gr.Column(scale=2):
680
+ prompt_input_box_mmu = gr.Textbox(
681
+ label="Enter your prompt:",
682
+ lines=3,
683
+ value="Please describe this image in detail."
684
+ )
685
  think_button_mmu = gr.Button("Thinking Mode ✅", elem_id="think_btn")
686
  with gr.Accordion("Generation Parameters", open=True):
687
+ with gr.Row():
688
+ gen_length_slider_mmu = gr.Slider(minimum=64, maximum=1024, value=512, step=64, label="Generation Length", info="Number of tokens to generate.")
689
+ steps_slider_mmu = gr.Slider(minimum=1, maximum=512, value=256, step=32, label="Total Sampling Steps", info="Must be divisible by (gen_length / block_length).")
690
+ with gr.Row():
691
+ block_length_slider_mmu = gr.Slider(minimum=32, maximum=1024, value=128, step=32, label="Block Length", info="gen_length must be divisible by this.")
692
  remasking_dropdown_mmu = gr.Dropdown(choices=['low_confidence', 'random'], value='low_confidence', label="Remasking Strategy")
693
+ with gr.Row():
694
+ cfg_scale_slider_mmu = gr.Slider(minimum=0.0, maximum=2.0, value=0.0, step=0.1, label="CFG Scale", info="Classifier-Free Guidance. 0 disables it.")
695
+ temperature_slider_mmu = gr.Slider(minimum=0.0, maximum=2.0, value=1, step=0.05, label="Temperature", info="Controls randomness via Gumbel noise. 0 is deterministic.")
696
+
697
  with gr.Row():
698
  image_upload_box = gr.Image(type="pil", label="Upload Image")
699
+
700
  with gr.Row():
701
  run_button_ui_mmu = gr.Button("Generate Description", variant="primary", scale=3)
702
  clear_button_ui_mmu = gr.Button("Clear Outputs", scale=1)
703
+
704
  with gr.Column(scale=3):
705
+ gr.Markdown("## Live Generation Process")
706
+ output_visualization_box_mmu = gr.HighlightedText(
707
+ label="Token Sequence (Live Update)",
708
+ show_legend=True,
709
+ color_map=color_map_config,
710
+ combine_adjacent=False,
711
+ interactive=False,
712
+ elem_id="live-update-scrollable-box",
713
+ )
714
+ gr.Markdown("## Final Generated Text")
715
  output_final_text_box_mmu = gr.Textbox(label="Final Output", lines=8, interactive=False, show_copy_button=True)
716
+
717
+
718
  gr.Examples(
719
  examples=[
720
  ["figs/geo.png", "In the given figure, a square ABCD is inscribed in a circle with center O. Point P is located on side CD. What is the value of angle APB?", 256, 512, 64, 1, 0, "low_confidence"],
721
  ["figs/bus.jpg", "What are the colors of the bus?", 256, 512, 64, 1, 0, "low_confidence"]
722
  ],
723
+ inputs=[
724
+ image_upload_box,
725
+ prompt_input_box_mmu,
726
+ steps_slider_mmu,
727
+ gen_length_slider_mmu,
728
+ block_length_slider_mmu,
729
+ temperature_slider_mmu,
730
+ cfg_scale_slider_mmu,
731
+ remasking_dropdown_mmu
732
+ ],
733
  outputs=[output_visualization_box_mmu, output_final_text_box_mmu],
734
  fn=generate_viz_wrapper,
735
  cache_examples=False
736
  )
737
+
738
  gr.Markdown("---")
739
  gr.Markdown("## Part 3. Text-to-Image Generation")
 
740
  with gr.Row():
741
  with gr.Column(scale=2):
742
  prompt_input_box_t2i = gr.Textbox(label="Enter your prompt:", lines=3, value="A sea turtle swimming near a coral reef in the ocean, with a clear blue sky and water in the background.")
743
+
744
  with gr.Accordion("Generation Parameters", open=True):
745
  with gr.Row():
746
+ steps_slider_t2i = gr.Slider(minimum=5, maximum=100, value=15, step=5, label="Total Sampling Steps", info="Must be divisible by (gen_length / block_length).")
747
+ guidance_scale_slider_t2i = gr.Slider(minimum=0.0, maximum=7.0, value=3.5, step=0.5, label="Guidance Scale", info="Classifier-Free Guidance. 0 disables it.")
748
+
749
+
750
+ with gr.Row():
751
+ scheduler_radio_t2i = gr.Radio(
752
+ choices=["cosine", "sigmoid", "linear"],
753
+ value="cosine",
754
+ label="Scheduler",
755
+ )
756
+
757
  with gr.Row():
758
  run_button_ui_t2i = gr.Button("Generate Image", variant="primary", scale=3)
759
  clear_button_ui_t2i = gr.Button("Clear Outputs", scale=1)
760
+
761
+
762
  with gr.Column(scale=3):
763
+ # gr.Markdown("## Live Generation Process")
764
  output_image_t2i = gr.Image(label="Generated Image", interactive=False, type="pil")
765
  output_status_t2i = gr.Textbox(label="Generation Status", interactive=False)
766
+
767
  gr.Examples(
768
  examples=[
769
  ["A sea turtle swimming near a coral reef in the ocean, with a clear blue sky and water in the background.", 15, 3.5, "cosine"],
 
774
  fn=generate_viz_wrapper_t2i,
775
  cache_examples=False
776
  )
777
+
778
+ run_button_ui_t2i.click(
779
+ fn=generate_viz_wrapper_t2i,
780
+ inputs=[
781
+ prompt_input_box_t2i,
782
+ steps_slider_t2i,
783
+ guidance_scale_slider_t2i,
784
+ scheduler_radio_t2i
785
+ ],
786
+ outputs=[output_image_t2i, output_status_t2i]
787
+ )
788
 
789
+ clear_button_ui_t2i.click(
790
+ fn=lambda: (None, ""),
791
+ inputs=None,
792
+ outputs=[output_image_t2i, output_status_t2i],
793
+ queue=False
794
+ )
795
+
796
+ think_button_lm.click(
797
+ fn=toggle_thinking_mode_lm,
798
+ inputs=[thinking_mode_lm],
799
+ outputs=[thinking_mode_lm, think_button_lm]
800
+ )
801
+
802
+ think_button_mmu.click(
803
+ fn=toggle_thinking_mode_mmu,
804
+ inputs=[thinking_mode_mmu],
805
+ outputs=[thinking_mode_mmu, think_button_mmu]
806
+ )
807
+
808
+
809
+
810
+ def initialize_default_model():
811
+ default_model = "MMaDA-8B-Base"
812
+ result = handle_model_selection_change(default_model)
813
+ return default_model, result
814
+
815
+ def clear_outputs():
816
+ return None, None, None # Clear image, visualization, and final text
817
 
818
+ clear_button_ui_lm.click(
819
+ fn=clear_outputs,
820
  inputs=None,
821
+ outputs=[image_upload_box, output_visualization_box_lm, output_final_text_box_lm],
822
+ queue=False
823
+ )
824
+ clear_button_ui_mmu.click(
825
+ fn=clear_outputs,
826
+ inputs=None,
827
+ outputs=[image_upload_box, output_visualization_box_mmu, output_final_text_box_mmu],
828
+ queue=False
829
+ )
830
+
831
+ run_button_ui_lm.click(
832
+ fn=generate_viz_wrapper_lm,
833
+ inputs=[
834
+ prompt_input_box_lm,
835
+ steps_slider_lm,
836
+ gen_length_slider_lm,
837
+ block_length_slider_lm,
838
+ temperature_slider_lm,
839
+ cfg_scale_slider_lm,
840
+ remasking_dropdown_lm,
841
+ thinking_mode_lm
842
  ],
843
+ outputs=[output_visualization_box_lm, output_final_text_box_lm]
844
  )
845
+
846
+ run_button_ui_mmu.click(
847
+ fn=generate_viz_wrapper,
848
+ inputs=[
849
+ image_upload_box,
850
+ prompt_input_box_mmu,
851
+ steps_slider_mmu,
852
+ gen_length_slider_mmu,
853
+ block_length_slider_mmu,
854
+ temperature_slider_mmu,
855
+ cfg_scale_slider_mmu,
856
+ remasking_dropdown_mmu,
857
+ thinking_mode_mmu
858
+ ],
859
+ outputs=[output_visualization_box_mmu, output_final_text_box_mmu]
860
+ )
861
+
862
 
863
  if __name__ == "__main__":
864
  print(f"Starting Gradio App. Attempting to use device: {DEVICE}")
865
+ demo.launch(allowed_paths=["title.png"])