JannikAhlers's picture
fix filepaths
22f9998
import os
import random
from string import ascii_letters
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
import mido
import uvicorn
import gradio as gr
from inference.inference import generate_groove
# create a FastAPI app
app = FastAPI()
app.mount("/uploads", StaticFiles(directory="./uploads"), name="uploads")
app.mount("/generated", StaticFiles(directory="./generated"), name="generated")
def display_bpm(bpm):
return f"BPM: {bpm}"
visualizer_html_template = """
<div>
<midi-visualizer type="piano-roll" id="myPianoRollVisualizer{id}"
src="{filepath}">
</midi-visualizer>
<midi-player
src="{filepath}"
sound-font="https://storage.googleapis.com/magentadata/js/soundfonts/jazz_kit" visualizer="#myPianoRollVisualizer{id}">
</midi-player>
</div>
"""
def load_midi_file(input_midi, bpm_input):
if input_midi:
mid = mido.MidiFile()
midi_filename = f"{str.join("", random.choices(ascii_letters, k=16))}.mid"
filepath = os.path.join("uploads/", midi_filename)
with open(filepath, "wb") as fd:
fd.write(input_midi)
return visualizer_html_template.format(filepath=filepath, id="input"), filepath
else:
return None, None
def run_inference(midi_filename: str, count: int=1):
visualizers = []
filenames = generate_groove(midi_filename, count)
for id, filename in enumerate(filenames):
visualizers.append(visualizer_html_template.format(filepath=filename, id=id))
return visualizers + filenames
head = """
<script src="https://cdn.jsdelivr.net/combine/npm/[email protected],npm/@magenta/[email protected]/es6/core.js,npm/focus-visible@5,npm/[email protected]"></script>
"""
block = gr.Blocks(head=head)
with block:
midi_filepath = gr.State()
### MIDI UPLOAD AND PREVIEW ###
with gr.Group():
input_midi = gr.File(label="Upload basic drum part", file_types=[".midi", ".mid"], type="binary")
bpm_input = gr.Number(value=120, label="BPM", interactive=True)
load_btn = gr.Button("load midi file")
midi_player = gr.HTML()
run_event = load_btn.click(load_midi_file, [input_midi, bpm_input], [midi_player, midi_filepath])
with open("js/midi-player.html") as fd:
html = fd.read()
### GENERATION SETTINGS ###
with gr.Group():
gr.Markdown("## Generation Settings")
with gr.Row():
generate_genre = gr.Dropdown(["Rock", "Pop", "Reggae", "Jazz", "Metal"], label="Genre", interactive=True)
generate_complexity = gr.Slider(1, 10, value=5, label="Complexity", info="Choose between 1 and 10", step=float, interactive=True)
with gr.Row():
bpm_display = gr.Textbox(label="BPM value", interactive=False)
bpm_input.change(fn=display_bpm, inputs=bpm_input, outputs=bpm_display)
generate_length = gr.Dropdown(["1 Bar", "2 Bars", "3 Bars", "4 Bars", "5 Bars"], label="Length", interactive=True)
with gr.Row():
generate_button = gr.Button("Generate")
### OUTPUT ###
number_outputs = gr.State(4)
with gr.Group():
gr.Markdown("## Output")
with gr.Tab("1"):
with gr.Row():
midi_player_output_1 = gr.HTML()
# with gr.Row():
# bpm_output = gr.Number(value=120, label="BPM", interactive=False)
# output_instrument = gr.Dropdown(["Drums", "Snare", "Kick"], label="Sound", interactive=True)
download_button_1 = gr.DownloadButton("Download")
with gr.Tab("2"):
with gr.Row():
midi_player_output_2 = gr.HTML()
# with gr.Row():
# bpm_output = gr.Number(value=123, label="BPM", interactive=False)
# output_instrument = gr.Dropdown(["Drums", "Snare", "Kick"], label="Sound", interactive=True)
download_button_2 = gr.DownloadButton("Download")
with gr.Tab("3"):
with gr.Row():
midi_player_output_3 = gr.HTML()
# with gr.Row():
# bpm_output = gr.Number(value=168, label="BPM", interactive=False)
# # output_instrument = gr.Dropdown(["Drums", "Snare", "Kick"], label="Sound", interactive=True)
download_button_3 = gr.DownloadButton("Download")
with gr.Tab("4"):
with gr.Row():
midi_player_output_4 = gr.HTML()
# with gr.Row():
# bpm_output = gr.Number(value=222, label="BPM", interactive=False)
# output_instrument = gr.Dropdown(["Drums", "Snare", "Kick"], label="Sound", interactive=True)
download_button_4 = gr.DownloadButton("Download")
run_event = generate_button.click(run_inference, [midi_filepath, number_outputs],
[midi_player_output_1, midi_player_output_2, midi_player_output_3, midi_player_output_4, download_button_1, download_button_2, download_button_3, download_button_4])
# mount Gradio app to FastAPI app
app = gr.mount_gradio_app(app, block, path="/")
# serve the app
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)