asigalov61 commited on
Commit
dd9d99a
·
verified ·
1 Parent(s): 83355ca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -41
app.py CHANGED
@@ -56,7 +56,10 @@ print('=' * 70)
56
 
57
  #==================================================================================
58
 
59
- MODEL_CHECKPOINT_VEL = 'Monster_Piano_Transformer_Velocity_Trained_Model_59896_steps_0.9055_loss_0.735_acc.pth'
 
 
 
60
 
61
  SOUDFONT_PATH = 'SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2'
62
 
@@ -66,45 +69,47 @@ PREVIEW_LENGTH = 120 # in tokens
66
 
67
  #==================================================================================
68
 
69
- print('=' * 70)
70
- print('Instantiating model...')
71
-
72
- device_type = 'cuda'
73
- dtype = 'bfloat16'
74
-
75
- ptdtype = {'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
76
- ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype)
77
-
78
- SEQ_LEN = 2048
79
- PAD_IDX = 512
80
-
81
- model = TransformerWrapper(
82
- num_tokens = PAD_IDX+1,
83
- max_seq_len = SEQ_LEN,
84
- attn_layers = Decoder(dim = 2048,
85
- depth = 4,
86
- heads = 32,
87
- rotary_pos_emb = True,
88
- attn_flash = True
89
- )
90
- )
91
-
92
- model = AutoregressiveWrapper(model, ignore_index=PAD_IDX, pad_value=PAD_IDX)
93
-
94
- print('=' * 70)
95
- print('Loading model checkpoint...')
96
-
97
- model_checkpoint = hf_hub_download(repo_id='asigalov61/Monster-Piano-Transformer', filename=MODEL_CHECKPOINT_VEL)
98
 
99
- model.load_state_dict(torch.load(model_checkpoint, map_location='cpu'))
100
-
101
- model = torch.compile(model, mode='max-autotune')
102
-
103
- print('=' * 70)
104
- print('Done!')
105
- print('=' * 70)
106
- print('Model will use', dtype, 'precision...')
107
- print('=' * 70)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
  #==================================================================================
110
 
@@ -304,7 +309,8 @@ def generate_callback_wrapper(input_midi,
304
  # model_sampling_top_p,
305
  final_composition,
306
  generated_batches,
307
- block_lines
 
308
  ):
309
 
310
  print('=' * 70)
@@ -317,6 +323,10 @@ def generate_callback_wrapper(input_midi,
317
  fn1 = fn.split('.')[0]
318
  print('Input file name:', fn)
319
 
 
 
 
 
320
  print('Num prime tokens:', num_prime_tokens)
321
  print('Num gen tokens:', num_gen_tokens)
322
  print('Num mem tokens:', num_mem_tokens)
@@ -481,6 +491,13 @@ with gr.Blocks() as demo:
481
  [final_composition, generated_batches, block_lines])
482
 
483
  gr.Markdown("## Generate")
 
 
 
 
 
 
 
484
 
485
  num_prime_tokens = gr.Slider(15, 1024, value=1024, step=1, label="Number of prime tokens")
486
  num_gen_tokens = gr.Slider(15, 1024, value=1024, step=1, label="Number of tokens to generate")
@@ -511,7 +528,8 @@ with gr.Blocks() as demo:
511
  # model_sampling_top_p,
512
  final_composition,
513
  generated_batches,
514
- block_lines
 
515
  ],
516
  outputs
517
  )
 
56
 
57
  #==================================================================================
58
 
59
+ MODEL_CHECKPOINTS = {
60
+ 'with velocity': 'Monster_Piano_Transformer_Velocity_Trained_Model_59896_steps_0.9055_loss_0.735_acc.pth',
61
+ 'without velocity': 'Monster_Piano_Transformer_Velocity_Trained_Model_59896_steps_0.9055_loss_0.735_acc.pth'
62
+ }
63
 
64
  SOUDFONT_PATH = 'SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2'
65
 
 
69
 
70
  #==================================================================================
71
 
72
+ def load_model(model_selector):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
+ print('=' * 70)
75
+ print('Instantiating model...')
76
+
77
+ device_type = 'cuda'
78
+ dtype = 'bfloat16'
79
+
80
+ ptdtype = {'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
81
+ ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype)
82
+
83
+ SEQ_LEN = 2048
84
+ PAD_IDX = 512
85
+
86
+ model = TransformerWrapper(
87
+ num_tokens = PAD_IDX+1,
88
+ max_seq_len = SEQ_LEN,
89
+ attn_layers = Decoder(dim = 2048,
90
+ depth = 4,
91
+ heads = 32,
92
+ rotary_pos_emb = True,
93
+ attn_flash = True
94
+ )
95
+ )
96
+
97
+ model = AutoregressiveWrapper(model, ignore_index=PAD_IDX, pad_value=PAD_IDX)
98
+
99
+ print('=' * 70)
100
+ print('Loading model checkpoint...')
101
+
102
+ model_checkpoint = hf_hub_download(repo_id='asigalov61/Monster-Piano-Transformer', filename=MODEL_CHECKPOINTS[model_selector])
103
+
104
+ model.load_state_dict(torch.load(model_checkpoint, map_location='cpu'))
105
+
106
+ model = torch.compile(model, mode='max-autotune')
107
+
108
+ print('=' * 70)
109
+ print('Done!')
110
+ print('=' * 70)
111
+ print('Model will use', dtype, 'precision...')
112
+ print('=' * 70)
113
 
114
  #==================================================================================
115
 
 
309
  # model_sampling_top_p,
310
  final_composition,
311
  generated_batches,
312
+ block_lines,
313
+ model_selector
314
  ):
315
 
316
  print('=' * 70)
 
323
  fn1 = fn.split('.')[0]
324
  print('Input file name:', fn)
325
 
326
+ print('Selected model type:', model_selector)
327
+
328
+ load_model(model_selector)
329
+
330
  print('Num prime tokens:', num_prime_tokens)
331
  print('Num gen tokens:', num_gen_tokens)
332
  print('Num mem tokens:', num_mem_tokens)
 
491
  [final_composition, generated_batches, block_lines])
492
 
493
  gr.Markdown("## Generate")
494
+
495
+ model_selector = gr.gr.Dropdown(["with velocity",
496
+ "Without velocity"
497
+ ],
498
+ label="Select model",
499
+ info="Select desired Monster Piano Transformer model"
500
+ )
501
 
502
  num_prime_tokens = gr.Slider(15, 1024, value=1024, step=1, label="Number of prime tokens")
503
  num_gen_tokens = gr.Slider(15, 1024, value=1024, step=1, label="Number of tokens to generate")
 
528
  # model_sampling_top_p,
529
  final_composition,
530
  generated_batches,
531
+ block_lines,
532
+ model_selector
533
  ],
534
  outputs
535
  )