alpha31476 commited on
Commit
bda7d99
·
verified ·
1 Parent(s): a3c0704

Image Audio Alingment

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +5 -0
  2. Vaani/Img_Audio_Alignment/000000039769.jpg +3 -0
  3. Vaani/Img_Audio_Alignment/CLAP-Audio-Encoder.txt +292 -0
  4. Vaani/Img_Audio_Alignment/LoRA-CLAP-Audio-Encoder.txt +954 -0
  5. Vaani/Img_Audio_Alignment/_1_CLAP-Audio-Encoder.ipynb +0 -0
  6. Vaani/Img_Audio_Alignment/_2_Train.py +1495 -0
  7. Vaani/Img_Audio_Alignment/audio_embedding.npy +3 -0
  8. Vaani/Img_Audio_Alignment/audio_embedding_dismantled_msclap.npy +3 -0
  9. Vaani/Img_Audio_Alignment/audio_embedding_dismantled_msclap_untrained.npy +3 -0
  10. Vaani/SDFT/checkpoints/checkpoint.pth +1 -1
  11. Vaani/SDFT/samples/inference_epoch10.png +0 -0
  12. Vaani/SDFT/samples/inference_epoch9.png +0 -0
  13. Vaani/Vaani-Audio-Image-Hindi.csv +1 -0
  14. Vaani/VaaniLDM/ddpm_ckpt_epoch55.pt +3 -0
  15. Vaani/VaaniLDM/ddpm_ckpt_epoch56.pt +3 -0
  16. Vaani/VaaniLDM/ldmH_ckpt_epoch49.pt +3 -0
  17. Vaani/VaaniLDM/ldmH_ckpt_epoch50.pt +3 -0
  18. Vaani/VaaniLDM/samples/x0_0.png +2 -2
  19. Vaani/VaaniLDM/samples/x0_1.png +0 -0
  20. Vaani/VaaniLDM/samples/x0_10.png +0 -0
  21. Vaani/VaaniLDM/samples/x0_100.png +0 -0
  22. Vaani/VaaniLDM/samples/x0_101.png +0 -0
  23. Vaani/VaaniLDM/samples/x0_102.png +0 -0
  24. Vaani/VaaniLDM/samples/x0_103.png +0 -0
  25. Vaani/VaaniLDM/samples/x0_104.png +0 -0
  26. Vaani/VaaniLDM/samples/x0_105.png +0 -0
  27. Vaani/VaaniLDM/samples/x0_106.png +0 -0
  28. Vaani/VaaniLDM/samples/x0_107.png +0 -0
  29. Vaani/VaaniLDM/samples/x0_108.png +0 -0
  30. Vaani/VaaniLDM/samples/x0_109.png +0 -0
  31. Vaani/VaaniLDM/samples/x0_11.png +0 -0
  32. Vaani/VaaniLDM/samples/x0_110.png +0 -0
  33. Vaani/VaaniLDM/samples/x0_111.png +0 -0
  34. Vaani/VaaniLDM/samples/x0_112.png +0 -0
  35. Vaani/VaaniLDM/samples/x0_113.png +0 -0
  36. Vaani/VaaniLDM/samples/x0_114.png +0 -0
  37. Vaani/VaaniLDM/samples/x0_115.png +0 -0
  38. Vaani/VaaniLDM/samples/x0_116.png +0 -0
  39. Vaani/VaaniLDM/samples/x0_117.png +0 -0
  40. Vaani/VaaniLDM/samples/x0_118.png +0 -0
  41. Vaani/VaaniLDM/samples/x0_119.png +0 -0
  42. Vaani/VaaniLDM/samples/x0_12.png +0 -0
  43. Vaani/VaaniLDM/samples/x0_120.png +0 -0
  44. Vaani/VaaniLDM/samples/x0_121.png +0 -0
  45. Vaani/VaaniLDM/samples/x0_122.png +0 -0
  46. Vaani/VaaniLDM/samples/x0_123.png +0 -0
  47. Vaani/VaaniLDM/samples/x0_124.png +0 -0
  48. Vaani/VaaniLDM/samples/x0_125.png +0 -0
  49. Vaani/VaaniLDM/samples/x0_126.png +0 -0
  50. Vaani/VaaniLDM/samples/x0_127.png +0 -0
.gitattributes CHANGED
@@ -137,3 +137,8 @@ Vaani/sampleJSON.json filter=lfs diff=lfs merge=lfs -text
137
  tools/__pycache__/pynvml.cpython-312.pyc filter=lfs diff=lfs merge=lfs -text
138
  Vaani/VaaniLDM/samplesH/x0_0.png filter=lfs diff=lfs merge=lfs -text
139
  Vaani/SDFT/astronaut_horse_mars.png filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
137
  tools/__pycache__/pynvml.cpython-312.pyc filter=lfs diff=lfs merge=lfs -text
138
  Vaani/VaaniLDM/samplesH/x0_0.png filter=lfs diff=lfs merge=lfs -text
139
  Vaani/SDFT/astronaut_horse_mars.png filter=lfs diff=lfs merge=lfs -text
140
+ Vaani/Img_Audio_Alignment/000000039769.jpg filter=lfs diff=lfs merge=lfs -text
141
+ Vaani/audio_urls[[:space:]]copy.txt filter=lfs diff=lfs merge=lfs -text
142
+ Vaani/imageBY.csv filter=lfs diff=lfs merge=lfs -text
143
+ Vaani/imageBY2.csv filter=lfs diff=lfs merge=lfs -text
144
+ Vaani/imageBY3.csv filter=lfs diff=lfs merge=lfs -text
Vaani/Img_Audio_Alignment/000000039769.jpg ADDED

Git LFS Details

  • SHA256: dea9e7ef97386345f7cff32f9055da4982da5471c48d575146c796ab4563b04e
  • Pointer size: 131 Bytes
  • Size of remote file: 173 kB
