lmzjms commited on
Commit
62bee25
1 Parent(s): cf7e296

Upload 16 files

Browse files
sound_extraction/model/LASSNet.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from .text_encoder import Text_Encoder
5
+ from .resunet_film import UNetRes_FiLM
6
+
7
+ class LASSNet(nn.Module):
8
+ def __init__(self, device='cuda'):
9
+ super(LASSNet, self).__init__()
10
+ self.text_embedder = Text_Encoder(device)
11
+ self.UNet = UNetRes_FiLM(channels=1, cond_embedding_dim=256)
12
+
13
+ def forward(self, x, caption):
14
+ # x: (Batch, 1, T, 128))
15
+ input_ids, attns_mask = self.text_embedder.tokenize(caption)
16
+
17
+ cond_vec = self.text_embedder(input_ids, attns_mask)[0]
18
+ dec_cond_vec = cond_vec
19
+
20
+ mask = self.UNet(x, cond_vec, dec_cond_vec)
21
+ mask = torch.sigmoid(mask)
22
+ return mask
23
+
24
+ def get_tokenizer(self):
25
+ return self.text_embedder.tokenizer
sound_extraction/model/__pycache__/LASSNet.cpython-38.pyc ADDED
Binary file (1.27 kB). View file
 
sound_extraction/model/__pycache__/film.cpython-38.pyc ADDED
Binary file (1.26 kB). View file
 
sound_extraction/model/__pycache__/modules.cpython-38.pyc ADDED
Binary file (14.7 kB). View file
 
sound_extraction/model/__pycache__/resunet_film.cpython-38.pyc ADDED
Binary file (3.26 kB). View file
 
sound_extraction/model/__pycache__/text_encoder.cpython-38.pyc ADDED
Binary file (1.69 kB). View file
 
