asigalov61 commited on
Commit
8586ed2
·
verified ·
1 Parent(s): 9b24b99

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +106 -79
app.py CHANGED
@@ -60,7 +60,7 @@ SOUDFONT_PATH = 'SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2'
60
 
61
  NUM_OUT_BATCHES = 8
62
 
63
- PREVIEW_LENGTH = 120
64
 
65
  #==================================================================================
66
 
@@ -253,12 +253,6 @@ def save_midi(tokens, batch_number=None):
253
 
254
  return song_f
255
 
256
- #==================================================================================
257
-
258
- final_composition = []
259
- generated_batches = []
260
- block_lines = []
261
-
262
  #==================================================================================
263
 
264
  @spaces.GPU
@@ -331,14 +325,16 @@ def generate_callback(input_midi,
331
  gen_outro,
332
  gen_drums,
333
  model_temperature,
334
- model_sampling_top_p
 
 
 
335
  ):
336
 
337
- # global generated_batches
338
  generated_batches = []
339
 
340
  if not final_composition and input_midi is not None:
341
- final_composition.extend(load_midi(input_midi)[:num_prime_tokens])
342
  midi_score = save_midi(final_composition)
343
  block_lines.append(midi_score[-1][1] / 1000)
344
 
@@ -389,9 +385,9 @@ def generate_callback(input_midi,
389
  output_for_gradio=True
390
  )
391
 
392
- outputs.append(((16000, midi_audio), midi_plot, tokens))
393
 
394
- return outputs
395
 
396
  #==================================================================================
397
 
@@ -402,23 +398,23 @@ def generate_callback_wrapper(input_midi,
402
  gen_outro,
403
  gen_drums,
404
  model_temperature,
405
- model_sampling_top_p
 
 
 
406
  ):
407
 
408
  print('=' * 70)
409
- print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
410
- start_time = reqtime.time()
411
 
412
- print('=' * 70)
413
  if input_midi is not None:
414
  fn = os.path.basename(input_midi.name)
415
  fn1 = fn.split('.')[0]
416
  print('Input file name:', fn)
 
417
  print('Num prime tokens:', num_prime_tokens)
418
  print('Num gen tokens:', num_gen_tokens)
419
- print('Num mem tokens:', num_mem_tokens)
420
- print('Gen drums:', gen_drums)
421
  print('Gen outro:', gen_outro)
 
422
  print('Model temp:', model_temperature)
423
  print('Model top_p:', model_sampling_top_p)
424
  print('=' * 70)
@@ -430,10 +426,13 @@ def generate_callback_wrapper(input_midi,
430
  gen_outro,
431
  gen_drums,
432
  model_temperature,
433
- model_sampling_top_p
 
 
 
434
  )
435
 
436
- generated_batches.extend([sublist[2] for sublist in result])
437
 
438
  print('=' * 70)
439
  print('Req end time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
@@ -441,74 +440,94 @@ def generate_callback_wrapper(input_midi,
441
  print('Req execution time:', (reqtime.time() - start_time), 'sec')
442
  print('*' * 70)
443
 
444
- return tuple(item for sublist in result for item in sublist[:2])
445
 
446
  #==================================================================================
447
 
448
- def add_batch(batch_number):
449
-
450
- final_composition.extend(generated_batches[batch_number])
451
 
452
- # Save MIDI to a temporary file
453
- midi_score = save_midi(final_composition)
454
- block_lines.append(midi_score[-1][1] / 1000)
455
 
456
- # MIDI plot
457
- midi_plot = TMIDIX.plot_ms_SONG(midi_score,
458
- plot_title='Giant Music Transformer Composition',
459
- block_lines_times_list=block_lines[:-1],
460
- return_plt=True)
461
 
462
- # File name
463
- fname = 'Giant-Music-Transformer-Music-Composition'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
464
 
465
- # Save audio to a temporary file
466
- midi_audio = midi_to_colab_audio(fname + '.mid',
467
- soundfont_path=SOUDFONT_PATH,
468
- sample_rate=16000,
469
- output_for_gradio=True
470
- )
471
 
472
- return (16000, midi_audio), midi_plot, fname+'.mid'
 
 
 
473
 
474
  #==================================================================================
475
 
476
- def remove_batch(batch_number, num_tokens):
477
 
478
- global final_composition
479
 
480
- if len(final_composition) > num_tokens:
481
- final_composition = final_composition[:-num_tokens]
482
- block_lines.pop()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
483
 
484
- # Save MIDI to a temporary file
485
- midi_score = save_midi(final_composition)
486
 
487
- # MIDI plot
488
- midi_plot = TMIDIX.plot_ms_SONG(midi_score,
489
- plot_title='Giant Music Transformer Composition',
490
- block_lines_times_list=block_lines[:-1],
491
- return_plt=True)
492
 
493
- # File name
494
- fname = 'Giant-Music-Transformer-Music-Composition'
495
 
496
- # Save audio to a temporary file
497
- midi_audio = midi_to_colab_audio(fname + '.mid',
498
- soundfont_path=SOUDFONT_PATH,
499
- sample_rate=16000,
500
- output_for_gradio=True
501
- )
502
-
503
- return (16000, midi_audio), midi_plot, fname+'.mid'
504
 
 
 
505
  #==================================================================================
506
 
507
- def reset():
508
-
509
- global final_composition
510
- global generated_batches
511
- global block_lines
512
 
513
  final_composition = []
514
  generated_batches = []
@@ -545,10 +564,17 @@ with gr.Blocks() as demo:
545
 
546
  #==================================================================================
547
 
 
 
 
 
 
 
548
  gr.Markdown("## Upload seed MIDI or click 'Generate' button for random output")
549
 
550
  input_midi = gr.File(label="Input MIDI", file_types=[".midi", ".mid", ".kar"])
551
- input_midi.upload(reset)
 
552
 
553
  gr.Markdown("## Generate")
554
 
@@ -564,7 +590,7 @@ with gr.Blocks() as demo:
564
 
565
  gr.Markdown("## Select batch")
566
 
567
- outputs = []
568
 
569
  for i in range(NUM_OUT_BATCHES):
570
  with gr.Tab(f"Batch # {i}") as tab:
@@ -582,7 +608,10 @@ with gr.Blocks() as demo:
582
  gen_outro,
583
  gen_drums,
584
  model_temperature,
585
- model_sampling_top_p
 
 
 
586
  ],
587
  outputs
588
  )
