Spaces:
Build error
Build error
Upload 16 files
Browse files- sound_extraction/model/LASSNet.py +25 -0
- sound_extraction/model/__pycache__/LASSNet.cpython-38.pyc +0 -0
- sound_extraction/model/__pycache__/film.cpython-38.pyc +0 -0
- sound_extraction/model/__pycache__/modules.cpython-38.pyc +0 -0
- sound_extraction/model/__pycache__/resunet_film.cpython-38.pyc +0 -0
- sound_extraction/model/__pycache__/text_encoder.cpython-38.pyc +0 -0
- sound_extraction/model/film.py +27 -0
- sound_extraction/model/modules.py +483 -0
- sound_extraction/model/resunet_film.py +110 -0
- sound_extraction/model/text_encoder.py +45 -0
- sound_extraction/useful_ckpts/LASSNet.pt +3 -0
- sound_extraction/utils/__pycache__/stft.cpython-38.pyc +0 -0
- sound_extraction/utils/__pycache__/wav_io.cpython-38.pyc +0 -0
- sound_extraction/utils/create_mixtures.py +98 -0
- sound_extraction/utils/stft.py +159 -0
- sound_extraction/utils/wav_io.py +23 -0
sound_extraction/model/LASSNet.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from .text_encoder import Text_Encoder
|
5 |
+
from .resunet_film import UNetRes_FiLM
|
6 |
+
|
7 |
+
class LASSNet(nn.Module):
|
8 |
+
def __init__(self, device='cuda'):
|
9 |
+
super(LASSNet, self).__init__()
|
10 |
+
self.text_embedder = Text_Encoder(device)
|
11 |
+
self.UNet = UNetRes_FiLM(channels=1, cond_embedding_dim=256)
|
12 |
+
|
13 |
+
def forward(self, x, caption):
|
14 |
+
# x: (Batch, 1, T, 128))
|
15 |
+
input_ids, attns_mask = self.text_embedder.tokenize(caption)
|
16 |
+
|
17 |
+
cond_vec = self.text_embedder(input_ids, attns_mask)[0]
|
18 |
+
dec_cond_vec = cond_vec
|
19 |
+
|
20 |
+
mask = self.UNet(x, cond_vec, dec_cond_vec)
|
21 |
+
mask = torch.sigmoid(mask)
|
22 |
+
return mask
|
23 |
+
|
24 |
+
def get_tokenizer(self):
|
25 |
+
return self.text_embedder.tokenizer
|
sound_extraction/model/__pycache__/LASSNet.cpython-38.pyc
ADDED
Binary file (1.27 kB). View file
|
|
sound_extraction/model/__pycache__/film.cpython-38.pyc
ADDED
Binary file (1.26 kB). View file
|
|
sound_extraction/model/__pycache__/modules.cpython-38.pyc
ADDED
Binary file (14.7 kB). View file
|
|
sound_extraction/model/__pycache__/resunet_film.cpython-38.pyc
ADDED
Binary file (3.26 kB). View file
|
|
sound_extraction/model/__pycache__/text_encoder.cpython-38.pyc
ADDED
Binary file (1.69 kB). View file
|
|
sound_extraction/model/film.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
class Film(nn.Module):
|
5 |
+
def __init__(self, channels, cond_embedding_dim):
|
6 |
+
super(Film, self).__init__()
|
7 |
+
self.linear = nn.Sequential(
|
8 |
+
nn.Linear(cond_embedding_dim, channels * 2),
|
9 |
+
nn.ReLU(inplace=True),
|
10 |
+
nn.Linear(channels * 2, channels),
|
11 |
+
nn.ReLU(inplace=True)
|
12 |
+
)
|
13 |
+
|
14 |
+
def forward(self, data, cond_vec):
|
15 |
+
"""
|
16 |
+
:param data: [batchsize, channels, samples] or [batchsize, channels, T, F] or [batchsize, channels, F, T]
|
17 |
+
:param cond_vec: [batchsize, cond_embedding_dim]
|
18 |
+
:return:
|
19 |
+
"""
|
20 |
+
bias = self.linear(cond_vec) # [batchsize, channels]
|
21 |
+
if len(list(data.size())) == 3:
|
22 |
+
data = data + bias[..., None]
|
23 |
+
elif len(list(data.size())) == 4:
|
24 |
+
data = data + bias[..., None, None]
|
25 |
+
else:
|
26 |
+
print("Warning: The size of input tensor,", data.size(), "is not correct. Film is not working.")
|
27 |
+
return data
|
sound_extraction/model/modules.py
ADDED
@@ -0,0 +1,483 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import math
|
5 |
+
from .film import Film
|
6 |
+
|
7 |
+
class ConvBlock(nn.Module):
|
8 |
+
def __init__(self, in_channels, out_channels, kernel_size, activation, momentum):
|
9 |
+
super(ConvBlock, self).__init__()
|
10 |
+
|
11 |
+
self.activation = activation
|
12 |
+
padding = (kernel_size[0] // 2, kernel_size[1] // 2)
|
13 |
+
|
14 |
+
self.conv1 = nn.Conv2d(
|
15 |
+
in_channels=in_channels,
|
16 |
+
out_channels=out_channels,
|
17 |
+
kernel_size=kernel_size,
|
18 |
+
stride=(1, 1),
|
19 |
+
dilation=(1, 1),
|
20 |
+
padding=padding,
|
21 |
+
bias=False,
|
22 |
+
)
|
23 |
+
|
24 |
+
self.bn1 = nn.BatchNorm2d(out_channels, momentum=momentum)
|
25 |
+
|
26 |
+
self.conv2 = nn.Conv2d(
|
27 |
+
in_channels=out_channels,
|
28 |
+
out_channels=out_channels,
|
29 |
+
kernel_size=kernel_size,
|
30 |
+
stride=(1, 1),
|
31 |
+
dilation=(1, 1),
|
32 |
+
padding=padding,
|
33 |
+
bias=False,
|
34 |
+
)
|
35 |
+
|
36 |
+
self.bn2 = nn.BatchNorm2d(out_channels, momentum=momentum)
|
37 |
+
|
38 |
+
self.init_weights()
|
39 |
+
|
40 |
+
def init_weights(self):
|
41 |
+
init_layer(self.conv1)
|
42 |
+
init_layer(self.conv2)
|
43 |
+
init_bn(self.bn1)
|
44 |
+
init_bn(self.bn2)
|
45 |
+
|
46 |
+
def forward(self, x):
|
47 |
+
x = act(self.bn1(self.conv1(x)), self.activation)
|
48 |
+
x = act(self.bn2(self.conv2(x)), self.activation)
|
49 |
+
return x
|
50 |
+
|
51 |
+
|
52 |
+
class EncoderBlock(nn.Module):
|
53 |
+
def __init__(self, in_channels, out_channels, kernel_size, downsample, activation, momentum):
|
54 |
+
super(EncoderBlock, self).__init__()
|
55 |
+
|
56 |
+
self.conv_block = ConvBlock(
|
57 |
+
in_channels, out_channels, kernel_size, activation, momentum
|
58 |
+
)
|
59 |
+
self.downsample = downsample
|
60 |
+
|
61 |
+
def forward(self, x):
|
62 |
+
encoder = self.conv_block(x)
|
63 |
+
encoder_pool = F.avg_pool2d(encoder, kernel_size=self.downsample)
|
64 |
+
return encoder_pool, encoder
|
65 |
+
|
66 |
+
|
67 |
+
class DecoderBlock(nn.Module):
|
68 |
+
def __init__(self, in_channels, out_channels, kernel_size, upsample, activation, momentum):
|
69 |
+
super(DecoderBlock, self).__init__()
|
70 |
+
self.kernel_size = kernel_size
|
71 |
+
self.stride = upsample
|
72 |
+
self.activation = activation
|
73 |
+
|
74 |
+
self.conv1 = torch.nn.ConvTranspose2d(
|
75 |
+
in_channels=in_channels,
|
76 |
+
out_channels=out_channels,
|
77 |
+
kernel_size=self.stride,
|
78 |
+
stride=self.stride,
|
79 |
+
padding=(0, 0),
|
80 |
+
bias=False,
|
81 |
+
dilation=(1, 1),
|
82 |
+
)
|
83 |
+
|
84 |
+
self.bn1 = nn.BatchNorm2d(out_channels, momentum=momentum)
|
85 |
+
|
86 |
+
self.conv_block2 = ConvBlock(
|
87 |
+
out_channels * 2, out_channels, kernel_size, activation, momentum
|
88 |
+
)
|
89 |
+
|
90 |
+
def init_weights(self):
|
91 |
+
init_layer(self.conv1)
|
92 |
+
init_bn(self.bn)
|
93 |
+
|
94 |
+
def prune(self, x):
|
95 |
+
"""Prune the shape of x after transpose convolution."""
|
96 |
+
padding = (self.kernel_size[0] // 2, self.kernel_size[1] // 2)
|
97 |
+
x = x[
|
98 |
+
:,
|
99 |
+
:,
|
100 |
+
padding[0] : padding[0] - self.stride[0],
|
101 |
+
padding[1] : padding[1] - self.stride[1]]
|
102 |
+
return x
|
103 |
+
|
104 |
+
def forward(self, input_tensor, concat_tensor):
|
105 |
+
x = act(self.bn1(self.conv1(input_tensor)), self.activation)
|
106 |
+
# from IPython import embed; embed(using=False); os._exit(0)
|
107 |
+
# x = self.prune(x)
|
108 |
+
x = torch.cat((x, concat_tensor), dim=1)
|
109 |
+
x = self.conv_block2(x)
|
110 |
+
return x
|
111 |
+
|
112 |
+
|
113 |
+
class EncoderBlockRes1B(nn.Module):
|
114 |
+
def __init__(self, in_channels, out_channels, downsample, activation, momentum):
|
115 |
+
super(EncoderBlockRes1B, self).__init__()
|
116 |
+
size = (3,3)
|
117 |
+
|
118 |
+
self.conv_block1 = ConvBlockRes(in_channels, out_channels, size, activation, momentum)
|
119 |
+
self.conv_block2 = ConvBlockRes(out_channels, out_channels, size, activation, momentum)
|
120 |
+
self.conv_block3 = ConvBlockRes(out_channels, out_channels, size, activation, momentum)
|
121 |
+
self.conv_block4 = ConvBlockRes(out_channels, out_channels, size, activation, momentum)
|
122 |
+
self.downsample = downsample
|
123 |
+
|
124 |
+
def forward(self, x):
|
125 |
+
encoder = self.conv_block1(x)
|
126 |
+
encoder = self.conv_block2(encoder)
|
127 |
+
encoder = self.conv_block3(encoder)
|
128 |
+
encoder = self.conv_block4(encoder)
|
129 |
+
encoder_pool = F.avg_pool2d(encoder, kernel_size=self.downsample)
|
130 |
+
return encoder_pool, encoder
|
131 |
+
|
132 |
+
class DecoderBlockRes1B(nn.Module):
|
133 |
+
def __init__(self, in_channels, out_channels, stride, activation, momentum):
|
134 |
+
super(DecoderBlockRes1B, self).__init__()
|
135 |
+
size = (3,3)
|
136 |
+
self.activation = activation
|
137 |
+
|
138 |
+
self.conv1 = torch.nn.ConvTranspose2d(in_channels=in_channels,
|
139 |
+
out_channels=out_channels, kernel_size=size, stride=stride,
|
140 |
+
padding=(0, 0), output_padding=(0, 0), bias=False, dilation=1)
|
141 |
+
|
142 |
+
self.bn1 = nn.BatchNorm2d(in_channels)
|
143 |
+
self.conv_block2 = ConvBlockRes(out_channels * 2, out_channels, size, activation, momentum)
|
144 |
+
self.conv_block3 = ConvBlockRes(out_channels, out_channels, size, activation, momentum)
|
145 |
+
self.conv_block4 = ConvBlockRes(out_channels, out_channels, size, activation, momentum)
|
146 |
+
self.conv_block5 = ConvBlockRes(out_channels, out_channels, size, activation, momentum)
|
147 |
+
|
148 |
+
def init_weights(self):
|
149 |
+
init_layer(self.conv1)
|
150 |
+
|
151 |
+
def prune(self, x, both=False):
|
152 |
+
"""Prune the shape of x after transpose convolution.
|
153 |
+
"""
|
154 |
+
if(both): x = x[:, :, 0 : - 1, 0:-1]
|
155 |
+
else: x = x[:, :, 0: - 1, :]
|
156 |
+
return x
|
157 |
+
|
158 |
+
def forward(self, input_tensor, concat_tensor,both=False):
|
159 |
+
x = self.conv1(F.relu_(self.bn1(input_tensor)))
|
160 |
+
x = self.prune(x,both=both)
|
161 |
+
x = torch.cat((x, concat_tensor), dim=1)
|
162 |
+
x = self.conv_block2(x)
|
163 |
+
x = self.conv_block3(x)
|
164 |
+
x = self.conv_block4(x)
|
165 |
+
x = self.conv_block5(x)
|
166 |
+
return x
|
167 |
+
|
168 |
+
|
169 |
+
class EncoderBlockRes2BCond(nn.Module):
|
170 |
+
def __init__(self, in_channels, out_channels, downsample, activation, momentum, cond_embedding_dim):
|
171 |
+
super(EncoderBlockRes2BCond, self).__init__()
|
172 |
+
size = (3, 3)
|
173 |
+
|
174 |
+
self.conv_block1 = ConvBlockResCond(in_channels, out_channels, size, activation, momentum, cond_embedding_dim)
|
175 |
+
self.conv_block2 = ConvBlockResCond(out_channels, out_channels, size, activation, momentum, cond_embedding_dim)
|
176 |
+
self.downsample = downsample
|
177 |
+
|
178 |
+
def forward(self, x, cond_vec):
|
179 |
+
encoder = self.conv_block1(x, cond_vec)
|
180 |
+
encoder = self.conv_block2(encoder, cond_vec)
|
181 |
+
encoder_pool = F.avg_pool2d(encoder, kernel_size=self.downsample)
|
182 |
+
return encoder_pool, encoder
|
183 |
+
|
184 |
+
class DecoderBlockRes2BCond(nn.Module):
|
185 |
+
def __init__(self, in_channels, out_channels, stride, activation, momentum, cond_embedding_dim):
|
186 |
+
super(DecoderBlockRes2BCond, self).__init__()
|
187 |
+
size = (3, 3)
|
188 |
+
self.activation = activation
|
189 |
+
|
190 |
+
self.conv1 = torch.nn.ConvTranspose2d(in_channels=in_channels,
|
191 |
+
out_channels=out_channels, kernel_size=size, stride=stride,
|
192 |
+
padding=(0, 0), output_padding=(0, 0), bias=False, dilation=1)
|
193 |
+
|
194 |
+
self.bn1 = nn.BatchNorm2d(in_channels)
|
195 |
+
self.conv_block2 = ConvBlockResCond(out_channels * 2, out_channels, size, activation, momentum, cond_embedding_dim)
|
196 |
+
self.conv_block3 = ConvBlockResCond(out_channels, out_channels, size, activation, momentum, cond_embedding_dim)
|
197 |
+
|
198 |
+
def init_weights(self):
|
199 |
+
init_layer(self.conv1)
|
200 |
+
|
201 |
+
def prune(self, x, both=False):
|
202 |
+
"""Prune the shape of x after transpose convolution.
|
203 |
+
"""
|
204 |
+
if(both): x = x[:, :, 0 : - 1, 0:-1]
|
205 |
+
else: x = x[:, :, 0: - 1, :]
|
206 |
+
return x
|
207 |
+
|
208 |
+
def forward(self, input_tensor, concat_tensor, cond_vec, both=False):
|
209 |
+
x = self.conv1(F.relu_(self.bn1(input_tensor)))
|
210 |
+
x = self.prune(x, both=both)
|
211 |
+
x = torch.cat((x, concat_tensor), dim=1)
|
212 |
+
x = self.conv_block2(x, cond_vec)
|
213 |
+
x = self.conv_block3(x, cond_vec)
|
214 |
+
return x
|
215 |
+
|
216 |
+
class EncoderBlockRes4BCond(nn.Module):
|
217 |
+
def __init__(self, in_channels, out_channels, downsample, activation, momentum, cond_embedding_dim):
|
218 |
+
super(EncoderBlockRes4B, self).__init__()
|
219 |
+
size = (3,3)
|
220 |
+
|
221 |
+
self.conv_block1 = ConvBlockResCond(in_channels, out_channels, size, activation, momentum, cond_embedding_dim)
|
222 |
+
self.conv_block2 = ConvBlockResCond(out_channels, out_channels, size, activation, momentum, cond_embedding_dim)
|
223 |
+
self.conv_block3 = ConvBlockResCond(out_channels, out_channels, size, activation, momentum, cond_embedding_dim)
|
224 |
+
self.conv_block4 = ConvBlockResCond(out_channels, out_channels, size, activation, momentum, cond_embedding_dim)
|
225 |
+
self.downsample = downsample
|
226 |
+
|
227 |
+
def forward(self, x, cond_vec):
|
228 |
+
encoder = self.conv_block1(x, cond_vec)
|
229 |
+
encoder = self.conv_block2(encoder, cond_vec)
|
230 |
+
encoder = self.conv_block3(encoder, cond_vec)
|
231 |
+
encoder = self.conv_block4(encoder, cond_vec)
|
232 |
+
encoder_pool = F.avg_pool2d(encoder, kernel_size=self.downsample)
|
233 |
+
return encoder_pool, encoder
|
234 |
+
|
235 |
+
class DecoderBlockRes4BCond(nn.Module):
|
236 |
+
def __init__(self, in_channels, out_channels, stride, activation, momentum, cond_embedding_dim):
|
237 |
+
super(DecoderBlockRes4B, self).__init__()
|
238 |
+
size = (3, 3)
|
239 |
+
self.activation = activation
|
240 |
+
|
241 |
+
self.conv1 = torch.nn.ConvTranspose2d(in_channels=in_channels,
|
242 |
+
out_channels=out_channels, kernel_size=size, stride=stride,
|
243 |
+
padding=(0, 0), output_padding=(0, 0), bias=False, dilation=1)
|
244 |
+
|
245 |
+
self.bn1 = nn.BatchNorm2d(in_channels)
|
246 |
+
self.conv_block2 = ConvBlockResCond(out_channels * 2, out_channels, size, activation, momentum, cond_embedding_dim)
|
247 |
+
self.conv_block3 = ConvBlockResCond(out_channels, out_channels, size, activation, momentum, cond_embedding_dim)
|
248 |
+
self.conv_block4 = ConvBlockResCond(out_channels, out_channels, size, activation, momentum, cond_embedding_dim)
|
249 |
+
self.conv_block5 = ConvBlockResCond(out_channels, out_channels, size, activation, momentum, cond_embedding_dim)
|
250 |
+
|
251 |
+
def init_weights(self):
|
252 |
+
init_layer(self.conv1)
|
253 |
+
|
254 |
+
def prune(self, x, both=False):
|
255 |
+
"""Prune the shape of x after transpose convolution.
|
256 |
+
"""
|
257 |
+
if(both): x = x[:, :, 0 : - 1, 0:-1]
|
258 |
+
else: x = x[:, :, 0: - 1, :]
|
259 |
+
return x
|
260 |
+
|
261 |
+
def forward(self, input_tensor, concat_tensor, cond_vec, both=False):
|
262 |
+
x = self.conv1(F.relu_(self.bn1(input_tensor)))
|
263 |
+
x = self.prune(x,both=both)
|
264 |
+
x = torch.cat((x, concat_tensor), dim=1)
|
265 |
+
x = self.conv_block2(x, cond_vec)
|
266 |
+
x = self.conv_block3(x, cond_vec)
|
267 |
+
x = self.conv_block4(x, cond_vec)
|
268 |
+
x = self.conv_block5(x, cond_vec)
|
269 |
+
return x
|
270 |
+
|
271 |
+
class EncoderBlockRes4B(nn.Module):
|
272 |
+
def __init__(self, in_channels, out_channels, downsample, activation, momentum):
|
273 |
+
super(EncoderBlockRes4B, self).__init__()
|
274 |
+
size = (3, 3)
|
275 |
+
|
276 |
+
self.conv_block1 = ConvBlockRes(in_channels, out_channels, size, activation, momentum)
|
277 |
+
self.conv_block2 = ConvBlockRes(out_channels, out_channels, size, activation, momentum)
|
278 |
+
self.conv_block3 = ConvBlockRes(out_channels, out_channels, size, activation, momentum)
|
279 |
+
self.conv_block4 = ConvBlockRes(out_channels, out_channels, size, activation, momentum)
|
280 |
+
self.downsample = downsample
|
281 |
+
|
282 |
+
def forward(self, x):
|
283 |
+
encoder = self.conv_block1(x)
|
284 |
+
encoder = self.conv_block2(encoder)
|
285 |
+
encoder = self.conv_block3(encoder)
|
286 |
+
encoder = self.conv_block4(encoder)
|
287 |
+
encoder_pool = F.avg_pool2d(encoder, kernel_size=self.downsample)
|
288 |
+
return encoder_pool, encoder
|
289 |
+
|
290 |
+
class DecoderBlockRes4B(nn.Module):
|
291 |
+
def __init__(self, in_channels, out_channels, stride, activation, momentum):
|
292 |
+
super(DecoderBlockRes4B, self).__init__()
|
293 |
+
size = (3,3)
|
294 |
+
self.activation = activation
|
295 |
+
|
296 |
+
self.conv1 = torch.nn.ConvTranspose2d(in_channels=in_channels,
|
297 |
+
out_channels=out_channels, kernel_size=size, stride=stride,
|
298 |
+
padding=(0, 0), output_padding=(0, 0), bias=False, dilation=1)
|
299 |
+
|
300 |
+
self.bn1 = nn.BatchNorm2d(in_channels)
|
301 |
+
self.conv_block2 = ConvBlockRes(out_channels * 2, out_channels, size, activation, momentum)
|
302 |
+
self.conv_block3 = ConvBlockRes(out_channels, out_channels, size, activation, momentum)
|
303 |
+
self.conv_block4 = ConvBlockRes(out_channels, out_channels, size, activation, momentum)
|
304 |
+
self.conv_block5 = ConvBlockRes(out_channels, out_channels, size, activation, momentum)
|
305 |
+
|
306 |
+
def init_weights(self):
|
307 |
+
init_layer(self.conv1)
|
308 |
+
|
309 |
+
def prune(self, x, both=False):
|
310 |
+
"""Prune the shape of x after transpose convolution.
|
311 |
+
"""
|
312 |
+
if(both): x = x[:, :, 0 : - 1, 0:-1]
|
313 |
+
else: x = x[:, :, 0: - 1, :]
|
314 |
+
return x
|
315 |
+
|
316 |
+
def forward(self, input_tensor, concat_tensor,both=False):
|
317 |
+
x = self.conv1(F.relu_(self.bn1(input_tensor)))
|
318 |
+
x = self.prune(x,both=both)
|
319 |
+
x = torch.cat((x, concat_tensor), dim=1)
|
320 |
+
x = self.conv_block2(x)
|
321 |
+
x = self.conv_block3(x)
|
322 |
+
x = self.conv_block4(x)
|
323 |
+
x = self.conv_block5(x)
|
324 |
+
return x
|
325 |
+
|
326 |
+
class ConvBlockResCond(nn.Module):
|
327 |
+
def __init__(self, in_channels, out_channels, kernel_size, activation, momentum, cond_embedding_dim):
|
328 |
+
r"""Residual block.
|
329 |
+
"""
|
330 |
+
super(ConvBlockResCond, self).__init__()
|
331 |
+
|
332 |
+
self.activation = activation
|
333 |
+
padding = [kernel_size[0] // 2, kernel_size[1] // 2]
|
334 |
+
|
335 |
+
self.bn1 = nn.BatchNorm2d(in_channels)
|
336 |
+
self.bn2 = nn.BatchNorm2d(out_channels)
|
337 |
+
|
338 |
+
self.conv1 = nn.Conv2d(in_channels=in_channels,
|
339 |
+
out_channels=out_channels,
|
340 |
+
kernel_size=kernel_size, stride=(1, 1),
|
341 |
+
dilation=(1, 1), padding=padding, bias=False)
|
342 |
+
self.film1 = Film(channels=out_channels, cond_embedding_dim=cond_embedding_dim)
|
343 |
+
self.conv2 = nn.Conv2d(in_channels=out_channels,
|
344 |
+
out_channels=out_channels,
|
345 |
+
kernel_size=kernel_size, stride=(1, 1),
|
346 |
+
dilation=(1, 1), padding=padding, bias=False)
|
347 |
+
self.film2 = Film(channels=out_channels, cond_embedding_dim=cond_embedding_dim)
|
348 |
+
|
349 |
+
if in_channels != out_channels:
|
350 |
+
self.shortcut = nn.Conv2d(in_channels=in_channels,
|
351 |
+
out_channels=out_channels, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))
|
352 |
+
self.film_res = Film(channels=out_channels, cond_embedding_dim=cond_embedding_dim)
|
353 |
+
self.is_shortcut = True
|
354 |
+
else:
|
355 |
+
self.is_shortcut = False
|
356 |
+
|
357 |
+
self.init_weights()
|
358 |
+
|
359 |
+
def init_weights(self):
|
360 |
+
init_bn(self.bn1)
|
361 |
+
init_bn(self.bn2)
|
362 |
+
init_layer(self.conv1)
|
363 |
+
init_layer(self.conv2)
|
364 |
+
|
365 |
+
if self.is_shortcut:
|
366 |
+
init_layer(self.shortcut)
|
367 |
+
|
368 |
+
def forward(self, x, cond_vec):
|
369 |
+
origin = x
|
370 |
+
x = self.conv1(F.leaky_relu_(self.bn1(x), negative_slope=0.01))
|
371 |
+
x = self.film1(x, cond_vec)
|
372 |
+
x = self.conv2(F.leaky_relu_(self.bn2(x), negative_slope=0.01))
|
373 |
+
x = self.film2(x, cond_vec)
|
374 |
+
if self.is_shortcut:
|
375 |
+
residual = self.shortcut(origin)
|
376 |
+
residual = self.film_res(residual, cond_vec)
|
377 |
+
return residual + x
|
378 |
+
else:
|
379 |
+
return origin + x
|
380 |
+
|
381 |
+
class ConvBlockRes(nn.Module):
|
382 |
+
def __init__(self, in_channels, out_channels, kernel_size, activation, momentum):
|
383 |
+
r"""Residual block.
|
384 |
+
"""
|
385 |
+
super(ConvBlockRes, self).__init__()
|
386 |
+
|
387 |
+
self.activation = activation
|
388 |
+
padding = [kernel_size[0] // 2, kernel_size[1] // 2]
|
389 |
+
|
390 |
+
self.bn1 = nn.BatchNorm2d(in_channels)
|
391 |
+
self.bn2 = nn.BatchNorm2d(out_channels)
|
392 |
+
|
393 |
+
self.conv1 = nn.Conv2d(in_channels=in_channels,
|
394 |
+
out_channels=out_channels,
|
395 |
+
kernel_size=kernel_size, stride=(1, 1),
|
396 |
+
dilation=(1, 1), padding=padding, bias=False)
|
397 |
+
|
398 |
+
self.conv2 = nn.Conv2d(in_channels=out_channels,
|
399 |
+
out_channels=out_channels,
|
400 |
+
kernel_size=kernel_size, stride=(1, 1),
|
401 |
+
dilation=(1, 1), padding=padding, bias=False)
|
402 |
+
|
403 |
+
if in_channels != out_channels:
|
404 |
+
self.shortcut = nn.Conv2d(in_channels=in_channels,
|
405 |
+
out_channels=out_channels, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))
|
406 |
+
self.is_shortcut = True
|
407 |
+
else:
|
408 |
+
self.is_shortcut = False
|
409 |
+
|
410 |
+
self.init_weights()
|
411 |
+
|
412 |
+
def init_weights(self):
|
413 |
+
init_bn(self.bn1)
|
414 |
+
init_bn(self.bn2)
|
415 |
+
init_layer(self.conv1)
|
416 |
+
init_layer(self.conv2)
|
417 |
+
|
418 |
+
if self.is_shortcut:
|
419 |
+
init_layer(self.shortcut)
|
420 |
+
|
421 |
+
def forward(self, x):
|
422 |
+
origin = x
|
423 |
+
x = self.conv1(F.leaky_relu_(self.bn1(x), negative_slope=0.01))
|
424 |
+
x = self.conv2(F.leaky_relu_(self.bn2(x), negative_slope=0.01))
|
425 |
+
|
426 |
+
if self.is_shortcut:
|
427 |
+
return self.shortcut(origin) + x
|
428 |
+
else:
|
429 |
+
return origin + x
|
430 |
+
|
431 |
+
def init_layer(layer):
|
432 |
+
"""Initialize a Linear or Convolutional layer. """
|
433 |
+
nn.init.xavier_uniform_(layer.weight)
|
434 |
+
|
435 |
+
if hasattr(layer, 'bias'):
|
436 |
+
if layer.bias is not None:
|
437 |
+
layer.bias.data.fill_(0.)
|
438 |
+
|
439 |
+
def init_bn(bn):
|
440 |
+
"""Initialize a Batchnorm layer. """
|
441 |
+
bn.bias.data.fill_(0.)
|
442 |
+
bn.weight.data.fill_(1.)
|
443 |
+
|
444 |
+
def init_gru(rnn):
|
445 |
+
"""Initialize a GRU layer. """
|
446 |
+
|
447 |
+
def _concat_init(tensor, init_funcs):
|
448 |
+
(length, fan_out) = tensor.shape
|
449 |
+
fan_in = length // len(init_funcs)
|
450 |
+
|
451 |
+
for (i, init_func) in enumerate(init_funcs):
|
452 |
+
init_func(tensor[i * fan_in: (i + 1) * fan_in, :])
|
453 |
+
|
454 |
+
def _inner_uniform(tensor):
|
455 |
+
fan_in = nn.init._calculate_correct_fan(tensor, 'fan_in')
|
456 |
+
nn.init.uniform_(tensor, -math.sqrt(3 / fan_in), math.sqrt(3 / fan_in))
|
457 |
+
|
458 |
+
for i in range(rnn.num_layers):
|
459 |
+
_concat_init(
|
460 |
+
getattr(rnn, 'weight_ih_l{}'.format(i)),
|
461 |
+
[_inner_uniform, _inner_uniform, _inner_uniform]
|
462 |
+
)
|
463 |
+
torch.nn.init.constant_(getattr(rnn, 'bias_ih_l{}'.format(i)), 0)
|
464 |
+
|
465 |
+
_concat_init(
|
466 |
+
getattr(rnn, 'weight_hh_l{}'.format(i)),
|
467 |
+
[_inner_uniform, _inner_uniform, nn.init.orthogonal_]
|
468 |
+
)
|
469 |
+
torch.nn.init.constant_(getattr(rnn, 'bias_hh_l{}'.format(i)), 0)
|
470 |
+
|
471 |
+
|
472 |
+
def act(x, activation):
|
473 |
+
if activation == 'relu':
|
474 |
+
return F.relu_(x)
|
475 |
+
|
476 |
+
elif activation == 'leaky_relu':
|
477 |
+
return F.leaky_relu_(x, negative_slope=0.2)
|
478 |
+
|
479 |
+
elif activation == 'swish':
|
480 |
+
return x * torch.sigmoid(x)
|
481 |
+
|
482 |
+
else:
|
483 |
+
raise Exception('Incorrect activation!')
|
sound_extraction/model/resunet_film.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .modules import *
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
class UNetRes_FiLM(nn.Module):
|
5 |
+
def __init__(self, channels, cond_embedding_dim, nsrc=1):
|
6 |
+
super(UNetRes_FiLM, self).__init__()
|
7 |
+
activation = 'relu'
|
8 |
+
momentum = 0.01
|
9 |
+
|
10 |
+
self.nsrc = nsrc
|
11 |
+
self.channels = channels
|
12 |
+
self.downsample_ratio = 2 ** 6 # This number equals 2^{#encoder_blocks}
|
13 |
+
|
14 |
+
self.encoder_block1 = EncoderBlockRes2BCond(in_channels=channels * nsrc, out_channels=32,
|
15 |
+
downsample=(2, 2), activation=activation, momentum=momentum,
|
16 |
+
cond_embedding_dim=cond_embedding_dim)
|
17 |
+
self.encoder_block2 = EncoderBlockRes2BCond(in_channels=32, out_channels=64,
|
18 |
+
downsample=(2, 2), activation=activation, momentum=momentum,
|
19 |
+
cond_embedding_dim=cond_embedding_dim)
|
20 |
+
self.encoder_block3 = EncoderBlockRes2BCond(in_channels=64, out_channels=128,
|
21 |
+
downsample=(2, 2), activation=activation, momentum=momentum,
|
22 |
+
cond_embedding_dim=cond_embedding_dim)
|
23 |
+
self.encoder_block4 = EncoderBlockRes2BCond(in_channels=128, out_channels=256,
|
24 |
+
downsample=(2, 2), activation=activation, momentum=momentum,
|
25 |
+
cond_embedding_dim=cond_embedding_dim)
|
26 |
+
self.encoder_block5 = EncoderBlockRes2BCond(in_channels=256, out_channels=384,
|
27 |
+
downsample=(2, 2), activation=activation, momentum=momentum,
|
28 |
+
cond_embedding_dim=cond_embedding_dim)
|
29 |
+
self.encoder_block6 = EncoderBlockRes2BCond(in_channels=384, out_channels=384,
|
30 |
+
downsample=(2, 2), activation=activation, momentum=momentum,
|
31 |
+
cond_embedding_dim=cond_embedding_dim)
|
32 |
+
self.conv_block7 = ConvBlockResCond(in_channels=384, out_channels=384,
|
33 |
+
kernel_size=(3, 3), activation=activation, momentum=momentum,
|
34 |
+
cond_embedding_dim=cond_embedding_dim)
|
35 |
+
self.decoder_block1 = DecoderBlockRes2BCond(in_channels=384, out_channels=384,
|
36 |
+
stride=(2, 2), activation=activation, momentum=momentum,
|
37 |
+
cond_embedding_dim=cond_embedding_dim)
|
38 |
+
self.decoder_block2 = DecoderBlockRes2BCond(in_channels=384, out_channels=384,
|
39 |
+
stride=(2, 2), activation=activation, momentum=momentum,
|
40 |
+
cond_embedding_dim=cond_embedding_dim)
|
41 |
+
self.decoder_block3 = DecoderBlockRes2BCond(in_channels=384, out_channels=256,
|
42 |
+
stride=(2, 2), activation=activation, momentum=momentum,
|
43 |
+
cond_embedding_dim=cond_embedding_dim)
|
44 |
+
self.decoder_block4 = DecoderBlockRes2BCond(in_channels=256, out_channels=128,
|
45 |
+
stride=(2, 2), activation=activation, momentum=momentum,
|
46 |
+
cond_embedding_dim=cond_embedding_dim)
|
47 |
+
self.decoder_block5 = DecoderBlockRes2BCond(in_channels=128, out_channels=64,
|
48 |
+
stride=(2, 2), activation=activation, momentum=momentum,
|
49 |
+
cond_embedding_dim=cond_embedding_dim)
|
50 |
+
self.decoder_block6 = DecoderBlockRes2BCond(in_channels=64, out_channels=32,
|
51 |
+
stride=(2, 2), activation=activation, momentum=momentum,
|
52 |
+
cond_embedding_dim=cond_embedding_dim)
|
53 |
+
|
54 |
+
self.after_conv_block1 = ConvBlockResCond(in_channels=32, out_channels=32,
|
55 |
+
kernel_size=(3, 3), activation=activation, momentum=momentum,
|
56 |
+
cond_embedding_dim=cond_embedding_dim)
|
57 |
+
|
58 |
+
self.after_conv2 = nn.Conv2d(in_channels=32, out_channels=1,
|
59 |
+
kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True)
|
60 |
+
|
61 |
+
self.init_weights()
|
62 |
+
|
63 |
+
def init_weights(self):
|
64 |
+
init_layer(self.after_conv2)
|
65 |
+
|
66 |
+
def forward(self, sp, cond_vec, dec_cond_vec):
|
67 |
+
"""
|
68 |
+
Args:
|
69 |
+
input: sp: (batch_size, channels_num, segment_samples)
|
70 |
+
Outputs:
|
71 |
+
output_dict: {
|
72 |
+
'wav': (batch_size, channels_num, segment_samples),
|
73 |
+
'sp': (batch_size, channels_num, time_steps, freq_bins)}
|
74 |
+
"""
|
75 |
+
|
76 |
+
x = sp
|
77 |
+
# Pad spectrogram to be evenly divided by downsample ratio.
|
78 |
+
origin_len = x.shape[2] # time_steps
|
79 |
+
pad_len = int(np.ceil(x.shape[2] / self.downsample_ratio)) * self.downsample_ratio - origin_len
|
80 |
+
x = F.pad(x, pad=(0, 0, 0, pad_len))
|
81 |
+
x = x[..., 0: x.shape[-1] - 2] # (bs, channels, T, F)
|
82 |
+
|
83 |
+
# UNet
|
84 |
+
(x1_pool, x1) = self.encoder_block1(x, cond_vec) # x1_pool: (bs, 32, T / 2, F / 2)
|
85 |
+
(x2_pool, x2) = self.encoder_block2(x1_pool, cond_vec) # x2_pool: (bs, 64, T / 4, F / 4)
|
86 |
+
(x3_pool, x3) = self.encoder_block3(x2_pool, cond_vec) # x3_pool: (bs, 128, T / 8, F / 8)
|
87 |
+
(x4_pool, x4) = self.encoder_block4(x3_pool, dec_cond_vec) # x4_pool: (bs, 256, T / 16, F / 16)
|
88 |
+
(x5_pool, x5) = self.encoder_block5(x4_pool, dec_cond_vec) # x5_pool: (bs, 512, T / 32, F / 32)
|
89 |
+
(x6_pool, x6) = self.encoder_block6(x5_pool, dec_cond_vec) # x6_pool: (bs, 1024, T / 64, F / 64)
|
90 |
+
x_center = self.conv_block7(x6_pool, dec_cond_vec) # (bs, 2048, T / 64, F / 64)
|
91 |
+
x7 = self.decoder_block1(x_center, x6, dec_cond_vec) # (bs, 1024, T / 32, F / 32)
|
92 |
+
x8 = self.decoder_block2(x7, x5, dec_cond_vec) # (bs, 512, T / 16, F / 16)
|
93 |
+
x9 = self.decoder_block3(x8, x4, cond_vec) # (bs, 256, T / 8, F / 8)
|
94 |
+
x10 = self.decoder_block4(x9, x3, cond_vec) # (bs, 128, T / 4, F / 4)
|
95 |
+
x11 = self.decoder_block5(x10, x2, cond_vec) # (bs, 64, T / 2, F / 2)
|
96 |
+
x12 = self.decoder_block6(x11, x1, cond_vec) # (bs, 32, T, F)
|
97 |
+
x = self.after_conv_block1(x12, cond_vec) # (bs, 32, T, F)
|
98 |
+
x = self.after_conv2(x) # (bs, channels, T, F)
|
99 |
+
|
100 |
+
# Recover shape
|
101 |
+
x = F.pad(x, pad=(0, 2))
|
102 |
+
x = x[:, :, 0: origin_len, :]
|
103 |
+
return x
|
104 |
+
|
105 |
+
|
106 |
+
if __name__ == "__main__":
|
107 |
+
model = UNetRes_FiLM(channels=1, cond_embedding_dim=16)
|
108 |
+
cond_vec = torch.randn((1, 16))
|
109 |
+
dec_vec = cond_vec
|
110 |
+
print(model(torch.randn((1, 1, 1001, 513)), cond_vec, dec_vec).size())
|
sound_extraction/model/text_encoder.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from transformers import *
|
4 |
+
import warnings
|
5 |
+
warnings.filterwarnings('ignore')
|
6 |
+
# pretrained model name: (model class, model tokenizer, output dimension, token style)
|
7 |
+
MODELS = {
|
8 |
+
'prajjwal1/bert-mini': (BertModel, BertTokenizer),
|
9 |
+
}
|
10 |
+
|
11 |
+
class Text_Encoder(nn.Module):
|
12 |
+
def __init__(self, device):
|
13 |
+
super(Text_Encoder, self).__init__()
|
14 |
+
self.base_model = 'prajjwal1/bert-mini'
|
15 |
+
self.dropout = 0.1
|
16 |
+
|
17 |
+
self.tokenizer = MODELS[self.base_model][1].from_pretrained(self.base_model)
|
18 |
+
|
19 |
+
self.bert_layer = MODELS[self.base_model][0].from_pretrained(self.base_model,
|
20 |
+
add_pooling_layer=False,
|
21 |
+
hidden_dropout_prob=self.dropout,
|
22 |
+
attention_probs_dropout_prob=self.dropout,
|
23 |
+
output_hidden_states=True)
|
24 |
+
|
25 |
+
self.linear_layer = nn.Sequential(nn.Linear(256, 256), nn.ReLU(inplace=True))
|
26 |
+
|
27 |
+
self.device = device
|
28 |
+
|
29 |
+
def tokenize(self, caption):
|
30 |
+
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
31 |
+
tokenized = self.tokenizer(caption, add_special_tokens=False, padding=True, return_tensors='pt')
|
32 |
+
input_ids = tokenized['input_ids']
|
33 |
+
attns_mask = tokenized['attention_mask']
|
34 |
+
|
35 |
+
input_ids = input_ids.to(self.device)
|
36 |
+
attns_mask = attns_mask.to(self.device)
|
37 |
+
return input_ids, attns_mask
|
38 |
+
|
39 |
+
def forward(self, input_ids, attns_mask):
|
40 |
+
# input_ids, attns_mask = self.tokenize(caption)
|
41 |
+
output = self.bert_layer(input_ids=input_ids, attention_mask=attns_mask)[0]
|
42 |
+
cls_embed = output[:, 0, :]
|
43 |
+
text_embed = self.linear_layer(cls_embed)
|
44 |
+
|
45 |
+
return text_embed, output # text_embed: (batch, hidden_size)
|
sound_extraction/useful_ckpts/LASSNet.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2c6a60910bc1db03d9ff7040d0e5906ab784431cb8b279cf4e295124e9e76fae
|
3 |
+
size 761532233
|
sound_extraction/utils/__pycache__/stft.cpython-38.pyc
ADDED
Binary file (4.76 kB). View file
|
|
sound_extraction/utils/__pycache__/wav_io.cpython-38.pyc
ADDED
Binary file (823 Bytes). View file
|
|
sound_extraction/utils/create_mixtures.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
def add_noise_and_scale(front, noise, snr_l=0, snr_h=0, scale_lower=1.0, scale_upper=1.0):
|
5 |
+
"""
|
6 |
+
:param front: front-head audio, like vocal [samples,channel], will be normlized so any scale will be fine
|
7 |
+
:param noise: noise, [samples,channel], any scale
|
8 |
+
:param snr_l: Optional
|
9 |
+
:param snr_h: Optional
|
10 |
+
:param scale_lower: Optional
|
11 |
+
:param scale_upper: Optional
|
12 |
+
:return: scaled front and noise (noisy = front + noise), all_mel_e2e outputs are noramlized within [-1 , 1]
|
13 |
+
"""
|
14 |
+
snr = None
|
15 |
+
noise, front = normalize_energy_torch(noise), normalize_energy_torch(front) # set noise and vocal to equal range [-1,1]
|
16 |
+
# print("normalize:",torch.max(noise),torch.max(front))
|
17 |
+
if snr_l is not None and snr_h is not None:
|
18 |
+
front, noise, snr = _random_noise(front, noise, snr_l=snr_l, snr_h=snr_h) # remix them with a specific snr
|
19 |
+
|
20 |
+
noisy, noise, front = unify_energy_torch(noise + front, noise, front) # normalize noisy, noise and vocal energy into [-1,1]
|
21 |
+
|
22 |
+
# print("unify:", torch.max(noise), torch.max(front), torch.max(noisy))
|
23 |
+
scale = _random_scale(scale_lower, scale_upper) # random scale these three signal
|
24 |
+
|
25 |
+
# print("Scale",scale)
|
26 |
+
noisy, noise, front = noisy * scale, noise * scale, front * scale # apply scale
|
27 |
+
# print("after scale", torch.max(noisy), torch.max(noise), torch.max(front), snr, scale)
|
28 |
+
|
29 |
+
front, noise = _to_numpy(front), _to_numpy(noise) # [num_samples]
|
30 |
+
mixed_wav = front + noise
|
31 |
+
|
32 |
+
return front, noise, mixed_wav, snr, scale
|
33 |
+
|
34 |
+
def _random_scale(lower=0.3, upper=0.9):
|
35 |
+
return float(uniform_torch(lower, upper))
|
36 |
+
|
37 |
+
def _random_noise(clean, noise, snr_l=None, snr_h=None):
|
38 |
+
snr = uniform_torch(snr_l,snr_h)
|
39 |
+
clean_weight = 10 ** (float(snr) / 20)
|
40 |
+
return clean, noise/clean_weight, snr
|
41 |
+
|
42 |
+
def _to_numpy(wav):
|
43 |
+
return np.transpose(wav, (1, 0))[0].numpy() # [num_samples]
|
44 |
+
|
45 |
+
def normalize_energy(audio, alpha = 1):
|
46 |
+
'''
|
47 |
+
:param audio: 1d waveform, [batchsize, *],
|
48 |
+
:param alpha: the value of output range from: [-alpha,alpha]
|
49 |
+
:return: 1d waveform which value range from: [-alpha,alpha]
|
50 |
+
'''
|
51 |
+
val_max = activelev(audio)
|
52 |
+
return (audio / val_max) * alpha
|
53 |
+
|
54 |
+
def normalize_energy_torch(audio, alpha = 1):
|
55 |
+
'''
|
56 |
+
If the signal is almost empty(determined by threshold), if will only be divided by 2**15
|
57 |
+
:param audio: 1d waveform, 2**15
|
58 |
+
:param alpha: the value of output range from: [-alpha,alpha]
|
59 |
+
:return: 1d waveform which value range from: [-alpha,alpha]
|
60 |
+
'''
|
61 |
+
val_max = activelev_torch([audio])
|
62 |
+
return (audio / val_max) * alpha
|
63 |
+
|
64 |
+
def unify_energy(*args):
|
65 |
+
max_amp = activelev(args)
|
66 |
+
mix_scale = 1.0/max_amp
|
67 |
+
return [x * mix_scale for x in args]
|
68 |
+
|
69 |
+
def unify_energy_torch(*args):
|
70 |
+
max_amp = activelev_torch(args)
|
71 |
+
mix_scale = 1.0/max_amp
|
72 |
+
return [x * mix_scale for x in args]
|
73 |
+
|
74 |
+
def activelev(*args):
|
75 |
+
'''
|
76 |
+
need to update like matlab
|
77 |
+
'''
|
78 |
+
return np.max(np.abs([*args]))
|
79 |
+
|
80 |
+
def activelev_torch(*args):
|
81 |
+
'''
|
82 |
+
need to update like matlab
|
83 |
+
'''
|
84 |
+
res = []
|
85 |
+
args = args[0]
|
86 |
+
for each in args:
|
87 |
+
res.append(torch.max(torch.abs(each)))
|
88 |
+
return max(res)
|
89 |
+
|
90 |
+
def uniform_torch(lower, upper):
|
91 |
+
if(abs(lower-upper)<1e-5):
|
92 |
+
return upper
|
93 |
+
return (upper-lower)*torch.rand(1)+lower
|
94 |
+
|
95 |
+
if __name__ == "__main__":
|
96 |
+
wav1 = torch.randn(1, 32000)
|
97 |
+
wav2 = torch.randn(1, 32000)
|
98 |
+
target, noise, snr, scale = add_noise_and_scale(wav1, wav2)
|
sound_extraction/utils/stft.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch.autograd import Variable
|
5 |
+
from scipy.signal import get_window
|
6 |
+
import librosa.util as librosa_util
|
7 |
+
from librosa.util import pad_center, tiny
|
8 |
+
# from audio_processing import window_sumsquare
|
9 |
+
|
10 |
+
def window_sumsquare(window, n_frames, hop_length=512, win_length=1024,
|
11 |
+
n_fft=1024, dtype=np.float32, norm=None):
|
12 |
+
"""
|
13 |
+
# from librosa 0.6
|
14 |
+
Compute the sum-square envelope of a window function at a given hop length.
|
15 |
+
This is used to estimate modulation effects induced by windowing
|
16 |
+
observations in short-time fourier transforms.
|
17 |
+
Parameters
|
18 |
+
----------
|
19 |
+
window : string, tuple, number, callable, or list-like
|
20 |
+
Window specification, as in `get_window`
|
21 |
+
n_frames : int > 0
|
22 |
+
The number of analysis frames
|
23 |
+
hop_length : int > 0
|
24 |
+
The number of samples to advance between frames
|
25 |
+
win_length : [optional]
|
26 |
+
The length of the window function. By default, this matches `n_fft`.
|
27 |
+
n_fft : int > 0
|
28 |
+
The length of each analysis frame.
|
29 |
+
dtype : np.dtype
|
30 |
+
The data type of the output
|
31 |
+
Returns
|
32 |
+
-------
|
33 |
+
wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
|
34 |
+
The sum-squared envelope of the window function
|
35 |
+
"""
|
36 |
+
if win_length is None:
|
37 |
+
win_length = n_fft
|
38 |
+
|
39 |
+
n = n_fft + hop_length * (n_frames - 1)
|
40 |
+
x = np.zeros(n, dtype=dtype)
|
41 |
+
|
42 |
+
# Compute the squared window at the desired length
|
43 |
+
win_sq = get_window(window, win_length, fftbins=True)
|
44 |
+
win_sq = librosa_util.normalize(win_sq, norm=norm)**2
|
45 |
+
win_sq = librosa_util.pad_center(win_sq, n_fft)
|
46 |
+
|
47 |
+
# Fill the envelope
|
48 |
+
for i in range(n_frames):
|
49 |
+
sample = i * hop_length
|
50 |
+
x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))]
|
51 |
+
return x
|
52 |
+
|
53 |
+
class STFT(torch.nn.Module):
|
54 |
+
"""adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
|
55 |
+
def __init__(self, filter_length=1024, hop_length=512, win_length=1024,
|
56 |
+
window='hann'):
|
57 |
+
super(STFT, self).__init__()
|
58 |
+
self.filter_length = filter_length
|
59 |
+
self.hop_length = hop_length
|
60 |
+
self.win_length = win_length
|
61 |
+
self.window = window
|
62 |
+
self.forward_transform = None
|
63 |
+
scale = self.filter_length / self.hop_length
|
64 |
+
fourier_basis = np.fft.fft(np.eye(self.filter_length))
|
65 |
+
|
66 |
+
cutoff = int((self.filter_length / 2 + 1))
|
67 |
+
fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]),
|
68 |
+
np.imag(fourier_basis[:cutoff, :])])
|
69 |
+
|
70 |
+
forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
|
71 |
+
inverse_basis = torch.FloatTensor(
|
72 |
+
np.linalg.pinv(scale * fourier_basis).T[:, None, :])
|
73 |
+
|
74 |
+
if window is not None:
|
75 |
+
assert(filter_length >= win_length)
|
76 |
+
# get window and zero center pad it to filter_length
|
77 |
+
fft_window = get_window(window, win_length, fftbins=True)
|
78 |
+
fft_window = pad_center(fft_window, filter_length)
|
79 |
+
fft_window = torch.from_numpy(fft_window).float()
|
80 |
+
|
81 |
+
# window the bases
|
82 |
+
forward_basis *= fft_window
|
83 |
+
inverse_basis *= fft_window
|
84 |
+
|
85 |
+
self.register_buffer('forward_basis', forward_basis.float())
|
86 |
+
self.register_buffer('inverse_basis', inverse_basis.float())
|
87 |
+
|
88 |
+
def transform(self, input_data):
|
89 |
+
num_batches = input_data.size(0)
|
90 |
+
num_samples = input_data.size(1)
|
91 |
+
|
92 |
+
self.num_samples = num_samples
|
93 |
+
|
94 |
+
# similar to librosa, reflect-pad the input
|
95 |
+
input_data = input_data.view(num_batches, 1, num_samples)
|
96 |
+
input_data = F.pad(
|
97 |
+
input_data.unsqueeze(1),
|
98 |
+
(int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
|
99 |
+
mode='reflect')
|
100 |
+
input_data = input_data.squeeze(1)
|
101 |
+
|
102 |
+
forward_transform = F.conv1d(
|
103 |
+
input_data,
|
104 |
+
Variable(self.forward_basis, requires_grad=False),
|
105 |
+
stride=self.hop_length,
|
106 |
+
padding=0)
|
107 |
+
|
108 |
+
cutoff = int((self.filter_length / 2) + 1)
|
109 |
+
real_part = forward_transform[:, :cutoff, :]
|
110 |
+
imag_part = forward_transform[:, cutoff:, :]
|
111 |
+
|
112 |
+
magnitude = torch.sqrt(real_part**2 + imag_part**2)
|
113 |
+
phase = torch.autograd.Variable(
|
114 |
+
torch.atan2(imag_part.data, real_part.data))
|
115 |
+
|
116 |
+
return magnitude, phase # [batch_size, F(513), T(1251)]
|
117 |
+
|
118 |
+
def inverse(self, magnitude, phase):
|
119 |
+
recombine_magnitude_phase = torch.cat(
|
120 |
+
[magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1)
|
121 |
+
|
122 |
+
inverse_transform = F.conv_transpose1d(
|
123 |
+
recombine_magnitude_phase,
|
124 |
+
Variable(self.inverse_basis, requires_grad=False),
|
125 |
+
stride=self.hop_length,
|
126 |
+
padding=0)
|
127 |
+
|
128 |
+
if self.window is not None:
|
129 |
+
window_sum = window_sumsquare(
|
130 |
+
self.window, magnitude.size(-1), hop_length=self.hop_length,
|
131 |
+
win_length=self.win_length, n_fft=self.filter_length,
|
132 |
+
dtype=np.float32)
|
133 |
+
# remove modulation effects
|
134 |
+
approx_nonzero_indices = torch.from_numpy(
|
135 |
+
np.where(window_sum > tiny(window_sum))[0])
|
136 |
+
window_sum = torch.autograd.Variable(
|
137 |
+
torch.from_numpy(window_sum), requires_grad=False)
|
138 |
+
window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum
|
139 |
+
inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices]
|
140 |
+
|
141 |
+
# scale by hop ratio
|
142 |
+
inverse_transform *= float(self.filter_length) / self.hop_length
|
143 |
+
|
144 |
+
inverse_transform = inverse_transform[:, :, int(self.filter_length/2):]
|
145 |
+
inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):]
|
146 |
+
|
147 |
+
return inverse_transform #[batch_size, 1, sample_num]
|
148 |
+
|
149 |
+
def forward(self, input_data):
|
150 |
+
self.magnitude, self.phase = self.transform(input_data)
|
151 |
+
reconstruction = self.inverse(self.magnitude, self.phase)
|
152 |
+
return reconstruction
|
153 |
+
|
154 |
+
if __name__ == '__main__':
|
155 |
+
a = torch.randn(4, 320000)
|
156 |
+
stft = STFT()
|
157 |
+
mag, phase = stft.transform(a)
|
158 |
+
# rec_a = stft.inverse(mag, phase)
|
159 |
+
print(mag.shape)
|
sound_extraction/utils/wav_io.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import librosa
|
2 |
+
import librosa.filters
|
3 |
+
import math
|
4 |
+
import numpy as np
|
5 |
+
import scipy.io.wavfile
|
6 |
+
|
7 |
+
def load_wav(path):
|
8 |
+
max_length = 32000 * 10
|
9 |
+
wav = librosa.core.load(path, sr=32000)[0]
|
10 |
+
if len(wav) > max_length:
|
11 |
+
audio = wav[0:max_length]
|
12 |
+
|
13 |
+
# pad audio to max length, 10s for AudioCaps
|
14 |
+
if len(wav) < max_length:
|
15 |
+
# audio = torch.nn.functional.pad(audio, (0, self.max_length - audio.size(1)), 'constant')
|
16 |
+
wav = np.pad(wav, (0, max_length - len(wav)), 'constant')
|
17 |
+
wav = wav[...,None]
|
18 |
+
return wav
|
19 |
+
|
20 |
+
|
21 |
+
def save_wav(wav, path):
|
22 |
+
wav *= 32767 / max(0.01, np.max(np.abs(wav)))
|
23 |
+
scipy.io.wavfile.write(path, 32000, wav.astype(np.int16))
|