YucYux commited on
Commit
0e83169
·
1 Parent(s): f6ba9f2

fixed model loading bug

Browse files
Files changed (1) hide show
  1. app.py +25 -19
app.py CHANGED
@@ -47,14 +47,14 @@ def get_num_transfer_tokens(mask_index, steps):
47
  return num_transfer_tokens
48
 
49
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
50
- DEFAULT_MODEL_PATH = "Gen-Verse/MMaDA-8B-MixCoT" # Default
51
- MASK_ID = 126336
52
- MODEL = MMadaModelLM.from_pretrained(DEFAULT_MODEL_PATH, trust_remote_code=True, torch_dtype=torch.bfloat16).to(DEVICE).eval()
53
- TOKENIZER = AutoTokenizer.from_pretrained(DEFAULT_MODEL_PATH, trust_remote_code=True)
54
- uni_prompting = UniversalPrompting(TOKENIZER, max_text_len=512, special_tokens=("<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>"),ignore_id=-100, cond_dropout_prob=0.1, use_reserved_token=True)
55
- VQ_MODEL = MAGVITv2().from_pretrained("showlab/magvitv2").to(DEVICE)
56
 
57
- CURRENT_MODEL_PATH = DEFAULT_MODEL_PATH
58
 
59
  MODEL_CHOICES = [
60
  "MMaDA-8B-Base",
@@ -1021,9 +1021,15 @@ with gr.Blocks(css=css_styles, theme=theme) as demo:
1021
  )
1022
 
1023
  def initialize_app_state():
1024
- default_model_choice = "MMaDA-8B-MixCoT" # 默认加载 MixCoT
 
 
 
 
 
 
 
1025
 
1026
- # handle_model_selection_change 现在返回更多项
1027
  status, lm_b_vis, lm_m_vis, lm_x_vis, \
1028
  mmu_b_vis, mmu_m_vis, mmu_x_vis, \
1029
  init_thinking_lm_state, init_think_lm_btn_update, \
@@ -1031,12 +1037,12 @@ with gr.Blocks(css=css_styles, theme=theme) as demo:
1031
 
1032
  return (
1033
  default_model_choice,
1034
- status,
1035
- lm_b_vis,
1036
- lm_m_vis,
1037
- lm_x_vis,
1038
- mmu_b_vis,
1039
- mmu_m_vis,
1040
  mmu_x_vis,
1041
  init_thinking_lm_state,
1042
  init_think_lm_btn_update,
@@ -1056,10 +1062,10 @@ with gr.Blocks(css=css_styles, theme=theme) as demo:
1056
  examples_mmu_base,
1057
  examples_mmu_mixcot,
1058
  examples_mmu_max,
1059
- thinking_mode_lm, # gr.State for LM thinking mode
1060
- think_button_lm, # gr.Button for LM thinking mode
1061
- thinking_mode_mmu, # gr.State for MMU thinking mode
1062
- think_button_mmu # gr.Button for MMU thinking mode
1063
  ],
1064
  queue=True
1065
  )
 
47
  return num_transfer_tokens
48
 
49
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
50
+ DEFAULT_MODEL_PATH = "Gen-Verse/MMaDA-8B-MixCoT"
51
+ MASK_ID = None # 初始化为 None
52
+ MODEL = None # 初始化为 None
53
+ TOKENIZER = None# 初始化为 None
54
+ uni_prompting = None # 初始化为 None
55
+ VQ_MODEL = None # 初始化为 None, 稍后在初始化函数中加载
56
 
57
+ CURRENT_MODEL_PATH = None # 初始化为 None
58
 
59
  MODEL_CHOICES = [
60
  "MMaDA-8B-Base",
 
1021
  )
1022
 
1023
  def initialize_app_state():
1024
+ global VQ_MODEL
1025
+
1026
+ if VQ_MODEL is None:
1027
+ print("Loading VQ_MODEL for the first time...")
1028
+ VQ_MODEL = MAGVITv2().from_pretrained("showlab/magvitv2").to(DEVICE)
1029
+ print("VQ_MODEL loaded.")
1030
+
1031
+ default_model_choice = "MMaDA-8B-MixCoT"
1032
 
 
1033
  status, lm_b_vis, lm_m_vis, lm_x_vis, \
1034
  mmu_b_vis, mmu_m_vis, mmu_x_vis, \
1035
  init_thinking_lm_state, init_think_lm_btn_update, \
 
1037
 
1038
  return (
1039
  default_model_choice,
1040
+ status,
1041
+ lm_b_vis,
1042
+ lm_m_vis,
1043
+ lm_x_vis,
1044
+ mmu_b_vis,
1045
+ mmu_m_vis,
1046
  mmu_x_vis,
1047
  init_thinking_lm_state,
1048
  init_think_lm_btn_update,
 
1062
  examples_mmu_base,
1063
  examples_mmu_mixcot,
1064
  examples_mmu_max,
1065
+ thinking_mode_lm,
1066
+ think_button_lm,
1067
+ thinking_mode_mmu,
1068
+ think_button_mmu
1069
  ],
1070
  queue=True
1071
  )