#================================================================================== # https://huggingface.co/spaces/asigalov61/Guided-Accompaniment-Transformer #================================================================================== print('=' * 70) print('Guided Accompaniment Transformer Gradio App') print('=' * 70) print('Loading core Guided Accompaniment Transformer modules...') import os import copy import time as reqtime import datetime from pytz import timezone print('=' * 70) print('Loading main Guided Accompaniment Transformer modules...') os.environ['USE_FLASH_ATTENTION'] = '1' import torch torch.set_float32_matmul_precision('high') torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn torch.backends.cuda.enable_mem_efficient_sdp(True) torch.backends.cuda.enable_math_sdp(True) torch.backends.cuda.enable_flash_sdp(True) torch.backends.cuda.enable_cudnn_sdp(True) from huggingface_hub import hf_hub_download import TMIDIX from midi_to_colab_audio import midi_to_colab_audio from x_transformer_1_23_2 import * import random import tqdm print('=' * 70) print('Loading aux Guided Accompaniment Transformer modules...') import matplotlib.pyplot as plt import gradio as gr import spaces print('=' * 70) print('PyTorch version:', torch.__version__) print('=' * 70) print('Done!') print('Enjoy! :)') print('=' * 70) #================================================================================== MODEL_CHECKPOINT = 'Guided_Accompaniment_Transformer_Trained_Model_36457_steps_0.5384_loss_0.8417_acc.pth' SOUDFONT_PATH = 'SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2' #================================================================================== print('=' * 70) print('Instantiating model...') device_type = 'cuda' dtype = 'bfloat16' ptdtype = {'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) SEQ_LEN = 4096 PAD_IDX = 1794 model = TransformerWrapper( num_tokens = PAD_IDX+1, max_seq_len = SEQ_LEN, attn_layers = Decoder(dim = 2048, depth = 4, heads = 32, rotary_pos_emb = True, attn_flash = True ) ) model = AutoregressiveWrapper(model, ignore_index=PAD_IDX, pad_value=PAD_IDX) print('=' * 70) print('Loading model checkpoint...') model_checkpoint = hf_hub_download(repo_id='asigalov61/Guided-Accompaniment-Transformer', filename=MODEL_CHECKPOINT) model.load_state_dict(torch.load(model_checkpoint, map_location='cpu', weights_only=True)) model = torch.compile(model, mode='max-autotune') print('=' * 70) print('Done!') print('=' * 70) print('Model will use', dtype, 'precision...') print('=' * 70) #================================================================================== def load_midi(input_midi, melody_patch=-1): raw_score = TMIDIX.midi2single_track_ms_score(input_midi) escore_notes = TMIDIX.advanced_score_processor(raw_score, return_enhanced_score_notes=True)[0] escore_notes = TMIDIX.augment_enhanced_score_notes(escore_notes, timings_divider=32) sp_escore_notes = TMIDIX.solo_piano_escore_notes(escore_notes, keep_drums=False) if melody_patch == -1: zscore = TMIDIX.recalculate_score_timings(sp_escore_notes) else: zscore = TMIDIX.recalculate_score_timings([e for e in sp_escore_notes if e[6] == melody_patch]) cscore = TMIDIX.chordify_score([1000, zscore]) score = [] score_list = [] pc = cscore[0] for c in cscore: score.append(max(0, min(127, c[0][1]-pc[0][1]))) scl = [[max(0, min(127, c[0][1]-pc[0][1]))]] n = c[0] score.extend([max(1, min(127, n[2]))+128, max(1, min(127, n[4]))+256]) scl.append([max(1, min(127, n[2]))+128, max(1, min(127, n[4]))+256]) score_list.append(scl) pc = c score_list.append(scl) return score, score_list #================================================================================== @spaces.GPU def Generate_Accompaniment(input_midi, generation_type, melody_patch, model_temperature ): #=============================================================================== def generate_full_seq(input_seq, temperature=0.9, verbose=True): seq_abs_run_time = sum([t for t in input_seq if t < 128]) cur_time = 0 full_seq = input_seq toks_counter = 0 while cur_time < seq_abs_run_time: if verbose: if toks_counter % 128 == 0: print('Generated', toks_counter, 'tokens') x = torch.LongTensor(full_seq).cuda() with ctx: out = model.generate(x, 1, temperature=temperature, return_prime=False, verbose=False) y = out.tolist()[0][0] if y < 128: cur_time += y full_seq.append(y) toks_counter += 1 return full_seq #=============================================================================== def generate_block_seq(input_seq, trg_dtime, temperature=0.9): cur_time = 0 block_seq = [128] while cur_time != trg_dtime and len(block_seq) < 2 and block_seq[-1] > 127: inp_seq = copy.deepcopy(input_seq) block_seq = [] cur_time = 0 while cur_time < trg_dtime: x = torch.LongTensor(inp_seq).cuda() with ctx: out = model.generate(x, 1, temperature=temperature, return_prime=False, verbose=False) y = out.tolist()[0][0] if y < 128: cur_time += y inp_seq.append(y) block_seq.append(y) return block_seq #=============================================================================== print('=' * 70) print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT))) start_time = reqtime.time() print('=' * 70) fn = os.path.basename(input_midi) fn1 = fn.split('.')[0] print('=' * 70) print('Requested settings:') print('=' * 70) print('Input MIDI file name:', fn) print('Generation type:', generation_type) print('Source melody patch:', melody_patch) print('Model temperature:', model_temperature) print('=' * 70) #================================================================== score, score_list = load_midi(input_midi.name) print('Sample score events', score[:12]) #================================================================== print('=' * 70) print('Generating...') model.to(device_type) model.eval() #================================================================== start_score_seq = [1792] + score + [1793] #================================================================== if generation_type == 'Guided': input_seq = [] input_seq.extend(start_score_seq) input_seq.extend(score_list[0][0]) for i in tqdm.tqdm(range(len(score_list)-1)): input_seq.extend(score_list[i][1]) block_seq = generate_block_seq(input_seq, score_list[i+1][0][0], temperature=model_temperature) input_seq.extend(block_seq) else: input_seq = generate_full_seq(start_score_seq, temperature=model_temperature) #================================================================== final_song = input_seq[len(start_score_seq):] #================================================================== print('=' * 70) print('Done!') print('=' * 70) #=============================================================================== print('Rendering results...') print('=' * 70) print('Sample INTs', final_song[:15]) print('=' * 70) song_f = [] if len(final_song) != 0: time = 0 dur = 0 vel = 90 pitch = 0 channel = 0 patch = 0 channels_map = [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 11, 9, 12, 13, 14, 15] patches_map = [40, 0, 10, 19, 24, 35, 40, 52, 56, 9, 65, 73, 0, 0, 0, 0] velocities_map = [125, 80, 100, 80, 90, 100, 100, 80, 110, 110, 110, 110, 80, 80, 80, 80] for m in final_song: if 0 <= m < 128: time += m * 32 elif 128 < m < 256: dur = (m-128) * 32 elif 256 < m < 1792: cha = (m-256) // 128 pitch = (m-256) % 128 channel = channels_map[cha] patch = patches_map[channel] vel = velocities_map[channel] song_f.append(['note', time, dur, channel, pitch, vel, patch]) fn1 = "Guided-Accompaniment-Transformer-Composition" detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(song_f, output_signature = 'Guided Accompaniment Transformer', output_file_name = fn1, track_name='Project Los Angeles', list_of_MIDI_patches=patches_map ) new_fn = fn1+'.mid' audio = midi_to_colab_audio(new_fn, soundfont_path=soundfont, sample_rate=16000, volume_scale=10, output_for_gradio=True ) print('Done!') print('=' * 70) #======================================================== output_midi = str(new_fn) output_audio = (16000, audio) output_plot = TMIDIX.plot_ms_SONG(song_f, plot_title=output_midi, return_plt=True) print('Output MIDI file name:', output_midi) print('=' * 70) #======================================================== print('-' * 70) print('Req end time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT))) print('-' * 70) print('Req execution time:', (reqtime.time() - start_time), 'sec') return output_audio, output_plot, output_midi #================================================================================== PDT = timezone('US/Pacific') print('=' * 70) print('App start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT))) print('=' * 70) #================================================================================== with gr.Blocks() as demo: #================================================================================== gr.Markdown("

