import os import random from string import ascii_letters from fastapi import FastAPI from fastapi.staticfiles import StaticFiles import mido import uvicorn import gradio as gr from inference.inference import generate_groove # create a FastAPI app app = FastAPI() app.mount("/uploads", StaticFiles(directory="./uploads"), name="uploads") app.mount("/generated", StaticFiles(directory="./generated"), name="generated") def display_bpm(bpm): return f"BPM: {bpm}" visualizer_html_template = """
""" def load_midi_file(input_midi, bpm_input): if input_midi: mid = mido.MidiFile() midi_filename = f"{str.join("", random.choices(ascii_letters, k=16))}.mid" filepath = os.path.join("uploads/", midi_filename) with open(filepath, "wb") as fd: fd.write(input_midi) return visualizer_html_template.format(filepath=filepath, id="input"), filepath else: return None, None def run_inference(midi_filename: str, count: int=1): visualizers = [] filenames = generate_groove(midi_filename, count) for id, filename in enumerate(filenames): visualizers.append(visualizer_html_template.format(filepath=filename, id=id)) return visualizers + filenames head = """ """ block = gr.Blocks(head=head) with block: midi_filepath = gr.State() ### MIDI UPLOAD AND PREVIEW ### with gr.Group(): input_midi = gr.File(label="Upload basic drum part", file_types=[".midi", ".mid"], type="binary") bpm_input = gr.Number(value=120, label="BPM", interactive=True) load_btn = gr.Button("load midi file") midi_player = gr.HTML() run_event = load_btn.click(load_midi_file, [input_midi, bpm_input], [midi_player, midi_filepath]) with open("js/midi-player.html") as fd: html = fd.read() ### GENERATION SETTINGS ### with gr.Group(): gr.Markdown("## Generation Settings") with gr.Row(): generate_genre = gr.Dropdown(["Rock", "Pop", "Reggae", "Jazz", "Metal"], label="Genre", interactive=True) generate_complexity = gr.Slider(1, 10, value=5, label="Complexity", info="Choose between 1 and 10", step=float, interactive=True) with gr.Row(): bpm_display = gr.Textbox(label="BPM value", interactive=False) bpm_input.change(fn=display_bpm, inputs=bpm_input, outputs=bpm_display) generate_length = gr.Dropdown(["1 Bar", "2 Bars", "3 Bars", "4 Bars", "5 Bars"], label="Length", interactive=True) with gr.Row(): generate_button = gr.Button("Generate") ### OUTPUT ### number_outputs = gr.State(4) with gr.Group(): gr.Markdown("## Output") with gr.Tab("1"): with gr.Row(): midi_player_output_1 = gr.HTML() # with gr.Row(): # bpm_output = gr.Number(value=120, label="BPM", interactive=False) # output_instrument = gr.Dropdown(["Drums", "Snare", "Kick"], label="Sound", interactive=True) download_button_1 = gr.DownloadButton("Download") with gr.Tab("2"): with gr.Row(): midi_player_output_2 = gr.HTML() # with gr.Row(): # bpm_output = gr.Number(value=123, label="BPM", interactive=False) # output_instrument = gr.Dropdown(["Drums", "Snare", "Kick"], label="Sound", interactive=True) download_button_2 = gr.DownloadButton("Download") with gr.Tab("3"): with gr.Row(): midi_player_output_3 = gr.HTML() # with gr.Row(): # bpm_output = gr.Number(value=168, label="BPM", interactive=False) # # output_instrument = gr.Dropdown(["Drums", "Snare", "Kick"], label="Sound", interactive=True) download_button_3 = gr.DownloadButton("Download") with gr.Tab("4"): with gr.Row(): midi_player_output_4 = gr.HTML() # with gr.Row(): # bpm_output = gr.Number(value=222, label="BPM", interactive=False) # output_instrument = gr.Dropdown(["Drums", "Snare", "Kick"], label="Sound", interactive=True) download_button_4 = gr.DownloadButton("Download") run_event = generate_button.click(run_inference, [midi_filepath, number_outputs], [midi_player_output_1, midi_player_output_2, midi_player_output_3, midi_player_output_4, download_button_1, download_button_2, download_button_3, download_button_4]) # mount Gradio app to FastAPI app app = gr.mount_gradio_app(app, block, path="/") # serve the app if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)