import gradio as gr import torch from PIL import Image import numpy as np from spectro import wav_bytes_from_spectrogram_image from diffusers import StableDiffusionPipeline from diffusers import StableDiffusionImg2ImgPipeline device = "cpu" MODEL_ID = "Hyeon2/riffusion-musiccaps" pipe = StableDiffusionPipeline.from_pretrained(MODEL_ID, torch_dtype=torch.float16) pipe = pipe.to(device) def predict(prompt, negative_prompt, audio_input, duration): return classic(prompt, negative_prompt, duration) def classic(prompt, negative_prompt, duration): if duration == 5: width_duration=512 else: width_duration = 512 + ((int(duration) - 5) * 128) spec = pipe(prompt, negative_prompt=negative_prompt, height=512, width=width_duration).images[0] print(spec) wav = wav_bytes_from_spectrogram_image(spec) with open("output.wav", "wb") as f: f.write(wav[0].getbuffer()) return spec, 'output.wav' title = """

Riffusion-Musiccaps real-time music generation

Describe a musical prompt, generate music by getting a spectrogram image & sound.

""" css = ''' #col-container, #col-container-2 {max-width: 510px; margin-left: auto; margin-right: auto;} a {text-decoration-line: underline; font-weight: 600;} div#record_btn > .mt-6 { margin-top: 0!important; } div#record_btn > .mt-6 button { width: 100%; height: 40px; } .footer { margin-bottom: 45px; margin-top: 10px; text-align: center; border-bottom: 1px solid #e5e5e5; } .footer>p { font-size: .8rem; display: inline-block; padding: 0 10px; transform: translateY(10px); background: white; } .dark .footer { border-color: #303030; } .dark .footer>p { background: #0b0f19; } .animate-spin { animation: spin 1s linear infinite; } @keyframes spin { from { transform: rotate(0deg); } to { transform: rotate(360deg); } } #share-btn-container { display: flex; padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; width: 13rem; } #share-btn { all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.25rem !important; padding-bottom: 0.25rem !important;right:0; } #share-btn * { all: unset; } #share-btn-container div:nth-child(-n+2){ width: auto !important; min-height: 0px !important; } #share-btn-container .wrap { display: none !important; } ''' with gr.Blocks(css="style.css") as demo: with gr.Column(elem_id="col-container"): gr.HTML(title) prompt_input = gr.Textbox(placeholder="a cat diva singing in a New York jazz club", label="Musical prompt", elem_id="prompt-in") audio_input = gr.Audio(sources=["upload"], type="filepath", visible=False) with gr.Row(): negative_prompt = gr.Textbox(label="Negative prompt") duration_input = gr.Slider(label="Duration in seconds", minimum=5, maximum=10, step=1, value=8, elem_id="duration-slider") send_btn = gr.Button(value="Get a new spectrogram!", elem_id="submit-btn") with gr.Column(elem_id="col-container-2"): spectrogram_output = gr.Image(label="spectrogram image result", elem_id="img-out") sound_output = gr.Audio(type='filepath', label="spectrogram sound", elem_id="music-out") send_btn.click(predict, inputs=[prompt_input, negative_prompt, audio_input, duration_input], outputs=[spectrogram_output, sound_output]) demo.queue(max_size=250).launch(debug=True, ssr_mode=False)