# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: MIT # # Permission is hereby granted, free of charge, to any person obtaining a # copy of this software and associated documentation files (the "Software"), # to deal in the Software without restriction, including without limitation # the rights to use, copy, modify, merge, publish, distribute, sublicense, # and/or sell copies of the Software, and to permit persons to whom the # Software is furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in # all copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER # DEALINGS IN THE SOFTWARE. import torch from audio_processing import STFT class Denoiser(torch.nn.Module): """Removes model bias from audio produced with hifigan""" def __init__( self, hifigan, filter_length=1024, n_overlap=4, win_length=1024, mode="zeros" ): super(Denoiser, self).__init__() self.stft = STFT( filter_length=filter_length, hop_length=int(filter_length / n_overlap), win_length=win_length, ) self.stft = self.stft.to(hifigan.ups[0].weight.device) if mode == "zeros": mel_input = torch.zeros( (1, 80, 88), dtype=hifigan.ups[0].weight.dtype, device=hifigan.ups[0].weight.device, ) elif mode == "normal": mel_input = torch.randn( (1, 80, 88), dtype=hifigan.upsample.weight.dtype, device=hifigan.upsample.weight.device, ) else: raise Exception("Mode {} if not supported".format(mode)) with torch.no_grad(): bias_audio = hifigan(mel_input).float()[0] bias_spec, _ = self.stft.transform(bias_audio) self.register_buffer("bias_spec", bias_spec[:, :, 0][:, :, None]) def forward(self, audio, strength=0.1): audio_spec, audio_angles = self.stft.transform(audio.float()) audio_spec_denoised = audio_spec - self.bias_spec * strength audio_spec_denoised = torch.clamp(audio_spec_denoised, 0.0) audio_denoised = self.stft.inverse(audio_spec_denoised, audio_angles) return audio_denoised