asigalov61 commited on
Commit
2ec9c41
·
verified ·
1 Parent(s): 3eaad6d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -12
app.py CHANGED
@@ -115,7 +115,7 @@ def load_model(model_selector):
115
 
116
  #==================================================================================
117
 
118
- def load_midi(input_midi):
119
 
120
  raw_score = TMIDIX.midi2single_track_ms_score(input_midi.name)
121
 
@@ -143,7 +143,7 @@ def load_midi(input_midi):
143
 
144
  #==================================================================================
145
 
146
- def save_midi(tokens, batch_number=None):
147
 
148
  song = tokens
149
  song_f = []
@@ -253,7 +253,7 @@ def generate_callback(input_midi,
253
 
254
  if not final_composition and input_midi is not None:
255
  final_composition = load_midi(input_midi)[:num_prime_tokens]
256
- midi_score = save_midi(final_composition)
257
  block_lines.append(midi_score[-1][1] / 1000)
258
 
259
  batched_gen_tokens = generate_music(final_composition,
@@ -275,7 +275,7 @@ def generate_callback(input_midi,
275
  tokens_preview = final_composition[-PREVIEW_LENGTH:]
276
 
277
  # Save MIDI to a temporary file
278
- midi_score = save_midi(tokens_preview + tokens, i)
279
 
280
  # MIDI plot
281
 
@@ -335,6 +335,7 @@ def generate_callback_wrapper(input_midi,
335
 
336
  if not model_state:
337
  model_state = load_model(model_selector)
 
338
 
339
  print('Num prime tokens:', num_prime_tokens)
340
  print('Num gen tokens:', num_gen_tokens)
@@ -368,13 +369,13 @@ def generate_callback_wrapper(input_midi,
368
 
369
  #==================================================================================
370
 
371
- def add_batch(batch_number, final_composition, generated_batches, block_lines):
372
 
373
  if generated_batches:
374
  final_composition.extend(generated_batches[batch_number])
375
 
376
  # Save MIDI to a temporary file
377
- midi_score = save_midi(final_composition)
378
 
379
  block_lines.append(midi_score[-1][1] / 1000)
380
 
@@ -404,7 +405,7 @@ def add_batch(batch_number, final_composition, generated_batches, block_lines):
404
 
405
  #==================================================================================
406
 
407
- def remove_batch(batch_number, num_tokens, final_composition, generated_batches, block_lines):
408
 
409
  if final_composition:
410
 
@@ -413,7 +414,7 @@ def remove_batch(batch_number, num_tokens, final_composition, generated_batches,
413
  block_lines.pop()
414
 
415
  # Save MIDI to a temporary file
416
- midi_score = save_midi(final_composition)
417
 
418
  # MIDI plot
419
  midi_plot = TMIDIX.plot_ms_SONG(midi_score,
@@ -441,13 +442,14 @@ def remove_batch(batch_number, num_tokens, final_composition, generated_batches,
441
 
442
  #==================================================================================
443
 
444
- def reset(final_composition=[], generated_batches=[], block_lines=[]):
445
 
446
  final_composition = []
447
  generated_batches = []
448
  block_lines = []
 
449
 
450
- return final_composition, generated_batches, block_lines
451
 
452
  #==================================================================================
453
 
@@ -456,6 +458,7 @@ def reset_demo(final_composition=[], generated_batches=[], block_lines=[]):
456
  final_composition = []
457
  generated_batches = []
458
  block_lines = []
 
459
 
460
  #==================================================================================
461
 
@@ -557,10 +560,10 @@ with gr.Blocks() as demo:
557
  final_plot_output = gr.Plot(label="Final MIDI plot")
558
  final_file_output = gr.File(label="Final MIDI file")
559
 
560
- add_btn.click(add_batch, [batch_number, final_composition, generated_batches, block_lines],
561
  [final_audio_output, final_plot_output, final_file_output, final_composition, generated_batches, block_lines])
562
 
563
- remove_btn.click(remove_batch, [batch_number, num_gen_tokens, final_composition, generated_batches, block_lines],
564
  [final_audio_output, final_plot_output, final_file_output, final_composition, generated_batches, block_lines])
565
 
566
  demo.unload(reset_demo)
 
115
 
116
  #==================================================================================
117
 
118
+ def load_midi(input_midi, model_selector=''):
119
 
120
  raw_score = TMIDIX.midi2single_track_ms_score(input_midi.name)
121
 
 
143
 
144
  #==================================================================================
145
 
146
+ def save_midi(tokens, batch_number=None, model_selector=''):
147
 
148
  song = tokens
149
  song_f = []
 
253
 
254
  if not final_composition and input_midi is not None:
255
  final_composition = load_midi(input_midi)[:num_prime_tokens]
256
+ midi_score = save_midi(final_composition, model_selector=model_state[2])
257
  block_lines.append(midi_score[-1][1] / 1000)
258
 
259
  batched_gen_tokens = generate_music(final_composition,
 
275
  tokens_preview = final_composition[-PREVIEW_LENGTH:]
276
 
277
  # Save MIDI to a temporary file
278
+ midi_score = save_midi(tokens_preview + tokens, i, model_selector=model_state[2])
279
 
280
  # MIDI plot
281
 
 
335
 
336
  if not model_state:
337
  model_state = load_model(model_selector)
338
+ model_state.append(model_selector)
339
 
340
  print('Num prime tokens:', num_prime_tokens)
341
  print('Num gen tokens:', num_gen_tokens)
 
369
 
370
  #==================================================================================
371
 
372
+ def add_batch(batch_number, final_composition, generated_batches, block_lines, model_state=[]):
373
 
374
  if generated_batches:
375
  final_composition.extend(generated_batches[batch_number])
376
 
377
  # Save MIDI to a temporary file
378
+ midi_score = save_midi(final_composition, model_selector=model_state[2])
379
 
380
  block_lines.append(midi_score[-1][1] / 1000)
381
 
 
405
 
406
  #==================================================================================
407
 
408
+ def remove_batch(batch_number, num_tokens, final_composition, generated_batches, block_lines, model_state=[]):
409
 
410
  if final_composition:
411
 
 
414
  block_lines.pop()
415
 
416
  # Save MIDI to a temporary file
417
+ midi_score = save_midi(final_composition, model_selector=model_state[2])
418
 
419
  # MIDI plot
420
  midi_plot = TMIDIX.plot_ms_SONG(midi_score,
 
442
 
443
  #==================================================================================
444
 
445
+ def reset(final_composition=[], generated_batches=[], block_lines=[], model_state = []):
446
 
447
  final_composition = []
448
  generated_batches = []
449
  block_lines = []
450
+ model_state = []
451
 
452
+ return final_composition, generated_batches, block_lines, model_state
453
 
454
  #==================================================================================
455
 
 
458
  final_composition = []
459
  generated_batches = []
460
  block_lines = []
461
+ model_state = []
462
 
463
  #==================================================================================
464
 
 
560
  final_plot_output = gr.Plot(label="Final MIDI plot")
561
  final_file_output = gr.File(label="Final MIDI file")
562
 
563
+ add_btn.click(add_batch, [batch_number, final_composition, generated_batches, block_lines, model_state],
564
  [final_audio_output, final_plot_output, final_file_output, final_composition, generated_batches, block_lines])
565
 
566
+ remove_btn.click(remove_batch, [batch_number, num_gen_tokens, final_composition, generated_batches, block_lines, model_state],
567
  [final_audio_output, final_plot_output, final_file_output, final_composition, generated_batches, block_lines])
568
 
569
  demo.unload(reset_demo)