asigalov61 commited on
Commit
e02c3f7
·
verified ·
1 Parent(s): b4bb336

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -7
app.py CHANGED
@@ -111,6 +111,8 @@ def load_model(model_selector):
111
  print('Model will use', dtype, 'precision...')
112
  print('=' * 70)
113
 
 
 
114
  #==================================================================================
115
 
116
  def load_midi(input_midi):
@@ -195,7 +197,8 @@ def generate_music(prime,
195
  num_mem_tokens,
196
  num_gen_batches,
197
  model_temperature,
198
- # model_sampling_top_p
 
199
  ):
200
 
201
  if not prime:
@@ -203,6 +206,8 @@ def generate_music(prime,
203
 
204
  else:
205
  inputs = prime[-num_mem_tokens:]
 
 
206
 
207
  model.cuda()
208
  model.eval()
@@ -240,7 +245,8 @@ def generate_callback(input_midi,
240
  # model_sampling_top_p,
241
  final_composition,
242
  generated_batches,
243
- block_lines
 
244
  ):
245
 
246
  generated_batches = []
@@ -255,7 +261,8 @@ def generate_callback(input_midi,
255
  num_mem_tokens,
256
  NUM_OUT_BATCHES,
257
  model_temperature,
258
- # model_sampling_top_p
 
259
  )
260
 
261
  outputs = []
@@ -310,7 +317,8 @@ def generate_callback_wrapper(input_midi,
310
  final_composition,
311
  generated_batches,
312
  block_lines,
313
- model_selector
 
314
  ):
315
 
316
  print('=' * 70)
@@ -325,7 +333,8 @@ def generate_callback_wrapper(input_midi,
325
 
326
  print('Selected model type:', model_selector)
327
 
328
- load_model(model_selector)
 
329
 
330
  print('Num prime tokens:', num_prime_tokens)
331
  print('Num gen tokens:', num_gen_tokens)
@@ -343,7 +352,8 @@ def generate_callback_wrapper(input_midi,
343
  # model_sampling_top_p,
344
  final_composition,
345
  generated_batches,
346
- block_lines
 
347
  )
348
 
349
  generated_batches = [sublist[-1] for sublist in result[0]]
@@ -481,6 +491,7 @@ with gr.Blocks() as demo:
481
  final_composition = gr.State([])
482
  generated_batches = gr.State([])
483
  block_lines = gr.State([])
 
484
 
485
  #==================================================================================
486
 
@@ -529,7 +540,8 @@ with gr.Blocks() as demo:
529
  final_composition,
530
  generated_batches,
531
  block_lines,
532
- model_selector
 
533
  ],
534
  outputs
535
  )
 
111
  print('Model will use', dtype, 'precision...')
112
  print('=' * 70)
113
 
114
+ return model
115
+
116
  #==================================================================================
117
 
118
  def load_midi(input_midi):
 
197
  num_mem_tokens,
198
  num_gen_batches,
199
  model_temperature,
200
+ # model_sampling_top_p,
201
+ model_state
202
  ):
203
 
204
  if not prime:
 
206
 
207
  else:
208
  inputs = prime[-num_mem_tokens:]
209
+
210
+ model = model_state
211
 
212
  model.cuda()
213
  model.eval()
 
245
  # model_sampling_top_p,
246
  final_composition,
247
  generated_batches,
248
+ block_lines,
249
+ model_state
250
  ):
251
 
252
  generated_batches = []
 
261
  num_mem_tokens,
262
  NUM_OUT_BATCHES,
263
  model_temperature,
264
+ # model_sampling_top_p,
265
+ model_state
266
  )
267
 
268
  outputs = []
 
317
  final_composition,
318
  generated_batches,
319
  block_lines,
320
+ model_selector,
321
+ model_state
322
  ):
323
 
324
  print('=' * 70)
 
333
 
334
  print('Selected model type:', model_selector)
335
 
336
+ if not model_State:
337
+ model_state = load_model(model_selector)
338
 
339
  print('Num prime tokens:', num_prime_tokens)
340
  print('Num gen tokens:', num_gen_tokens)
 
352
  # model_sampling_top_p,
353
  final_composition,
354
  generated_batches,
355
+ block_lines,
356
+ model_state
357
  )
358
 
359
  generated_batches = [sublist[-1] for sublist in result[0]]
 
491
  final_composition = gr.State([])
492
  generated_batches = gr.State([])
493
  block_lines = gr.State([])
494
+ model_state = gr.State([])
495
 
496
  #==================================================================================
497
 
 
540
  final_composition,
541
  generated_batches,
542
  block_lines,
543
+ model_selector,
544
+ model_state
545
  ],
546
  outputs
547
  )