Spaces:
Running
Running
# Third-party | |
import torch | |
import torch.nn as nn | |
# Local | |
from src.Sound_Feature_Extraction.short_time_fourier_transform import STFT | |
COMPUTATION_DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
class Conv_TDF(nn.Module): | |
""" | |
Convolutional Time-Domain Filter (TDF) Module. | |
Args: | |
c (int): The number of input and output channels for the convolutional layers. | |
l (int): The number of convolutional layers within the module. | |
f (int): The number of features (or units) in the time-domain filter. | |
k (int): The size of the convolutional kernels (filters). | |
bn (int or None): Batch normalization factor (controls TDF behavior). If None, TDF is not used. | |
bias (bool): A boolean flag indicating whether bias terms are included in the linear layers. | |
Attributes: | |
use_tdf (bool): Flag indicating whether TDF is used. | |
Methods: | |
forward(x): Forward pass through the TDF module. | |
""" | |
def __init__(self, c, l, f, k, bn, bias=True): | |
super(Conv_TDF, self).__init__() | |
# Determine whether to use TDF (Time-Domain Filter) | |
self.use_tdf = bn is not None | |
# Define a list of convolutional layers within the module | |
self.H = nn.ModuleList() | |
for i in range(l): | |
self.H.append( | |
nn.Sequential( | |
nn.Conv2d( | |
in_channels=c, | |
out_channels=c, | |
kernel_size=k, | |
stride=1, | |
padding=k // 2, | |
), | |
nn.GroupNorm(2, c), | |
nn.ReLU(), | |
) | |
) | |
# Define the Time-Domain Filter (TDF) layers if enabled | |
if self.use_tdf: | |
if bn == 0: | |
self.tdf = nn.Sequential( | |
nn.Linear(f, f, bias=bias), nn.GroupNorm(2, c), nn.ReLU() | |
) | |
else: | |
self.tdf = nn.Sequential( | |
nn.Linear(f, f // bn, bias=bias), | |
nn.GroupNorm(2, c), | |
nn.ReLU(), | |
nn.Linear(f // bn, f, bias=bias), | |
nn.GroupNorm(2, c), | |
nn.ReLU(), | |
) | |
def forward(self, x): | |
# Apply the convolutional layers sequentially | |
for h in self.H: | |
x = h(x) | |
# Apply the Time-Domain Filter (TDF) if enabled, and add the result to the orignal input | |
return x + self.tdf(x) if self.use_tdf else x | |
class Conv_TDF_net_trimm(nn.Module): | |
""" | |
Convolutional Time-Domain Filter (TDF) Network with Trimming. | |
Args: | |
L (int): This parameter controls the number of down-sampling (DS) blocks in the network. | |
It's divided by 2 to determine how many DS blocks should be created. | |
l (int): This parameter represents the number of convolutional layers (or filters) within each dense (fully connected) block. | |
g (int): This parameter specifies the number of output channels for the first convolutional layer and is also used to determine the number of channels for subsequent layers in the network. | |
dim_f (int): This parameter represents the number of frequency bins (spectrogram columns) in the input audio data. | |
dim_t (int): This parameter represents the number of time frames (spectrogram rows) in the input audio data. | |
k (int): This parameter specifies the size of convolutional kernels (filters) used in the network's convolutional layers. | |
bn (int or None): This parameter controls whether batch normalization is used in the network. | |
If it's None, batch normalization may or may not be used based on other conditions in the code. | |
bias (bool): This parameter is a boolean flag that controls whether bias terms are included in the convolutional layers. | |
overlap (int): This parameter specifies the amount of overlap between consecutive chunks of audio data during processing. | |
Attributes: | |
n (int): The calculated number of down-sampling (DS) blocks. | |
dim_f (int): The number of frequency bins (spectrogram columns) in the input audio data. | |
dim_t (int): The number of time frames (spectrogram rows) in the input audio data. | |
n_fft (int): The size of the Fast Fourier Transform (FFT) window. | |
hop (int): The hop size used in the STFT calculations. | |
n_bins (int): The number of bins in the frequency domain. | |
chunk_size (int): The size of each chunk of audio data. | |
target_name (str): The name of the target instrument being separated. | |
overlap (int): The amount of overlap between consecutive chunks of audio data during processing. | |
Methods: | |
forward(x): Forward pass through the Conv_TDF_net_trimm network. | |
""" | |
def __init__( | |
self, | |
model_path, | |
use_onnx, | |
target_name, | |
L, | |
l, | |
g, | |
dim_f, | |
dim_t, | |
k=3, | |
hop=1024, | |
bn=None, | |
bias=True, | |
overlap=1500, | |
): | |
super(Conv_TDF_net_trimm, self).__init__() | |
# Dictionary specifying the scale for the number of FFT bins for different target names | |
n_fft_scale = {"vocals": 3, "*": 2} | |
# Number of input and output channels for the initial and final convolutional layers | |
out_c = in_c = 4 | |
# Number of down-sampling (DS) blocks | |
self.n = L // 2 | |
# Dimensions of the frequency and time axes of the input data | |
self.dim_f = 3072 | |
self.dim_t = 256 | |
# Number of FFT bins (frequencies) and hop size for the Short-Time Fourier Transform (STFT) | |
self.n_fft = 7680 | |
self.hop = hop | |
self.n_bins = self.n_fft // 2 + 1 | |
# Chunk size used for processing | |
self.chunk_size = hop * (self.dim_t - 1) | |
# Target name for the model | |
self.target_name = target_name | |
# Overlap between consecutive chunks of audio data during processing | |
self.overlap = overlap | |
# STFT module for audio processing | |
self.stft = STFT(self.n_fft, self.hop, self.dim_f) | |
# Check if ONNX representation of the model should be used | |
if not use_onnx: | |
# First convolutional layer | |
self.first_conv = nn.Sequential( | |
nn.Conv2d(in_channels=in_c, out_channels=g, kernel_size=1, stride=1), | |
nn.BatchNorm2d(g), | |
nn.ReLU(), | |
) | |
# Initialize variables for dense (fully connected) blocks and downsampling (DS) blocks | |
f = self.dim_f | |
c = g | |
self.ds_dense = nn.ModuleList() | |
self.ds = nn.ModuleList() | |
# Loop through down-sampling (DS) blocks | |
for i in range(self.n): | |
# Create dense (fully connected) block for down-sampling | |
self.ds_dense.append(Conv_TDF(c, l, f, k, bn, bias=bias)) | |
# Create down-sampling (DS) block | |
scale = (2, 2) | |
self.ds.append( | |
nn.Sequential( | |
nn.Conv2d( | |
in_channels=c, | |
out_channels=c + g, | |
kernel_size=scale, | |
stride=scale, | |
), | |
nn.BatchNorm2d(c + g), | |
nn.ReLU(), | |
) | |
) | |
f = f // 2 | |
c += g | |
# Middle dense (fully connected block) | |
self.mid_dense = Conv_TDF(c, l, f, k, bn, bias=bias) | |
# If batch normalization is not specified and mid_tdf is True, use Conv_TDF with bn=0 and bias=False | |
if bn is None and mid_tdf: | |
self.mid_dense = Conv_TDF(c, l, f, k, bn=0, bias=False) | |
# Initialize variables for up-sampling (US) blocks | |
self.us_dense = nn.ModuleList() | |
self.us = nn.ModuleList() | |
# Loop through up-sampling (US) blocks | |
for i in range(self.n): | |
scale = (2, 2) | |
# Create up-sampling (US) block | |
self.us.append( | |
nn.Sequential( | |
nn.ConvTranspose2d( | |
in_channels=c, | |
out_channels=c - g, | |
kernel_size=scale, | |
stride=scale, | |
), | |
nn.BatchNorm2d(c - g), | |
nn.ReLU(), | |
) | |
) | |
f = f * 2 | |
c -= g | |
# Create dense (fully connected) block for up-sampling | |
self.us_dense.append(Conv_TDF(c, l, f, k, bn, bias=bias)) | |
# Final convolutional layer | |
self.final_conv = nn.Sequential( | |
nn.Conv2d(in_channels=c, out_channels=out_c, kernel_size=1, stride=1), | |
) | |
try: | |
# Load model state from a file | |
self.load_state_dict( | |
torch.load( | |
f"{model_path}/{target_name}.pt", | |
map_location=COMPUTATION_DEVICE, | |
) | |
) | |
print(f"Loading model ({target_name})") | |
except FileNotFoundError: | |
print(f"Random init ({target_name})") | |
def forward(self, x): | |
""" | |
Forward pass through the Conv_TDF_net_trimm network. | |
Args: | |
x (torch.Tensor): Input tensor. | |
Returns: | |
torch.Tensor: Output tensor after passing through the network. | |
""" | |
x = self.first_conv(x) | |
x = x.transpose(-1, -2) | |
ds_outputs = [] | |
for i in range(self.n): | |
x = self.ds_dense[i](x) | |
ds_outputs.append(x) | |
x = self.ds[i](x) | |
x = self.mid_dense(x) | |
for i in range(self.n): | |
x = self.us[i](x) | |
x *= ds_outputs[-i - 1] | |
x = self.us_dense[i](x) | |
x = x.transpose(-1, -2) | |
x = self.final_conv(x) | |
return x | |