Spaces:
Running
on
Zero
Running
on
Zero
asigalov61
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -111,6 +111,8 @@ def load_model(model_selector):
|
|
111 |
print('Model will use', dtype, 'precision...')
|
112 |
print('=' * 70)
|
113 |
|
|
|
|
|
114 |
#==================================================================================
|
115 |
|
116 |
def load_midi(input_midi):
|
@@ -195,7 +197,8 @@ def generate_music(prime,
|
|
195 |
num_mem_tokens,
|
196 |
num_gen_batches,
|
197 |
model_temperature,
|
198 |
-
# model_sampling_top_p
|
|
|
199 |
):
|
200 |
|
201 |
if not prime:
|
@@ -203,6 +206,8 @@ def generate_music(prime,
|
|
203 |
|
204 |
else:
|
205 |
inputs = prime[-num_mem_tokens:]
|
|
|
|
|
206 |
|
207 |
model.cuda()
|
208 |
model.eval()
|
@@ -240,7 +245,8 @@ def generate_callback(input_midi,
|
|
240 |
# model_sampling_top_p,
|
241 |
final_composition,
|
242 |
generated_batches,
|
243 |
-
block_lines
|
|
|
244 |
):
|
245 |
|
246 |
generated_batches = []
|
@@ -255,7 +261,8 @@ def generate_callback(input_midi,
|
|
255 |
num_mem_tokens,
|
256 |
NUM_OUT_BATCHES,
|
257 |
model_temperature,
|
258 |
-
# model_sampling_top_p
|
|
|
259 |
)
|
260 |
|
261 |
outputs = []
|
@@ -310,7 +317,8 @@ def generate_callback_wrapper(input_midi,
|
|
310 |
final_composition,
|
311 |
generated_batches,
|
312 |
block_lines,
|
313 |
-
model_selector
|
|
|
314 |
):
|
315 |
|
316 |
print('=' * 70)
|
@@ -325,7 +333,8 @@ def generate_callback_wrapper(input_midi,
|
|
325 |
|
326 |
print('Selected model type:', model_selector)
|
327 |
|
328 |
-
|
|
|
329 |
|
330 |
print('Num prime tokens:', num_prime_tokens)
|
331 |
print('Num gen tokens:', num_gen_tokens)
|
@@ -343,7 +352,8 @@ def generate_callback_wrapper(input_midi,
|
|
343 |
# model_sampling_top_p,
|
344 |
final_composition,
|
345 |
generated_batches,
|
346 |
-
block_lines
|
|
|
347 |
)
|
348 |
|
349 |
generated_batches = [sublist[-1] for sublist in result[0]]
|
@@ -481,6 +491,7 @@ with gr.Blocks() as demo:
|
|
481 |
final_composition = gr.State([])
|
482 |
generated_batches = gr.State([])
|
483 |
block_lines = gr.State([])
|
|
|
484 |
|
485 |
#==================================================================================
|
486 |
|
@@ -529,7 +540,8 @@ with gr.Blocks() as demo:
|
|
529 |
final_composition,
|
530 |
generated_batches,
|
531 |
block_lines,
|
532 |
-
model_selector
|
|
|
533 |
],
|
534 |
outputs
|
535 |
)
|
|
|
111 |
print('Model will use', dtype, 'precision...')
|
112 |
print('=' * 70)
|
113 |
|
114 |
+
return model
|
115 |
+
|
116 |
#==================================================================================
|
117 |
|
118 |
def load_midi(input_midi):
|
|
|
197 |
num_mem_tokens,
|
198 |
num_gen_batches,
|
199 |
model_temperature,
|
200 |
+
# model_sampling_top_p,
|
201 |
+
model_state
|
202 |
):
|
203 |
|
204 |
if not prime:
|
|
|
206 |
|
207 |
else:
|
208 |
inputs = prime[-num_mem_tokens:]
|
209 |
+
|
210 |
+
model = model_state
|
211 |
|
212 |
model.cuda()
|
213 |
model.eval()
|
|
|
245 |
# model_sampling_top_p,
|
246 |
final_composition,
|
247 |
generated_batches,
|
248 |
+
block_lines,
|
249 |
+
model_state
|
250 |
):
|
251 |
|
252 |
generated_batches = []
|
|
|
261 |
num_mem_tokens,
|
262 |
NUM_OUT_BATCHES,
|
263 |
model_temperature,
|
264 |
+
# model_sampling_top_p,
|
265 |
+
model_state
|
266 |
)
|
267 |
|
268 |
outputs = []
|
|
|
317 |
final_composition,
|
318 |
generated_batches,
|
319 |
block_lines,
|
320 |
+
model_selector,
|
321 |
+
model_state
|
322 |
):
|
323 |
|
324 |
print('=' * 70)
|
|
|
333 |
|
334 |
print('Selected model type:', model_selector)
|
335 |
|
336 |
+
if not model_State:
|
337 |
+
model_state = load_model(model_selector)
|
338 |
|
339 |
print('Num prime tokens:', num_prime_tokens)
|
340 |
print('Num gen tokens:', num_gen_tokens)
|
|
|
352 |
# model_sampling_top_p,
|
353 |
final_composition,
|
354 |
generated_batches,
|
355 |
+
block_lines,
|
356 |
+
model_state
|
357 |
)
|
358 |
|
359 |
generated_batches = [sublist[-1] for sublist in result[0]]
|
|
|
491 |
final_composition = gr.State([])
|
492 |
generated_batches = gr.State([])
|
493 |
block_lines = gr.State([])
|
494 |
+
model_state = gr.State([])
|
495 |
|
496 |
#==================================================================================
|
497 |
|
|
|
540 |
final_composition,
|
541 |
generated_batches,
|
542 |
block_lines,
|
543 |
+
model_selector,
|
544 |
+
model_state
|
545 |
],
|
546 |
outputs
|
547 |
)
|