USAD-Base / usad_modules.py
vectominist's picture
upload model and code
b038b10
# Copyright (c) 2021, Soohwan Kim. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import math
from collections import defaultdict
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch import nn
class SamePad(nn.Module):
def __init__(self, kernel_size, causal=False):
super().__init__()
if causal:
self.remove = kernel_size - 1
else:
self.remove = 1 if kernel_size % 2 == 0 else 0
def forward(self, x):
if self.remove > 0:
x = x[:, :, : -self.remove]
return x
class TransposeLast(nn.Module):
def __init__(self, deconstruct_idx=None, tranpose_dim=-2):
super().__init__()
self.deconstruct_idx = deconstruct_idx
self.tranpose_dim = tranpose_dim
def forward(self, x):
if self.deconstruct_idx is not None:
x = x[self.deconstruct_idx]
return x.transpose(self.tranpose_dim, -1)
class Swish(nn.Module):
def __init__(self):
super(Swish, self).__init__()
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
return inputs * inputs.sigmoid()
class GLU(nn.Module):
def __init__(self, dim: int) -> None:
super(GLU, self).__init__()
self.dim = dim
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
outputs, gate = inputs.chunk(2, dim=self.dim)
return outputs * gate.sigmoid()
class ResidualConnectionModule(nn.Module):
def __init__(
self,
module: nn.Module,
module_factor: float = 1.0,
input_factor: float = 1.0,
):
super(ResidualConnectionModule, self).__init__()
self.module = module
self.module_factor = module_factor
self.input_factor = input_factor
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
return (self.module(inputs) * self.module_factor) + (inputs * self.input_factor)
class Linear(nn.Module):
def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
super(Linear, self).__init__()
self.linear = nn.Linear(in_features, out_features, bias=bias)
nn.init.xavier_uniform_(self.linear.weight)
if bias:
nn.init.zeros_(self.linear.bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear(x)
class View(nn.Module):
def __init__(self, shape: tuple, contiguous: bool = False):
super(View, self).__init__()
self.shape = shape
self.contiguous = contiguous
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.contiguous:
x = x.contiguous()
return x.view(*self.shape)
class Transpose(nn.Module):
def __init__(self, shape: tuple):
super(Transpose, self).__init__()
self.shape = shape
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.transpose(*self.shape)
class FeedForwardModule(nn.Module):
def __init__(
self,
encoder_dim: int = 512,
expansion_factor: int = 4,
dropout_p: float = 0.1,
) -> None:
super(FeedForwardModule, self).__init__()
self.sequential = nn.Sequential(
nn.LayerNorm(encoder_dim),
Linear(encoder_dim, encoder_dim * expansion_factor, bias=True),
Swish(),
nn.Dropout(p=dropout_p),
Linear(encoder_dim * expansion_factor, encoder_dim, bias=True),
nn.Dropout(p=dropout_p),
)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
return self.sequential(inputs)
class DepthwiseConv1d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
padding: int = 0,
bias: bool = False,
) -> None:
super(DepthwiseConv1d, self).__init__()
assert (
out_channels % in_channels == 0
), "out_channels should be constant multiple of in_channels"
self.conv = nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
groups=in_channels,
stride=stride,
padding=padding,
bias=bias,
)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
return self.conv(inputs)
class PointwiseConv1d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
stride: int = 1,
padding: int = 0,
bias: bool = True,
) -> None:
super(PointwiseConv1d, self).__init__()
self.conv = nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=stride,
padding=padding,
bias=bias,
)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
return self.conv(inputs)
class ConformerConvModule(nn.Module):
def __init__(
self,
in_channels: int,
kernel_size: int = 31,
expansion_factor: int = 2,
dropout_p: float = 0.1,
) -> None:
super(ConformerConvModule, self).__init__()
assert (
kernel_size - 1
) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding"
assert expansion_factor == 2, "Currently, Only Supports expansion_factor 2"
self.sequential = nn.Sequential(
nn.LayerNorm(in_channels),
Transpose(shape=(1, 2)),
PointwiseConv1d(
in_channels,
in_channels * expansion_factor,
stride=1,
padding=0,
bias=True,
),
GLU(dim=1),
DepthwiseConv1d(
in_channels,
in_channels,
kernel_size,
stride=1,
padding=(kernel_size - 1) // 2,
),
nn.BatchNorm1d(in_channels),
Swish(),
PointwiseConv1d(in_channels, in_channels, stride=1, padding=0, bias=True),
nn.Dropout(p=dropout_p),
)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
return self.sequential(inputs).transpose(1, 2)
class FramewiseConv2dSubampling(nn.Module):
def __init__(self, out_channels: int, subsample_rate: int = 2) -> None:
super(FramewiseConv2dSubampling, self).__init__()
assert subsample_rate in {2, 4}, "subsample_rate should be 2 or 4"
self.subsample_rate = subsample_rate
self.cnn = nn.Sequential(
nn.Conv2d(1, out_channels, kernel_size=3, stride=2),
nn.ReLU(),
nn.Conv2d(
out_channels,
out_channels,
kernel_size=3,
stride=(2 if subsample_rate == 4 else 1, 2),
padding=(0 if subsample_rate == 4 else 1, 0),
),
nn.ReLU(),
)
def forward(
self, inputs: torch.Tensor, input_lengths: torch.LongTensor
) -> Tuple[torch.Tensor, torch.LongTensor]:
# inputs: (B, T, C) -> (B, 1, T, C)
if self.subsample_rate == 2 and inputs.shape[1] % 2 == 0:
inputs = F.pad(inputs, (0, 0, 0, 1), "constant", 0)
outputs = self.cnn(inputs.unsqueeze(1))
batch_size, channels, subsampled_lengths, sumsampled_dim = outputs.size()
outputs = outputs.permute(0, 2, 1, 3)
outputs = outputs.contiguous().view(
batch_size, subsampled_lengths, channels * sumsampled_dim
)
if self.subsample_rate == 4:
output_lengths = (((input_lengths - 1) >> 1) - 1) >> 1
else:
output_lengths = input_lengths >> 1
return outputs, output_lengths
class PatchwiseConv2dSubampling(nn.Module):
def __init__(
self,
mel_dim: int,
out_channels: int,
patch_size_time: int = 16,
patch_size_freq: int = 16,
) -> None:
super(PatchwiseConv2dSubampling, self).__init__()
self.mel_dim = mel_dim
self.patch_size_time = patch_size_time
self.patch_size_freq = patch_size_freq
self.proj = nn.Conv2d(
1,
out_channels,
kernel_size=(patch_size_time, patch_size_freq),
stride=(patch_size_time, patch_size_freq),
padding=0,
)
self.cnn = nn.Sequential(
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
)
@property
def subsample_rate(self) -> int:
return self.patch_size_time * self.patch_size_freq // self.mel_dim
def forward(
self, inputs: torch.Tensor, input_lengths: torch.LongTensor
) -> Tuple[torch.Tensor, torch.LongTensor]:
assert (
inputs.shape[2] == self.mel_dim
), "inputs.shape[2] should be equal to mel_dim"
# inputs: (B, Time, Freq) -> (B, 1, Time, Freq)
outputs = self.proj(inputs.unsqueeze(1))
outputs = self.cnn(outputs)
# (B, channels, Time // patch_size_time, Freq // patch_size_freq)
outputs = outputs.flatten(2, 3).transpose(1, 2)
# (B, (Time // patch_size_time) * (Freq // patch_size_freq), channels)
output_lengths = (
input_lengths
// self.patch_size_time
* (self.mel_dim // self.patch_size_freq)
)
return outputs, output_lengths
class RelPositionalEncoding(nn.Module):
def __init__(self, d_model: int, max_len: int = 10000) -> None:
super(RelPositionalEncoding, self).__init__()
self.d_model = d_model
self.pe = None
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
def extend_pe(self, x: torch.Tensor) -> None:
if self.pe is not None:
if self.pe.size(1) >= x.size(1) * 2 - 1:
if self.pe.dtype != x.dtype or self.pe.device != x.device:
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
pe_positive = torch.zeros(x.size(1), self.d_model)
pe_negative = torch.zeros(x.size(1), self.d_model)
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.d_model)
)
pe_positive[:, 0::2] = torch.sin(position * div_term)
pe_positive[:, 1::2] = torch.cos(position * div_term)
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
pe_negative = pe_negative[1:].unsqueeze(0)
pe = torch.cat([pe_positive, pe_negative], dim=1)
self.pe = pe.to(device=x.device, dtype=x.dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: (B, T, C)
self.extend_pe(x)
pos_emb = self.pe[
:,
self.pe.size(1) // 2 - x.size(1) + 1 : self.pe.size(1) // 2 + x.size(1),
]
return pos_emb
class RelativeMultiHeadAttention(nn.Module):
def __init__(
self,
d_model: int = 512,
num_heads: int = 16,
dropout_p: float = 0.1,
):
super(RelativeMultiHeadAttention, self).__init__()
assert d_model % num_heads == 0, "d_model % num_heads should be zero."
self.d_model = d_model
self.d_head = int(d_model / num_heads)
self.num_heads = num_heads
self.sqrt_dim = math.sqrt(self.d_head)
self.query_proj = Linear(d_model, d_model)
self.key_proj = Linear(d_model, d_model)
self.value_proj = Linear(d_model, d_model)
self.pos_proj = Linear(d_model, d_model, bias=False)
self.dropout = nn.Dropout(p=dropout_p)
self.u_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head))
self.v_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head))
torch.nn.init.xavier_uniform_(self.u_bias)
torch.nn.init.xavier_uniform_(self.v_bias)
self.out_proj = Linear(d_model, d_model)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
pos_embedding: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size = value.size(0)
query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head)
key = (
self.key_proj(key)
.view(batch_size, -1, self.num_heads, self.d_head)
.permute(0, 2, 1, 3)
)
value = (
self.value_proj(value)
.view(batch_size, -1, self.num_heads, self.d_head)
.permute(0, 2, 1, 3)
)
pos_embedding = self.pos_proj(pos_embedding).view(
batch_size, -1, self.num_heads, self.d_head
)
content_score = torch.matmul(
(query + self.u_bias).transpose(1, 2), key.transpose(2, 3)
)
pos_score = torch.matmul(
(query + self.v_bias).transpose(1, 2),
pos_embedding.permute(0, 2, 3, 1),
)
pos_score = self._relative_shift(pos_score)
score = (content_score + pos_score) / self.sqrt_dim
if mask is not None:
mask = mask.unsqueeze(1)
score.masked_fill_(mask, -1e9)
attn = F.softmax(score, -1)
attn = self.dropout(attn)
context = torch.matmul(attn, value).transpose(1, 2)
context = context.contiguous().view(batch_size, -1, self.d_model)
return self.out_proj(context), attn
def _relative_shift(self, pos_score: torch.Tensor) -> torch.Tensor:
batch_size, num_heads, seq_length1, seq_length2 = pos_score.size()
zeros = pos_score.new_zeros(batch_size, num_heads, seq_length1, 1)
padded_pos_score = torch.cat([zeros, pos_score], dim=-1)
padded_pos_score = padded_pos_score.view(
batch_size, num_heads, seq_length2 + 1, seq_length1
)
pos_score = padded_pos_score[:, :, 1:].view_as(pos_score)[
:, :, :, : seq_length2 // 2 + 1
]
return pos_score
class MultiHeadedSelfAttentionModule(nn.Module):
def __init__(self, d_model: int, num_heads: int, dropout_p: float = 0.1):
super(MultiHeadedSelfAttentionModule, self).__init__()
self.positional_encoding = RelPositionalEncoding(d_model)
self.layer_norm = nn.LayerNorm(d_model)
self.attention = RelativeMultiHeadAttention(d_model, num_heads, dropout_p)
self.dropout = nn.Dropout(p=dropout_p)
def forward(
self, inputs: torch.Tensor, mask: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size = inputs.size(0)
pos_embedding = self.positional_encoding(inputs)
pos_embedding = pos_embedding.repeat(batch_size, 1, 1)
inputs = self.layer_norm(inputs)
outputs, attn = self.attention(
inputs, inputs, inputs, pos_embedding=pos_embedding, mask=mask
)
return self.dropout(outputs), attn
class ConformerBlock(nn.Module):
def __init__(
self,
encoder_dim: int = 512,
attention_type: str = "mhsa",
num_attention_heads: int = 8,
mamba_d_state: int = 16,
mamba_d_conv: int = 4,
mamba_expand: int = 2,
mamba_bidirectional: bool = True,
feed_forward_expansion_factor: int = 4,
conv_expansion_factor: int = 2,
feed_forward_dropout_p: float = 0.1,
attention_dropout_p: float = 0.1,
conv_dropout_p: float = 0.1,
conv_kernel_size: int = 31,
half_step_residual: bool = True,
transformer_style: bool = False,
):
super(ConformerBlock, self).__init__()
self.transformer_style = transformer_style
self.attention_type = attention_type
if half_step_residual and not transformer_style:
self.feed_forward_residual_factor = 0.5
else:
self.feed_forward_residual_factor = 1
assert attention_type in ["mhsa", "mamba"]
if attention_type == "mhsa":
attention = MultiHeadedSelfAttentionModule(
d_model=encoder_dim,
num_heads=num_attention_heads,
dropout_p=attention_dropout_p,
)
self.ffn_1 = FeedForwardModule(
encoder_dim=encoder_dim,
expansion_factor=feed_forward_expansion_factor,
dropout_p=feed_forward_dropout_p,
)
self.attention = attention
if not transformer_style:
self.conv = ConformerConvModule(
in_channels=encoder_dim,
kernel_size=conv_kernel_size,
expansion_factor=conv_expansion_factor,
dropout_p=conv_dropout_p,
)
self.ffn_2 = FeedForwardModule(
encoder_dim=encoder_dim,
expansion_factor=feed_forward_expansion_factor,
dropout_p=feed_forward_dropout_p,
)
self.layernorm = nn.LayerNorm(encoder_dim)
def forward(
self, x: torch.Tensor
) -> Tuple[torch.Tensor, Dict[str, Union[torch.Tensor, None]]]:
# FFN 1
ffn_1_out = self.ffn_1(x)
x = ffn_1_out * self.feed_forward_residual_factor + x
# Attention
if not isinstance(self.attention, MultiHeadedSelfAttentionModule):
# MAMBA
attn_out = self.attention(x)
attn = None
else:
attn_out, attn = self.attention(x)
x = attn_out + x
if self.transformer_style:
x = self.layernorm(x)
return x, {
"ffn_1": ffn_1_out,
"attn": attn,
"conv": None,
"ffn_2": None,
}
# Convolution
conv_out = self.conv(x)
x = conv_out + x
# FFN 2
ffn_2_out = self.ffn_2(x)
x = ffn_2_out * self.feed_forward_residual_factor + x
x = self.layernorm(x)
other = {
"ffn_1": ffn_1_out,
"attn": attn,
"conv": conv_out,
"ffn_2": ffn_2_out,
}
return x, other
class ConformerEncoder(nn.Module):
def __init__(self, cfg):
super(ConformerEncoder, self).__init__()
self.cfg = cfg
self.framewise_subsample = None
self.patchwise_subsample = None
self.framewise_in_proj = None
self.patchwise_in_proj = None
assert (
cfg.use_framewise_subsample or cfg.use_patchwise_subsample
), "At least one subsampling method should be used"
if cfg.use_framewise_subsample:
self.framewise_subsample = FramewiseConv2dSubampling(
out_channels=cfg.conv_subsample_channels,
subsample_rate=cfg.conv_subsample_rate,
)
self.framewise_in_proj = nn.Sequential(
Linear(
cfg.conv_subsample_channels * (((cfg.input_dim - 1) // 2 - 1) // 2),
cfg.encoder_dim,
),
nn.Dropout(p=cfg.input_dropout_p),
)
if cfg.use_patchwise_subsample:
self.patchwise_subsample = PatchwiseConv2dSubampling(
mel_dim=cfg.input_dim,
out_channels=cfg.conv_subsample_channels,
patch_size_time=cfg.patch_size_time,
patch_size_freq=cfg.patch_size_freq,
)
self.patchwise_in_proj = nn.Sequential(
Linear(
cfg.conv_subsample_channels,
cfg.encoder_dim,
),
nn.Dropout(p=cfg.input_dropout_p),
)
assert not cfg.use_framewise_subsample or (
cfg.conv_subsample_rate == self.patchwise_subsample.subsample_rate
), (
f"conv_subsample_rate ({cfg.conv_subsample_rate}) != patchwise_subsample.subsample_rate"
f"({self.patchwise_subsample.subsample_rate})"
)
self.framewise_norm, self.patchwise_norm = None, None
if getattr(cfg, "subsample_normalization", False):
if cfg.use_framewise_subsample:
self.framewise_norm = nn.LayerNorm(cfg.encoder_dim)
if cfg.use_patchwise_subsample:
self.patchwise_norm = nn.LayerNorm(cfg.encoder_dim)
self.conv_pos = None
if getattr(cfg, "conv_pos", False):
num_pos_layers = cfg.conv_pos_depth
k = max(3, cfg.conv_pos_width // num_pos_layers)
self.conv_pos = nn.Sequential(
TransposeLast(),
*[
nn.Sequential(
nn.Conv1d(
cfg.encoder_dim,
cfg.encoder_dim,
kernel_size=k,
padding=k // 2,
groups=cfg.conv_pos_groups,
),
SamePad(k),
TransposeLast(),
nn.LayerNorm(cfg.encoder_dim, elementwise_affine=False),
TransposeLast(),
nn.GELU(),
)
for _ in range(num_pos_layers)
],
TransposeLast(),
)
self.conv_pos_post_ln = nn.LayerNorm(cfg.encoder_dim)
self.layers = nn.ModuleList(
[
ConformerBlock(
encoder_dim=cfg.encoder_dim,
attention_type=cfg.attention_type,
num_attention_heads=cfg.num_attention_heads,
mamba_d_state=cfg.mamba_d_state,
mamba_d_conv=cfg.mamba_d_conv,
mamba_expand=cfg.mamba_expand,
mamba_bidirectional=cfg.mamba_bidirectional,
feed_forward_expansion_factor=cfg.feed_forward_expansion_factor,
conv_expansion_factor=cfg.conv_expansion_factor,
feed_forward_dropout_p=cfg.feed_forward_dropout_p,
attention_dropout_p=cfg.attention_dropout_p,
conv_dropout_p=cfg.conv_dropout_p,
conv_kernel_size=cfg.conv_kernel_size,
half_step_residual=cfg.half_step_residual,
transformer_style=getattr(cfg, "transformer_style", False),
)
for _ in range(cfg.num_layers)
]
)
def count_parameters(self) -> int:
"""Count parameters of encoder"""
return sum([p.numel() for p in self.parameters() if p.requires_grad])
def update_dropout(self, dropout_p: float) -> None:
"""Update dropout probability of encoder"""
for name, child in self.named_children():
if isinstance(child, nn.Dropout):
child.p = dropout_p
def forward(
self,
inputs: torch.Tensor,
input_lengths: Optional[torch.Tensor] = None,
return_hidden: bool = False,
freeze_input_layers: bool = False,
target_layer: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, List[torch.Tensor]]]:
if input_lengths is None:
input_lengths = torch.full(
(inputs.size(0),),
inputs.size(1),
dtype=torch.long,
device=inputs.device,
)
with torch.no_grad() if freeze_input_layers else contextlib.ExitStack():
frame_feat, patch_feat = None, None
if self.framewise_subsample is not None:
frame_feat, frame_lengths = self.framewise_subsample(
inputs, input_lengths
)
frame_feat = self.framewise_in_proj(frame_feat)
if self.framewise_norm is not None:
frame_feat = self.framewise_norm(frame_feat)
if self.patchwise_subsample is not None:
patch_feat, patch_lengths = self.patchwise_subsample(
inputs, input_lengths
)
patch_feat = self.patchwise_in_proj(patch_feat)
if self.patchwise_norm is not None:
patch_feat = self.patchwise_norm(patch_feat)
if frame_feat is not None and patch_feat is not None:
min_len = min(frame_feat.size(1), patch_feat.size(1))
frame_feat = frame_feat[:, :min_len]
patch_feat = patch_feat[:, :min_len]
features = frame_feat + patch_feat
output_lengths = (
frame_lengths
if frame_lengths.max().item() < patch_lengths.max().item()
else patch_lengths
)
elif frame_feat is not None:
features = frame_feat
output_lengths = frame_lengths
else:
features = patch_feat
output_lengths = patch_lengths
if self.conv_pos is not None:
features = features + self.conv_pos(features)
features = self.conv_pos_post_ln(features)
layer_results = defaultdict(list)
outputs = features
for i, layer in enumerate(self.layers):
outputs, other = layer(outputs)
if return_hidden:
layer_results["hidden_states"].append(outputs)
for k, v in other.items():
layer_results[k].append(v)
if target_layer is not None and i == target_layer:
break
return outputs, output_lengths, layer_results