YucYux commited on
Commit
5954d37
·
1 Parent(s): 0e83169

fixed model loading bug

Browse files
Files changed (1) hide show
  1. app.py +337 -299
app.py CHANGED
@@ -83,7 +83,7 @@ def _load_model_and_tokenizer_core(model_path_to_load, model_display_name_for_st
83
  TOKENIZER = AutoTokenizer.from_pretrained(model_path_to_load, trust_remote_code=True)
84
  status_msg_parts.append(f"Tokenizer for '{model_display_name_for_status}' loaded.")
85
 
86
- MODEL = MMadaModelLM.from_pretrained(model_path_to_load, trust_remote_code=True, torch_dtype=torch.bfloat16).to(DEVICE).eval()
87
  status_msg_parts.append(f"Model '{model_display_name_for_status}' loaded to {DEVICE}.")
88
 
89
  uni_prompting = UniversalPrompting(TOKENIZER, max_text_len=512, special_tokens=("<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>"),ignore_id=-100, cond_dropout_prob=0.1, use_reserved_token=True)
@@ -264,35 +264,49 @@ def generate_viz_wrapper_t2i(prompt_text, steps, guidance_scale, mask_schedule="
264
  if MODEL is None or TOKENIZER is None or MASK_ID is None:
265
  yield [("Error: Model not loaded. Please load the model first.", "ERROR")], "Model not loaded."
266
  return
267
- steps = int(steps)
268
- guidance_scale = float(guidance_scale)
269
 
270
- image_tokens = torch.ones((1, 1024), dtype=torch.long, device=DEVICE) * MASK_ID
271
- prompt_text = [prompt_text]
272
- input_ids, attention_mask = uni_prompting((prompt_text, image_tokens), 't2i_gen')
 
273
 
274
- if guidance_scale > 0:
275
- uncond_input_ids, uncond_attention_mask = uni_prompting(([''], image_tokens), 't2i_gen')
276
- else:
277
- uncond_input_ids, uncond_attention_mask = None, None
278
-
279
- mask_schedule = get_mask_schedule(mask_schedule)
280
- blank_image = Image.new("RGB", (512, 512), (255, 255, 255))
281
- yield blank_image, "Starting generation..."
282
- for image_step, status_msg_step in MODEL.t2i_generate_decoding_stepwise(
283
- input_ids = input_ids,
284
- uncond_input_ids = uncond_input_ids,
285
- attention_mask = attention_mask,
286
- uncond_attention_mask = uncond_attention_mask,
287
- temperature=1.0,
288
- timesteps = steps,
289
- guidance_scale = guidance_scale,
290
- noise_schedule = mask_schedule,
291
- noise_type = "mask",
292
- seq_len = 1024,
293
- vq_model = VQ_MODEL,
294
- uni_prompting=uni_prompting):
295
- yield image_step, status_msg_step
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
 
297
 
298
 
@@ -306,149 +320,160 @@ def generate_viz_wrapper_lm(prompt_text, steps, gen_length, block_length, temper
306
  yield [("Error: Model not loaded. Please load the model first.", "ERROR")], "Model not loaded."
307
  return
308
 
309
- steps = int(steps)
310
- gen_length = int(gen_length)
311
- block_length = int(block_length)
312
-
313
- if thinking_mode_lm:
314
- prompt_text = "You should first think about the reasoning process in the mind and then provide the user with the answer. The reasoning process is enclosed within <think> </think> tags, i.e. <think> reasoning process here </think> answer here\n" + prompt_text
315
 
316
  try:
317
- m = [{"role": "user", "content": prompt_text}]
318
- processed_prompt_text = TOKENIZER.apply_chat_template(m, add_generation_prompt=True, tokenize=False)
319
- except Exception as e:
320
- yield [("Error applying chat template.", "ERROR")], f"Chat template error: {e}"
321
- processed_prompt_text = prompt_text
322
- try:
323
- if TOKENIZER.pad_token_id is None:
324
- if TOKENIZER.eos_token_id is not None:
325
- TOKENIZER.pad_token_id = TOKENIZER.eos_token_id
326
- else: # Should have been caught by load_model, but double check
327
- yield [("Tokenizer Error", "ERROR")], "pad_token_id is not set in tokenizer."
328
- return
329
-
330
- input_ids = TOKENIZER(text=processed_prompt_text, return_tensors="pt", padding="longest", padding_side="left", truncation=True, max_length=MODEL.config.max_position_embeddings if hasattr(MODEL.config, 'max_position_embeddings') else 2048)['input_ids'].to(DEVICE)
331
- raw_prompt_attention_mask = None
332
-
333
- except Exception as e:
334
- yield [("Error tokenizing prompt.", "ERROR")], f"Tokenization error: {e}"
335
- return
336
 
337
-
 
338
 
339
- batch_size = input_ids.shape[0]
340
- prompt_len = input_ids.shape[1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
341
 
342
- x = torch.full((batch_size, prompt_len + gen_length), MASK_ID, dtype=torch.long, device=DEVICE)
343
- x[:, :prompt_len] = input_ids.clone()
344
 
345
- yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), "Starting generation: Prompt + Initial Masks"
 
346
 
347
- if gen_length == 0:
348
- final_text_output = TOKENIZER.batch_decode(x[:,prompt_len:], skip_special_tokens=True)
349
- yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), final_text_output[0] if final_text_output else ""
350
- return
351
 
352
- if block_length <= 0 or gen_length % block_length != 0 :
353
- yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), \
354
- f"Error: gen_length ({gen_length}) must be divisible by block_length ({block_length}) and block_length > 0."
355
- return
356
- num_blocks = gen_length // block_length
357
 
358
- if steps <=0 or steps % num_blocks != 0:
359
- yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), \
360
- f"Error: steps ({steps}) must be positive and divisible by num_blocks ({num_blocks}). Steps: {steps}, Num Blocks: {num_blocks}"
361
- return
362
- steps_per_block = steps // num_blocks
363
-
364
- for num_block_iter in range(num_blocks):
365
- current_block_start_idx_in_x = prompt_len + num_block_iter * block_length
366
- current_block_end_idx_in_x = prompt_len + (num_block_iter + 1) * block_length
 
 
 
 
 
 
 
