YucYux commited on
Commit
60e176e
·
1 Parent(s): d08f144

Added support for MMaDA-8B-MixCoT

Browse files
Files changed (1) hide show
  1. app.py +290 -71
app.py CHANGED
@@ -47,22 +47,23 @@ def get_num_transfer_tokens(mask_index, steps):
47
  return num_transfer_tokens
48
 
49
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
50
- DEFAULT_MODEL_PATH = "Gen-Verse/MMaDA-8B-Base" # Default
51
  MASK_ID = 126336
52
  MODEL = MMadaModelLM.from_pretrained(DEFAULT_MODEL_PATH, trust_remote_code=True, torch_dtype=torch.bfloat16).to(DEVICE).eval()
53
  TOKENIZER = AutoTokenizer.from_pretrained(DEFAULT_MODEL_PATH, trust_remote_code=True)
54
  uni_prompting = UniversalPrompting(TOKENIZER, max_text_len=512, special_tokens=("<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>"),ignore_id=-100, cond_dropout_prob=0.1, use_reserved_token=True)
55
  VQ_MODEL = MAGVITv2().from_pretrained("showlab/magvitv2").to(DEVICE)
56
 
57
- CURRENT_MODEL_PATH = None
58
 
59
  MODEL_CHOICES = [
60
  "MMaDA-8B-Base",
61
- "MMaDA-8B-MixCoT (coming soon)",
62
  "MMaDA-8B-Max (coming soon)"
63
  ]
64
  MODEL_ACTUAL_PATHS = {
65
- "MMaDA-8B-Base": DEFAULT_MODEL_PATH,
 
66
  }
67
 
68
  def clear_outputs_action():
