Spaces:
Sleeping
Sleeping
Add model checkpoints
Browse files
app.py
CHANGED
@@ -75,7 +75,11 @@ specialist_model, specialist_processor = build_specialist_model()
|
|
75 |
|
76 |
|
77 |
def load_specialist_model(model_ckpt):
|
78 |
-
|
|
|
|
|
|
|
|
|
79 |
|
80 |
|
81 |
def get_feature_dict(batch_size, device, active_dataset: ActiveDataset):
|
@@ -169,7 +173,11 @@ def build_input_ui():
|
|
169 |
def build_parameters_ui():
|
170 |
with gr.Accordion() as blk:
|
171 |
budget_input = gr.Number(config.budget, label="Budget")
|
172 |
-
|
|
|
|
|
|
|
|
|
173 |
device_input = gr.Dropdown(choices=["cuda", "cpu"], value="cpu", label="Device", interactive=True)
|
174 |
batch_size_input = gr.Number(config.batch_size, label="Batch Size")
|
175 |
foundation_model_weight_input = gr.Number(config.loaded_feature_weight, label="foundation_model_weight")
|
@@ -180,10 +188,11 @@ def build_parameters_ui():
|
|
180 |
|
181 |
budget_input.change(budget_input_change, budget_input, None)
|
182 |
|
183 |
-
def
|
184 |
config.model_ckpt = x
|
|
|
185 |
|
186 |
-
model_ckpt_input.
|
187 |
|
188 |
def device_input_change(x):
|
189 |
config.device = torch.device(x)
|
|
|
75 |
|
76 |
|
77 |
def load_specialist_model(model_ckpt):
|
78 |
+
try:
|
79 |
+
specialist_model.load_state_dict(torch.load(model_ckpt, map_location=torch.device("cpu"), weights_only=True))
|
80 |
+
print(f"Loaded {model_ckpt}")
|
81 |
+
except:
|
82 |
+
print(f"Failed to load {model_ckpt}")
|
83 |
|
84 |
|
85 |
def get_feature_dict(batch_size, device, active_dataset: ActiveDataset):
|
|
|
173 |
def build_parameters_ui():
|
174 |
with gr.Accordion() as blk:
|
175 |
budget_input = gr.Number(config.budget, label="Budget")
|
176 |
+
|
177 |
+
with gr.Row():
|
178 |
+
model_ckpt_file = gr.File(visible=False)
|
179 |
+
model_ckpt_input = gr.UploadButton(label="Upload model checkpoint")
|
180 |
+
|
181 |
device_input = gr.Dropdown(choices=["cuda", "cpu"], value="cpu", label="Device", interactive=True)
|
182 |
batch_size_input = gr.Number(config.batch_size, label="Batch Size")
|
183 |
foundation_model_weight_input = gr.Number(config.loaded_feature_weight, label="foundation_model_weight")
|
|
|
188 |
|
189 |
budget_input.change(budget_input_change, budget_input, None)
|
190 |
|
191 |
+
def model_ckpt_input_upload(x):
|
192 |
config.model_ckpt = x
|
193 |
+
return gr.File(value=x, visible=True)
|
194 |
|
195 |
+
model_ckpt_input.upload(model_ckpt_input_upload, model_ckpt_input, model_ckpt_file)
|
196 |
|
197 |
def device_input_change(x):
|
198 |
config.device = torch.device(x)
|