Guided Accompaniment Transformer

") gr.Markdown("

Guided melody accompaniment generation with transformers

") gr.HTML(""" Check out Guided Accompaniment Transformer on GitHub or on

PyPI Project or Duplicate in Hugging Face

for faster execution and endless generation! """) #================================================================================== gr.Markdown("## Upload source melody MIDI") input_midi = gr.File(label="Input MIDI", file_types=[".midi", ".mid", ".kar"]) gr.Markdown("## Generation options") generation_type = gr.Radio(["Guided", "Freestyle"], value="Guided", label="Generation type") melody_patch = gr.Slider(-1, 127, value=-1, step=1, label="Source melody MIDI patch") model_temperature = gr.Slider(0.1, 1, value=0.9, step=0.01, label="Model temperature") generate_btn = gr.Button("Generate", variant="primary") gr.Markdown("## Generation results") output_audio = gr.Audio(label="MIDI audio", format="wav", elem_id="midi_audio") output_plot = gr.Plot(label="MIDI score plot") output_midi = gr.File(label="MIDI file", file_types=[".mid"]) generate_btn.click(Generate_Accompaniment, [input_midi, generation_type, melody_patch, model_temperature ], [ output_audio, output_plot, output_midi, ] ) '''gr.Examples( [["asap_midi_score_21.mid", "Score", "Durations and Velocities", 8, 600, 1, 1.1, 1.5], ["asap_midi_score_45.mid", "Score", "Durations and Velocities", 8, 600, 1, 1.1, 1.5], ["asap_midi_score_69.mid", "Score", "Durations and Velocities", 8, 600, 1, 1.1, 1.5], ["asap_midi_score_118.mid", "Score", "Durations and Velocities", 8, 600, 1, 1.1, 1.5], ["asap_midi_score_167.mid", "Score", "Durations and Velocities", 8, 600, 1, 1.1, 1.5], ], [input_midi, input_midi_type, input_conv_type, input_number_prime_notes, input_number_conv_notes, input_model_dur_top_k, input_model_dur_temperature, input_model_vel_temperature ], [output_midi_title, output_midi_summary, output_midi, output_audio, output_plot], Convert_Score_to_Performance )''' #================================================================================== demo.launch() #==================================================================================