sound_extraction/model/film.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class Film(nn.Module):
5
+ def __init__(self, channels, cond_embedding_dim):
6
+ super(Film, self).__init__()
7
+ self.linear = nn.Sequential(
8
+ nn.Linear(cond_embedding_dim, channels * 2),
9
+ nn.ReLU(inplace=True),
10
+ nn.Linear(channels * 2, channels),
11
+ nn.ReLU(inplace=True)
12
+ )
13
+
14
+ def forward(self, data, cond_vec):
15
+ """
16
+ :param data: [batchsize, channels, samples] or [batchsize, channels, T, F] or [batchsize, channels, F, T]
17
+ :param cond_vec: [batchsize, cond_embedding_dim]
18
+ :return:
19
+ """
20
+ bias = self.linear(cond_vec) # [batchsize, channels]
21
+ if len(list(data.size())) == 3:
22
+ data = data + bias[..., None]
23
+ elif len(list(data.size())) == 4:
24
+ data = data + bias[..., None, None]
25
+ else:
26
+ print("Warning: The size of input tensor,", data.size(), "is not correct. Film is not working.")
27
+ return data
sound_extraction/model/modules.py ADDED
@@ -0,0 +1,483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+ from .film import Film
6
+
7
+ class ConvBlock(nn.Module):
8
+ def __init__(self, in_channels, out_channels, kernel_size, activation, momentum):
9
+ super(ConvBlock, self).__init__()
10
+
11
+ self.activation = activation
12
+ padding = (kernel_size[0] // 2, kernel_size[1] // 2)
13
+
14
+ self.conv1 = nn.Conv2d(
15
+ in_channels=in_channels,
16
+ out_channels=out_channels,
17
+ kernel_size=kernel_size,
18
+ stride=(1, 1),
19
+ dilation=(1, 1),
20
+ padding=padding,
21
+ bias=False,
22
+ )
23
+
24
+ self.bn1 = nn.BatchNorm2d(out_channels, momentum=momentum)
25
+
26
+ self.conv2 = nn.Conv2d(
27
+ in_channels=out_channels,
28
+ out_channels=out_channels,
29
+ kernel_size=kernel_size,
30
+ stride=(1, 1),
31
+ dilation=(1, 1),
32
+ padding=padding,
33
+ bias=False,
34
+ )
35
+
36
+ self.bn2 = nn.BatchNorm2d(out_channels, momentum=momentum)
37
+
38
+ self.init_weights()
39
+
40
+ def init_weights(self):
41
+ init_layer(self.conv1)
42
+ init_layer(self.conv2)
43
+ init_bn(self.bn1)
44
+ init_bn(self.bn2)
45
+
46
+ def forward(self, x):
47
+ x = act(self.bn1(self.conv1(x)), self.activation)
48
+ x = act(self.bn2(self.conv2(x)), self.activation)
49
+ return x
50
+
51
+
52
+ class EncoderBlock(nn.Module):
53
+ def __init__(self, in_channels, out_channels, kernel_size, downsample, activation, momentum):
54
+ super(EncoderBlock, self).__init__()
55
+
56
+ self.conv_block = ConvBlock(
57
+ in_channels, out_channels, kernel_size, activation, momentum
58
+ )
59
+ self.downsample = downsample
60
+
61
+ def forward(self, x):
62
+ encoder = self.conv_block(x)
63
+ encoder_pool = F.avg_pool2d(encoder, kernel_size=self.downsample)
64
+ return encoder_pool, encoder
65
+
66
+
67
+ class DecoderBlock(nn.Module):
68
+ def __init__(self, in_channels, out_channels, kernel_size, upsample, activation, momentum):
69
+ super(DecoderBlock, self).__init__()
70
+ self.kernel_size = kernel_size
71
+ self.stride = upsample
72
+ self.activation = activation
73
+
74
+ self.conv1 = torch.nn.ConvTranspose2d(
75
+ in_channels=in_channels,
76
+ out_channels=out_channels,
77
+ kernel_size=self.stride,
78
+ stride=self.stride,
79
+ padding=(0, 0),
80
+ bias=False,
81
+ dilation=(1, 1),
82
+ )
83
+
84
+ self.bn1 = nn.BatchNorm2d(out_channels, momentum=momentum)
85
+
86
+ self.conv_block2 = ConvBlock(
87
+ out_channels * 2, out_channels, kernel_size, activation, momentum
88
+ )
89
+
90
+ def init_weights(self):
91
+ init_layer(self.conv1)
92
+ init_bn(self.bn)
93
+
94
+ def prune(self, x):
95
+ """Prune the shape of x after transpose convolution."""
96
+ padding = (self.kernel_size[0] // 2, self.kernel_size[1] // 2)
97
+ x = x[
98
+ :,
99
+ :,
100
+ padding[0] : padding[0] - self.stride[0],
101
+ padding[1] : padding[1] - self.stride[1]]
102
+ return x
103
+
104
+ def forward(self, input_tensor, concat_tensor):
105
+ x = act(self.bn1(self.conv1(input_tensor)), self.activation)
106
+ # from IPython import embed; embed(using=False); os._exit(0)
107
+ # x = self.prune(x)
108
+ x = torch.cat((x, concat_tensor), dim=1)
109
+ x = self.conv_block2(x)
110
+ return x
111
+
112
+
113
+ class EncoderBlockRes1B(nn.Module):
114
+ def __init__(self, in_channels, out_channels, downsample, activation, momentum):
115
+ super(EncoderBlockRes1B, self).__init__()
116
+ size = (3,3)
117
+
118
+ self.conv_block1 = ConvBlockRes(in_channels, out_channels, size, activation, momentum)
119
+ self.conv_block2 = ConvBlockRes(out_channels, out_channels, size, activation, momentum)
120
+ self.conv_block3 = ConvBlockRes(out_channels, out_channels, size, activation, momentum)
121
+ self.conv_block4 = ConvBlockRes(out_channels, out_channels, size, activation, momentum)
122
+ self.downsample = downsample
123
+
124
+ def forward(self, x):
125
+ encoder = self.conv_block1(x)
126
+ encoder = self.conv_block2(encoder)
127
+ encoder = self.conv_block3(encoder)
128
+ encoder = self.conv_block4(encoder)
129
+ encoder_pool = F.avg_pool2d(encoder, kernel_size=self.downsample)
130
+ return encoder_pool, encoder
131
+
132
+ class DecoderBlockRes1B(nn.Module):
133
+ def __init__(self, in_channels, out_channels, stride, activation, momentum):
134
+ super(DecoderBlockRes1B, self).__init__()
135
+ size = (3,3)
136
+ self.activation = activation
137
+
138
+ self.conv1 = torch.nn.ConvTranspose2d(in_channels=in_channels,
139
+ out_channels=out_channels, kernel_size=size, stride=stride,
140
+ padding=(0, 0), output_padding=(0, 0), bias=False, dilation=1)
141
+
142
+ self.bn1 = nn.BatchNorm2d(in_channels)
143
+ self.conv_block2 = ConvBlockRes(out_channels * 2, out_channels, size, activation, momentum)
144
+ self.conv_block3 = ConvBlockRes(out_channels, out_channels, size, activation, momentum)
145
+ self.conv_block4 = ConvBlockRes(out_channels, out_channels, size, activation, momentum)
146
+ self.conv_block5 = ConvBlockRes(out_channels, out_channels, size, activation, momentum)
147
+
148
+ def init_weights(self):
149
+ init_layer(self.conv1)
150
+
151
+ def prune(self, x, both=False):
152
+ """Prune the shape of x after transpose convolution.
153
+ """
154
+ if(both): x = x[:, :, 0 : - 1, 0:-1]
155
+ else: x = x[:, :, 0: - 1, :]
156
+ return x
157
+
158
+ def forward(self, input_tensor, concat_tensor,both=False):
159
+ x = self.conv1(F.relu_(self.bn1(input_tensor)))
160
+ x = self.prune(x,both=both)
161
+ x = torch.cat((x, concat_tensor), dim=1)
162
+ x = self.conv_block2(x)
163
+ x = self.conv_block3(x)
164
+ x = self.conv_block4(x)
165
+ x = self.conv_block5(x)
166
+ return x
167
+
168
+
169
+ class EncoderBlockRes2BCond(nn.Module):
170
+ def __init__(self, in_channels, out_channels, downsample, activation, momentum, cond_embedding_dim):
171
+ super(EncoderBlockRes2BCond, self).__init__()
172
+ size = (3, 3)
173
+
174
+ self.conv_block1 = ConvBlockResCond(in_channels, out_channels, size, activation, momentum, cond_embedding_dim)
175
+ self.conv_block2 = ConvBlockResCond(out_channels, out_channels, size, activation, momentum, cond_embedding_dim)
176
+ self.downsample = downsample
177
+
178
+ def forward(self, x, cond_vec):
179
+ encoder = self.conv_block1(x, cond_vec)
180
+ encoder = self.conv_block2(encoder, cond_vec)
181
+ encoder_pool = F.avg_pool2d(encoder, kernel_size=self.downsample)
182
+ return encoder_pool, encoder
183
+
184
+ class DecoderBlockRes2BCond(nn.Module):
185
+ def __init__(self, in_channels, out_channels, stride, activation, momentum, cond_embedding_dim):
186
+ super(DecoderBlockRes2BCond, self).__init__()
187
+ size = (3, 3)
188
+ self.activation = activation
189
+
190
+ self.conv1 = torch.nn.ConvTranspose2d(in_channels=in_channels,
191
+ out_channels=out_channels, kernel_size=size, stride=stride,
192
+ padding=(0, 0), output_padding=(0, 0), bias=False, dilation=1)
193
+
194
+ self.bn1 = nn.BatchNorm2d(in_channels)
195
+ self.conv_block2 = ConvBlockResCond(out_channels * 2, out_channels, size, activation, momentum, cond_embedding_dim)
196
+ self.conv_block3 = ConvBlockResCond(out_channels, out_channels, size, activation, momentum, cond_embedding_dim)
197
+
198
+ def init_weights(self):
199
+ init_layer(self.conv1)
200
+
201
+ def prune(self, x, both=False):
202
+ """Prune the shape of x after transpose convolution.
203
+ """
204
+ if(both): x = x[:, :, 0 : - 1, 0:-1]
205
+ else: x = x[:, :, 0: - 1, :]
206
+ return x
207
+
208
+ def forward(self, input_tensor, concat_tensor, cond_vec, both=False):
209
+ x = self.conv1(F.relu_(self.bn1(input_tensor)))
210
+ x = self.prune(x, both=both)
211
+ x = torch.cat((x, concat_tensor), dim=1)
212
+ x = self.conv_block2(x, cond_vec)
213
+ x = self.conv_block3(x, cond_vec)
214
+ return x
215
+
216
+ class EncoderBlockRes4BCond(nn.Module):
217
+ def __init__(self, in_channels, out_channels, downsample, activation, momentum, cond_embedding_dim):
218
+ super(EncoderBlockRes4B, self).__init__()
219
+ size = (3,3)
220
+
221
+ self.conv_block1 = ConvBlockResCond(in_channels, out_channels, size, activation, momentum, cond_embedding_dim)
222
+ self.conv_block2 = ConvBlockResCond(out_channels, out_channels, size, activation, momentum, cond_embedding_dim)
223
+ self.conv_block3 = ConvBlockResCond(out_channels, out_channels, size, activation, momentum, cond_embedding_dim)
224
+ self.conv_block4 = ConvBlockResCond(out_channels, out_channels, size, activation, momentum, cond_embedding_dim)
225
+ self.downsample = downsample
226
+
227
+ def forward(self, x, cond_vec):
228
+ encoder = self.conv_block1(x, cond_vec)
229
+ encoder = self.conv_block2(encoder, cond_vec)
230
+ encoder = self.conv_block3(encoder, cond_vec)
231
+ encoder = self.conv_block4(encoder, cond_vec)
232
+ encoder_pool = F.avg_pool2d(encoder, kernel_size=self.downsample)
233
+ return encoder_pool, encoder
234
+
235
+ class DecoderBlockRes4BCond(nn.Module):
236
+ def __init__(self, in_channels, out_channels, stride, activation, momentum, cond_embedding_dim):
237
+ super(DecoderBlockRes4B, self).__init__()
238
+ size = (3, 3)
239
+ self.activation = activation
240
+
241
+ self.conv1 = torch.nn.ConvTranspose2d(in_channels=in_channels,
242
+ out_channels=out_channels, kernel_size=size, stride=stride,
243
+ padding=(0, 0), output_padding=(0, 0), bias=False, dilation=1)
244
+
245
+ self.bn1 = nn.BatchNorm2d(in_channels)
246
+ self.conv_block2 = ConvBlockResCond(out_channels * 2, out_channels, size, activation, momentum, cond_embedding_dim)
247
+ self.conv_block3 = ConvBlockResCond(out_channels, out_channels, size, activation, momentum, cond_embedding_dim)
248
+ self.conv_block4 = ConvBlockResCond(out_channels, out_channels, size, activation, momentum, cond_embedding_dim)
249
+ self.conv_block5 = ConvBlockResCond(out_channels, out_channels, size, activation, momentum, cond_embedding_dim)
250
+
251
+ def init_weights(self):
252
+ init_layer(self.conv1)
253
+
254
+ def prune(self, x, both=False):
255
+ """Prune the shape of x after transpose convolution.
256
+ """
257
+ if(both): x = x[:, :, 0 : - 1, 0:-1]
258
+ else: x = x[:, :, 0: - 1, :]
259
+ return x
260
+
261
+ def forward(self, input_tensor, concat_tensor, cond_vec, both=False):
262
+ x = self.conv1(F.relu_(self.bn1(input_tensor)))
263
+ x = self.prune(x,both=both)
264
+ x = torch.cat((x, concat_tensor), dim=1)
265
+ x = self.conv_block2(x, cond_vec)
266
+ x = self.conv_block3(x, cond_vec)
267
+ x = self.conv_block4(x, cond_vec)
268
+ x = self.conv_block5(x, cond_vec)
269
+ return x
270
+
271
+ class EncoderBlockRes4B(nn.Module):
272
+ def __init__(self, in_channels, out_channels, downsample, activation, momentum):
273
+ super(EncoderBlockRes4B, self).__init__()
274
+ size = (3, 3)
275
+
276
+ self.conv_block1 = ConvBlockRes(in_channels, out_channels, size, activation, momentum)
277
+ self.conv_block2 = ConvBlockRes(out_channels, out_channels, size, activation, momentum)
278
+ self.conv_block3 = ConvBlockRes(out_channels, out_channels, size, activation, momentum)
279
+ self.conv_block4 = ConvBlockRes(out_channels, out_channels, size, activation, momentum)
280
+ self.downsample = downsample
281
+
282
+ def forward(self, x):
283
+ encoder = self.conv_block1(x)
284
+ encoder = self.conv_block2(encoder)
285
+ encoder = self.conv_block3(encoder)
286
+ encoder = self.conv_block4(encoder)
287
+ encoder_pool = F.avg_pool2d(encoder, kernel_size=self.downsample)
288
+ return encoder_pool, encoder
289
+
290
+ class DecoderBlockRes4B(nn.Module):
291
+ def __init__(self, in_channels, out_channels, stride, activation, momentum):
292
+ super(DecoderBlockRes4B, self).__init__()
293
+ size = (3,3)
294
+ self.activation = activation
295
+
296
+ self.conv1 = torch.nn.ConvTranspose2d(in_channels=in_channels,
297
+ out_channels=out_channels, kernel_size=size, stride=stride,
298
+ padding=(0, 0), output_padding=(0, 0), bias=False, dilation=1)
299
+
300
+ self.bn1 = nn.BatchNorm2d(in_channels)
301
+ self.conv_block2 = ConvBlockRes(out_channels * 2, out_channels, size, activation, momentum)
302
+ self.conv_block3 = ConvBlockRes(out_channels, out_channels, size, activation, momentum)
303
+ self.conv_block4 = ConvBlockRes(out_channels, out_channels, size, activation, momentum)
304
+ self.conv_block5 = ConvBlockRes(out_channels, out_channels, size, activation, momentum)
305
+
306
+ def init_weights(self):
307
+ init_layer(self.conv1)
308
+
309
+ def prune(self, x, both=False):
310
+ """Prune the shape of x after transpose convolution.
311
+ """
312
+ if(both): x = x[:, :, 0 : - 1, 0:-1]
313
+ else: x = x[:, :, 0: - 1, :]
314
+ return x
315
+
316
+ def forward(self, input_tensor, concat_tensor,both=False):
317
+ x = self.conv1(F.relu_(self.bn1(input_tensor)))
318
+ x = self.prune(x,both=both)
319
+ x = torch.cat((x, concat_tensor), dim=1)
320
+ x = self.conv_block2(x)
321
+ x = self.conv_block3(x)
322
+ x = self.conv_block4(x)
323
+ x = self.conv_block5(x)
324
+ return x
325
+
326
+ class ConvBlockResCond(nn.Module):
327
+ def __init__(self, in_channels, out_channels, kernel_size, activation, momentum, cond_embedding_dim):
328
+ r"""Residual block.
329
+ """
330
+ super(ConvBlockResCond, self).__init__()
331
+
332
+ self.activation = activation
333
+ padding = [kernel_size[0] // 2, kernel_size[1] // 2]
334
+
335
+ self.bn1 = nn.BatchNorm2d(in_channels)
336
+ self.bn2 = nn.BatchNorm2d(out_channels)
337
+
338
+ self.conv1 = nn.Conv2d(in_channels=in_channels,
339
+ out_channels=out_channels,
340
+ kernel_size=kernel_size, stride=(1, 1),
341
+ dilation=(1, 1), padding=padding, bias=False)
342
+ self.film1 = Film(channels=out_channels, cond_embedding_dim=cond_embedding_dim)
343
+ self.conv2 = nn.Conv2d(in_channels=out_channels,
344
+ out_channels=out_channels,
345
+ kernel_size=kernel_size, stride=(1, 1),
346
+ dilation=(1, 1), padding=padding, bias=False)
347
+ self.film2 = Film(channels=out_channels, cond_embedding_dim=cond_embedding_dim)
348
+
349
+ if in_channels != out_channels:
350
+ self.shortcut = nn.Conv2d(in_channels=in_channels,
351
+ out_channels=out_channels, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))
352
+ self.film_res = Film(channels=out_channels, cond_embedding_dim=cond_embedding_dim)
353
+ self.is_shortcut = True
354
+ else:
355
+ self.is_shortcut = False
356
+
357
+ self.init_weights()
358
+
359
+ def init_weights(self):
360
+ init_bn(self.bn1)
361
+ init_bn(self.bn2)
362
+ init_layer(self.conv1)
363
+ init_layer(self.conv2)
364
+
365
+ if self.is_shortcut:
366
+ init_layer(self.shortcut)
367
+
368
+ def forward(self, x, cond_vec):
369
+ origin = x
370
+ x = self.conv1(F.leaky_relu_(self.bn1(x), negative_slope=0.01))
371
+ x = self.film1(x, cond_vec)
372
+ x = self.conv2(F.leaky_relu_(self.bn2(x), negative_slope=0.01))
373
+ x = self.film2(x, cond_vec)
374
+ if self.is_shortcut:
375
+ residual = self.shortcut(origin)
376
+ residual = self.film_res(residual, cond_vec)
377
+ return residual + x
378
+ else:
379
+ return origin + x
380
+
381
+ class ConvBlockRes(nn.Module):
382
+ def __init__(self, in_channels, out_channels, kernel_size, activation, momentum):
383
+ r"""Residual block.
384
+ """
385
+ super(ConvBlockRes, self).__init__()
386
+
387
+ self.activation = activation
388
+ padding = [kernel_size[0] // 2, kernel_size[1] // 2]
389
+
390
+ self.bn1 = nn.BatchNorm2d(in_channels)
391
+ self.bn2 = nn.BatchNorm2d(out_channels)
392
+
393
+ self.conv1 = nn.Conv2d(in_channels=in_channels,
394
+ out_channels=out_channels,
395
+ kernel_size=kernel_size, stride=(1, 1),
396
+ dilation=(1, 1), padding=padding, bias=False)
397
+
398
+ self.conv2 = nn.Conv2d(in_channels=out_channels,
399
+ out_channels=out_channels,
400
+ kernel_size=kernel_size, stride=(1, 1),
401
+ dilation=(1, 1), padding=padding, bias=False)
402
+
403
+ if in_channels != out_channels:
404
+ self.shortcut = nn.Conv2d(in_channels=in_channels,
405
+ out_channels=out_channels, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))
406
+ self.is_shortcut = True
407
+ else:
408
+ self.is_shortcut = False
409
+
410
+ self.init_weights()
411
+
412
+ def init_weights(self):
413
+ init_bn(self.bn1)
414
+ init_bn(self.bn2)
415
+ init_layer(self.conv1)
416
+ init_layer(self.conv2)
417
+
418
+ if self.is_shortcut:
419
+ init_layer(self.shortcut)
420
+
421
+ def forward(self, x):
422
+ origin = x
423
+ x = self.conv1(F.leaky_relu_(self.bn1(x), negative_slope=0.01))
424
+ x = self.conv2(F.leaky_relu_(self.bn2(x), negative_slope=0.01))
425
+
426
+ if self.is_shortcut:
427
+ return self.shortcut(origin) + x
428
+ else:
429
+ return origin + x
430
+
431
+ def init_layer(layer):
432
+ """Initialize a Linear or Convolutional layer. """
433
+ nn.init.xavier_uniform_(layer.weight)
434
+
435
+ if hasattr(layer, 'bias'):
436
+ if layer.bias is not None:
437
+ layer.bias.data.fill_(0.)
438
+
439
+ def init_bn(bn):
440
+ """Initialize a Batchnorm layer. """
441
+ bn.bias.data.fill_(0.)
442
+ bn.weight.data.fill_(1.)
443
+
444
+ def init_gru(rnn):
445
+ """Initialize a GRU layer. """
446
+
447
+ def _concat_init(tensor, init_funcs):
448
+ (length, fan_out) = tensor.shape
449
+ fan_in = length // len(init_funcs)
450
+
451
+ for (i, init_func) in enumerate(init_funcs):
452
+ init_func(tensor[i * fan_in: (i + 1) * fan_in, :])
453
+
454
+ def _inner_uniform(tensor):
455
+ fan_in = nn.init._calculate_correct_fan(tensor, 'fan_in')
456
+ nn.init.uniform_(tensor, -math.sqrt(3 / fan_in), math.sqrt(3 / fan_in))
457
+
458
+ for i in range(rnn.num_layers):
459
+ _concat_init(
460
+ getattr(rnn, 'weight_ih_l{}'.format(i)),
461
+ [_inner_uniform, _inner_uniform, _inner_uniform]
462
+ )
463
+ torch.nn.init.constant_(getattr(rnn, 'bias_ih_l{}'.format(i)), 0)
464
+
465
+ _concat_init(
466
+ getattr(rnn, 'weight_hh_l{}'.format(i)),
467
+ [_inner_uniform, _inner_uniform, nn.init.orthogonal_]
468
+ )
469
+ torch.nn.init.constant_(getattr(rnn, 'bias_hh_l{}'.format(i)), 0)
470
+
471
+
472
+ def act(x, activation):
473
+ if activation == 'relu':
474
+ return F.relu_(x)
475
+
476
+ elif activation == 'leaky_relu':
477
+ return F.leaky_relu_(x, negative_slope=0.2)
478
+
479
+ elif activation == 'swish':
480
+ return x * torch.sigmoid(x)
481
+
482
+ else:
483
+ raise Exception('Incorrect activation!')
sound_extraction/model/resunet_film.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .modules import *
2
+ import numpy as np
3
+
4
+ class UNetRes_FiLM(nn.Module):
5
+ def __init__(self, channels, cond_embedding_dim, nsrc=1):
6
+ super(UNetRes_FiLM, self).__init__()
7
+ activation = 'relu'
8
+ momentum = 0.01
9
+
10
+ self.nsrc = nsrc
11
+ self.channels = channels
12
+ self.downsample_ratio = 2 ** 6 # This number equals 2^{#encoder_blocks}
13
+
14
+ self.encoder_block1 = EncoderBlockRes2BCond(in_channels=channels * nsrc, out_channels=32,
15
+ downsample=(2, 2), activation=activation, momentum=momentum,
16
+ cond_embedding_dim=cond_embedding_dim)
17
+ self.encoder_block2 = EncoderBlockRes2BCond(in_channels=32, out_channels=64,
18
+ downsample=(2, 2), activation=activation, momentum=momentum,
19
+ cond_embedding_dim=cond_embedding_dim)
20
+ self.encoder_block3 = EncoderBlockRes2BCond(in_channels=64, out_channels=128,
21
+ downsample=(2, 2), activation=activation, momentum=momentum,
22
+ cond_embedding_dim=cond_embedding_dim)
23
+ self.encoder_block4 = EncoderBlockRes2BCond(in_channels=128, out_channels=256,
24
+ downsample=(2, 2), activation=activation, momentum=momentum,
25
+ cond_embedding_dim=cond_embedding_dim)
26
+ self.encoder_block5 = EncoderBlockRes2BCond(in_channels=256, out_channels=384,
27
+ downsample=(2, 2), activation=activation, momentum=momentum,
28
+ cond_embedding_dim=cond_embedding_dim)
29
+ self.encoder_block6 = EncoderBlockRes2BCond(in_channels=384, out_channels=384,
30
+ downsample=(2, 2), activation=activation, momentum=momentum,
31
+ cond_embedding_dim=cond_embedding_dim)
32
+ self.conv_block7 = ConvBlockResCond(in_channels=384, out_channels=384,
33
+ kernel_size=(3, 3), activation=activation, momentum=momentum,
34
+ cond_embedding_dim=cond_embedding_dim)
35
+ self.decoder_block1 = DecoderBlockRes2BCond(in_channels=384, out_channels=384,
36
+ stride=(2, 2), activation=activation, momentum=momentum,
37
+ cond_embedding_dim=cond_embedding_dim)
38
+ self.decoder_block2 = DecoderBlockRes2BCond(in_channels=384, out_channels=384,
39
+ stride=(2, 2), activation=activation, momentum=momentum,
40
+ cond_embedding_dim=cond_embedding_dim)
41
+ self.decoder_block3 = DecoderBlockRes2BCond(in_channels=384, out_channels=256,
42
+ stride=(2, 2), activation=activation, momentum=momentum,
43
+ cond_embedding_dim=cond_embedding_dim)
44
+ self.decoder_block4 = DecoderBlockRes2BCond(in_channels=256, out_channels=128,
45
+ stride=(2, 2), activation=activation, momentum=momentum,
46
+ cond_embedding_dim=cond_embedding_dim)
47
+ self.decoder_block5 = DecoderBlockRes2BCond(in_channels=128, out_channels=64,
48
+ stride=(2, 2), activation=activation, momentum=momentum,
49
+ cond_embedding_dim=cond_embedding_dim)
50
+ self.decoder_block6 = DecoderBlockRes2BCond(in_channels=64, out_channels=32,
51
+ stride=(2, 2), activation=activation, momentum=momentum,
52
+ cond_embedding_dim=cond_embedding_dim)
53
+
54
+ self.after_conv_block1 = ConvBlockResCond(in_channels=32, out_channels=32,
55
+ kernel_size=(3, 3), activation=activation, momentum=momentum,
56
+ cond_embedding_dim=cond_embedding_dim)
57
+
58
+ self.after_conv2 = nn.Conv2d(in_channels=32, out_channels=1,
59
+ kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True)
60
+
61
+ self.init_weights()
62
+
63
+ def init_weights(self):
64
+ init_layer(self.after_conv2)
65
+
66
+ def forward(self, sp, cond_vec, dec_cond_vec):
67
+ """
68
+ Args:
69
+ input: sp: (batch_size, channels_num, segment_samples)
70
+ Outputs:
71
+ output_dict: {
72
+ 'wav': (batch_size, channels_num, segment_samples),
73
+ 'sp': (batch_size, channels_num, time_steps, freq_bins)}
74
+ """
75
+
76
+ x = sp
77
+ # Pad spectrogram to be evenly divided by downsample ratio.
78
+ origin_len = x.shape[2] # time_steps
79
+ pad_len = int(np.ceil(x.shape[2] / self.downsample_ratio)) * self.downsample_ratio - origin_len
80
+ x = F.pad(x, pad=(0, 0, 0, pad_len))
81
+ x = x[..., 0: x.shape[-1] - 2] # (bs, channels, T, F)
82
+
83
+ # UNet
84
+ (x1_pool, x1) = self.encoder_block1(x, cond_vec) # x1_pool: (bs, 32, T / 2, F / 2)
85
+ (x2_pool, x2) = self.encoder_block2(x1_pool, cond_vec) # x2_pool: (bs, 64, T / 4, F / 4)
86
+ (x3_pool, x3) = self.encoder_block3(x2_pool, cond_vec) # x3_pool: (bs, 128, T / 8, F / 8)
87
+ (x4_pool, x4) = self.encoder_block4(x3_pool, dec_cond_vec) # x4_pool: (bs, 256, T / 16, F / 16)
88
+ (x5_pool, x5) = self.encoder_block5(x4_pool, dec_cond_vec) # x5_pool: (bs, 512, T / 32, F / 32)
89
+ (x6_pool, x6) = self.encoder_block6(x5_pool, dec_cond_vec) # x6_pool: (bs, 1024, T / 64, F / 64)
90
+ x_center = self.conv_block7(x6_pool, dec_cond_vec) # (bs, 2048, T / 64, F / 64)
91
+ x7 = self.decoder_block1(x_center, x6, dec_cond_vec) # (bs, 1024, T / 32, F / 32)
92
+ x8 = self.decoder_block2(x7, x5, dec_cond_vec) # (bs, 512, T / 16, F / 16)
93
+ x9 = self.decoder_block3(x8, x4, cond_vec) # (bs, 256, T / 8, F / 8)
94
+ x10 = self.decoder_block4(x9, x3, cond_vec) # (bs, 128, T / 4, F / 4)
95
+ x11 = self.decoder_block5(x10, x2, cond_vec) # (bs, 64, T / 2, F / 2)
96
+ x12 = self.decoder_block6(x11, x1, cond_vec) # (bs, 32, T, F)
97
+ x = self.after_conv_block1(x12, cond_vec) # (bs, 32, T, F)
98
+ x = self.after_conv2(x) # (bs, channels, T, F)
99
+
100
+ # Recover shape
101
+ x = F.pad(x, pad=(0, 2))
102
+ x = x[:, :, 0: origin_len, :]
103
+ return x
104
+
105
+
106
+ if __name__ == "__main__":
107
+ model = UNetRes_FiLM(channels=1, cond_embedding_dim=16)
108
+ cond_vec = torch.randn((1, 16))
109
+ dec_vec = cond_vec
110
+ print(model(torch.randn((1, 1, 1001, 513)), cond_vec, dec_vec).size())
sound_extraction/model/text_encoder.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import *
4
+ import warnings
5
+ warnings.filterwarnings('ignore')
6
+ # pretrained model name: (model class, model tokenizer, output dimension, token style)
7
+ MODELS = {
8
+ 'prajjwal1/bert-mini': (BertModel, BertTokenizer),
9
+ }
10
+
11
+ class Text_Encoder(nn.Module):
12
+ def __init__(self, device):
13
+ super(Text_Encoder, self).__init__()
14
+ self.base_model = 'prajjwal1/bert-mini'
15
+ self.dropout = 0.1
16
+
17
+ self.tokenizer = MODELS[self.base_model][1].from_pretrained(self.base_model)
18
+
19
+ self.bert_layer = MODELS[self.base_model][0].from_pretrained(self.base_model,
20
+ add_pooling_layer=False,
21
+ hidden_dropout_prob=self.dropout,
22
+ attention_probs_dropout_prob=self.dropout,
23
+ output_hidden_states=True)
24
+
25
+ self.linear_layer = nn.Sequential(nn.Linear(256, 256), nn.ReLU(inplace=True))
26
+
27
+ self.device = device
28
+
29
+ def tokenize(self, caption):
30
+ # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
31
+ tokenized = self.tokenizer(caption, add_special_tokens=False, padding=True, return_tensors='pt')
32
+ input_ids = tokenized['input_ids']
33
+ attns_mask = tokenized['attention_mask']
34
+
35
+ input_ids = input_ids.to(self.device)
36
+ attns_mask = attns_mask.to(self.device)
37
+ return input_ids, attns_mask
38
+
39
+ def forward(self, input_ids, attns_mask):
40
+ # input_ids, attns_mask = self.tokenize(caption)
41
+ output = self.bert_layer(input_ids=input_ids, attention_mask=attns_mask)[0]
42
+ cls_embed = output[:, 0, :]
43
+ text_embed = self.linear_layer(cls_embed)
44
+
45
+ return text_embed, output # text_embed: (batch, hidden_size)
sound_extraction/useful_ckpts/LASSNet.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2c6a60910bc1db03d9ff7040d0e5906ab784431cb8b279cf4e295124e9e76fae
3
+ size 761532233
sound_extraction/utils/__pycache__/stft.cpython-38.pyc ADDED
Binary file (4.76 kB). View file
 
sound_extraction/utils/__pycache__/wav_io.cpython-38.pyc ADDED
Binary file (823 Bytes). View file
 
sound_extraction/utils/create_mixtures.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+ def add_noise_and_scale(front, noise, snr_l=0, snr_h=0, scale_lower=1.0, scale_upper=1.0):
5
+ """
6
+ :param front: front-head audio, like vocal [samples,channel], will be normlized so any scale will be fine
7
+ :param noise: noise, [samples,channel], any scale
8
+ :param snr_l: Optional
9
+ :param snr_h: Optional
10
+ :param scale_lower: Optional
11
+ :param scale_upper: Optional
12
+ :return: scaled front and noise (noisy = front + noise), all_mel_e2e outputs are noramlized within [-1 , 1]
13
+ """
14
+ snr = None
15
+ noise, front = normalize_energy_torch(noise), normalize_energy_torch(front) # set noise and vocal to equal range [-1,1]
16
+ # print("normalize:",torch.max(noise),torch.max(front))
17
+ if snr_l is not None and snr_h is not None:
18
+ front, noise, snr = _random_noise(front, noise, snr_l=snr_l, snr_h=snr_h) # remix them with a specific snr
19
+
20
+ noisy, noise, front = unify_energy_torch(noise + front, noise, front) # normalize noisy, noise and vocal energy into [-1,1]
21
+
22
+ # print("unify:", torch.max(noise), torch.max(front), torch.max(noisy))
23
+ scale = _random_scale(scale_lower, scale_upper) # random scale these three signal
24
+
25
+ # print("Scale",scale)
26
+ noisy, noise, front = noisy * scale, noise * scale, front * scale # apply scale
27
+ # print("after scale", torch.max(noisy), torch.max(noise), torch.max(front), snr, scale)
28
+
29
+ front, noise = _to_numpy(front), _to_numpy(noise) # [num_samples]
30
+ mixed_wav = front + noise
31
+
32
+ return front, noise, mixed_wav, snr, scale
33
+
34
+ def _random_scale(lower=0.3, upper=0.9):
35
+ return float(uniform_torch(lower, upper))
36
+
37
+ def _random_noise(clean, noise, snr_l=None, snr_h=None):
38
+ snr = uniform_torch(snr_l,snr_h)
39
+ clean_weight = 10 ** (float(snr) / 20)
40
+ return clean, noise/clean_weight, snr
41
+
42
+ def _to_numpy(wav):
43
+ return np.transpose(wav, (1, 0))[0].numpy() # [num_samples]
44
+
45
+ def normalize_energy(audio, alpha = 1):
46
+ '''
47
+ :param audio: 1d waveform, [batchsize, *],
48
+ :param alpha: the value of output range from: [-alpha,alpha]
49
+ :return: 1d waveform which value range from: [-alpha,alpha]
50
+ '''
51
+ val_max = activelev(audio)
52
+ return (audio / val_max) * alpha
53
+
54
+ def normalize_energy_torch(audio, alpha = 1):
55
+ '''
56
+ If the signal is almost empty(determined by threshold), if will only be divided by 2**15
57
+ :param audio: 1d waveform, 2**15
58
+ :param alpha: the value of output range from: [-alpha,alpha]
59
+ :return: 1d waveform which value range from: [-alpha,alpha]
60
+ '''
61
+ val_max = activelev_torch([audio])
62
+ return (audio / val_max) * alpha
63
+
64
+ def unify_energy(*args):
65
+ max_amp = activelev(args)
66
+ mix_scale = 1.0/max_amp
67
+ return [x * mix_scale for x in args]
68
+
69
+ def unify_energy_torch(*args):
70
+ max_amp = activelev_torch(args)
71
+ mix_scale = 1.0/max_amp
72
+ return [x * mix_scale for x in args]
73
+
74
+ def activelev(*args):
75
+ '''
76
+ need to update like matlab
77
+ '''
78
+ return np.max(np.abs([*args]))
79
+
80
+ def activelev_torch(*args):
81
+ '''
82
+ need to update like matlab
83
+ '''
84
+ res = []
85
+ args = args[0]
86
+ for each in args:
87
+ res.append(torch.max(torch.abs(each)))
88
+ return max(res)
89
+
90
+ def uniform_torch(lower, upper):
91
+ if(abs(lower-upper)<1e-5):
92
+ return upper
93
+ return (upper-lower)*torch.rand(1)+lower
94
+
95
+ if __name__ == "__main__":
96
+ wav1 = torch.randn(1, 32000)
97
+ wav2 = torch.randn(1, 32000)
98
+ target, noise, snr, scale = add_noise_and_scale(wav1, wav2)
sound_extraction/utils/stft.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torch.nn.functional as F
4
+ from torch.autograd import Variable
5
+ from scipy.signal import get_window
6
+ import librosa.util as librosa_util
7
+ from librosa.util import pad_center, tiny
8
+ # from audio_processing import window_sumsquare
9
+
10
+ def window_sumsquare(window, n_frames, hop_length=512, win_length=1024,
11
+ n_fft=1024, dtype=np.float32, norm=None):
12
+ """
13
+ # from librosa 0.6
14
+ Compute the sum-square envelope of a window function at a given hop length.
15
+ This is used to estimate modulation effects induced by windowing
16
+ observations in short-time fourier transforms.
17
+ Parameters
18
+ ----------
19
+ window : string, tuple, number, callable, or list-like
20
+ Window specification, as in `get_window`
21
+ n_frames : int > 0
22
+ The number of analysis frames
23
+ hop_length : int > 0
24
+ The number of samples to advance between frames
25
+ win_length : [optional]
26
+ The length of the window function. By default, this matches `n_fft`.
27
+ n_fft : int > 0
28
+ The length of each analysis frame.
29
+ dtype : np.dtype
30
+ The data type of the output
31
+ Returns
32
+ -------
33
+ wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
34
+ The sum-squared envelope of the window function
35
+ """
36
+ if win_length is None:
37
+ win_length = n_fft
38
+
39
+ n = n_fft + hop_length * (n_frames - 1)
40
+ x = np.zeros(n, dtype=dtype)
41
+
42
+ # Compute the squared window at the desired length
43
+ win_sq = get_window(window, win_length, fftbins=True)
44
+ win_sq = librosa_util.normalize(win_sq, norm=norm)**2
45
+ win_sq = librosa_util.pad_center(win_sq, n_fft)
46
+
47
+ # Fill the envelope
48
+ for i in range(n_frames):
49
+ sample = i * hop_length
50
+ x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))]
51
+ return x
52
+
53
+ class STFT(torch.nn.Module):
54
+ """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
55
+ def __init__(self, filter_length=1024, hop_length=512, win_length=1024,
56
+ window='hann'):
57
+ super(STFT, self).__init__()
58
+ self.filter_length = filter_length
59
+ self.hop_length = hop_length
60
+ self.win_length = win_length
61
+ self.window = window
62
+ self.forward_transform = None
63
+ scale = self.filter_length / self.hop_length
64
+ fourier_basis = np.fft.fft(np.eye(self.filter_length))
65
+
66
+ cutoff = int((self.filter_length / 2 + 1))
67
+ fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]),
68
+ np.imag(fourier_basis[:cutoff, :])])
69
+
70
+ forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
71
+ inverse_basis = torch.FloatTensor(
72
+ np.linalg.pinv(scale * fourier_basis).T[:, None, :])
73
+
74
+ if window is not None:
75
+ assert(filter_length >= win_length)
76
+ # get window and zero center pad it to filter_length
77
+ fft_window = get_window(window, win_length, fftbins=True)
78
+ fft_window = pad_center(fft_window, filter_length)
79
+ fft_window = torch.from_numpy(fft_window).float()
80
+
81
+ # window the bases
82
+ forward_basis *= fft_window
83
+ inverse_basis *= fft_window
84
+
85
+ self.register_buffer('forward_basis', forward_basis.float())
86
+ self.register_buffer('inverse_basis', inverse_basis.float())
87
+
88
+ def transform(self, input_data):
89
+ num_batches = input_data.size(0)
90
+ num_samples = input_data.size(1)
91
+
92
+ self.num_samples = num_samples
93
+
94
+ # similar to librosa, reflect-pad the input
95
+ input_data = input_data.view(num_batches, 1, num_samples)
96
+ input_data = F.pad(
97
+ input_data.unsqueeze(1),
98
+ (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
99
+ mode='reflect')
100
+ input_data = input_data.squeeze(1)
101
+
102
+ forward_transform = F.conv1d(
103
+ input_data,
104
+ Variable(self.forward_basis, requires_grad=False),
105
+ stride=self.hop_length,
106
+ padding=0)
107
+
108
+ cutoff = int((self.filter_length / 2) + 1)
109
+ real_part = forward_transform[:, :cutoff, :]
110
+ imag_part = forward_transform[:, cutoff:, :]
111
+
112
+ magnitude = torch.sqrt(real_part**2 + imag_part**2)
113
+ phase = torch.autograd.Variable(
114
+ torch.atan2(imag_part.data, real_part.data))
115
+
116
+ return magnitude, phase # [batch_size, F(513), T(1251)]
117
+
118
+ def inverse(self, magnitude, phase):
119
+ recombine_magnitude_phase = torch.cat(
120
+ [magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1)
121
+
122
+ inverse_transform = F.conv_transpose1d(
123
+ recombine_magnitude_phase,
124
+ Variable(self.inverse_basis, requires_grad=False),
125
+ stride=self.hop_length,
126
+ padding=0)
127
+
128
+ if self.window is not None:
129
+ window_sum = window_sumsquare(
130
+ self.window, magnitude.size(-1), hop_length=self.hop_length,
131
+ win_length=self.win_length, n_fft=self.filter_length,
132
+ dtype=np.float32)
133
+ # remove modulation effects
134
+ approx_nonzero_indices = torch.from_numpy(
135
+ np.where(window_sum > tiny(window_sum))[0])
136
+ window_sum = torch.autograd.Variable(
137
+ torch.from_numpy(window_sum), requires_grad=False)
138
+ window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum
139
+ inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices]
140
+
141
+ # scale by hop ratio
142
+ inverse_transform *= float(self.filter_length) / self.hop_length
143
+
144
+ inverse_transform = inverse_transform[:, :, int(self.filter_length/2):]
145
+ inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):]
146
+
147
+ return inverse_transform #[batch_size, 1, sample_num]
148
+
149
+ def forward(self, input_data):
150
+ self.magnitude, self.phase = self.transform(input_data)
151
+ reconstruction = self.inverse(self.magnitude, self.phase)
152
+ return reconstruction
153
+
154
+ if __name__ == '__main__':
155
+ a = torch.randn(4, 320000)
156
+ stft = STFT()
157
+ mag, phase = stft.transform(a)
158
+ # rec_a = stft.inverse(mag, phase)
159
+ print(mag.shape)
sound_extraction/utils/wav_io.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+ import librosa.filters
3
+ import math
4
+ import numpy as np
5
+ import scipy.io.wavfile
6
+
7
+ def load_wav(path):
8
+ max_length = 32000 * 10
9
+ wav = librosa.core.load(path, sr=32000)[0]
10
+ if len(wav) > max_length:
11
+ audio = wav[0:max_length]
12
+
13
+ # pad audio to max length, 10s for AudioCaps
14
+ if len(wav) < max_length:
15
+ # audio = torch.nn.functional.pad(audio, (0, self.max_length - audio.size(1)), 'constant')
16
+ wav = np.pad(wav, (0, max_length - len(wav)), 'constant')
17
+ wav = wav[...,None]
18
+ return wav
19
+
20
+
21
+ def save_wav(wav, path):
22
+ wav *= 32767 / max(0.01, np.max(np.abs(wav)))
23
+ scipy.io.wavfile.write(path, 32000, wav.astype(np.int16))