#================================================================================== # https://huggingface.co/spaces/asigalov61/Guided-Accompaniment-Transformer #================================================================================== print('=' * 70) print('Guided Accompaniment Transformer Gradio App') print('=' * 70) print('Loading core Guided Accompaniment Transformer modules...') import os import time as reqtime import datetime from pytz import timezone print('=' * 70) print('Loading main Guided Accompaniment Transformer modules...') os.environ['USE_FLASH_ATTENTION'] = '1' import torch torch.set_float32_matmul_precision('high') torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn torch.backends.cuda.enable_mem_efficient_sdp(True) torch.backends.cuda.enable_math_sdp(True) torch.backends.cuda.enable_flash_sdp(True) torch.backends.cuda.enable_cudnn_sdp(True) from huggingface_hub import hf_hub_download import TMIDIX from midi_to_colab_audio import midi_to_colab_audio from x_transformer_1_23_2 import * import random print('=' * 70) print('Loading aux Guided Accompaniment Transformer modules...') import matplotlib.pyplot as plt import gradio as gr import spaces print('=' * 70) print('PyTorch version:', torch.__version__) print('=' * 70) print('Done!') print('Enjoy! :)') print('=' * 70) #================================================================================== MODEL_CHECKPOINTS = { 'with velocity - 3 epochs': 'Monster_Piano_Transformer_Velocity_Trained_Model_59896_steps_0.9055_loss_0.735_acc.pth', 'without velocity - 3 epochs': 'Monster_Piano_Transformer_No_Velocity_Trained_Model_69412_steps_0.8577_loss_0.7442_acc.pth', 'without velocity - 7 epochs': 'Monster_Piano_Transformer_No_Velocity_Trained_Model_161960_steps_0.7775_loss_0.7661_acc.pth' } SOUDFONT_PATH = 'SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2' NUM_OUT_BATCHES = 12 PREVIEW_LENGTH = 120 # in tokens #================================================================================== def load_model(model_selector): print('=' * 70) print('Instantiating model...') device_type = 'cuda' dtype = 'bfloat16' ptdtype = {'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) SEQ_LEN = 2048 if model_selector == 'with velocity - 3 epochs': PAD_IDX = 512 else: PAD_IDX = 384 model = TransformerWrapper( num_tokens = PAD_IDX+1, max_seq_len = SEQ_LEN, attn_layers = Decoder(dim = 2048, depth = 4, heads = 32, rotary_pos_emb = True, attn_flash = True ) ) model = AutoregressiveWrapper(model, ignore_index=PAD_IDX, pad_value=PAD_IDX) print('=' * 70) print('Loading model checkpoint...') model_checkpoint = hf_hub_download(repo_id='asigalov61/Guided-Accompaniment-Transformer', filename=MODEL_CHECKPOINTS[model_selector]) model.load_state_dict(torch.load(model_checkpoint, map_location='cpu', weights_only=True)) model = torch.compile(model, mode='max-autotune') print('=' * 70) print('Done!') print('=' * 70) print('Model will use', dtype, 'precision...') print('=' * 70) return [model, ctx] #================================================================================== def load_midi(input_midi, model_selector=''): raw_score = TMIDIX.midi2single_track_ms_score(input_midi.name) escore_notes = TMIDIX.advanced_score_processor(raw_score, return_enhanced_score_notes=True)[0] escore_notes = TMIDIX.augment_enhanced_score_notes(escore_notes, timings_divider=32) sp_escore_notes = TMIDIX.solo_piano_escore_notes(escore_notes, keep_drums=False) zscore = TMIDIX.recalculate_score_timings(sp_escore_notes) cscore = TMIDIX.chordify_score([1000, zscore]) score = [] pc = cscore[0] for c in cscore: score.append(max(0, min(127, c[0][1]-pc[0][1]))) for n in c: if model_selector == 'with velocity - 3 epochs': score.extend([max(1, min(127, n[2]))+128, max(1, min(127, n[4]))+256, max(1, min(127, n[5]))+384]) else: score.extend([max(1, min(127, n[2]))+128, max(1, min(127, n[4]))+256]) pc = c return score #================================================================================== def save_midi(tokens, batch_number=None, model_selector=''): song = tokens song_f = [] time = 0 dur = 0 vel = 90 pitch = 0 channel = 0 patch = 0 patches = [0] * 16 for m in song: if 0 <= m < 128: time += m * 32 elif 128 < m < 256: dur = (m-128) * 32 elif 256 < m < 384: pitch = (m-256) if model_selector == 'without velocity - 3 epochs' or model_selector == 'without velocity - 7 epochs': song_f.append(['note', time, dur, 0, pitch, max(40, pitch), 0]) elif 384 < m < 512: vel = (m-384) if model_selector == 'with velocity - 3 epochs': song_f.append(['note', time, dur, 0, pitch, vel, 0]) if batch_number == None: fname = 'Guided-Accompaniment-Transformer-Music-Composition' else: fname = 'Guided-Accompaniment-Transformer-Music-Composition_'+str(batch_number) data = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(song_f, output_signature = 'Guided Accompaniment Transformer', output_file_name = fname, track_name='Project Los Angeles', list_of_MIDI_patches=patches, verbose=False ) return song_f #================================================================================== @spaces.GPU def generate_music(prime, num_gen_tokens, num_mem_tokens, num_gen_batches, model_temperature, # model_sampling_top_p, model_state ): if not prime: inputs = [0] else: inputs = prime[-num_mem_tokens:] model = model_state[0] ctx = model_state[1] model.cuda() model.eval() print('Generating...') inp = [inputs] * num_gen_batches inp = torch.LongTensor(inp).cuda() with ctx: out = model.generate(inp, num_gen_tokens, #filter_logits_fn=top_p, #filter_kwargs={'thres': model_sampling_top_p}, temperature=model_temperature, return_prime=False, verbose=False) output = out.tolist() print('Done!') print('=' * 70) return output #================================================================================== def generate_callback(input_midi, num_prime_tokens, num_gen_tokens, num_mem_tokens, model_temperature, # model_sampling_top_p, final_composition, generated_batches, block_lines, model_state ): generated_batches = [] if not final_composition and input_midi is not None: final_composition = load_midi(input_midi, model_selector=model_state[2])[:num_prime_tokens] midi_score = save_midi(final_composition, model_selector=model_state[2]) block_lines.append(midi_score[-1][1] / 1000) batched_gen_tokens = generate_music(final_composition, num_gen_tokens, num_mem_tokens, NUM_OUT_BATCHES, model_temperature, # model_sampling_top_p, model_state ) outputs = [] for i in range(len(batched_gen_tokens)): tokens = batched_gen_tokens[i] # Preview tokens_preview = final_composition[-PREVIEW_LENGTH:] # Save MIDI to a temporary file midi_score = save_midi(tokens_preview + tokens, i, model_selector=model_state[2]) # MIDI plot if len(final_composition) > PREVIEW_LENGTH: midi_plot = TMIDIX.plot_ms_SONG(midi_score, plot_title='Batch # ' + str(i), preview_length_in_notes=int(PREVIEW_LENGTH / 3), return_plt=True ) else: midi_plot = TMIDIX.plot_ms_SONG(midi_score, plot_title='Batch # ' + str(i), return_plt=True ) # File name fname = 'Guided-Accompaniment-Transformer-Music-Composition_'+str(i) # Save audio to a temporary file midi_audio = midi_to_colab_audio(fname + '.mid', soundfont_path=SOUDFONT_PATH, sample_rate=16000, output_for_gradio=True ) outputs.append([(16000, midi_audio), midi_plot, tokens]) return outputs, final_composition, generated_batches, block_lines #================================================================================== def generate_callback_wrapper(input_midi, num_prime_tokens, num_gen_tokens, num_mem_tokens, model_temperature, # model_sampling_top_p, final_composition, generated_batches, block_lines, model_selector, model_state ): print('=' * 70) print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT))) start_time = reqtime.time() print('=' * 70) if input_midi is not None: fn = os.path.basename(input_midi.name) fn1 = fn.split('.')[0] print('Input file name:', fn) print('Selected model type:', model_selector) if not model_state: model_state = load_model(model_selector) model_state.append(model_selector) else: if model_selector != model_state[2]: print('=' * 70) print('Switching model...') model_state = load_model(model_selector) model_state.append(model_selector) print('=' * 70) print('Num prime tokens:', num_prime_tokens) print('Num gen tokens:', num_gen_tokens) print('Num mem tokens:', num_mem_tokens) print('Model temp:', model_temperature) # print('Model top_p:', model_sampling_top_p) print('=' * 70) result = generate_callback(input_midi, num_prime_tokens, num_gen_tokens, num_mem_tokens, model_temperature, # model_sampling_top_p, final_composition, generated_batches, block_lines, model_state ) generated_batches = [sublist[-1] for sublist in result[0]] print('=' * 70) print('Req end time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT))) print('=' * 70) print('Req execution time:', (reqtime.time() - start_time), 'sec') print('*' * 70) return tuple([result[1], generated_batches, result[3]] + [item for sublist in result[0] for item in sublist[:-1]] + [model_state]) #================================================================================== def reset(final_composition=[], generated_batches=[], block_lines=[], model_state=[]): final_composition = [] generated_batches = [] block_lines = [] model_state = [] return final_composition, generated_batches, block_lines #================================================================================== def reset_demo(final_composition=[], generated_batches=[], block_lines=[], model_state=[]): final_composition = [] generated_batches = [] block_lines = [] model_state = [] #================================================================================== PDT = timezone('US/Pacific') print('=' * 70) print('App start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT))) print('=' * 70) #================================================================================== with gr.Blocks() as demo: #================================================================================== demo.load(reset_demo) #================================================================================== gr.Markdown("