@@ -598,15 +627,13 @@ with gr.Blocks() as demo:
598
  final_plot_output = gr.Plot(label="Final MIDI plot")
599
  final_file_output = gr.File(label="Final MIDI file")
600
 
601
- add_btn.click(add_batch, inputs=[batch_number],
602
- outputs=[final_audio_output, final_plot_output, final_file_output]
603
- )
604
-
605
- remove_btn.click(remove_batch, inputs=[batch_number, num_gen_tokens],
606
- outputs=[final_audio_output, final_plot_output, final_file_output]
607
- )
608
 
609
- demo.unload(reset)
610
 
611
  #==================================================================================
612
 
 
60
 
61
  NUM_OUT_BATCHES = 8
62
 
63
+ PREVIEW_LENGTH = 120 # in tokens
64
 
65
  #==================================================================================
66
 
 
253
 
254
  return song_f
255
 
 
 
 
 
 
 
256
  #==================================================================================
257
 
258
  @spaces.GPU
 
325
  gen_outro,
326
  gen_drums,
327
  model_temperature,
328
+ model_sampling_top_p,
329
+ final_composition,
330
+ generated_batches,
331
+ block_lines
332
  ):
333
 
 
334
  generated_batches = []
335
 
336
  if not final_composition and input_midi is not None:
337
+ final_composition = load_midi(input_midi)[:num_prime_tokens]
338
  midi_score = save_midi(final_composition)
339
  block_lines.append(midi_score[-1][1] / 1000)
340
 
 
385
  output_for_gradio=True
386
  )
387
 
388
+ outputs.append([(16000, midi_audio), midi_plot, tokens])
389
 
390
+ return outputs, final_composition, generated_batches, block_lines
391
 
392
  #==================================================================================
393
 
 
398
  gen_outro,
399
  gen_drums,
400
  model_temperature,
401
+ model_sampling_top_p,
402
+ final_composition,
403
+ generated_batches,
404
+ block_lines
405
  ):
406
 
407
  print('=' * 70)
 
 
408
 
 
409
  if input_midi is not None:
410
  fn = os.path.basename(input_midi.name)
411
  fn1 = fn.split('.')[0]
412
  print('Input file name:', fn)
413
+
414
  print('Num prime tokens:', num_prime_tokens)
415
  print('Num gen tokens:', num_gen_tokens)
 
 
416
  print('Gen outro:', gen_outro)
417
+ print('Gen drums:', gen_drums)
418
  print('Model temp:', model_temperature)
419
  print('Model top_p:', model_sampling_top_p)
420
  print('=' * 70)
 
426
  gen_outro,
427
  gen_drums,
428
  model_temperature,
429
+ model_sampling_top_p,
430
+ final_composition,
431
+ generated_batches,
432
+ block_lines
433
  )
434
 
435
+ generated_batches = [sublist[-1] for sublist in result[0]]
436
 
437
  print('=' * 70)
