|
import os |
|
import random |
|
from string import ascii_letters |
|
from fastapi import FastAPI |
|
from fastapi.staticfiles import StaticFiles |
|
import mido |
|
import uvicorn |
|
import gradio as gr |
|
from inference.inference import generate_groove |
|
|
|
|
|
app = FastAPI() |
|
app.mount("/uploads", StaticFiles(directory="./uploads"), name="uploads") |
|
app.mount("/generated", StaticFiles(directory="./generated"), name="generated") |
|
|
|
def display_bpm(bpm): |
|
return f"BPM: {bpm}" |
|
|
|
visualizer_html_template = """ |
|
<div> |
|
<midi-visualizer type="piano-roll" id="myPianoRollVisualizer{id}" |
|
src="{filepath}"> |
|
</midi-visualizer> |
|
<midi-player |
|
src="{filepath}" |
|
sound-font="https://storage.googleapis.com/magentadata/js/soundfonts/jazz_kit" visualizer="#myPianoRollVisualizer{id}"> |
|
</midi-player> |
|
</div> |
|
""" |
|
|
|
|
|
|
|
|
|
def load_midi_file(input_midi, bpm_input): |
|
if input_midi: |
|
mid = mido.MidiFile() |
|
midi_filename = f"{str.join("", random.choices(ascii_letters, k=16))}.mid" |
|
filepath = os.path.join("uploads/", midi_filename) |
|
with open(filepath, "wb") as fd: |
|
fd.write(input_midi) |
|
return visualizer_html_template.format(filepath=filepath, id="input"), filepath |
|
else: |
|
return None, None |
|
|
|
def run_inference(midi_filename: str, count: int=1): |
|
visualizers = [] |
|
filenames = generate_groove(midi_filename, count) |
|
for id, filename in enumerate(filenames): |
|
visualizers.append(visualizer_html_template.format(filepath=filename, id=id)) |
|
return visualizers + filenames |
|
|
|
|
|
head = """ |
|
<script src="https://cdn.jsdelivr.net/combine/npm/[email protected],npm/@magenta/[email protected]/es6/core.js,npm/focus-visible@5,npm/[email protected]"></script> |
|
""" |
|
block = gr.Blocks(head=head) |
|
with block: |
|
midi_filepath = gr.State() |
|
|
|
with gr.Group(): |
|
input_midi = gr.File(label="Upload basic drum part", file_types=[".midi", ".mid"], type="binary") |
|
bpm_input = gr.Number(value=120, label="BPM", interactive=True) |
|
load_btn = gr.Button("load midi file") |
|
midi_player = gr.HTML() |
|
run_event = load_btn.click(load_midi_file, [input_midi, bpm_input], [midi_player, midi_filepath]) |
|
|
|
with open("js/midi-player.html") as fd: |
|
html = fd.read() |
|
|
|
|
|
with gr.Group(): |
|
gr.Markdown("## Generation Settings") |
|
with gr.Row(): |
|
generate_genre = gr.Dropdown(["Rock", "Pop", "Reggae", "Jazz", "Metal"], label="Genre", interactive=True) |
|
generate_complexity = gr.Slider(1, 10, value=5, label="Complexity", info="Choose between 1 and 10", step=float, interactive=True) |
|
|
|
with gr.Row(): |
|
bpm_display = gr.Textbox(label="BPM value", interactive=False) |
|
bpm_input.change(fn=display_bpm, inputs=bpm_input, outputs=bpm_display) |
|
generate_length = gr.Dropdown(["1 Bar", "2 Bars", "3 Bars", "4 Bars", "5 Bars"], label="Length", interactive=True) |
|
|
|
with gr.Row(): |
|
generate_button = gr.Button("Generate") |
|
|
|
|
|
|
|
number_outputs = gr.State(4) |
|
with gr.Group(): |
|
gr.Markdown("## Output") |
|
with gr.Tab("1"): |
|
with gr.Row(): |
|
midi_player_output_1 = gr.HTML() |
|
|
|
|
|
|
|
download_button_1 = gr.DownloadButton("Download") |
|
|
|
with gr.Tab("2"): |
|
with gr.Row(): |
|
midi_player_output_2 = gr.HTML() |
|
|
|
|
|
|
|
download_button_2 = gr.DownloadButton("Download") |
|
|
|
with gr.Tab("3"): |
|
with gr.Row(): |
|
midi_player_output_3 = gr.HTML() |
|
|
|
|
|
|
|
download_button_3 = gr.DownloadButton("Download") |
|
|
|
with gr.Tab("4"): |
|
with gr.Row(): |
|
midi_player_output_4 = gr.HTML() |
|
|
|
|
|
|
|
download_button_4 = gr.DownloadButton("Download") |
|
|
|
run_event = generate_button.click(run_inference, [midi_filepath, number_outputs], |
|
[midi_player_output_1, midi_player_output_2, midi_player_output_3, midi_player_output_4, download_button_1, download_button_2, download_button_3, download_button_4]) |
|
|
|
|
|
|
|
app = gr.mount_gradio_app(app, block, path="/") |
|
|
|
if __name__ == "__main__": |
|
uvicorn.run(app, host="0.0.0.0", port=7860) |