Musimple / sample.py
ZheqiDAI's picture
Initial commit
070e26e
import os
import torch
import h5py
import random
import numpy as np
import soundfile as sf
from models import DiT
from diffusion import create_diffusion
from tqdm import tqdm
import sys
sys.path.append('./tools/bigvgan_v2_22khz_80band_256x')
from bigvgan import BigVGAN
from torch import nn
import torch.nn.functional as F
import argparse
device = 'cuda:1' if torch.cuda.is_available() else 'cpu'
class MelToAudio_bigvgan(nn.Module):
def __init__(self):
super().__init__()
self.vocoder = BigVGAN.from_pretrained('/home/zheqid/workspace/music_dit/bigvgan_v2_22khz_80band_256x', use_cuda_kernel=False)
self.vocoder.remove_weight_norm()
def __call__(self, z):
x = self.mel_to_audio(z)
return x
def mel_to_audio(self, x):
with torch.no_grad():
self.vocoder.eval()
y = self.vocoder(x[:, :, :])
y = y.squeeze(0)
return y
vocoder = MelToAudio_bigvgan().to(device)
def load_trained_model(checkpoint_path):
model = DiT(
input_size=(80, 800),
patch_size=8,
in_channels=1,
hidden_size=384,
depth=12,
num_heads=6,
)
model.to(device)
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
return model
def load_all_meta_and_mel_from_h5(h5_file):
with h5py.File(h5_file, 'r') as f:
keys = list(f.keys())
for key in keys:
meta_latent = torch.FloatTensor(f[key]['meta'][:]).to(device)
mel = torch.FloatTensor(f[key]['mel'][:]).to(device)
yield key, meta_latent, mel
def extract_random_mel_segment(mel, segment_length=800):
total_length = mel.shape[2]
if total_length > segment_length:
start = np.random.randint(0, total_length - segment_length)
mel_segment = mel[:, :, start:start + segment_length]
else:
padding = segment_length - total_length
mel_segment = F.pad(mel, (0, padding), mode='constant', value=0)
mel_segment = (mel_segment + 10) / 20
return mel_segment
def infer_and_generate_audio(model, diffusion, meta_latent):
latent_size = (80, 800)
z = torch.randn(1, 1, latent_size[0], latent_size[1], device=device)
model_kwargs = dict(y=meta_latent)
with torch.no_grad():
samples = diffusion.p_sample_loop(
model.forward, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device
)
return samples
def save_audio(mel, vocoder, output_path, sample_rate=24000):
with torch.no_grad():
if mel.dim() == 4 and mel.shape[1] == 1:
mel = mel[0, 0, :, :]
elif mel.dim() == 3 and mel.shape[0] == 1:
mel = mel[0]
else:
raise ValueError(f"Unexpected mel shape: {mel.shape}")
mel = mel.unsqueeze(0)
wav = vocoder(mel * 20 - 10).cpu().numpy()
sf.write(output_path, wav[0], samplerate=sample_rate)
print(f"Saved audio to: {output_path}")
def main():
parser = argparse.ArgumentParser(description='Generate audio using DiT and BigVGAN')
parser.add_argument('--checkpoint', type=str, required=True, help='Path to model checkpoint')
parser.add_argument('--h5_file', type=str, required=True, help='Path to input H5 file')
parser.add_argument('--output_gt_dir', type=str, required=True, help='Directory to save ground truth audio')
parser.add_argument('--output_gen_dir', type=str, required=True, help='Directory to save generated audio')
parser.add_argument('--segment_length', type=int, default=800, help='Segment length for mel slices (default: 800)')
parser.add_argument('--sample_rate', type=int, default=22050, help='Sample rate for output audio (default: 24000)')
args = parser.parse_args()
model = load_trained_model(args.checkpoint)
diffusion = create_diffusion(timestep_respacing="")
for i, (key, meta_latent, mel) in enumerate(tqdm(load_all_meta_and_mel_from_h5(args.h5_file))):
mel_segment = extract_random_mel_segment(mel, segment_length=args.segment_length)
ground_truth_wav_path = os.path.join(args.output_gt_dir, f"{key}.wav")
save_audio(mel_segment, vocoder, ground_truth_wav_path, sample_rate=args.sample_rate)
generated_mel = infer_and_generate_audio(model, diffusion, meta_latent)
output_wav_path = os.path.join(args.output_gen_dir, f"{key}.wav")
save_audio(generated_mel, vocoder, output_wav_path, sample_rate=args.sample_rate)
if __name__ == "__main__":
main()
### how to use
'''
python sample.py --checkpoint ./gtzan-ck/model_epoch_20000.pt \
--h5_file ./dataset/gtzan_test.h5 \
--output_gt_dir ./sample/gn \
--output_gen_dir ./sample/gt \
--segment_length 800 \
--sample_rate 22050
'''