438
  print('Req end time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
 
440
  print('Req execution time:', (reqtime.time() - start_time), 'sec')
441
  print('*' * 70)
442
 
443
+ return tuple([result[1], generated_batches, result[3]] + [item for sublist in result[0] for item in sublist[:-1]])
444
 
445
  #==================================================================================
446
 
447
+ def add_batch(batch_number, final_composition, generated_batches, block_lines):
 
 
448
 
449
+ if generated_batches:
450
+ final_composition.extend(generated_batches[batch_number])
 
451
 
452
+ # Save MIDI to a temporary file
453
+ midi_score = save_midi(final_composition)
 
 
 
454
 
455
+ block_lines.append(midi_score[-1][1] / 1000)
456
+
457
+ # MIDI plot
458
+ midi_plot = TMIDIX.plot_ms_SONG(midi_score,
459
+ plot_title='Giant Music Transformer Composition',
460
+ block_lines_times_list=block_lines[:-1],
461
+ return_plt=True)
462
+
463
+ # File name
464
+ fname = 'Giant-Music-Transformer-Music-Composition'
465
+
466
+ # Save audio to a temporary file
467
+ midi_audio = midi_to_colab_audio(fname + '.mid',
468
+ soundfont_path=SOUDFONT_PATH,
469
+ sample_rate=16000,
470
+ output_for_gradio=True
471
+ )
472
 
473
+ print('Added batch #', batch_number)
474
+ print('=' * 70)
 
 
 
 
475
 
476
+ return (16000, midi_audio), midi_plot, fname+'.mid', final_composition, generated_batches, block_lines
477
+
478
+ else:
479
+ return None, None, None, [], [], []
480
 
481
  #==================================================================================
482
 
483
+ def remove_batch(batch_number, num_tokens, final_composition, generated_batches, block_lines):
484
 
485
+ if final_composition:
486
 
487
+ if len(final_composition) > num_tokens:
488
+ final_composition = final_composition[:-num_tokens]
489
+ block_lines.pop()
490
+
491
+ # Save MIDI to a temporary file
492
+ midi_score = save_midi(final_composition)
493
+
494
+ # MIDI plot
495
+ midi_plot = TMIDIX.plot_ms_SONG(midi_score,
496
+ plot_title='Giant Music Transformer Composition',
497
+ block_lines_times_list=block_lines[:-1],
498
+ return_plt=True)
499
+
500
+ # File name
501
+ fname = 'Giant-Music-Transformer-Music-Composition'
502
+
503
+ # Save audio to a temporary file
504
+ midi_audio = midi_to_colab_audio(fname + '.mid',
505
+ soundfont_path=SOUDFONT_PATH,
506
+ sample_rate=16000,
507
+ output_for_gradio=True
508
+ )
509
+
510
+ print('Removed batch #', batch_number)
511
+ print('=' * 70)
512
+
513
+ return (16000, midi_audio), midi_plot, fname+'.mid', final_composition, generated_batches, block_lines
514
 
515
+ else:
516
+ return None, None, None, [], [], []
517
 
518
+ #==================================================================================
 
 
 
 
519
 
520
+ def reset(final_composition=[], generated_batches=[], block_lines=[]):
 
521
 
522
+ final_composition = []
523
+ generated_batches = []
524
+ block_lines = []
 
 
 
 
 
525
 
526
+ return final_composition, generated_batches, block_lines
527
+
528
  #==================================================================================
529
 
530
+ def reset_demo(final_composition=[], generated_batches=[], block_lines=[]):
 
 
 
 
531
 
532
  final_composition = []
533
  generated_batches = []
 
564
 
565
  #==================================================================================
566
 
567
+ final_composition = gr.State([])
568
+ generated_batches = gr.State([])
569
+ block_lines = gr.State([])
570
+
571
+ #==================================================================================
572
+
573
  gr.Markdown("## Upload seed MIDI or click 'Generate' button for random output")
574
 
575
  input_midi = gr.File(label="Input MIDI", file_types=[".midi", ".mid", ".kar"])
576
+ input_midi.upload(reset, [final_composition, generated_batches, block_lines],
577
+ [final_composition, generated_batches, block_lines])
578
 
579
  gr.Markdown("## Generate")
580
 
 
590
 
591
  gr.Markdown("## Select batch")
592
 
593
+ outputs = [final_composition, generated_batches, block_lines]
594
 
595
  for i in range(NUM_OUT_BATCHES):
596
  with gr.Tab(f"Batch # {i}") as tab:
 
608
  gen_outro,
609
  gen_drums,
610
  model_temperature,
611
+ model_sampling_top_p,
612
+ final_composition,
613
+ generated_batches,
614
+ block_lines
615
  ],
616
  outputs
617
  )
 
627
  final_plot_output = gr.Plot(label="Final MIDI plot")
628
  final_file_output = gr.File(label="Final MIDI file")
629
 
630
+ add_btn.click(add_batch, [batch_number, final_composition, generated_batches, block_lines],
631
+ [final_audio_output, final_plot_output, final_file_output, final_composition, generated_batches, block_lines])
632
+
633
+ remove_btn.click(remove_batch, [batch_number, num_gen_tokens, final_composition, generated_batches, block_lines],
634
+ [final_audio_output, final_plot_output, final_file_output, final_composition, generated_batches, block_lines])
 
 
635
 
636
+ demo.unload(reset_demo)
637
 
638
  #==================================================================================
639