367
 
368
- block_masks_bool_current = torch.zeros_like(x, dtype=torch.bool)
369
- block_masks_bool_current[:, current_block_start_idx_in_x:current_block_end_idx_in_x] = \
370
- (x[:, current_block_start_idx_in_x:current_block_end_idx_in_x] == MASK_ID)
 
 
 
 
371
 
372
- num_transfer_tokens_for_this_block = get_num_transfer_tokens(
373
- block_masks_bool_current[:, current_block_start_idx_in_x:current_block_end_idx_in_x],
374
- steps_per_block
375
- )
376
 
377
- for i_step_in_block in range(steps_per_block):
378
- mask_index_global = (x == MASK_ID)
379
-
380
- if cfg_scale > 0.:
381
- un_x = x.clone()
382
- # For unconditional pass, mask out the original prompt tokens that are not padding
383
- # raw_prompt_attention_mask is (B, prompt_len)
384
- prompt_active_tokens_mask = raw_prompt_attention_mask.bool() # True where actual prompt tokens are
385
- un_x[:, :prompt_len][prompt_active_tokens_mask] = MASK_ID
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
386
 
387
- x_cfg_input = torch.cat([x, un_x], dim=0)
388
- # Pass attention_mask for CFG if model expects it, covering both parts
389
- # For simplicity, not passing explicit attention_mask here; relies on model's internal handling.
390
- model_output = MODEL(x_cfg_input)
391
- logits_cond, logits_uncond = torch.chunk(model_output.logits, 2, dim=0)
392
- logits = logits_uncond + (cfg_scale + 1) * (logits_cond - logits_uncond)
393
- else:
394
- # Not passing explicit attention_mask here; relies on model's internal handling.
395
- model_output = MODEL(x)
396
- logits = model_output.logits
397
-
398
- logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
399
- x0_predicted_tokens = torch.argmax(logits_with_noise, dim=-1)
400
-
401
- if remasking_strategy == 'low_confidence':
402
- probs = F.softmax(logits.to(torch.float64), dim=-1)
403
- x0_probs = torch.gather(probs, dim=-1, index=x0_predicted_tokens.unsqueeze(-1)).squeeze(-1)
404
- elif remasking_strategy == 'random':
405
- x0_probs = torch.rand(x.shape, device=x.device, dtype=torch.float64)
406
- else:
407
- yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), f"Error: Unknown remasking strategy '{remasking_strategy}'"
408
- return
409
 
410
- confidence_for_selection = torch.full_like(x0_probs, -torch.inf)
411
- candidate_positions_for_unmasking = mask_index_global & block_masks_bool_current
412
- confidence_for_selection = torch.where(
413
- candidate_positions_for_unmasking,
414
- x0_probs,
415
- -torch.inf
416
- )
417
-
418
- x0_final_candidates = torch.where(mask_index_global, x0_predicted_tokens, x)
419
 
420
- transfer_indices_bool = torch.zeros_like(x, dtype=torch.bool)
421
- num_to_transfer_this_step_batch = num_transfer_tokens_for_this_block[:, i_step_in_block]
 
422
 
423
- for j_batch_idx in range(batch_size):
424
- k_val = min(num_to_transfer_this_step_batch[j_batch_idx].item(),
425
- candidate_positions_for_unmasking[j_batch_idx].sum().item()) # ensure k isn't too large
 
 
 
 
 
426
 
427
- if k_val > 0:
428
- # Ensure confidence_for_selection[j_batch_idx] is 1D for topk
429
- conf_slice = confidence_for_selection[j_batch_idx]
430
- if conf_slice.ndim > 1: conf_slice = conf_slice.view(-1) # Should already be 1D from x0_probs
431
-
432
- # Check if there are enough valid (non -inf) confidences
433
- valid_conf_count = (conf_slice > -torch.inf).sum().item()
434
- actual_k = min(k_val, valid_conf_count)
435
 
436
- if actual_k > 0:
437
- _, topk_indices_in_x = torch.topk(conf_slice, k=actual_k)
438
- transfer_indices_bool[j_batch_idx, topk_indices_in_x] = True
439
-
440
- x[transfer_indices_bool] = x0_final_candidates[transfer_indices_bool]
441
 
442
- current_total_step = num_block_iter * steps_per_block + i_step_in_block + 1
443
- total_overall_steps = num_blocks * steps_per_block
444
- status_msg = f"Block {num_block_iter+1}/{num_blocks}, Step {i_step_in_block+1}/{steps_per_block} (Total: {current_total_step}/{total_overall_steps})"
445
- yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), status_msg
 
446
 
