tnk2908 commited on
Commit
bf84082
·
1 Parent(s): b7e7adb

Add model checkpoints

Browse files
Files changed (1) hide show
  1. app.py +13 -4
app.py CHANGED
@@ -75,7 +75,11 @@ specialist_model, specialist_processor = build_specialist_model()
75
 
76
 
77
  def load_specialist_model(model_ckpt):
78
- specialist_model.load_state_dict(torch.load(model_ckpt, map_location=torch.device("cpu"), weights_only=True))
 
 
 
 
79
 
80
 
81
  def get_feature_dict(batch_size, device, active_dataset: ActiveDataset):
@@ -169,7 +173,11 @@ def build_input_ui():
169
  def build_parameters_ui():
170
  with gr.Accordion() as blk:
171
  budget_input = gr.Number(config.budget, label="Budget")
172
- model_ckpt_input = gr.Text(config.model_ckpt, label="Specialist Model Checkpoint")
 
 
 
 
173
  device_input = gr.Dropdown(choices=["cuda", "cpu"], value="cpu", label="Device", interactive=True)
174
  batch_size_input = gr.Number(config.batch_size, label="Batch Size")
175
  foundation_model_weight_input = gr.Number(config.loaded_feature_weight, label="foundation_model_weight")
@@ -180,10 +188,11 @@ def build_parameters_ui():
180
 
181
  budget_input.change(budget_input_change, budget_input, None)
182
 
183
- def model_ckpt_input_change(x):
184
  config.model_ckpt = x
 
185
 
186
- model_ckpt_input.change(model_ckpt_input_change, model_ckpt_input, None)
187
 
188
  def device_input_change(x):
189
  config.device = torch.device(x)
 
75
 
76
 
77
  def load_specialist_model(model_ckpt):
78
+ try:
79
+ specialist_model.load_state_dict(torch.load(model_ckpt, map_location=torch.device("cpu"), weights_only=True))
80
+ print(f"Loaded {model_ckpt}")
81
+ except:
82
+ print(f"Failed to load {model_ckpt}")
83
 
84
 
85
  def get_feature_dict(batch_size, device, active_dataset: ActiveDataset):
 
173
  def build_parameters_ui():
174
  with gr.Accordion() as blk:
175
  budget_input = gr.Number(config.budget, label="Budget")
176
+
177
+ with gr.Row():
178
+ model_ckpt_file = gr.File(visible=False)
179
+ model_ckpt_input = gr.UploadButton(label="Upload model checkpoint")
180
+
181
  device_input = gr.Dropdown(choices=["cuda", "cpu"], value="cpu", label="Device", interactive=True)
182
  batch_size_input = gr.Number(config.batch_size, label="Batch Size")
183
  foundation_model_weight_input = gr.Number(config.loaded_feature_weight, label="foundation_model_weight")
 
188
 
189
  budget_input.change(budget_input_change, budget_input, None)
190
 
191
+ def model_ckpt_input_upload(x):
192
  config.model_ckpt = x
193
+ return gr.File(value=x, visible=True)
194
 
195
+ model_ckpt_input.upload(model_ckpt_input_upload, model_ckpt_input, model_ckpt_file)
196
 
197
  def device_input_change(x):
198
  config.device = torch.device(x)