Spaces:
Running
on
Zero
Running
on
Zero
add seed
Browse files
app.py
CHANGED
|
@@ -14,6 +14,7 @@ import MIDI
|
|
| 14 |
from midi_synthesizer import synthesis
|
| 15 |
from midi_tokenizer import MIDITokenizer
|
| 16 |
|
|
|
|
| 17 |
in_space = os.getenv("SYSTEM") == "spaces"
|
| 18 |
|
| 19 |
|
|
@@ -23,7 +24,9 @@ def softmax(x, axis):
|
|
| 23 |
return exp_x_shifted / np.sum(exp_x_shifted, axis=axis, keepdims=True)
|
| 24 |
|
| 25 |
|
| 26 |
-
def sample_top_p_k(probs, p, k):
|
|
|
|
|
|
|
| 27 |
probs_idx = np.argsort(-probs, axis=-1)
|
| 28 |
probs_sort = np.take_along_axis(probs, probs_idx, -1)
|
| 29 |
probs_sum = np.cumsum(probs_sort, axis=-1)
|
|
@@ -36,17 +39,19 @@ def sample_top_p_k(probs, p, k):
|
|
| 36 |
shape = probs_sort.shape
|
| 37 |
probs_sort_flat = probs_sort.reshape(-1, shape[-1])
|
| 38 |
probs_idx_flat = probs_idx.reshape(-1, shape[-1])
|
| 39 |
-
next_token = np.stack([
|
| 40 |
next_token = next_token.reshape(*shape[:-1])
|
| 41 |
return next_token
|
| 42 |
|
| 43 |
|
| 44 |
def generate(model, prompt=None, max_len=512, temp=1.0, top_p=0.98, top_k=20,
|
| 45 |
-
disable_patch_change=False, disable_control_change=False, disable_channels=None):
|
| 46 |
if disable_channels is not None:
|
| 47 |
disable_channels = [tokenizer.parameter_ids["channel"][c] for c in disable_channels]
|
| 48 |
else:
|
| 49 |
disable_channels = []
|
|
|
|
|
|
|
| 50 |
max_token_seq = tokenizer.max_token_seq
|
| 51 |
if prompt is None:
|
| 52 |
input_tensor = np.full((1, max_token_seq), tokenizer.pad_id, dtype=np.int64)
|
|
@@ -83,7 +88,7 @@ def generate(model, prompt=None, max_len=512, temp=1.0, top_p=0.98, top_k=20,
|
|
| 83 |
mask[mask_ids] = 1
|
| 84 |
logits = model[1].run(None, {'x': next_token_seq, "hidden": hidden})[0][:, -1:]
|
| 85 |
scores = softmax(logits / temp, -1) * mask
|
| 86 |
-
sample = sample_top_p_k(scores, top_p, top_k)
|
| 87 |
if i == 0:
|
| 88 |
next_token_seq = sample
|
| 89 |
eid = sample.item()
|
|
@@ -120,13 +125,16 @@ def send_msgs(msgs, msgs_history=None):
|
|
| 120 |
return json.dumps(msgs_history)
|
| 121 |
|
| 122 |
|
| 123 |
-
def run(model_name, tab, instruments, drum_kit, bpm, mid, midi_events,
|
|
|
|
| 124 |
msgs_history = []
|
| 125 |
mid_seq = []
|
| 126 |
bpm = int(bpm)
|
| 127 |
gen_events = int(gen_events)
|
| 128 |
max_len = gen_events
|
| 129 |
-
|
|
|
|
|
|
|
| 130 |
disable_patch_change = False
|
| 131 |
disable_channels = None
|
| 132 |
if tab == 0:
|
|
@@ -159,22 +167,22 @@ def run(model_name, tab, instruments, drum_kit, bpm, mid, midi_events, gen_event
|
|
| 159 |
init_msgs = [create_msg("visualizer_clear", False)]
|
| 160 |
for tokens in mid_seq:
|
| 161 |
init_msgs.append(create_msg("visualizer_append", tokenizer.tokens2event(tokens)))
|
| 162 |
-
yield mid_seq, None, None, send_msgs(init_msgs, msgs_history)
|
| 163 |
model = models[model_name]
|
| 164 |
-
|
| 165 |
disable_patch_change=disable_patch_change, disable_control_change=not allow_cc,
|
| 166 |
-
disable_channels=disable_channels)
|
| 167 |
-
for i, token_seq in enumerate(
|
| 168 |
token_seq = token_seq.tolist()
|
| 169 |
mid_seq.append(token_seq)
|
| 170 |
event = tokenizer.tokens2event(token_seq)
|
| 171 |
-
yield mid_seq, None, None, send_msgs([create_msg("visualizer_append", event), create_msg("progress", [i + 1, gen_events])], msgs_history)
|
| 172 |
mid = tokenizer.detokenize(mid_seq)
|
| 173 |
with open(f"output.mid", 'wb') as f:
|
| 174 |
f.write(MIDI.score2midi(mid))
|
| 175 |
audio = synthesis(MIDI.score2opus(mid), soundfont_path)
|
| 176 |
events = [tokenizer.tokens2event(tokens) for tokens in mid_seq]
|
| 177 |
-
yield mid_seq, "output.mid", (44100, audio), send_msgs([create_msg("visualizer_end", events)])
|
| 178 |
|
| 179 |
|
| 180 |
def cancel_run(mid_seq):
|
|
@@ -232,8 +240,8 @@ if __name__ == "__main__":
|
|
| 232 |
opt = parser.parse_args()
|
| 233 |
soundfont_path = hf_hub_download_retry(repo_id="skytnt/midi-model", filename="soundfont.sf2")
|
| 234 |
models_info = {"generic pretrain model": ["skytnt/midi-model", ""],
|
| 235 |
-
"j-pop finetune model": ["skytnt/midi-model-ft", "jpop/"],
|
| 236 |
-
"touhou finetune model": ["skytnt/midi-model-ft", "touhou/"],
|
| 237 |
}
|
| 238 |
models = {}
|
| 239 |
tokenizer = MIDITokenizer()
|
|
@@ -301,7 +309,10 @@ if __name__ == "__main__":
|
|
| 301 |
|
| 302 |
tab1.select(lambda: 0, None, tab_select, queue=False)
|
| 303 |
tab2.select(lambda: 1, None, tab_select, queue=False)
|
| 304 |
-
|
|
|
|
|
|
|
|
|
|
| 305 |
step=1, value=opt.max_gen // 2)
|
| 306 |
with gr.Accordion("options", open=False):
|
| 307 |
input_temp = gr.Slider(label="temperature", minimum=0.1, maximum=1.2, step=0.01, value=1)
|
|
@@ -316,9 +327,9 @@ if __name__ == "__main__":
|
|
| 316 |
output_audio = gr.Audio(label="output audio", format="mp3", elem_id="midi_audio")
|
| 317 |
output_midi = gr.File(label="output midi", file_types=[".mid"])
|
| 318 |
run_event = run_btn.click(run, [input_model, tab_select, input_instruments, input_drum_kit, input_bpm,
|
| 319 |
-
input_midi, input_midi_events,
|
| 320 |
-
input_top_p, input_top_k, input_allow_cc],
|
| 321 |
-
[output_midi_seq, output_midi, output_audio, js_msg],
|
| 322 |
concurrency_limit=3)
|
| 323 |
stop_btn.click(cancel_run, [output_midi_seq], [output_midi, output_audio, js_msg], cancels=run_event, queue=False)
|
| 324 |
app.launch(server_port=opt.port, share=opt.share, inbrowser=True)
|
|
|
|
| 14 |
from midi_synthesizer import synthesis
|
| 15 |
from midi_tokenizer import MIDITokenizer
|
| 16 |
|
| 17 |
+
MAX_SEED = np.iinfo(np.int32).max
|
| 18 |
in_space = os.getenv("SYSTEM") == "spaces"
|
| 19 |
|
| 20 |
|
|
|
|
| 24 |
return exp_x_shifted / np.sum(exp_x_shifted, axis=axis, keepdims=True)
|
| 25 |
|
| 26 |
|
| 27 |
+
def sample_top_p_k(probs, p, k, generator=None):
|
| 28 |
+
if generator is None:
|
| 29 |
+
generator = np.random
|
| 30 |
probs_idx = np.argsort(-probs, axis=-1)
|
| 31 |
probs_sort = np.take_along_axis(probs, probs_idx, -1)
|
| 32 |
probs_sum = np.cumsum(probs_sort, axis=-1)
|
|
|
|
| 39 |
shape = probs_sort.shape
|
| 40 |
probs_sort_flat = probs_sort.reshape(-1, shape[-1])
|
| 41 |
probs_idx_flat = probs_idx.reshape(-1, shape[-1])
|
| 42 |
+
next_token = np.stack([generator.choice(idxs, p=pvals) for pvals, idxs in zip(probs_sort_flat, probs_idx_flat)])
|
| 43 |
next_token = next_token.reshape(*shape[:-1])
|
| 44 |
return next_token
|
| 45 |
|
| 46 |
|
| 47 |
def generate(model, prompt=None, max_len=512, temp=1.0, top_p=0.98, top_k=20,
|
| 48 |
+
disable_patch_change=False, disable_control_change=False, disable_channels=None, generator=None):
|
| 49 |
if disable_channels is not None:
|
| 50 |
disable_channels = [tokenizer.parameter_ids["channel"][c] for c in disable_channels]
|
| 51 |
else:
|
| 52 |
disable_channels = []
|
| 53 |
+
if generator is None:
|
| 54 |
+
generator = np.random
|
| 55 |
max_token_seq = tokenizer.max_token_seq
|
| 56 |
if prompt is None:
|
| 57 |
input_tensor = np.full((1, max_token_seq), tokenizer.pad_id, dtype=np.int64)
|
|
|
|
| 88 |
mask[mask_ids] = 1
|
| 89 |
logits = model[1].run(None, {'x': next_token_seq, "hidden": hidden})[0][:, -1:]
|
| 90 |
scores = softmax(logits / temp, -1) * mask
|
| 91 |
+
sample = sample_top_p_k(scores, top_p, top_k, generator)
|
| 92 |
if i == 0:
|
| 93 |
next_token_seq = sample
|
| 94 |
eid = sample.item()
|
|
|
|
| 125 |
return json.dumps(msgs_history)
|
| 126 |
|
| 127 |
|
| 128 |
+
def run(model_name, tab, instruments, drum_kit, bpm, mid, midi_events, seed, seed_rand,
|
| 129 |
+
gen_events, temp, top_p, top_k, allow_cc):
|
| 130 |
msgs_history = []
|
| 131 |
mid_seq = []
|
| 132 |
bpm = int(bpm)
|
| 133 |
gen_events = int(gen_events)
|
| 134 |
max_len = gen_events
|
| 135 |
+
if seed_rand:
|
| 136 |
+
seed = np.random.randint(0, MAX_SEED)
|
| 137 |
+
generator = np.random.RandomState(seed)
|
| 138 |
disable_patch_change = False
|
| 139 |
disable_channels = None
|
| 140 |
if tab == 0:
|
|
|
|
| 167 |
init_msgs = [create_msg("visualizer_clear", False)]
|
| 168 |
for tokens in mid_seq:
|
| 169 |
init_msgs.append(create_msg("visualizer_append", tokenizer.tokens2event(tokens)))
|
| 170 |
+
yield mid_seq, None, None, seed, send_msgs(init_msgs, msgs_history)
|
| 171 |
model = models[model_name]
|
| 172 |
+
midi_generator = generate(model, mid, max_len=max_len, temp=temp, top_p=top_p, top_k=top_k,
|
| 173 |
disable_patch_change=disable_patch_change, disable_control_change=not allow_cc,
|
| 174 |
+
disable_channels=disable_channels, generator=generator)
|
| 175 |
+
for i, token_seq in enumerate(midi_generator):
|
| 176 |
token_seq = token_seq.tolist()
|
| 177 |
mid_seq.append(token_seq)
|
| 178 |
event = tokenizer.tokens2event(token_seq)
|
| 179 |
+
yield mid_seq, None, None, seed, send_msgs([create_msg("visualizer_append", event), create_msg("progress", [i + 1, gen_events])], msgs_history)
|
| 180 |
mid = tokenizer.detokenize(mid_seq)
|
| 181 |
with open(f"output.mid", 'wb') as f:
|
| 182 |
f.write(MIDI.score2midi(mid))
|
| 183 |
audio = synthesis(MIDI.score2opus(mid), soundfont_path)
|
| 184 |
events = [tokenizer.tokens2event(tokens) for tokens in mid_seq]
|
| 185 |
+
yield mid_seq, "output.mid", (44100, audio), seed, send_msgs([create_msg("visualizer_end", events)])
|
| 186 |
|
| 187 |
|
| 188 |
def cancel_run(mid_seq):
|
|
|
|
| 240 |
opt = parser.parse_args()
|
| 241 |
soundfont_path = hf_hub_download_retry(repo_id="skytnt/midi-model", filename="soundfont.sf2")
|
| 242 |
models_info = {"generic pretrain model": ["skytnt/midi-model", ""],
|
| 243 |
+
# "j-pop finetune model": ["skytnt/midi-model-ft", "jpop/"],
|
| 244 |
+
# "touhou finetune model": ["skytnt/midi-model-ft", "touhou/"],
|
| 245 |
}
|
| 246 |
models = {}
|
| 247 |
tokenizer = MIDITokenizer()
|
|
|
|
| 309 |
|
| 310 |
tab1.select(lambda: 0, None, tab_select, queue=False)
|
| 311 |
tab2.select(lambda: 1, None, tab_select, queue=False)
|
| 312 |
+
input_seed = gr.Slider(label="seed", minimum=0, maximum=2 ** 31 - 1,
|
| 313 |
+
step=1, value=0)
|
| 314 |
+
input_seed_rand = gr.Checkbox(label="random seed", value=True)
|
| 315 |
+
input_gen_events = gr.Slider(label="generate max n midi events", minimum=1, maximum=opt.max_gen,
|
| 316 |
step=1, value=opt.max_gen // 2)
|
| 317 |
with gr.Accordion("options", open=False):
|
| 318 |
input_temp = gr.Slider(label="temperature", minimum=0.1, maximum=1.2, step=0.01, value=1)
|
|
|
|
| 327 |
output_audio = gr.Audio(label="output audio", format="mp3", elem_id="midi_audio")
|
| 328 |
output_midi = gr.File(label="output midi", file_types=[".mid"])
|
| 329 |
run_event = run_btn.click(run, [input_model, tab_select, input_instruments, input_drum_kit, input_bpm,
|
| 330 |
+
input_midi, input_midi_events, input_seed, input_seed_rand, input_gen_events,
|
| 331 |
+
input_temp, input_top_p, input_top_k, input_allow_cc],
|
| 332 |
+
[output_midi_seq, output_midi, output_audio, input_seed, js_msg],
|
| 333 |
concurrency_limit=3)
|
| 334 |
stop_btn.click(cancel_run, [output_midi_seq], [output_midi, output_audio, js_msg], cancels=run_event, queue=False)
|
| 335 |
app.launch(server_port=opt.port, share=opt.share, inbrowser=True)
|