import os import sys import numpy as np import torch import gradio as gr from vae_module import VAE, Encoder, Decoder, loss_function from config import config from slicer_module import get_slices from diffusers import UNet2DConditionModel, DDPMScheduler from mel_module import Mel from generator_module import Generator import shutil slices_folder = 'slices' if os.path.exists(slices_folder): # delete previous tracks shutil.rmtree(slices_folder) vae = VAE() vae.load_state_dict(torch.load('vae_model_state_dict.pth', map_location=torch.device('cpu'))) vae.to(config.device) vae.eval() model = UNet2DConditionModel.from_pretrained(config.hub_model_id, subfolder="unet") noise_scheduler = DDPMScheduler.from_pretrained(config.hub_model_id, subfolder="scheduler") def generate_new_track(audio_paths, progress=gr.Progress(track_tqdm=True)): for i, audio_path in enumerate(audio_paths): print(audio_paths, audio_path) get_slices(audio_path) embedding = get_embedding() print("sample latent", embedding.shape) generator = Generator(config, model, noise_scheduler, vae, embedding, progress_callback=progress) generator.generate() return config.generated_track_path def get_embedding(): # returns middle point of given audio files latent representations latents = [] slices_dir = 'slices' for slice_file in os.listdir(slices_dir): if slice_file.endswith('.wav'): # make sure the file is audio mel = Mel(os.path.join(slices_dir, slice_file)) spectrogram = mel.get_spectrogram() tensor = torch.tensor(spectrogram).float().unsqueeze(0).unsqueeze(0) mu, log_var = vae.encode(tensor) latent = torch.cat((mu, log_var), dim=1) min_val = latent.min() max_val = latent.max() normalized_tensor = 2 * ((latent - min_val) / (max_val - min_val)) - 1 latent = normalized_tensor.unsqueeze(0) latents.append(latent) if not latents: return None latents_tensor = torch.cat(latents, dim=0) mean_latent = latents_tensor.mean(dim=0, keepdim=True) return mean_latent interface = gr.Interface( fn=generate_new_track, inputs=gr.Files(file_count="multiple", label="Upload Your Audio Files"), outputs=gr.Audio(type="filepath", label="Generated Track"), title="AMUSE: Music Generation", description = ( "
Here's how it works:
" "