Spaces:
Runtime error
Runtime error
| import os | |
| import time | |
| import functools | |
| import numpy as np | |
| from math import cos, pi, floor, sin | |
| from tqdm import tqdm | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from stft_loss import MultiResolutionSTFTLoss | |
| torch.manual_seed(0) | |
| np.random.seed(0) | |
| def flatten(v): | |
| return [x for y in v for x in y] | |
| def rescale(x): | |
| return (x - x.min()) / (x.max() - x.min()) | |
| def find_max_epoch(path): | |
| """ | |
| Find latest checkpoint | |
| Returns: | |
| maximum iteration, -1 if there is no (valid) checkpoint | |
| """ | |
| files = os.listdir(path) | |
| epoch = -1 | |
| for f in files: | |
| if len(f) <= 4: | |
| continue | |
| if f[-4:] == '.pkl': | |
| number = f[:-4] | |
| try: | |
| epoch = max(epoch, int(number)) | |
| except: | |
| continue | |
| return epoch | |
| def print_size(net, keyword=None): | |
| """ | |
| Print the number of parameters of a network | |
| """ | |
| if net is not None and isinstance(net, torch.nn.Module): | |
| module_parameters = filter(lambda p: p.requires_grad, net.parameters()) | |
| params = sum([np.prod(p.size()) for p in module_parameters]) | |
| print("{} Parameters: {:.6f}M".format( | |
| net.__class__.__name__, params / 1e6), flush=True, end="; ") | |
| if keyword is not None: | |
| keyword_parameters = [p for name, p in net.named_parameters() if p.requires_grad and keyword in name] | |
| params = sum([np.prod(p.size()) for p in keyword_parameters]) | |
| print("{} Parameters: {:.6f}M".format( | |
| keyword, params / 1e6), flush=True, end="; ") | |
| print(" ") | |
| ####################### lr scheduler: Linear Warmup then Cosine Decay ############################# | |
| # Adapted from https://github.com/rosinality/vq-vae-2-pytorch | |
| # Original Copyright 2019 Kim Seonghyeon | |
| # MIT License (https://opensource.org/licenses/MIT) | |
| def anneal_linear(start, end, proportion): | |
| return start + proportion * (end - start) | |
| def anneal_cosine(start, end, proportion): | |
| cos_val = cos(pi * proportion) + 1 | |
| return end + (start - end) / 2 * cos_val | |
| class Phase: | |
| def __init__(self, start, end, n_iter, cur_iter, anneal_fn): | |
| self.start, self.end = start, end | |
| self.n_iter = n_iter | |
| self.anneal_fn = anneal_fn | |
| self.n = cur_iter | |
| def step(self): | |
| self.n += 1 | |
| return self.anneal_fn(self.start, self.end, self.n / self.n_iter) | |
| def reset(self): | |
| self.n = 0 | |
| def is_done(self): | |
| return self.n >= self.n_iter | |
| class LinearWarmupCosineDecay: | |
| def __init__( | |
| self, | |
| optimizer, | |
| lr_max, | |
| n_iter, | |
| iteration=0, | |
| divider=25, | |
| warmup_proportion=0.3, | |
| phase=('linear', 'cosine'), | |
| ): | |
| self.optimizer = optimizer | |
| phase1 = int(n_iter * warmup_proportion) | |
| phase2 = n_iter - phase1 | |
| lr_min = lr_max / divider | |
| phase_map = {'linear': anneal_linear, 'cosine': anneal_cosine} | |
| cur_iter_phase1 = iteration | |
| cur_iter_phase2 = max(0, iteration - phase1) | |
| self.lr_phase = [ | |
| Phase(lr_min, lr_max, phase1, cur_iter_phase1, phase_map[phase[0]]), | |
| Phase(lr_max, lr_min / 1e4, phase2, cur_iter_phase2, phase_map[phase[1]]), | |
| ] | |
| if iteration < phase1: | |
| self.phase = 0 | |
| else: | |
| self.phase = 1 | |
| def step(self): | |
| lr = self.lr_phase[self.phase].step() | |
| for group in self.optimizer.param_groups: | |
| group['lr'] = lr | |
| if self.lr_phase[self.phase].is_done: | |
| self.phase += 1 | |
| if self.phase >= len(self.lr_phase): | |
| for phase in self.lr_phase: | |
| phase.reset() | |
| self.phase = 0 | |
| return lr | |
| ####################### model util ############################# | |
| def std_normal(size): | |
| """ | |
| Generate the standard Gaussian variable of a certain size | |
| """ | |
| return torch.normal(0, 1, size=size).cuda() | |
| def weight_scaling_init(layer): | |
| """ | |
| weight rescaling initialization from https://arxiv.org/abs/1911.13254 | |
| """ | |
| w = layer.weight.detach() | |
| alpha = 10.0 * w.std() | |
| layer.weight.data /= torch.sqrt(alpha) | |
| layer.bias.data /= torch.sqrt(alpha) | |
| def sampling(net, noisy_audio): | |
| """ | |
| Perform denoising (forward) step | |
| """ | |
| return net(noisy_audio) | |
| def loss_fn(net, X, ell_p, ell_p_lambda, stft_lambda, mrstftloss, **kwargs): | |
| """ | |
| Loss function in CleanUNet | |
| Parameters: | |
| net: network | |
| X: training data pair (clean audio, noisy_audio) | |
| ell_p: \ell_p norm (1 or 2) of the AE loss | |
| ell_p_lambda: factor of the AE loss | |
| stft_lambda: factor of the STFT loss | |
| mrstftloss: multi-resolution STFT loss function | |
| Returns: | |
| loss: value of objective function | |
| output_dic: values of each component of loss | |
| """ | |
| assert type(X) == tuple and len(X) == 2 | |
| clean_audio, noisy_audio = X | |
| B, C, L = clean_audio.shape | |
| output_dic = {} | |
| loss = 0.0 | |
| # AE loss | |
| denoised_audio = net(noisy_audio) | |
| if ell_p == 2: | |
| ae_loss = nn.MSELoss()(denoised_audio, clean_audio) | |
| elif ell_p == 1: | |
| ae_loss = F.l1_loss(denoised_audio, clean_audio) | |
| else: | |
| raise NotImplementedError | |
| loss += ae_loss * ell_p_lambda | |
| output_dic["reconstruct"] = ae_loss.data * ell_p_lambda | |
| if stft_lambda > 0: | |
| sc_loss, mag_loss = mrstftloss(denoised_audio.squeeze(1), clean_audio.squeeze(1)) | |
| loss += (sc_loss + mag_loss) * stft_lambda | |
| output_dic["stft_sc"] = sc_loss.data * stft_lambda | |
| output_dic["stft_mag"] = mag_loss.data * stft_lambda | |
| return loss, output_dic | |