Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| import argparse | |
| import math | |
| from einops import rearrange, repeat | |
| from PIL import Image | |
| from diffusers import AutoencoderKL | |
| from transformers import SpeechT5HifiGan | |
| from utils import load_t5, load_clap, load_ae | |
| from train import RF | |
| from constants import build_model | |
| def prepare(t5, clip, img, prompt): | |
| bs, c, h, w = img.shape | |
| if bs == 1 and not isinstance(prompt, str): | |
| bs = len(prompt) | |
| img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) | |
| if img.shape[0] == 1 and bs > 1: | |
| img = repeat(img, "1 ... -> bs ...", bs=bs) | |
| img_ids = torch.zeros(h // 2, w // 2, 3) | |
| img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] | |
| img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] | |
| img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) | |
| if isinstance(prompt, str): | |
| prompt = [prompt] | |
| txt = t5(prompt) | |
| if txt.shape[0] == 1 and bs > 1: | |
| txt = repeat(txt, "1 ... -> bs ...", bs=bs) | |
| txt_ids = torch.zeros(bs, txt.shape[1], 3) | |
| vec = clip(prompt) | |
| if vec.shape[0] == 1 and bs > 1: | |
| vec = repeat(vec, "1 ... -> bs ...", bs=bs) | |
| print(img_ids.size(), txt.size(), vec.size()) | |
| return img, { | |
| "img_ids": img_ids.to(img.device), | |
| "txt": txt.to(img.device), | |
| "txt_ids": txt_ids.to(img.device), | |
| "y": vec.to(img.device), | |
| } | |
| def main(args): | |
| print('generate with MusicFlux') | |
| torch.manual_seed(args.seed) | |
| torch.set_grad_enabled(False) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| latent_size = (256, 16) | |
| model = build_model(args.version).to(device) | |
| local_path = '/maindata/data/shared/multimodal/zhengcong.fei/code/music-flow/results/base/checkpoints/0050000.pt' | |
| state_dict = torch.load(local_path, map_location=lambda storage, loc: storage) | |
| model.load_state_dict(state_dict['ema']) | |
| model.eval() # important! | |
| diffusion = RF() | |
| model_path = '/maindata/data/shared/multimodal/public/ckpts/FLUX.1-dev' | |
| # Setup VAE | |
| t5 = load_t5(device, max_length=256) | |
| clap = load_clap(device, max_length=256) | |
| model_path = '/maindata/data/shared/multimodal/public/dataset_music/audioldm2' | |
| vae = AutoencoderKL.from_pretrained(os.path.join(model_path, 'vae')).to(device) | |
| vocoder = SpeechT5HifiGan.from_pretrained(os.path.join(model_path, 'vocoder')).to(device) | |
| with open(args.prompt_file, 'r') as f: | |
| conds_txt = f.readlines() | |
| L = len(conds_txt) | |
| unconds_txt = ["low quality, gentle"] * L | |
| print(L, conds_txt, unconds_txt) | |
| init_noise = torch.randn(L, 8, latent_size[0], latent_size[1]).cuda() | |
| STEPSIZE = 50 | |
| img, conds = prepare(t5, clap, init_noise, conds_txt) | |
| _, unconds = prepare(t5, clap, init_noise, unconds_txt) | |
| with torch.autocast(device_type='cuda'): | |
| images = diffusion.sample_with_xps(model, img, conds=conds, null_cond=unconds, sample_steps = STEPSIZE, cfg = 7.0) | |
| print(images[-1].size(), ) | |
| images = rearrange( | |
| images[-1], | |
| "b (h w) (c ph pw) -> b c (h ph) (w pw)", | |
| h=128, | |
| w=8, | |
| ph=2, | |
| pw=2,) | |
| # print(images.size()) | |
| latents = 1 / vae.config.scaling_factor * images | |
| mel_spectrogram = vae.decode(latents).sample | |
| print(mel_spectrogram.size()) | |
| for i in range(L): | |
| x_i = mel_spectrogram[i] | |
| if x_i.dim() == 4: | |
| x_i = x_i.squeeze(1) | |
| waveform = vocoder(x_i) | |
| waveform = waveform[0].cpu().float().detach().numpy() | |
| print(waveform.shape) | |
| # import soundfile as sf | |
| # sf.write('reconstruct.wav', waveform, samplerate=16000) | |
| from scipy.io import wavfile | |
| wavfile.write('wav/sample_' + str(i) + '.wav', 16000, waveform) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--version", type=str, default="base") | |
| parser.add_argument("--prompt_file", type=str, default='config/example.txt') | |
| parser.add_argument("--seed", type=int, default=2024) | |
| args = parser.parse_args() | |
| main(args) | |