Yjhhh's picture
Update app.py
b4d6f6d verified
import os
import uuid
import torch
import re
import spaces
import gradio as gr
import torchaudio
from audiocraft.models import MusicGen
from audiocraft.data.audio import audio_write
# Configuración del dispositivo
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Configurar ZERO_GPU_PATCH_TORCH_DEVICE
ZERO_GPU_PATCH_TORCH_DEVICE = 1
# Cargar el modelo `musicgen-melody` una única vez
model = MusicGen.get_pretrained("facebook/musicgen-melody")
@spaces.GPU
def generate_music(description, melody_audio, duration):
description = clean_text(description)
model.set_generation_params(duration=int(duration * 1000)) # Convertir segundos a milisegundos
try:
with torch.no_grad():
if description:
description = [description]
if melody_audio:
# Cargar el archivo de audio para remixar
melody, sr = torchaudio.load(melody_audio, normalize=True)
melody = melody.to(device) if torch.cuda.is_available() else melody
wav = model.generate_with_chroma(description, melody[None], sr)
else:
wav = model.generate(description)
else:
wav = model.generate_unconditional(1)
filename = f'{str(uuid.uuid4())}.wav'
path = audio_write(filename, wav[0].cpu().to(torch.float32), model.sample_rate, strategy="loudness", loudness_compressor=True)
if not os.path.exists(path):
raise ValueError(f'Failed to save audio to {path}')
return path
except Exception as e:
return str(e)
def clean_text(text):
text = re.sub(r'http\S+', '', text)
text = re.sub(r'[^a-zA-Z0-9\s]', '', text)
return text
# Definir la interfaz de Gradio
description = gr.Textbox(label="Description", placeholder="Acoustic, guitar, melody, trap, D minor, 90 bpm")
melody_audio = gr.Audio(label="Melody Audio (optional)", type="filepath")
duration = gr.Number(label="Duration (seconds)", value=10, precision=0)
output_path = gr.File(label="Generated Music")
gr.Interface(
fn=generate_music,
inputs=[description, melody_audio, duration],
outputs=output_path,
title="MusicGen Melody Demo",
description="Generate music using the MusicGen melody model. Optionally remix with an audio file. Download the generated audio file.",
examples=[
["happy rock", None, 8],
["energetic EDM", None, 8],
["chillwave", "./assets/example_melody.mp3", 10]
]
).launch()