Spaces:
Sleeping
Sleeping
IlayMalinyak
commited on
Commit
·
b3fb4dd
1
Parent(s):
192ac3b
first commit
Browse files- tasks/Modules/ResNet18.py +69 -0
- tasks/Modules/__init__.py +0 -0
- tasks/Modules/cnn.py +58 -0
- tasks/Modules/conformer.py +584 -0
- tasks/Modules/mhsa_pro.py +231 -0
- tasks/audio.py +37 -6
- tasks/config.yaml +66 -0
- tasks/data.py +43 -0
- tasks/data_utils.py +63 -0
- tasks/models.py +114 -0
- tasks/train.py +293 -0
tasks/Modules/ResNet18.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
# https://github.com/samcw/ResNet18-Pytorch
|
| 5 |
+
class ResBlock(nn.Module):
|
| 6 |
+
def __init__(self, inchannel, outchannel, stride=1):
|
| 7 |
+
super(ResBlock, self).__init__()
|
| 8 |
+
self.left = nn.Sequential(
|
| 9 |
+
nn.Conv1d(inchannel, outchannel, kernel_size=3, stride=stride, padding=1, bias=False),
|
| 10 |
+
nn.BatchNorm1d(outchannel),
|
| 11 |
+
nn.ReLU(inplace=True),
|
| 12 |
+
nn.Conv1d(outchannel, outchannel, kernel_size=3, stride=1, padding=1, bias=False),
|
| 13 |
+
nn.BatchNorm1d(outchannel)
|
| 14 |
+
)
|
| 15 |
+
self.shortcut = nn.Sequential()
|
| 16 |
+
if stride != 1 or inchannel != outchannel:
|
| 17 |
+
self.shortcut = nn.Sequential(
|
| 18 |
+
nn.Conv1d(inchannel, outchannel, kernel_size=1, stride=stride, bias=False),
|
| 19 |
+
nn.BatchNorm1d(outchannel)
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
def forward(self, x):
|
| 23 |
+
out = self.left(x)
|
| 24 |
+
out = out + self.shortcut(x)
|
| 25 |
+
out = F.relu(out)
|
| 26 |
+
|
| 27 |
+
return out
|
| 28 |
+
|
| 29 |
+
class ResNet18(nn.Module):
|
| 30 |
+
def __init__(self, args):
|
| 31 |
+
super(ResNet18, self).__init__()
|
| 32 |
+
self.inchannel = 64
|
| 33 |
+
self.conv1 = nn.Sequential(
|
| 34 |
+
nn.Conv1d(1, 64, kernel_size=3, stride=1, padding=1, bias=False),
|
| 35 |
+
nn.BatchNorm1d(64),
|
| 36 |
+
nn.ReLU()
|
| 37 |
+
)
|
| 38 |
+
self.layer1 = self.make_layer(ResBlock, 64, 2, stride=1)
|
| 39 |
+
self.layer2 = self.make_layer(ResBlock, 128, 2, stride=2)
|
| 40 |
+
self.layer3 = self.make_layer(ResBlock, 256, 2, stride=2)
|
| 41 |
+
self.layer4 = self.make_layer(ResBlock, 512, 2, stride=2)
|
| 42 |
+
self.pred_layer = nn.Sequential(
|
| 43 |
+
nn.Linear(512, 512),
|
| 44 |
+
nn.SiLU(),
|
| 45 |
+
nn.Dropout(p=0.3),
|
| 46 |
+
nn.Linear(512, 1),
|
| 47 |
+
)
|
| 48 |
+
if getattr(args, 'mean_label', False):
|
| 49 |
+
self.pred_layer[3].bias.data.fill_(args.mean_label)
|
| 50 |
+
|
| 51 |
+
def make_layer(self, block, channels, num_blocks, stride):
|
| 52 |
+
strides = [stride] + [1] * (num_blocks - 1)
|
| 53 |
+
layers = []
|
| 54 |
+
for stride in strides:
|
| 55 |
+
layers.append(block(self.inchannel, channels, stride))
|
| 56 |
+
self.inchannel = channels
|
| 57 |
+
return nn.Sequential(*layers)
|
| 58 |
+
|
| 59 |
+
def forward(self, x):
|
| 60 |
+
x = x.unsqueeze(1)
|
| 61 |
+
out = self.conv1(x)
|
| 62 |
+
out = F.max_pool1d(out, 3, 2, 1)
|
| 63 |
+
out = self.layer1(out)
|
| 64 |
+
out = self.layer2(out)
|
| 65 |
+
out = self.layer3(out)
|
| 66 |
+
out = self.layer4(out)
|
| 67 |
+
out = out.mean(-1)
|
| 68 |
+
out = self.pred_layer(out)
|
| 69 |
+
return out
|
tasks/Modules/__init__.py
ADDED
|
File without changes
|
tasks/Modules/cnn.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
class ConvBlock(nn.Module):
|
| 5 |
+
def __init__(self, args) -> None:
|
| 6 |
+
super().__init__()
|
| 7 |
+
self.layers = nn.Sequential(
|
| 8 |
+
nn.Conv1d(in_channels=args.encoder_dim,
|
| 9 |
+
out_channels=args.encoder_dim,
|
| 10 |
+
kernel_size=args.kernel_size,
|
| 11 |
+
stride=1, padding='same', bias=False),
|
| 12 |
+
nn.BatchNorm1d(num_features=args.encoder_dim),
|
| 13 |
+
nn.SiLU(),
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 17 |
+
x = x.transpose(1, 2)
|
| 18 |
+
return self.layers(x).transpose(1, 2)
|
| 19 |
+
|
| 20 |
+
class ConvBlockDecoder(nn.Module):
|
| 21 |
+
def __init__(self, args) -> None:
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.layers = nn.Sequential(
|
| 24 |
+
nn.Conv1d(in_channels=args.decoder_dim,
|
| 25 |
+
out_channels=args.decoder_dim,
|
| 26 |
+
kernel_size=args.kernel_size,
|
| 27 |
+
stride=1, padding='same', bias=False),
|
| 28 |
+
nn.BatchNorm1d(num_features=args.decoder_dim),
|
| 29 |
+
nn.SiLU(),
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 33 |
+
x = x.transpose(1, 2)
|
| 34 |
+
return self.layers(x).transpose(1, 2)
|
| 35 |
+
|
| 36 |
+
class ResNetLayer(nn.Module):
|
| 37 |
+
def __init__(self, args) -> None:
|
| 38 |
+
super().__init__()
|
| 39 |
+
self.conv_layer = nn.Sequential(
|
| 40 |
+
nn.Conv1d(in_channels=args.encoder_dim,
|
| 41 |
+
out_channels=args.encoder_dim,
|
| 42 |
+
kernel_size=3,
|
| 43 |
+
stride=1, padding='same', bias=False),
|
| 44 |
+
nn.BatchNorm1d(num_features=args.encoder_dim),
|
| 45 |
+
nn.SiLU(),
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 49 |
+
return self.conv_layer(x)+x
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class ResNetBlock(nn.Module):
|
| 53 |
+
def __init__(self, args) -> None:
|
| 54 |
+
super().__init__()
|
| 55 |
+
self.layers = nn.Sequential(*[ResNetLayer(args) for _ in range(3)])
|
| 56 |
+
|
| 57 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 58 |
+
return self.layers(x)
|
tasks/Modules/conformer.py
ADDED
|
@@ -0,0 +1,584 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from torch import Tensor
|
| 5 |
+
import torch.nn.init as init
|
| 6 |
+
import math
|
| 7 |
+
|
| 8 |
+
from .mhsa_pro import MHA_rotary, MHA_decoder
|
| 9 |
+
from .cnn import ConvBlock, ConvBlockDecoder
|
| 10 |
+
|
| 11 |
+
from typing import Optional,Tuple
|
| 12 |
+
|
| 13 |
+
class ResidualConnectionModule(nn.Module):
|
| 14 |
+
"""
|
| 15 |
+
Residual Connection Module.
|
| 16 |
+
outputs = (module(inputs) x module_factor + inputs x input_factor)
|
| 17 |
+
"""
|
| 18 |
+
def __init__(self, module: nn.Module, dims, args):
|
| 19 |
+
super(ResidualConnectionModule, self).__init__()
|
| 20 |
+
self.module = module
|
| 21 |
+
self.module_factor = 1
|
| 22 |
+
self.input_factor = 1
|
| 23 |
+
|
| 24 |
+
def forward(self, inputs: Tensor, **kwargs) -> Tensor:
|
| 25 |
+
return (self.module(inputs, **kwargs) * self.module_factor) + (inputs * self.input_factor)
|
| 26 |
+
|
| 27 |
+
class PostNorm(nn.Module):
|
| 28 |
+
"""
|
| 29 |
+
Residual Connection Module.
|
| 30 |
+
outputs = (module(inputs) x module_factor + inputs x input_factor)
|
| 31 |
+
"""
|
| 32 |
+
def __init__(self, module: nn.Module, dims, args):
|
| 33 |
+
super(PostNorm, self).__init__()
|
| 34 |
+
self.module = module
|
| 35 |
+
input_factor = torch.FloatTensor(args.alpha) if getattr(args, 'alpha', None) else torch.tensor(1.)
|
| 36 |
+
self.register_buffer('input_factor', input_factor)
|
| 37 |
+
self.norm = nn.LayerNorm(dims)
|
| 38 |
+
|
| 39 |
+
def forward(self, inputs: Tensor, **kwargs) -> Tensor:
|
| 40 |
+
return self.norm(self.module(inputs, **kwargs) + (inputs * self.input_factor))
|
| 41 |
+
|
| 42 |
+
class Linear(nn.Module):
|
| 43 |
+
"""
|
| 44 |
+
Wrapper class of torch.nn.Linear
|
| 45 |
+
Weight initialize by xavier initialization and bias initialize to zeros.
|
| 46 |
+
"""
|
| 47 |
+
def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
|
| 48 |
+
super(Linear, self).__init__()
|
| 49 |
+
self.linear = nn.Linear(in_features, out_features, bias=bias)
|
| 50 |
+
init.xavier_uniform_(self.linear.weight)
|
| 51 |
+
if bias:
|
| 52 |
+
init.zeros_(self.linear.bias)
|
| 53 |
+
|
| 54 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 55 |
+
return self.linear(x)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class View(nn.Module):
|
| 59 |
+
""" Wrapper class of torch.view() for Sequential module. """
|
| 60 |
+
def __init__(self, shape: tuple, contiguous: bool = False):
|
| 61 |
+
super(View, self).__init__()
|
| 62 |
+
self.shape = shape
|
| 63 |
+
self.contiguous = contiguous
|
| 64 |
+
|
| 65 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 66 |
+
if self.contiguous:
|
| 67 |
+
x = x.contiguous()
|
| 68 |
+
|
| 69 |
+
return x.view(*self.shape)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class Transpose(nn.Module):
|
| 73 |
+
""" Wrapper class of torch.transpose() for Sequential module. """
|
| 74 |
+
def __init__(self, shape: tuple):
|
| 75 |
+
super(Transpose, self).__init__()
|
| 76 |
+
self.shape = shape
|
| 77 |
+
|
| 78 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 79 |
+
return x.transpose(*self.shape)
|
| 80 |
+
|
| 81 |
+
class FeedForwardModule(nn.Module):
|
| 82 |
+
"""
|
| 83 |
+
Conformer Feed Forward Module follow pre-norm residual units and apply layer normalization within the residual unit
|
| 84 |
+
and on the input before the first linear layer. This module also apply Swish activation and dropout, which helps
|
| 85 |
+
regularizing the network.
|
| 86 |
+
Args:
|
| 87 |
+
encoder_dim (int): Dimension of conformer encoder
|
| 88 |
+
expansion_factor (int): Expansion factor of feed forward module.
|
| 89 |
+
dropout_p (float): Ratio of dropout
|
| 90 |
+
device (torch.device): torch device (cuda or cpu)
|
| 91 |
+
Inputs: inputs
|
| 92 |
+
- **inputs** (batch, time, dim): Tensor contains input sequences
|
| 93 |
+
Outputs: outputs
|
| 94 |
+
- **outputs** (batch, time, dim): Tensor produces by feed forward module.
|
| 95 |
+
"""
|
| 96 |
+
def __init__(
|
| 97 |
+
self,
|
| 98 |
+
args,
|
| 99 |
+
|
| 100 |
+
) -> None:
|
| 101 |
+
super(FeedForwardModule, self).__init__()
|
| 102 |
+
expansion_factor = 4
|
| 103 |
+
self.sequential = nn.Sequential(
|
| 104 |
+
nn.LayerNorm(args.encoder_dim),
|
| 105 |
+
Linear(args.encoder_dim, args.encoder_dim * expansion_factor, bias=True),
|
| 106 |
+
nn.SiLU(),
|
| 107 |
+
nn.Dropout(p=args.dropout_p),
|
| 108 |
+
Linear(args.encoder_dim * expansion_factor, args.encoder_dim, bias=True),
|
| 109 |
+
nn.Dropout(p=args.dropout_p),
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
def forward(self, inputs: Tensor) -> Tensor:
|
| 113 |
+
return self.sequential(inputs)
|
| 114 |
+
|
| 115 |
+
class DepthwiseConv1d(nn.Module):
|
| 116 |
+
"""
|
| 117 |
+
When groups == in_channels and out_channels == K * in_channels, where K is a positive integer,
|
| 118 |
+
this operation is termed in literature as depthwise convolution.
|
| 119 |
+
Args:
|
| 120 |
+
in_channels (int): Number of channels in the input
|
| 121 |
+
out_channels (int): Number of channels produced by the convolution
|
| 122 |
+
kernel_size (int or tuple): Size of the convolving kernel
|
| 123 |
+
stride (int, optional): Stride of the convolution. Default: 1
|
| 124 |
+
padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
|
| 125 |
+
bias (bool, optional): If True, adds a learnable bias to the output. Default: True
|
| 126 |
+
Inputs: inputs
|
| 127 |
+
- **inputs** (batch, in_channels, time): Tensor containing input vector
|
| 128 |
+
Returns: outputs
|
| 129 |
+
- **outputs** (batch, out_channels, time): Tensor produces by depthwise 1-D convolution.
|
| 130 |
+
"""
|
| 131 |
+
def __init__(
|
| 132 |
+
self,
|
| 133 |
+
in_channels: int,
|
| 134 |
+
out_channels: int,
|
| 135 |
+
kernel_size: int,
|
| 136 |
+
stride: int = 1,
|
| 137 |
+
padding: int = 0,
|
| 138 |
+
bias: bool = False,
|
| 139 |
+
) -> None:
|
| 140 |
+
super(DepthwiseConv1d, self).__init__()
|
| 141 |
+
assert out_channels % in_channels == 0, "out_channels should be constant multiple of in_channels"
|
| 142 |
+
self.conv = nn.Conv1d(
|
| 143 |
+
in_channels=in_channels,
|
| 144 |
+
out_channels=out_channels,
|
| 145 |
+
kernel_size=kernel_size,
|
| 146 |
+
groups=in_channels,
|
| 147 |
+
stride=stride,
|
| 148 |
+
padding=padding,
|
| 149 |
+
bias=bias,
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
def forward(self, inputs: Tensor) -> Tensor:
|
| 153 |
+
return self.conv(inputs)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class PointwiseConv1d(nn.Module):
|
| 157 |
+
"""
|
| 158 |
+
When kernel size == 1 conv1d, this operation is termed in literature as pointwise convolution.
|
| 159 |
+
This operation often used to match dimensions.
|
| 160 |
+
Args:
|
| 161 |
+
in_channels (int): Number of channels in the input
|
| 162 |
+
out_channels (int): Number of channels produced by the convolution
|
| 163 |
+
stride (int, optional): Stride of the convolution. Default: 1
|
| 164 |
+
padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
|
| 165 |
+
bias (bool, optional): If True, adds a learnable bias to the output. Default: True
|
| 166 |
+
Inputs: inputs
|
| 167 |
+
- **inputs** (batch, in_channels, time): Tensor containing input vector
|
| 168 |
+
Returns: outputs
|
| 169 |
+
- **outputs** (batch, out_channels, time): Tensor produces by pointwise 1-D convolution.
|
| 170 |
+
"""
|
| 171 |
+
def __init__(
|
| 172 |
+
self,
|
| 173 |
+
in_channels: int,
|
| 174 |
+
out_channels: int,
|
| 175 |
+
stride: int = 1,
|
| 176 |
+
padding: int = 0,
|
| 177 |
+
bias: bool = True,
|
| 178 |
+
) -> None:
|
| 179 |
+
super(PointwiseConv1d, self).__init__()
|
| 180 |
+
self.conv = nn.Conv1d(
|
| 181 |
+
in_channels=in_channels,
|
| 182 |
+
out_channels=out_channels,
|
| 183 |
+
kernel_size=1,
|
| 184 |
+
stride=stride,
|
| 185 |
+
padding=padding,
|
| 186 |
+
bias=bias,
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
def forward(self, inputs: Tensor) -> Tensor:
|
| 190 |
+
return self.conv(inputs)
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
class ConformerConvModule(nn.Module):
|
| 194 |
+
"""
|
| 195 |
+
Conformer convolution module starts with a pointwise convolution and a gated linear unit (GLU).
|
| 196 |
+
This is followed by a single 1-D depthwise convolution layer. Batchnorm is deployed just after the convolution
|
| 197 |
+
to aid training deep models.
|
| 198 |
+
Args:
|
| 199 |
+
in_channels (int): Number of channels in the input
|
| 200 |
+
kernel_size (int or tuple, optional): Size of the convolving kernel Default: 31
|
| 201 |
+
dropout_p (float, optional): probability of dropout
|
| 202 |
+
Inputs: inputs
|
| 203 |
+
inputs (batch, time, dim): Tensor contains input sequences
|
| 204 |
+
Outputs: outputs
|
| 205 |
+
outputs (batch, time, dim): Tensor produces by conformer convolution module.
|
| 206 |
+
"""
|
| 207 |
+
def __init__(
|
| 208 |
+
self,
|
| 209 |
+
args,
|
| 210 |
+
) -> None:
|
| 211 |
+
super(ConformerConvModule, self).__init__()
|
| 212 |
+
assert (args.kernel_size - 1) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding"
|
| 213 |
+
expansion_factor = 2
|
| 214 |
+
dropout_p = 0.1
|
| 215 |
+
|
| 216 |
+
self.sequential = nn.Sequential(
|
| 217 |
+
nn.LayerNorm(args.encoder_dim),
|
| 218 |
+
Transpose(shape=(1, 2)),
|
| 219 |
+
PointwiseConv1d(args.encoder_dim, args.encoder_dim * expansion_factor, stride=1, padding=0, bias=True),
|
| 220 |
+
nn.GLU(dim=1),
|
| 221 |
+
DepthwiseConv1d(args.encoder_dim, args.encoder_dim, args.kernel_size, stride=1, padding=(args.kernel_size - 1) // 2),
|
| 222 |
+
nn.BatchNorm1d(args.encoder_dim),
|
| 223 |
+
nn.SiLU(),
|
| 224 |
+
PointwiseConv1d(args.encoder_dim, args.encoder_dim, stride=1, padding=0, bias=True),
|
| 225 |
+
nn.Dropout(p=dropout_p),
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
def forward(self, inputs: Tensor) -> Tensor:
|
| 229 |
+
return self.sequential(inputs).transpose(1, 2)
|
| 230 |
+
|
| 231 |
+
class PositionalEncoding(nn.Module):
|
| 232 |
+
"""
|
| 233 |
+
Positional Encoding proposed in "Attention Is All You Need".
|
| 234 |
+
Since transformer contains no recurrence and no convolution, in order for the model to make
|
| 235 |
+
use of the order of the sequence, we must add some positional information.
|
| 236 |
+
"Attention Is All You Need" use sine and cosine functions of different frequencies:
|
| 237 |
+
PE_(pos, 2i) = sin(pos / power(10000, 2i / d_model))
|
| 238 |
+
PE_(pos, 2i+1) = cos(pos / power(10000, 2i / d_model))
|
| 239 |
+
"""
|
| 240 |
+
def __init__(self, d_model: int = 128, max_len: int = 10000) -> None:
|
| 241 |
+
super(PositionalEncoding, self).__init__()
|
| 242 |
+
pe = torch.zeros(max_len, d_model, requires_grad=False)
|
| 243 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
| 244 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
|
| 245 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 246 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
| 247 |
+
pe = pe.unsqueeze(0)
|
| 248 |
+
self.register_buffer('pe', pe)
|
| 249 |
+
|
| 250 |
+
def forward(self, length: int) -> Tensor:
|
| 251 |
+
return self.pe[:, :length]
|
| 252 |
+
|
| 253 |
+
class RelativeMultiHeadAttention(nn.Module):
|
| 254 |
+
"""
|
| 255 |
+
Multi-head attention with relative positional encoding.
|
| 256 |
+
This concept was proposed in the "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
|
| 257 |
+
Args:
|
| 258 |
+
d_model (int): The dimension of model
|
| 259 |
+
num_heads (int): The number of attention heads.
|
| 260 |
+
dropout_p (float): probability of dropout
|
| 261 |
+
Inputs: query, key, value, pos_embedding, mask
|
| 262 |
+
- **query** (batch, time, dim): Tensor containing query vector
|
| 263 |
+
- **key** (batch, time, dim): Tensor containing key vector
|
| 264 |
+
- **value** (batch, time, dim): Tensor containing value vector
|
| 265 |
+
- **pos_embedding** (batch, time, dim): Positional embedding tensor
|
| 266 |
+
- **mask** (batch, 1, time2) or (batch, time1, time2): Tensor containing indices to be masked
|
| 267 |
+
Returns:
|
| 268 |
+
- **outputs**: Tensor produces by relative multi head attention module.
|
| 269 |
+
"""
|
| 270 |
+
def __init__(
|
| 271 |
+
self,
|
| 272 |
+
encoder_dim: int = 128,
|
| 273 |
+
num_heads: int = 8,
|
| 274 |
+
dropout_p: float = 0.1
|
| 275 |
+
):
|
| 276 |
+
super(RelativeMultiHeadAttention, self).__init__()
|
| 277 |
+
assert encoder_dim % num_heads == 0, "d_model % num_heads should be zero."
|
| 278 |
+
self.d_model = encoder_dim
|
| 279 |
+
self.d_head = int(encoder_dim / num_heads)
|
| 280 |
+
self.num_heads = num_heads
|
| 281 |
+
self.sqrt_dim = math.sqrt(encoder_dim)
|
| 282 |
+
|
| 283 |
+
self.query_proj = Linear(encoder_dim, encoder_dim)
|
| 284 |
+
self.key_proj = Linear(encoder_dim, encoder_dim)
|
| 285 |
+
self.value_proj = Linear(encoder_dim, encoder_dim)
|
| 286 |
+
self.pos_proj = Linear(encoder_dim, encoder_dim, bias=False)
|
| 287 |
+
|
| 288 |
+
self.dropout = nn.Dropout(p=dropout_p)
|
| 289 |
+
self.u_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head))
|
| 290 |
+
self.v_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head))
|
| 291 |
+
torch.nn.init.xavier_uniform_(self.u_bias)
|
| 292 |
+
torch.nn.init.xavier_uniform_(self.v_bias)
|
| 293 |
+
|
| 294 |
+
self.out_proj = Linear(encoder_dim, encoder_dim)
|
| 295 |
+
|
| 296 |
+
def forward(
|
| 297 |
+
self,
|
| 298 |
+
query: Tensor,
|
| 299 |
+
key: Tensor,
|
| 300 |
+
value: Tensor,
|
| 301 |
+
pos_embedding: Tensor,
|
| 302 |
+
mask: Optional[Tensor] = None,
|
| 303 |
+
) -> Tensor:
|
| 304 |
+
batch_size = value.size(0)
|
| 305 |
+
|
| 306 |
+
query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head)
|
| 307 |
+
query = query.view(batch_size, -1, self.num_heads, self.d_head)
|
| 308 |
+
key = self.key_proj(key).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3)
|
| 309 |
+
value = self.value_proj(value).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3)
|
| 310 |
+
pos_embedding = self.pos_proj(pos_embedding).view(batch_size, -1, self.num_heads, self.d_head)
|
| 311 |
+
|
| 312 |
+
content_score = torch.matmul((query + self.u_bias).transpose(1, 2), key.transpose(2, 3))
|
| 313 |
+
pos_score = torch.matmul((query + self.v_bias).transpose(1, 2), pos_embedding.permute(0, 2, 3, 1))
|
| 314 |
+
# content_score = torch.matmul((query).transpose(1, 2), key.transpose(2, 3))
|
| 315 |
+
# pos_score = torch.matmul((query).transpose(1, 2), pos_embedding.permute(0, 2, 3, 1))
|
| 316 |
+
#Q(B,numheads,length,d_head)*PE(B,numheads,d_heads,length) = posscore(B,num_heads,length,length)
|
| 317 |
+
pos_score = self._relative_shift(pos_score)
|
| 318 |
+
score = (content_score + pos_score) / self.sqrt_dim
|
| 319 |
+
|
| 320 |
+
if mask is not None:
|
| 321 |
+
mask = mask.unsqueeze(1)
|
| 322 |
+
score.masked_fill_(mask, -1e9)
|
| 323 |
+
|
| 324 |
+
score = F.softmax(score, -1)
|
| 325 |
+
attn = self.dropout(score)
|
| 326 |
+
|
| 327 |
+
context = torch.matmul(attn, value).transpose(1, 2)
|
| 328 |
+
context = context.contiguous().view(batch_size, -1, self.d_model)
|
| 329 |
+
|
| 330 |
+
return self.out_proj(context)
|
| 331 |
+
|
| 332 |
+
def _relative_shift(self, pos_score: Tensor) -> Tensor:
|
| 333 |
+
batch_size, num_heads, seq_length1, seq_length2 = pos_score.size()
|
| 334 |
+
zeros = pos_score.new_zeros(batch_size, num_heads, seq_length1, 1)
|
| 335 |
+
padded_pos_score = torch.cat([zeros, pos_score], dim=-1)
|
| 336 |
+
|
| 337 |
+
padded_pos_score = padded_pos_score.view(batch_size, num_heads, seq_length2 + 1, seq_length1)
|
| 338 |
+
pos_score = padded_pos_score[:, :, 1:].view_as(pos_score)
|
| 339 |
+
#shift position score a unit along length axis and leave a blank row.
|
| 340 |
+
return pos_score
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
class MultiHeadedSelfAttentionModule(nn.Module):
|
| 344 |
+
"""
|
| 345 |
+
Conformer employ multi-headed self-attention (MHSA) while integrating an important technique from Transformer-XL,
|
| 346 |
+
the relative sinusoidal positional encoding scheme. The relative positional encoding allows the self-attention
|
| 347 |
+
module to generalize better on different input length and the resulting encoder is more robust to the variance of
|
| 348 |
+
the utterance length. Conformer use prenorm residual units with dropout which helps training
|
| 349 |
+
and regularizing deeper models.
|
| 350 |
+
Args:
|
| 351 |
+
d_model (int): The dimension of model
|
| 352 |
+
num_heads (int): The number of attention heads.
|
| 353 |
+
dropout_p (float): probability of dropout
|
| 354 |
+
device (torch.device): torch device (cuda or cpu)
|
| 355 |
+
Inputs: inputs, mask
|
| 356 |
+
- **inputs** (batch, time, dim): Tensor containing input vector
|
| 357 |
+
- **mask** (batch, 1, time2) or (batch, time1, time2): Tensor containing indices to be masked
|
| 358 |
+
Returns:
|
| 359 |
+
- **outputs** (batch, time, dim): Tensor produces by relative multi headed self attention module.
|
| 360 |
+
"""
|
| 361 |
+
def __init__(self, args):
|
| 362 |
+
super(MultiHeadedSelfAttentionModule, self).__init__()
|
| 363 |
+
dropout_p = 0.1
|
| 364 |
+
self.positional_encoding = PositionalEncoding(args.encoder_dim)
|
| 365 |
+
self.layer_norm = nn.LayerNorm(args.encoder_dim)
|
| 366 |
+
self.attention = RelativeMultiHeadAttention(args.encoder_dim, args.num_heads, args.dropout_p)
|
| 367 |
+
self.dropout = nn.Dropout(p=dropout_p)
|
| 368 |
+
|
| 369 |
+
def forward(self, inputs: Tensor, mask: Optional[Tensor] = None):
|
| 370 |
+
batch_size, seq_length, _ = inputs.size()
|
| 371 |
+
pos_embedding = self.positional_encoding(seq_length)
|
| 372 |
+
pos_embedding = pos_embedding.repeat(batch_size, 1, 1)
|
| 373 |
+
|
| 374 |
+
inputs = self.layer_norm(inputs)
|
| 375 |
+
outputs = self.attention(inputs, inputs, inputs, pos_embedding=pos_embedding, mask=mask)
|
| 376 |
+
|
| 377 |
+
return self.dropout(outputs)
|
| 378 |
+
|
| 379 |
+
class ConformerBlock(nn.Module):
|
| 380 |
+
"""
|
| 381 |
+
Conformer block contains two Feed Forward modules sandwiching the Multi-Headed Self-Attention module
|
| 382 |
+
and the Convolution module. This sandwich structure is inspired by Macaron-Net, which proposes replacing
|
| 383 |
+
the original feed-forward layer in the Transformer block into two half-step feed-forward layers,
|
| 384 |
+
one before the attention layer and one after.
|
| 385 |
+
Args:
|
| 386 |
+
encoder_dim (int, optional): Dimension of conformer encoder
|
| 387 |
+
num_attention_heads (int, optional): Number of attention heads
|
| 388 |
+
feed_forward_expansion_factor (int, optional): Expansion factor of feed forward module
|
| 389 |
+
conv_expansion_factor (int, optional): Expansion factor of conformer convolution module
|
| 390 |
+
feed_forward_dropout_p (float, optional): Probability of feed forward module dropout
|
| 391 |
+
attention_dropout_p (float, optional): Probability of attention module dropout
|
| 392 |
+
conv_dropout_p (float, optional): Probability of conformer convolution module dropout
|
| 393 |
+
conv_kernel_size (int or tuple, optional): Size of the convolving kernel
|
| 394 |
+
half_step_residual (bool): Flag indication whether to use half step residual or not
|
| 395 |
+
device (torch.device): torch device (cuda or cpu)
|
| 396 |
+
Inputs: inputs
|
| 397 |
+
- **inputs** (batch, time, dim): Tensor containing input vector
|
| 398 |
+
Returns: outputs
|
| 399 |
+
- **outputs** (batch, time, dim): Tensor produces by conformer block.
|
| 400 |
+
"""
|
| 401 |
+
def __init__(
|
| 402 |
+
self,
|
| 403 |
+
args
|
| 404 |
+
):
|
| 405 |
+
super(ConformerBlock, self).__init__()
|
| 406 |
+
|
| 407 |
+
norm_dict = {
|
| 408 |
+
'shortcut': ResidualConnectionModule,
|
| 409 |
+
'postnorm': PostNorm
|
| 410 |
+
}
|
| 411 |
+
block_dict = {
|
| 412 |
+
'ffn': FeedForwardModule,
|
| 413 |
+
'mhsa': MultiHeadedSelfAttentionModule,
|
| 414 |
+
'mhsa_pro': MHA_rotary,
|
| 415 |
+
'conv': ConvBlock,
|
| 416 |
+
'conformerconv': ConformerConvModule
|
| 417 |
+
}
|
| 418 |
+
|
| 419 |
+
self.modlist = nn.ModuleList([norm_dict[args.norm](block_dict[block](args), args.encoder_dim, args) for block in args.encoder]\
|
| 420 |
+
)
|
| 421 |
+
|
| 422 |
+
def forward(self, x: Tensor, RoPE, key_padding_mask=None) -> Tensor:
|
| 423 |
+
for m in self.modlist:
|
| 424 |
+
if isinstance(m.module, MHA_rotary):
|
| 425 |
+
x = m(x, RoPE=RoPE, key_padding_mask=key_padding_mask)
|
| 426 |
+
else:
|
| 427 |
+
x = m(x)
|
| 428 |
+
return x
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
class DecoderBlock(nn.Module):
|
| 432 |
+
"""
|
| 433 |
+
Decoder block contains two Feed Forward modules sandwiching the Multi-Headed Self-Attention module
|
| 434 |
+
and the Convolution module. This sandwich structure is inspired by Macaron-Net, which proposes replacing
|
| 435 |
+
the original feed-forward layer in the Transformer block into two half-step feed-forward layers,
|
| 436 |
+
one before the attention layer and one after.
|
| 437 |
+
Args:
|
| 438 |
+
encoder_dim (int, optional): Dimension of conformer encoder
|
| 439 |
+
num_attention_heads (int, optional): Number of attention heads
|
| 440 |
+
feed_forward_expansion_factor (int, optional): Expansion factor of feed forward module
|
| 441 |
+
conv_expansion_factor (int, optional): Expansion factor of conformer convolution module
|
| 442 |
+
feed_forward_dropout_p (float, optional): Probability of feed forward module dropout
|
| 443 |
+
attention_dropout_p (float, optional): Probability of attention module dropout
|
| 444 |
+
conv_dropout_p (float, optional): Probability of conformer convolution module dropout
|
| 445 |
+
conv_kernel_size (int or tuple, optional): Size of the convolving kernel
|
| 446 |
+
half_step_residual (bool): Flag indication whether to use half step residual or not
|
| 447 |
+
device (torch.device): torch device (cuda or cpu)
|
| 448 |
+
Inputs: inputs
|
| 449 |
+
- **inputs** (batch, time, dim): Tensor containing input vector
|
| 450 |
+
Returns: outputs
|
| 451 |
+
- **outputs** (batch, time, dim): Tensor produces by conformer block.
|
| 452 |
+
"""
|
| 453 |
+
def __init__(
|
| 454 |
+
self,
|
| 455 |
+
args
|
| 456 |
+
):
|
| 457 |
+
super(DecoderBlock, self).__init__()
|
| 458 |
+
|
| 459 |
+
norm_dict = {
|
| 460 |
+
'shortcut': ResidualConnectionModule,
|
| 461 |
+
'postnorm': PostNorm
|
| 462 |
+
}
|
| 463 |
+
block_dict = {
|
| 464 |
+
'ffn': FeedForwardModule,
|
| 465 |
+
'mhsa': MultiHeadedSelfAttentionModule,
|
| 466 |
+
'mhsa_pro': MHA_rotary,
|
| 467 |
+
'mhsa_decoder': MHA_decoder,
|
| 468 |
+
'conv': ConvBlockDecoder,
|
| 469 |
+
'conformerconv': ConformerConvModule
|
| 470 |
+
}
|
| 471 |
+
|
| 472 |
+
self.modlist = nn.ModuleList([norm_dict[args.norm](block_dict[block](args),args.decoder_dim, args) for block in args.decoder]\
|
| 473 |
+
)
|
| 474 |
+
|
| 475 |
+
def forward(self, x: Tensor, memory:Tensor, RoPE, key_padding_mask=None) -> Tensor:
|
| 476 |
+
for m in self.modlist:
|
| 477 |
+
if isinstance(m.module, MHA_decoder):
|
| 478 |
+
x = m(x, memory=memory, RoPE=RoPE, key_padding_mask=key_padding_mask)
|
| 479 |
+
elif isinstance(m.module, MHA_rotary):
|
| 480 |
+
x = m(x, RoPE=RoPE, key_padding_mask=key_padding_mask).transpose(0,1)
|
| 481 |
+
else:
|
| 482 |
+
x = m(x)
|
| 483 |
+
return x
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
class ConformerEncoder(nn.Module):
|
| 487 |
+
"""
|
| 488 |
+
Conformer encoder first processes the input with a convolution subsampling layer and then
|
| 489 |
+
with a number of conformer blocks.
|
| 490 |
+
Args:
|
| 491 |
+
input_dim (int, optional): Dimension of input vector
|
| 492 |
+
encoder_dim (int, optional): Dimension of conformer encoder
|
| 493 |
+
num_layers (int, optional): Number of conformer blocks
|
| 494 |
+
num_attention_heads (int, optional): Number of attention heads
|
| 495 |
+
feed_forward_expansion_factor (int, optional): Expansion factor of feed forward module
|
| 496 |
+
conv_expansion_factor (int, optional): Expansion factor of conformer convolution module
|
| 497 |
+
feed_forward_dropout_p (float, optional): Probability of feed forward module dropout
|
| 498 |
+
attention_dropout_p (float, optional): Probability of attention module dropout
|
| 499 |
+
conv_dropout_p (float, optional): Probability of conformer convolution module dropout
|
| 500 |
+
conv_kernel_size (int or tuple, optional): Size of the convolving kernel
|
| 501 |
+
half_step_residual (bool): Flag indication whether to use half step residual or not
|
| 502 |
+
device (torch.device): torch device (cuda or cpu)
|
| 503 |
+
Inputs: inputs, input_lengths
|
| 504 |
+
- **inputs** (batch, time, dim): Tensor containing input vector
|
| 505 |
+
- **input_lengths** (batch): list of sequence input lengths
|
| 506 |
+
Returns: outputs, output_lengths
|
| 507 |
+
- **outputs** (batch, out_channels, time): Tensor produces by conformer encoder.
|
| 508 |
+
- **output_lengths** (batch): list of sequence output lengths
|
| 509 |
+
"""
|
| 510 |
+
def __init__(
|
| 511 |
+
self,
|
| 512 |
+
args,
|
| 513 |
+
):
|
| 514 |
+
super(ConformerEncoder, self).__init__()
|
| 515 |
+
self.blocks = nn.ModuleList([ConformerBlock(
|
| 516 |
+
args) for _ in range(args.num_layers)])
|
| 517 |
+
|
| 518 |
+
def forward(self, x: Tensor, RoPE=None, key_padding_mask=None) -> Tuple[Tensor, Tensor]:
|
| 519 |
+
"""
|
| 520 |
+
Forward propagate a `inputs` for encoder training.
|
| 521 |
+
Args:
|
| 522 |
+
inputs (torch.FloatTensor): A input sequence passed to encoder. Typically for inputs this will be a padded
|
| 523 |
+
`FloatTensor` of size ``(batch, seq_length, dimension)``.
|
| 524 |
+
input_lengths (torch.LongTensor): The length of input tensor. ``(batch)``
|
| 525 |
+
Returns:
|
| 526 |
+
(Tensor, Tensor)
|
| 527 |
+
* outputs (torch.FloatTensor): A output sequence of encoder. `FloatTensor` of size
|
| 528 |
+
``(batch, seq_length, dimension)``
|
| 529 |
+
* output_lengths (torch.LongTensor): The length of output tensor. ``(batch)``
|
| 530 |
+
"""
|
| 531 |
+
for block in self.blocks:
|
| 532 |
+
x = block(x, RoPE=RoPE, key_padding_mask=key_padding_mask)
|
| 533 |
+
|
| 534 |
+
return x
|
| 535 |
+
|
| 536 |
+
class ConformerDecoder(nn.Module):
|
| 537 |
+
"""
|
| 538 |
+
Conformer encoder first processes the input with a convolution subsampling layer and then
|
| 539 |
+
with a number of conformer blocks.
|
| 540 |
+
Args:
|
| 541 |
+
input_dim (int, optional): Dimension of input vector
|
| 542 |
+
encoder_dim (int, optional): Dimension of conformer encoder
|
| 543 |
+
num_layers (int, optional): Number of conformer blocks
|
| 544 |
+
num_attention_heads (int, optional): Number of attention heads
|
| 545 |
+
feed_forward_expansion_factor (int, optional): Expansion factor of feed forward module
|
| 546 |
+
conv_expansion_factor (int, optional): Expansion factor of conformer convolution module
|
| 547 |
+
feed_forward_dropout_p (float, optional): Probability of feed forward module dropout
|
| 548 |
+
attention_dropout_p (float, optional): Probability of attention module dropout
|
| 549 |
+
conv_dropout_p (float, optional): Probability of conformer convolution module dropout
|
| 550 |
+
conv_kernel_size (int or tuple, optional): Size of the convolving kernel
|
| 551 |
+
half_step_residual (bool): Flag indication whether to use half step residual or not
|
| 552 |
+
device (torch.device): torch device (cuda or cpu)
|
| 553 |
+
Inputs: inputs, input_lengths
|
| 554 |
+
- **inputs** (batch, time, dim): Tensor containing input vector
|
| 555 |
+
- **input_lengths** (batch): list of sequence input lengths
|
| 556 |
+
Returns: outputs, output_lengths
|
| 557 |
+
- **outputs** (batch, out_channels, time): Tensor produces by conformer encoder.
|
| 558 |
+
- **output_lengths** (batch): list of sequence output lengths
|
| 559 |
+
"""
|
| 560 |
+
def __init__(
|
| 561 |
+
self,
|
| 562 |
+
args,
|
| 563 |
+
):
|
| 564 |
+
super(ConformerDecoder, self).__init__()
|
| 565 |
+
self.blocks = nn.ModuleList([DecoderBlock(
|
| 566 |
+
args) for _ in range(args.num_decoder_layers)])
|
| 567 |
+
|
| 568 |
+
def forward(self, x: Tensor, memory: Tensor, RoPE=None, key_padding_mask=None) -> Tuple[Tensor, Tensor]:
|
| 569 |
+
"""
|
| 570 |
+
Forward propagate a `inputs` for encoder training.
|
| 571 |
+
Args:
|
| 572 |
+
inputs (torch.FloatTensor): A input sequence passed to encoder. Typically for inputs this will be a padded
|
| 573 |
+
`FloatTensor` of size ``(batch, seq_length, dimension)``.
|
| 574 |
+
input_lengths (torch.LongTensor): The length of input tensor. ``(batch)``
|
| 575 |
+
Returns:
|
| 576 |
+
(Tensor, Tensor)
|
| 577 |
+
* outputs (torch.FloatTensor): A output sequence of encoder. `FloatTensor` of size
|
| 578 |
+
``(batch, seq_length, dimension)``
|
| 579 |
+
* output_lengths (torch.LongTensor): The length of output tensor. ``(batch)``
|
| 580 |
+
"""
|
| 581 |
+
for block in self.blocks:
|
| 582 |
+
x = block(x, memory, RoPE=RoPE, key_padding_mask=key_padding_mask)
|
| 583 |
+
|
| 584 |
+
return x
|
tasks/Modules/mhsa_pro.py
ADDED
|
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import torch.nn.init as init
|
| 5 |
+
|
| 6 |
+
from typing import Optional,Tuple
|
| 7 |
+
import math
|
| 8 |
+
import logging
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
rwkv_emb_scale = 0.4 # try 0.4 for char-level english. try 1.0 for chinese.
|
| 14 |
+
rwkv_layer_decay = 1.0 # decay weights in higher layers. try 0.5 ~ 1.0.
|
| 15 |
+
|
| 16 |
+
class AttentionConfig:
|
| 17 |
+
def __init__(self, ctx_len=100, **kwargs):
|
| 18 |
+
self.ctx_len = ctx_len
|
| 19 |
+
for k,v in kwargs.items():
|
| 20 |
+
setattr(self, k, v)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
########################################################################################################
|
| 24 |
+
# MHA_rotary: Multi-head Attention + Rotary Encoding + GeGLU FFN
|
| 25 |
+
########################################################################################################
|
| 26 |
+
|
| 27 |
+
class RotaryEmbedding(torch.nn.Module):
|
| 28 |
+
def __init__(self, dim, base=10000):
|
| 29 |
+
super().__init__()
|
| 30 |
+
inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
|
| 31 |
+
self.register_buffer('inv_freq', inv_freq)
|
| 32 |
+
self.seq_len_cached = None
|
| 33 |
+
self.cos_cached = None
|
| 34 |
+
self.sin_cached = None
|
| 35 |
+
|
| 36 |
+
def forward(self, x, seq_len=None):
|
| 37 |
+
if seq_len != self.seq_len_cached:
|
| 38 |
+
self.seq_len_cached = seq_len
|
| 39 |
+
t = torch.arange(seq_len, device=x.device)
|
| 40 |
+
freqs = torch.einsum('i,j->ij', t, self.inv_freq)
|
| 41 |
+
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
|
| 42 |
+
self.cos_cached = emb.cos()
|
| 43 |
+
self.sin_cached = emb.sin()
|
| 44 |
+
return torch.stack([self.cos_cached, self.sin_cached])
|
| 45 |
+
|
| 46 |
+
class ContinuousRotaryEmbedding(torch.nn.Module):
|
| 47 |
+
'''Continuous rotary position embedding'''
|
| 48 |
+
def __init__(self, dim, sequence_scale):
|
| 49 |
+
super().__init__()
|
| 50 |
+
base=10000
|
| 51 |
+
self.sequence_scale = sequence_scale
|
| 52 |
+
self.register_buffer('inv_freq', 1. / (base ** (torch.arange(0, dim, 2))))
|
| 53 |
+
|
| 54 |
+
def forward(self, t):
|
| 55 |
+
t = (t + 0.5)* self.sequence_scale
|
| 56 |
+
freqs = torch.einsum('ij,k->ijk', t, self.inv_freq) # freqs: [B, L, dim//2]
|
| 57 |
+
emb = torch.cat((freqs, freqs), dim=-1).unsqueeze(1) # emb: [B, 1, L, dim], 1 for broadcast in head_num dim
|
| 58 |
+
return torch.stack([emb.cos(), emb.sin()])
|
| 59 |
+
|
| 60 |
+
def rotate_half(x):
|
| 61 |
+
x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
|
| 62 |
+
return torch.cat((-x2, x1), -1)
|
| 63 |
+
|
| 64 |
+
@torch.jit.script
|
| 65 |
+
def apply_rotary_pos_emb(q, k, cos, sin):
|
| 66 |
+
cos, sin = cos[...,:q.shape[2],:], sin[...,:q.shape[2],:]
|
| 67 |
+
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
|
| 68 |
+
|
| 69 |
+
class MHA_rotary(nn.Module):
|
| 70 |
+
def __init__(self, args):
|
| 71 |
+
super().__init__()
|
| 72 |
+
self.collect_attention_map = False
|
| 73 |
+
self.attention_map = None
|
| 74 |
+
assert args.encoder_dim % args.num_heads == 0
|
| 75 |
+
self.num_heads = args.num_heads
|
| 76 |
+
self.head_size = args.encoder_dim // args.num_heads
|
| 77 |
+
|
| 78 |
+
if args.timeshift:
|
| 79 |
+
self.time_shift = nn.ZeroPad2d((0,0,1,0))
|
| 80 |
+
|
| 81 |
+
self.query = nn.Linear(args.encoder_dim, args.encoder_dim)
|
| 82 |
+
self.key = nn.Linear(args.encoder_dim, args.encoder_dim)
|
| 83 |
+
self.value = nn.Linear(args.encoder_dim, args.encoder_dim)
|
| 84 |
+
|
| 85 |
+
# self.register_buffer("mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len)))
|
| 86 |
+
|
| 87 |
+
self.rotary_ndims = int(self.head_size * 0.5)
|
| 88 |
+
|
| 89 |
+
self.rotary_emb = RotaryEmbedding(self.rotary_ndims)
|
| 90 |
+
|
| 91 |
+
self.output = nn.Linear(args.encoder_dim, args.encoder_dim)
|
| 92 |
+
|
| 93 |
+
def forward(self, x, RoPE, key_padding_mask=None):
|
| 94 |
+
B, T, C = x.size()
|
| 95 |
+
|
| 96 |
+
if hasattr(self, 'time_shift'):
|
| 97 |
+
x = torch.cat([self.time_shift(x)[:, :-1, :C//2], x[:, :, C//2:]], dim = -1)
|
| 98 |
+
|
| 99 |
+
q = self.query(x).view(B, T, self.num_heads, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
|
| 100 |
+
k = self.key(x).view(B, T, self.num_heads, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
|
| 101 |
+
v = self.value(x).view(B, T, self.num_heads, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
|
| 102 |
+
|
| 103 |
+
q, query_pass = q[..., :self.rotary_ndims], q[..., self.rotary_ndims:]
|
| 104 |
+
k, key_pass = k[..., :self.rotary_ndims], k[..., self.rotary_ndims:]
|
| 105 |
+
|
| 106 |
+
# cos, sin = self.rotary_emb(q, seq_len=T)
|
| 107 |
+
cos, sin = RoPE
|
| 108 |
+
q, k = apply_rotary_pos_emb(q, k, cos, sin) # rotary encoding
|
| 109 |
+
q = torch.cat((q, query_pass), dim=-1)
|
| 110 |
+
k = torch.cat((k, key_pass), dim=-1)
|
| 111 |
+
|
| 112 |
+
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) # self-attention: (B, nh, T, hs) * (B, nh, hs, T) -> (B, nh, T, T)
|
| 113 |
+
if key_padding_mask is not None:
|
| 114 |
+
key_padding_mask = key_padding_mask[:, None, None, :] # (B, T) -> (B, 1, 1, T)
|
| 115 |
+
att = att.masked_fill(key_padding_mask == 0, float('-inf'))
|
| 116 |
+
att = F.softmax(att, dim = -1) # softmax
|
| 117 |
+
|
| 118 |
+
x = att @ v # (B, nh, T, T) * (B, nh, T, hs) -> (B, nh, T, hs)
|
| 119 |
+
x = x.transpose(1, 2).contiguous().view(B, T, -1) # (B, nh, T, hs) -> (B, T, nh, hs) -> (B, T, C)
|
| 120 |
+
|
| 121 |
+
x = self.output(x)
|
| 122 |
+
|
| 123 |
+
if self.collect_attention_map:
|
| 124 |
+
self.attention_map = att
|
| 125 |
+
|
| 126 |
+
return x
|
| 127 |
+
|
| 128 |
+
class MHA_decoder(nn.Module):
|
| 129 |
+
def __init__(self, args):
|
| 130 |
+
super().__init__()
|
| 131 |
+
self.collect_attention_map = False
|
| 132 |
+
self.attention_map = None
|
| 133 |
+
assert args.encoder_dim % args.num_heads == 0
|
| 134 |
+
self.num_heads = args.num_heads
|
| 135 |
+
self.head_size = args.decoder_dim // args.num_heads
|
| 136 |
+
|
| 137 |
+
if args.timeshift:
|
| 138 |
+
self.time_shift = nn.ZeroPad2d((0,0,1,0))
|
| 139 |
+
|
| 140 |
+
self.query = nn.Linear(args.decoder_dim, args.decoder_dim)
|
| 141 |
+
self.key = nn.Linear(args.decoder_dim, args.decoder_dim)
|
| 142 |
+
self.value = nn.Linear(args.decoder_dim, args.decoder_dim)
|
| 143 |
+
|
| 144 |
+
# self.register_buffer("mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len)))
|
| 145 |
+
|
| 146 |
+
self.rotary_ndims = int(self.head_size * 0.5)
|
| 147 |
+
|
| 148 |
+
self.rotary_emb = RotaryEmbedding(self.rotary_ndims)
|
| 149 |
+
|
| 150 |
+
self.output = nn.Linear(args.decoder_dim, args.decoder_dim)
|
| 151 |
+
|
| 152 |
+
def forward(self, x, memory,RoPE, key_padding_mask=None):
|
| 153 |
+
B, T, C = x.size()
|
| 154 |
+
_, L, M = memory.size()
|
| 155 |
+
|
| 156 |
+
# print("x size: ", x.size(), 'memory size: ', memory.size())
|
| 157 |
+
# print('B, T, C: ', B, T, C, 'L: ', L)
|
| 158 |
+
|
| 159 |
+
q = self.query(x).view(B, T, self.num_heads, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
|
| 160 |
+
k = self.key(x).view(B, T, self.num_heads, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
|
| 161 |
+
v = self.value(x).view(B, T, self.num_heads, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
|
| 162 |
+
|
| 163 |
+
q, query_pass = q[..., :self.rotary_ndims], q[..., self.rotary_ndims:]
|
| 164 |
+
k, key_pass = k[..., :self.rotary_ndims], k[..., self.rotary_ndims:]
|
| 165 |
+
|
| 166 |
+
# cos, sin = self.rotary_emb(q, seq_len=T)
|
| 167 |
+
cos, sin = RoPE
|
| 168 |
+
q, k = apply_rotary_pos_emb(q, k, cos, sin) # rotary encoding
|
| 169 |
+
q = torch.cat((q, query_pass), dim=-1)
|
| 170 |
+
k = torch.cat((k, key_pass), dim=-1)
|
| 171 |
+
|
| 172 |
+
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) # self-attention: (B, nh, T, hs) * (B, nh, hs, T) -> (B, nh, T, T)
|
| 173 |
+
if key_padding_mask is not None:
|
| 174 |
+
key_padding_mask = key_padding_mask[:, None, None, :] # (B, T) -> (B, 1, 1, T)
|
| 175 |
+
att = att.masked_fill(key_padding_mask == 0, float('-inf'))
|
| 176 |
+
att = F.softmax(att, dim = -1) # softmax
|
| 177 |
+
|
| 178 |
+
x = att @ v
|
| 179 |
+
# print("after attention vals: ", x.shape) # (B, nh, T, T) * (B, nh, T, hs) -> (B, nh, T, hs)
|
| 180 |
+
x = x.transpose(1, 2).contiguous().view(B, T, -1) # (B, nh, T, hs) -> (B, T, nh, hs) -> (B, T, C)
|
| 181 |
+
|
| 182 |
+
# x = self.output(x)
|
| 183 |
+
|
| 184 |
+
# print("after linear: ", x.shape) # (B, nh, T, T) * (B, nh, T, hs) -> (B, nh, T, hs)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
# cross attention:
|
| 188 |
+
q = self.query(x).view(B, T, self.num_heads, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
|
| 189 |
+
k = self.key(memory).view(B, L, self.num_heads, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
|
| 190 |
+
v = self.value(memory).view(B, L, self.num_heads, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
|
| 191 |
+
|
| 192 |
+
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) # self-attention: (B, nh, T, hs) * (B, nh, hs, T) -> (B, nh, T, T)
|
| 193 |
+
# print("att size: ", att.size())
|
| 194 |
+
if key_padding_mask is not None:
|
| 195 |
+
key_padding_mask = key_padding_mask[:, None, None, :] # (B, T) -> (B, 1, 1, T)
|
| 196 |
+
att = att.masked_fill(key_padding_mask == 0, float('-inf'))
|
| 197 |
+
att = F.softmax(att, dim = -1) # softmax
|
| 198 |
+
|
| 199 |
+
x = att @ v # (B, nh, T, T) * (B, nh, T, hs) -> (B, nh, T, hs)
|
| 200 |
+
# print("x deocder size: ", x.size())
|
| 201 |
+
x = x.transpose(1, 2).contiguous().view(B, T, -1) # (B, nh, T, hs) -> (B, T, nh, hs) -> (B, T, C)
|
| 202 |
+
# print("x deocder size transposed: ", x.size())
|
| 203 |
+
x = self.output(x)
|
| 204 |
+
|
| 205 |
+
if self.collect_attention_map:
|
| 206 |
+
self.attention_map = att
|
| 207 |
+
|
| 208 |
+
return x
|
| 209 |
+
|
| 210 |
+
class GeGLU(torch.nn.Module):
|
| 211 |
+
def __init__(self, config, layer_id, time_shift = False):
|
| 212 |
+
super().__init__()
|
| 213 |
+
self.layer_id = layer_id
|
| 214 |
+
|
| 215 |
+
if time_shift:
|
| 216 |
+
self.time_shift = nn.ZeroPad2d((0,0,1,0))
|
| 217 |
+
|
| 218 |
+
hidden_sz = 3 * config.n_ffn
|
| 219 |
+
self.key = nn.Linear(config.n_embd, hidden_sz)
|
| 220 |
+
self.value = nn.Linear(config.n_embd, hidden_sz)
|
| 221 |
+
self.weight = nn.Linear(hidden_sz, config.n_embd)
|
| 222 |
+
|
| 223 |
+
def forward(self, x):
|
| 224 |
+
B, T, C = x.size()
|
| 225 |
+
if hasattr(self, 'time_shift'):
|
| 226 |
+
x = torch.cat([self.time_shift(x)[:, :-1, :C//2], x[:, :, C//2:]], dim = -1)
|
| 227 |
+
|
| 228 |
+
k = self.key(x)
|
| 229 |
+
v = self.value(x)
|
| 230 |
+
y = self.weight(F.gelu(k) * v)
|
| 231 |
+
return y
|
tasks/audio.py
CHANGED
|
@@ -2,11 +2,19 @@ from fastapi import APIRouter
|
|
| 2 |
from datetime import datetime
|
| 3 |
from datasets import load_dataset
|
| 4 |
from sklearn.metrics import accuracy_score
|
|
|
|
| 5 |
import random
|
| 6 |
import os
|
|
|
|
|
|
|
| 7 |
|
| 8 |
from .utils.evaluation import AudioEvaluationRequest
|
| 9 |
from .utils.emissions import tracker, clean_emissions_data, get_space_info
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
from dotenv import load_dotenv
|
| 12 |
load_dotenv()
|
|
@@ -43,20 +51,43 @@ async def evaluate_audio(request: AudioEvaluationRequest):
|
|
| 43 |
# Split dataset
|
| 44 |
train_test = dataset["train"].train_test_split(test_size=request.test_size, seed=request.test_seed)
|
| 45 |
test_dataset = train_test["test"]
|
| 46 |
-
|
| 47 |
# Start tracking emissions
|
| 48 |
tracker.start()
|
| 49 |
tracker.start_task("inference")
|
| 50 |
-
|
| 51 |
#--------------------------------------------------------------------------------------------
|
| 52 |
# YOUR MODEL INFERENCE CODE HERE
|
| 53 |
# Update the code below to replace the random baseline by your model inference within the inference pass where the energy consumption and emissions are tracked.
|
| 54 |
-
#--------------------------------------------------------------------------------------------
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
# Make random predictions (placeholder for actual model inference)
|
| 57 |
true_labels = test_dataset["label"]
|
| 58 |
-
|
| 59 |
-
|
| 60 |
#--------------------------------------------------------------------------------------------
|
| 61 |
# YOUR MODEL INFERENCE STOPS HERE
|
| 62 |
#--------------------------------------------------------------------------------------------
|
|
|
|
| 2 |
from datetime import datetime
|
| 3 |
from datasets import load_dataset
|
| 4 |
from sklearn.metrics import accuracy_score
|
| 5 |
+
import numpy as np
|
| 6 |
import random
|
| 7 |
import os
|
| 8 |
+
import torch
|
| 9 |
+
from torch.utils.data import DataLoader
|
| 10 |
|
| 11 |
from .utils.evaluation import AudioEvaluationRequest
|
| 12 |
from .utils.emissions import tracker, clean_emissions_data, get_space_info
|
| 13 |
+
from data import FFTDataset
|
| 14 |
+
from models import DualEncoder
|
| 15 |
+
from train import Trainer
|
| 16 |
+
from data_utils import collate_fn, Container
|
| 17 |
+
import yaml
|
| 18 |
|
| 19 |
from dotenv import load_dotenv
|
| 20 |
load_dotenv()
|
|
|
|
| 51 |
# Split dataset
|
| 52 |
train_test = dataset["train"].train_test_split(test_size=request.test_size, seed=request.test_seed)
|
| 53 |
test_dataset = train_test["test"]
|
| 54 |
+
|
| 55 |
# Start tracking emissions
|
| 56 |
tracker.start()
|
| 57 |
tracker.start_task("inference")
|
| 58 |
+
|
| 59 |
#--------------------------------------------------------------------------------------------
|
| 60 |
# YOUR MODEL INFERENCE CODE HERE
|
| 61 |
# Update the code below to replace the random baseline by your model inference within the inference pass where the energy consumption and emissions are tracked.
|
| 62 |
+
#--------------------------------------------------------------------------------------------
|
| 63 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 64 |
+
args_path = 'config.yaml'
|
| 65 |
+
data_args = Container(**yaml.safe_load(open(args_path, 'r'))['Data'])
|
| 66 |
+
model_args = Container(**yaml.safe_load(open(args_path, 'r'))['CNNEncoder'])
|
| 67 |
+
model_args_f = Container(**yaml.safe_load(open(args_path, 'r'))['CNNEncoder_f'])
|
| 68 |
+
conformer_args = Container(**yaml.safe_load(open(args_path, 'r'))['Conformer'])
|
| 69 |
+
|
| 70 |
+
test_dataset = FFTDataset(test_dataset)
|
| 71 |
+
test_dl = DataLoader(test_dataset, batch_size=data_args.batch_size, collate_fn=collate_fn)
|
| 72 |
+
|
| 73 |
+
model = DualEncoder(model_args, model_args_f, conformer_args)
|
| 74 |
+
model = model.to(device)
|
| 75 |
+
missing, unexpected = model.load_state_dict(torch.load(model_args.checkpoint_path))
|
| 76 |
+
|
| 77 |
+
loss_fn = torch.nn.BCEWithLogitsLoss()
|
| 78 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
|
| 79 |
+
trainer = Trainer(model=model, optimizer=optimizer,
|
| 80 |
+
criterion=loss_fn, output_dim=model_args.output_dim, scaler=None,
|
| 81 |
+
scheduler=None, train_dataloader=None,
|
| 82 |
+
val_dataloader=None, device=device,
|
| 83 |
+
exp_num='test', log_path=None,
|
| 84 |
+
range_update=None,
|
| 85 |
+
accumulation_step=1, max_iter=np.inf,
|
| 86 |
+
exp_name=f"frugal_cnnencoder_inference")
|
| 87 |
+
predictions, acc = trainer.predict(test_dl, device=device)
|
| 88 |
# Make random predictions (placeholder for actual model inference)
|
| 89 |
true_labels = test_dataset["label"]
|
| 90 |
+
|
|
|
|
| 91 |
#--------------------------------------------------------------------------------------------
|
| 92 |
# YOUR MODEL INFERENCE STOPS HERE
|
| 93 |
#--------------------------------------------------------------------------------------------
|
tasks/config.yaml
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Data:
|
| 2 |
+
# Basics
|
| 3 |
+
log_dir: '/data/frugal/logs'
|
| 4 |
+
# Data
|
| 5 |
+
dataset: "KeplerDataset"
|
| 6 |
+
data_dir: '/data/lightPred/data'
|
| 7 |
+
model_name: "CNNEncoder"
|
| 8 |
+
batch_size: 16
|
| 9 |
+
num_epochs: 1000
|
| 10 |
+
exp_num: 2
|
| 11 |
+
max_len_spectra: 4096
|
| 12 |
+
max_days_lc: 270
|
| 13 |
+
lc_freq: 0.0208
|
| 14 |
+
create_umap: True
|
| 15 |
+
|
| 16 |
+
CNNEncoder:
|
| 17 |
+
# Model
|
| 18 |
+
in_channels: 1
|
| 19 |
+
num_layers: 4
|
| 20 |
+
stride: 1
|
| 21 |
+
encoder_dims: [32,64,128,256]
|
| 22 |
+
kernel_size: 3
|
| 23 |
+
dropout_p: 0.3
|
| 24 |
+
output_dim: 2
|
| 25 |
+
beta: 1
|
| 26 |
+
load_checkpoint: True
|
| 27 |
+
checkpoint_num: 1
|
| 28 |
+
activation: "silu"
|
| 29 |
+
sine_w0: 1.0
|
| 30 |
+
avg_output: True
|
| 31 |
+
checkpoint_path: 'logs/frugal_2025-01-10/frugal_cnnencoder_2.pth'
|
| 32 |
+
|
| 33 |
+
CNNEncoder_f:
|
| 34 |
+
# Model
|
| 35 |
+
in_channels: 1
|
| 36 |
+
num_layers: 4
|
| 37 |
+
stride: 1
|
| 38 |
+
encoder_dims: [32,64,128]
|
| 39 |
+
kernel_size: 3
|
| 40 |
+
dropout_p: 0.3
|
| 41 |
+
output_dim: 2
|
| 42 |
+
beta: 1
|
| 43 |
+
load_checkpoint: True
|
| 44 |
+
checkpoint_num: 1
|
| 45 |
+
activation: "silu"
|
| 46 |
+
sine_w0: 1.0
|
| 47 |
+
avg_output: True
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
Conformer:
|
| 51 |
+
encoder: ["mhsa_pro", "conv"]
|
| 52 |
+
timeshift: false
|
| 53 |
+
num_layers: 8
|
| 54 |
+
encoder_dim: 128
|
| 55 |
+
num_heads: 8
|
| 56 |
+
kernel_size: 3
|
| 57 |
+
dropout_p: 0.2
|
| 58 |
+
norm: "postnorm"
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
Optimization:
|
| 62 |
+
# Optimization
|
| 63 |
+
max_lr: 1e-5
|
| 64 |
+
weight_decay: 5e-6
|
| 65 |
+
warmup_pct: 0.3
|
| 66 |
+
steps_per_epoch: 3500
|
tasks/data.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.utils.data import IterableDataset
|
| 3 |
+
from torch.fft import fft
|
| 4 |
+
from itertools import tee
|
| 5 |
+
import random
|
| 6 |
+
import torchaudio.transforms as T
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class SplitDataset(IterableDataset):
|
| 10 |
+
def __init__(self, dataset, is_train=True, train_ratio=0.8):
|
| 11 |
+
self.dataset = dataset
|
| 12 |
+
self.is_train = is_train
|
| 13 |
+
self.train_ratio = train_ratio
|
| 14 |
+
|
| 15 |
+
def __iter__(self):
|
| 16 |
+
count = 0
|
| 17 |
+
for item in self.dataset:
|
| 18 |
+
# For first train_ratio portion of items, yield to train
|
| 19 |
+
# For remaining items, yield to validation
|
| 20 |
+
is_train_item = count < int(self.train_ratio * 100)
|
| 21 |
+
if is_train_item == self.is_train:
|
| 22 |
+
yield item
|
| 23 |
+
count = (count + 1) % 100
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class FFTDataset(IterableDataset):
|
| 27 |
+
def __init__(self, original_dataset, orig_sample_rate=12000, target_sample_rate=6000):
|
| 28 |
+
self.dataset = original_dataset
|
| 29 |
+
self.resampler = T.Resample(orig_freq=orig_sample_rate, new_freq=target_sample_rate)
|
| 30 |
+
|
| 31 |
+
def __iter__(self):
|
| 32 |
+
for item in self.dataset:
|
| 33 |
+
# Assuming your audio data is in item['audio']
|
| 34 |
+
# Modify this based on your actual data structure
|
| 35 |
+
audio_data = torch.tensor(item['audio']['array']).float()
|
| 36 |
+
if len(audio_data) == 0:
|
| 37 |
+
continue
|
| 38 |
+
resampled_audio = self.resampler(audio_data)
|
| 39 |
+
fft_data = fft(resampled_audio)
|
| 40 |
+
|
| 41 |
+
# Update the item with FFT data
|
| 42 |
+
item['audio']['fft'] = fft_data
|
| 43 |
+
yield item
|
tasks/data_utils.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import torch.distributed as dist
|
| 4 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 5 |
+
|
| 6 |
+
def collate_fn(batch):
|
| 7 |
+
# Extract audio arrays and FFT data from the batch of dictionaries
|
| 8 |
+
audio_arrays = [torch.tensor(item['audio']['array']) for item in batch]
|
| 9 |
+
fft_arrays = [torch.tensor(item['audio']['fft']) for item in batch]
|
| 10 |
+
labels = [torch.tensor(item['label']) for item in batch]
|
| 11 |
+
|
| 12 |
+
# Pad both sequences
|
| 13 |
+
padded_audio = pad_sequence(audio_arrays, batch_first=True, padding_value=0)
|
| 14 |
+
padded_fft = pad_sequence(fft_arrays, batch_first=True, padding_value=0)
|
| 15 |
+
|
| 16 |
+
# Return as dictionary with the same structure
|
| 17 |
+
return {
|
| 18 |
+
'audio': {
|
| 19 |
+
'array': padded_audio,
|
| 20 |
+
'fft': padded_fft
|
| 21 |
+
},
|
| 22 |
+
'label': torch.stack(labels)
|
| 23 |
+
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
class Container(object):
|
| 27 |
+
'''A container class that can be used to store any attributes.'''
|
| 28 |
+
def __init__(self, **kwargs):
|
| 29 |
+
self.__dict__.update(kwargs)
|
| 30 |
+
|
| 31 |
+
def load_dict(self, dict):
|
| 32 |
+
for key, value in dict.items():
|
| 33 |
+
if getattr(self, key, None) is None:
|
| 34 |
+
setattr(self, key, value)
|
| 35 |
+
|
| 36 |
+
def print_attributes(self):
|
| 37 |
+
for key, value in vars(self).items():
|
| 38 |
+
print(f"{key}: {value}")
|
| 39 |
+
|
| 40 |
+
def get_dict(self):
|
| 41 |
+
return self.__dict__
|
| 42 |
+
|
| 43 |
+
def setup():
|
| 44 |
+
"""
|
| 45 |
+
Setup the distributed training environment.
|
| 46 |
+
"""
|
| 47 |
+
world_size = int(os.environ["WORLD_SIZE"])
|
| 48 |
+
rank = int(os.environ["SLURM_PROCID"])
|
| 49 |
+
jobid = int(os.environ["SLURM_JOBID"])
|
| 50 |
+
gpus_per_node = torch.cuda.device_count()
|
| 51 |
+
print('jobid ', jobid)
|
| 52 |
+
print('gpus per node ', gpus_per_node)
|
| 53 |
+
print(f"Hello from rank {rank} of {world_size} where there are" \
|
| 54 |
+
f" {gpus_per_node} allocated GPUs per node. ", flush=True)
|
| 55 |
+
|
| 56 |
+
# initialize the process group
|
| 57 |
+
dist.init_process_group("nccl", rank=rank, world_size=world_size)
|
| 58 |
+
|
| 59 |
+
if rank == 0: print(f"Group initialized? {dist.is_initialized()}", flush=True)
|
| 60 |
+
local_rank = rank - gpus_per_node * (rank // gpus_per_node)
|
| 61 |
+
torch.cuda.set_device(local_rank)
|
| 62 |
+
print(f"rank: {rank}, local_rank: {local_rank}")
|
| 63 |
+
return local_rank, world_size, gpus_per_node
|
tasks/models.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from Modules.conformer import ConformerEncoder, ConformerDecoder
|
| 4 |
+
from Modules.mhsa_pro import RotaryEmbedding, ContinuousRotaryEmbedding
|
| 5 |
+
|
| 6 |
+
class ConvBlock(nn.Module):
|
| 7 |
+
def __init__(self, args, num_layer) -> None:
|
| 8 |
+
super().__init__()
|
| 9 |
+
if args.activation == 'silu':
|
| 10 |
+
self.activation = nn.SiLU()
|
| 11 |
+
else:
|
| 12 |
+
self.activation = nn.ReLU()
|
| 13 |
+
in_channels = args.encoder_dims[num_layer-1] if num_layer < len(args.encoder_dims) else args.encoder_dims[-1]
|
| 14 |
+
out_channels = args.encoder_dims[num_layer] if num_layer < len(args.encoder_dims) else args.encoder_dims[-1]
|
| 15 |
+
self.layers = nn.Sequential(
|
| 16 |
+
nn.Conv1d(in_channels=in_channels,
|
| 17 |
+
out_channels=out_channels,
|
| 18 |
+
kernel_size=args.kernel_size,
|
| 19 |
+
stride=1, padding='same', bias=False),
|
| 20 |
+
nn.BatchNorm1d(num_features=out_channels),
|
| 21 |
+
self.activation,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 25 |
+
return self.layers(x)
|
| 26 |
+
|
| 27 |
+
class CNNEncoder(nn.Module):
|
| 28 |
+
def __init__(self, args) -> None:
|
| 29 |
+
super().__init__()
|
| 30 |
+
print("Using CNN encoder wit activation: ", args.activation, 'args avg_output: ', args.avg_output)
|
| 31 |
+
if args.activation == 'silu':
|
| 32 |
+
self.activation = nn.SiLU()
|
| 33 |
+
else:
|
| 34 |
+
self.activation = nn.ReLU()
|
| 35 |
+
self.embedding = nn.Sequential(nn.Conv1d(in_channels = args.in_channels,
|
| 36 |
+
kernel_size=3, out_channels = args.encoder_dims[0], stride=1, padding = 'same', bias = False),
|
| 37 |
+
nn.BatchNorm1d(args.encoder_dims[0]),
|
| 38 |
+
self.activation,
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
self.layers = nn.ModuleList([ConvBlock(args, i+1)
|
| 42 |
+
for i in range(args.num_layers)])
|
| 43 |
+
self.pool = nn.MaxPool1d(2)
|
| 44 |
+
self.output_dim = args.encoder_dims[-1]
|
| 45 |
+
self.min_seq_len = 2
|
| 46 |
+
self.avg_output = args.avg_output
|
| 47 |
+
|
| 48 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 49 |
+
if len(x.shape)==2:
|
| 50 |
+
x = x.unsqueeze(1)
|
| 51 |
+
if len(x.shape)==3 and x.shape[-1]==1:
|
| 52 |
+
x = x.permute(0,2,1)
|
| 53 |
+
x = self.embedding(x)
|
| 54 |
+
for m in self.layers:
|
| 55 |
+
x = m(x)
|
| 56 |
+
if x.shape[-1] > self.min_seq_len:
|
| 57 |
+
x = self.pool(x)
|
| 58 |
+
if self.avg_output:
|
| 59 |
+
x = x.mean(dim=-1)
|
| 60 |
+
return x
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class MultiEncoder(nn.Module):
|
| 64 |
+
def __init__(self, args, conformer_args):
|
| 65 |
+
super().__init__()
|
| 66 |
+
self.backbone = CNNEncoder(args)
|
| 67 |
+
self.backbone.avg_output = False
|
| 68 |
+
self.head_size = conformer_args.encoder_dim // conformer_args.num_heads
|
| 69 |
+
self.rotary_ndims = int(self.head_size * 0.5)
|
| 70 |
+
self.pe = RotaryEmbedding(self.rotary_ndims)
|
| 71 |
+
self.encoder = ConformerEncoder(conformer_args)
|
| 72 |
+
self.output_dim = conformer_args.encoder_dim
|
| 73 |
+
self.avg_output = args.avg_output
|
| 74 |
+
|
| 75 |
+
def forward(self, x):
|
| 76 |
+
# Store backbone output in a separate tensor
|
| 77 |
+
backbone_out = self.backbone(x)
|
| 78 |
+
|
| 79 |
+
# Create x_enc from backbone_out
|
| 80 |
+
if len(backbone_out.shape) == 2:
|
| 81 |
+
x_enc = backbone_out.unsqueeze(1).clone()
|
| 82 |
+
else:
|
| 83 |
+
x_enc = backbone_out.permute(0,2,1).clone()
|
| 84 |
+
|
| 85 |
+
RoPE = self.pe(x_enc, x_enc.shape[1])
|
| 86 |
+
x_enc = self.encoder(x_enc, RoPE)
|
| 87 |
+
|
| 88 |
+
if len(x_enc.shape) == 3:
|
| 89 |
+
if self.avg_output:
|
| 90 |
+
x_enc = x_enc.sum(dim=1)
|
| 91 |
+
else:
|
| 92 |
+
x_enc = x_enc.permute(0,2,1)
|
| 93 |
+
|
| 94 |
+
# Return x_enc and the original backbone output
|
| 95 |
+
return x_enc, backbone_out
|
| 96 |
+
|
| 97 |
+
class DualEncoder(nn.Module):
|
| 98 |
+
def __init__(self, args_x, args_f, conformer_args) -> None:
|
| 99 |
+
super().__init__()
|
| 100 |
+
self.encoder_x = CNNEncoder(args_x)
|
| 101 |
+
self.encoder_f = MultiEncoder(args_f, conformer_args)
|
| 102 |
+
total_output_dim = args_x.encoder_dims[-1] + args_f.encoder_dims[-1]
|
| 103 |
+
self.regressor = nn.Sequential(
|
| 104 |
+
nn.Linear(total_output_dim, total_output_dim//2),
|
| 105 |
+
nn.BatchNorm1d(total_output_dim//2),
|
| 106 |
+
nn.SiLU(),
|
| 107 |
+
nn.Linear(total_output_dim//2, 1)
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 111 |
+
x1 = self.encoder_x(x)
|
| 112 |
+
x2, _ = self.encoder_f(x)
|
| 113 |
+
logits = torch.cat([x1, x2], dim=-1)
|
| 114 |
+
return self.regressor(logits).squeeze()
|
tasks/train.py
ADDED
|
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.cuda.amp import autocast
|
| 3 |
+
import numpy as np
|
| 4 |
+
import time
|
| 5 |
+
import os
|
| 6 |
+
import yaml
|
| 7 |
+
from matplotlib import pyplot as plt
|
| 8 |
+
import glob
|
| 9 |
+
from collections import OrderedDict
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
import torch.distributed as dist
|
| 12 |
+
import umap
|
| 13 |
+
|
| 14 |
+
class Trainer(object):
|
| 15 |
+
"""
|
| 16 |
+
A class that encapsulates the training loop for a PyTorch model.
|
| 17 |
+
"""
|
| 18 |
+
def __init__(self, model, optimizer, criterion, train_dataloader, device, world_size=1, output_dim=2,
|
| 19 |
+
scheduler=None, val_dataloader=None, max_iter=np.inf, scaler=None,
|
| 20 |
+
grad_clip=False, exp_num=None, log_path=None, exp_name=None, plot_every=None,
|
| 21 |
+
cos_inc=False, range_update=None, accumulation_step=1, wandb_log=False, num_quantiles=1,
|
| 22 |
+
update_func=lambda x: x):
|
| 23 |
+
self.model = model
|
| 24 |
+
self.optimizer = optimizer
|
| 25 |
+
self.criterion = criterion
|
| 26 |
+
self.scaler = scaler
|
| 27 |
+
self.grad_clip = grad_clip
|
| 28 |
+
self.cos_inc = cos_inc
|
| 29 |
+
self.output_dim = output_dim
|
| 30 |
+
self.scheduler = scheduler
|
| 31 |
+
self.train_dl = train_dataloader
|
| 32 |
+
self.val_dl = val_dataloader
|
| 33 |
+
self.train_sampler = self.get_sampler_from_dataloader(train_dataloader)
|
| 34 |
+
self.val_sampler = self.get_sampler_from_dataloader(val_dataloader)
|
| 35 |
+
self.max_iter = max_iter
|
| 36 |
+
self.device = device
|
| 37 |
+
self.world_size = world_size
|
| 38 |
+
self.exp_num = exp_num
|
| 39 |
+
self.exp_name = exp_name
|
| 40 |
+
self.log_path = log_path
|
| 41 |
+
self.best_state_dict = None
|
| 42 |
+
self.plot_every = plot_every
|
| 43 |
+
self.logger = None
|
| 44 |
+
self.range_update = range_update
|
| 45 |
+
self.accumulation_step = accumulation_step
|
| 46 |
+
self.wandb = wandb_log
|
| 47 |
+
self.num_quantiles = num_quantiles
|
| 48 |
+
self.update_func = update_func
|
| 49 |
+
# if log_path is not None:
|
| 50 |
+
# self.logger =SummaryWriter(f'{self.log_path}/exp{self.exp_num}')
|
| 51 |
+
# # print(f"logger path: {self.log_path}/exp{self.exp_num}")
|
| 52 |
+
|
| 53 |
+
# print("logger is: ", self.logger)
|
| 54 |
+
|
| 55 |
+
def get_sampler_from_dataloader(self, dataloader):
|
| 56 |
+
if hasattr(dataloader, 'sampler'):
|
| 57 |
+
if isinstance(dataloader.sampler, torch.utils.data.DistributedSampler):
|
| 58 |
+
return dataloader.sampler
|
| 59 |
+
elif hasattr(dataloader.sampler, 'sampler'):
|
| 60 |
+
return dataloader.sampler.sampler
|
| 61 |
+
|
| 62 |
+
if hasattr(dataloader, 'batch_sampler') and hasattr(dataloader.batch_sampler, 'sampler'):
|
| 63 |
+
return dataloader.batch_sampler.sampler
|
| 64 |
+
|
| 65 |
+
return None
|
| 66 |
+
|
| 67 |
+
def fit(self, num_epochs, device, early_stopping=None, only_p=False, best='loss', conf=False):
|
| 68 |
+
"""
|
| 69 |
+
Fits the model for the given number of epochs.
|
| 70 |
+
"""
|
| 71 |
+
min_loss = np.inf
|
| 72 |
+
best_acc = 0
|
| 73 |
+
train_loss, val_loss, = [], []
|
| 74 |
+
train_acc, val_acc = [], []
|
| 75 |
+
lrs = []
|
| 76 |
+
# self.optim_params['lr_history'] = []
|
| 77 |
+
epochs_without_improvement = 0
|
| 78 |
+
main_proccess = (torch.distributed.is_initialized() and torch.distributed.get_rank() == 0) or self.device == 'cpu'
|
| 79 |
+
|
| 80 |
+
print(f"Starting training for {num_epochs} epochs")
|
| 81 |
+
print("is main process: ", main_proccess, flush=True)
|
| 82 |
+
global_time = time.time()
|
| 83 |
+
self.epoch = 0
|
| 84 |
+
for epoch in range(num_epochs):
|
| 85 |
+
self.epoch = epoch
|
| 86 |
+
start_time = time.time()
|
| 87 |
+
plot = (self.plot_every is not None) and (epoch % self.plot_every == 0)
|
| 88 |
+
t_loss, t_acc = self.train_epoch(device, epoch=epoch)
|
| 89 |
+
t_loss_mean = np.nanmean(t_loss)
|
| 90 |
+
train_loss.extend(t_loss)
|
| 91 |
+
global_train_accuracy, global_train_loss = self.process_loss(t_acc, t_loss_mean)
|
| 92 |
+
if main_proccess: # Only perform this on the master GPU
|
| 93 |
+
train_acc.append(global_train_accuracy.mean().item())
|
| 94 |
+
|
| 95 |
+
v_loss, v_acc = self.eval_epoch(device, epoch=epoch)
|
| 96 |
+
v_loss_mean = np.nanmean(v_loss)
|
| 97 |
+
val_loss.extend(v_loss)
|
| 98 |
+
global_val_accuracy, global_val_loss = self.process_loss(v_acc, v_loss_mean)
|
| 99 |
+
if main_proccess: # Only perform this on the master GPU
|
| 100 |
+
val_acc.append(global_val_accuracy.mean().item())
|
| 101 |
+
|
| 102 |
+
current_objective = global_val_loss if best == 'loss' else global_val_accuracy.mean()
|
| 103 |
+
improved = False
|
| 104 |
+
|
| 105 |
+
if best == 'loss':
|
| 106 |
+
if current_objective < min_loss:
|
| 107 |
+
min_loss = current_objective
|
| 108 |
+
improved = True
|
| 109 |
+
else:
|
| 110 |
+
if current_objective > best_acc:
|
| 111 |
+
best_acc = current_objective
|
| 112 |
+
improved = True
|
| 113 |
+
|
| 114 |
+
if improved:
|
| 115 |
+
model_name = f'{self.log_path}/{self.exp_num}/{self.exp_name}.pth'
|
| 116 |
+
print(f"saving model at {model_name}...")
|
| 117 |
+
torch.save(self.model.state_dict(), model_name)
|
| 118 |
+
self.best_state_dict = self.model.state_dict()
|
| 119 |
+
epochs_without_improvement = 0
|
| 120 |
+
else:
|
| 121 |
+
epochs_without_improvement += 1
|
| 122 |
+
|
| 123 |
+
current_lr = self.optimizer.param_groups[0]['lr'] if self.scheduler is None \
|
| 124 |
+
else self.scheduler.get_last_lr()[0]
|
| 125 |
+
|
| 126 |
+
lrs.append(current_lr)
|
| 127 |
+
|
| 128 |
+
print(f'Epoch {epoch}, lr {current_lr}, Train Loss: {global_train_loss:.6f}, Val Loss:'\
|
| 129 |
+
f'{global_val_loss:.6f}, Train Acc: {global_train_accuracy.round(decimals=4).tolist()}, '\
|
| 130 |
+
f'Val Acc: {global_val_accuracy.round(decimals=4).tolist()},'\
|
| 131 |
+
f'Time: {time.time() - start_time:.2f}s, Total Time: {(time.time() - global_time)/3600} hr', flush=True)
|
| 132 |
+
if epoch % 10 == 0:
|
| 133 |
+
print(os.system('nvidia-smi'))
|
| 134 |
+
|
| 135 |
+
if epochs_without_improvement == early_stopping:
|
| 136 |
+
print('early stopping!', flush=True)
|
| 137 |
+
break
|
| 138 |
+
if time.time() - global_time > (23.83 * 3600):
|
| 139 |
+
print("time limit reached")
|
| 140 |
+
break
|
| 141 |
+
|
| 142 |
+
return {"num_epochs":num_epochs, "train_loss": train_loss,
|
| 143 |
+
"val_loss": val_loss, "train_acc": train_acc, "val_acc": val_acc, "lrs": lrs}
|
| 144 |
+
|
| 145 |
+
def process_loss(self, acc, loss_mean):
|
| 146 |
+
if torch.cuda.is_available() and torch.distributed.is_initialized():
|
| 147 |
+
global_accuracy = torch.tensor(acc).cuda() # Convert accuracy to a tensor on the GPU
|
| 148 |
+
torch.distributed.reduce(global_accuracy, dst=0, op=torch.distributed.ReduceOp.SUM)
|
| 149 |
+
global_loss = torch.tensor(loss_mean).cuda() # Convert loss to a tensor on the GPU
|
| 150 |
+
torch.distributed.reduce(global_loss, dst=0, op=torch.distributed.ReduceOp.SUM)
|
| 151 |
+
|
| 152 |
+
# Divide both loss and accuracy by world size
|
| 153 |
+
world_size = torch.distributed.get_world_size()
|
| 154 |
+
global_loss /= world_size
|
| 155 |
+
global_accuracy /= world_size
|
| 156 |
+
else:
|
| 157 |
+
global_loss = torch.tensor(loss_mean)
|
| 158 |
+
global_accuracy = torch.tensor(acc)
|
| 159 |
+
return global_accuracy, global_loss
|
| 160 |
+
|
| 161 |
+
def load_best_model(self, to_ddp=True, from_ddp=True):
|
| 162 |
+
data_dir = f'{self.log_path}/exp{self.exp_num}'
|
| 163 |
+
# data_dir = f'{self.log_path}/exp29' # for debugging
|
| 164 |
+
|
| 165 |
+
state_dict_files = glob.glob(data_dir + '/*.pth')
|
| 166 |
+
print("loading model from ", state_dict_files[-1])
|
| 167 |
+
|
| 168 |
+
state_dict = torch.load(state_dict_files[-1]) if to_ddp else torch.load(state_dict_files[0],map_location=self.device)
|
| 169 |
+
|
| 170 |
+
if from_ddp:
|
| 171 |
+
print("loading distributed model")
|
| 172 |
+
# Remove "module." from keys
|
| 173 |
+
new_state_dict = OrderedDict()
|
| 174 |
+
for key, value in state_dict.items():
|
| 175 |
+
if key.startswith('module.'):
|
| 176 |
+
while key.startswith('module.'):
|
| 177 |
+
key = key[7:]
|
| 178 |
+
new_state_dict[key] = value
|
| 179 |
+
state_dict = new_state_dict
|
| 180 |
+
# print("state_dict: ", state_dict.keys())
|
| 181 |
+
# print("model: ", self.model.state_dict().keys())
|
| 182 |
+
|
| 183 |
+
self.model.load_state_dict(state_dict, strict=False)
|
| 184 |
+
|
| 185 |
+
def check_gradients(self):
|
| 186 |
+
for name, param in self.model.named_parameters():
|
| 187 |
+
if param.grad is not None:
|
| 188 |
+
grad_norm = param.grad.norm().item()
|
| 189 |
+
if grad_norm > 10:
|
| 190 |
+
print(f"Large gradient in {name}: {grad_norm}")
|
| 191 |
+
|
| 192 |
+
def train_epoch(self, device, epoch):
|
| 193 |
+
"""
|
| 194 |
+
Trains the model for one epoch.
|
| 195 |
+
"""
|
| 196 |
+
if self.train_sampler is not None:
|
| 197 |
+
try:
|
| 198 |
+
self.train_sampler.set_epoch(epoch)
|
| 199 |
+
except AttributeError:
|
| 200 |
+
pass
|
| 201 |
+
self.model.train()
|
| 202 |
+
train_loss = []
|
| 203 |
+
train_acc = 0
|
| 204 |
+
total = 0
|
| 205 |
+
all_accs = torch.zeros(self.output_dim, device=device)
|
| 206 |
+
pbar = tqdm(self.train_dl)
|
| 207 |
+
for i, batch in enumerate(pbar):
|
| 208 |
+
if self.optimizer is not None:
|
| 209 |
+
self.optimizer.zero_grad()
|
| 210 |
+
loss, acc , y = self.train_batch(batch, i, device)
|
| 211 |
+
train_loss.append(loss.item())
|
| 212 |
+
all_accs = all_accs + acc
|
| 213 |
+
total += len(y)
|
| 214 |
+
pbar.set_description(f"train_acc: {acc}, train_loss: {loss.item()}")
|
| 215 |
+
if i > self.max_iter:
|
| 216 |
+
break
|
| 217 |
+
print("number of train_accs: ", train_acc)
|
| 218 |
+
return train_loss, all_accs/total
|
| 219 |
+
|
| 220 |
+
def train_batch(self, batch, batch_idx, device):
|
| 221 |
+
x, fft, y = batch['audio']['array'], batch['audio']['fft'], batch['label']
|
| 222 |
+
x = x.to(device).float()
|
| 223 |
+
fft = fft.to(device).float()
|
| 224 |
+
y = y.to(device).float()
|
| 225 |
+
y_pred = self.model(fft)
|
| 226 |
+
loss = self.criterion(y_pred, y)
|
| 227 |
+
loss.backward()
|
| 228 |
+
self.optimizer.step()
|
| 229 |
+
if self.scheduler is not None:
|
| 230 |
+
self.scheduler.step()
|
| 231 |
+
# get predicted classes
|
| 232 |
+
probs = torch.sigmoid(y_pred)
|
| 233 |
+
cls_pred = (probs > 0.5).float()
|
| 234 |
+
acc = (cls_pred == y).sum()
|
| 235 |
+
return loss, acc, y
|
| 236 |
+
|
| 237 |
+
def eval_epoch(self, device, epoch):
|
| 238 |
+
"""
|
| 239 |
+
Evaluates the model for one epoch.
|
| 240 |
+
"""
|
| 241 |
+
self.model.eval()
|
| 242 |
+
val_loss = []
|
| 243 |
+
val_acc = 0
|
| 244 |
+
total = 0
|
| 245 |
+
all_accs = torch.zeros(self.output_dim, device=device)
|
| 246 |
+
pbar = tqdm(self.val_dl)
|
| 247 |
+
for i,batch in enumerate(pbar):
|
| 248 |
+
loss, acc, y = self.eval_batch(batch, i, device)
|
| 249 |
+
val_loss.append(loss.item())
|
| 250 |
+
all_accs = all_accs + acc
|
| 251 |
+
total += len(y)
|
| 252 |
+
pbar.set_description(f"val_acc: {acc}, val_loss: {loss.item()}")
|
| 253 |
+
if i > self.max_iter:
|
| 254 |
+
break
|
| 255 |
+
return val_loss, all_accs/total
|
| 256 |
+
|
| 257 |
+
def eval_batch(self, batch, batch_idx, device):
|
| 258 |
+
x, fft, y = batch['audio']['array'], batch['audio']['fft'], batch['label']
|
| 259 |
+
x = x.to(device).float()
|
| 260 |
+
fft = fft.to(device).float()
|
| 261 |
+
y = y.to(device).float()
|
| 262 |
+
with torch.no_grad():
|
| 263 |
+
y_pred = self.model(fft)
|
| 264 |
+
loss = self.criterion(y_pred, y)
|
| 265 |
+
probs = torch.sigmoid(y_pred)
|
| 266 |
+
cls_pred = (probs > 0.5).float()
|
| 267 |
+
acc = (cls_pred == y).sum()
|
| 268 |
+
return loss, acc, y
|
| 269 |
+
|
| 270 |
+
def predict(self, test_dataloader, device):
|
| 271 |
+
"""
|
| 272 |
+
Returns the predictions of the model on the given dataset.
|
| 273 |
+
"""
|
| 274 |
+
self.model.eval()
|
| 275 |
+
total = 0
|
| 276 |
+
all_accs = 0
|
| 277 |
+
predictions = []
|
| 278 |
+
pbar = tqdm(self.val_dl)
|
| 279 |
+
for i,batch in enumerate(pbar):
|
| 280 |
+
x, fft, y = batch['audio']['array'], batch['audio']['fft'], batch['label']
|
| 281 |
+
x = x.to(device).float()
|
| 282 |
+
fft = fft.to(device).float()
|
| 283 |
+
y = y.to(device).float()
|
| 284 |
+
with torch.no_grad():
|
| 285 |
+
y_pred = self.model(fft)
|
| 286 |
+
loss = self.criterion(y_pred, y)
|
| 287 |
+
probs = torch.sigmoid(y_pred)
|
| 288 |
+
cls_pred = (probs > 0.5).float()
|
| 289 |
+
acc = (cls_pred == y).sum()
|
| 290 |
+
predictions.append(cls_pred)
|
| 291 |
+
all_accs += acc
|
| 292 |
+
total += len(y)
|
| 293 |
+
return predictions, all_accs/total
|