447
- final_generated_ids = x[:, prompt_len:]
448
- final_text_output = TOKENIZER.batch_decode(final_generated_ids, skip_special_tokens=True)
449
-
450
- final_text_str = final_text_output[0] if final_text_output and len(final_text_output) > 0 else ""
451
- yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), final_text_str
452
 
453
  @torch.no_grad()
454
  @spaces.GPU
@@ -460,177 +485,190 @@ def generate_viz_wrapper(uploaded_image_pil, prompt_text, steps, gen_length, blo
460
  yield [("Error: Model not loaded. Please load the model first.", "ERROR")], "Model not loaded."
461
  return
462
 
463
- steps = int(steps)
464
- gen_length = int(gen_length)
465
- block_length = int(block_length)
 
 
 
 
 
 
466
 
467
- if thinking_mode_mmu:
468
- prompt_text = "You should first think about the reasoning process in the mind and then provide the user with the answer. The reasoning process is enclosed within <think> </think> tags, i.e. <think> reasoning process here </think> answer here\n" + prompt_text
469
 
470
- try:
471
- m = [{"role": "user", "content": prompt_text}]
472
- processed_prompt_text = TOKENIZER.apply_chat_template(m, add_generation_prompt=True, tokenize=False)
473
- except Exception as e:
474
- yield [("Error applying chat template.", "ERROR")], f"Chat template error: {e}"
475
- processed_prompt_text = prompt_text
476
-
477
- image_vq_ids_tensor = None
478
- if uploaded_image_pil is not None:
479
  try:
480
-
481
- image = image_transform(uploaded_image_pil, resolution=512).to(DEVICE)
482
- image = image.unsqueeze(0)
483
- image_vq_ids_tensor = VQ_MODEL.get_code(image) + 126349
484
  except Exception as e:
485
- yield [("Error processing image.", "ERROR")], f"Image to VQ tokens conversion failed: {str(e)}"
486
- return
487
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
488
 
489
- try:
490
- if TOKENIZER.pad_token_id is None:
491
- if TOKENIZER.eos_token_id is not None:
492
- TOKENIZER.pad_token_id = TOKENIZER.eos_token_id
493
- else:
494
- yield [("Tokenizer Error", "ERROR")], "pad_token_id is not set in tokenizer."
495
- return
496
-
497
- input_ids = TOKENIZER(text=processed_prompt_text, return_tensors="pt", padding="longest", padding_side="left", truncation=True, max_length=MODEL.config.max_position_embeddings if hasattr(MODEL.config, 'max_position_embeddings') else 2048)['input_ids'].to(DEVICE)
498
- raw_prompt_attention_mask = None
499
- if image_vq_ids_tensor is not None:
500
- if image_vq_ids_tensor.ndim == 1:
501
- image_vq_ids_tensor = image_vq_ids_tensor.unsqueeze(0)
502
-
503
- input_ids = torch.cat([
504
- (torch.ones(input_ids.shape[0], 1) * torch.tensor([126089])).to(DEVICE),
505
- (torch.ones(input_ids.shape[0], 1) * torch.tensor([126084])).to(DEVICE),
506
- image_vq_ids_tensor,
507
- (torch.ones(input_ids.shape[0], 1) * torch.tensor([126085])).to(DEVICE),
508
- input_ids
509
- ], dim=1).long()
510
 
511
- else:
512
- input_ids = input_ids
 
513
 
514
 
515
- except Exception as e:
516
- yield [("Error tokenizing prompt.", "ERROR")], f"Tokenization error: {e}"
517
- return
518
 
519
-
520
-
521
- batch_size = input_ids.shape[0]
522
- prompt_len = input_ids.shape[1]
523
 
524
- x = torch.full((batch_size, prompt_len + gen_length), MASK_ID, dtype=torch.long, device=DEVICE)
525
- x[:, :prompt_len] = input_ids.clone()
526
 
527
- yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), "Starting generation: Prompt + Initial Masks"
528
 
529
- if gen_length == 0:
530
- final_text_output = TOKENIZER.batch_decode(x[:,prompt_len:], skip_special_tokens=True)
531
- yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), final_text_output[0] if final_text_output else ""
532
- return
533
 
534
- if block_length <= 0 or gen_length % block_length != 0 :
535
- yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), \
536
- f"Error: gen_length ({gen_length}) must be divisible by block_length ({block_length}) and block_length > 0."
537
- return
538
- num_blocks = gen_length // block_length
539
 
540
- if steps <=0 or steps % num_blocks != 0:
541
- yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), \
542
- f"Error: steps ({steps}) must be positive and divisible by num_blocks ({num_blocks}). Steps: {steps}, Num Blocks: {num_blocks}"
543
- return
544
- steps_per_block = steps // num_blocks
545
-
546
- for num_block_iter in range(num_blocks):
547
- current_block_start_idx_in_x = prompt_len + num_block_iter * block_length
548
- current_block_end_idx_in_x = prompt_len + (num_block_iter + 1) * block_length
549
 
550
- block_masks_bool_current = torch.zeros_like(x, dtype=torch.bool)
551
- block_masks_bool_current[:, current_block_start_idx_in_x:current_block_end_idx_in_x] = \
552
- (x[:, current_block_start_idx_in_x:current_block_end_idx_in_x] == MASK_ID)
 
 
 
 
553
 
554
- num_transfer_tokens_for_this_block = get_num_transfer_tokens(
555
- block_masks_bool_current[:, current_block_start_idx_in_x:current_block_end_idx_in_x],
556
- steps_per_block
557
- )
558
 