Vaani/Img_Audio_Alignment/CLAP-Audio-Encoder.txt ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ AudioEncoder(
2
+ (base): HTSATWrapper(
3
+ (htsat): HTSAT_Swin_Transformer(
4
+ (spectrogram_extractor): Spectrogram(
5
+ (stft): STFT(
6
+ (conv_real): Conv1d(1, 513, kernel_size=(1024,), stride=(320,), bias=False)
7
+ (conv_imag): Conv1d(1, 513, kernel_size=(1024,), stride=(320,), bias=False)
8
+ )
9
+ )
10
+ (logmel_extractor): LogmelFilterBank()
11
+ (spec_augmenter): SpecAugmentation(
12
+ (time_dropper): DropStripes()
13
+ (freq_dropper): DropStripes()
14
+ )
15
+ (bn0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
16
+ (patch_embed): PatchEmbed(
17
+ (proj): Conv2d(1, 96, kernel_size=(4, 4), stride=(4, 4))
18
+ (norm): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
19
+ )
20
+ (pos_drop): Dropout(p=0.0, inplace=False)
21
+ (layers): ModuleList(
22
+ (0): BasicLayer(
23
+ dim=96, input_resolution=(64, 64), depth=2
24
+ (blocks): ModuleList(
25
+ (0): SwinTransformerBlock(
26
+ dim=96, input_resolution=(64, 64), num_heads=4, window_size=8, shift_size=0, mlp_ratio=4.0
27
+ (norm1): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
28
+ (attn): WindowAttention(
29
+ dim=96, window_size=(8, 8), num_heads=4
30
+ (qkv): Linear(in_features=96, out_features=288, bias=True)
31
+ (attn_drop): Dropout(p=0.0, inplace=False)
32
+ (proj): Linear(in_features=96, out_features=96, bias=True)
33
+ (proj_drop): Dropout(p=0.0, inplace=False)
34
+ (softmax): Softmax(dim=-1)
35
+ )
36
+ (drop_path): Identity()
37
+ (norm2): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
38
+ (mlp): Mlp(
39
+ (fc1): Linear(in_features=96, out_features=384, bias=True)
40
+ (act): GELU(approximate='none')
41
+ (fc2): Linear(in_features=384, out_features=96, bias=True)
42
+ (drop): Dropout(p=0.0, inplace=False)
43
+ )
44
+ )
45
+ (1): SwinTransformerBlock(
46
+ dim=96, input_resolution=(64, 64), num_heads=4, window_size=8, shift_size=4, mlp_ratio=4.0
47
+ (norm1): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
48
+ (attn): WindowAttention(
49
+ dim=96, window_size=(8, 8), num_heads=4
50
+ (qkv): Linear(in_features=96, out_features=288, bias=True)
51
+ (attn_drop): Dropout(p=0.0, inplace=False)
52
+ (proj): Linear(in_features=96, out_features=96, bias=True)
53
+ (proj_drop): Dropout(p=0.0, inplace=False)
54
+ (softmax): Softmax(dim=-1)
55
+ )
56
+ (drop_path): DropPath()
57
+ (norm2): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
58
+ (mlp): Mlp(
59
+ (fc1): Linear(in_features=96, out_features=384, bias=True)
60
+ (act): GELU(approximate='none')
61
+ (fc2): Linear(in_features=384, out_features=96, bias=True)
62
+ (drop): Dropout(p=0.0, inplace=False)
63
+ )
64
+ )
65
+ )
66
+ (downsample): PatchMerging(
67
+ input_resolution=(64, 64), dim=96
68
+ (reduction): Linear(in_features=384, out_features=192, bias=False)
69
+ (norm): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
70
+ )
71
+ )
72
+ (1): BasicLayer(
73
+ dim=192, input_resolution=(32, 32), depth=2
74
+ (blocks): ModuleList(
75
+ (0): SwinTransformerBlock(
76
+ dim=192, input_resolution=(32, 32), num_heads=8, window_size=8, shift_size=0, mlp_ratio=4.0
77
+ (norm1): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
78
+ (attn): WindowAttention(
79
+ dim=192, window_size=(8, 8), num_heads=8
80
+ (qkv): Linear(in_features=192, out_features=576, bias=True)
81
+ (attn_drop): Dropout(p=0.0, inplace=False)
82
+ (proj): Linear(in_features=192, out_features=192, bias=True)
83
+ (proj_drop): Dropout(p=0.0, inplace=False)
84
+ (softmax): Softmax(dim=-1)
85
+ )
86
+ (drop_path): DropPath()
87
+ (norm2): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
88
+ (mlp): Mlp(
89
+ (fc1): Linear(in_features=192, out_features=768, bias=True)
90
+ (act): GELU(approximate='none')
91
+ (fc2): Linear(in_features=768, out_features=192, bias=True)
92
+ (drop): Dropout(p=0.0, inplace=False)
93
+ )
94
+ )
95
+ (1): SwinTransformerBlock(
96
+ dim=192, input_resolution=(32, 32), num_heads=8, window_size=8, shift_size=4, mlp_ratio=4.0
97
+ (norm1): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
98
+ (attn): WindowAttention(
99
+ dim=192, window_size=(8, 8), num_heads=8
100
+ (qkv): Linear(in_features=192, out_features=576, bias=True)
101
+ (attn_drop): Dropout(p=0.0, inplace=False)
102
+ (proj): Linear(in_features=192, out_features=192, bias=True)
103
+ (proj_drop): Dropout(p=0.0, inplace=False)
104
+ (softmax): Softmax(dim=-1)
105
+ )
106
+ (drop_path): DropPath()
107
+ (norm2): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
108
+ (mlp): Mlp(
109
+ (fc1): Linear(in_features=192, out_features=768, bias=True)
110
+ (act): GELU(approximate='none')
111
+ (fc2): Linear(in_features=768, out_features=192, bias=True)
112
+ (drop): Dropout(p=0.0, inplace=False)
113
+ )
114
+ )
115
+ )
116
+ (downsample): PatchMerging(
117
+ input_resolution=(32, 32), dim=192
118
+ (reduction): Linear(in_features=768, out_features=384, bias=False)
119
+ (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
120
+ )
121
+ )
122
+ (2): BasicLayer(
123
+ dim=384, input_resolution=(16, 16), depth=6
124
+ (blocks): ModuleList(
125
+ (0): SwinTransformerBlock(
126
+ dim=384, input_resolution=(16, 16), num_heads=16, window_size=8, shift_size=0, mlp_ratio=4.0
127
+ (norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
128
+ (attn): WindowAttention(
129
+ dim=384, window_size=(8, 8), num_heads=16
130
+ (qkv): Linear(in_features=384, out_features=1152, bias=True)
131
+ (attn_drop): Dropout(p=0.0, inplace=False)
132
+ (proj): Linear(in_features=384, out_features=384, bias=True)
133
+ (proj_drop): Dropout(p=0.0, inplace=False)
134
+ (softmax): Softmax(dim=-1)
135
+ )
136
+ (drop_path): DropPath()
137
+ (norm2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
138
+ (mlp): Mlp(
139
+ (fc1): Linear(in_features=384, out_features=1536, bias=True)
140
+ (act): GELU(approximate='none')
141
+ (fc2): Linear(in_features=1536, out_features=384, bias=True)
142
+ (drop): Dropout(p=0.0, inplace=False)
143
+ )
144
+ )
145
+ (1): SwinTransformerBlock(
146
+ dim=384, input_resolution=(16, 16), num_heads=16, window_size=8, shift_size=4, mlp_ratio=4.0
147
+ (norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
148
+ (attn): WindowAttention(
149
+ dim=384, window_size=(8, 8), num_heads=16
150
+ (qkv): Linear(in_features=384, out_features=1152, bias=True)
151
+ (attn_drop): Dropout(p=0.0, inplace=False)
152
+ (proj): Linear(in_features=384, out_features=384, bias=True)
153
+ (proj_drop): Dropout(p=0.0, inplace=False)
154
+ (softmax): Softmax(dim=-1)
155
+ )
156
+ (drop_path): DropPath()
157
+ (norm2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
158
+ (mlp): Mlp(
159
+ (fc1): Linear(in_features=384, out_features=1536, bias=True)
160
+ (act): GELU(approximate='none')
161
+ (fc2): Linear(in_features=1536, out_features=384, bias=True)
162
+ (drop): Dropout(p=0.0, inplace=False)
163
+ )
164
+ )
165
+ (2): SwinTransformerBlock(
166
+ dim=384, input_resolution=(16, 16), num_heads=16, window_size=8, shift_size=0, mlp_ratio=4.0
167
+ (norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
168
+ (attn): WindowAttention(
169
+ dim=384, window_size=(8, 8), num_heads=16
170
+ (qkv): Linear(in_features=384, out_features=1152, bias=True)
171
+ (attn_drop): Dropout(p=0.0, inplace=False)
172
+ (proj): Linear(in_features=384, out_features=384, bias=True)
173
+ (proj_drop): Dropout(p=0.0, inplace=False)
174
+ (softmax): Softmax(dim=-1)
175
+ )
176
+ (drop_path): DropPath()
177
+ (norm2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
178
+ (mlp): Mlp(
179
+ (fc1): Linear(in_features=384, out_features=1536, bias=True)
180
+ (act): GELU(approximate='none')
181
+ (fc2): Linear(in_features=1536, out_features=384, bias=True)
182
+ (drop): Dropout(p=0.0, inplace=False)
183
+ )
184
+ )
185
+ (3): SwinTransformerBlock(
186
+ dim=384, input_resolution=(16, 16), num_heads=16, window_size=8, shift_size=4, mlp_ratio=4.0
187
+ (norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
188
+ (attn): WindowAttention(
189
+ dim=384, window_size=(8, 8), num_heads=16
190
+ (qkv): Linear(in_features=384, out_features=1152, bias=True)
191
+ (attn_drop): Dropout(p=0.0, inplace=False)
192
+ (proj): Linear(in_features=384, out_features=384, bias=True)
193
+ (proj_drop): Dropout(p=0.0, inplace=False)
194
+ (softmax): Softmax(dim=-1)
195
+ )
196
+ (drop_path): DropPath()
197
+ (norm2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
198
+ (mlp): Mlp(
199
+ (fc1): Linear(in_features=384, out_features=1536, bias=True)
200
+ (act): GELU(approximate='none')
201
+ (fc2): Linear(in_features=1536, out_features=384, bias=True)
202
+ (drop): Dropout(p=0.0, inplace=False)
203
+ )
204
+ )
205
+ (4): SwinTransformerBlock(
206
+ dim=384, input_resolution=(16, 16), num_heads=16, window_size=8, shift_size=0, mlp_ratio=4.0
207
+ (norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
208
+ (attn): WindowAttention(
209
+ dim=384, window_size=(8, 8), num_heads=16
210
+ (qkv): Linear(in_features=384, out_features=1152, bias=True)
211
+ (attn_drop): Dropout(p=0.0, inplace=False)
212
+ (proj): Linear(in_features=384, out_features=384, bias=True)
213
+ (proj_drop): Dropout(p=0.0, inplace=False)
214
+ (softmax): Softmax(dim=-1)
215
+ )
216
+ (drop_path): DropPath()
217
+ (norm2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
218
+ (mlp): Mlp(
219
+ (fc1): Linear(in_features=384, out_features=1536, bias=True)
220
+ (act): GELU(approximate='none')
221
+ (fc2): Linear(in_features=1536, out_features=384, bias=True)
222
+ (drop): Dropout(p=0.0, inplace=False)
223
+ )
224
+ )
225
+ (5): SwinTransformerBlock(
226
+ dim=384, input_resolution=(16, 16), num_heads=16, window_size=8, shift_size=4, mlp_ratio=4.0
227
+ (norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
228
+ (attn): WindowAttention(
229
+ dim=384, window_size=(8, 8), num_heads=16
230
+ (qkv): Linear(in_features=384, out_features=1152, bias=True)
231
+ (attn_drop): Dropout(p=0.0, inplace=False)
232
+ (proj): Linear(in_features=384, out_features=384, bias=True)
233
+ (proj_drop): Dropout(p=0.0, inplace=False)
234
+ (softmax): Softmax(dim=-1)
235
+ )
236
+ (drop_path): DropPath()
237
+ (norm2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
238
+ (mlp): Mlp(
239
+ (fc1): Linear(in_features=384, out_features=1536, bias=True)
240
+ (act): GELU(approximate='none')
241
+ (fc2): Linear(in_features=1536, out_features=384, bias=True)
242
+ (drop): Dropout(p=0.0, inplace=False)
243
+ )
244
+ )
245
+ )
246
+ (downsample): PatchMerging(
247
+ input_resolution=(16, 16), dim=384
248
+ (reduction): Linear(in_features=1536, out_features=768, bias=False)
249
+ (norm): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)
250
+ )
251
+ )
252
+ (3): BasicLayer(
253
+ dim=768, input_resolution=(8, 8), depth=2
254
+ (blocks): ModuleList(
255
+ (0-1): 2 x SwinTransformerBlock(
256
+ dim=768, input_resolution=(8, 8), num_heads=32, window_size=8, shift_size=0, mlp_ratio=4.0
257
+ (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
258
+ (attn): WindowAttention(
259
+ dim=768, window_size=(8, 8), num_heads=32
260
+ (qkv): Linear(in_features=768, out_features=2304, bias=True)
261
+ (attn_drop): Dropout(p=0.0, inplace=False)
262
+ (proj): Linear(in_features=768, out_features=768, bias=True)
263
+ (proj_drop): Dropout(p=0.0, inplace=False)
264
+ (softmax): Softmax(dim=-1)
265
+ )
266
+ (drop_path): DropPath()
267
+ (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
268
+ (mlp): Mlp(
269
+ (fc1): Linear(in_features=768, out_features=3072, bias=True)
270
+ (act): GELU(approximate='none')
271
+ (fc2): Linear(in_features=3072, out_features=768, bias=True)
272
+ (drop): Dropout(p=0.0, inplace=False)
273
+ )
274
+ )
275
+ )
276
+ )
277
+ )
278
+ (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
279
+ (avgpool): AdaptiveAvgPool1d(output_size=1)
280
+ (maxpool): AdaptiveMaxPool1d(output_size=1)
281
+ (tscam_conv): Conv2d(768, 527, kernel_size=(2, 3), stride=(1, 1), padding=(0, 1))
282
+ (head): Linear(in_features=527, out_features=527, bias=True)
283
+ )
284
+ )
285
+ (projection): Projection(
286
+ (linear1): Linear(in_features=768, out_features=1024, bias=False)
287
+ (linear2): Linear(in_features=1024, out_features=1024, bias=False)
288
+ (layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
289
+ (drop): Dropout(p=0.5, inplace=False)
290
+ )
291
+ )
292
+
Vaani/Img_Audio_Alignment/LoRA-CLAP-Audio-Encoder.txt ADDED
@@ -0,0 +1,954 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ PeftModelForFeatureExtraction(
2
+ (base_model): LoraModel(
3
+ (model): AudioEncoder(
4
+ (base): HTSATWrapper(
5
+ (htsat): HTSAT_Swin_Transformer(
6
+ (spectrogram_extractor): Spectrogram(
7
+ (stft): STFT(
8
+ (conv_real): Conv1d(1, 513, kernel_size=(1024,), stride=(320,), bias=False)
9
+ (conv_imag): Conv1d(1, 513, kernel_size=(1024,), stride=(320,), bias=False)
10
+ )
11
+ )
12
+ (logmel_extractor): LogmelFilterBank()
13
+ (spec_augmenter): SpecAugmentation(
14
+ (time_dropper): DropStripes()
15
+ (freq_dropper): DropStripes()
16
+ )
17
+ (bn0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
18
+ (patch_embed): PatchEmbed(
19
+ (proj): lora.Conv2d(
20
+ (base_layer): Conv2d(1, 96, kernel_size=(4, 4), stride=(4, 4))
21
+ (lora_dropout): ModuleDict(
22
+ (default): Dropout(p=0.05, inplace=False)
23
+ )
24
+ (lora_A): ModuleDict(
25
+ (default): Conv2d(1, 8, kernel_size=(4, 4), stride=(4, 4), bias=False)
26
+ )
27
+ (lora_B): ModuleDict(
28
+ (default): Conv2d(8, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
29
+ )
30
+ (lora_embedding_A): ParameterDict()
31
+ (lora_embedding_B): ParameterDict()
32
+ (lora_magnitude_vector): ModuleDict()
33
+ )
34
+ (norm): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
35
+ )
36
+ (pos_drop): Dropout(p=0.0, inplace=False)
37
+ (layers): ModuleList(
38
+ (0): BasicLayer(
39
+ dim=96, input_resolution=(64, 64), depth=2
40
+ (blocks): ModuleList(
41
+ (0): SwinTransformerBlock(
42
+ dim=96, input_resolution=(64, 64), num_heads=4, window_size=8, shift_size=0, mlp_ratio=4.0
43
+ (norm1): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
44
+ (attn): WindowAttention(
45
+ dim=96, window_size=(8, 8), num_heads=4
46
+ (qkv): lora.Linear(
47
+ (base_layer): Linear(in_features=96, out_features=288, bias=True)
48
+ (lora_dropout): ModuleDict(
49
+ (default): Dropout(p=0.05, inplace=False)
50
+ )
51
+ (lora_A): ModuleDict(
52
+ (default): Linear(in_features=96, out_features=8, bias=False)
53
+ )
54
+ (lora_B): ModuleDict(
55
+ (default): Linear(in_features=8, out_features=288, bias=False)
56
+ )
57
+ (lora_embedding_A): ParameterDict()
58
+ (lora_embedding_B): ParameterDict()
59
+ (lora_magnitude_vector): ModuleDict()
60
+ )
61
+ (attn_drop): Dropout(p=0.0, inplace=False)
62
+ (proj): lora.Linear(
63
+ (base_layer): Linear(in_features=96, out_features=96, bias=True)
64
+ (lora_dropout): ModuleDict(
65
+ (default): Dropout(p=0.05, inplace=False)
66
+ )
67
+ (lora_A): ModuleDict(
68
+ (default): Linear(in_features=96, out_features=8, bias=False)
69
+ )
70
+ (lora_B): ModuleDict(
71
+ (default): Linear(in_features=8, out_features=96, bias=False)
72
+ )
73
+ (lora_embedding_A): ParameterDict()
74
+ (lora_embedding_B): ParameterDict()
75
+ (lora_magnitude_vector): ModuleDict()
76
+ )
77
+ (proj_drop): Dropout(p=0.0, inplace=False)
78
+ (softmax): Softmax(dim=-1)
79
+ )
80
+ (drop_path): Identity()
81
+ (norm2): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
82
+ (mlp): Mlp(
83
+ (fc1): lora.Linear(
84
+ (base_layer): Linear(in_features=96, out_features=384, bias=True)
85
+ (lora_dropout): ModuleDict(
86
+ (default): Dropout(p=0.05, inplace=False)
87
+ )
88
+ (lora_A): ModuleDict(
89
+ (default): Linear(in_features=96, out_features=8, bias=False)
90
+ )
91
+ (lora_B): ModuleDict(
92
+ (default): Linear(in_features=8, out_features=384, bias=False)
93
+ )
94
+ (lora_embedding_A): ParameterDict()
95
+ (lora_embedding_B): ParameterDict()
96
+ (lora_magnitude_vector): ModuleDict()
97
+ )
98
+ (act): GELU(approximate='none')
99
+ (fc2): lora.Linear(
100
+ (base_layer): Linear(in_features=384, out_features=96, bias=True)
101
+ (lora_dropout): ModuleDict(
102
+ (default): Dropout(p=0.05, inplace=False)
103
+ )
104
+ (lora_A): ModuleDict(
105
+ (default): Linear(in_features=384, out_features=8, bias=False)
106
+ )
107
+ (lora_B): ModuleDict(
108
+ (default): Linear(in_features=8, out_features=96, bias=False)
109
+ )
110
+ (lora_embedding_A): ParameterDict()
111
+ (lora_embedding_B): ParameterDict()
112
+ (lora_magnitude_vector): ModuleDict()
113
+ )
114
+ (drop): Dropout(p=0.0, inplace=False)
115
+ )
116
+ )
117
+ (1): SwinTransformerBlock(
118
+ dim=96, input_resolution=(64, 64), num_heads=4, window_size=8, shift_size=4, mlp_ratio=4.0
119
+ (norm1): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
120
+ (attn): WindowAttention(
121
+ dim=96, window_size=(8, 8), num_heads=4
122
+ (qkv): lora.Linear(
123
+ (base_layer): Linear(in_features=96, out_features=288, bias=True)
124
+ (lora_dropout): ModuleDict(
125
+ (default): Dropout(p=0.05, inplace=False)
126
+ )
127
+ (lora_A): ModuleDict(
128
+ (default): Linear(in_features=96, out_features=8, bias=False)
129
+ )
130
+ (lora_B): ModuleDict(
131
+ (default): Linear(in_features=8, out_features=288, bias=False)
132
+ )
133
+ (lora_embedding_A): ParameterDict()
134
+ (lora_embedding_B): ParameterDict()
135
+ (lora_magnitude_vector): ModuleDict()
136
+ )
137
+ (attn_drop): Dropout(p=0.0, inplace=False)
138
+ (proj): lora.Linear(
139
+ (base_layer): Linear(in_features=96, out_features=96, bias=True)
140
+ (lora_dropout): ModuleDict(
141
+ (default): Dropout(p=0.05, inplace=False)
142
+ )
143
+ (lora_A): ModuleDict(
144
+ (default): Linear(in_features=96, out_features=8, bias=False)
145
+ )
146
+ (lora_B): ModuleDict(
147
+ (default): Linear(in_features=8, out_features=96, bias=False)
148
+ )
149
+ (lora_embedding_A): ParameterDict()
150
+ (lora_embedding_B): ParameterDict()
151
+ (lora_magnitude_vector): ModuleDict()
152
+ )
153
+ (proj_drop): Dropout(p=0.0, inplace=False)
154
+ (softmax): Softmax(dim=-1)
155
+ )
156
+ (drop_path): DropPath()
157
+ (norm2): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
158
+ (mlp): Mlp(
159
+ (fc1): lora.Linear(
160
+ (base_layer): Linear(in_features=96, out_features=384, bias=True)
161
+ (lora_dropout): ModuleDict(
162
+ (default): Dropout(p=0.05, inplace=False)
163
+ )
164
+ (lora_A): ModuleDict(
165
+ (default): Linear(in_features=96, out_features=8, bias=False)
166
+ )
167
+ (lora_B): ModuleDict(
168
+ (default): Linear(in_features=8, out_features=384, bias=False)
169
+ )
170
+ (lora_embedding_A): ParameterDict()
171
+ (lora_embedding_B): ParameterDict()
172
+ (lora_magnitude_vector): ModuleDict()
173
+ )
174
+ (act): GELU(approximate='none')
175
+ (fc2): lora.Linear(
176
+ (base_layer): Linear(in_features=384, out_features=96, bias=True)
177
+ (lora_dropout): ModuleDict(
178
+ (default): Dropout(p=0.05, inplace=False)
179
+ )
180
+ (lora_A): ModuleDict(
181
+ (default): Linear(in_features=384, out_features=8, bias=False)
182
+ )
183
+ (lora_B): ModuleDict(
184
+ (default): Linear(in_features=8, out_features=96, bias=False)
185
+ )
186
+ (lora_embedding_A): ParameterDict()
187
+ (lora_embedding_B): ParameterDict()
188
+ (lora_magnitude_vector): ModuleDict()
189
+ )
190
+ (drop): Dropout(p=0.0, inplace=False)
191
+ )
192
+ )
193
+ )
194
+ (downsample): PatchMerging(
195
+ input_resolution=(64, 64), dim=96
196
+ (reduction): Linear(in_features=384, out_features=192, bias=False)
197
+ (norm): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
198
+ )
199
+ )
200
+ (1): BasicLayer(
201
+ dim=192, input_resolution=(32, 32), depth=2
202
+ (blocks): ModuleList(
203
+ (0): SwinTransformerBlock(
204
+ dim=192, input_resolution=(32, 32), num_heads=8, window_size=8, shift_size=0, mlp_ratio=4.0
205
+ (norm1): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
206
+ (attn): WindowAttention(
207
+ dim=192, window_size=(8, 8), num_heads=8
208
+ (qkv): lora.Linear(
209
+ (base_layer): Linear(in_features=192, out_features=576, bias=True)
210
+ (lora_dropout): ModuleDict(
211
+ (default): Dropout(p=0.05, inplace=False)
212
+ )
213
+ (lora_A): ModuleDict(
214
+ (default): Linear(in_features=192, out_features=8, bias=False)
215
+ )
216
+ (lora_B): ModuleDict(
217
+ (default): Linear(in_features=8, out_features=576, bias=False)
218
+ )
219
+ (lora_embedding_A): ParameterDict()
220
+ (lora_embedding_B): ParameterDict()
221
+ (lora_magnitude_vector): ModuleDict()
222
+ )
223
+ (attn_drop): Dropout(p=0.0, inplace=False)
224
+ (proj): lora.Linear(
225
+ (base_layer): Linear(in_features=192, out_features=192, bias=True)
226
+ (lora_dropout): ModuleDict(
227
+ (default): Dropout(p=0.05, inplace=False)
228
+ )
229
+ (lora_A): ModuleDict(
230
+ (default): Linear(in_features=192, out_features=8, bias=False)
231
+ )
232
+ (lora_B): ModuleDict(
233
+ (default): Linear(in_features=8, out_features=192, bias=False)
234
+ )
235
+ (lora_embedding_A): ParameterDict()
236
+ (lora_embedding_B): ParameterDict()
237
+ (lora_magnitude_vector): ModuleDict()
238
+ )
239
+ (proj_drop): Dropout(p=0.0, inplace=False)
240
+ (softmax): Softmax(dim=-1)
241
+ )
242
+ (drop_path): DropPath()
243
+ (norm2): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
244
+ (mlp): Mlp(
245
+ (fc1): lora.Linear(
246
+ (base_layer): Linear(in_features=192, out_features=768, bias=True)
247
+ (lora_dropout): ModuleDict(
248
+ (default): Dropout(p=0.05, inplace=False)
249
+ )
250
+ (lora_A): ModuleDict(
251
+ (default): Linear(in_features=192, out_features=8, bias=False)
252
+ )
253
+ (lora_B): ModuleDict(
254
+ (default): Linear(in_features=8, out_features=768, bias=False)
255
+ )
256
+ (lora_embedding_A): ParameterDict()
257
+ (lora_embedding_B): ParameterDict()
258
+ (lora_magnitude_vector): ModuleDict()
259
+ )
260
+ (act): GELU(approximate='none')
261
+ (fc2): lora.Linear(
262
+ (base_layer): Linear(in_features=768, out_features=192, bias=True)
263
+ (lora_dropout): ModuleDict(
264
+ (default): Dropout(p=0.05, inplace=False)
265
+ )
266
+ (lora_A): ModuleDict(
267
+ (default): Linear(in_features=768, out_features=8, bias=False)
268
+ )
269
+ (lora_B): ModuleDict(
270
+ (default): Linear(in_features=8, out_features=192, bias=False)
271
+ )
272
+ (lora_embedding_A): ParameterDict()
273
+ (lora_embedding_B): ParameterDict()
274
+ (lora_magnitude_vector): ModuleDict()
275
+ )
276
+ (drop): Dropout(p=0.0, inplace=False)
277
+ )
278
+ )
279
+ (1): SwinTransformerBlock(
280
+ dim=192, input_resolution=(32, 32), num_heads=8, window_size=8, shift_size=4, mlp_ratio=4.0
281
+ (norm1): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
282
+ (attn): WindowAttention(
283
+ dim=192, window_size=(8, 8), num_heads=8
284
+ (qkv): lora.Linear(
285
+ (base_layer): Linear(in_features=192, out_features=576, bias=True)
286
+ (lora_dropout): ModuleDict(
287
+ (default): Dropout(p=0.05, inplace=False)
288
+ )
289
+ (lora_A): ModuleDict(
290
+ (default): Linear(in_features=192, out_features=8, bias=False)
291
+ )
292
+ (lora_B): ModuleDict(
293
+ (default): Linear(in_features=8, out_features=576, bias=False)
294
+ )
295
+ (lora_embedding_A): ParameterDict()
296
+ (lora_embedding_B): ParameterDict()
297
+ (lora_magnitude_vector): ModuleDict()
298
+ )
299
+ (attn_drop): Dropout(p=0.0, inplace=False)
300
+ (proj): lora.Linear(
301
+ (base_layer): Linear(in_features=192, out_features=192, bias=True)
302
+ (lora_dropout): ModuleDict(
303
+ (default): Dropout(p=0.05, inplace=False)
304
+ )
305
+ (lora_A): ModuleDict(
306
+ (default): Linear(in_features=192, out_features=8, bias=False)
307
+ )
308
+ (lora_B): ModuleDict(
309
+ (default): Linear(in_features=8, out_features=192, bias=False)
310
+ )
311
+ (lora_embedding_A): ParameterDict()
312
+ (lora_embedding_B): ParameterDict()
313
+ (lora_magnitude_vector): ModuleDict()
314
+ )
315
+ (proj_drop): Dropout(p=0.0, inplace=False)
316
+ (softmax): Softmax(dim=-1)
317
+ )
318
+ (drop_path): DropPath()
319
+ (norm2): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
320
+ (mlp): Mlp(
321
+ (fc1): lora.Linear(
322
+ (base_layer): Linear(in_features=192, out_features=768, bias=True)
323
+ (lora_dropout): ModuleDict(
324
+ (default): Dropout(p=0.05, inplace=False)
325
+ )
326
+ (lora_A): ModuleDict(
327
+ (default): Linear(in_features=192, out_features=8, bias=False)
328
+ )
329
+ (lora_B): ModuleDict(
330
+ (default): Linear(in_features=8, out_features=768, bias=False)
331
+ )
332
+ (lora_embedding_A): ParameterDict()
333
+ (lora_embedding_B): ParameterDict()
334
+ (lora_magnitude_vector): ModuleDict()
335
+ )
336
+ (act): GELU(approximate='none')
337
+ (fc2): lora.Linear(
338
+ (base_layer): Linear(in_features=768, out_features=192, bias=True)
339
+ (lora_dropout): ModuleDict(
340
+ (default): Dropout(p=0.05, inplace=False)
341
+ )
342
+ (lora_A): ModuleDict(
343
+ (default): Linear(in_features=768, out_features=8, bias=False)
344
+ )
345
+ (lora_B): ModuleDict(
346
+ (default): Linear(in_features=8, out_features=192, bias=False)
347
+ )
348
+ (lora_embedding_A): ParameterDict()
349
+ (lora_embedding_B): ParameterDict()
350
+ (lora_magnitude_vector): ModuleDict()
351
+ )
352
+ (drop): Dropout(p=0.0, inplace=False)
353
+ )
354
+ )
355
+ )
356
+ (downsample): PatchMerging(
357
+ input_resolution=(32, 32), dim=192
358
+ (reduction): Linear(in_features=768, out_features=384, bias=False)
359
+ (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
360
+ )
361
+ )
362
+ (2): BasicLayer(
363
+ dim=384, input_resolution=(16, 16), depth=6
364
+ (blocks): ModuleList(
365
+ (0): SwinTransformerBlock(
366
+ dim=384, input_resolution=(16, 16), num_heads=16, window_size=8, shift_size=0, mlp_ratio=4.0
367
+ (norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
368
+ (attn): WindowAttention(
369
+ dim=384, window_size=(8, 8), num_heads=16
370
+ (qkv): lora.Linear(
371
+ (base_layer): Linear(in_features=384, out_features=1152, bias=True)
372
+ (lora_dropout): ModuleDict(
373
+ (default): Dropout(p=0.05, inplace=False)
374
+ )
375
+ (lora_A): ModuleDict(
376
+ (default): Linear(in_features=384, out_features=8, bias=False)
377
+ )
378
+ (lora_B): ModuleDict(
379
+ (default): Linear(in_features=8, out_features=1152, bias=False)
380
+ )
381
+ (lora_embedding_A): ParameterDict()
382
+ (lora_embedding_B): ParameterDict()
383
+ (lora_magnitude_vector): ModuleDict()
384
+ )
385
+ (attn_drop): Dropout(p=0.0, inplace=False)
386
+ (proj): lora.Linear(
387
+ (base_layer): Linear(in_features=384, out_features=384, bias=True)
388
+ (lora_dropout): ModuleDict(
389
+ (default): Dropout(p=0.05, inplace=False)
390
+ )
391
+ (lora_A): ModuleDict(
392
+ (default): Linear(in_features=384, out_features=8, bias=False)
393
+ )
394
+ (lora_B): ModuleDict(
395
+ (default): Linear(in_features=8, out_features=384, bias=False)
396
+ )
397
+ (lora_embedding_A): ParameterDict()
398
+ (lora_embedding_B): ParameterDict()
399
+ (lora_magnitude_vector): ModuleDict()
400
+ )
401
+ (proj_drop): Dropout(p=0.0, inplace=False)
402
+ (softmax): Softmax(dim=-1)
403
+ )
404
+ (drop_path): DropPath()
405
+ (norm2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
406
+ (mlp): Mlp(
407
+ (fc1): lora.Linear(
408
+ (base_layer): Linear(in_features=384, out_features=1536, bias=True)
409
+ (lora_dropout): ModuleDict(
410
+ (default): Dropout(p=0.05, inplace=False)
411
+ )
412
+ (lora_A): ModuleDict(
413
+ (default): Linear(in_features=384, out_features=8, bias=False)
414
+ )
415
+ (lora_B): ModuleDict(
416
+ (default): Linear(in_features=8, out_features=1536, bias=False)
417
+ )
418
+ (lora_embedding_A): ParameterDict()
419
+ (lora_embedding_B): ParameterDict()
420
+ (lora_magnitude_vector): ModuleDict()
421
+ )
422
+ (act): GELU(approximate='none')
423
+ (fc2): lora.Linear(
424
+ (base_layer): Linear(in_features=1536, out_features=384, bias=True)
425
+ (lora_dropout): ModuleDict(
426
+ (default): Dropout(p=0.05, inplace=False)
427
+ )
428
+ (lora_A): ModuleDict(
429
+ (default): Linear(in_features=1536, out_features=8, bias=False)
430
+ )
431
+ (lora_B): ModuleDict(
432
+ (default): Linear(in_features=8, out_features=384, bias=False)
433
+ )
434
+ (lora_embedding_A): ParameterDict()
435
+ (lora_embedding_B): ParameterDict()
436
+ (lora_magnitude_vector): ModuleDict()
437
+ )
438
+ (drop): Dropout(p=0.0, inplace=False)
439
+ )
440
+ )
441
+ (1): SwinTransformerBlock(
442
+ dim=384, input_resolution=(16, 16), num_heads=16, window_size=8, shift_size=4, mlp_ratio=4.0
443
+ (norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
444
+ (attn): WindowAttention(
445
+ dim=384, window_size=(8, 8), num_heads=16
446
+ (qkv): lora.Linear(
447
+ (base_layer): Linear(in_features=384, out_features=1152, bias=True)
448
+ (lora_dropout): ModuleDict(
449
+ (default): Dropout(p=0.05, inplace=False)
450
+ )
451
+ (lora_A): ModuleDict(
452
+ (default): Linear(in_features=384, out_features=8, bias=False)
453
+ )
454
+ (lora_B): ModuleDict(
455
+ (default): Linear(in_features=8, out_features=1152, bias=False)
456
+ )
457
+ (lora_embedding_A): ParameterDict()
458
+ (lora_embedding_B): ParameterDict()
459
+ (lora_magnitude_vector): ModuleDict()
460
+ )
461
+ (attn_drop): Dropout(p=0.0, inplace=False)
462
+ (proj): lora.Linear(
463
+ (base_layer): Linear(in_features=384, out_features=384, bias=True)
464
+ (lora_dropout): ModuleDict(
465
+ (default): Dropout(p=0.05, inplace=False)
466
+ )
467
+ (lora_A): ModuleDict(
468
+ (default): Linear(in_features=384, out_features=8, bias=False)
469
+ )
470
+ (lora_B): ModuleDict(
471
+ (default): Linear(in_features=8, out_features=384, bias=False)
472
+ )
473
+ (lora_embedding_A): ParameterDict()
474
+ (lora_embedding_B): ParameterDict()
475
+ (lora_magnitude_vector): ModuleDict()
476
+ )
477
+ (proj_drop): Dropout(p=0.0, inplace=False)
478
+ (softmax): Softmax(dim=-1)
479
+ )
480
+ (drop_path): DropPath()
481
+ (norm2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
482
+ (mlp): Mlp(
483
+ (fc1): lora.Linear(
484
+ (base_layer): Linear(in_features=384, out_features=1536, bias=True)
485
+ (lora_dropout): ModuleDict(
486
+ (default): Dropout(p=0.05, inplace=False)
487
+ )
488
+ (lora_A): ModuleDict(
489
+ (default): Linear(in_features=384, out_features=8, bias=False)
490
+ )
491
+ (lora_B): ModuleDict(
492
+ (default): Linear(in_features=8, out_features=1536, bias=False)
493
+ )
494
+ (lora_embedding_A): ParameterDict()
495
+ (lora_embedding_B): ParameterDict()
496
+ (lora_magnitude_vector): ModuleDict()
497
+ )
498
+ (act): GELU(approximate='none')
499
+ (fc2): lora.Linear(
500
+ (base_layer): Linear(in_features=1536, out_features=384, bias=True)
501
+ (lora_dropout): ModuleDict(
502
+ (default): Dropout(p=0.05, inplace=False)
503
+ )
504
+ (lora_A): ModuleDict(
505
+ (default): Linear(in_features=1536, out_features=8, bias=False)
506
+ )
507
+ (lora_B): ModuleDict(
508
+ (default): Linear(in_features=8, out_features=384, bias=False)
509
+ )
510
+ (lora_embedding_A): ParameterDict()
511
+ (lora_embedding_B): ParameterDict()
512
+ (lora_magnitude_vector): ModuleDict()
513
+ )
514
+ (drop): Dropout(p=0.0, inplace=False)
515
+ )
516
+ )
517
+ (2): SwinTransformerBlock(
518
+ dim=384, input_resolution=(16, 16), num_heads=16, window_size=8, shift_size=0, mlp_ratio=4.0
519
+ (norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
520
+ (attn): WindowAttention(
521
+ dim=384, window_size=(8, 8), num_heads=16
522
+ (qkv): lora.Linear(
523
+ (base_layer): Linear(in_features=384, out_features=1152, bias=True)
524
+ (lora_dropout): ModuleDict(
525
+ (default): Dropout(p=0.05, inplace=False)
526
+ )
527
+ (lora_A): ModuleDict(
528
+ (default): Linear(in_features=384, out_features=8, bias=False)
529
+ )
530
+ (lora_B): ModuleDict(
531
+ (default): Linear(in_features=8, out_features=1152, bias=False)
532
+ )
533
+ (lora_embedding_A): ParameterDict()
534
+ (lora_embedding_B): ParameterDict()
535
+ (lora_magnitude_vector): ModuleDict()
536
+ )
537
+ (attn_drop): Dropout(p=0.0, inplace=False)
538
+ (proj): lora.Linear(
539
+ (base_layer): Linear(in_features=384, out_features=384, bias=True)
540
+ (lora_dropout): ModuleDict(
541
+ (default): Dropout(p=0.05, inplace=False)
542
+ )
543
+ (lora_A): ModuleDict(
544
+ (default): Linear(in_features=384, out_features=8, bias=False)
545
+ )
546
+ (lora_B): ModuleDict(
547
+ (default): Linear(in_features=8, out_features=384, bias=False)
548
+ )
549
+ (lora_embedding_A): ParameterDict()
550
+ (lora_embedding_B): ParameterDict()
551
+ (lora_magnitude_vector): ModuleDict()
552
+ )
553
+ (proj_drop): Dropout(p=0.0, inplace=False)
554
+ (softmax): Softmax(dim=-1)
555
+ )
556
+ (drop_path): DropPath()
557
+ (norm2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
558
+ (mlp): Mlp(
559
+ (fc1): lora.Linear(
560
+ (base_layer): Linear(in_features=384, out_features=1536, bias=True)
561
+ (lora_dropout): ModuleDict(
562
+ (default): Dropout(p=0.05, inplace=False)
563
+ )
564
+ (lora_A): ModuleDict(
565
+ (default): Linear(in_features=384, out_features=8, bias=False)
566
+ )
567
+ (lora_B): ModuleDict(
568
+ (default): Linear(in_features=8, out_features=1536, bias=False)
569
+ )
570
+ (lora_embedding_A): ParameterDict()
571
+ (lora_embedding_B): ParameterDict()
572
+ (lora_magnitude_vector): ModuleDict()
573
+ )
574
+ (act): GELU(approximate='none')
575
+ (fc2): lora.Linear(
576
+ (base_layer): Linear(in_features=1536, out_features=384, bias=True)
577
+ (lora_dropout): ModuleDict(
578
+ (default): Dropout(p=0.05, inplace=False)
579
+ )
580
+ (lora_A): ModuleDict(
581
+ (default): Linear(in_features=1536, out_features=8, bias=False)
582
+ )
583
+ (lora_B): ModuleDict(
584
+ (default): Linear(in_features=8, out_features=384, bias=False)
585
+ )
586
+ (lora_embedding_A): ParameterDict()
587
+ (lora_embedding_B): ParameterDict()
588
+ (lora_magnitude_vector): ModuleDict()
589
+ )
590
+ (drop): Dropout(p=0.0, inplace=False)
591
+ )
592
+ )
593
+ (3): SwinTransformerBlock(
594
+ dim=384, input_resolution=(16, 16), num_heads=16, window_size=8, shift_size=4, mlp_ratio=4.0
595
+ (norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
596
+ (attn): WindowAttention(
597
+ dim=384, window_size=(8, 8), num_heads=16
598
+ (qkv): lora.Linear(
599
+ (base_layer): Linear(in_features=384, out_features=1152, bias=True)
600
+ (lora_dropout): ModuleDict(
601
+ (default): Dropout(p=0.05, inplace=False)
602
+ )
603
+ (lora_A): ModuleDict(
604
+ (default): Linear(in_features=384, out_features=8, bias=False)
605
+ )
606
+ (lora_B): ModuleDict(
607
+ (default): Linear(in_features=8, out_features=1152, bias=False)
608
+ )
609
+ (lora_embedding_A): ParameterDict()
610
+ (lora_embedding_B): ParameterDict()
611
+ (lora_magnitude_vector): ModuleDict()
612
+ )
613
+ (attn_drop): Dropout(p=0.0, inplace=False)
614
+ (proj): lora.Linear(
615
+ (base_layer): Linear(in_features=384, out_features=384, bias=True)
616
+ (lora_dropout): ModuleDict(
617
+ (default): Dropout(p=0.05, inplace=False)
618
+ )
619
+ (lora_A): ModuleDict(
620
+ (default): Linear(in_features=384, out_features=8, bias=False)
621
+ )
622
+ (lora_B): ModuleDict(
623
+ (default): Linear(in_features=8, out_features=384, bias=False)
624
+ )
625
+ (lora_embedding_A): ParameterDict()
626
+ (lora_embedding_B): ParameterDict()
627
+ (lora_magnitude_vector): ModuleDict()
628
+ )
629
+ (proj_drop): Dropout(p=0.0, inplace=False)
630
+ (softmax): Softmax(dim=-1)
631
+ )
632
+ (drop_path): DropPath()
633
+ (norm2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
634
+ (mlp): Mlp(
635
+ (fc1): lora.Linear(
636
+ (base_layer): Linear(in_features=384, out_features=1536, bias=True)
637
+ (lora_dropout): ModuleDict(
638
+ (default): Dropout(p=0.05, inplace=False)
639
+ )
640
+ (lora_A): ModuleDict(
641
+ (default): Linear(in_features=384, out_features=8, bias=False)
642
+ )
643
+ (lora_B): ModuleDict(
644
+ (default): Linear(in_features=8, out_features=1536, bias=False)
645
+ )
646
+ (lora_embedding_A): ParameterDict()
647
+ (lora_embedding_B): ParameterDict()
648
+ (lora_magnitude_vector): ModuleDict()
649
+ )
650
+ (act): GELU(approximate='none')
651
+ (fc2): lora.Linear(
652
+ (base_layer): Linear(in_features=1536, out_features=384, bias=True)
653
+ (lora_dropout): ModuleDict(
654
+ (default): Dropout(p=0.05, inplace=False)
655
+ )
656
+ (lora_A): ModuleDict(
657
+ (default): Linear(in_features=1536, out_features=8, bias=False)
658
+ )
659
+ (lora_B): ModuleDict(
660
+ (default): Linear(in_features=8, out_features=384, bias=False)
661
+ )
662
+ (lora_embedding_A): ParameterDict()
663
+ (lora_embedding_B): ParameterDict()
664
+ (lora_magnitude_vector): ModuleDict()
665
+ )
666
+ (drop): Dropout(p=0.0, inplace=False)
667
+ )
668
+ )
669
+ (4): SwinTransformerBlock(
670
+ dim=384, input_resolution=(16, 16), num_heads=16, window_size=8, shift_size=0, mlp_ratio=4.0
671
+ (norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
672
+ (attn): WindowAttention(
673
+ dim=384, window_size=(8, 8), num_heads=16
674
+ (qkv): lora.Linear(
675
+ (base_layer): Linear(in_features=384, out_features=1152, bias=True)
676
+ (lora_dropout): ModuleDict(
677
+ (default): Dropout(p=0.05, inplace=False)
678
+ )
679
+ (lora_A): ModuleDict(
680
+ (default): Linear(in_features=384, out_features=8, bias=False)
681
+ )
682
+ (lora_B): ModuleDict(
683
+ (default): Linear(in_features=8, out_features=1152, bias=False)
684
+ )
685
+ (lora_embedding_A): ParameterDict()
686
+ (lora_embedding_B): ParameterDict()
687
+ (lora_magnitude_vector): ModuleDict()
688
+ )
689
+ (attn_drop): Dropout(p=0.0, inplace=False)
690
+ (proj): lora.Linear(
691
+ (base_layer): Linear(in_features=384, out_features=384, bias=True)
692
+ (lora_dropout): ModuleDict(
693
+ (default): Dropout(p=0.05, inplace=False)
694
+ )
695
+ (lora_A): ModuleDict(
696
+ (default): Linear(in_features=384, out_features=8, bias=False)
697
+ )
698
+ (lora_B): ModuleDict(
699
+ (default): Linear(in_features=8, out_features=384, bias=False)
700
+ )
701
+ (lora_embedding_A): ParameterDict()
702
+ (lora_embedding_B): ParameterDict()
703
+ (lora_magnitude_vector): ModuleDict()
704
+ )
705
+ (proj_drop): Dropout(p=0.0, inplace=False)
706
+ (softmax): Softmax(dim=-1)
707
+ )
708
+ (drop_path): DropPath()
709
+ (norm2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
710
+ (mlp): Mlp(
711
+ (fc1): lora.Linear(
712
+ (base_layer): Linear(in_features=384, out_features=1536, bias=True)
713
+ (lora_dropout): ModuleDict(
714
+ (default): Dropout(p=0.05, inplace=False)
715
+ )
716
+ (lora_A): ModuleDict(
717
+ (default): Linear(in_features=384, out_features=8, bias=False)
718
+ )
719
+ (lora_B): ModuleDict(
720
+ (default): Linear(in_features=8, out_features=1536, bias=False)
721
+ )
722
+ (lora_embedding_A): ParameterDict()
723
+ (lora_embedding_B): ParameterDict()
724
+ (lora_magnitude_vector): ModuleDict()
725
+ )
726
+ (act): GELU(approximate='none')
727
+ (fc2): lora.Linear(
728
+ (base_layer): Linear(in_features=1536, out_features=384, bias=True)
729
+ (lora_dropout): ModuleDict(
730
+ (default): Dropout(p=0.05, inplace=False)
731
+ )
732
+ (lora_A): ModuleDict(
733
+ (default): Linear(in_features=1536, out_features=8, bias=False)
734
+ )
735
+ (lora_B): ModuleDict(
736
+ (default): Linear(in_features=8, out_features=384, bias=False)
737
+ )
738
+ (lora_embedding_A): ParameterDict()
739
+ (lora_embedding_B): ParameterDict()
740
+ (lora_magnitude_vector): ModuleDict()
741
+ )
742
+ (drop): Dropout(p=0.0, inplace=False)
743
+ )
744
+ )
745
+ (5): SwinTransformerBlock(
746
+ dim=384, input_resolution=(16, 16), num_heads=16, window_size=8, shift_size=4, mlp_ratio=4.0
747
+ (norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
748
+ (attn): WindowAttention(
749
+ dim=384, window_size=(8, 8), num_heads=16
750
+ (qkv): lora.Linear(
751
+ (base_layer): Linear(in_features=384, out_features=1152, bias=True)
752
+ (lora_dropout): ModuleDict(
753
+ (default): Dropout(p=0.05, inplace=False)
754
+ )
755
+ (lora_A): ModuleDict(
756
+ (default): Linear(in_features=384, out_features=8, bias=False)
757
+ )
758
+ (lora_B): ModuleDict(
759
+ (default): Linear(in_features=8, out_features=1152, bias=False)
760
+ )
761
+ (lora_embedding_A): ParameterDict()
762
+ (lora_embedding_B): ParameterDict()
763
+ (lora_magnitude_vector): ModuleDict()
764
+ )
765
+ (attn_drop): Dropout(p=0.0, inplace=False)
766
+ (proj): lora.Linear(
767
+ (base_layer): Linear(in_features=384, out_features=384, bias=True)
768
+ (lora_dropout): ModuleDict(
769
+ (default): Dropout(p=0.05, inplace=False)
770
+ )
771
+ (lora_A): ModuleDict(
772
+ (default): Linear(in_features=384, out_features=8, bias=False)
773
+ )
774
+ (lora_B): ModuleDict(
775
+ (default): Linear(in_features=8, out_features=384, bias=False)
776
+ )
777
+ (lora_embedding_A): ParameterDict()
778
+ (lora_embedding_B): ParameterDict()
779
+ (lora_magnitude_vector): ModuleDict()
780
+ )
781
+ (proj_drop): Dropout(p=0.0, inplace=False)
782
+ (softmax): Softmax(dim=-1)
783
+ )
784
+ (drop_path): DropPath()
785
+ (norm2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
786
+ (mlp): Mlp(
787
+ (fc1): lora.Linear(
788
+ (base_layer): Linear(in_features=384, out_features=1536, bias=True)
789
+ (lora_dropout): ModuleDict(
790
+ (default): Dropout(p=0.05, inplace=False)
791
+ )
792
+ (lora_A): ModuleDict(
793
+ (default): Linear(in_features=384, out_features=8, bias=False)
794
+ )
795
+ (lora_B): ModuleDict(
796
+ (default): Linear(in_features=8, out_features=1536, bias=False)
797
+ )
798
+ (lora_embedding_A): ParameterDict()
799
+ (lora_embedding_B): ParameterDict()
800
+ (lora_magnitude_vector): ModuleDict()
801
+ )
802
+ (act): GELU(approximate='none')
803
+ (fc2): lora.Linear(
804
+ (base_layer): Linear(in_features=1536, out_features=384, bias=True)
805
+ (lora_dropout): ModuleDict(
806
+ (default): Dropout(p=0.05, inplace=False)
807
+ )
808
+ (lora_A): ModuleDict(
809
+ (default): Linear(in_features=1536, out_features=8, bias=False)
810
+ )
811
+ (lora_B): ModuleDict(
812
+ (default): Linear(in_features=8, out_features=384, bias=False)
813
+ )
814
+ (lora_embedding_A): ParameterDict()
815
+ (lora_embedding_B): ParameterDict()
816
+ (lora_magnitude_vector): ModuleDict()
817
+ )
818
+ (drop): Dropout(p=0.0, inplace=False)
819
+ )
820
+ )
821
+ )
822
+ (downsample): PatchMerging(
823
+ input_resolution=(16, 16), dim=384
824
+ (reduction): Linear(in_features=1536, out_features=768, bias=False)
825
+ (norm): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)
826
+ )
827
+ )
828
+ (3): BasicLayer(
829
+ dim=768, input_resolution=(8, 8), depth=2
830
+ (blocks): ModuleList(
831
+ (0-1): 2 x SwinTransformerBlock(
832
+ dim=768, input_resolution=(8, 8), num_heads=32, window_size=8, shift_size=0, mlp_ratio=4.0
833
+ (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
834
+ (attn): WindowAttention(
835
+ dim=768, window_size=(8, 8), num_heads=32
836
+ (qkv): lora.Linear(
837
+ (base_layer): Linear(in_features=768, out_features=2304, bias=True)
838
+ (lora_dropout): ModuleDict(
839
+ (default): Dropout(p=0.05, inplace=False)
840
+ )
841
+ (lora_A): ModuleDict(
842
+ (default): Linear(in_features=768, out_features=8, bias=False)
843
+ )
844
+ (lora_B): ModuleDict(
845
+ (default): Linear(in_features=8, out_features=2304, bias=False)
846
+ )
847
+ (lora_embedding_A): ParameterDict()
848
+ (lora_embedding_B): ParameterDict()
849
+ (lora_magnitude_vector): ModuleDict()
850
+ )
851
+ (attn_drop): Dropout(p=0.0, inplace=False)
852
+ (proj): lora.Linear(
853
+ (base_layer): Linear(in_features=768, out_features=768, bias=True)
854
+ (lora_dropout): ModuleDict(
855
+ (default): Dropout(p=0.05, inplace=False)
856
+ )
857
+ (lora_A): ModuleDict(
858
+ (default): Linear(in_features=768, out_features=8, bias=False)
859
+ )
860
+ (lora_B): ModuleDict(
861
+ (default): Linear(in_features=8, out_features=768, bias=False)
862
+ )
863
+ (lora_embedding_A): ParameterDict()
864
+ (lora_embedding_B): ParameterDict()
865
+ (lora_magnitude_vector): ModuleDict()
866
+ )
867
+ (proj_drop): Dropout(p=0.0, inplace=False)
868
+ (softmax): Softmax(dim=-1)
869
+ )
870
+ (drop_path): DropPath()
871
+ (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
872
+ (mlp): Mlp(
873
+ (fc1): lora.Linear(
874
+ (base_layer): Linear(in_features=768, out_features=3072, bias=True)
875
+ (lora_dropout): ModuleDict(
876
+ (default): Dropout(p=0.05, inplace=False)
877
+ )
878
+ (lora_A): ModuleDict(
879
+ (default): Linear(in_features=768, out_features=8, bias=False)
880
+ )
881
+ (lora_B): ModuleDict(
882
+ (default): Linear(in_features=8, out_features=3072, bias=False)
883
+ )
884
+ (lora_embedding_A): ParameterDict()
885
+ (lora_embedding_B): ParameterDict()
886
+ (lora_magnitude_vector): ModuleDict()
887
+ )
888
+ (act): GELU(approximate='none')
889
+ (fc2): lora.Linear(
890
+ (base_layer): Linear(in_features=3072, out_features=768, bias=True)
891
+ (lora_dropout): ModuleDict(
892
+ (default): Dropout(p=0.05, inplace=False)
893
+ )
894
+ (lora_A): ModuleDict(
895
+ (default): Linear(in_features=3072, out_features=8, bias=False)
896
+ )
897
+ (lora_B): ModuleDict(
898
+ (default): Linear(in_features=8, out_features=768, bias=False)
899
+ )
900
+ (lora_embedding_A): ParameterDict()
901
+ (lora_embedding_B): ParameterDict()
902
+ (lora_magnitude_vector): ModuleDict()
903
+ )
904
+ (drop): Dropout(p=0.0, inplace=False)
905
+ )
906
+ )
907
+ )
908
+ )
909
+ )
910
+ (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
911
+ (avgpool): AdaptiveAvgPool1d(output_size=1)
912
+ (maxpool): AdaptiveMaxPool1d(output_size=1)
913
+ (tscam_conv): Conv2d(768, 527, kernel_size=(2, 3), stride=(1, 1), padding=(0, 1))
914
+ (head): Linear(in_features=527, out_features=527, bias=True)
915
+ )
916
+ )
917
+ (projection): Projection(
918
+ (linear1): lora.Linear(
919
+ (base_layer): Linear(in_features=768, out_features=1024, bias=False)
920
+ (lora_dropout): ModuleDict(
921
+ (default): Dropout(p=0.05, inplace=False)
922
+ )
923
+ (lora_A): ModuleDict(
924
+ (default): Linear(in_features=768, out_features=8, bias=False)
925
+ )
926
+ (lora_B): ModuleDict(
927
+ (default): Linear(in_features=8, out_features=1024, bias=False)
928
+ )
929
+ (lora_embedding_A): ParameterDict()
930
+ (lora_embedding_B): ParameterDict()
931
+ (lora_magnitude_vector): ModuleDict()
932
+ )
933
+ (linear2): lora.Linear(
934
+ (base_layer): Linear(in_features=1024, out_features=1024, bias=False)
935
+ (lora_dropout): ModuleDict(
936
+ (default): Dropout(p=0.05, inplace=False)
937
+ )
938
+ (lora_A): ModuleDict(
939
+ (default): Linear(in_features=1024, out_features=8, bias=False)
940
+ )
941
+ (lora_B): ModuleDict(
942
+ (default): Linear(in_features=8, out_features=1024, bias=False)
943
+ )
944
+ (lora_embedding_A): ParameterDict()
945
+ (lora_embedding_B): ParameterDict()
946
+ (lora_magnitude_vector): ModuleDict()
947
+ )
948
+ (layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
949
+ (drop): Dropout(p=0.5, inplace=False)
950
+ )
951
+ )
952
+ )
953
+ )
954
+
Vaani/Img_Audio_Alignment/_1_CLAP-Audio-Encoder.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
Vaani/Img_Audio_Alignment/_2_Train.py ADDED
@@ -0,0 +1,1495 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ==================================================================
2
+ # L A T E N T D I F F U S I O N M O D E L
3
+ # ==================================================================
4
+ # Author : Ashish Kumar Uchadiya
5
+ # Created : May 11, 2025
6
+ # Description: This script implements the training of a VQ-VAE model for
7
+ # image reconstruction, integrated with Latent Diffusion Models (LDMs) and
8
+ # audio conditioning. The VQ-VAE maps images to a discrete latent space,
9
+ # which is then modeled by the LDM for learning a diffusion process over the
10
+ # compressed representation. Audio features are used as conditioning inputs
11
+ # to guide the generation process. The training minimizes a combination of
12
+ # LPIPS (Learned Perceptual Image Patch Similarity) loss for perceptual
13
+ # fidelity and PatchGAN loss to enforce local realism. This setup enables
14
+ # efficient and semantically-aware generation of high-quality images driven
15
+ # by audio cues.
16
+ # ==================================================================
17
+ # I M P O R T S
18
+ # ==================================================================
19
+ from __future__ import annotations
20
+ import warnings
21
+ warnings.filterwarnings("ignore")
22
+
23
+ import os
24
+ import sys
25
+ import math
26
+ import random
27
+ import collections
28
+ import collections.abc
29
+ import re
30
+ from itertools import repeat
31
+ from pathlib import Path
32
+ from typing import Optional, Tuple, Union, List, Dict
33
+
34
+ import requests
35
+ from PIL import Image
36
+ import numpy as np
37
+ import torch
38
+ import torch.nn.functional as F
39
+ from torch import nn
40
+ from torch.nn.init import _calculate_fan_in_and_fan_out
41
+ import torch.utils.checkpoint as checkpoint
42
+
43
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
44
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
45
+ print(f"Using device: {device}")
46
+
47
+ import torchaudio
48
+ import torchaudio.transforms as T
49
+ from torchlibrosa.stft import Spectrogram, LogmelFilterBank
50
+ from torchlibrosa.augmentation import SpecAugmentation
51
+
52
+ from transformers import AutoModel, AutoTokenizer, logging
53
+ from huggingface_hub.file_download import hf_hub_download
54
+ from huggingface_hub.file_download import hf_hub_download
55
+ from peft import get_peft_config, get_peft_model
56
+ from transformers import CLIPVisionModel, AutoProcessor
57
+
58
+
59
+
60
+ # ==================================================================
61
+ # H T S - A T
62
+ # ==================================================================
63
+ class HTSATConfig:
64
+ # Ke Chen
65
66
+ # HTS-AT: A HIERARCHICAL TOKEN-SEMANTIC AUDIO TRANSFORMER FOR SOUND CLASSIFICATION AND DETECTION
67
+ # The configuration for training the model
68
+
69
+ exp_name = "exp_htsat_pretrain" # the saved ckpt prefix name of the model
70
+ workspace = "/home/kechen/Research/HTSAT" # the folder of your code
71
+ dataset_path = "/home/Research/audioset" # the dataset path
72
+ desed_folder = "/home/Research/DESED" # the desed file
73
+
74
+ dataset_type = "audioset" # "audioset" "esc-50" "scv2"
75
+ index_type = "full_train" # only works for audioset
76
+ balanced_data = True # only works for audioset
77
+
78
+ loss_type = "clip_bce" #
79
+ # AudioSet & SCV2: "clip_bce" | ESC-50: "clip_ce"
80
+
81
+ # trained from a checkpoint, or evaluate a single model
82
+ resume_checkpoint = None
83
+ # "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_1.ckpt"
84
+
85
+ esc_fold = 0 # just for esc dataset, select the fold you need for evaluation and (+1) validation
86
+
87
+
88
+ debug = False
89
+
90
+ random_seed = 970131 # 19970318 970131 12412 127777 1009 34047
91
+ batch_size = 32 * 4 # batch size per GPU x GPU number , default is 32 x 4 = 128
92
+ learning_rate = 1e-3 # 1e-4 also workable
93
+ max_epoch = 100
94
+ num_workers = 3
95
+
96
+ lr_scheduler_epoch = [10,20,30]
97
+ lr_rate = [0.02, 0.05, 0.1]
98
+
99
+ # these data preparation optimizations do not bring many improvements, so deprecated
100
+ enable_token_label = False # token label
101
+ class_map_path = "class_hier_map.npy"
102
+ class_filter = None
103
+ retrieval_index = [15382, 9202, 130, 17618, 17157, 17516, 16356, 6165, 13992, 9238, 5550, 5733, 1914, 1600, 3450, 13735, 11108, 3762,
104
+ 9840, 11318, 8131, 4429, 16748, 4992, 16783, 12691, 4945, 8779, 2805, 9418, 2797, 14357, 5603, 212, 3852, 12666, 1338, 10269, 2388, 8260, 4293, 14454, 7677, 11253, 5060, 14938, 8840, 4542, 2627, 16336, 8992, 15496, 11140, 446, 6126, 10691, 8624, 10127, 9068, 16710, 10155, 14358, 7567, 5695, 2354, 8057, 17635, 133, 16183, 14535, 7248, 4560, 14429, 2463, 10773, 113, 2462, 9223, 4929, 14274, 4716, 17307, 4617, 2132, 11083, 1039, 1403, 9621, 13936, 2229, 2875, 17840, 9359, 13311, 9790, 13288, 4750, 17052, 8260, 14900]
105
+ token_label_range = [0.2,0.6]
106
+ enable_time_shift = False # shift time
107
+ enable_label_enhance = False # enhance hierarchical label
108
+ enable_repeat_mode = False # repeat the spectrogram / reshape the spectrogram
109
+
110
+
111
+
112
+ # for model's design
113
+ enable_tscam = True # enbale the token-semantic layer
114
+
115
+ # for signal processing
116
+ sample_rate = 32000 # 16000 for scv2, 32000 for audioset and esc-50
117
+ clip_samples = sample_rate * 10 # audio_set 10-sec clip
118
+ window_size = 1024
119
+ hop_size = 320 # 160 for scv2, 320 for audioset and esc-50
120
+ mel_bins = 64
121
+ fmin = 50
122
+ fmax = 14000
123
+ shift_max = int(clip_samples * 0.5)
124
+
125
+ # for data collection
126
+ classes_num = 527 # esc: 50 | audioset: 527 | scv2: 35
127
+ patch_size = (25, 4) # deprecated
128
+ crop_size = None # int(clip_samples * 0.5) deprecated
129
+
130
+ # for htsat hyperparamater
131
+ htsat_window_size = 8
132
+ htsat_spec_size = 256
133
+ htsat_patch_size = 4
134
+ htsat_stride = (4, 4)
135
+ htsat_num_head = [4,8,16,32]
136
+ htsat_dim = 96
137
+ htsat_depth = [2,2,6,2]
138
+
139
+ swin_pretrain_path = None
140
+ # "/home/Research/model_backup/pretrain/swin_tiny_c24_patch4_window8_256.pth"
141
+
142
+ # Some Deprecated Optimization in the model design, check the model code for details
143
+ htsat_attn_heatmap = False
144
+ htsat_hier_output = False
145
+ htsat_use_max = False
146
+
147
+
148
+ # for ensemble test
149
+
150
+ ensemble_checkpoints = []
151
+ ensemble_strides = []
152
+
153
+
154
+ # weight average folder
155
+ wa_folder = "/home/version_0/checkpoints/"
156
+ # weight average output filename
157
+ wa_model_path = "HTSAT_AudioSet_Saved_x.ckpt"
158
+
159
+ esm_model_pathes = [
160
+ "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_1.ckpt",
161
+ "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_2.ckpt",
162
+ "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_3.ckpt",
163
+ "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_4.ckpt",
164
+ "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_5.ckpt",
165
+ "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_6.ckpt"
166
+ ]
167
+
168
+ # for framewise localization
169
+ heatmap_dir = "/home/Research/heatmap_output"
170
+ test_file = "htsat-test-ensemble"
171
+ fl_local = False # indicate if we need to use this dataset for the framewise detection
172
+ fl_dataset = "/home/Research/desed/desed_eval.npy"
173
+ fl_class_num = [
174
+ "Speech", "Frying", "Dishes", "Running_water",
175
+ "Blender", "Electric_shaver_toothbrush", "Alarm_bell_ringing",
176
+ "Cat", "Dog", "Vacuum_cleaner"
177
+ ]
178
+
179
+ # map 527 classes into 10 classes
180
+ fl_audioset_mapping = [
181
+ [0,1,2,3,4,5,6,7],
182
+ [366, 367, 368],
183
+ [364],
184
+ [288, 289, 290, 291, 292, 293, 294, 295, 296, 297],
185
+ [369],
186
+ [382],
187
+ [310, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402],
188
+ [81, 82, 83, 84, 85],
189
+ [74, 75, 76, 77, 78, 79],
190
+ [377]
191
+ ]
192
+
193
+
194
+
195
+ def _ntuple(n):
196
+ def parse(x):
197
+ if isinstance(x, collections.abc.Iterable):
198
+ return x
199
+ return tuple(repeat(x, n))
200
+ return parse
201
+
202
+ to_1tuple = _ntuple(1)
203
+ to_2tuple = _ntuple(2)
204
+ to_3tuple = _ntuple(3)
205
+ to_4tuple = _ntuple(4)
206
+ to_ntuple = _ntuple
207
+
208
+ def do_mixup(x, mixup_lambda):
209
+ """Mixup x of even indexes (0, 2, 4, ...) with x of odd indexes
210
+ (1, 3, 5, ...).
211
+ Args:
212
+ x: (batch_size * 2, ...)
213
+ mixup_lambda: (batch_size * 2,)
214
+ Returns:
215
+ out: (batch_size, ...)
216
+ """
217
+ out = (x[0 :: 2].transpose(0, -1) * mixup_lambda[0 :: 2] + \
218
+ x[1 :: 2].transpose(0, -1) * mixup_lambda[1 :: 2]).transpose(0, -1)
219
+ return out
220
+
221
+ def interpolate(x, ratio):
222
+ """Interpolate data in time domain. This is used to compensate the
223
+ resolution reduction in downsampling of a CNN.
224
+
225
+ Args:
226
+ x: (batch_size, time_steps, classes_num)
227
+ ratio: int, ratio to interpolate
228
+ Returns:
229
+ upsampled: (batch_size, time_steps * ratio, classes_num)
230
+ """
231
+ (batch_size, time_steps, classes_num) = x.shape
232
+ upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1)
233
+ upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num)
234
+ return upsampled
235
+
236
+
237
+ def drop_path(x, drop_prob: float = 0., training: bool = False):
238
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
239
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
240
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
241
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
242
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
243
+ 'survival rate' as the argument.
244
+ """
245
+ if drop_prob == 0. or not training:
246
+ return x
247
+ keep_prob = 1 - drop_prob
248
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
249
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
250
+ random_tensor.floor_() # binarize
251
+ output = x.div(keep_prob) * random_tensor
252
+ return output
253
+
254
+
255
+ class DropPath(nn.Module):
256
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
257
+ """
258
+ def __init__(self, drop_prob=None):
259
+ super(DropPath, self).__init__()
260
+ self.drop_prob = drop_prob
261
+
262
+ def forward(self, x):
263
+ return drop_path(x, self.drop_prob, self.training)
264
+
265
+ class PatchEmbed(nn.Module):
266
+ """ 2D Image to Patch Embedding
267
+ """
268
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True, patch_stride = 16):
269
+ super().__init__()
270
+ img_size = to_2tuple(img_size)
271
+ patch_size = to_2tuple(patch_size)
272
+ patch_stride = to_2tuple(patch_stride)
273
+ self.img_size = img_size
274
+ self.patch_size = patch_size
275
+ self.patch_stride = patch_stride
276
+ self.grid_size = (img_size[0] // patch_stride[0], img_size[1] // patch_stride[1])
277
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
278
+ self.flatten = flatten
279
+ self.in_chans = in_chans
280
+ self.embed_dim = embed_dim
281
+
282
+ padding = ((patch_size[0] - patch_stride[0]) // 2, (patch_size[1] - patch_stride[1]) // 2)
283
+
284
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_stride, padding=padding)
285
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
286
+
287
+ def forward(self, x):
288
+ B, C, H, W = x.shape
289
+ assert H == self.img_size[0] and W == self.img_size[1], \
290
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
291
+ x = self.proj(x)
292
+ if self.flatten:
293
+ x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
294
+ x = self.norm(x)
295
+ return x
296
+
297
+ class Mlp(nn.Module):
298
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks
299
+ """
300
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
301
+ super().__init__()
302
+ out_features = out_features or in_features
303
+ hidden_features = hidden_features or in_features
304
+ self.fc1 = nn.Linear(in_features, hidden_features)
305
+ self.act = act_layer()
306
+ self.fc2 = nn.Linear(hidden_features, out_features)
307
+ self.drop = nn.Dropout(drop)
308
+
309
+ def forward(self, x):
310
+ x = self.fc1(x)
311
+ x = self.act(x)
312
+ x = self.drop(x)
313
+ x = self.fc2(x)
314
+ x = self.drop(x)
315
+ return x
316
+
317
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
318
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
319
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
320
+ def norm_cdf(x):
321
+ # Computes standard normal cumulative distribution function
322
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
323
+
324
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
325
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
326
+ "The distribution of values may be incorrect.",
327
+ stacklevel=2)
328
+
329
+ with torch.no_grad():
330
+ # Values are generated by using a truncated uniform distribution and
331
+ # then using the inverse CDF for the normal distribution.
332
+ # Get upper and lower cdf values
333
+ l = norm_cdf((a - mean) / std)
334
+ u = norm_cdf((b - mean) / std)
335
+
336
+ # Uniformly fill tensor with values from [l, u], then translate to
337
+ # [2l-1, 2u-1].
338
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
339
+
340
+ # Use inverse cdf transform for normal distribution to get truncated
341
+ # standard normal
342
+ tensor.erfinv_()
343
+
344
+ # Transform to proper mean, std
345
+ tensor.mul_(std * math.sqrt(2.))
346
+ tensor.add_(mean)
347
+
348
+ # Clamp to ensure it's in the proper range
349
+ tensor.clamp_(min=a, max=b)
350
+ return tensor
351
+
352
+
353
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
354
+ # type: (Tensor, float, float, float, float) -> Tensor
355
+ r"""Fills the input Tensor with values drawn from a truncated
356
+ normal distribution. The values are effectively drawn from the
357
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
358
+ with values outside :math:`[a, b]` redrawn until they are within
359
+ the bounds. The method used for generating the random values works
360
+ best when :math:`a \leq \text{mean} \leq b`.
361
+ Args:
362
+ tensor: an n-dimensional `torch.Tensor`
363
+ mean: the mean of the normal distribution
364
+ std: the standard deviation of the normal distribution
365
+ a: the minimum cutoff value
366
+ b: the maximum cutoff value
367
+ Examples:
368
+ >>> w = torch.empty(3, 5)
369
+ >>> nn.init.trunc_normal_(w)
370
+ """
371
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
372
+
373
+
374
+ def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'):
375
+ fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
376
+ if mode == 'fan_in':
377
+ denom = fan_in
378
+ elif mode == 'fan_out':
379
+ denom = fan_out
380
+ elif mode == 'fan_avg':
381
+ denom = (fan_in + fan_out) / 2
382
+
383
+ variance = scale / denom
384
+
385
+ if distribution == "truncated_normal":
386
+ # constant is stddev of standard normal truncated to (-2, 2)
387
+ trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978)
388
+ elif distribution == "normal":
389
+ tensor.normal_(std=math.sqrt(variance))
390
+ elif distribution == "uniform":
391
+ bound = math.sqrt(3 * variance)
392
+ tensor.uniform_(-bound, bound)
393
+ else:
394
+ raise ValueError(f"invalid distribution {distribution}")
395
+
396
+
397
+ def lecun_normal_(tensor):
398
+ variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal')
399
+
400
+
401
+ # below codes are based and referred from https://github.com/microsoft/Swin-Transformer
402
+ # Swin Transformer for Computer Vision: https://arxiv.org/pdf/2103.14030.pdf
403
+
404
+ def window_partition(x, window_size):
405
+ """
406
+ Args:
407
+ x: (B, H, W, C)
408
+ window_size (int): window size
409
+ Returns:
410
+ windows: (num_windows*B, window_size, window_size, C)
411
+ """
412
+ B, H, W, C = x.shape
413
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
414
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
415
+ return windows
416
+
417
+
418
+ def window_reverse(windows, window_size, H, W):
419
+ """
420
+ Args:
421
+ windows: (num_windows*B, window_size, window_size, C)
422
+ window_size (int): Window size
423
+ H (int): Height of image
424
+ W (int): Width of image
425
+ Returns:
426
+ x: (B, H, W, C)
427
+ """
428
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
429
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
430
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
431
+ return x
432
+
433
+
434
+ class WindowAttention(nn.Module):
435
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
436
+ It supports both of shifted and non-shifted window.
437
+ Args:
438
+ dim (int): Number of input channels.
439
+ window_size (tuple[int]): The height and width of the window.
440
+ num_heads (int): Number of attention heads.
441
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
442
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
443
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
444
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
445
+ """
446
+
447
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
448
+
449
+ super().__init__()
450
+ self.dim = dim
451
+ self.window_size = window_size # Wh, Ww
452
+ self.num_heads = num_heads
453
+ head_dim = dim // num_heads
454
+ self.scale = qk_scale or head_dim ** -0.5
455
+
456
+ # define a parameter table of relative position bias
457
+ self.relative_position_bias_table = nn.Parameter(
458
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
459
+
460
+ # get pair-wise relative position index for each token inside the window
461
+ coords_h = torch.arange(self.window_size[0])
462
+ coords_w = torch.arange(self.window_size[1])
463
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
464
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
465
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
466
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
467
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
468
+ relative_coords[:, :, 1] += self.window_size[1] - 1
469
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
470
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
471
+ self.register_buffer("relative_position_index", relative_position_index)
472
+
473
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
474
+ self.attn_drop = nn.Dropout(attn_drop)
475
+ self.proj = nn.Linear(dim, dim)
476
+ self.proj_drop = nn.Dropout(proj_drop)
477
+
478
+ trunc_normal_(self.relative_position_bias_table, std=.02)
479
+ self.softmax = nn.Softmax(dim=-1)
480
+
481
+ def forward(self, x, mask=None):
482
+ """
483
+ Args:
484
+ x: input features with shape of (num_windows*B, N, C)
485
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
486
+ """
487
+ B_, N, C = x.shape
488
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
489
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
490
+
491
+ q = q * self.scale
492
+ attn = (q @ k.transpose(-2, -1))
493
+
494
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
495
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
496
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
497
+ attn = attn + relative_position_bias.unsqueeze(0)
498
+
499
+ if mask is not None:
500
+ nW = mask.shape[0]
501
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
502
+ attn = attn.view(-1, self.num_heads, N, N)
503
+ attn = self.softmax(attn)
504
+ else:
505
+ attn = self.softmax(attn)
506
+
507
+ attn = self.attn_drop(attn)
508
+
509
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
510
+ x = self.proj(x)
511
+ x = self.proj_drop(x)
512
+ return x, attn
513
+
514
+ def extra_repr(self):
515
+ return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
516
+
517
+
518
+ # We use the model based on Swintransformer Block, therefore we can use the swin-transformer pretrained model
519
+ class SwinTransformerBlock(nn.Module):
520
+ r""" Swin Transformer Block.
521
+ Args:
522
+ dim (int): Number of input channels.
523
+ input_resolution (tuple[int]): Input resulotion.
524
+ num_heads (int): Number of attention heads.
525
+ window_size (int): Window size.
526
+ shift_size (int): Shift size for SW-MSA.
527
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
528
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
529
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
530
+ drop (float, optional): Dropout rate. Default: 0.0
531
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
532
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
533
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
534
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
535
+ """
536
+
537
+ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
538
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
539
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm, norm_before_mlp='ln'):
540
+ super().__init__()
541
+ self.dim = dim
542
+ self.input_resolution = input_resolution
543
+ self.num_heads = num_heads
544
+ self.window_size = window_size
545
+ self.shift_size = shift_size
546
+ self.mlp_ratio = mlp_ratio
547
+ self.norm_before_mlp = norm_before_mlp
548
+ if min(self.input_resolution) <= self.window_size:
549
+ # if window size is larger than input resolution, we don't partition windows
550
+ self.shift_size = 0
551
+ self.window_size = min(self.input_resolution)
552
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
553
+
554
+ self.norm1 = norm_layer(dim)
555
+ self.attn = WindowAttention(
556
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
557
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
558
+
559
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
560
+ if self.norm_before_mlp == 'ln':
561
+ self.norm2 = nn.LayerNorm(dim)
562
+ elif self.norm_before_mlp == 'bn':
563
+ self.norm2 = lambda x: nn.BatchNorm1d(dim)(x.transpose(1, 2)).transpose(1, 2)
564
+ else:
565
+ raise NotImplementedError
566
+ mlp_hidden_dim = int(dim * mlp_ratio)
567
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
568
+
569
+ if self.shift_size > 0:
570
+ # calculate attention mask for SW-MSA
571
+ H, W = self.input_resolution
572
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
573
+ h_slices = (slice(0, -self.window_size),
574
+ slice(-self.window_size, -self.shift_size),
575
+ slice(-self.shift_size, None))
576
+ w_slices = (slice(0, -self.window_size),
577
+ slice(-self.window_size, -self.shift_size),
578
+ slice(-self.shift_size, None))
579
+ cnt = 0
580
+ for h in h_slices:
581
+ for w in w_slices:
582
+ img_mask[:, h, w, :] = cnt
583
+ cnt += 1
584
+
585
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
586
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
587
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
588
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
589
+ else:
590
+ attn_mask = None
591
+
592
+ self.register_buffer("attn_mask", attn_mask)
593
+
594
+ def forward(self, x):
595
+ # pdb.set_trace()
596
+ H, W = self.input_resolution
597
+ # print("H: ", H)
598
+ # print("W: ", W)
599
+ # pdb.set_trace()
600
+ B, L, C = x.shape
601
+ # assert L == H * W, "input feature has wrong size"
602
+
603
+ shortcut = x
604
+ x = self.norm1(x)
605
+ x = x.view(B, H, W, C)
606
+
607
+ # cyclic shift
608
+ if self.shift_size > 0:
609
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
610
+ else:
611
+ shifted_x = x
612
+
613
+ # partition windows
614
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
615
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
616
+
617
+ # W-MSA/SW-MSA
618
+ attn_windows, attn = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
619
+
620
+ # merge windows
621
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
622
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
623
+
624
+ # reverse cyclic shift
625
+ if self.shift_size > 0:
626
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
627
+ else:
628
+ x = shifted_x
629
+ x = x.view(B, H * W, C)
630
+
631
+ # FFN
632
+ x = shortcut + self.drop_path(x)
633
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
634
+
635
+ return x, attn
636
+
637
+ def extra_repr(self):
638
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
639
+ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
640
+
641
+
642
+
643
+ class PatchMerging(nn.Module):
644
+ r""" Patch Merging Layer.
645
+ Args:
646
+ input_resolution (tuple[int]): Resolution of input feature.
647
+ dim (int): Number of input channels.
648
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
649
+ """
650
+
651
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
652
+ super().__init__()
653
+ self.input_resolution = input_resolution
654
+ self.dim = dim
655
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
656
+ self.norm = norm_layer(4 * dim)
657
+
658
+ def forward(self, x):
659
+ """
660
+ x: B, H*W, C
661
+ """
662
+ H, W = self.input_resolution
663
+ B, L, C = x.shape
664
+ assert L == H * W, "input feature has wrong size"
665
+ assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
666
+
667
+ x = x.view(B, H, W, C)
668
+
669
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
670
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
671
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
672
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
673
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
674
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
675
+
676
+ x = self.norm(x)
677
+ x = self.reduction(x)
678
+
679
+ return x
680
+
681
+ def extra_repr(self):
682
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
683
+
684
+
685
+ class BasicLayer(nn.Module):
686
+ """ A basic Swin Transformer layer for one stage.
687
+ Args:
688
+ dim (int): Number of input channels.
689
+ input_resolution (tuple[int]): Input resolution.
690
+ depth (int): Number of blocks.
691
+ num_heads (int): Number of attention heads.
692
+ window_size (int): Local window size.
693
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
694
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
695
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
696
+ drop (float, optional): Dropout rate. Default: 0.0
697
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
698
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
699
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
700
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
701
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
702
+ """
703
+
704
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size,
705
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
706
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
707
+ norm_before_mlp='ln'):
708
+
709
+ super().__init__()
710
+ self.dim = dim
711
+ self.input_resolution = input_resolution
712
+ self.depth = depth
713
+ self.use_checkpoint = use_checkpoint
714
+
715
+ # build blocks
716
+ self.blocks = nn.ModuleList([
717
+ SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
718
+ num_heads=num_heads, window_size=window_size,
719
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
720
+ mlp_ratio=mlp_ratio,
721
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
722
+ drop=drop, attn_drop=attn_drop,
723
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
724
+ norm_layer=norm_layer, norm_before_mlp=norm_before_mlp)
725
+ for i in range(depth)])
726
+
727
+ # patch merging layer
728
+ if downsample is not None:
729
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
730
+ else:
731
+ self.downsample = None
732
+
733
+ def forward(self, x):
734
+ attns = []
735
+ for blk in self.blocks:
736
+ if self.use_checkpoint:
737
+ x = checkpoint.checkpoint(blk, x)
738
+ else:
739
+ x, attn = blk(x)
740
+ if not self.training:
741
+ attns.append(attn.unsqueeze(0))
742
+ if self.downsample is not None:
743
+ x = self.downsample(x)
744
+ if not self.training:
745
+ attn = torch.cat(attns, dim = 0)
746
+ attn = torch.mean(attn, dim = 0)
747
+ return x, attn
748
+
749
+ def extra_repr(self):
750
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
751
+
752
+
753
+ # The Core of HTSAT
754
+ class HTSAT_Swin_Transformer(nn.Module):
755
+ r"""HTSAT based on the Swin Transformer
756
+ Args:
757
+ spec_size (int | tuple(int)): Input Spectrogram size. Default 256
758
+ patch_size (int | tuple(int)): Patch size. Default: 4
759
+ path_stride (iot | tuple(int)): Patch Stride for Frequency and Time Axis. Default: 4
760
+ in_chans (int): Number of input image channels. Default: 1 (mono)
761
+ num_classes (int): Number of classes for classification head. Default: 527
762
+ embed_dim (int): Patch embedding dimension. Default: 96
763
+ depths (tuple(int)): Depth of each HTSAT-Swin Transformer layer.
764
+ num_heads (tuple(int)): Number of attention heads in different layers.
765
+ window_size (int): Window size. Default: 8
766
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
767
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
768
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
769
+ drop_rate (float): Dropout rate. Default: 0
770
+ attn_drop_rate (float): Attention dropout rate. Default: 0
771
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
772
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
773
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
774
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
775
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
776
+ config (module): The configuration Module from config.py (HTSATConfig Class)
777
+ """
778
+
779
+ def __init__(self, spec_size=256, patch_size=4, patch_stride=(4,4),
780
+ in_chans=1, num_classes=527,
781
+ embed_dim=96, depths=[2, 2, 6, 2], num_heads=[4, 8, 16, 32],
782
+ window_size=8, mlp_ratio=4., qkv_bias=True, qk_scale=None,
783
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
784
+ norm_layer=nn.LayerNorm,
785
+ ape=False, patch_norm=True,
786
+ use_checkpoint=False, norm_before_mlp='ln', config = None, **kwargs):
787
+ super(HTSAT_Swin_Transformer, self).__init__()
788
+
789
+ self.config = config
790
+ self.spec_size = spec_size
791
+ self.patch_stride = patch_stride
792
+ self.patch_size = patch_size
793
+ self.window_size = window_size
794
+ self.embed_dim = embed_dim
795
+ self.depths = depths
796
+ self.ape = ape
797
+ self.in_chans = in_chans
798
+ self.num_classes = num_classes
799
+ self.num_heads = num_heads
800
+ self.num_layers = len(self.depths)
801
+ self.num_features = int(self.embed_dim * 2 ** (self.num_layers - 1))
802
+
803
+ self.drop_rate = drop_rate
804
+ self.attn_drop_rate = attn_drop_rate
805
+ self.drop_path_rate = drop_path_rate
806
+
807
+ self.qkv_bias = qkv_bias
808
+ self.qk_scale = None
809
+
810
+ self.patch_norm = patch_norm
811
+ self.norm_layer = norm_layer if self.patch_norm else None
812
+ self.norm_before_mlp = norm_before_mlp
813
+ self.mlp_ratio = mlp_ratio
814
+
815
+ self.use_checkpoint = use_checkpoint
816
+
817
+ # process mel-spec ; used only once
818
+ self.freq_ratio = self.spec_size // self.config.mel_bins
819
+ window = 'hann'
820
+ center = True
821
+ pad_mode = 'reflect'
822
+ ref = 1.0
823
+ amin = 1e-10
824
+ top_db = None
825
+ self.interpolate_ratio = 32 # Downsampled ratio
826
+ # Spectrogram extractor
827
+ self.spectrogram_extractor = Spectrogram(n_fft=config.window_size, hop_length=config.hop_size,
828
+ win_length=config.window_size, window=window, center=center, pad_mode=pad_mode,
829
+ freeze_parameters=True)
830
+ # Logmel feature extractor
831
+ self.logmel_extractor = LogmelFilterBank(sr=config.sample_rate, n_fft=config.window_size,
832
+ n_mels=config.mel_bins, fmin=config.fmin, fmax=config.fmax, ref=ref, amin=amin, top_db=top_db,
833
+ freeze_parameters=True)
834
+ # Spec augmenter
835
+ self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2,
836
+ freq_drop_width=8, freq_stripes_num=2) # 2 2
837
+ self.bn0 = nn.BatchNorm2d(self.config.mel_bins)
838
+
839
+
840
+ # split spctrogram into non-overlapping patches
841
+ self.patch_embed = PatchEmbed(
842
+ img_size=self.spec_size, patch_size=self.patch_size, in_chans=self.in_chans,
843
+ embed_dim=self.embed_dim, norm_layer=self.norm_layer, patch_stride = patch_stride)
844
+
845
+ num_patches = self.patch_embed.num_patches
846
+ patches_resolution = self.patch_embed.grid_size
847
+ self.patches_resolution = patches_resolution
848
+
849
+ # absolute position embedding
850
+ if self.ape:
851
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, self.embed_dim))
852
+ trunc_normal_(self.absolute_pos_embed, std=.02)
853
+
854
+ self.pos_drop = nn.Dropout(p=self.drop_rate)
855
+
856
+ # stochastic depth
857
+ dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, sum(self.depths))] # stochastic depth decay rule
858
+
859
+ # build layers
860
+ self.layers = nn.ModuleList()
861
+ for i_layer in range(self.num_layers):
862
+ layer = BasicLayer(dim=int(self.embed_dim * 2 ** i_layer),
863
+ input_resolution=(patches_resolution[0] // (2 ** i_layer),
864
+ patches_resolution[1] // (2 ** i_layer)),
865
+ depth=self.depths[i_layer],
866
+ num_heads=self.num_heads[i_layer],
867
+ window_size=self.window_size,
868
+ mlp_ratio=self.mlp_ratio,
869
+ qkv_bias=self.qkv_bias, qk_scale=self.qk_scale,
870
+ drop=self.drop_rate, attn_drop=self.attn_drop_rate,
871
+ drop_path=dpr[sum(self.depths[:i_layer]):sum(self.depths[:i_layer + 1])],
872
+ norm_layer=self.norm_layer,
873
+ downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
874
+ use_checkpoint=use_checkpoint,
875
+ norm_before_mlp=self.norm_before_mlp)
876
+ self.layers.append(layer)
877
+
878
+ self.norm = self.norm_layer(self.num_features)
879
+ self.avgpool = nn.AdaptiveAvgPool1d(1)
880
+ self.maxpool = nn.AdaptiveMaxPool1d(1)
881
+
882
+ if self.config.enable_tscam:
883
+ SF = self.spec_size // (2 ** (len(self.depths) - 1)) // self.patch_stride[0] // self.freq_ratio
884
+ self.tscam_conv = nn.Conv2d(
885
+ in_channels = self.num_features,
886
+ out_channels = self.num_classes,
887
+ kernel_size = (SF,3),
888
+ padding = (0,1)
889
+ )
890
+ self.head = nn.Linear(num_classes, num_classes)
891
+ else:
892
+ self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
893
+
894
+ self.apply(self._init_weights)
895
+
896
+ def _init_weights(self, m):
897
+ if isinstance(m, nn.Linear):
898
+ trunc_normal_(m.weight, std=.02)
899
+ if isinstance(m, nn.Linear) and m.bias is not None:
900
+ nn.init.constant_(m.bias, 0)
901
+ elif isinstance(m, nn.LayerNorm):
902
+ nn.init.constant_(m.bias, 0)
903
+ nn.init.constant_(m.weight, 1.0)
904
+
905
+ @torch.jit.ignore
906
+ def no_weight_decay(self):
907
+ return {'absolute_pos_embed'}
908
+
909
+ @torch.jit.ignore
910
+ def no_weight_decay_keywords(self):
911
+ return {'relative_position_bias_table'}
912
+
913
+ def forward_features(self, x):
914
+ frames_num = x.shape[2]
915
+ x = self.patch_embed(x)
916
+ if self.ape:
917
+ x = x + self.absolute_pos_embed
918
+ x = self.pos_drop(x)
919
+ for i, layer in enumerate(self.layers):
920
+ x, attn = layer(x)
921
+
922
+ if self.config.enable_tscam:
923
+ # for x
924
+ x = self.norm(x)
925
+ B, N, C = x.shape
926
+ SF = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[0]
927
+ ST = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[1]
928
+ x = x.permute(0,2,1).contiguous().reshape(B, C, SF, ST)
929
+ B, C, F, T = x.shape
930
+ # group 2D CNN
931
+ c_freq_bin = F // self.freq_ratio
932
+ x = x.reshape(B, C, F // c_freq_bin, c_freq_bin, T)
933
+ x = x.permute(0,1,3,2,4).contiguous().reshape(B, C, c_freq_bin, -1)
934
+
935
+ # get latent_output
936
+ latent_output = self.avgpool(torch.flatten(x,2))
937
+ latent_output = torch.flatten(latent_output, 1)
938
+
939
+ # display the attention map, if needed
940
+ if self.config.htsat_attn_heatmap:
941
+ # for attn
942
+ attn = torch.mean(attn, dim = 1)
943
+ attn = torch.mean(attn, dim = 1)
944
+ attn = attn.reshape(B, SF, ST)
945
+ c_freq_bin = SF // self.freq_ratio
946
+ attn = attn.reshape(B, SF // c_freq_bin, c_freq_bin, ST)
947
+ attn = attn.permute(0,2,1,3).contiguous().reshape(B, c_freq_bin, -1)
948
+ attn = attn.mean(dim = 1)
949
+ attn_max = torch.max(attn, dim = 1, keepdim = True)[0]
950
+ attn_min = torch.min(attn, dim = 1, keepdim = True)[0]
951
+ attn = ((attn * 0.15) + (attn_max * 0.85 - attn_min)) / (attn_max - attn_min)
952
+ attn = attn.unsqueeze(dim = 2)
953
+
954
+ x = self.tscam_conv(x)
955
+ x = torch.flatten(x, 2) # B, C, T
956
+
957
+ if self.config.htsat_attn_heatmap:
958
+ fpx = interpolate(torch.sigmoid(x).permute(0,2,1).contiguous() * attn, 8 * self.patch_stride[1])
959
+ else:
960
+ fpx = interpolate(torch.sigmoid(x).permute(0,2,1).contiguous(), 8 * self.patch_stride[1])
961
+
962
+ x = self.avgpool(x)
963
+ x = torch.flatten(x, 1)
964
+
965
+ if self.config.loss_type == "clip_ce":
966
+ output_dict = {
967
+ 'framewise_output': fpx, # already sigmoided
968
+ 'clipwise_output': x,
969
+ 'latent_output': latent_output
970
+ }
971
+ else:
972
+ output_dict = {
973
+ 'framewise_output': fpx, # already sigmoided
974
+ 'clipwise_output': torch.sigmoid(x),
975
+ 'latent_output': latent_output
976
+ }
977
+
978
+ else:
979
+ x = self.norm(x) # B N C
980
+ B, N, C = x.shape
981
+
982
+ fpx = x.permute(0,2,1).contiguous().reshape(B, C, frames_num // (2 ** (len(self.depths) + 1)), frames_num // (2 ** (len(self.depths) + 1)) )
983
+ B, C, F, T = fpx.shape
984
+ c_freq_bin = F // self.freq_ratio
985
+ fpx = fpx.reshape(B, C, F // c_freq_bin, c_freq_bin, T)
986
+ fpx = fpx.permute(0,1,3,2,4).contiguous().reshape(B, C, c_freq_bin, -1)
987
+ fpx = torch.sum(fpx, dim = 2)
988
+ fpx = interpolate(fpx.permute(0,2,1).contiguous(), 8 * self.patch_stride[1])
989
+ x = self.avgpool(x.transpose(1, 2)) # B C 1
990
+ x = torch.flatten(x, 1)
991
+ if self.num_classes > 0:
992
+ x = self.head(x)
993
+ fpx = self.head(fpx)
994
+ output_dict = {'framewise_output': torch.sigmoid(fpx),
995
+ 'clipwise_output': torch.sigmoid(x)}
996
+ return output_dict
997
+
998
+ def crop_wav(self, x, crop_size, spe_pos = None):
999
+ time_steps = x.shape[2]
1000
+ tx = torch.zeros(x.shape[0], x.shape[1], crop_size, x.shape[3]).to(x.device)
1001
+ for i in range(len(x)):
1002
+ if spe_pos is None:
1003
+ crop_pos = random.randint(0, time_steps - crop_size - 1)
1004
+ else:
1005
+ crop_pos = spe_pos
1006
+ tx[i][0] = x[i, 0, crop_pos:crop_pos + crop_size,:]
1007
+ return tx
1008
+
1009
+ # Reshape the wavform to a img size, if you want to use the pretrained swin transformer model
1010
+ def reshape_wav2img(self, x):
1011
+ B, C, T, F = x.shape
1012
+ target_T = int(self.spec_size * self.freq_ratio)
1013
+ target_F = self.spec_size // self.freq_ratio
1014
+ assert T <= target_T and F <= target_F, "the wav size should less than or equal to the swin input size"
1015
+ # to avoid bicubic zero error
1016
+ if T < target_T:
1017
+ x = nn.functional.interpolate(x, (target_T, x.shape[3]), mode="bicubic", align_corners=True)
1018
+ if F < target_F:
1019
+ x = nn.functional.interpolate(x, (x.shape[2], target_F), mode="bicubic", align_corners=True)
1020
+ x = x.permute(0,1,3,2).contiguous()
1021
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2], self.freq_ratio, x.shape[3] // self.freq_ratio)
1022
+ # print(x.shape)
1023
+ x = x.permute(0,1,3,2,4).contiguous()
1024
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3], x.shape[4])
1025
+ return x
1026
+
1027
+ # Repeat the wavform to a img size, if you want to use the pretrained swin transformer model
1028
+ def repeat_wat2img(self, x, cur_pos):
1029
+ B, C, T, F = x.shape
1030
+ target_T = int(self.spec_size * self.freq_ratio)
1031
+ target_F = self.spec_size // self.freq_ratio
1032
+ assert T <= target_T and F <= target_F, "the wav size should less than or equal to the swin input size"
1033
+ # to avoid bicubic zero error
1034
+ if T < target_T:
1035
+ x = nn.functional.interpolate(x, (target_T, x.shape[3]), mode="bicubic", align_corners=True)
1036
+ if F < target_F:
1037
+ x = nn.functional.interpolate(x, (x.shape[2], target_F), mode="bicubic", align_corners=True)
1038
+ x = x.permute(0,1,3,2).contiguous() # B C F T
1039
+ x = x[:,:,:,cur_pos:cur_pos + self.spec_size]
1040
+ x = x.repeat(repeats = (1,1,4,1))
1041
+ return x
1042
+
1043
+ def forward(self, x: torch.Tensor, mixup_lambda = None, infer_mode = False):# out_feat_keys: List[str] = None):
1044
+ x = self.spectrogram_extractor(x) # (batch_size, 1, time_steps, freq_bins)
1045
+ x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
1046
+
1047
+
1048
+ x = x.transpose(1, 3)
1049
+ x = self.bn0(x)
1050
+ x = x.transpose(1, 3)
1051
+ if self.training:
1052
+ x = self.spec_augmenter(x)
1053
+ if self.training and mixup_lambda is not None:
1054
+ x = do_mixup(x, mixup_lambda)
1055
+
1056
+ if infer_mode:
1057
+ # in infer mode. we need to handle different length audio input
1058
+ frame_num = x.shape[2]
1059
+ target_T = int(self.spec_size * self.freq_ratio)
1060
+ repeat_ratio = math.floor(target_T / frame_num)
1061
+ x = x.repeat(repeats=(1,1,repeat_ratio,1))
1062
+ x = self.reshape_wav2img(x)
1063
+ output_dict = self.forward_features(x)
1064
+ elif self.config.enable_repeat_mode:
1065
+ if self.training:
1066
+ cur_pos = random.randint(0, (self.freq_ratio - 1) * self.spec_size - 1)
1067
+ x = self.repeat_wat2img(x, cur_pos)
1068
+ output_dict = self.forward_features(x)
1069
+ else:
1070
+ output_dicts = []
1071
+ for cur_pos in range(0, (self.freq_ratio - 1) * self.spec_size + 1, self.spec_size):
1072
+ tx = x.clone()
1073
+ tx = self.repeat_wat2img(tx, cur_pos)
1074
+ output_dicts.append(self.forward_features(tx))
1075
+ clipwise_output = torch.zeros_like(output_dicts[0]["clipwise_output"]).float().to(x.device)
1076
+ framewise_output = torch.zeros_like(output_dicts[0]["framewise_output"]).float().to(x.device)
1077
+ for d in output_dicts:
1078
+ clipwise_output += d["clipwise_output"]
1079
+ framewise_output += d["framewise_output"]
1080
+ clipwise_output = clipwise_output / len(output_dicts)
1081
+ framewise_output = framewise_output / len(output_dicts)
1082
+
1083
+ output_dict = {
1084
+ 'framewise_output': framewise_output,
1085
+ 'clipwise_output': clipwise_output
1086
+ }
1087
+ else:
1088
+ if x.shape[2] > self.freq_ratio * self.spec_size:
1089
+ if self.training:
1090
+ x = self.crop_wav(x, crop_size=self.freq_ratio * self.spec_size)
1091
+ x = self.reshape_wav2img(x)
1092
+ output_dict = self.forward_features(x)
1093
+ else:
1094
+ # Change: Hard code here
1095
+ overlap_size = 344 #(x.shape[2] - 1) // 4
1096
+ output_dicts = []
1097
+ crop_size = 689 #(x.shape[2] - 1) // 2
1098
+ for cur_pos in range(0, x.shape[2] - crop_size - 1, overlap_size):
1099
+ tx = self.crop_wav(x, crop_size = crop_size, spe_pos = cur_pos)
1100
+ tx = self.reshape_wav2img(tx)
1101
+ output_dicts.append(self.forward_features(tx))
1102
+ clipwise_output = torch.zeros_like(output_dicts[0]["clipwise_output"]).float().to(x.device)
1103
+ framewise_output = torch.zeros_like(output_dicts[0]["framewise_output"]).float().to(x.device)
1104
+ latent_output = torch.zeros_like(output_dicts[0]["latent_output"]).float().to(x.device)
1105
+ for d in output_dicts:
1106
+ clipwise_output += d["clipwise_output"]
1107
+ framewise_output += d["framewise_output"]
1108
+ latent_output += d["latent_output"]
1109
+ clipwise_output = clipwise_output / len(output_dicts)
1110
+ framewise_output = framewise_output / len(output_dicts)
1111
+ latent_output = latent_output / len(output_dicts)
1112
+ output_dict = {
1113
+ 'framewise_output': framewise_output,
1114
+ 'clipwise_output': clipwise_output,
1115
+ 'latent_output': latent_output,
1116
+ }
1117
+ else: # this part is typically used, and most easy one
1118
+ x = self.reshape_wav2img(x)
1119
+ output_dict = self.forward_features(x)
1120
+ # x = self.head(x)
1121
+ return output_dict
1122
+
1123
+ class HTSATWrapper(nn.Module):
1124
+ def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin,
1125
+ fmax, classes_num, out_emb):
1126
+ super().__init__()
1127
+
1128
+ # print("parameters are being overidden when using HTSAT")
1129
+ # print("HTSAT only support loading a pretrained model on AudioSet")
1130
+ # @TODO later look at what parameters are same and can be merged
1131
+
1132
+ self.htsat = HTSAT_Swin_Transformer(config=HTSATConfig())
1133
+
1134
+ def forward(self, x):
1135
+ out_dict = self.htsat(x)
1136
+ out_dict['embedding'] = out_dict['latent_output']
1137
+ return out_dict
1138
+
1139
+
1140
+ def get_audio_encoder(name: str):
1141
+ if name == "HTSAT":
1142
+ return HTSATWrapper
1143
+ else:
1144
+ raise Exception('The audio encoder name {} is incorrect or not supported'.format(name))
1145
+
1146
+ class Projection(nn.Module):
1147
+ def __init__(self, d_in: int, d_out: int, p: float=0.5) -> None:
1148
+ super().__init__()
1149
+ self.linear1 = nn.Linear(d_in, d_out, bias=False)
1150
+ self.linear2 = nn.Linear(d_out, d_out, bias=False)
1151
+ self.layer_norm = nn.LayerNorm(d_out)
1152
+ self.drop = nn.Dropout(p)
1153
+
1154
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1155
+ embed1 = self.linear1(x)
1156
+ embed2 = self.drop(self.linear2(F.gelu(embed1)))
1157
+ embeds = self.layer_norm(embed1 + embed2)
1158
+ return embeds
1159
+
1160
+ class AudioEncoder(nn.Module):
1161
+ def __init__(self, audioenc_name:str, d_in: int, d_out: int, sample_rate: int, window_size: int,
1162
+ hop_size: int, mel_bins: int, fmin: int, fmax: int, classes_num: int) -> None:
1163
+ super().__init__()
1164
+
1165
+ audio_encoder = get_audio_encoder(audioenc_name)
1166
+
1167
+ self.base = audio_encoder(
1168
+ sample_rate, window_size,
1169
+ hop_size, mel_bins, fmin, fmax,
1170
+ classes_num, d_in)
1171
+
1172
+ self.projection = Projection(d_in, d_out)
1173
+
1174
+ def forward(self, x):
1175
+ out_dict = self.base(x)
1176
+ audio_features, audio_classification_output = out_dict['embedding'], out_dict['clipwise_output']
1177
+ projected_vec = self.projection(audio_features)
1178
+ return projected_vec, audio_classification_output
1179
+
1180
+ class TextEncoder(nn.Module):
1181
+ def __init__(self, d_out: int, text_model: str, transformer_embed_dim: int) -> None:
1182
+ super().__init__()
1183
+ self.text_model = text_model
1184
+ self.base = AutoModel.from_pretrained(text_model)
1185
+
1186
+ if 'clip' in text_model:
1187
+ self.clip_text_projection = self.base.text_projection
1188
+ self.base = self.base.text_model
1189
+ if 'base' in text_model:
1190
+ transformer_embed_dim = 512
1191
+
1192
+ self.projection = Projection(transformer_embed_dim, d_out)
1193
+
1194
+ def forward(self, x):
1195
+ if 'clip' in self.text_model:
1196
+ pooled_output = self.base(**x)[1] # get pooled output
1197
+ out = self.clip_text_projection(pooled_output) # get CLS token output
1198
+ elif 'gpt' in self.text_model:
1199
+ batch_size = x['input_ids'].shape[0]
1200
+ hidden_states = self.base(**x)[0] # (batch_size=4, seq_len, 768)
1201
+
1202
+ sequence_lengths = torch.ne(x['input_ids'], 0).sum(-1) - 1 # tensor([13, 14, 18, 17])
1203
+ out = hidden_states[torch.arange(batch_size, device=hidden_states.device), sequence_lengths] # [batch_size, 768] = [4, 768]
1204
+ else:
1205
+ out = self.base(**x)[0]
1206
+ out = out[:, 0, :] # get CLS token output
1207
+
1208
+ projected_vec = self.projection(out)
1209
+
1210
+ return projected_vec
1211
+
1212
+ class CLAP(nn.Module):
1213
+ def __init__(self,
1214
+ # audio
1215
+ audioenc_name: str,
1216
+ sample_rate: int,
1217
+ window_size: int,
1218
+ hop_size: int,
1219
+ mel_bins: int,
1220
+ fmin: int,
1221
+ fmax: int,
1222
+ classes_num: int,
1223
+ out_emb: int,
1224
+ # text
1225
+ text_model: str,
1226
+ transformer_embed_dim: int,
1227
+ # common
1228
+ d_proj: int,
1229
+ ):
1230
+ super().__init__()
1231
+
1232
+
1233
+ self.audio_encoder = AudioEncoder(
1234
+ audioenc_name, out_emb, d_proj,
1235
+ sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num)
1236
+
1237
+ self.caption_encoder = TextEncoder(
1238
+ d_proj, text_model, transformer_embed_dim
1239
+ )
1240
+
1241
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
1242
+
1243
+ def forward(self, audio, text):
1244
+ audio_embed, _ = self.audio_encoder(audio)
1245
+ caption_embed = self.caption_encoder(text)
1246
+
1247
+ return caption_embed, audio_embed, self.logit_scale.exp()
1248
+
1249
+
1250
+
1251
+ # ==================================================================
1252
+ # A U D I O - P R E - P R O C E S S I N G
1253
+ # ==================================================================
1254
+ def read_audio(audio_path, resample=True):
1255
+ r"""Loads audio file or array and returns a torch tensor"""
1256
+ # Randomly sample a segment of audio_duration from the clip or pad to match duration
1257
+ audio_time_series, sample_rate = torchaudio.load(audio_path)
1258
+
1259
+ resample_rate = clapConfig.sample_rate
1260
+ if resample and resample_rate != sample_rate:
1261
+ resampler = T.Resample(sample_rate, resample_rate)
1262
+ audio_time_series = resampler(audio_time_series)
1263
+ return audio_time_series, resample_rate
1264
+
1265
+ def load_audio_into_tensor(audio_path, audio_duration, resample=False):
1266
+ r"""Loads audio file and returns raw audio."""
1267
+ # Randomly sample a segment of audio_duration from the clip or pad to match duration
1268
+ audio_time_series, sample_rate = read_audio(audio_path, resample)
1269
+ audio_time_series = audio_time_series.reshape(-1)
1270
+
1271
+ # audio_time_series is shorter than predefined audio duration,
1272
+ # so audio_time_series is extended
1273
+ if audio_duration*sample_rate >= audio_time_series.shape[0]:
1274
+ repeat_factor = int(np.ceil((audio_duration*sample_rate) /
1275
+ audio_time_series.shape[0]))
1276
+ # Repeat audio_time_series by repeat_factor to match audio_duration
1277
+ audio_time_series = audio_time_series.repeat(repeat_factor)
1278
+ # remove excess part of audio_time_series
1279
+ audio_time_series = audio_time_series[0:audio_duration*sample_rate]
1280
+ else:
1281
+ # audio_time_series is longer than predefined audio duration,
1282
+ # so audio_time_series is trimmed
1283
+ start_index = random.randrange(
1284
+ audio_time_series.shape[0] - audio_duration*sample_rate)
1285
+ audio_time_series = audio_time_series[start_index:start_index +
1286
+ audio_duration*sample_rate]
1287
+ return torch.FloatTensor(audio_time_series)
1288
+
1289
+ np_str_obj_array_pattern = re.compile(r'[SaUO]')
1290
+ default_collate_err_msg_format = (
1291
+ "default_collate: batch must contain tensors, numpy arrays, numbers, "
1292
+ "dicts or lists; found {}")
1293
+
1294
+ def default_collate(batch):
1295
+ r"""Puts each data field into a tensor with outer dimension batch size"""
1296
+ elem = batch[0]
1297
+ elem_type = type(elem)
1298
+ if isinstance(elem, torch.Tensor):
1299
+ out = None
1300
+ if torch.utils.data.get_worker_info() is not None:
1301
+ # If we're in a background process, concatenate directly into a
1302
+ # shared memory tensor to avoid an extra copy
1303
+ numel = sum([x.numel() for x in batch])
1304
+ storage = elem.storage()._new_shared(numel)
1305
+ out = elem.new(storage)
1306
+ return torch.stack(batch, 0, out=out)
1307
+ elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
1308
+ and elem_type.__name__ != 'string_':
1309
+ if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
1310
+ # array of string classes and object
1311
+ if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
1312
+ raise TypeError(
1313
+ default_collate_err_msg_format.format(elem.dtype))
1314
+
1315
+ return default_collate([torch.as_tensor(b) for b in batch])
1316
+ elif elem.shape == (): # scalars
1317
+ return torch.as_tensor(batch)
1318
+ elif isinstance(elem, float):
1319
+ return torch.tensor(batch, dtype=torch.float64)
1320
+ elif isinstance(elem, int):
1321
+ return torch.tensor(batch)
1322
+ elif isinstance(elem, str):
1323
+ return batch
1324
+ elif isinstance(elem, collections.abc.Mapping):
1325
+ return {key: default_collate([d[key] for d in batch]) for key in elem}
1326
+ elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
1327
+ return elem_type(*(default_collate(samples) for samples in zip(*batch)))
1328
+ elif isinstance(elem, collections.abc.Sequence):
1329
+ # check to make sure that the elements in batch have consistent size
1330
+ it = iter(batch)
1331
+ elem_size = len(next(it))
1332
+ if not all(len(elem) == elem_size for elem in it):
1333
+ raise RuntimeError(
1334
+ 'each element in list of batch should be of equal size')
1335
+ transposed = zip(*batch)
1336
+ return [default_collate(samples) for samples in transposed]
1337
+
1338
+ raise TypeError(default_collate_err_msg_format.format(elem_type))
1339
+
1340
+ def preprocess_audio(audio_files, resample):
1341
+ r"""Load list of audio files and return raw audio"""
1342
+ audio_tensors = []
1343
+ for audio_file in audio_files:
1344
+ audio_tensor = load_audio_into_tensor(
1345
+ audio_file, clapConfig.duration, resample)
1346
+ audio_tensor = audio_tensor.reshape(1, -1).to(device)
1347
+ audio_tensors.append(audio_tensor)
1348
+ return default_collate(audio_tensors)
1349
+
1350
+
1351
+
1352
+ # ==================================================================
1353
+ # A U D I O - E M B E D D I N G S - H E L P E R
1354
+ # ==================================================================
1355
+ def get_audio_embeddings(audio_files: List[str], audio_encoder, resample=True):
1356
+ """Load list of audio files and return audio embeddings"""
1357
+ preprocessed_audio = preprocess_audio(audio_files, resample)
1358
+ with torch.no_grad():
1359
+ preprocessed_audio = preprocessed_audio.reshape(
1360
+ preprocessed_audio.shape[0], preprocessed_audio.shape[2])
1361
+ return audio_encoder(preprocessed_audio)[0]
1362
+
1363
+
1364
+ # ==================================================================
1365
+ # C L A P
1366
+ # ==================================================================
1367
+ class ClapConfig:
1368
+ # TEXT ENCODER CONFIG
1369
+ text_model = 'gpt2'
1370
+ text_len = 77
1371
+ transformer_embed_dim = 768
1372
+ freeze_text_encoder_weights = True
1373
+
1374
+ # AUDIO ENCODER CONFIG
1375
+ audioenc_name = 'HTSAT'
1376
+ out_emb = 768
1377
+ sample_rate = 44100
1378
+ duration = 7
1379
+ fmin = 50
1380
+ fmax = 8000 # 14000
1381
+ n_fft = 1024 # 1028
1382
+ hop_size = 320
1383
+ mel_bins = 64
1384
+ window_size = 1024
1385
+
1386
+ # PROJECTION SPACE CONFIG
1387
+ d_proj = 1024
1388
+ temperature = 0.003
1389
+
1390
+ # TRAINING AND EVALUATION CONFIG
1391
+ num_classes = 527
1392
+ batch_size = 1024
1393
+ demo = False
1394
+
1395
+
1396
+ clapConfig = ClapConfig()
1397
+ clap = CLAP(
1398
+ audioenc_name=clapConfig.audioenc_name,
1399
+ sample_rate=clapConfig.sample_rate,
1400
+ window_size=clapConfig.window_size,
1401
+ hop_size=clapConfig.hop_size,
1402
+ mel_bins=clapConfig.mel_bins,
1403
+ fmin=clapConfig.fmin,
1404
+ fmax=clapConfig.fmax,
1405
+ classes_num=clapConfig.num_classes,
1406
+ out_emb=clapConfig.out_emb,
1407
+ text_model=clapConfig.text_model,
1408
+ transformer_embed_dim=clapConfig.transformer_embed_dim,
1409
+ d_proj=clapConfig.d_proj
1410
+ )
1411
+
1412
+ model_repo = "microsoft/msclap"
1413
+ model_name = {
1414
+ '2022': 'CLAP_weights_2022.pth',
1415
+ '2023': 'CLAP_weights_2023.pth',
1416
+ 'clapcap': 'clapcap_weights_2023.pth'
1417
+ }
1418
+
1419
+ version = '2023'
1420
+ model_fp = hf_hub_download(model_repo, model_name[version])
1421
+
1422
+ model_state_dict = torch.load(model_fp, map_location=torch.device('cpu'))['model']
1423
+ clap.load_state_dict(model_state_dict, strict=False)
1424
+ clap.to(device)
1425
+ clap.eval()
1426
+
1427
+ clap_audio_encoder = clap.audio_encoder.eval()
1428
+
1429
+
1430
+ ENGLISH_AUDIO_DIR = r"/home/IITB/ai-at-ieor/23m1521/datasets/Vaani/Audios/English"
1431
+ audio_files = [os.path.join(ENGLISH_AUDIO_DIR, i) for i in os.listdir(ENGLISH_AUDIO_DIR) if i.endswith(".wav")]
1432
+ audio_embedding = get_audio_embeddings(audio_files, clap_audio_encoder)
1433
+ print("CLAP Audio Encoder Embeddings:", audio_embedding.shape) # [5, 1024]
1434
+
1435
+
1436
+ # ==================================================================
1437
+ # C L A P - L o R A - M O D E L
1438
+ # ==================================================================
1439
+ LoRAconfig = {
1440
+ "peft_type": "LORA",
1441
+ "task_type": "FEATURE_EXTRACTION",
1442
+ "inference_mode": False,
1443
+ "r": 16,
1444
+ "target_modules": ["qkv", "fc1", "fc2", "proj", "linear1", "linear2"],
1445
+ "lora_alpha": 32,
1446
+ "lora_dropout": 0.05,
1447
+ "fan_in_fan_out": False,
1448
+ "bias": "all",
1449
+ }
1450
+ peft_config = get_peft_config(LoRAconfig)
1451
+
1452
+ model = clap_audio_encoder
1453
+ peft_model = get_peft_model(model, peft_config)
1454
+
1455
+ peft_model.print_trainable_parameters()
1456
+
1457
+ # peft_model.base_model
1458
+ # peft_model
1459
+
1460
+ peft_clap_audio_encoder = peft_model.base_model
1461
+ audio_embedding = get_audio_embeddings(audio_files, peft_clap_audio_encoder)
1462
+ print("CLAP LoRA Audio Encoder Embeddings:", audio_embedding.shape) # [5, 1024]
1463
+
1464
+
1465
+ # ==================================================================
1466
+ # C L I P - M O D E L
1467
+ # ==================================================================
1468
+ from transformers import CLIPImageProcessorFast, CLIPImageProcessor
1469
+ clip_vision_model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
1470
+ # clip_vision_processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
1471
+ # clip_vision_processor = CLIPImageProcessorFast.from_pretrained("openai/clip-vit-base-patch32")
1472
+ clip_vision_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
1473
+
1474
+ image = Image.open("/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/Img_Audio_Alignment/000000039769.jpg")
1475
+ inputs = clip_vision_processor(images=image, return_tensors="pt")
1476
+ print("CLIP input image:", inputs['pixel_values'].shape)
1477
+ # input_data = {'pixel_values': inputs}
1478
+
1479
+ IMAGE_SIZE = 224
1480
+ dummy_input = torch.randn(1, 3, IMAGE_SIZE, IMAGE_SIZE) # [1, 3, 224, 224]
1481
+ # dummy_input_data = {'pixel_values': dummy_input}
1482
+
1483
+ # class VisionModelWrapper(nn.Module):
1484
+ # def __init__(self, peft_model):
1485
+ # super().__init__()
1486
+ # self.model = peft_model.base_model
1487
+
1488
+ # def forward(self, x):
1489
+ # return self.model(pixel_values=x).last_hidden_state
1490
+
1491
+ # wrapped_model = VisionModelWrapper(peft_model)
1492
+ # output = wrapped_model(input_data)
1493
+ output = clip_vision_model(inputs['pixel_values'])
1494
+ print("CLIP Image Encoder Embeddings:", output.last_hidden_state.shape) # [1, 50, 768]
1495
+ print("CLIP Image Encoder Pooled Output:", output.pooler_output.shape) # [1, 768]
Vaani/Img_Audio_Alignment/audio_embedding.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:71297113c0308ee06bedc099a3a8ebb889a76bfc987ef0872bdba9283bf1a3b9
3
+ size 20608
Vaani/Img_Audio_Alignment/audio_embedding_dismantled_msclap.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e015b9f4b8babb25e50c3557e8a2dfd5a09f2a659ed443d7db9a30b51ab40e3d
3
+ size 20608
Vaani/Img_Audio_Alignment/audio_embedding_dismantled_msclap_untrained.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5a2090189a3c657c42ffa73cffdddce10671b43cb664d6b3dae2899d8500cc3e
3
+ size 20608
Vaani/SDFT/checkpoints/checkpoint.pth CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:efbc74616b71c84b435ff10d1070c4149d5b17daf6fc9d5f76590522ea82ed6c
3
  size 2866661866
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dc0ddf905f5bac4366d70408fabf6f224ed2056d9a3290a227983839635ff3d4
3
  size 2866661866
Vaani/SDFT/samples/inference_epoch10.png ADDED
Vaani/SDFT/samples/inference_epoch9.png ADDED
Vaani/Vaani-Audio-Image-Hindi.csv ADDED
@@ -0,0 +1 @@
 
 
1
+ audio_path,referenceImage,gender,state,district
Vaani/VaaniLDM/ddpm_ckpt_epoch55.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cb8ca92e67d7a7ee3193ad69652374aa8034207c6a7cb5d8bebd8b20863bac20
3
+ size 593245226
Vaani/VaaniLDM/ddpm_ckpt_epoch56.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ccf36158867f9e8807d9a3f9b20f7b63d51f3cdee799ba3ad86d178c2bda63c6
3
+ size 593245290
Vaani/VaaniLDM/ldmH_ckpt_epoch49.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:17ad202fa634fb05bb389a7d00162b60cb426bee5077cbacd53369f27f7c00b0
3
+ size 2476369898
Vaani/VaaniLDM/ldmH_ckpt_epoch50.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b0845ce9ab94c92ff08afbc6066c39524f95f5dda2099731f964957c8adde96b
3
+ size 2476369962
Vaani/VaaniLDM/samples/x0_0.png CHANGED

Git LFS Details

  • SHA256: 7d5227300ede40ca19f6c2c1550880eb47a47a7c3ffa65cb4a1a3926afe6943c
  • Pointer size: 131 Bytes
  • Size of remote file: 428 kB

Git LFS Details

  • SHA256: 0dc592a8808b3ef083092d4c703b3af2fa4bddc39a4f96804318b5c4f268b447
  • Pointer size: 131 Bytes
  • Size of remote file: 427 kB
Vaani/VaaniLDM/samples/x0_1.png CHANGED
Vaani/VaaniLDM/samples/x0_10.png CHANGED
Vaani/VaaniLDM/samples/x0_100.png CHANGED
Vaani/VaaniLDM/samples/x0_101.png CHANGED
Vaani/VaaniLDM/samples/x0_102.png CHANGED
Vaani/VaaniLDM/samples/x0_103.png CHANGED
Vaani/VaaniLDM/samples/x0_104.png CHANGED
Vaani/VaaniLDM/samples/x0_105.png CHANGED
Vaani/VaaniLDM/samples/x0_106.png CHANGED
Vaani/VaaniLDM/samples/x0_107.png CHANGED
Vaani/VaaniLDM/samples/x0_108.png CHANGED
Vaani/VaaniLDM/samples/x0_109.png CHANGED
Vaani/VaaniLDM/samples/x0_11.png CHANGED
Vaani/VaaniLDM/samples/x0_110.png CHANGED
Vaani/VaaniLDM/samples/x0_111.png CHANGED
Vaani/VaaniLDM/samples/x0_112.png CHANGED
Vaani/VaaniLDM/samples/x0_113.png CHANGED
Vaani/VaaniLDM/samples/x0_114.png CHANGED
Vaani/VaaniLDM/samples/x0_115.png CHANGED
Vaani/VaaniLDM/samples/x0_116.png CHANGED
Vaani/VaaniLDM/samples/x0_117.png CHANGED
Vaani/VaaniLDM/samples/x0_118.png CHANGED
Vaani/VaaniLDM/samples/x0_119.png CHANGED
Vaani/VaaniLDM/samples/x0_12.png CHANGED
Vaani/VaaniLDM/samples/x0_120.png CHANGED
Vaani/VaaniLDM/samples/x0_121.png CHANGED
Vaani/VaaniLDM/samples/x0_122.png CHANGED
Vaani/VaaniLDM/samples/x0_123.png CHANGED
Vaani/VaaniLDM/samples/x0_124.png CHANGED
Vaani/VaaniLDM/samples/x0_125.png CHANGED
Vaani/VaaniLDM/samples/x0_126.png CHANGED
Vaani/VaaniLDM/samples/x0_127.png CHANGED