|
import glob |
|
import torch |
|
from tqdm import tqdm |
|
from os import makedirs |
|
from soundfile import write |
|
from torchaudio import load |
|
from os.path import join, dirname |
|
from argparse import ArgumentParser |
|
from librosa import resample |
|
|
|
|
|
from sgmse.util.other import set_torch_cuda_arch_list |
|
set_torch_cuda_arch_list() |
|
|
|
from sgmse.model import ScoreModel |
|
from sgmse.util.other import pad_spec |
|
|
|
|
|
if __name__ == '__main__': |
|
parser = ArgumentParser() |
|
parser.add_argument("--test_dir", type=str, required=True, help='Directory containing the test data') |
|
parser.add_argument("--enhanced_dir", type=str, required=True, help='Directory containing the enhanced data') |
|
parser.add_argument("--ckpt", type=str, help='Path to model checkpoint') |
|
parser.add_argument("--corrector", type=str, choices=("ald", "langevin", "none"), default="ald", help="Corrector class for the PC sampler.") |
|
parser.add_argument("--corrector_steps", type=int, default=1, help="Number of corrector steps") |
|
parser.add_argument("--snr", type=float, default=0.5, help="SNR value for (annealed) Langevin dynmaics") |
|
parser.add_argument("--N", type=int, default=30, help="Number of reverse steps") |
|
parser.add_argument("--device", type=str, default="cuda", help="Device to use for inference") |
|
args = parser.parse_args() |
|
|
|
|
|
model = ScoreModel.load_from_checkpoint(args.ckpt, map_location=args.device) |
|
model.eval() |
|
|
|
|
|
noisy_files = [] |
|
noisy_files += sorted(glob.glob(join(args.test_dir, '*.wav'))) |
|
noisy_files += sorted(glob.glob(join(args.test_dir, '**', '*.wav'))) |
|
|
|
|
|
if model.backbone == 'ncsnpp_48k': |
|
target_sr = 48000 |
|
pad_mode = "reflection" |
|
else: |
|
target_sr = 16000 |
|
pad_mode = "zero_pad" |
|
|
|
|
|
for noisy_file in tqdm(noisy_files): |
|
filename = noisy_file.replace(args.test_dir, "") |
|
filename = filename[1:] if filename.startswith("/") else filename |
|
|
|
|
|
y, sr = load(noisy_file) |
|
|
|
|
|
if sr != target_sr: |
|
y = torch.tensor(resample(y.numpy(), orig_sr=sr, target_sr=target_sr)) |
|
|
|
T_orig = y.size(1) |
|
|
|
|
|
norm_factor = y.abs().max() |
|
y = y / norm_factor |
|
|
|
|
|
Y = torch.unsqueeze(model._forward_transform(model._stft(y.to(args.device))), 0) |
|
Y = pad_spec(Y, mode=pad_mode) |
|
|
|
|
|
sampler = model.get_pc_sampler( |
|
'reverse_diffusion', args.corrector, Y.to(args.device), N=args.N, |
|
corrector_steps=args.corrector_steps, snr=args.snr) |
|
sample, _ = sampler() |
|
|
|
|
|
x_hat = model.to_audio(sample.squeeze(), T_orig) |
|
|
|
|
|
x_hat = x_hat * norm_factor |
|
|
|
|
|
makedirs(dirname(join(args.enhanced_dir, filename)), exist_ok=True) |
|
write(join(args.enhanced_dir, filename), x_hat.cpu().numpy(), target_sr) |
|
|