559
- for i_step_in_block in range(steps_per_block):
560
- mask_index_global = (x == MASK_ID)
561
-
562
- if cfg_scale > 0.:
563
- un_x = x.clone()
564
- # For unconditional pass, mask out the original prompt tokens that are not padding
565
- # raw_prompt_attention_mask is (B, prompt_len)
566
- prompt_active_tokens_mask = raw_prompt_attention_mask.bool() # True where actual prompt tokens are
567
- un_x[:, :prompt_len][prompt_active_tokens_mask] = MASK_ID
568
 
569
- x_cfg_input = torch.cat([x, un_x], dim=0)
570
- # Pass attention_mask for CFG if model expects it, covering both parts
571
- # For simplicity, not passing explicit attention_mask here; relies on model's internal handling.
572
- model_output = MODEL(x_cfg_input)
573
- logits_cond, logits_uncond = torch.chunk(model_output.logits, 2, dim=0)
574
- logits = logits_uncond + (cfg_scale + 1) * (logits_cond - logits_uncond)
575
- else:
576
- # Not passing explicit attention_mask here; relies on model's internal handling.
577
- model_output = MODEL(x)
578
- logits = model_output.logits
579
-
580
- logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
581
- x0_predicted_tokens = torch.argmax(logits_with_noise, dim=-1)
582
-
583
- if remasking_strategy == 'low_confidence':
584
- probs = F.softmax(logits.to(torch.float64), dim=-1)
585
- x0_probs = torch.gather(probs, dim=-1, index=x0_predicted_tokens.unsqueeze(-1)).squeeze(-1)
586
- elif remasking_strategy == 'random':
587
- x0_probs = torch.rand(x.shape, device=x.device, dtype=torch.float64)
588
- else:
589
- yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), f"Error: Unknown remasking strategy '{remasking_strategy}'"
590
- return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
591
 
592
- confidence_for_selection = torch.full_like(x0_probs, -torch.inf)
593
- candidate_positions_for_unmasking = mask_index_global & block_masks_bool_current
594
- confidence_for_selection = torch.where(
595
- candidate_positions_for_unmasking,
596
- x0_probs,
597
- -torch.inf
598
- )
599
-
600
- x0_final_candidates = torch.where(mask_index_global, x0_predicted_tokens, x)
601
 
602
- transfer_indices_bool = torch.zeros_like(x, dtype=torch.bool)
603
- num_to_transfer_this_step_batch = num_transfer_tokens_for_this_block[:, i_step_in_block]
 
604
 
605
- for j_batch_idx in range(batch_size):
606
- k_val = min(num_to_transfer_this_step_batch[j_batch_idx].item(),
607
- candidate_positions_for_unmasking[j_batch_idx].sum().item()) # ensure k isn't too large
 
 
 
 
 
608
 
609
- if k_val > 0:
610
- # Ensure confidence_for_selection[j_batch_idx] is 1D for topk
611
- conf_slice = confidence_for_selection[j_batch_idx]
612
- if conf_slice.ndim > 1: conf_slice = conf_slice.view(-1) # Should already be 1D from x0_probs
613
-
614
- # Check if there are enough valid (non -inf) confidences
615
- valid_conf_count = (conf_slice > -torch.inf).sum().item()
616
- actual_k = min(k_val, valid_conf_count)
617
 
618
- if actual_k > 0:
619
- _, topk_indices_in_x = torch.topk(conf_slice, k=actual_k)
620
- transfer_indices_bool[j_batch_idx, topk_indices_in_x] = True
621
-
622
- x[transfer_indices_bool] = x0_final_candidates[transfer_indices_bool]
623
 
624
- current_total_step = num_block_iter * steps_per_block + i_step_in_block + 1
625
- total_overall_steps = num_blocks * steps_per_block
626
- status_msg = f"Block {num_block_iter+1}/{num_blocks}, Step {i_step_in_block+1}/{steps_per_block} (Total: {current_total_step}/{total_overall_steps})"
627
- yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), status_msg
 
628
 
629
- final_generated_ids = x[:, prompt_len:]
630
- final_text_output = TOKENIZER.batch_decode(final_generated_ids, skip_special_tokens=True)
631
-
632
- final_text_str = final_text_output[0] if final_text_output and len(final_text_output) > 0 else ""
633
- yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), final_text_str
 
634
 
635
 