@@ -116,19 +117,91 @@ def _load_model_and_tokenizer_core(model_path_to_load, model_display_name_for_st
116
  # return f"Error loading model '{model_display_name_for_status}': {str(e)}"
117
 
118
  def handle_model_selection_change(selected_model_name_ui):
119
- if "coming soon" in selected_model_name_ui.lower():
120
- global MODEL, TOKENIZER, MASK_ID, CURRENT_MODEL_PATH
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  MODEL = None
122
  TOKENIZER = None
123
  MASK_ID = None
124
  CURRENT_MODEL_PATH = None
125
- return f"'{selected_model_name_ui}' is not yet available. Please select 'Model A'."
126
-
127
- actual_path = MODEL_ACTUAL_PATHS.get(selected_model_name_ui)
128
- if not actual_path:
129
- return f"Path for '{selected_model_name_ui}' is not defined. Cannot load."
130
-
131
- return _load_model_and_tokenizer_core(actual_path, selected_model_name_ui)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
 
134
  def get_highlighted_text_tuples(current_x_ids_batch, prompt_input_ids, prompt_len, tk, current_mask_id, raw_prompt_attention_mask):
@@ -618,7 +691,7 @@ with gr.Blocks(css=css_styles, theme=theme) as demo:
618
  model_select_radio = gr.Radio(
619
  label="Select Text Generation Model",
620
  choices=MODEL_CHOICES,
621
- value=MODEL_CHOICES[0]
622
  )
623
  model_load_status_box = gr.Textbox(
624
  label="Model Load Status",
@@ -662,17 +735,39 @@ with gr.Blocks(css=css_styles, theme=theme) as demo:
662
  output_final_text_box_lm = gr.Textbox(label="Final Output", lines=8, interactive=False, show_copy_button=True)
663
 
664
 
665
-
666
- gr.Examples(
667
- examples=[
668
- ["A rectangular prism has a length of 5 units, a width of 4 units, and a height of 3 units. What is the volume of the prism?", 256, 512, 128, 1, 0, "low_confidence"],
669
- ["Lily can run 12 kilometers per hour for 4 hours. After that, she can run 6 kilometers per hour. How many kilometers can she run in 8 hours?", 256, 512, 64, 1, 0, "low_confidence"]
670
- ],
671
- inputs=[prompt_input_box_lm, steps_slider_lm, gen_length_slider_lm, block_length_slider_lm, temperature_slider_lm, cfg_scale_slider_lm, remasking_dropdown_lm],
672
- outputs=[output_visualization_box_lm, output_final_text_box_lm],
673
- fn=generate_viz_wrapper_lm,
674
- cache_examples=False
675
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
676
 
677
  gr.Markdown("---")
678
  gr.Markdown("## Part 2. Multimodal Understanding")
@@ -681,7 +776,7 @@ with gr.Blocks(css=css_styles, theme=theme) as demo:
681
  prompt_input_box_mmu = gr.Textbox(
682
  label="Enter your prompt:",
683
  lines=3,
684
- value="Please describe this image in detail."
685
  )
686
  think_button_mmu = gr.Button("🧠 Enable Thinking Mode", elem_id="think_btn")
687
  with gr.Accordion("Generation Parameters", open=True):
@@ -689,7 +784,7 @@ with gr.Blocks(css=css_styles, theme=theme) as demo:
689
  gen_length_slider_mmu = gr.Slider(minimum=64, maximum=1024, value=512, step=64, label="Generation Length", info="Number of tokens to generate.")
690
  steps_slider_mmu = gr.Slider(minimum=1, maximum=512, value=256, step=32, label="Total Sampling Steps", info="Must be divisible by (gen_length / block_length).")
691
  with gr.Row():
692
- block_length_slider_mmu = gr.Slider(minimum=32, maximum=1024, value=128, step=32, label="Block Length", info="gen_length must be divisible by this.")
693
  remasking_dropdown_mmu = gr.Dropdown(choices=['low_confidence', 'random'], value='low_confidence', label="Remasking Strategy")
694
  with gr.Row():
695
  cfg_scale_slider_mmu = gr.Slider(minimum=0.0, maximum=2.0, value=0.0, step=0.1, label="CFG Scale", info="Classifier-Free Guidance. 0 disables it.")
@@ -715,44 +810,120 @@ with gr.Blocks(css=css_styles, theme=theme) as demo:
715
  gr.Markdown("## Final Generated Text")
716
  output_final_text_box_mmu = gr.Textbox(label="Final Output", lines=8, interactive=False, show_copy_button=True)
717
 
718
-
719
- gr.Examples(
720
- examples=[
721
- [
722
- "figs/sunflower.jpg",
723
- "Please describe this image in detail.",
724
- 256,
725
- 512,
726
- 128,
727
- 1,
728
- 0,
729
- "low_confidence"
 
 
 
 
 
 
 
 
 
 
 
730
  ],
731
- [
732
- "figs/woman.jpg",
733
- "Please describe this image in detail.",
734
- 256,
735
- 512,
736
- 128,
737
- 1,
738
- 0,
739
- "low_confidence"
740
- ]
741
- ],
742
- inputs=[
743
- image_upload_box,
744
- prompt_input_box_mmu,
745
- steps_slider_mmu,
746
- gen_length_slider_mmu,
747
- block_length_slider_mmu,
748
- temperature_slider_mmu,
749
- cfg_scale_slider_mmu,
750
- remasking_dropdown_mmu
751
- ],
752
- outputs=[output_visualization_box_mmu, output_final_text_box_mmu],
753
- fn=generate_viz_wrapper,
754
- cache_examples=False
755
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
756
 
757
  gr.Markdown("---")
758
  gr.Markdown("## Part 3. Text-to-Image Generation")
@@ -823,21 +994,69 @@ with gr.Blocks(css=css_styles, theme=theme) as demo:
823
  inputs=[thinking_mode_mmu],
824
  outputs=[thinking_mode_mmu, think_button_mmu]
825
  )
826
-
827
 
 
 
828
 
829
- def initialize_default_model():
830
- default_model = "MMaDA-8B-Base"
831
- result = handle_model_selection_change(default_model)
832
- return default_model, result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
833
 
834
  demo.load(
835
- fn=initialize_default_model,
836
  inputs=None,
837
- outputs=[model_select_radio, model_load_status_box],
 
 
 
 
 
 
 
 
 
 
 
 
 
838
  queue=True
839
  )
840
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
841
  def clear_outputs():
842
  return None, None, None # Clear image, visualization, and final text
843
 
 
47
  return num_transfer_tokens
48
 
49
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
50
+ DEFAULT_MODEL_PATH = "Gen-Verse/MMaDA-8B-MixCoT" # Default
51
  MASK_ID = 126336
52
  MODEL = MMadaModelLM.from_pretrained(DEFAULT_MODEL_PATH, trust_remote_code=True, torch_dtype=torch.bfloat16).to(DEVICE).eval()
53
  TOKENIZER = AutoTokenizer.from_pretrained(DEFAULT_MODEL_PATH, trust_remote_code=True)
54
  uni_prompting = UniversalPrompting(TOKENIZER, max_text_len=512, special_tokens=("<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>"),ignore_id=-100, cond_dropout_prob=0.1, use_reserved_token=True)
55
  VQ_MODEL = MAGVITv2().from_pretrained("showlab/magvitv2").to(DEVICE)
56
 
57
+ CURRENT_MODEL_PATH = DEFAULT_MODEL_PATH
58
 
59
  MODEL_CHOICES = [
60
  "MMaDA-8B-Base",
61
+ "MMaDA-8B-MixCoT",
62
  "MMaDA-8B-Max (coming soon)"
63
  ]
64
  MODEL_ACTUAL_PATHS = {
65
+ "MMaDA-8B-Base": "Gen-Verse/MMaDA-8B-Base",
66
+ "MMaDA-8B-MixCoT": "Gen-Verse/MMaDA-8B-MixCoT"
67
  }
68
 
69
  def clear_outputs_action():
 
117
  # return f"Error loading model '{model_display_name_for_status}': {str(e)}"
118
 
119
  def handle_model_selection_change(selected_model_name_ui):
120
+ global MODEL, TOKENIZER, MASK_ID, CURRENT_MODEL_PATH, DEVICE, uni_prompting
121
+
122
+ status_msg = ""
123
+ # 初始化 Examples 的可见性更新
124
+ vis_lm_base = gr.update(visible=False)
125
+ vis_lm_mixcot = gr.update(visible=False)
126
+ vis_lm_max = gr.update(visible=False)
127
+ vis_mmu_base = gr.update(visible=False)
128
+ vis_mmu_mixcot = gr.update(visible=False)
129
+ vis_mmu_max = gr.update(visible=False)
130
+
131
+ # 根据选择的模型决定 thinking mode 的默认状态
132
+ is_mixcot_model_selected = (selected_model_name_ui == "MMaDA-8B-MixCoT")
133
+
134
+ # 初始 thinking mode 状态和按钮标签
135
+ # 如果是 MixCoT 模型,则默认为 True (开启)
136
+ current_thinking_mode_lm_state = is_mixcot_model_selected
137
+ current_thinking_mode_mmu_state = is_mixcot_model_selected
138
+
139
+ lm_think_button_label = "Thinking Mode ✅" if current_thinking_mode_lm_state else "Thinking Mode ❌"
140
+ mmu_think_button_label = "Thinking Mode ✅" if current_thinking_mode_mmu_state else "Thinking Mode ❌"
141
+
142
+ update_think_button_lm = gr.update(value=lm_think_button_label)
143
+ update_think_button_mmu = gr.update(value=mmu_think_button_label)
144
+
145
+ if selected_model_name_ui == "MMaDA-8B-Max (coming soon)":
146
  MODEL = None
147
  TOKENIZER = None
148
  MASK_ID = None
149
  CURRENT_MODEL_PATH = None
150
+ status_msg = f"'{selected_model_name_ui}' is not yet available. Please select another model."
151
+ vis_lm_max = gr.update(visible=True)
152
+ vis_mmu_max = gr.update(visible=True)
153
+ # 对于非 MixCoT 模型,thinking mode 在上面已经根据 is_mixcot_model_selected 设置为 False
154
+ else:
155
+ actual_path = MODEL_ACTUAL_PATHS.get(selected_model_name_ui)
156
+ if not actual_path:
157
+ MODEL = None
158
+ TOKENIZER = None
159
+ MASK_ID = None
160
+ CURRENT_MODEL_PATH = None
161
+ status_msg = f"Path for '{selected_model_name_ui}' is not defined. Cannot load."
162
+ # 如果路径未定义(意味着不是有效的MixCoT加载),thinking mode应为False
163
+ if is_mixcot_model_selected: # 如果本应是MixCoT但路径没有
164
+ current_thinking_mode_lm_state = False
165
+ current_thinking_mode_mmu_state = False
166
+ update_think_button_lm = gr.update(value="Thinking Mode ❌")
167
+ update_think_button_mmu = gr.update(value="Thinking Mode ❌")
168
+ else:
169
+ # 尝试加载模型
170
+ status_msg = _load_model_and_tokenizer_core(actual_path, selected_model_name_ui)
171
+
172
+ # 检查模型是否成功加载
173
+ if "Error loading model" in status_msg or MODEL is None:
174
+ # 如果是 MixCoT 模型但加载失败,则关闭 thinking mode
175
+ if is_mixcot_model_selected:
176
+ current_thinking_mode_lm_state = False
177
+ current_thinking_mode_mmu_state = False
178
+ update_think_button_lm = gr.update(value="Thinking Mode ❌")
179
+ update_think_button_mmu = gr.update(value="Thinking Mode ❌")
180
+ if MODEL is None and "Error" not in status_msg: # 补充一个通用错误信息
181
+ status_msg = f"Failed to properly load model '{selected_model_name_ui}'. {status_msg}"
182
+ else: # 模型成功加载
183
+ if selected_model_name_ui == "MMaDA-8B-Base":
184
+ vis_lm_base = gr.update(visible=True)
185
+ vis_mmu_base = gr.update(visible=True)
186
+ elif selected_model_name_ui == "MMaDA-8B-MixCoT":
187
+ vis_lm_mixcot = gr.update(visible=True)
188
+ vis_mmu_mixcot = gr.update(visible=True)
189
+ # thinking mode 已经在函数开头根据 is_mixcot_model_selected 设置为 True
190
+
191
+ return (
192
+ status_msg,
193
+ vis_lm_base,
194
+ vis_lm_mixcot,
195
+ vis_lm_max,
196
+ vis_mmu_base,
197
+ vis_mmu_mixcot,
198
+ vis_mmu_max,
199
+ # 新增的返回值,用于更新 thinking_mode 状态和按钮
200
+ current_thinking_mode_lm_state, # 直接返回值给 gr.State
201
+ update_think_button_lm, # gr.update 对象给 gr.Button
202
+ current_thinking_mode_mmu_state,
203
+ update_think_button_mmu
204
+ )
205
 
206
 
207
  def get_highlighted_text_tuples(current_x_ids_batch, prompt_input_ids, prompt_len, tk, current_mask_id, raw_prompt_attention_mask):
 
691
  model_select_radio = gr.Radio(
692
  label="Select Text Generation Model",
693
  choices=MODEL_CHOICES,
694
+ value="MMaDA-8B-MixCoT"
695
  )
696
  model_load_status_box = gr.Textbox(
697
  label="Model Load Status",
 
735
  output_final_text_box_lm = gr.Textbox(label="Final Output", lines=8, interactive=False, show_copy_button=True)
736
 
737
 
738
+ with gr.Column(visible=False) as examples_lm_base:
739
+ gr.Examples(
740
+ examples=[
741
+ ["A rectangular prism has a length of 5 units, a width of 4 units, and a height of 3 units. What is the volume of the prism?", 256, 512, 128, 1, 0, "low_confidence"],
742
+ ["Lily can run 12 kilometers per hour for 4 hours. After that, she can run 6 kilometers per hour. How many kilometers can she run in 8 hours?", 256, 512, 64, 1, 0, "low_confidence"]
743
+ ],
744
+ inputs=[prompt_input_box_lm, steps_slider_lm, gen_length_slider_lm, block_length_slider_lm, temperature_slider_lm, cfg_scale_slider_lm, remasking_dropdown_lm],
745
+ outputs=[output_visualization_box_lm, output_final_text_box_lm],
746
+ fn=generate_viz_wrapper_lm,
747
+ cache_examples=False
748
+ )
749
+ with gr.Column(visible=True) as examples_lm_mixcot:
750
+ gr.Examples(
751
+ examples=[
752
+ ["A rectangular prism has a length of 5 units, a width of 4 units, and a height of 3 units. What is the volume of the prism?", 256, 512, 128, 1, 0, "low_confidence"],
753
+ ["Lily can run 12 kilometers per hour for 4 hours. After that, she can run 6 kilometers per hour. How many kilometers can she run in 8 hours?", 256, 512, 64, 1, 0, "low_confidence"]
754
+ ],
755
+ inputs=[prompt_input_box_lm, steps_slider_lm, gen_length_slider_lm, block_length_slider_lm, temperature_slider_lm, cfg_scale_slider_lm, remasking_dropdown_lm],
756
+ outputs=[output_visualization_box_lm, output_final_text_box_lm],
757
+ fn=generate_viz_wrapper_lm,
758
+ cache_examples=False
759
+ )
760
+ with gr.Column(visible=False) as examples_lm_max:
761
+ gr.Examples(
762
+ examples=[
763
+ ["A rectangular prism has a length of 5 units, a width of 4 units, and a height of 3 units. What is the volume of the prism?", 256, 512, 128, 1, 0, "low_confidence"],
764
+ ["Lily can run 12 kilometers per hour for 4 hours. After that, she can run 6 kilometers per hour. How many kilometers can she run in 8 hours?", 256, 512, 64, 1, 0, "low_confidence"]
765
+ ],
766
+ inputs=[prompt_input_box_lm, steps_slider_lm, gen_length_slider_lm, block_length_slider_lm, temperature_slider_lm, cfg_scale_slider_lm, remasking_dropdown_lm],
767
+ outputs=[output_visualization_box_lm, output_final_text_box_lm],
768
+ fn=generate_viz_wrapper_lm,
769
+ cache_examples=False
770
+ )
771
 
772
  gr.Markdown("---")
773
  gr.Markdown("## Part 2. Multimodal Understanding")
 
776
  prompt_input_box_mmu = gr.Textbox(
777
  label="Enter your prompt:",
778
  lines=3,
779
+ value=""
780
  )
781
  think_button_mmu = gr.Button("🧠 Enable Thinking Mode", elem_id="think_btn")
782
  with gr.Accordion("Generation Parameters", open=True):
 
784
  gen_length_slider_mmu = gr.Slider(minimum=64, maximum=1024, value=512, step=64, label="Generation Length", info="Number of tokens to generate.")
785
  steps_slider_mmu = gr.Slider(minimum=1, maximum=512, value=256, step=32, label="Total Sampling Steps", info="Must be divisible by (gen_length / block_length).")
786
  with gr.Row():
787
+ block_length_slider_mmu = gr.Slider(minimum=32, maximum=1024, value=64, step=32, label="Block Length", info="gen_length must be divisible by this.")
788
  remasking_dropdown_mmu = gr.Dropdown(choices=['low_confidence', 'random'], value='low_confidence', label="Remasking Strategy")
789
  with gr.Row():
790
  cfg_scale_slider_mmu = gr.Slider(minimum=0.0, maximum=2.0, value=0.0, step=0.1, label="CFG Scale", info="Classifier-Free Guidance. 0 disables it.")
 
810
  gr.Markdown("## Final Generated Text")
811
  output_final_text_box_mmu = gr.Textbox(label="Final Output", lines=8, interactive=False, show_copy_button=True)
812
 
813
+ with gr.Column(visible=False) as examples_mmu_base:
814
+ gr.Examples(
815
+ examples=[
816
+ [
817
+ "figs/sunflower.jpg",
818
+ "Please describe this image in detail.",
819
+ 256,
820
+ 512,
821
+ 128,
822
+ 1,
823
+ 0,
824
+ "low_confidence"
825
+ ],
826
+ [
827
+ "figs/woman.jpg",
828
+ "Please describe this image in detail.",
829
+ 256,
830
+ 512,
831
+ 128,
832
+ 1,
833
+ 0,
834
+ "low_confidence"
835
+ ]
836
  ],
837
+ inputs=[
838
+ image_upload_box,
839
+ prompt_input_box_mmu,
840
+ steps_slider_mmu,
841
+ gen_length_slider_mmu,
842
+ block_length_slider_mmu,
843
+ temperature_slider_mmu,
844
+ cfg_scale_slider_mmu,
845
+ remasking_dropdown_mmu
846
+ ],
847
+ outputs=[output_visualization_box_mmu, output_final_text_box_mmu],
848
+ fn=generate_viz_wrapper,
849
+ cache_examples=False
850
+ )
851
+ with gr.Column(visible=True) as examples_mmu_mixcot:
852
+ gr.Examples(
853
+ examples=[
854
+ [
855
+ "figs/geo.png",
856
+ "In the given figure, a square ABCD is inscribed in a circle with center O. Point P is located on side CD. What is the value of angle APB?",
857
+ 256,
858
+ 512,
859
+ 64,
860
+ 1,
861
+ 0,
862
+ "low_confidence"
863
+ ],
864
+ [
865
+ "figs/bus.jpg",
866
+ "What are the colors of the bus?",
867
+ 256,
868
+ 512,
869
+ 64,
870
+ 1,
871
+ 0,
872
+ "low_confidence"
873
+ ]
874
+ ],
875
+ inputs=[
876
+ image_upload_box,
877
+ prompt_input_box_mmu,
878
+ steps_slider_mmu,
879
+ gen_length_slider_mmu,
880
+ block_length_slider_mmu,
881
+ temperature_slider_mmu,
882
+ cfg_scale_slider_mmu,
883
+ remasking_dropdown_mmu
884
+ ],
885
+ outputs=[output_visualization_box_mmu, output_final_text_box_mmu],
886
+ fn=generate_viz_wrapper,
887
+ cache_examples=False
888
+ )
889
+ with gr.Column(visible=False) as examples_mmu_max:
890
+ gr.Examples(
891
+ examples=[
892
+ [
893
+ "figs/sunflower.jpg",
894
+ "Please describe this image in detail.",
895
+ 256,
896
+ 512,
897
+ 128,
898
+ 1,
899
+ 0,
900
+ "low_confidence"
901
+ ],
902
+ [
903
+ "figs/woman.jpg",
904
+ "Please describe this image in detail.",
905
+ 256,
906
+ 512,
907
+ 128,
908
+ 1,
909
+ 0,
910
+ "low_confidence"
911
+ ]
912
+ ],
913
+ inputs=[
914
+ image_upload_box,
915
+ prompt_input_box_mmu,
916
+ steps_slider_mmu,
917
+ gen_length_slider_mmu,
918
+ block_length_slider_mmu,
919
+ temperature_slider_mmu,
920
+ cfg_scale_slider_mmu,
921
+ remasking_dropdown_mmu
922
+ ],
923
+ outputs=[output_visualization_box_mmu, output_final_text_box_mmu],
924
+ fn=generate_viz_wrapper,
925
+ cache_examples=False
926
+ )
927
 
928
  gr.Markdown("---")
929
  gr.Markdown("## Part 3. Text-to-Image Generation")
 
994
  inputs=[thinking_mode_mmu],
995
  outputs=[thinking_mode_mmu, think_button_mmu]
996
  )
 
997
 
998
+ def initialize_app_state():
999
+ default_model_choice = "MMaDA-8B-MixCoT" # 默认加载 MixCoT
1000
 
1001
+ # handle_model_selection_change 现在返回更多项
1002
+ status, lm_b_vis, lm_m_vis, lm_x_vis, \
1003
+ mmu_b_vis, mmu_m_vis, mmu_x_vis, \
1004
+ init_thinking_lm_state, init_think_lm_btn_update, \
1005
+ init_thinking_mmu_state, init_think_mmu_btn_update = handle_model_selection_change(default_model_choice)
1006
+
1007
+ return (
1008
+ default_model_choice,
1009
+ status,
1010
+ lm_b_vis,
1011
+ lm_m_vis,
1012
+ lm_x_vis,
1013
+ mmu_b_vis,
1014
+ mmu_m_vis,
1015
+ mmu_x_vis,
1016
+ init_thinking_lm_state,
1017
+ init_think_lm_btn_update,
1018
+ init_thinking_mmu_state,
1019
+ init_think_mmu_btn_update
1020
+ )
1021
 
1022
  demo.load(
1023
+ fn=initialize_app_state,
1024
  inputs=None,
1025
+ outputs=[
1026
+ model_select_radio,
1027
+ model_load_status_box,
1028
+ examples_lm_base,
1029
+ examples_lm_mixcot,
1030
+ examples_lm_max,
1031
+ examples_mmu_base,
1032
+ examples_mmu_mixcot,
1033
+ examples_mmu_max,
1034
+ thinking_mode_lm, # gr.State for LM thinking mode
1035
+ think_button_lm, # gr.Button for LM thinking mode
1036
+ thinking_mode_mmu, # gr.State for MMU thinking mode
1037
+ think_button_mmu # gr.Button for MMU thinking mode
1038
+ ],
1039
  queue=True
1040
  )
1041
 
1042
+ model_select_radio.change(
1043
+ fn=handle_model_selection_change,
1044
+ inputs=[model_select_radio],
1045
+ outputs=[
1046
+ model_load_status_box,
1047
+ examples_lm_base,
1048
+ examples_lm_mixcot,
1049
+ examples_lm_max,
1050
+ examples_mmu_base,
1051
+ examples_mmu_mixcot,
1052
+ examples_mmu_max,
1053
+ thinking_mode_lm,
1054
+ think_button_lm,
1055
+ thinking_mode_mmu,
1056
+ think_button_mmu
1057
+ ]
1058
+ )
1059
+
1060
  def clear_outputs():
1061
  return None, None, None # Clear image, visualization, and final text
1062