Guided Accompaniment Transformer

") gr.Markdown("

Guided melody accompaniment generation with transformers

") gr.HTML(""" Check out Guided Accompaniment Transformer on GitHub or on

PyPI Project or Duplicate in Hugging Face

for faster execution and endless generation! """) #================================================================================== final_composition = gr.State([]) generated_batches = gr.State([]) block_lines = gr.State([]) model_state = gr.State([]) #================================================================================== gr.Markdown("## Upload seed MIDI or click 'Generate' button for random output") input_midi = gr.File(label="Input MIDI", file_types=[".midi", ".mid", ".kar"]) input_midi.upload(reset, [final_composition, generated_batches, block_lines], [final_composition, generated_batches, block_lines]) gr.Markdown("## Generate") model_selector = gr.Dropdown(["without velocity - 7 epochs", "without velocity - 3 epochs", "with velocity - 3 epochs" ], label="Select model", ) num_prime_tokens = gr.Slider(15, 1024, value=1024, step=1, label="Number of prime tokens") num_gen_tokens = gr.Slider(15, 1024, value=1024, step=1, label="Number of tokens to generate") num_mem_tokens = gr.Slider(15, 2048, value=2048, step=1, label="Number of memory tokens") model_temperature = gr.Slider(0.1, 1, value=0.9, step=0.01, label="Model temperature") # model_sampling_top_p = gr.Slider(0.1, 1, value=0.96, step=0.01, label="Model sampling top p value") generate_btn = gr.Button("Generate", variant="primary") gr.Markdown("## Select batch") outputs = [final_composition, generated_batches, block_lines] for i in range(NUM_OUT_BATCHES): with gr.Tab(f"Batch # {i}") as tab: audio_output = gr.Audio(label=f"Batch # {i} MIDI Audio", format="mp3", elem_id="midi_audio") plot_output = gr.Plot(label=f"Batch # {i} MIDI Plot") outputs.extend([audio_output, plot_output]) outputs.extend([model_state]) generate_btn.click(generate_callback_wrapper, [input_midi, num_prime_tokens, num_gen_tokens, num_mem_tokens, model_temperature, # model_sampling_top_p, final_composition, generated_batches, block_lines, model_selector, model_state ], outputs ) #================================================================================== demo.unload(reset_demo) #================================================================================== demo.launch() #==================================================================================