636
  css_styles = """
@@ -1025,8 +1063,8 @@ with gr.Blocks(css=css_styles, theme=theme) as demo:
1025
 
1026
  if VQ_MODEL is None:
1027
  print("Loading VQ_MODEL for the first time...")
1028
- VQ_MODEL = MAGVITv2().from_pretrained("showlab/magvitv2").to(DEVICE)
1029
- print("VQ_MODEL loaded.")
1030
 
1031
  default_model_choice = "MMaDA-8B-MixCoT"
1032
 
 
83
  TOKENIZER = AutoTokenizer.from_pretrained(model_path_to_load, trust_remote_code=True)
84
  status_msg_parts.append(f"Tokenizer for '{model_display_name_for_status}' loaded.")
85
 
86
+ MODEL = MMadaModelLM.from_pretrained(model_path_to_load, trust_remote_code=True, torch_dtype=torch.bfloat16).eval()
87
  status_msg_parts.append(f"Model '{model_display_name_for_status}' loaded to {DEVICE}.")
88
 
89
  uni_prompting = UniversalPrompting(TOKENIZER, max_text_len=512, special_tokens=("<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>"),ignore_id=-100, cond_dropout_prob=0.1, use_reserved_token=True)
 
264
  if MODEL is None or TOKENIZER is None or MASK_ID is None:
265
  yield [("Error: Model not loaded. Please load the model first.", "ERROR")], "Model not loaded."
266
  return
 
 
267
 
268
+ if DEVICE == 'cuda':
269
+ print("Moving MODEL to GPU for inference...")
270
+ MODEL.to(DEVICE)
271
+ VQ_MODEL.to(DEVICE)
272
 
273
+ try:
274
+ steps = int(steps)
275
+ guidance_scale = float(guidance_scale)
276
+
277
+ image_tokens = torch.ones((1, 1024), dtype=torch.long, device=DEVICE) * MASK_ID
278
+ prompt_text = [prompt_text]
279
+ input_ids, attention_mask = uni_prompting((prompt_text, image_tokens), 't2i_gen')
280
+
281
+ if guidance_scale > 0:
282
+ uncond_input_ids, uncond_attention_mask = uni_prompting(([''], image_tokens), 't2i_gen')
283
+ else:
284
+ uncond_input_ids, uncond_attention_mask = None, None
285
+
286
+ mask_schedule = get_mask_schedule(mask_schedule)
287
+ blank_image = Image.new("RGB", (512, 512), (255, 255, 255))
288
+ yield blank_image, "Starting generation..."
289
+ for image_step, status_msg_step in MODEL.t2i_generate_decoding_stepwise(
290
+ input_ids = input_ids,
291
+ uncond_input_ids = uncond_input_ids,
292
+ attention_mask = attention_mask,
293
+ uncond_attention_mask = uncond_attention_mask,
294
+ temperature=1.0,
295
+ timesteps = steps,
296
+ guidance_scale = guidance_scale,
297
+ noise_schedule = mask_schedule,
298
+ noise_type = "mask",
299
+ seq_len = 1024,
300
+ vq_model = VQ_MODEL,
301
+ uni_prompting=uni_prompting):
302
+ yield image_step, status_msg_step
303
+
304
+ finally:
305
+ if DEVICE == 'cuda':
306
+ print("Moving MODEL back to CPU...")
307
+ MODEL.to('cpu')
308
+ VQ_MODEL.to('cpu')
309
+ torch.cuda.empty_cache()
310
 
311
 
312
 
 
320
  yield [("Error: Model not loaded. Please load the model first.", "ERROR")], "Model not loaded."
321
  return
322
 
323
+ if DEVICE == 'cuda':
324
+ print("Moving MODEL to GPU for inference...")
325
+ MODEL.to(DEVICE)
 
 
 
326
 
327
  try:
328
+ steps = int(steps)
329
+ gen_length = int(gen_length)
330
+ block_length = int(block_length)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
331
 
332
+ if thinking_mode_lm:
333
+ prompt_text = "You should first think about the reasoning process in the mind and then provide the user with the answer. The reasoning process is enclosed within <think> </think> tags, i.e. <think> reasoning process here </think> answer here\n" + prompt_text
334
 
335
+ try:
336
+ m = [{"role": "user", "content": prompt_text}]
337
+ processed_prompt_text = TOKENIZER.apply_chat_template(m, add_generation_prompt=True, tokenize=False)
338
+ except Exception as e:
339
+ yield [("Error applying chat template.", "ERROR")], f"Chat template error: {e}"
340
+ processed_prompt_text = prompt_text
341
+ try:
342
+ if TOKENIZER.pad_token_id is None:
343
+ if TOKENIZER.eos_token_id is not None:
344
+ TOKENIZER.pad_token_id = TOKENIZER.eos_token_id
345
+ else: # Should have been caught by load_model, but double check
346
+ yield [("Tokenizer Error", "ERROR")], "pad_token_id is not set in tokenizer."
347
+ return
348
+
349
+ input_ids = TOKENIZER(text=processed_prompt_text, return_tensors="pt", padding="longest", padding_side="left", truncation=True, max_length=MODEL.config.max_position_embeddings if hasattr(MODEL.config, 'max_position_embeddings') else 2048)['input_ids'].to(DEVICE)
350
+ raw_prompt_attention_mask = None
351
+
352
+ except Exception as e:
353
+ yield [("Error tokenizing prompt.", "ERROR")], f"Tokenization error: {e}"
354
+ return
355
 
356
+
 
357
 
358
+ batch_size = input_ids.shape[0]
359
+ prompt_len = input_ids.shape[1]
360
 
361
+ x = torch.full((batch_size, prompt_len + gen_length), MASK_ID, dtype=torch.long, device=DEVICE)
362
+ x[:, :prompt_len] = input_ids.clone()
 
 
363
 
364
+ yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), "Starting generation: Prompt + Initial Masks"
 
 
 
 
365
 
366
+ if gen_length == 0:
367
+ final_text_output = TOKENIZER.batch_decode(x[:,prompt_len:], skip_special_tokens=True)
368
+ yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), final_text_output[0] if final_text_output else ""
369
+ return
370
+
371
+ if block_length <= 0 or gen_length % block_length != 0 :
372
+ yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), \
373
+ f"Error: gen_length ({gen_length}) must be divisible by block_length ({block_length}) and block_length > 0."
374
+ return
375
+ num_blocks = gen_length // block_length
376
+
377
+ if steps <=0 or steps % num_blocks != 0:
378
+ yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), \
379
+ f"Error: steps ({steps}) must be positive and divisible by num_blocks ({num_blocks}). Steps: {steps}, Num Blocks: {num_blocks}"
380
+ return
381
+ steps_per_block = steps // num_blocks
382
 
383
+ for num_block_iter in range(num_blocks):
384
+ current_block_start_idx_in_x = prompt_len + num_block_iter * block_length
385
+ current_block_end_idx_in_x = prompt_len + (num_block_iter + 1) * block_length
386
+
387
+ block_masks_bool_current = torch.zeros_like(x, dtype=torch.bool)
388
+ block_masks_bool_current[:, current_block_start_idx_in_x:current_block_end_idx_in_x] = \
389
+ (x[:, current_block_start_idx_in_x:current_block_end_idx_in_x] == MASK_ID)
390
 
391
+ num_transfer_tokens_for_this_block = get_num_transfer_tokens(
392
+ block_masks_bool_current[:, current_block_start_idx_in_x:current_block_end_idx_in_x],
393
+ steps_per_block
394
+ )
395
 
396
+ for i_step_in_block in range(steps_per_block):
397
+ mask_index_global = (x == MASK_ID)
398
+
399
+ if cfg_scale > 0.:
400
+ un_x = x.clone()
401
+ # For unconditional pass, mask out the original prompt tokens that are not padding
402
+ # raw_prompt_attention_mask is (B, prompt_len)
403
+ prompt_active_tokens_mask = raw_prompt_attention_mask.bool() # True where actual prompt tokens are
404
+ un_x[:, :prompt_len][prompt_active_tokens_mask] = MASK_ID
405
+
406
+ x_cfg_input = torch.cat([x, un_x], dim=0)
407
+ # Pass attention_mask for CFG if model expects it, covering both parts
408
+ # For simplicity, not passing explicit attention_mask here; relies on model's internal handling.
409
+ model_output = MODEL(x_cfg_input)
410
+ logits_cond, logits_uncond = torch.chunk(model_output.logits, 2, dim=0)
411
+ logits = logits_uncond + (cfg_scale + 1) * (logits_cond - logits_uncond)
412
+ else:
413
+ # Not passing explicit attention_mask here; relies on model's internal handling.
414
+ model_output = MODEL(x)
415
+ logits = model_output.logits
416
+
417
+ logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
418
+ x0_predicted_tokens = torch.argmax(logits_with_noise, dim=-1)
419
+
420
+ if remasking_strategy == 'low_confidence':
421
+ probs = F.softmax(logits.to(torch.float64), dim=-1)
422
+ x0_probs = torch.gather(probs, dim=-1, index=x0_predicted_tokens.unsqueeze(-1)).squeeze(-1)
423
+ elif remasking_strategy == 'random':
424
+ x0_probs = torch.rand(x.shape, device=x.device, dtype=torch.float64)
425
+ else:
426
+ yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), f"Error: Unknown remasking strategy '{remasking_strategy}'"
427
+ return
428
+
429
+ confidence_for_selection = torch.full_like(x0_probs, -torch.inf)
430
+ candidate_positions_for_unmasking = mask_index_global & block_masks_bool_current
431
+ confidence_for_selection = torch.where(
432
+ candidate_positions_for_unmasking,
433
+ x0_probs,
434
+ -torch.inf
435
+ )
436
 
437
+ x0_final_candidates = torch.where(mask_index_global, x0_predicted_tokens, x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
438
 
439
+ transfer_indices_bool = torch.zeros_like(x, dtype=torch.bool)
440
+ num_to_transfer_this_step_batch = num_transfer_tokens_for_this_block[:, i_step_in_block]
 
 
 
 
 
 
 
441
 
442
+ for j_batch_idx in range(batch_size):
443
+ k_val = min(num_to_transfer_this_step_batch[j_batch_idx].item(),
444
+ candidate_positions_for_unmasking[j_batch_idx].sum().item()) # ensure k isn't too large
445
 
446
+ if k_val > 0:
447
+ # Ensure confidence_for_selection[j_batch_idx] is 1D for topk
448
+ conf_slice = confidence_for_selection[j_batch_idx]
449
+ if conf_slice.ndim > 1: conf_slice = conf_slice.view(-1) # Should already be 1D from x0_probs
450
+
451
+ # Check if there are enough valid (non -inf) confidences
452
+ valid_conf_count = (conf_slice > -torch.inf).sum().item()
453
+ actual_k = min(k_val, valid_conf_count)
454
 
455
+ if actual_k > 0:
456
+ _, topk_indices_in_x = torch.topk(conf_slice, k=actual_k)
457
+ transfer_indices_bool[j_batch_idx, topk_indices_in_x] = True
458
+
459
+ x[transfer_indices_bool] = x0_final_candidates[transfer_indices_bool]
 
 
 
460
 
461
+ current_total_step = num_block_iter * steps_per_block + i_step_in_block + 1
462
+ total_overall_steps = num_blocks * steps_per_block
463
+ status_msg = f"Block {num_block_iter+1}/{num_blocks}, Step {i_step_in_block+1}/{steps_per_block} (Total: {current_total_step}/{total_overall_steps})"
464
+ yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), status_msg
 
465
 
466
+ final_generated_ids = x[:, prompt_len:]
467
+ final_text_output = TOKENIZER.batch_decode(final_generated_ids, skip_special_tokens=True)
468
+
469
+ final_text_str = final_text_output[0] if final_text_output and len(final_text_output) > 0 else ""
470
+ yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), final_text_str
471
 
472
+ finally:
473
+ if DEVICE == 'cuda':
474
+ print("Moving MODEL back to CPU and clearing cache...")
475
+ MODEL.to('cpu')
476
+ torch.cuda.empty_cache()
477
 
478
  @torch.no_grad()
479
  @spaces.GPU
 
485
  yield [("Error: Model not loaded. Please load the model first.", "ERROR")], "Model not loaded."
486
  return
487
 
488
+ if DEVICE == 'cuda':
489
+ print("Moving MODEL to GPU for inference...")
490
+ MODEL.to(DEVICE)
491
+ VQ_MODEL.to(DEVICE)
492
+
493
+ try:
494
+ steps = int(steps)
495
+ gen_length = int(gen_length)
496
+ block_length = int(block_length)
497
 
498
+ if thinking_mode_mmu:
499
+ prompt_text = "You should first think about the reasoning process in the mind and then provide the user with the answer. The reasoning process is enclosed within <think> </think> tags, i.e. <think> reasoning process here </think> answer here\n" + prompt_text
500
 
 
 
 
 
 
 
 
 
 
501
  try:
502
+ m = [{"role": "user", "content": prompt_text}]
503
+ processed_prompt_text = TOKENIZER.apply_chat_template(m, add_generation_prompt=True, tokenize=False)
 
 
504
  except Exception as e:
505
+ yield [("Error applying chat template.", "ERROR")], f"Chat template error: {e}"
506
+ processed_prompt_text = prompt_text
507
+
508
+ image_vq_ids_tensor = None
509
+ if uploaded_image_pil is not None:
510
+ try:
511
+
512
+ image = image_transform(uploaded_image_pil, resolution=512).to(DEVICE)
513
+ image = image.unsqueeze(0)
514
+ image_vq_ids_tensor = VQ_MODEL.get_code(image) + 126349
515
+ except Exception as e:
516
+ yield [("Error processing image.", "ERROR")], f"Image to VQ tokens conversion failed: {str(e)}"
517
+ return
518
+
519
+
520
+ try:
521
+ if TOKENIZER.pad_token_id is None:
522
+ if TOKENIZER.eos_token_id is not None:
523
+ TOKENIZER.pad_token_id = TOKENIZER.eos_token_id
524
+ else:
525
+ yield [("Tokenizer Error", "ERROR")], "pad_token_id is not set in tokenizer."
526
+ return
527
+
528
+ input_ids = TOKENIZER(text=processed_prompt_text, return_tensors="pt", padding="longest", padding_side="left", truncation=True, max_length=MODEL.config.max_position_embeddings if hasattr(MODEL.config, 'max_position_embeddings') else 2048)['input_ids'].to(DEVICE)
529
+ raw_prompt_attention_mask = None
530
+ if image_vq_ids_tensor is not None:
531
+ if image_vq_ids_tensor.ndim == 1:
532
+ image_vq_ids_tensor = image_vq_ids_tensor.unsqueeze(0)
533
+
534
+ input_ids = torch.cat([
535
+ (torch.ones(input_ids.shape[0], 1) * torch.tensor([126089])).to(DEVICE),
536
+ (torch.ones(input_ids.shape[0], 1) * torch.tensor([126084])).to(DEVICE),
537
+ image_vq_ids_tensor,
538
+ (torch.ones(input_ids.shape[0], 1) * torch.tensor([126085])).to(DEVICE),
539
+ input_ids
540
+ ], dim=1).long()
541
+
542
+ else:
543
+ input_ids = input_ids
544
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
545
 
546
+ except Exception as e:
547
+ yield [("Error tokenizing prompt.", "ERROR")], f"Tokenization error: {e}"
548
+ return
549
 
550
 
 
 
 
551
 
552
+ batch_size = input_ids.shape[0]
553
+ prompt_len = input_ids.shape[1]
 
 
554
 
555
+ x = torch.full((batch_size, prompt_len + gen_length), MASK_ID, dtype=torch.long, device=DEVICE)
556
+ x[:, :prompt_len] = input_ids.clone()
557
 
558
+ yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), "Starting generation: Prompt + Initial Masks"
559
 
560
+ if gen_length == 0:
561
+ final_text_output = TOKENIZER.batch_decode(x[:,prompt_len:], skip_special_tokens=True)
562
+ yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), final_text_output[0] if final_text_output else ""
563
+ return
564
 
565
+ if block_length <= 0 or gen_length % block_length != 0 :
566
+ yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), \
567
+ f"Error: gen_length ({gen_length}) must be divisible by block_length ({block_length}) and block_length > 0."
568
+ return
569
+ num_blocks = gen_length // block_length
570
 
571
+ if steps <=0 or steps % num_blocks != 0:
572
+ yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), \
573
+ f"Error: steps ({steps}) must be positive and divisible by num_blocks ({num_blocks}). Steps: {steps}, Num Blocks: {num_blocks}"
574
+ return
575
+ steps_per_block = steps // num_blocks
 
 
 
 
576
 
577
+ for num_block_iter in range(num_blocks):
578
+ current_block_start_idx_in_x = prompt_len + num_block_iter * block_length
579
+ current_block_end_idx_in_x = prompt_len + (num_block_iter + 1) * block_length
580
+
581
+ block_masks_bool_current = torch.zeros_like(x, dtype=torch.bool)
582
+ block_masks_bool_current[:, current_block_start_idx_in_x:current_block_end_idx_in_x] = \
583
+ (x[:, current_block_start_idx_in_x:current_block_end_idx_in_x] == MASK_ID)
584
 
585
+ num_transfer_tokens_for_this_block = get_num_transfer_tokens(
586
+ block_masks_bool_current[:, current_block_start_idx_in_x:current_block_end_idx_in_x],
587
+ steps_per_block
588
+ )
589
 
590
+ for i_step_in_block in range(steps_per_block):
591
+ mask_index_global = (x == MASK_ID)
 
 
 
 
 
 
 
592
 
593
+ if cfg_scale > 0.:
594
+ un_x = x.clone()
595
+ # For unconditional pass, mask out the original prompt tokens that are not padding
596
+ # raw_prompt_attention_mask is (B, prompt_len)
597
+ prompt_active_tokens_mask = raw_prompt_attention_mask.bool() # True where actual prompt tokens are
598
+ un_x[:, :prompt_len][prompt_active_tokens_mask] = MASK_ID
599
+
600
+ x_cfg_input = torch.cat([x, un_x], dim=0)
601
+ # Pass attention_mask for CFG if model expects it, covering both parts
602
+ # For simplicity, not passing explicit attention_mask here; relies on model's internal handling.
603
+ model_output = MODEL(x_cfg_input)
604
+ logits_cond, logits_uncond = torch.chunk(model_output.logits, 2, dim=0)
605
+ logits = logits_uncond + (cfg_scale + 1) * (logits_cond - logits_uncond)
606
+ else:
607
+ # Not passing explicit attention_mask here; relies on model's internal handling.
608
+ model_output = MODEL(x)
609
+ logits = model_output.logits
610
+
611
+ logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
612
+ x0_predicted_tokens = torch.argmax(logits_with_noise, dim=-1)
613
+
614
+ if remasking_strategy == 'low_confidence':
615
+ probs = F.softmax(logits.to(torch.float64), dim=-1)
616
+ x0_probs = torch.gather(probs, dim=-1, index=x0_predicted_tokens.unsqueeze(-1)).squeeze(-1)
617
+ elif remasking_strategy == 'random':
618
+ x0_probs = torch.rand(x.shape, device=x.device, dtype=torch.float64)
619
+ else:
620
+ yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), f"Error: Unknown remasking strategy '{remasking_strategy}'"
621
+ return
622
+
623
+ confidence_for_selection = torch.full_like(x0_probs, -torch.inf)
624
+ candidate_positions_for_unmasking = mask_index_global & block_masks_bool_current
625
+ confidence_for_selection = torch.where(
626
+ candidate_positions_for_unmasking,
627
+ x0_probs,
628
+ -torch.inf
629
+ )
630
+
631
+ x0_final_candidates = torch.where(mask_index_global, x0_predicted_tokens, x)
632
 
633
+ transfer_indices_bool = torch.zeros_like(x, dtype=torch.bool)
634
+ num_to_transfer_this_step_batch = num_transfer_tokens_for_this_block[:, i_step_in_block]
 
 
 
 
 
 
 
635
 
636
+ for j_batch_idx in range(batch_size):
637
+ k_val = min(num_to_transfer_this_step_batch[j_batch_idx].item(),
638
+ candidate_positions_for_unmasking[j_batch_idx].sum().item()) # ensure k isn't too large
639
 
640
+ if k_val > 0:
641
+ # Ensure confidence_for_selection[j_batch_idx] is 1D for topk
642
+ conf_slice = confidence_for_selection[j_batch_idx]
643
+ if conf_slice.ndim > 1: conf_slice = conf_slice.view(-1) # Should already be 1D from x0_probs
644
+
645
+ # Check if there are enough valid (non -inf) confidences
646
+ valid_conf_count = (conf_slice > -torch.inf).sum().item()
647
+ actual_k = min(k_val, valid_conf_count)
648
 
649
+ if actual_k > 0:
650
+ _, topk_indices_in_x = torch.topk(conf_slice, k=actual_k)
651
+ transfer_indices_bool[j_batch_idx, topk_indices_in_x] = True
652
+
653
+ x[transfer_indices_bool] = x0_final_candidates[transfer_indices_bool]
 
 
 
654
 
655
+ current_total_step = num_block_iter * steps_per_block + i_step_in_block + 1
656
+ total_overall_steps = num_blocks * steps_per_block
657
+ status_msg = f"Block {num_block_iter+1}/{num_blocks}, Step {i_step_in_block+1}/{steps_per_block} (Total: {current_total_step}/{total_overall_steps})"
658
+ yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), status_msg
 
659
 
660
+ final_generated_ids = x[:, prompt_len:]
661
+ final_text_output = TOKENIZER.batch_decode(final_generated_ids, skip_special_tokens=True)
662
+
663
+ final_text_str = final_text_output[0] if final_text_output and len(final_text_output) > 0 else ""
664
+ yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), final_text_str
665
 
666
+ finally:
667
+ if DEVICE == 'cuda':
668
+ print("Moving MODEL back to CPU and clearing cache...")
669
+ MODEL.to('cpu')
670
+ VQ_MODEL.to('cpu')
671
+ torch.cuda.empty_cache()
672
 
673
 
674
  css_styles = """
 
1063
 
1064
  if VQ_MODEL is None:
1065
  print("Loading VQ_MODEL for the first time...")
1066
+ VQ_MODEL = MAGVITv2().from_pretrained("showlab/magvitv2")
1067
+ print("VQ_MODEL loaded to CPU.")
1068
 
1069
  default_model_choice = "MMaDA-8B-MixCoT"
1070