YucYux
commited on
Commit
·
0e83169
1
Parent(s):
f6ba9f2
fixed model loading bug
Browse files
app.py
CHANGED
@@ -47,14 +47,14 @@ def get_num_transfer_tokens(mask_index, steps):
|
|
47 |
return num_transfer_tokens
|
48 |
|
49 |
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
50 |
-
DEFAULT_MODEL_PATH = "Gen-Verse/MMaDA-8B-MixCoT"
|
51 |
-
MASK_ID =
|
52 |
-
MODEL =
|
53 |
-
TOKENIZER =
|
54 |
-
uni_prompting =
|
55 |
-
VQ_MODEL =
|
56 |
|
57 |
-
CURRENT_MODEL_PATH =
|
58 |
|
59 |
MODEL_CHOICES = [
|
60 |
"MMaDA-8B-Base",
|
@@ -1021,9 +1021,15 @@ with gr.Blocks(css=css_styles, theme=theme) as demo:
|
|
1021 |
)
|
1022 |
|
1023 |
def initialize_app_state():
|
1024 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1025 |
|
1026 |
-
# handle_model_selection_change 现在返回更多项
|
1027 |
status, lm_b_vis, lm_m_vis, lm_x_vis, \
|
1028 |
mmu_b_vis, mmu_m_vis, mmu_x_vis, \
|
1029 |
init_thinking_lm_state, init_think_lm_btn_update, \
|
@@ -1031,12 +1037,12 @@ with gr.Blocks(css=css_styles, theme=theme) as demo:
|
|
1031 |
|
1032 |
return (
|
1033 |
default_model_choice,
|
1034 |
-
status,
|
1035 |
-
lm_b_vis,
|
1036 |
-
lm_m_vis,
|
1037 |
-
lm_x_vis,
|
1038 |
-
mmu_b_vis,
|
1039 |
-
mmu_m_vis,
|
1040 |
mmu_x_vis,
|
1041 |
init_thinking_lm_state,
|
1042 |
init_think_lm_btn_update,
|
@@ -1056,10 +1062,10 @@ with gr.Blocks(css=css_styles, theme=theme) as demo:
|
|
1056 |
examples_mmu_base,
|
1057 |
examples_mmu_mixcot,
|
1058 |
examples_mmu_max,
|
1059 |
-
thinking_mode_lm,
|
1060 |
-
think_button_lm,
|
1061 |
-
thinking_mode_mmu,
|
1062 |
-
think_button_mmu
|
1063 |
],
|
1064 |
queue=True
|
1065 |
)
|
|
|
47 |
return num_transfer_tokens
|
48 |
|
49 |
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
50 |
+
DEFAULT_MODEL_PATH = "Gen-Verse/MMaDA-8B-MixCoT"
|
51 |
+
MASK_ID = None # 初始化为 None
|
52 |
+
MODEL = None # 初始化为 None
|
53 |
+
TOKENIZER = None# 初始化为 None
|
54 |
+
uni_prompting = None # 初始化为 None
|
55 |
+
VQ_MODEL = None # 初始化为 None, 稍后在初始化函数中加载
|
56 |
|
57 |
+
CURRENT_MODEL_PATH = None # 初始化为 None
|
58 |
|
59 |
MODEL_CHOICES = [
|
60 |
"MMaDA-8B-Base",
|
|
|
1021 |
)
|
1022 |
|
1023 |
def initialize_app_state():
|
1024 |
+
global VQ_MODEL
|
1025 |
+
|
1026 |
+
if VQ_MODEL is None:
|
1027 |
+
print("Loading VQ_MODEL for the first time...")
|
1028 |
+
VQ_MODEL = MAGVITv2().from_pretrained("showlab/magvitv2").to(DEVICE)
|
1029 |
+
print("VQ_MODEL loaded.")
|
1030 |
+
|
1031 |
+
default_model_choice = "MMaDA-8B-MixCoT"
|
1032 |
|
|
|
1033 |
status, lm_b_vis, lm_m_vis, lm_x_vis, \
|
1034 |
mmu_b_vis, mmu_m_vis, mmu_x_vis, \
|
1035 |
init_thinking_lm_state, init_think_lm_btn_update, \
|
|
|
1037 |
|
1038 |
return (
|
1039 |
default_model_choice,
|
1040 |
+
status,
|
1041 |
+
lm_b_vis,
|
1042 |
+
lm_m_vis,
|
1043 |
+
lm_x_vis,
|
1044 |
+
mmu_b_vis,
|
1045 |
+
mmu_m_vis,
|
1046 |
mmu_x_vis,
|
1047 |
init_thinking_lm_state,
|
1048 |
init_think_lm_btn_update,
|
|
|
1062 |
examples_mmu_base,
|
1063 |
examples_mmu_mixcot,
|
1064 |
examples_mmu_max,
|
1065 |
+
thinking_mode_lm,
|
1066 |
+
think_button_lm,
|
1067 |
+
thinking_mode_mmu,
|
1068 |
+
think_button_mmu
|
1069 |
],
|
1070 |
queue=True
|
1071 |
)
|