Spaces:
Running
Running
import gradio as gr | |
import torch | |
from PIL import Image | |
import numpy as np | |
from spectro import wav_bytes_from_spectrogram_image | |
from diffusers import StableDiffusionPipeline | |
from diffusers import StableDiffusionImg2ImgPipeline | |
device = "cpu" | |
MODEL_ID = "Hyeon2/riffusion-musiccaps" | |
pipe = StableDiffusionPipeline.from_pretrained(MODEL_ID, torch_dtype=torch.float16) | |
pipe = pipe.to(device) | |
def predict(prompt, negative_prompt, audio_input, duration): | |
return classic(prompt, negative_prompt, duration) | |
def classic(prompt, negative_prompt, duration): | |
if duration == 5: | |
width_duration=512 | |
else: | |
width_duration = 512 + ((int(duration) - 5) * 128) | |
spec = pipe(prompt, negative_prompt=negative_prompt, height=512, width=width_duration).images[0] | |
print(spec) | |
wav = wav_bytes_from_spectrogram_image(spec) | |
with open("output.wav", "wb") as f: | |
f.write(wav[0].getbuffer()) | |
return spec, 'output.wav' | |
title = """ | |
<div style="text-align: center; max-width: 500px; margin: 0 auto;"> | |
<div | |
style=" | |
display: inline-flex; | |
align-items: center; | |
gap: 0.8rem; | |
font-size: 1.75rem; | |
margin-bottom: 10px; | |
line-height: 1em; | |
" | |
> | |
<h1 style="font-weight: 600; margin-bottom: 7px;"> | |
Riffusion-Musiccaps real-time music generation | |
</h1> | |
</div> | |
<p style="margin-bottom: 10px;font-size: 94%;font-weight: 100;line-height: 1.5em;"> | |
Describe a musical prompt, generate music by getting a spectrogram image & sound. | |
</p> | |
</div> | |
""" | |
css = ''' | |
#col-container, #col-container-2 {max-width: 510px; margin-left: auto; margin-right: auto;} | |
a {text-decoration-line: underline; font-weight: 600;} | |
div#record_btn > .mt-6 { | |
margin-top: 0!important; | |
} | |
div#record_btn > .mt-6 button { | |
width: 100%; | |
height: 40px; | |
} | |
.footer { | |
margin-bottom: 45px; | |
margin-top: 10px; | |
text-align: center; | |
border-bottom: 1px solid #e5e5e5; | |
} | |
.footer>p { | |
font-size: .8rem; | |
display: inline-block; | |
padding: 0 10px; | |
transform: translateY(10px); | |
background: white; | |
} | |
.dark .footer { | |
border-color: #303030; | |
} | |
.dark .footer>p { | |
background: #0b0f19; | |
} | |
.animate-spin { | |
animation: spin 1s linear infinite; | |
} | |
@keyframes spin { | |
from { | |
transform: rotate(0deg); | |
} | |
to { | |
transform: rotate(360deg); | |
} | |
} | |
#share-btn-container { | |
display: flex; padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; width: 13rem; | |
} | |
#share-btn { | |
all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.25rem !important; padding-bottom: 0.25rem !important;right:0; | |
} | |
#share-btn * { | |
all: unset; | |
} | |
#share-btn-container div:nth-child(-n+2){ | |
width: auto !important; | |
min-height: 0px !important; | |
} | |
#share-btn-container .wrap { | |
display: none !important; | |
} | |
''' | |
with gr.Blocks(css="style.css") as demo: | |
with gr.Column(elem_id="col-container"): | |
gr.HTML(title) | |
prompt_input = gr.Textbox(placeholder="a cat diva singing in a New York jazz club", label="Musical prompt", elem_id="prompt-in") | |
audio_input = gr.Audio(sources=["upload"], type="filepath", visible=False) | |
with gr.Row(): | |
negative_prompt = gr.Textbox(label="Negative prompt") | |
duration_input = gr.Slider(label="Duration in seconds", minimum=5, maximum=10, step=1, value=8, elem_id="duration-slider") | |
send_btn = gr.Button(value="Get a new spectrogram!", elem_id="submit-btn") | |
with gr.Column(elem_id="col-container-2"): | |
spectrogram_output = gr.Image(label="spectrogram image result", elem_id="img-out") | |
sound_output = gr.Audio(type='filepath', label="spectrogram sound", elem_id="music-out") | |
send_btn.click(predict, inputs=[prompt_input, negative_prompt, audio_input, duration_input], outputs=[spectrogram_output, sound_output]) | |
demo.queue(max_size=250).launch(debug=True, ssr_mode=False) |