alpha31476 commited on
Commit
0733b34
·
verified ·
1 Parent(s): 0c51fb1

Image Audio Alingment Train OpenClip

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +277 -0
  2. Vaani/AndhraPradesh_Anantpur_meta.csv +3 -0
  3. Vaani/AndhraPradesh_Anantpur_meta.parquet +3 -0
  4. Vaani/Img_Audio_Alignment/=12.2.0 +9 -0
  5. Vaani/Img_Audio_Alignment/Vaani_SD2_1_CSIP_Test.ipynb +0 -0
  6. Vaani/Img_Audio_Alignment/_2.1.1_Train_OpenCLIP.py +0 -0
  7. Vaani/Img_Audio_Alignment/_2.1_Train.py +96 -37
  8. Vaani/Img_Audio_Alignment/_2.1_Train_OpenCLIP.py +1834 -0
  9. Vaani/Img_Audio_Alignment/_2.2_OpenCLIP.py +0 -0
  10. Vaani/Img_Audio_Alignment/_2_Train.ipynb +379 -50
  11. Vaani/Img_Audio_Alignment/available_img_audios_TEST2.csv +0 -0
  12. Vaani/Img_Audio_Alignment/available_img_audios_TRAIN2.csv +0 -0
  13. Vaani/Img_Audio_Alignment/checkpoints/csip/csip_best_epoch_201.pt +3 -0
  14. Vaani/Img_Audio_Alignment/checkpoints/csip/csip_best_epoch_202.pt +3 -0
  15. Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_102_loss_4.1355.png +3 -0
  16. Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_103_loss_4.1354.png +3 -0
  17. Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_104_loss_4.1349.png +3 -0
  18. Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_105_loss_4.1342.png +3 -0
  19. Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_106_loss_4.1335.png +3 -0
  20. Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_107_loss_4.1329.png +3 -0
  21. Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_108_loss_4.1327.png +3 -0
  22. Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_109_loss_4.1320.png +3 -0
  23. Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_10_loss_4.1517.png +3 -0
  24. Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_110_loss_4.1312.png +3 -0
  25. Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_112_loss_4.1300.png +3 -0
  26. Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_113_loss_4.1299.png +3 -0
  27. Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_115_loss_4.1281.png +3 -0
  28. Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_116_loss_4.1266.png +3 -0
  29. Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_119_loss_4.1249.png +3 -0
  30. Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_11_loss_4.1514.png +3 -0
  31. Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_120_loss_4.1244.png +3 -0
  32. Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_122_loss_4.1238.png +3 -0
  33. Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_123_loss_4.1229.png +3 -0
  34. Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_125_loss_4.1207.png +3 -0
  35. Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_126_loss_4.1206.png +3 -0
  36. Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_127_loss_4.1205.png +3 -0
  37. Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_128_loss_4.1183.png +3 -0
  38. Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_12_loss_4.1511.png +3 -0
  39. Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_130_loss_4.1173.png +3 -0
  40. Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_131_loss_4.1162.png +3 -0
  41. Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_132_loss_4.1158.png +3 -0
  42. Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_133_loss_4.1141.png +3 -0
  43. Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_134_loss_4.1138.png +3 -0
  44. Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_135_loss_4.1132.png +3 -0
  45. Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_136_loss_4.1129.png +3 -0
  46. Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_137_loss_4.1123.png +3 -0
  47. Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_138_loss_4.1104.png +3 -0
  48. Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_139_loss_4.1101.png +3 -0
  49. Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_13_loss_4.1507.png +3 -0
  50. Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_140_loss_4.1085.png +3 -0
.gitattributes CHANGED
@@ -143,3 +143,280 @@ Vaani/imageBY.csv filter=lfs diff=lfs merge=lfs -text
143
  Vaani/imageBY2.csv filter=lfs diff=lfs merge=lfs -text
144
  Vaani/imageBY3.csv filter=lfs diff=lfs merge=lfs -text
145
  Vaani/Image-Audio-Hindi.csv filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  Vaani/imageBY2.csv filter=lfs diff=lfs merge=lfs -text
144
  Vaani/imageBY3.csv filter=lfs diff=lfs merge=lfs -text
145
  Vaani/Image-Audio-Hindi.csv filter=lfs diff=lfs merge=lfs -text
146
+ Vaani/AndhraPradesh_Anantpur_meta.csv filter=lfs diff=lfs merge=lfs -text
147
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_102_loss_4.1355.png filter=lfs diff=lfs merge=lfs -text
148
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_103_loss_4.1354.png filter=lfs diff=lfs merge=lfs -text
149
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_104_loss_4.1349.png filter=lfs diff=lfs merge=lfs -text
150
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_105_loss_4.1342.png filter=lfs diff=lfs merge=lfs -text
151
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_106_loss_4.1335.png filter=lfs diff=lfs merge=lfs -text
152
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_107_loss_4.1329.png filter=lfs diff=lfs merge=lfs -text
153
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_108_loss_4.1327.png filter=lfs diff=lfs merge=lfs -text
154
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_109_loss_4.1320.png filter=lfs diff=lfs merge=lfs -text
155
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_10_loss_4.1517.png filter=lfs diff=lfs merge=lfs -text
156
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_110_loss_4.1312.png filter=lfs diff=lfs merge=lfs -text
157
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_112_loss_4.1300.png filter=lfs diff=lfs merge=lfs -text
158
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_113_loss_4.1299.png filter=lfs diff=lfs merge=lfs -text
159
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_115_loss_4.1281.png filter=lfs diff=lfs merge=lfs -text
160
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_116_loss_4.1266.png filter=lfs diff=lfs merge=lfs -text
161
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_119_loss_4.1249.png filter=lfs diff=lfs merge=lfs -text
162
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_11_loss_4.1514.png filter=lfs diff=lfs merge=lfs -text
163
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_120_loss_4.1244.png filter=lfs diff=lfs merge=lfs -text
164
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_122_loss_4.1238.png filter=lfs diff=lfs merge=lfs -text
165
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_123_loss_4.1229.png filter=lfs diff=lfs merge=lfs -text
166
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_125_loss_4.1207.png filter=lfs diff=lfs merge=lfs -text
167
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_126_loss_4.1206.png filter=lfs diff=lfs merge=lfs -text
168
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_127_loss_4.1205.png filter=lfs diff=lfs merge=lfs -text
169
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_128_loss_4.1183.png filter=lfs diff=lfs merge=lfs -text
170
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_12_loss_4.1511.png filter=lfs diff=lfs merge=lfs -text
171
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_130_loss_4.1173.png filter=lfs diff=lfs merge=lfs -text
172
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_131_loss_4.1162.png filter=lfs diff=lfs merge=lfs -text
173
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_132_loss_4.1158.png filter=lfs diff=lfs merge=lfs -text
174
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_133_loss_4.1141.png filter=lfs diff=lfs merge=lfs -text
175
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_134_loss_4.1138.png filter=lfs diff=lfs merge=lfs -text
176
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_135_loss_4.1132.png filter=lfs diff=lfs merge=lfs -text
177
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_136_loss_4.1129.png filter=lfs diff=lfs merge=lfs -text
178
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_137_loss_4.1123.png filter=lfs diff=lfs merge=lfs -text
179
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_138_loss_4.1104.png filter=lfs diff=lfs merge=lfs -text
180
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_139_loss_4.1101.png filter=lfs diff=lfs merge=lfs -text
181
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_13_loss_4.1507.png filter=lfs diff=lfs merge=lfs -text
182
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_140_loss_4.1085.png filter=lfs diff=lfs merge=lfs -text
183
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_141_loss_4.1082.png filter=lfs diff=lfs merge=lfs -text
184
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_142_loss_4.1076.png filter=lfs diff=lfs merge=lfs -text
185
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_143_loss_4.1070.png filter=lfs diff=lfs merge=lfs -text
186
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_144_loss_4.1045.png filter=lfs diff=lfs merge=lfs -text
187
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_145_loss_4.1044.png filter=lfs diff=lfs merge=lfs -text
188
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_146_loss_4.1038.png filter=lfs diff=lfs merge=lfs -text
189
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_147_loss_4.1026.png filter=lfs diff=lfs merge=lfs -text
190
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_148_loss_4.1021.png filter=lfs diff=lfs merge=lfs -text
191
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_149_loss_4.1011.png filter=lfs diff=lfs merge=lfs -text
192
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_14_loss_4.1506.png filter=lfs diff=lfs merge=lfs -text
193
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_150_loss_4.1005.png filter=lfs diff=lfs merge=lfs -text
194
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_151_loss_4.0997.png filter=lfs diff=lfs merge=lfs -text
195
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_153_loss_4.0984.png filter=lfs diff=lfs merge=lfs -text
196
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_154_loss_4.0942.png filter=lfs diff=lfs merge=lfs -text
197
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_157_loss_4.0923.png filter=lfs diff=lfs merge=lfs -text
198
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_159_loss_4.0917.png filter=lfs diff=lfs merge=lfs -text
199
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_15_loss_4.1504.png filter=lfs diff=lfs merge=lfs -text
200
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_160_loss_4.0917.png filter=lfs diff=lfs merge=lfs -text
201
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_162_loss_4.0895.png filter=lfs diff=lfs merge=lfs -text
202
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_163_loss_4.0868.png filter=lfs diff=lfs merge=lfs -text
203
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_164_loss_4.0853.png filter=lfs diff=lfs merge=lfs -text
204
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_168_loss_4.0838.png filter=lfs diff=lfs merge=lfs -text
205
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_169_loss_4.0834.png filter=lfs diff=lfs merge=lfs -text
206
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_16_loss_4.1503.png filter=lfs diff=lfs merge=lfs -text
207
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_170_loss_4.0828.png filter=lfs diff=lfs merge=lfs -text
208
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_171_loss_4.0814.png filter=lfs diff=lfs merge=lfs -text
209
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_172_loss_4.0798.png filter=lfs diff=lfs merge=lfs -text
210
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_173_loss_4.0775.png filter=lfs diff=lfs merge=lfs -text
211
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_177_loss_4.0775.png filter=lfs diff=lfs merge=lfs -text
212
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_178_loss_4.0767.png filter=lfs diff=lfs merge=lfs -text
213
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_179_loss_4.0740.png filter=lfs diff=lfs merge=lfs -text
214
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_17_loss_4.1498.png filter=lfs diff=lfs merge=lfs -text
215
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_182_loss_4.0712.png filter=lfs diff=lfs merge=lfs -text
216
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_184_loss_4.0707.png filter=lfs diff=lfs merge=lfs -text
217
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_186_loss_4.0706.png filter=lfs diff=lfs merge=lfs -text
218
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_187_loss_4.0704.png filter=lfs diff=lfs merge=lfs -text
219
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_189_loss_4.0687.png filter=lfs diff=lfs merge=lfs -text
220
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_18_loss_4.1493.png filter=lfs diff=lfs merge=lfs -text
221
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_191_loss_4.0667.png filter=lfs diff=lfs merge=lfs -text
222
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_193_loss_4.0662.png filter=lfs diff=lfs merge=lfs -text
223
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_194_loss_4.0662.png filter=lfs diff=lfs merge=lfs -text
224
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_195_loss_4.0659.png filter=lfs diff=lfs merge=lfs -text
225
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_199_loss_4.0650.png filter=lfs diff=lfs merge=lfs -text
226
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_19_loss_4.1492.png filter=lfs diff=lfs merge=lfs -text
227
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_1_loss_4.1553.png filter=lfs diff=lfs merge=lfs -text
228
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_201_loss_4.0621.png filter=lfs diff=lfs merge=lfs -text
229
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_202_loss_4.0603.png filter=lfs diff=lfs merge=lfs -text
230
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_20_loss_4.1488.png filter=lfs diff=lfs merge=lfs -text
231
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_21_loss_4.1485.png filter=lfs diff=lfs merge=lfs -text
232
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_22_loss_4.1484.png filter=lfs diff=lfs merge=lfs -text
233
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_23_loss_4.1483.png filter=lfs diff=lfs merge=lfs -text
234
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_24_loss_4.1477.png filter=lfs diff=lfs merge=lfs -text
235
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_25_loss_4.1474.png filter=lfs diff=lfs merge=lfs -text
236
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_26_loss_4.1473.png filter=lfs diff=lfs merge=lfs -text
237
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_27_loss_4.1471.png filter=lfs diff=lfs merge=lfs -text
238
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_28_loss_4.1467.png filter=lfs diff=lfs merge=lfs -text
239
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_29_loss_4.1463.png filter=lfs diff=lfs merge=lfs -text
240
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_2_loss_4.1552.png filter=lfs diff=lfs merge=lfs -text
241
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_30_loss_4.1461.png filter=lfs diff=lfs merge=lfs -text
242
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_32_loss_4.1460.png filter=lfs diff=lfs merge=lfs -text
243
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_33_loss_4.1456.png filter=lfs diff=lfs merge=lfs -text
244
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_34_loss_4.1452.png filter=lfs diff=lfs merge=lfs -text
245
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_35_loss_4.1449.png filter=lfs diff=lfs merge=lfs -text
246
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_36_loss_4.1449.png filter=lfs diff=lfs merge=lfs -text
247
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_37_loss_4.1445.png filter=lfs diff=lfs merge=lfs -text
248
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_38_loss_4.1443.png filter=lfs diff=lfs merge=lfs -text
249
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_3_loss_4.1549.png filter=lfs diff=lfs merge=lfs -text
250
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_40_loss_4.1438.png filter=lfs diff=lfs merge=lfs -text
251
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_41_loss_4.1434.png filter=lfs diff=lfs merge=lfs -text
252
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_42_loss_4.1433.png filter=lfs diff=lfs merge=lfs -text
253
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_43_loss_4.1430.png filter=lfs diff=lfs merge=lfs -text
254
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_44_loss_4.1426.png filter=lfs diff=lfs merge=lfs -text
255
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_46_loss_4.1422.png filter=lfs diff=lfs merge=lfs -text
256
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_48_loss_4.1416.png filter=lfs diff=lfs merge=lfs -text
257
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_49_loss_4.1414.png filter=lfs diff=lfs merge=lfs -text
258
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_4_loss_4.1541.png filter=lfs diff=lfs merge=lfs -text
259
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_50_loss_4.1412.png filter=lfs diff=lfs merge=lfs -text
260
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_51_loss_4.1410.png filter=lfs diff=lfs merge=lfs -text
261
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_53_loss_4.1405.png filter=lfs diff=lfs merge=lfs -text
262
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_54_loss_4.1404.png filter=lfs diff=lfs merge=lfs -text
263
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_55_loss_4.1402.png filter=lfs diff=lfs merge=lfs -text
264
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_56_loss_4.1399.png filter=lfs diff=lfs merge=lfs -text
265
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_58_loss_4.1394.png filter=lfs diff=lfs merge=lfs -text
266
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_5_loss_4.1535.png filter=lfs diff=lfs merge=lfs -text
267
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_61_loss_4.1391.png filter=lfs diff=lfs merge=lfs -text
268
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_62_loss_4.1387.png filter=lfs diff=lfs merge=lfs -text
269
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_63_loss_4.1387.png filter=lfs diff=lfs merge=lfs -text
270
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_65_loss_4.1382.png filter=lfs diff=lfs merge=lfs -text
271
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_67_loss_4.1379.png filter=lfs diff=lfs merge=lfs -text
272
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_69_loss_4.1377.png filter=lfs diff=lfs merge=lfs -text
273
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_6_loss_4.1528.png filter=lfs diff=lfs merge=lfs -text
274
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_71_loss_4.1374.png filter=lfs diff=lfs merge=lfs -text
275
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_72_loss_4.1373.png filter=lfs diff=lfs merge=lfs -text
276
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_75_loss_4.1371.png filter=lfs diff=lfs merge=lfs -text
277
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_76_loss_4.1371.png filter=lfs diff=lfs merge=lfs -text
278
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_79_loss_4.1366.png filter=lfs diff=lfs merge=lfs -text
279
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_7_loss_4.1527.png filter=lfs diff=lfs merge=lfs -text
280
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_81_loss_4.1366.png filter=lfs diff=lfs merge=lfs -text
281
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_85_loss_4.1363.png filter=lfs diff=lfs merge=lfs -text
282
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_8_loss_4.1522.png filter=lfs diff=lfs merge=lfs -text
283
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_96_loss_4.1361.png filter=lfs diff=lfs merge=lfs -text
284
+ Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_9_loss_4.1521.png filter=lfs diff=lfs merge=lfs -text
285
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_102_loss_4.1355.png filter=lfs diff=lfs merge=lfs -text
286
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_103_loss_4.1354.png filter=lfs diff=lfs merge=lfs -text
287
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_104_loss_4.1349.png filter=lfs diff=lfs merge=lfs -text
288
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_105_loss_4.1342.png filter=lfs diff=lfs merge=lfs -text
289
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_106_loss_4.1335.png filter=lfs diff=lfs merge=lfs -text
290
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_107_loss_4.1329.png filter=lfs diff=lfs merge=lfs -text
291
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_108_loss_4.1327.png filter=lfs diff=lfs merge=lfs -text
292
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_109_loss_4.1320.png filter=lfs diff=lfs merge=lfs -text
293
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_10_loss_4.1517.png filter=lfs diff=lfs merge=lfs -text
294
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_110_loss_4.1312.png filter=lfs diff=lfs merge=lfs -text
295
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_112_loss_4.1300.png filter=lfs diff=lfs merge=lfs -text
296
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_113_loss_4.1299.png filter=lfs diff=lfs merge=lfs -text
297
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_115_loss_4.1281.png filter=lfs diff=lfs merge=lfs -text
298
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_116_loss_4.1266.png filter=lfs diff=lfs merge=lfs -text
299
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_119_loss_4.1249.png filter=lfs diff=lfs merge=lfs -text
300
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_11_loss_4.1514.png filter=lfs diff=lfs merge=lfs -text
301
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_120_loss_4.1244.png filter=lfs diff=lfs merge=lfs -text
302
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_122_loss_4.1238.png filter=lfs diff=lfs merge=lfs -text
303
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_123_loss_4.1229.png filter=lfs diff=lfs merge=lfs -text
304
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_125_loss_4.1207.png filter=lfs diff=lfs merge=lfs -text
305
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_126_loss_4.1206.png filter=lfs diff=lfs merge=lfs -text
306
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_127_loss_4.1205.png filter=lfs diff=lfs merge=lfs -text
307
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_128_loss_4.1183.png filter=lfs diff=lfs merge=lfs -text
308
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_12_loss_4.1511.png filter=lfs diff=lfs merge=lfs -text
309
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_130_loss_4.1173.png filter=lfs diff=lfs merge=lfs -text
310
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_131_loss_4.1162.png filter=lfs diff=lfs merge=lfs -text
311
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_132_loss_4.1158.png filter=lfs diff=lfs merge=lfs -text
312
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_133_loss_4.1141.png filter=lfs diff=lfs merge=lfs -text
313
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_134_loss_4.1138.png filter=lfs diff=lfs merge=lfs -text
314
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_135_loss_4.1132.png filter=lfs diff=lfs merge=lfs -text
315
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_136_loss_4.1129.png filter=lfs diff=lfs merge=lfs -text
316
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_137_loss_4.1123.png filter=lfs diff=lfs merge=lfs -text
317
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_138_loss_4.1104.png filter=lfs diff=lfs merge=lfs -text
318
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_139_loss_4.1101.png filter=lfs diff=lfs merge=lfs -text
319
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_13_loss_4.1507.png filter=lfs diff=lfs merge=lfs -text
320
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_140_loss_4.1085.png filter=lfs diff=lfs merge=lfs -text
321
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_141_loss_4.1082.png filter=lfs diff=lfs merge=lfs -text
322
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_142_loss_4.1076.png filter=lfs diff=lfs merge=lfs -text
323
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_143_loss_4.1070.png filter=lfs diff=lfs merge=lfs -text
324
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_144_loss_4.1045.png filter=lfs diff=lfs merge=lfs -text
325
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_145_loss_4.1044.png filter=lfs diff=lfs merge=lfs -text
326
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_146_loss_4.1038.png filter=lfs diff=lfs merge=lfs -text
327
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_147_loss_4.1026.png filter=lfs diff=lfs merge=lfs -text
328
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_148_loss_4.1021.png filter=lfs diff=lfs merge=lfs -text
329
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_149_loss_4.1011.png filter=lfs diff=lfs merge=lfs -text
330
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_14_loss_4.1506.png filter=lfs diff=lfs merge=lfs -text
331
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_150_loss_4.1005.png filter=lfs diff=lfs merge=lfs -text
332
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_151_loss_4.0997.png filter=lfs diff=lfs merge=lfs -text
333
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_153_loss_4.0984.png filter=lfs diff=lfs merge=lfs -text
334
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_154_loss_4.0942.png filter=lfs diff=lfs merge=lfs -text
335
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_157_loss_4.0923.png filter=lfs diff=lfs merge=lfs -text
336
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_159_loss_4.0917.png filter=lfs diff=lfs merge=lfs -text
337
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_15_loss_4.1504.png filter=lfs diff=lfs merge=lfs -text
338
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_160_loss_4.0917.png filter=lfs diff=lfs merge=lfs -text
339
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_162_loss_4.0895.png filter=lfs diff=lfs merge=lfs -text
340
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_163_loss_4.0868.png filter=lfs diff=lfs merge=lfs -text
341
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_164_loss_4.0853.png filter=lfs diff=lfs merge=lfs -text
342
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_168_loss_4.0838.png filter=lfs diff=lfs merge=lfs -text
343
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_169_loss_4.0834.png filter=lfs diff=lfs merge=lfs -text
344
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_16_loss_4.1503.png filter=lfs diff=lfs merge=lfs -text
345
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_170_loss_4.0828.png filter=lfs diff=lfs merge=lfs -text
346
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_171_loss_4.0814.png filter=lfs diff=lfs merge=lfs -text
347
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_172_loss_4.0798.png filter=lfs diff=lfs merge=lfs -text
348
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_173_loss_4.0775.png filter=lfs diff=lfs merge=lfs -text
349
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_177_loss_4.0775.png filter=lfs diff=lfs merge=lfs -text
350
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_178_loss_4.0767.png filter=lfs diff=lfs merge=lfs -text
351
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_179_loss_4.0740.png filter=lfs diff=lfs merge=lfs -text
352
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_17_loss_4.1498.png filter=lfs diff=lfs merge=lfs -text
353
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_182_loss_4.0712.png filter=lfs diff=lfs merge=lfs -text
354
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_184_loss_4.0707.png filter=lfs diff=lfs merge=lfs -text
355
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_186_loss_4.0706.png filter=lfs diff=lfs merge=lfs -text
356
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_187_loss_4.0704.png filter=lfs diff=lfs merge=lfs -text
357
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_189_loss_4.0687.png filter=lfs diff=lfs merge=lfs -text
358
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_18_loss_4.1493.png filter=lfs diff=lfs merge=lfs -text
359
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_191_loss_4.0667.png filter=lfs diff=lfs merge=lfs -text
360
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_193_loss_4.0662.png filter=lfs diff=lfs merge=lfs -text
361
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_194_loss_4.0662.png filter=lfs diff=lfs merge=lfs -text
362
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_195_loss_4.0659.png filter=lfs diff=lfs merge=lfs -text
363
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_199_loss_4.0650.png filter=lfs diff=lfs merge=lfs -text
364
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_19_loss_4.1492.png filter=lfs diff=lfs merge=lfs -text
365
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_1_loss_4.1553.png filter=lfs diff=lfs merge=lfs -text
366
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_201_loss_4.0621.png filter=lfs diff=lfs merge=lfs -text
367
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_202_loss_4.0603.png filter=lfs diff=lfs merge=lfs -text
368
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_20_loss_4.1488.png filter=lfs diff=lfs merge=lfs -text
369
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_21_loss_4.1485.png filter=lfs diff=lfs merge=lfs -text
370
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_22_loss_4.1484.png filter=lfs diff=lfs merge=lfs -text
371
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_23_loss_4.1483.png filter=lfs diff=lfs merge=lfs -text
372
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_24_loss_4.1477.png filter=lfs diff=lfs merge=lfs -text
373
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_25_loss_4.1474.png filter=lfs diff=lfs merge=lfs -text
374
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_26_loss_4.1473.png filter=lfs diff=lfs merge=lfs -text
375
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_27_loss_4.1471.png filter=lfs diff=lfs merge=lfs -text
376
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_28_loss_4.1467.png filter=lfs diff=lfs merge=lfs -text
377
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_29_loss_4.1463.png filter=lfs diff=lfs merge=lfs -text
378
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_2_loss_4.1552.png filter=lfs diff=lfs merge=lfs -text
379
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_30_loss_4.1461.png filter=lfs diff=lfs merge=lfs -text
380
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_32_loss_4.1460.png filter=lfs diff=lfs merge=lfs -text
381
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_33_loss_4.1456.png filter=lfs diff=lfs merge=lfs -text
382
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_34_loss_4.1452.png filter=lfs diff=lfs merge=lfs -text
383
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_35_loss_4.1449.png filter=lfs diff=lfs merge=lfs -text
384
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_36_loss_4.1449.png filter=lfs diff=lfs merge=lfs -text
385
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_37_loss_4.1445.png filter=lfs diff=lfs merge=lfs -text
386
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_38_loss_4.1443.png filter=lfs diff=lfs merge=lfs -text
387
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_3_loss_4.1549.png filter=lfs diff=lfs merge=lfs -text
388
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_40_loss_4.1438.png filter=lfs diff=lfs merge=lfs -text
389
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_41_loss_4.1434.png filter=lfs diff=lfs merge=lfs -text
390
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_42_loss_4.1433.png filter=lfs diff=lfs merge=lfs -text
391
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_43_loss_4.1430.png filter=lfs diff=lfs merge=lfs -text
392
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_44_loss_4.1426.png filter=lfs diff=lfs merge=lfs -text
393
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_46_loss_4.1422.png filter=lfs diff=lfs merge=lfs -text
394
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_48_loss_4.1416.png filter=lfs diff=lfs merge=lfs -text
395
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_49_loss_4.1414.png filter=lfs diff=lfs merge=lfs -text
396
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_4_loss_4.1541.png filter=lfs diff=lfs merge=lfs -text
397
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_50_loss_4.1412.png filter=lfs diff=lfs merge=lfs -text
398
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_51_loss_4.1410.png filter=lfs diff=lfs merge=lfs -text
399
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_53_loss_4.1405.png filter=lfs diff=lfs merge=lfs -text
400
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_54_loss_4.1404.png filter=lfs diff=lfs merge=lfs -text
401
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_55_loss_4.1402.png filter=lfs diff=lfs merge=lfs -text
402
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_56_loss_4.1399.png filter=lfs diff=lfs merge=lfs -text
403
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_58_loss_4.1394.png filter=lfs diff=lfs merge=lfs -text
404
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_5_loss_4.1535.png filter=lfs diff=lfs merge=lfs -text
405
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_61_loss_4.1391.png filter=lfs diff=lfs merge=lfs -text
406
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_62_loss_4.1387.png filter=lfs diff=lfs merge=lfs -text
407
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_63_loss_4.1387.png filter=lfs diff=lfs merge=lfs -text
408
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_65_loss_4.1382.png filter=lfs diff=lfs merge=lfs -text
409
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_67_loss_4.1379.png filter=lfs diff=lfs merge=lfs -text
410
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_69_loss_4.1377.png filter=lfs diff=lfs merge=lfs -text
411
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_6_loss_4.1528.png filter=lfs diff=lfs merge=lfs -text
412
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_71_loss_4.1374.png filter=lfs diff=lfs merge=lfs -text
413
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_72_loss_4.1373.png filter=lfs diff=lfs merge=lfs -text
414
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_75_loss_4.1371.png filter=lfs diff=lfs merge=lfs -text
415
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_76_loss_4.1371.png filter=lfs diff=lfs merge=lfs -text
416
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_79_loss_4.1366.png filter=lfs diff=lfs merge=lfs -text
417
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_7_loss_4.1527.png filter=lfs diff=lfs merge=lfs -text
418
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_81_loss_4.1366.png filter=lfs diff=lfs merge=lfs -text
419
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_85_loss_4.1363.png filter=lfs diff=lfs merge=lfs -text
420
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_8_loss_4.1522.png filter=lfs diff=lfs merge=lfs -text
421
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_96_loss_4.1361.png filter=lfs diff=lfs merge=lfs -text
422
+ Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_9_loss_4.1521.png filter=lfs diff=lfs merge=lfs -text
Vaani/AndhraPradesh_Anantpur_meta.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d1d586211fb9bced5c71171960dc4180a9446ca8971106645951bc99686e4215
3
+ size 25038247
Vaani/AndhraPradesh_Anantpur_meta.parquet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5b7c4262ecbd3aa31824efb124af64e4fc9b90f4e97cb5f5468b18e542d05891
3
+ size 2001282
Vaani/Img_Audio_Alignment/=12.2.0 ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ Channels:
2
+ - conda-forge
3
+ - defaults
4
+ Platform: linux-64
5
+ Collecting package metadata (repodata.json): ...working... done
6
+ Solving environment: ...working... done
7
+
8
+ # All requested packages already installed.
9
+
Vaani/Img_Audio_Alignment/Vaani_SD2_1_CSIP_Test.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
Vaani/Img_Audio_Alignment/_2.1.1_Train_OpenCLIP.py ADDED
The diff for this file is too large to render. See raw diff
 
Vaani/Img_Audio_Alignment/_2.1_Train.py CHANGED
@@ -66,6 +66,28 @@ from huggingface_hub.file_download import hf_hub_download
66
  from peft import get_peft_config, get_peft_model
67
  from transformers import CLIPVisionModel, AutoProcessor
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
  # ==================================================================
71
  # H T S - A T
@@ -1440,9 +1462,9 @@ model_fp = hf_hub_download(model_repo, model_name[version])
1440
 
1441
  model_state_dict = torch.load(model_fp, map_location=torch.device('cpu'))['model']
1442
  clap.load_state_dict(model_state_dict, strict=False)
1443
- clap.eval()
1444
 
1445
- clap_audio_encoder = clap.audio_encoder.eval().to(device)
1446
 
1447
 
1448
  # ENGLISH_AUDIO_DIR = r"/home/IITB/ai-at-ieor/23m1521/datasets/Vaani/Audios/English"
@@ -1509,6 +1531,9 @@ class CSIP(nn.Module):
1509
  self.image_encoder = image_encoder # CLIPVisionModel
1510
  self.audio_encoder = audio_encoder # CLAP_audio_encoder
1511
 
 
 
 
1512
  # self.image_proj = nn.Linear(dim_img, dim_emb)
1513
  self.audio_proj = nn.Linear(dim_audio, dim_emb)
1514
 
@@ -1517,9 +1542,9 @@ class CSIP(nn.Module):
1517
 
1518
  def forward(self, images, audios):
1519
  # Step 1: Feature extraction
1520
- with torch.no_grad():
1521
- with torch.inference_mode():
1522
- image_features = self.image_encoder(images).pooler_output # shape: [n, dim_img]
1523
  audio_features = self.audio_encoder(audios)[0] # shape: [n, dim_audio]
1524
 
1525
  # Step 2: Project and normalize
@@ -1569,7 +1594,7 @@ train_df = pd.read_csv("/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/Img_Audio
1569
  test_df = pd.read_csv("/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/Img_Audio_Alignment/available_img_audios_TEST.csv")
1570
  train_dataset = VaaniImageAudioDataset(train_df)
1571
  test_dataset = VaaniImageAudioDataset(test_df)
1572
- BATCH_SIZE = 32
1573
 
1574
  print('Train Dataset:', len(train_dataset))
1575
  print('Test Dataset:', len(test_dataset))
@@ -1598,12 +1623,27 @@ test_dataloader = torch.utils.data.DataLoader(
1598
  )
1599
 
1600
  batch = next(iter(train_dataloader))
1601
- print("Image batch:", batch['image_tensor'].shape)
1602
- print("Audio batch:", batch['audio_tensor'].shape)
1603
-
1604
-
1605
- csip_model = CSIP(clip_vision_model.eval(), peft_clap_audio_encoder).to(device)
1606
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1607
 
1608
  # loss, logits, probs = csip_model(batch['image_tensor'].to(device), batch['audio_tensor'].to(device))
1609
  # loss, logits, probs, logits.shape, probs.shape
@@ -1636,7 +1676,7 @@ def save_checkpoint(state, checkpoint_dir, epoch, max_checkpoints=2):
1636
  os.remove(os.path.join(checkpoint_dir, to_delete))
1637
 
1638
 
1639
- def load_checkpoint(checkpoint_dir, model, optimizer):
1640
  checkpoints = sorted(
1641
  [f for f in os.listdir(checkpoint_dir) if f.startswith("csip_best_epoch_")],
1642
  key=lambda x: int(x.split("_")[-1].split(".")[0])
@@ -1650,6 +1690,7 @@ def load_checkpoint(checkpoint_dir, model, optimizer):
1650
  checkpoint = torch.load(path)
1651
  model.load_state_dict(checkpoint['model_state'])
1652
  optimizer.load_state_dict(checkpoint['optimizer_state'])
 
1653
  start_epoch = checkpoint['epoch']
1654
  best_loss = checkpoint['best_loss']
1655
  print(f"Resumed training from epoch {start_epoch+1} with best loss {best_loss:.4f}")
@@ -1696,7 +1737,7 @@ def save_similarity_heatmaps(logits, epoch, loss, save_dir, writer):
1696
 
1697
 
1698
  def train_model(model, train_loader, test_loader,
1699
- optimizer, device, log_dir,
1700
  checkpoint_dir, resume=False, epochs=10):
1701
 
1702
  os.makedirs(log_dir, exist_ok=True)
@@ -1710,41 +1751,45 @@ def train_model(model, train_loader, test_loader,
1710
  best_epoch = -1
1711
 
1712
  if resume:
1713
- start_epoch, best_loss = load_checkpoint(checkpoint_dir, model, optimizer)
1714
 
1715
  # If resuming, don't overwrite the CSV
1716
  if not (resume and os.path.exists(csv_path)):
1717
  with open(csv_path, mode='w', newline='') as f:
1718
  writer_csv = csv.writer(f)
1719
- writer_csv.writerow(["Epoch", "Train Loss", "Test Loss", "Best Loss", "Best Epoch"])
1720
 
1721
  for epoch in trange(start_epoch, epochs, colour='yellow', dynamic_ncols=True):
1722
  train_losses = []
1723
  test_losses = []
1724
 
1725
- train_loop = tqdm(train_loader, desc=f"[Train Epoch {epoch+1}]", colour='blue', dynamic_ncols=True)
1726
  for batch in train_loop:
1727
  images = batch['image_tensor'].to(device)
1728
  audios = batch['audio_tensor'].to(device)
1729
  loss, logits, probs = train_batch(model, images, audios, optimizer)
1730
  train_losses.append(loss)
1731
- train_loop.set_postfix(train_loss=loss)
1732
 
1733
- test_loop = tqdm(test_loader, desc=f"[Test Epoch {epoch+1}]", colour='red', dynamic_ncols=True)
1734
  for batch in test_loop:
1735
  images = batch['image_tensor'].to(device)
1736
  audios = batch['audio_tensor'].to(device)
1737
  loss, logits, probs = evaluate_batch(model, images, audios)
1738
  test_losses.append(loss)
1739
- test_loop.set_postfix(test_loss=loss)
1740
 
1741
  avg_train_loss = sum(train_losses) / len(train_losses)
1742
  avg_test_loss = sum(test_losses) / len(test_losses)
 
 
1743
 
1744
  writer.add_scalar("Loss/Train", avg_train_loss, epoch + 1)
1745
  writer.add_scalar("Loss/Test", avg_test_loss, epoch + 1)
 
1746
 
1747
- print(f"Epoch {epoch+1} | Train Loss: {avg_train_loss:.4f} | Test Loss: {avg_test_loss:.4f}")
 
1748
 
1749
  if avg_test_loss < best_loss:
1750
  save_similarity_heatmaps(logits, epoch, avg_test_loss, checkpoint_dir, writer)
@@ -1754,35 +1799,49 @@ def train_model(model, train_loader, test_loader,
1754
  'epoch': epoch,
1755
  'model_state': model.state_dict(),
1756
  'optimizer_state': optimizer.state_dict(),
1757
- 'best_loss': best_loss
 
1758
  }, checkpoint_dir, epoch)
1759
  print(f">>> Saved new best model at epoch {epoch+1}")
1760
 
1761
- # Append row to CSV
1762
  with open(csv_path, mode='a', newline='') as f:
1763
  writer_csv = csv.writer(f)
1764
- writer_csv.writerow([epoch + 1, avg_train_loss, avg_test_loss, best_loss, best_epoch])
1765
 
1766
  writer.close()
1767
 
1768
 
 
 
 
 
 
 
1769
 
1770
-
1771
- learning_rate = 1e-20
1772
  epochs = 400
1773
  optimizer = torch.optim.AdamW(csip_model.parameters(), lr=learning_rate)
 
 
 
 
 
1774
  train_model(
1775
- model=csip_model,
1776
- train_loader=train_dataloader,
1777
- test_loader=test_dataloader,
1778
- optimizer=optimizer,
1779
- device=device,
1780
- log_dir="runs/csip",
1781
- checkpoint_dir="checkpoints/csip",
1782
- resume=True,
1783
- epochs=epochs
1784
- )
 
1785
 
1786
 
1787
  # tensorboard --logdir=/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/Img_Audio_Alignment/runs --port=6006
1788
- # 127.0.0.1:40697
 
 
 
66
  from peft import get_peft_config, get_peft_model
67
  from transformers import CLIPVisionModel, AutoProcessor
68
 
69
+ from watermark import watermark
70
+ print(watermark(
71
+ author='Ashish',
72
+ # email='[email protected]',
73
+ current_date=True,
74
+ datename=True,
75
+ current_time=True,
76
+ iso8601=True,
77
+ timezone=True,
78
+ updated=True,
79
+ custom_time=None,
80
+ python=True,
81
+ # packages="torch,torchvision,numpy",
82
+ conda=True,
83
+ hostname=True,
84
+ machine=True,
85
+ watermark=False,
86
+ iversions=True,
87
+ gpu=True,
88
+ globals_=globals()
89
+ ))
90
+
91
 
92
  # ==================================================================
93
  # H T S - A T
 
1462
 
1463
  model_state_dict = torch.load(model_fp, map_location=torch.device('cpu'))['model']
1464
  clap.load_state_dict(model_state_dict, strict=False)
1465
+ # clap.eval()
1466
 
1467
+ clap_audio_encoder = clap.audio_encoder.to(device)
1468
 
1469
 
1470
  # ENGLISH_AUDIO_DIR = r"/home/IITB/ai-at-ieor/23m1521/datasets/Vaani/Audios/English"
 
1531
  self.image_encoder = image_encoder # CLIPVisionModel
1532
  self.audio_encoder = audio_encoder # CLAP_audio_encoder
1533
 
1534
+ for param in self.image_encoder.parameters():
1535
+ param.requires_grad = False
1536
+
1537
  # self.image_proj = nn.Linear(dim_img, dim_emb)
1538
  self.audio_proj = nn.Linear(dim_audio, dim_emb)
1539
 
 
1542
 
1543
  def forward(self, images, audios):
1544
  # Step 1: Feature extraction
1545
+ # with torch.no_grad():
1546
+ # with torch.inference_mode():
1547
+ image_features = self.image_encoder(images).pooler_output # shape: [n, dim_img]
1548
  audio_features = self.audio_encoder(audios)[0] # shape: [n, dim_audio]
1549
 
1550
  # Step 2: Project and normalize
 
1594
  test_df = pd.read_csv("/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/Img_Audio_Alignment/available_img_audios_TEST.csv")
1595
  train_dataset = VaaniImageAudioDataset(train_df)
1596
  test_dataset = VaaniImageAudioDataset(test_df)
1597
+ BATCH_SIZE = 64
1598
 
1599
  print('Train Dataset:', len(train_dataset))
1600
  print('Test Dataset:', len(test_dataset))
 
1623
  )
1624
 
1625
  batch = next(iter(train_dataloader))
1626
+ image_tensor_batch = batch['image_tensor']
1627
+ audio_tensor_batch = batch['audio_tensor']
1628
+ print("Image batch shape:", image_tensor_batch.shape) # [BATCH_SIZE, 3, 224, 224]
1629
+ print("Audio batch shape:", audio_tensor_batch.shape) # [BATCH_SIZE, 1, 44100]
1630
+
1631
+
1632
+ csip_model = CSIP(clip_vision_model, peft_clap_audio_encoder).to(device)
1633
+
1634
+ from torchinfo import summary
1635
+ import subprocess
1636
+ summary(model=csip_model,
1637
+ input_data=((image_tensor_batch.to(device)), (audio_tensor_batch.to(device))),
1638
+ # input_size = (1, 3, config.IMAGE_SIZE, config.IMAGE_SIZE),
1639
+ dtypes=[torch.long],
1640
+ col_names = ["input_size", "output_size", "num_params", "trainable", "params_percent"],
1641
+ col_width=20,
1642
+ row_settings=["var_names"],
1643
+ depth = 2,
1644
+ # verbose=2,
1645
+ # device=device
1646
+ )
1647
 
1648
  # loss, logits, probs = csip_model(batch['image_tensor'].to(device), batch['audio_tensor'].to(device))
1649
  # loss, logits, probs, logits.shape, probs.shape
 
1676
  os.remove(os.path.join(checkpoint_dir, to_delete))
1677
 
1678
 
1679
+ def load_checkpoint(checkpoint_dir, model, optimizer, scheduler):
1680
  checkpoints = sorted(
1681
  [f for f in os.listdir(checkpoint_dir) if f.startswith("csip_best_epoch_")],
1682
  key=lambda x: int(x.split("_")[-1].split(".")[0])
 
1690
  checkpoint = torch.load(path)
1691
  model.load_state_dict(checkpoint['model_state'])
1692
  optimizer.load_state_dict(checkpoint['optimizer_state'])
1693
+ scheduler.load_state_dict(checkpoint['scheduler_state'])
1694
  start_epoch = checkpoint['epoch']
1695
  best_loss = checkpoint['best_loss']
1696
  print(f"Resumed training from epoch {start_epoch+1} with best loss {best_loss:.4f}")
 
1737
 
1738
 
1739
  def train_model(model, train_loader, test_loader,
1740
+ optimizer, scheduler, device, log_dir,
1741
  checkpoint_dir, resume=False, epochs=10):
1742
 
1743
  os.makedirs(log_dir, exist_ok=True)
 
1751
  best_epoch = -1
1752
 
1753
  if resume:
1754
+ start_epoch, best_loss = load_checkpoint(checkpoint_dir, model, optimizer, scheduler)
1755
 
1756
  # If resuming, don't overwrite the CSV
1757
  if not (resume and os.path.exists(csv_path)):
1758
  with open(csv_path, mode='w', newline='') as f:
1759
  writer_csv = csv.writer(f)
1760
+ writer_csv.writerow(["Epoch", "Best Epoch", "Train Loss", "Test Loss", "Best Loss", "Learning Rate"])
1761
 
1762
  for epoch in trange(start_epoch, epochs, colour='yellow', dynamic_ncols=True):
1763
  train_losses = []
1764
  test_losses = []
1765
 
1766
+ train_loop = tqdm(train_loader, desc=f"[TrainEp {epoch+1}]", colour='blue', dynamic_ncols=True)
1767
  for batch in train_loop:
1768
  images = batch['image_tensor'].to(device)
1769
  audios = batch['audio_tensor'].to(device)
1770
  loss, logits, probs = train_batch(model, images, audios, optimizer)
1771
  train_losses.append(loss)
1772
+ train_loop.set_postfix(trainLoss=loss)
1773
 
1774
+ test_loop = tqdm(test_loader, desc=f"[TestEp {epoch+1}]", colour='red', dynamic_ncols=True)
1775
  for batch in test_loop:
1776
  images = batch['image_tensor'].to(device)
1777
  audios = batch['audio_tensor'].to(device)
1778
  loss, logits, probs = evaluate_batch(model, images, audios)
1779
  test_losses.append(loss)
1780
+ test_loop.set_postfix(testLoss=loss)
1781
 
1782
  avg_train_loss = sum(train_losses) / len(train_losses)
1783
  avg_test_loss = sum(test_losses) / len(test_losses)
1784
+
1785
+ current_lr = optimizer.param_groups[0]['lr']
1786
 
1787
  writer.add_scalar("Loss/Train", avg_train_loss, epoch + 1)
1788
  writer.add_scalar("Loss/Test", avg_test_loss, epoch + 1)
1789
+ writer.add_scalar("Learning Rate", current_lr, epoch + 1)
1790
 
1791
+ print(f"Epoch {epoch+1} | Train Loss: {avg_train_loss:.4f} | \
1792
+ Test Loss: {avg_test_loss:.4f} | LR: {current_lr:.2e}")
1793
 
1794
  if avg_test_loss < best_loss:
1795
  save_similarity_heatmaps(logits, epoch, avg_test_loss, checkpoint_dir, writer)
 
1799
  'epoch': epoch,
1800
  'model_state': model.state_dict(),
1801
  'optimizer_state': optimizer.state_dict(),
1802
+ 'best_loss': best_loss,
1803
+ 'scheduler_state': scheduler.state_dict() if scheduler else None
1804
  }, checkpoint_dir, epoch)
1805
  print(f">>> Saved new best model at epoch {epoch+1}")
1806
 
1807
+ scheduler.step()
1808
  with open(csv_path, mode='a', newline='') as f:
1809
  writer_csv = csv.writer(f)
1810
+ writer_csv.writerow([epoch + 1, best_epoch, avg_train_loss, avg_test_loss, best_loss, current_lr])
1811
 
1812
  writer.close()
1813
 
1814
 
1815
+ subprocess.run([
1816
+ "rm",
1817
+ "-rf",
1818
+ "/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/Img_Audio_Alignment/checkpoints",
1819
+ "/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/Img_Audio_Alignment/runs"
1820
+ ])
1821
 
1822
+ learning_rate = 1e-4
 
1823
  epochs = 400
1824
  optimizer = torch.optim.AdamW(csip_model.parameters(), lr=learning_rate)
1825
+ warmup_epochs = 10
1826
+ warmup = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.01, total_iters=warmup_epochs)
1827
+ cosine_restarts = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=90, T_mult=2, eta_min=1e-10) # T_0 adjusted for warmup
1828
+ scheduler = torch.optim.lr_scheduler.SequentialLR(optimizer, [warmup, cosine_restarts], milestones=[warmup_epochs])
1829
+
1830
  train_model(
1831
+ model=csip_model,
1832
+ train_loader=train_dataloader,
1833
+ test_loader=test_dataloader,
1834
+ optimizer=optimizer,
1835
+ scheduler=scheduler,
1836
+ device=device,
1837
+ log_dir="runs/csip",
1838
+ checkpoint_dir="checkpoints/csip",
1839
+ resume=True,
1840
+ epochs=epochs
1841
+ )
1842
 
1843
 
1844
  # tensorboard --logdir=/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/Img_Audio_Alignment/runs --port=6006
1845
+ # 127.0.0.1:40697
1846
+
1847
+ # tensorboard --logdir=/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/Img_Audio_Alignment/runs --port=6006 --host=0.0.0.0
Vaani/Img_Audio_Alignment/_2.1_Train_OpenCLIP.py ADDED
@@ -0,0 +1,1834 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ==================================================================
2
+ # L A T E N T D I F F U S I O N M O D E L
3
+ # ==================================================================
4
+ # Author : Ashish Kumar Uchadiya
5
+ # Created : May 11, 2025
6
+ # Description: This script implements the training of a VQ-VAE model for
7
+ # image reconstruction, integrated with Latent Diffusion Models (LDMs) and
8
+ # audio conditioning. The VQ-VAE maps images to a discrete latent space,
9
+ # which is then modeled by the LDM for learning a diffusion process over the
10
+ # compressed representation. Audio features are used as conditioning inputs
11
+ # to guide the generation process. The training minimizes a combination of
12
+ # LPIPS (Learned Perceptual Image Patch Similarity) loss for perceptual
13
+ # fidelity and PatchGAN loss to enforce local realism. This setup enables
14
+ # efficient and semantically-aware generation of high-quality images driven
15
+ # by audio cues.
16
+ # ==================================================================
17
+ # I M P O R T S
18
+ # ==================================================================
19
+ from __future__ import annotations
20
+ import warnings
21
+ warnings.filterwarnings("ignore")
22
+
23
+
24
+ import os
25
+ import io
26
+ import sys
27
+ import math
28
+ import random
29
+ import collections
30
+ import collections.abc
31
+ import re
32
+ from itertools import repeat
33
+ from pathlib import Path
34
+ from typing import Optional, Tuple, Union, List, Dict
35
+
36
+ import csv
37
+ import numpy as np
38
+ import pandas as pd
39
+ from PIL import Image
40
+ import seaborn as sns
41
+ import matplotlib.pyplot as plt
42
+ from tqdm import trange, tqdm
43
+
44
+ import torch
45
+ import torch.nn.functional as F
46
+ from torch import nn
47
+ from torch.nn.init import _calculate_fan_in_and_fan_out
48
+ import torch.utils.checkpoint as checkpoint
49
+
50
+ import torchvision as tv
51
+ from torchvision.transforms import v2
52
+ from torch.utils.tensorboard import SummaryWriter
53
+ # from tensorboardX import SummaryWriter
54
+
55
+ os.environ["CUDA_VISIBLE_DEVICES"] = "1"
56
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
57
+ print(f"Using device: {device}")
58
+
59
+ import torchaudio
60
+ import torchaudio.transforms as T
61
+ from torchlibrosa.stft import Spectrogram, LogmelFilterBank
62
+ from torchlibrosa.augmentation import SpecAugmentation
63
+
64
+ from transformers import AutoModel, AutoTokenizer, logging
65
+ from huggingface_hub.file_download import hf_hub_download
66
+ from huggingface_hub.file_download import hf_hub_download
67
+ from peft import get_peft_config, get_peft_model
68
+ from transformers import CLIPVisionModel, AutoProcessor
69
+
70
+ from watermark import watermark
71
+ print(watermark(
72
+ author='Ashish',
73
+ # email='[email protected]',
74
+ current_date=True,
75
+ datename=True,
76
+ current_time=True,
77
+ iso8601=True,
78
+ timezone=True,
79
+ updated=True,
80
+ custom_time=None,
81
+ python=True,
82
+ # packages="torch,torchvision,numpy",
83
+ conda=True,
84
+ hostname=True,
85
+ machine=True,
86
+ watermark=False,
87
+ iversions=True,
88
+ gpu=True,
89
+ globals_=globals()
90
+ ))
91
+
92
+
93
+ # ==================================================================
94
+ # H T S - A T
95
+ # ==================================================================
96
+ class HTSATConfig:
97
+ # Ke Chen
98
99
+ # HTS-AT: A HIERARCHICAL TOKEN-SEMANTIC AUDIO TRANSFORMER FOR SOUND CLASSIFICATION AND DETECTION
100
+ # The configuration for training the model
101
+
102
+ exp_name = "exp_htsat_pretrain" # the saved ckpt prefix name of the model
103
+ workspace = "/home/kechen/Research/HTSAT" # the folder of your code
104
+ dataset_path = "/home/Research/audioset" # the dataset path
105
+ desed_folder = "/home/Research/DESED" # the desed file
106
+
107
+ dataset_type = "audioset" # "audioset" "esc-50" "scv2"
108
+ index_type = "full_train" # only works for audioset
109
+ balanced_data = True # only works for audioset
110
+
111
+ loss_type = "clip_bce" #
112
+ # AudioSet & SCV2: "clip_bce" | ESC-50: "clip_ce"
113
+
114
+ # trained from a checkpoint, or evaluate a single model
115
+ resume_checkpoint = None
116
+ # "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_1.ckpt"
117
+
118
+ esc_fold = 0 # just for esc dataset, select the fold you need for evaluation and (+1) validation
119
+
120
+
121
+ debug = False
122
+
123
+ random_seed = 970131 # 19970318 970131 12412 127777 1009 34047
124
+ batch_size = 32 * 4 # batch size per GPU x GPU number , default is 32 x 4 = 128
125
+ learning_rate = 1e-3 # 1e-4 also workable
126
+ max_epoch = 100
127
+ num_workers = 3
128
+
129
+ lr_scheduler_epoch = [10,20,30]
130
+ lr_rate = [0.02, 0.05, 0.1]
131
+
132
+ # these data preparation optimizations do not bring many improvements, so deprecated
133
+ enable_token_label = False # token label
134
+ class_map_path = "class_hier_map.npy"
135
+ class_filter = None
136
+ retrieval_index = [15382, 9202, 130, 17618, 17157, 17516, 16356, 6165, 13992, 9238, 5550, 5733, 1914, 1600, 3450, 13735, 11108, 3762,
137
+ 9840, 11318, 8131, 4429, 16748, 4992, 16783, 12691, 4945, 8779, 2805, 9418, 2797, 14357, 5603, 212, 3852, 12666, 1338, 10269, 2388, 8260, 4293, 14454, 7677, 11253, 5060, 14938, 8840, 4542, 2627, 16336, 8992, 15496, 11140, 446, 6126, 10691, 8624, 10127, 9068, 16710, 10155, 14358, 7567, 5695, 2354, 8057, 17635, 133, 16183, 14535, 7248, 4560, 14429, 2463, 10773, 113, 2462, 9223, 4929, 14274, 4716, 17307, 4617, 2132, 11083, 1039, 1403, 9621, 13936, 2229, 2875, 17840, 9359, 13311, 9790, 13288, 4750, 17052, 8260, 14900]
138
+ token_label_range = [0.2,0.6]
139
+ enable_time_shift = False # shift time
140
+ enable_label_enhance = False # enhance hierarchical label
141
+ enable_repeat_mode = False # repeat the spectrogram / reshape the spectrogram
142
+
143
+
144
+
145
+ # for model's design
146
+ enable_tscam = True # enbale the token-semantic layer
147
+
148
+ # for signal processing
149
+ sample_rate = 32000 # 16000 for scv2, 32000 for audioset and esc-50
150
+ clip_samples = sample_rate * 10 # audio_set 10-sec clip
151
+ window_size = 1024
152
+ hop_size = 320 # 160 for scv2, 320 for audioset and esc-50
153
+ mel_bins = 64
154
+ fmin = 50
155
+ fmax = 14000
156
+ shift_max = int(clip_samples * 0.5)
157
+
158
+ # for data collection
159
+ classes_num = 527 # esc: 50 | audioset: 527 | scv2: 35
160
+ patch_size = (25, 4) # deprecated
161
+ crop_size = None # int(clip_samples * 0.5) deprecated
162
+
163
+ # for htsat hyperparamater
164
+ htsat_window_size = 8
165
+ htsat_spec_size = 256
166
+ htsat_patch_size = 4
167
+ htsat_stride = (4, 4)
168
+ htsat_num_head = [4,8,16,32]
169
+ htsat_dim = 96
170
+ htsat_depth = [2,2,6,2]
171
+
172
+ swin_pretrain_path = None
173
+ # "/home/Research/model_backup/pretrain/swin_tiny_c24_patch4_window8_256.pth"
174
+
175
+ # Some Deprecated Optimization in the model design, check the model code for details
176
+ htsat_attn_heatmap = False
177
+ htsat_hier_output = False
178
+ htsat_use_max = False
179
+
180
+
181
+ # for ensemble test
182
+
183
+ ensemble_checkpoints = []
184
+ ensemble_strides = []
185
+
186
+
187
+ # weight average folder
188
+ wa_folder = "/home/version_0/checkpoints/"
189
+ # weight average output filename
190
+ wa_model_path = "HTSAT_AudioSet_Saved_x.ckpt"
191
+
192
+ esm_model_pathes = [
193
+ "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_1.ckpt",
194
+ "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_2.ckpt",
195
+ "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_3.ckpt",
196
+ "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_4.ckpt",
197
+ "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_5.ckpt",
198
+ "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_6.ckpt"
199
+ ]
200
+
201
+ # for framewise localization
202
+ heatmap_dir = "/home/Research/heatmap_output"
203
+ test_file = "htsat-test-ensemble"
204
+ fl_local = False # indicate if we need to use this dataset for the framewise detection
205
+ fl_dataset = "/home/Research/desed/desedim_embval.npy"
206
+ fl_class_num = [
207
+ "Speech", "Frying", "Dishes", "Running_water",
208
+ "Blender", "Electric_shaver_toothbrush", "Alarm_bell_ringing",
209
+ "Cat", "Dog", "Vacuum_cleaner"
210
+ ]
211
+
212
+ # map 527 classes into 10 classes
213
+ fl_audioset_mapping = [
214
+ [0,1,2,3,4,5,6,7],
215
+ [366, 367, 368],
216
+ [364],
217
+ [288, 289, 290, 291, 292, 293, 294, 295, 296, 297],
218
+ [369],
219
+ [382],
220
+ [310, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402],
221
+ [81, 82, 83, 84, 85],
222
+ [74, 75, 76, 77, 78, 79],
223
+ [377]
224
+ ]
225
+
226
+
227
+
228
+ def _ntuple(n):
229
+ def parse(x):
230
+ if isinstance(x, collections.abc.Iterable):
231
+ return x
232
+ return tuple(repeat(x, n))
233
+ return parse
234
+
235
+ to_1tuple = _ntuple(1)
236
+ to_2tuple = _ntuple(2)
237
+ to_3tuple = _ntuple(3)
238
+ to_4tuple = _ntuple(4)
239
+ to_ntuple = _ntuple
240
+
241
+ def do_mixup(x, mixup_lambda):
242
+ """Mixup x of even indexes (0, 2, 4, ...) with x of odd indexes
243
+ (1, 3, 5, ...).
244
+ Args:
245
+ x: (batch_size * 2, ...)
246
+ mixup_lambda: (batch_size * 2,)
247
+ Returns:
248
+ out: (batch_size, ...)
249
+ """
250
+ out = (x[0 :: 2].transpose(0, -1) * mixup_lambda[0 :: 2] + \
251
+ x[1 :: 2].transpose(0, -1) * mixup_lambda[1 :: 2]).transpose(0, -1)
252
+ return out
253
+
254
+ def interpolate(x, ratio):
255
+ """Interpolate data in time domain. This is used to compensate the
256
+ resolution reduction in downsampling of a CNN.
257
+
258
+ Args:
259
+ x: (batch_size, time_steps, classes_num)
260
+ ratio: int, ratio to interpolate
261
+ Returns:
262
+ upsampled: (batch_size, time_steps * ratio, classes_num)
263
+ """
264
+ (batch_size, time_steps, classes_num) = x.shape
265
+ upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1)
266
+ upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num)
267
+ return upsampled
268
+
269
+
270
+ def drop_path(x, drop_prob: float = 0., training: bool = False):
271
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
272
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
273
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
274
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
275
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
276
+ 'survival rate' as the argument.
277
+ """
278
+ if drop_prob == 0. or not training:
279
+ return x
280
+ keep_prob = 1 - drop_prob
281
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
282
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
283
+ random_tensor.floor_() # binarize
284
+ output = x.div(keep_prob) * random_tensor
285
+ return output
286
+
287
+
288
+ class DropPath(nn.Module):
289
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
290
+ """
291
+ def __init__(self, drop_prob=None):
292
+ super(DropPath, self).__init__()
293
+ self.drop_prob = drop_prob
294
+
295
+ def forward(self, x):
296
+ return drop_path(x, self.drop_prob, self.training)
297
+
298
+ class PatchEmbed(nn.Module):
299
+ """ 2D Image to Patch Embedding
300
+ """
301
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True, patch_stride = 16):
302
+ super().__init__()
303
+ img_size = to_2tuple(img_size)
304
+ patch_size = to_2tuple(patch_size)
305
+ patch_stride = to_2tuple(patch_stride)
306
+ self.img_size = img_size
307
+ self.patch_size = patch_size
308
+ self.patch_stride = patch_stride
309
+ self.grid_size = (img_size[0] // patch_stride[0], img_size[1] // patch_stride[1])
310
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
311
+ self.flatten = flatten
312
+ self.in_chans = in_chans
313
+ self.embed_dim = embed_dim
314
+
315
+ padding = ((patch_size[0] - patch_stride[0]) // 2, (patch_size[1] - patch_stride[1]) // 2)
316
+
317
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_stride, padding=padding)
318
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
319
+
320
+ def forward(self, x):
321
+ B, C, H, W = x.shape
322
+ assert H == self.img_size[0] and W == self.img_size[1], \
323
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
324
+ x = self.proj(x)
325
+ if self.flatten:
326
+ x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
327
+ x = self.norm(x)
328
+ return x
329
+
330
+ class Mlp(nn.Module):
331
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks
332
+ """
333
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
334
+ super().__init__()
335
+ out_features = out_features or in_features
336
+ hidden_features = hidden_features or in_features
337
+ self.fc1 = nn.Linear(in_features, hidden_features)
338
+ self.act = act_layer()
339
+ self.fc2 = nn.Linear(hidden_features, out_features)
340
+ self.drop = nn.Dropout(drop)
341
+
342
+ def forward(self, x):
343
+ x = self.fc1(x)
344
+ x = self.act(x)
345
+ x = self.drop(x)
346
+ x = self.fc2(x)
347
+ x = self.drop(x)
348
+ return x
349
+
350
+ def _no_gradim_audiorunc_normal_(tensor, mean, std, a, b):
351
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
352
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
353
+ def norm_cdf(x):
354
+ # Computes standard normal cumulative distribution function
355
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
356
+
357
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
358
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
359
+ "The distribution of values may be incorrect.",
360
+ stacklevel=2)
361
+
362
+ with torch.no_grad():
363
+ # Values are generated by using a truncated uniform distribution and
364
+ # then using the inverse CDF for the normal distribution.
365
+ # Get upper and lower cdf values
366
+ l = norm_cdf((a - mean) / std)
367
+ u = norm_cdf((b - mean) / std)
368
+
369
+ # Uniformly fill tensor with values from [l, u], then translate to
370
+ # [2l-1, 2u-1].
371
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
372
+
373
+ # Use inverse cdf transform for normal distribution to get truncated
374
+ # standard normal
375
+ tensor.erfinv_()
376
+
377
+ # Transform to proper mean, std
378
+ tensor.mul_(std * math.sqrt(2.))
379
+ tensor.add_(mean)
380
+
381
+ # Clamp to ensure it's in the proper range
382
+ tensor.clamp_(min=a, max=b)
383
+ return tensor
384
+
385
+
386
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
387
+ # type: (Tensor, float, float, float, float) -> Tensor
388
+ r"""Fills the input Tensor with values drawn from a truncated
389
+ normal distribution. The values are effectively drawn from the
390
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
391
+ with values outside :math:`[a, b]` redrawn until they are within
392
+ the bounds. The method used for generating the random values works
393
+ best when :math:`a \leq \text{mean} \leq b`.
394
+ Args:
395
+ tensor: an n-dimensional `torch.Tensor`
396
+ mean: the mean of the normal distribution
397
+ std: the standard deviation of the normal distribution
398
+ a: the minimum cutoff value
399
+ b: the maximum cutoff value
400
+ Examples:
401
+ >>> w = torch.empty(3, 5)
402
+ >>> nn.init.trunc_normal_(w)
403
+ """
404
+ return _no_gradim_audiorunc_normal_(tensor, mean, std, a, b)
405
+
406
+
407
+ def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'):
408
+ fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
409
+ if mode == 'fan_in':
410
+ denom = fan_in
411
+ elif mode == 'fan_out':
412
+ denom = fan_out
413
+ elif mode == 'fan_avg':
414
+ denom = (fan_in + fan_out) / 2
415
+
416
+ variance = scale / denom
417
+
418
+ if distribution == "truncated_normal":
419
+ # constant is stddev of standard normal truncated to (-2, 2)
420
+ trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978)
421
+ elif distribution == "normal":
422
+ tensor.normal_(std=math.sqrt(variance))
423
+ elif distribution == "uniform":
424
+ bound = math.sqrt(3 * variance)
425
+ tensor.uniform_(-bound, bound)
426
+ else:
427
+ raise ValueError(f"invalid distribution {distribution}")
428
+
429
+
430
+ def lecun_normal_(tensor):
431
+ variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal')
432
+
433
+
434
+ # below codes are based and referred from https://github.com/microsoft/Swin-Transformer
435
+ # Swin Transformer for Computer Vision: https://arxiv.org/pdf/2103.14030.pdf
436
+
437
+ def window_partition(x, window_size):
438
+ """
439
+ Args:
440
+ x: (B, H, W, C)
441
+ window_size (int): window size
442
+ Returns:
443
+ windows: (num_windows*B, window_size, window_size, C)
444
+ """
445
+ B, H, W, C = x.shape
446
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
447
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
448
+ return windows
449
+
450
+
451
+ def window_reverse(windows, window_size, H, W):
452
+ """
453
+ Args:
454
+ windows: (num_windows*B, window_size, window_size, C)
455
+ window_size (int): Window size
456
+ H (int): Height of image
457
+ W (int): Width of image
458
+ Returns:
459
+ x: (B, H, W, C)
460
+ """
461
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
462
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
463
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
464
+ return x
465
+
466
+
467
+ class WindowAttention(nn.Module):
468
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
469
+ It supports both of shifted and non-shifted window.
470
+ Args:
471
+ dim (int): Number of input channels.
472
+ window_size (tuple[int]): The height and width of the window.
473
+ num_heads (int): Number of attention heads.
474
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
475
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
476
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
477
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
478
+ """
479
+
480
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
481
+
482
+ super().__init__()
483
+ self.dim = dim
484
+ self.window_size = window_size # Wh, Ww
485
+ self.num_heads = num_heads
486
+ head_dim = dim // num_heads
487
+ self.scale = qk_scale or head_dim ** -0.5
488
+
489
+ # define a parameter table of relative position bias
490
+ self.relative_position_bias_table = nn.Parameter(
491
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
492
+
493
+ # get pair-wise relative position index for each token inside the window
494
+ coords_h = torch.arange(self.window_size[0])
495
+ coords_w = torch.arange(self.window_size[1])
496
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
497
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
498
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
499
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
500
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
501
+ relative_coords[:, :, 1] += self.window_size[1] - 1
502
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
503
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
504
+ self.register_buffer("relative_position_index", relative_position_index)
505
+
506
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
507
+ self.attn_drop = nn.Dropout(attn_drop)
508
+ self.proj = nn.Linear(dim, dim)
509
+ self.proj_drop = nn.Dropout(proj_drop)
510
+
511
+ trunc_normal_(self.relative_position_bias_table, std=.02)
512
+ self.softmax = nn.Softmax(dim=-1)
513
+
514
+ def forward(self, x, mask=None):
515
+ """
516
+ Args:
517
+ x: input features with shape of (num_windows*B, N, C)
518
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
519
+ """
520
+ B_, N, C = x.shape
521
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
522
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
523
+
524
+ q = q * self.scale
525
+ attn = (q @ k.transpose(-2, -1))
526
+
527
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
528
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
529
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
530
+ attn = attn + relative_position_bias.unsqueeze(0)
531
+
532
+ if mask is not None:
533
+ nW = mask.shape[0]
534
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
535
+ attn = attn.view(-1, self.num_heads, N, N)
536
+ attn = self.softmax(attn)
537
+ else:
538
+ attn = self.softmax(attn)
539
+
540
+ attn = self.attn_drop(attn)
541
+
542
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
543
+ x = self.proj(x)
544
+ x = self.proj_drop(x)
545
+ return x, attn
546
+
547
+ def extra_repr(self):
548
+ return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
549
+
550
+
551
+ # We use the model based on Swintransformer Block, therefore we can use the swin-transformer pretrained model
552
+ class SwinTransformerBlock(nn.Module):
553
+ r""" Swin Transformer Block.
554
+ Args:
555
+ dim (int): Number of input channels.
556
+ input_resolution (tuple[int]): Input resulotion.
557
+ num_heads (int): Number of attention heads.
558
+ window_size (int): Window size.
559
+ shift_size (int): Shift size for SW-MSA.
560
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
561
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
562
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
563
+ drop (float, optional): Dropout rate. Default: 0.0
564
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
565
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
566
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
567
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
568
+ """
569
+
570
+ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
571
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
572
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm, norm_before_mlp='ln'):
573
+ super().__init__()
574
+ self.dim = dim
575
+ self.input_resolution = input_resolution
576
+ self.num_heads = num_heads
577
+ self.window_size = window_size
578
+ self.shift_size = shift_size
579
+ self.mlp_ratio = mlp_ratio
580
+ self.norm_before_mlp = norm_before_mlp
581
+ if min(self.input_resolution) <= self.window_size:
582
+ # if window size is larger than input resolution, we don't partition windows
583
+ self.shift_size = 0
584
+ self.window_size = min(self.input_resolution)
585
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
586
+
587
+ self.norm1 = norm_layer(dim)
588
+ self.attn = WindowAttention(
589
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
590
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
591
+
592
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
593
+ if self.norm_before_mlp == 'ln':
594
+ self.norm2 = nn.LayerNorm(dim)
595
+ elif self.norm_before_mlp == 'bn':
596
+ self.norm2 = lambda x: nn.BatchNorm1d(dim)(x.transpose(1, 2)).transpose(1, 2)
597
+ else:
598
+ raise NotImplementedError
599
+ mlp_hidden_dim = int(dim * mlp_ratio)
600
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
601
+
602
+ if self.shift_size > 0:
603
+ # calculate attention mask for SW-MSA
604
+ H, W = self.input_resolution
605
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
606
+ h_slices = (slice(0, -self.window_size),
607
+ slice(-self.window_size, -self.shift_size),
608
+ slice(-self.shift_size, None))
609
+ w_slices = (slice(0, -self.window_size),
610
+ slice(-self.window_size, -self.shift_size),
611
+ slice(-self.shift_size, None))
612
+ cnt = 0
613
+ for h in h_slices:
614
+ for w in w_slices:
615
+ img_mask[:, h, w, :] = cnt
616
+ cnt += 1
617
+
618
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
619
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
620
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
621
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
622
+ else:
623
+ attn_mask = None
624
+
625
+ self.register_buffer("attn_mask", attn_mask)
626
+
627
+ def forward(self, x):
628
+ # pdb.set_trace()
629
+ H, W = self.input_resolution
630
+ # print("H: ", H)
631
+ # print("W: ", W)
632
+ # pdb.set_trace()
633
+ B, L, C = x.shape
634
+ # assert L == H * W, "input feature has wrong size"
635
+
636
+ shortcut = x
637
+ x = self.norm1(x)
638
+ x = x.view(B, H, W, C)
639
+
640
+ # cyclic shift
641
+ if self.shift_size > 0:
642
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
643
+ else:
644
+ shifted_x = x
645
+
646
+ # partition windows
647
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
648
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
649
+
650
+ # W-MSA/SW-MSA
651
+ attn_windows, attn = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
652
+
653
+ # merge windows
654
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
655
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
656
+
657
+ # reverse cyclic shift
658
+ if self.shift_size > 0:
659
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
660
+ else:
661
+ x = shifted_x
662
+ x = x.view(B, H * W, C)
663
+
664
+ # FFN
665
+ x = shortcut + self.drop_path(x)
666
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
667
+
668
+ return x, attn
669
+
670
+ def extra_repr(self):
671
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
672
+ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
673
+
674
+
675
+
676
+ class PatchMerging(nn.Module):
677
+ r""" Patch Merging Layer.
678
+ Args:
679
+ input_resolution (tuple[int]): Resolution of input feature.
680
+ dim (int): Number of input channels.
681
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
682
+ """
683
+
684
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
685
+ super().__init__()
686
+ self.input_resolution = input_resolution
687
+ self.dim = dim
688
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
689
+ self.norm = norm_layer(4 * dim)
690
+
691
+ def forward(self, x):
692
+ """
693
+ x: B, H*W, C
694
+ """
695
+ H, W = self.input_resolution
696
+ B, L, C = x.shape
697
+ assert L == H * W, "input feature has wrong size"
698
+ assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
699
+
700
+ x = x.view(B, H, W, C)
701
+
702
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
703
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
704
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
705
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
706
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
707
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
708
+
709
+ x = self.norm(x)
710
+ x = self.reduction(x)
711
+
712
+ return x
713
+
714
+ def extra_repr(self):
715
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
716
+
717
+
718
+ class BasicLayer(nn.Module):
719
+ """ A basic Swin Transformer layer for one stage.
720
+ Args:
721
+ dim (int): Number of input channels.
722
+ input_resolution (tuple[int]): Input resolution.
723
+ depth (int): Number of blocks.
724
+ num_heads (int): Number of attention heads.
725
+ window_size (int): Local window size.
726
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
727
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
728
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
729
+ drop (float, optional): Dropout rate. Default: 0.0
730
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
731
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
732
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
733
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
734
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
735
+ """
736
+
737
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size,
738
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
739
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
740
+ norm_before_mlp='ln'):
741
+
742
+ super().__init__()
743
+ self.dim = dim
744
+ self.input_resolution = input_resolution
745
+ self.depth = depth
746
+ self.use_checkpoint = use_checkpoint
747
+
748
+ # build blocks
749
+ self.blocks = nn.ModuleList([
750
+ SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
751
+ num_heads=num_heads, window_size=window_size,
752
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
753
+ mlp_ratio=mlp_ratio,
754
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
755
+ drop=drop, attn_drop=attn_drop,
756
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
757
+ norm_layer=norm_layer, norm_before_mlp=norm_before_mlp)
758
+ for i in range(depth)])
759
+
760
+ # patch merging layer
761
+ if downsample is not None:
762
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
763
+ else:
764
+ self.downsample = None
765
+
766
+ def forward(self, x):
767
+ attns = []
768
+ for blk in self.blocks:
769
+ if self.use_checkpoint:
770
+ x = checkpoint.checkpoint(blk, x)
771
+ else:
772
+ x, attn = blk(x)
773
+ if not self.training:
774
+ attns.append(attn.unsqueeze(0))
775
+ if self.downsample is not None:
776
+ x = self.downsample(x)
777
+ if not self.training:
778
+ attn = torch.cat(attns, dim = 0)
779
+ attn = torch.mean(attn, dim = 0)
780
+ return x, attn
781
+
782
+ def extra_repr(self):
783
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
784
+
785
+
786
+ # The Core of HTSAT
787
+ class HTSAT_Swin_Transformer(nn.Module):
788
+ r"""HTSAT based on the Swin Transformer
789
+ Args:
790
+ spec_size (int | tuple(int)): Input Spectrogram size. Default 256
791
+ patch_size (int | tuple(int)): Patch size. Default: 4
792
+ path_stride (iot | tuple(int)): Patch Stride for Frequency and Time Axis. Default: 4
793
+ in_chans (int): Number of input image channels. Default: 1 (mono)
794
+ num_classes (int): Number of classes for classification head. Default: 527
795
+ embed_dim (int): Patch embedding dimension. Default: 96
796
+ depths (tuple(int)): Depth of each HTSAT-Swin Transformer layer.
797
+ num_heads (tuple(int)): Number of attention heads in different layers.
798
+ window_size (int): Window size. Default: 8
799
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
800
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
801
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
802
+ drop_rate (float): Dropout rate. Default: 0
803
+ attn_drop_rate (float): Attention dropout rate. Default: 0
804
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
805
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
806
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
807
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
808
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
809
+ config (module): The configuration Module from config.py (HTSATConfig Class)
810
+ """
811
+
812
+ def __init__(self, spec_size=256, patch_size=4, patch_stride=(4,4),
813
+ in_chans=1, num_classes=527,
814
+ embed_dim=96, depths=[2, 2, 6, 2], num_heads=[4, 8, 16, 32],
815
+ window_size=8, mlp_ratio=4., qkv_bias=True, qk_scale=None,
816
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
817
+ norm_layer=nn.LayerNorm,
818
+ ape=False, patch_norm=True,
819
+ use_checkpoint=False, norm_before_mlp='ln', config = None, **kwargs):
820
+ super(HTSAT_Swin_Transformer, self).__init__()
821
+
822
+ self.config = config
823
+ self.spec_size = spec_size
824
+ self.patch_stride = patch_stride
825
+ self.patch_size = patch_size
826
+ self.window_size = window_size
827
+ self.embed_dim = embed_dim
828
+ self.depths = depths
829
+ self.ape = ape
830
+ self.in_chans = in_chans
831
+ self.num_classes = num_classes
832
+ self.num_heads = num_heads
833
+ self.num_layers = len(self.depths)
834
+ self.num_features = int(self.embed_dim * 2 ** (self.num_layers - 1))
835
+
836
+ self.drop_rate = drop_rate
837
+ self.attn_drop_rate = attn_drop_rate
838
+ self.drop_path_rate = drop_path_rate
839
+
840
+ self.qkv_bias = qkv_bias
841
+ self.qk_scale = None
842
+
843
+ self.patch_norm = patch_norm
844
+ self.norm_layer = norm_layer if self.patch_norm else None
845
+ self.norm_before_mlp = norm_before_mlp
846
+ self.mlp_ratio = mlp_ratio
847
+
848
+ self.use_checkpoint = use_checkpoint
849
+
850
+ # process mel-spec ; used only once
851
+ self.freq_ratio = self.spec_size // self.config.mel_bins
852
+ window = 'hann'
853
+ center = True
854
+ pad_mode = 'reflect'
855
+ ref = 1.0
856
+ amin = 1e-10
857
+ top_db = None
858
+ self.interpolate_ratio = 32 # Downsampled ratio
859
+ # Spectrogram extractor
860
+ self.spectrogram_extractor = Spectrogram(n_fft=config.window_size, hop_length=config.hop_size,
861
+ win_length=config.window_size, window=window, center=center, pad_mode=pad_mode,
862
+ freeze_parameters=True)
863
+ # Logmel feature extractor
864
+ self.logmel_extractor = LogmelFilterBank(sr=config.sample_rate, n_fft=config.window_size,
865
+ n_mels=config.mel_bins, fmin=config.fmin, fmax=config.fmax, ref=ref, amin=amin, top_db=top_db,
866
+ freeze_parameters=True)
867
+ # Spec augmenter
868
+ self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2,
869
+ freq_drop_width=8, freq_stripes_num=2) # 2 2
870
+ self.bn0 = nn.BatchNorm2d(self.config.mel_bins)
871
+
872
+
873
+ # split spctrogram into non-overlapping patches
874
+ self.patch_embed = PatchEmbed(
875
+ img_size=self.spec_size, patch_size=self.patch_size, in_chans=self.in_chans,
876
+ embed_dim=self.embed_dim, norm_layer=self.norm_layer, patch_stride = patch_stride)
877
+
878
+ num_patches = self.patch_embed.num_patches
879
+ patches_resolution = self.patch_embed.grid_size
880
+ self.patches_resolution = patches_resolution
881
+
882
+ # absolute position embedding
883
+ if self.ape:
884
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, self.embed_dim))
885
+ trunc_normal_(self.absolute_pos_embed, std=.02)
886
+
887
+ self.pos_drop = nn.Dropout(p=self.drop_rate)
888
+
889
+ # stochastic depth
890
+ dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, sum(self.depths))] # stochastic depth decay rule
891
+
892
+ # build layers
893
+ self.layers = nn.ModuleList()
894
+ for i_layer in range(self.num_layers):
895
+ layer = BasicLayer(dim=int(self.embed_dim * 2 ** i_layer),
896
+ input_resolution=(patches_resolution[0] // (2 ** i_layer),
897
+ patches_resolution[1] // (2 ** i_layer)),
898
+ depth=self.depths[i_layer],
899
+ num_heads=self.num_heads[i_layer],
900
+ window_size=self.window_size,
901
+ mlp_ratio=self.mlp_ratio,
902
+ qkv_bias=self.qkv_bias, qk_scale=self.qk_scale,
903
+ drop=self.drop_rate, attn_drop=self.attn_drop_rate,
904
+ drop_path=dpr[sum(self.depths[:i_layer]):sum(self.depths[:i_layer + 1])],
905
+ norm_layer=self.norm_layer,
906
+ downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
907
+ use_checkpoint=use_checkpoint,
908
+ norm_before_mlp=self.norm_before_mlp)
909
+ self.layers.append(layer)
910
+
911
+ self.norm = self.norm_layer(self.num_features)
912
+ self.avgpool = nn.AdaptiveAvgPool1d(1)
913
+ self.maxpool = nn.AdaptiveMaxPool1d(1)
914
+
915
+ if self.config.enable_tscam:
916
+ SF = self.spec_size // (2 ** (len(self.depths) - 1)) // self.patch_stride[0] // self.freq_ratio
917
+ self.tscam_conv = nn.Conv2d(
918
+ in_channels = self.num_features,
919
+ out_channels = self.num_classes,
920
+ kernel_size = (SF,3),
921
+ padding = (0,1)
922
+ )
923
+ self.head = nn.Linear(num_classes, num_classes)
924
+ else:
925
+ self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
926
+
927
+ self.apply(self._init_weights)
928
+
929
+ def _init_weights(self, m):
930
+ if isinstance(m, nn.Linear):
931
+ trunc_normal_(m.weight, std=.02)
932
+ if isinstance(m, nn.Linear) and m.bias is not None:
933
+ nn.init.constant_(m.bias, 0)
934
+ elif isinstance(m, nn.LayerNorm):
935
+ nn.init.constant_(m.bias, 0)
936
+ nn.init.constant_(m.weight, 1.0)
937
+
938
+ @torch.jit.ignore
939
+ def no_weight_decay(self):
940
+ return {'absolute_pos_embed'}
941
+
942
+ @torch.jit.ignore
943
+ def no_weight_decay_keywords(self):
944
+ return {'relative_position_bias_table'}
945
+
946
+ def forward_features(self, x):
947
+ frames_num = x.shape[2]
948
+ x = self.patch_embed(x)
949
+ if self.ape:
950
+ x = x + self.absolute_pos_embed
951
+ x = self.pos_drop(x)
952
+ for i, layer in enumerate(self.layers):
953
+ x, attn = layer(x)
954
+
955
+ if self.config.enable_tscam:
956
+ # for x
957
+ x = self.norm(x)
958
+ B, N, C = x.shape
959
+ SF = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[0]
960
+ ST = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[1]
961
+ x = x.permute(0,2,1).contiguous().reshape(B, C, SF, ST)
962
+ B, C, F, T = x.shape
963
+ # group 2D CNN
964
+ c_freq_bin = F // self.freq_ratio
965
+ x = x.reshape(B, C, F // c_freq_bin, c_freq_bin, T)
966
+ x = x.permute(0,1,3,2,4).contiguous().reshape(B, C, c_freq_bin, -1)
967
+
968
+ # get latent_output
969
+ latent_output = self.avgpool(torch.flatten(x,2))
970
+ latent_output = torch.flatten(latent_output, 1)
971
+
972
+ # display the attention map, if needed
973
+ if self.config.htsat_attn_heatmap:
974
+ # for attn
975
+ attn = torch.mean(attn, dim = 1)
976
+ attn = torch.mean(attn, dim = 1)
977
+ attn = attn.reshape(B, SF, ST)
978
+ c_freq_bin = SF // self.freq_ratio
979
+ attn = attn.reshape(B, SF // c_freq_bin, c_freq_bin, ST)
980
+ attn = attn.permute(0,2,1,3).contiguous().reshape(B, c_freq_bin, -1)
981
+ attn = attn.mean(dim = 1)
982
+ attn_max = torch.max(attn, dim = 1, keepdim = True)[0]
983
+ attn_min = torch.min(attn, dim = 1, keepdim = True)[0]
984
+ attn = ((attn * 0.15) + (attn_max * 0.85 - attn_min)) / (attn_max - attn_min)
985
+ attn = attn.unsqueeze(dim = 2)
986
+
987
+ x = self.tscam_conv(x)
988
+ x = torch.flatten(x, 2) # B, C, T
989
+
990
+ if self.config.htsat_attn_heatmap:
991
+ fpx = interpolate(torch.sigmoid(x).permute(0,2,1).contiguous() * attn, 8 * self.patch_stride[1])
992
+ else:
993
+ fpx = interpolate(torch.sigmoid(x).permute(0,2,1).contiguous(), 8 * self.patch_stride[1])
994
+
995
+ x = self.avgpool(x)
996
+ x = torch.flatten(x, 1)
997
+
998
+ if self.config.loss_type == "clip_ce":
999
+ output_dict = {
1000
+ 'framewise_output': fpx, # already sigmoided
1001
+ 'clipwise_output': x,
1002
+ 'latent_output': latent_output
1003
+ }
1004
+ else:
1005
+ output_dict = {
1006
+ 'framewise_output': fpx, # already sigmoided
1007
+ 'clipwise_output': torch.sigmoid(x),
1008
+ 'latent_output': latent_output
1009
+ }
1010
+
1011
+ else:
1012
+ x = self.norm(x) # B N C
1013
+ B, N, C = x.shape
1014
+
1015
+ fpx = x.permute(0,2,1).contiguous().reshape(B, C, frames_num // (2 ** (len(self.depths) + 1)), frames_num // (2 ** (len(self.depths) + 1)) )
1016
+ B, C, F, T = fpx.shape
1017
+ c_freq_bin = F // self.freq_ratio
1018
+ fpx = fpx.reshape(B, C, F // c_freq_bin, c_freq_bin, T)
1019
+ fpx = fpx.permute(0,1,3,2,4).contiguous().reshape(B, C, c_freq_bin, -1)
1020
+ fpx = torch.sum(fpx, dim = 2)
1021
+ fpx = interpolate(fpx.permute(0,2,1).contiguous(), 8 * self.patch_stride[1])
1022
+ x = self.avgpool(x.transpose(1, 2)) # B C 1
1023
+ x = torch.flatten(x, 1)
1024
+ if self.num_classes > 0:
1025
+ x = self.head(x)
1026
+ fpx = self.head(fpx)
1027
+ output_dict = {'framewise_output': torch.sigmoid(fpx),
1028
+ 'clipwise_output': torch.sigmoid(x)}
1029
+ return output_dict
1030
+
1031
+ def crop_wav(self, x, crop_size, spe_pos = None):
1032
+ time_steps = x.shape[2]
1033
+ tx = torch.zeros(x.shape[0], x.shape[1], crop_size, x.shape[3]).to(x.device)
1034
+ for i in range(len(x)):
1035
+ if spe_pos is None:
1036
+ crop_pos = random.randint(0, time_steps - crop_size - 1)
1037
+ else:
1038
+ crop_pos = spe_pos
1039
+ tx[i][0] = x[i, 0, crop_pos:crop_pos + crop_size,:]
1040
+ return tx
1041
+
1042
+ # Reshape the wavform to a img size, if you want to use the pretrained swin transformer model
1043
+ def reshape_wav2img(self, x):
1044
+ B, C, T, F = x.shape
1045
+ target_T = int(self.spec_size * self.freq_ratio)
1046
+ target_F = self.spec_size // self.freq_ratio
1047
+ assert T <= target_T and F <= target_F, "the wav size should less than or equal to the swin input size"
1048
+ # to avoid bicubic zero error
1049
+ if T < target_T:
1050
+ x = nn.functional.interpolate(x, (target_T, x.shape[3]), mode="bicubic", align_corners=True)
1051
+ if F < target_F:
1052
+ x = nn.functional.interpolate(x, (x.shape[2], target_F), mode="bicubic", align_corners=True)
1053
+ x = x.permute(0,1,3,2).contiguous()
1054
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2], self.freq_ratio, x.shape[3] // self.freq_ratio)
1055
+ # print(x.shape)
1056
+ x = x.permute(0,1,3,2,4).contiguous()
1057
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3], x.shape[4])
1058
+ return x
1059
+
1060
+ # Repeat the wavform to a img size, if you want to use the pretrained swin transformer model
1061
+ def repeat_wat2img(self, x, cur_pos):
1062
+ B, C, T, F = x.shape
1063
+ target_T = int(self.spec_size * self.freq_ratio)
1064
+ target_F = self.spec_size // self.freq_ratio
1065
+ assert T <= target_T and F <= target_F, "the wav size should less than or equal to the swin input size"
1066
+ # to avoid bicubic zero error
1067
+ if T < target_T:
1068
+ x = nn.functional.interpolate(x, (target_T, x.shape[3]), mode="bicubic", align_corners=True)
1069
+ if F < target_F:
1070
+ x = nn.functional.interpolate(x, (x.shape[2], target_F), mode="bicubic", align_corners=True)
1071
+ x = x.permute(0,1,3,2).contiguous() # B C F T
1072
+ x = x[:,:,:,cur_pos:cur_pos + self.spec_size]
1073
+ x = x.repeat(repeats = (1,1,4,1))
1074
+ return x
1075
+
1076
+ def forward(self, x: torch.Tensor, mixup_lambda = None, infer_mode = False):# out_feat_keys: List[str] = None):
1077
+ x = self.spectrogram_extractor(x) # (batch_size, 1, time_steps, freq_bins)
1078
+ x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
1079
+
1080
+
1081
+ x = x.transpose(1, 3)
1082
+ x = self.bn0(x)
1083
+ x = x.transpose(1, 3)
1084
+ if self.training:
1085
+ x = self.spec_augmenter(x)
1086
+ if self.training and mixup_lambda is not None:
1087
+ x = do_mixup(x, mixup_lambda)
1088
+
1089
+ if infer_mode:
1090
+ # in infer mode. we need to handle different length audio input
1091
+ frame_num = x.shape[2]
1092
+ target_T = int(self.spec_size * self.freq_ratio)
1093
+ repeat_ratio = math.floor(target_T / frame_num)
1094
+ x = x.repeat(repeats=(1,1,repeat_ratio,1))
1095
+ x = self.reshape_wav2img(x)
1096
+ output_dict = self.forward_features(x)
1097
+ elif self.config.enable_repeat_mode:
1098
+ if self.training:
1099
+ cur_pos = random.randint(0, (self.freq_ratio - 1) * self.spec_size - 1)
1100
+ x = self.repeat_wat2img(x, cur_pos)
1101
+ output_dict = self.forward_features(x)
1102
+ else:
1103
+ output_dicts = []
1104
+ for cur_pos in range(0, (self.freq_ratio - 1) * self.spec_size + 1, self.spec_size):
1105
+ tx = x.clone()
1106
+ tx = self.repeat_wat2img(tx, cur_pos)
1107
+ output_dicts.append(self.forward_features(tx))
1108
+ clipwise_output = torch.zeros_like(output_dicts[0]["clipwise_output"]).float().to(x.device)
1109
+ framewise_output = torch.zeros_like(output_dicts[0]["framewise_output"]).float().to(x.device)
1110
+ for d in output_dicts:
1111
+ clipwise_output += d["clipwise_output"]
1112
+ framewise_output += d["framewise_output"]
1113
+ clipwise_output = clipwise_output / len(output_dicts)
1114
+ framewise_output = framewise_output / len(output_dicts)
1115
+
1116
+ output_dict = {
1117
+ 'framewise_output': framewise_output,
1118
+ 'clipwise_output': clipwise_output
1119
+ }
1120
+ else:
1121
+ if x.shape[2] > self.freq_ratio * self.spec_size:
1122
+ if self.training:
1123
+ x = self.crop_wav(x, crop_size=self.freq_ratio * self.spec_size)
1124
+ x = self.reshape_wav2img(x)
1125
+ output_dict = self.forward_features(x)
1126
+ else:
1127
+ # Change: Hard code here
1128
+ overlap_size = 344 #(x.shape[2] - 1) // 4
1129
+ output_dicts = []
1130
+ crop_size = 689 #(x.shape[2] - 1) // 2
1131
+ for cur_pos in range(0, x.shape[2] - crop_size - 1, overlap_size):
1132
+ tx = self.crop_wav(x, crop_size = crop_size, spe_pos = cur_pos)
1133
+ tx = self.reshape_wav2img(tx)
1134
+ output_dicts.append(self.forward_features(tx))
1135
+ clipwise_output = torch.zeros_like(output_dicts[0]["clipwise_output"]).float().to(x.device)
1136
+ framewise_output = torch.zeros_like(output_dicts[0]["framewise_output"]).float().to(x.device)
1137
+ latent_output = torch.zeros_like(output_dicts[0]["latent_output"]).float().to(x.device)
1138
+ for d in output_dicts:
1139
+ clipwise_output += d["clipwise_output"]
1140
+ framewise_output += d["framewise_output"]
1141
+ latent_output += d["latent_output"]
1142
+ clipwise_output = clipwise_output / len(output_dicts)
1143
+ framewise_output = framewise_output / len(output_dicts)
1144
+ latent_output = latent_output / len(output_dicts)
1145
+ output_dict = {
1146
+ 'framewise_output': framewise_output,
1147
+ 'clipwise_output': clipwise_output,
1148
+ 'latent_output': latent_output,
1149
+ }
1150
+ else: # this part is typically used, and most easy one
1151
+ x = self.reshape_wav2img(x)
1152
+ output_dict = self.forward_features(x)
1153
+ # x = self.head(x)
1154
+ return output_dict
1155
+
1156
+ class HTSATWrapper(nn.Module):
1157
+ def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin,
1158
+ fmax, classes_num, out_emb):
1159
+ super().__init__()
1160
+
1161
+ # print("parameters are being overidden when using HTSAT")
1162
+ # print("HTSAT only support loading a pretrained model on AudioSet")
1163
+ # @TODO later look at what parameters are same and can be merged
1164
+
1165
+ self.htsat = HTSAT_Swin_Transformer(config=HTSATConfig())
1166
+
1167
+ def forward(self, x):
1168
+ out_dict = self.htsat(x)
1169
+ out_dict['embedding'] = out_dict['latent_output']
1170
+ return out_dict
1171
+
1172
+
1173
+ def get_audio_encoder(name: str):
1174
+ if name == "HTSAT":
1175
+ return HTSATWrapper
1176
+ else:
1177
+ raise Exception('The audio encoder name {} is incorrect or not supported'.format(name))
1178
+
1179
+ class Projection(nn.Module):
1180
+ def __init__(self, dim_imgn: int, d_out: int, p: float=0.5) -> None:
1181
+ super().__init__()
1182
+ self.linear1 = nn.Linear(dim_imgn, d_out, bias=False)
1183
+ self.linear2 = nn.Linear(d_out, d_out, bias=False)
1184
+ self.layer_norm = nn.LayerNorm(d_out)
1185
+ self.drop = nn.Dropout(p)
1186
+
1187
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1188
+ embed1 = self.linear1(x)
1189
+ embed2 = self.drop(self.linear2(F.gelu(embed1)))
1190
+ embeds = self.layer_norm(embed1 + embed2)
1191
+ return embeds
1192
+
1193
+ class AudioEncoder(nn.Module):
1194
+ def __init__(self, audioenc_name:str, dim_imgn: int, d_out: int, sample_rate: int, window_size: int,
1195
+ hop_size: int, mel_bins: int, fmin: int, fmax: int, classes_num: int) -> None:
1196
+ super().__init__()
1197
+
1198
+ audio_encoder = get_audio_encoder(audioenc_name)
1199
+
1200
+ self.base = audio_encoder(
1201
+ sample_rate, window_size,
1202
+ hop_size, mel_bins, fmin, fmax,
1203
+ classes_num, dim_imgn)
1204
+
1205
+ self.projection = Projection(dim_imgn, d_out)
1206
+
1207
+ def forward(self, x):
1208
+ out_dict = self.base(x)
1209
+ audio_features, audio_classification_output = out_dict['embedding'], out_dict['clipwise_output']
1210
+ projected_vec = self.projection(audio_features)
1211
+ return projected_vec, audio_classification_output
1212
+
1213
+ class TextEncoder(nn.Module):
1214
+ def __init__(self, d_out: int, text_model: str, transformer_embed_dim: int) -> None:
1215
+ super().__init__()
1216
+ self.text_model = text_model
1217
+ self.base = AutoModel.from_pretrained(text_model)
1218
+
1219
+ if 'clip' in text_model:
1220
+ self.clip_text_projection = self.base.text_projection
1221
+ self.base = self.base.text_model
1222
+ if 'base' in text_model:
1223
+ transformer_embed_dim = 512
1224
+
1225
+ self.projection = Projection(transformer_embed_dim, d_out)
1226
+
1227
+ def forward(self, x):
1228
+ if 'clip' in self.text_model:
1229
+ pooled_output = self.base(**x)[1] # get pooled output
1230
+ out = self.clip_text_projection(pooled_output) # get CLS token output
1231
+ elif 'gpt' in self.text_model:
1232
+ batch_size = x['input_ids'].shape[0]
1233
+ hidden_states = self.base(**x)[0] # (batch_size=4, seq_len, 768)
1234
+
1235
+ sequence_lengths = torch.ne(x['input_ids'], 0).sum(-1) - 1 # tensor([13, 14, 18, 17])
1236
+ out = hidden_states[torch.arange(batch_size, device=hidden_states.device), sequence_lengths] # [batch_size, 768] = [4, 768]
1237
+ else:
1238
+ out = self.base(**x)[0]
1239
+ out = out[:, 0, :] # get CLS token output
1240
+
1241
+ projected_vec = self.projection(out)
1242
+
1243
+ return projected_vec
1244
+
1245
+ class CLAP(nn.Module):
1246
+ def __init__(self,
1247
+ # audio
1248
+ audioenc_name: str,
1249
+ sample_rate: int,
1250
+ window_size: int,
1251
+ hop_size: int,
1252
+ mel_bins: int,
1253
+ fmin: int,
1254
+ fmax: int,
1255
+ classes_num: int,
1256
+ out_emb: int,
1257
+ # text
1258
+ text_model: str,
1259
+ transformer_embed_dim: int,
1260
+ # common
1261
+ d_proj: int,
1262
+ ):
1263
+ super().__init__()
1264
+
1265
+
1266
+ self.audio_encoder = AudioEncoder(
1267
+ audioenc_name, out_emb, d_proj,
1268
+ sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num)
1269
+
1270
+ self.caption_encoder = TextEncoder(
1271
+ d_proj, text_model, transformer_embed_dim
1272
+ )
1273
+
1274
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
1275
+
1276
+ def forward(self, audio, text):
1277
+ audio_embed, _ = self.audio_encoder(audio)
1278
+ caption_embed = self.caption_encoder(text)
1279
+
1280
+ return caption_embed, audio_embed, self.logit_scale.exp()
1281
+
1282
+
1283
+
1284
+ # ==================================================================
1285
+ # A U D I O - P R E - P R O C E S S I N G
1286
+ # ==================================================================
1287
+ def read_audio(audio_path, resample=True):
1288
+ r"""Loads audio file or array and returns a torch tensor"""
1289
+ # Randomly sample a segment of audio_duration from the clip or pad to match duration
1290
+ audio_time_series, sample_rate = torchaudio.load(audio_path)
1291
+
1292
+ resample_rate = clapConfig.sample_rate
1293
+ if resample and resample_rate != sample_rate:
1294
+ resampler = T.Resample(sample_rate, resample_rate)
1295
+ audio_time_series = resampler(audio_time_series)
1296
+ return audio_time_series, resample_rate
1297
+
1298
+ def load_audio_into_tensor(audio_path, audio_duration, resample=False):
1299
+ r"""Loads audio file and returns raw audio."""
1300
+ # Randomly sample a segment of audio_duration from the clip or pad to match duration
1301
+ audio_time_series, sample_rate = read_audio(audio_path, resample)
1302
+ audio_time_series = audio_time_series.reshape(-1)
1303
+
1304
+ # audio_time_series is shorter than predefined audio duration,
1305
+ # so audio_time_series is extended
1306
+ if audio_duration*sample_rate >= audio_time_series.shape[0]:
1307
+ repeat_factor = int(np.ceil((audio_duration*sample_rate) /
1308
+ audio_time_series.shape[0]))
1309
+ # Repeat audio_time_series by repeat_factor to match audio_duration
1310
+ audio_time_series = audio_time_series.repeat(repeat_factor)
1311
+ # remove excess part of audio_time_series
1312
+ audio_time_series = audio_time_series[0:audio_duration*sample_rate]
1313
+ else:
1314
+ # audio_time_series is longer than predefined audio duration,
1315
+ # so audio_time_series is trimmed
1316
+ start_index = random.randrange(
1317
+ audio_time_series.shape[0] - audio_duration*sample_rate)
1318
+ audio_time_series = audio_time_series[start_index:start_index +
1319
+ audio_duration*sample_rate]
1320
+ return torch.FloatTensor(audio_time_series)
1321
+
1322
+ np_str_obj_array_pattern = re.compile(r'[SaUO]')
1323
+ default_collate_err_msg_format = (
1324
+ "default_collate: batch must contain tensors, numpy arrays, numbers, "
1325
+ "dicts or lists; found {}")
1326
+
1327
+ def default_collate(batch):
1328
+ r"""Puts each data field into a tensor with outer dimension batch size"""
1329
+ elem = batch[0]
1330
+ elem_type = type(elem)
1331
+ if isinstance(elem, torch.Tensor):
1332
+ out = None
1333
+ if torch.utils.data.get_worker_info() is not None:
1334
+ # If we're in a background process, concatenate directly into a
1335
+ # shared memory tensor to avoid an extra copy
1336
+ numel = sum([x.numel() for x in batch])
1337
+ storage = elem.storage()._new_shared(numel)
1338
+ out = elem.new(storage)
1339
+ return torch.stack(batch, 0, out=out)
1340
+ elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
1341
+ and elem_type.__name__ != 'string_':
1342
+ if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
1343
+ # array of string classes and object
1344
+ if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
1345
+ raise TypeError(
1346
+ default_collate_err_msg_format.format(elem.dtype))
1347
+
1348
+ return default_collate([torch.as_tensor(b) for b in batch])
1349
+ elif elem.shape == (): # scalars
1350
+ return torch.as_tensor(batch)
1351
+ elif isinstance(elem, float):
1352
+ return torch.tensor(batch, dtype=torch.float64)
1353
+ elif isinstance(elem, int):
1354
+ return torch.tensor(batch)
1355
+ elif isinstance(elem, str):
1356
+ return batch
1357
+ elif isinstance(elem, collections.abc.Mapping):
1358
+ return {key: default_collate([d[key] for d in batch]) for key in elem}
1359
+ elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
1360
+ return elem_type(*(default_collate(samples) for samples in zip(*batch)))
1361
+ elif isinstance(elem, collections.abc.Sequence):
1362
+ # check to make sure that the elements in batch have consistent size
1363
+ it = iter(batch)
1364
+ elem_size = len(next(it))
1365
+ if not all(len(elem) == elem_size for elem in it):
1366
+ raise RuntimeError(
1367
+ 'each element in list of batch should be of equal size')
1368
+ transposed = zip(*batch)
1369
+ return [default_collate(samples) for samples in transposed]
1370
+
1371
+ raise TypeError(default_collate_err_msg_format.format(elem_type))
1372
+
1373
+ def preprocess_audio(audio_files, resample):
1374
+ r"""Load list of audio files and return raw audio"""
1375
+ audio_tensors = []
1376
+ for audio_file in audio_files:
1377
+ audio_tensor = load_audio_into_tensor(
1378
+ audio_file, clapConfig.duration, resample)
1379
+ audio_tensor = audio_tensor.reshape(1, -1)
1380
+ audio_tensors.append(audio_tensor)
1381
+ return default_collate(audio_tensors)
1382
+
1383
+
1384
+
1385
+ # ==================================================================
1386
+ # A U D I O - E M B E D D I N G S - H E L P E R
1387
+ # ==================================================================
1388
+ def CLAPAudioProcessor(audio_files: List[str], resample=True):
1389
+ preprocessed_audio = preprocess_audio(audio_files, resample)
1390
+ preprocessed_audio = preprocessed_audio.reshape(
1391
+ preprocessed_audio.shape[0], preprocessed_audio.shape[2])
1392
+ preprocessed_audio = preprocessed_audio
1393
+ return preprocessed_audio
1394
+
1395
+ def get_audio_embeddings(audio_files: List[str], audio_encoder, resample=True):
1396
+ """Load list of audio files and return audio embeddings"""
1397
+ # preprocessed_audio = preprocess_audio(audio_files, resample)
1398
+ # with torch.no_grad():
1399
+ # preprocessed_audio = preprocessed_audio.reshape(
1400
+ # preprocessed_audio.shape[0], preprocessed_audio.shape[2])
1401
+ with torch.no_grad():
1402
+ preprocessed_audio = CLAPAudioProcessor(audio_files, resample)
1403
+ return audio_encoder(preprocessed_audio)[0]
1404
+
1405
+
1406
+ # ==================================================================
1407
+ # C L A P
1408
+ # ==================================================================
1409
+ class ClapConfig:
1410
+ # TEXT ENCODER CONFIG
1411
+ text_model = 'gpt2'
1412
+ text_len = 77
1413
+ transformer_embed_dim = 768
1414
+ freeze_text_encoder_weights = True
1415
+
1416
+ # AUDIO ENCODER CONFIG
1417
+ audioenc_name = 'HTSAT'
1418
+ out_emb = 768
1419
+ sample_rate = 44100
1420
+ duration = 7
1421
+ fmin = 50
1422
+ fmax = 8000 # 14000
1423
+ n_fft = 1024 # 1028
1424
+ hop_size = 320
1425
+ mel_bins = 64
1426
+ window_size = 1024
1427
+
1428
+ # PROJECTION SPACE CONFIG
1429
+ d_proj = 1024
1430
+ temperature = 0.003
1431
+
1432
+ # TRAINING AND EVALUATION CONFIG
1433
+ num_classes = 527
1434
+ batch_size = 1024
1435
+ demo = False
1436
+
1437
+
1438
+ clapConfig = ClapConfig()
1439
+ clap = CLAP(
1440
+ audioenc_name=clapConfig.audioenc_name,
1441
+ sample_rate=clapConfig.sample_rate,
1442
+ window_size=clapConfig.window_size,
1443
+ hop_size=clapConfig.hop_size,
1444
+ mel_bins=clapConfig.mel_bins,
1445
+ fmin=clapConfig.fmin,
1446
+ fmax=clapConfig.fmax,
1447
+ classes_num=clapConfig.num_classes,
1448
+ out_emb=clapConfig.out_emb,
1449
+ text_model=clapConfig.text_model,
1450
+ transformer_embed_dim=clapConfig.transformer_embed_dim,
1451
+ d_proj=clapConfig.d_proj
1452
+ )
1453
+
1454
+ model_repo = "microsoft/msclap"
1455
+ model_name = {
1456
+ '2022': 'CLAP_weights_2022.pth',
1457
+ '2023': 'CLAP_weights_2023.pth',
1458
+ 'clapcap': 'clapcap_weights_2023.pth'
1459
+ }
1460
+
1461
+ version = '2023'
1462
+ model_fp = hf_hub_download(model_repo, model_name[version])
1463
+
1464
+ model_state_dict = torch.load(model_fp, map_location=torch.device('cpu'))['model']
1465
+ clap.load_state_dict(model_state_dict, strict=False)
1466
+ # clap.eval()
1467
+
1468
+ clap_audio_encoder = clap.audio_encoder.to(device)
1469
+
1470
+ # ENGLISH_AUDIO_DIR = r"/home/IITB/ai-at-ieor/23m1521/datasets/Vaani/Audios/English"
1471
+ # audio_files = [os.path.join(ENGLISH_AUDIO_DIR, i) for i in os.listdir(ENGLISH_AUDIO_DIR) if i.endswith(".wav")]
1472
+ # audio_embedding = get_audio_embeddings(audio_files, clap_audio_encoder)
1473
+ # print("CLAP Audio Encoder Embeddings:", audio_embedding.shape) # [5, 1024]
1474
+
1475
+
1476
+ # ==================================================================
1477
+ # C L A P - L o R A - M O D E L
1478
+ # ==================================================================
1479
+ LoRAconfig = {
1480
+ "peft_type": "LORA",
1481
+ "task_type": "FEATURE_EXTRACTION",
1482
+ "inference_mode": False,
1483
+ "r": 16,
1484
+ "target_modules": ["qkv", "fc1", "fc2", "proj", "linear1", "linear2"],
1485
+ "lora_alpha": 32,
1486
+ "lora_dropout": 0.05,
1487
+ "fan_in_fan_out": False,
1488
+ "bias": "all",
1489
+ }
1490
+ peft_config = get_peft_config(LoRAconfig)
1491
+
1492
+ peft_model = get_peft_model(clap_audio_encoder, peft_config)
1493
+
1494
+ peft_model.print_trainable_parameters()
1495
+
1496
+ peft_clap_audio_encoder = peft_model.base_model
1497
+ # audio_embedding = get_audio_embeddings(audio_files, peft_clap_audio_encoder)
1498
+ # print("CLAP LoRA Audio Encoder Embeddings:", audio_embedding.shape) # [5, 1024]
1499
+
1500
+
1501
+
1502
+ # ==================================================================
1503
+ # O P E N - C L I P - M O D E L
1504
+ # ==================================================================
1505
+ import open_clip
1506
+ open_clip_model, open_clip_imgaug, open_clip_preprocess = open_clip.create_model_and_transforms(
1507
+ model_name='ViT-H-14', pretrained='laion2b_s32b_b79k', device=device
1508
+ )
1509
+
1510
+
1511
+ # ==================================================================
1512
+ # C S I P - M O D U L E
1513
+ # ==================================================================
1514
+ class CSIP(nn.Module):
1515
+ def __init__(self, image_encoder, audio_encoder,
1516
+ dim_img=None, dim_audio=1024, dim_emb=1024):
1517
+ super(CSIP, self).__init__()
1518
+
1519
+ self.image_encoder = image_encoder # CLIPVisionModel
1520
+ self.audio_encoder = audio_encoder # CLAP_audio_encoder
1521
+
1522
+ for param in self.image_encoder.parameters():
1523
+ param.requires_grad = False
1524
+
1525
+ # self.image_proj = nn.Linear(dim_img, dim_emb)
1526
+ self.audio_proj = nn.Linear(dim_audio, dim_emb)
1527
+
1528
+ # Learnable temperature parameter
1529
+ self.log_temp = nn.Parameter(torch.tensor(0.07).log())
1530
+
1531
+ def forward(self, images, audios):
1532
+ # Step 1: Feature extraction
1533
+ # with torch.no_grad():
1534
+ # with torch.inference_mode():
1535
+ image_features = self.image_encoder(images).norm(dim=-1, keepdim=True) # shape: [n, dim_img]
1536
+ audio_features = self.audio_encoder(audios)[0].norm(dim=-1, keepdim=True) # shape: [n, dim_audio]
1537
+
1538
+ # Step 2: Project and normalize
1539
+ image_embeds = F.normalize(image_features) # [n, dim_emb]
1540
+ audio_embeds = F.normalize(self.audio_proj(audio_features), dim=1) # [n, dim_emb]
1541
+
1542
+ # Step 3: Cosine similarity with temperature
1543
+ logits = torch.matmul(image_embeds, audio_embeds.T) * self.log_temp.exp() # [n, n]
1544
+ probs = logits.softmax(dim=1)
1545
+
1546
+ # Step 4: Symmetric cross-entropy loss
1547
+ labels = torch.arange(len(images), device=images.device)
1548
+ # loss_i = F.cross_entropy(logits, labels)
1549
+ loss_a = F.cross_entropy(logits.T, labels)
1550
+ # loss = (loss_i + loss_a) / 2
1551
+ loss = loss_a
1552
+ return loss, logits, probs
1553
+
1554
+
1555
+ # ==================================================================
1556
+ # I M A G E - A U D I O - D A T A S E T
1557
+ # ==================================================================
1558
+ class VaaniImageAudioDataset(torch.utils.data.Dataset):
1559
+ def __init__(self, df):
1560
+ self.image_paths = df.image_path.tolist()
1561
+ self.audio_paths = df.audio_path.tolist()
1562
+
1563
+ def __len__(self):
1564
+ return len(self.audio_paths)
1565
+
1566
+ def __getitem__(self, idx):
1567
+ return {
1568
+ 'image_path': self.image_paths[idx],
1569
+ 'audio_path': self.audio_paths[idx]
1570
+ }
1571
+
1572
+
1573
+ def collate_fn(batch):
1574
+ image_tensor = open_clip_imgaug([Image.open(item['image_path']) for item in batch])['pixel_values']
1575
+ audio_tensor = CLAPAudioProcessor([item['audio_path'] for item in batch], resample=True)
1576
+ return {'image_tensor': torch.stack(image_tensor), 'audio_tensor': audio_tensor}
1577
+
1578
+
1579
+ # preprocessed_audio = CLAPAudioProcessor(audio_files, resample=True)
1580
+ # clip_vision_processor = clip_vision_processor
1581
+
1582
+ train_df = pd.read_csv("/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/Img_Audio_Alignment/available_img_audios_TRAIN.csv")
1583
+ test_df = pd.read_csv("/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/Img_Audio_Alignment/available_img_audios_TEST.csv")
1584
+ train_dataset = VaaniImageAudioDataset(train_df)
1585
+ test_dataset = VaaniImageAudioDataset(test_df)
1586
+ BATCH_SIZE = 64
1587
+
1588
+ print('Train Dataset:', len(train_dataset))
1589
+ print('Test Dataset:', len(test_dataset))
1590
+
1591
+
1592
+ train_dataloader = torch.utils.data.DataLoader(
1593
+ train_dataset,
1594
+ batch_size=BATCH_SIZE,
1595
+ shuffle=True,
1596
+ num_workers=48,
1597
+ collate_fn=collate_fn,
1598
+ pin_memory=True,
1599
+ drop_last=False,
1600
+ persistent_workers=True
1601
+ )
1602
+
1603
+ test_dataloader = torch.utils.data.DataLoader(
1604
+ test_dataset,
1605
+ batch_size=BATCH_SIZE,
1606
+ shuffle=False,
1607
+ num_workers=48,
1608
+ collate_fn=collate_fn,
1609
+ pin_memory=True,
1610
+ drop_last=False,
1611
+ persistent_workers=True
1612
+ )
1613
+
1614
+ batch = next(iter(train_dataloader))
1615
+ image_tensor_batch = batch['image_tensor']
1616
+ audio_tensor_batch = batch['audio_tensor']
1617
+ print("Image batch shape:", image_tensor_batch.shape) # [BATCH_SIZE, 3, 224, 224]
1618
+ print("Audio batch shape:", audio_tensor_batch.shape) # [BATCH_SIZE, 1, 44100]
1619
+
1620
+
1621
+ csip_model = CSIP(open_clip_model.visual, peft_clap_audio_encoder).to(device)
1622
+
1623
+ from torchinfo import summary
1624
+ import subprocess
1625
+ summary(model=csip_model,
1626
+ input_data=((image_tensor_batch.to(device)), (audio_tensor_batch.to(device))),
1627
+ # input_size = (1, 3, config.IMAGE_SIZE, config.IMAGE_SIZE),
1628
+ dtypes=[torch.long],
1629
+ col_names = ["input_size", "output_size", "num_params", "trainable", "params_percent"],
1630
+ col_width=20,
1631
+ row_settings=["var_names"],
1632
+ depth = 2,
1633
+ # verbose=2,
1634
+ # device=device
1635
+ )
1636
+
1637
+ # loss, logits, probs = csip_model(batch['image_tensor'].to(device), batch['audio_tensor'].to(device))
1638
+ # loss, logits, probs, logits.shape, probs.shape
1639
+
1640
+
1641
+ def train_batch(model, images, audio, optimizer):
1642
+ model.train()
1643
+ optimizer.zero_grad()
1644
+ loss, logits, probs = model(images, audio)
1645
+ loss.backward()
1646
+ optimizer.step()
1647
+ return loss.item(), logits, probs
1648
+
1649
+ @torch.no_grad()
1650
+ def evaluate_batch(model, images, audio):
1651
+ model.eval()
1652
+ loss, logits, probs = model(images, audio)
1653
+ return loss.item(), logits, probs
1654
+
1655
+ def save_checkpoint(state, checkpoint_dir, epoch, max_checkpoints=2):
1656
+ filename = f"csip_best_epoch_{epoch+1}.pt"
1657
+ path = os.path.join(checkpoint_dir, filename)
1658
+ torch.save(state, path)
1659
+ checkpoints = sorted(
1660
+ [f for f in os.listdir(checkpoint_dir) if f.startswith("csip_best_epoch_")],
1661
+ key=lambda x: int(x.split("_")[-1].split(".")[0])
1662
+ )
1663
+ while len(checkpoints) > max_checkpoints:
1664
+ to_delete = checkpoints.pop(0)
1665
+ os.remove(os.path.join(checkpoint_dir, to_delete))
1666
+
1667
+
1668
+ def load_checkpoint(checkpoint_dir, model, optimizer, scheduler):
1669
+ checkpoints = sorted(
1670
+ [f for f in os.listdir(checkpoint_dir) if f.startswith("csip_best_epoch_")],
1671
+ key=lambda x: int(x.split("_")[-1].split(".")[0])
1672
+ )
1673
+ if not checkpoints:
1674
+ print("No checkpoint found to resume from.")
1675
+ return 0, float("inf")
1676
+
1677
+ best_ckpt = checkpoints[-1]
1678
+ path = os.path.join(checkpoint_dir, best_ckpt)
1679
+ checkpoint = torch.load(path)
1680
+ model.load_state_dict(checkpoint['model_state'])
1681
+ optimizer.load_state_dict(checkpoint['optimizer_state'])
1682
+ scheduler.load_state_dict(checkpoint['scheduler_state'])
1683
+ start_epoch = checkpoint['epoch']
1684
+ best_loss = checkpoint['best_loss']
1685
+ print(f"Resumed training from epoch {start_epoch+1} with best loss {best_loss:.4f}")
1686
+ return start_epoch, best_loss
1687
+
1688
+
1689
+ def fig_to_tensor(fig):
1690
+ """Convert a Matplotlib figure to a tensor suitable for TensorBoard."""
1691
+ buf = io.BytesIO()
1692
+ fig.savefig(buf, format='png')
1693
+ buf.seek(0)
1694
+ image = Image.open(buf).convert("RGB")
1695
+ tensor = tv.transforms.functional.to_tensor(image)
1696
+ buf.close()
1697
+ plt.close(fig)
1698
+ return tensor
1699
+
1700
+ def save_similarity_heatmaps(logits, epoch, loss, save_dir, writer):
1701
+ os.makedirs(os.path.join(save_dir, 'logits'), exist_ok=True)
1702
+ os.makedirs(os.path.join(save_dir, 'probs'), exist_ok=True)
1703
+
1704
+ # --- Raw logits heatmap ---
1705
+ logits_np = logits.detach().cpu().numpy()
1706
+ fig_logits = plt.figure(figsize=(8, 6))
1707
+ sns.heatmap(logits_np, square=True, cmap="Blues", cbar=True, annot=False)
1708
+ plt.title(f"Raw Logits Heatmap — Epoch {epoch+1}, Loss {loss:.4f}")
1709
+ plt.xlabel("Audio Index")
1710
+ plt.ylabel("Image Index")
1711
+ raw_path = os.path.join(save_dir, 'logits', f"raw_logits_epoch_{epoch+1}_loss_{loss:.4f}.png")
1712
+ fig_logits.savefig(raw_path)
1713
+ writer.add_image("Heatmap/RawLogits", fig_to_tensor(fig_logits), global_step=epoch+1)
1714
+
1715
+ # --- Softmax probs heatmap ---
1716
+ probs_np = logits.softmax(dim=1).cpu().numpy()
1717
+ fig_probs = plt.figure(figsize=(8, 6))
1718
+ sns.heatmap(probs_np, square=True, cmap="Blues", cbar=True, annot=False)
1719
+ plt.title(f"Softmax Probabilities Heatmap — Epoch {epoch+1}, Loss {loss:.4f}")
1720
+ plt.xlabel("Audio Index")
1721
+ plt.ylabel("Image Index")
1722
+ prob_path = os.path.join(save_dir, "probs", f"probs_epoch_{epoch+1}_loss_{loss:.4f}.png")
1723
+ fig_probs.savefig(prob_path)
1724
+ writer.add_image("Heatmap/SoftmaxProbs", fig_to_tensor(fig_probs), global_step=epoch+1)
1725
+
1726
+
1727
+
1728
+ def train_model(model, train_loader, test_loader,
1729
+ optimizer, scheduler, device, log_dir,
1730
+ checkpoint_dir, resume=False, epochs=10):
1731
+
1732
+ os.makedirs(log_dir, exist_ok=True)
1733
+ os.makedirs(checkpoint_dir, exist_ok=True)
1734
+ csv_path = os.path.join(log_dir, "training_log.csv")
1735
+
1736
+ writer = SummaryWriter(log_dir=log_dir)
1737
+
1738
+ start_epoch = 0
1739
+ best_loss = float("inf")
1740
+ best_epoch = -1
1741
+
1742
+ if resume:
1743
+ start_epoch, best_loss = load_checkpoint(checkpoint_dir, model, optimizer, scheduler)
1744
+
1745
+ # If resuming, don't overwrite the CSV
1746
+ if not (resume and os.path.exists(csv_path)):
1747
+ with open(csv_path, mode='w', newline='') as f:
1748
+ writer_csv = csv.writer(f)
1749
+ writer_csv.writerow(["Epoch", "Best Epoch", "Train Loss", "Test Loss", "Best Loss", "Learning Rate"])
1750
+
1751
+ for epoch in trange(start_epoch, epochs, colour='yellow', dynamic_ncols=True):
1752
+ train_losses = []
1753
+ test_losses = []
1754
+
1755
+ train_loop = tqdm(train_loader, desc=f"[TrainEp {epoch+1}]", colour='blue', dynamic_ncols=True)
1756
+ for batch in train_loop:
1757
+ images = batch['image_tensor'].to(device)
1758
+ audios = batch['audio_tensor'].to(device)
1759
+ loss, logits, probs = train_batch(model, images, audios, optimizer)
1760
+ train_losses.append(loss)
1761
+ train_loop.set_postfix(trainLoss=loss)
1762
+
1763
+ test_loop = tqdm(test_loader, desc=f"[TestEp {epoch+1}]", colour='red', dynamic_ncols=True)
1764
+ for batch in test_loop:
1765
+ images = batch['image_tensor'].to(device)
1766
+ audios = batch['audio_tensor'].to(device)
1767
+ loss, logits, probs = evaluate_batch(model, images, audios)
1768
+ test_losses.append(loss)
1769
+ test_loop.set_postfix(testLoss=loss)
1770
+
1771
+ avg_train_loss = sum(train_losses) / len(train_losses)
1772
+ avg_test_loss = sum(test_losses) / len(test_losses)
1773
+
1774
+ current_lr = optimizer.param_groups[0]['lr']
1775
+
1776
+ writer.add_scalar("Loss/Train", avg_train_loss, epoch + 1)
1777
+ writer.add_scalar("Loss/Test", avg_test_loss, epoch + 1)
1778
+ writer.add_scalar("Learning Rate", current_lr, epoch + 1)
1779
+
1780
+ print(f"Epoch {epoch+1} | Train Loss: {avg_train_loss:.4f} | \
1781
+ Test Loss: {avg_test_loss:.4f} | LR: {current_lr:.2e}")
1782
+
1783
+ if avg_test_loss < best_loss:
1784
+ save_similarity_heatmaps(logits, epoch, avg_test_loss, checkpoint_dir, writer)
1785
+ best_loss = avg_test_loss
1786
+ best_epoch = epoch + 1
1787
+ save_checkpoint({
1788
+ 'epoch': epoch,
1789
+ 'model_state': model.state_dict(),
1790
+ 'optimizer_state': optimizer.state_dict(),
1791
+ 'best_loss': best_loss,
1792
+ 'scheduler_state': scheduler.state_dict() if scheduler else None
1793
+ }, checkpoint_dir, epoch)
1794
+ print(f">>> Saved new best model at epoch {epoch+1}")
1795
+
1796
+ scheduler.step()
1797
+ with open(csv_path, mode='a', newline='') as f:
1798
+ writer_csv = csv.writer(f)
1799
+ writer_csv.writerow([epoch + 1, best_epoch, avg_train_loss, avg_test_loss, best_loss, current_lr])
1800
+
1801
+ writer.close()
1802
+
1803
+
1804
+
1805
+ model_name = "csip_model_openClip_CLAP"
1806
+ learning_rate = 1e-4
1807
+ epochs = 100
1808
+ optimizer = torch.optim.AdamW(csip_model.parameters(), lr=learning_rate)
1809
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-10)
1810
+
1811
+ # subprocess.run([
1812
+ # "rm",
1813
+ # "-rf",
1814
+ # f"/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/Img_Audio_Alignment/{model_name}",
1815
+ # ])
1816
+
1817
+ train_model(
1818
+ model=csip_model,
1819
+ train_loader=train_dataloader,
1820
+ test_loader=test_dataloader,
1821
+ optimizer=optimizer,
1822
+ scheduler=scheduler,
1823
+ device=device,
1824
+ log_dir=f"{model_name}/runs/csip",
1825
+ checkpoint_dir=f"{model_name}/checkpoints/csip",
1826
+ resume=True,
1827
+ epochs=epochs
1828
+ )
1829
+
1830
+
1831
+ # tensorboard --logdir=/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/Img_Audio_Alignment/runs --port=6006
1832
+ # 127.0.0.1:40697
1833
+
1834
+ # tensorboard --logdir=/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/Img_Audio_Alignment/runs --port=6006 --host=0.0.0.0
Vaani/Img_Audio_Alignment/_2.2_OpenCLIP.py ADDED
The diff for this file is too large to render. See raw diff
 
Vaani/Img_Audio_Alignment/_2_Train.ipynb CHANGED
@@ -1,5 +1,79 @@
1
  {
2
  "cells": [
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  {
4
  "cell_type": "code",
5
  "execution_count": 1,
@@ -10,7 +84,46 @@
10
  "name": "stdout",
11
  "output_type": "stream",
12
  "text": [
13
- "Using device: cpu\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  ]
15
  }
16
  ],
@@ -80,7 +193,54 @@
80
  "from huggingface_hub.file_download import hf_hub_download\n",
81
  "from huggingface_hub.file_download import hf_hub_download\n",
82
  "from peft import get_peft_config, get_peft_model\n",
83
- "from transformers import CLIPVisionModel, AutoProcessor"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  ]
85
  },
86
  {
@@ -1478,9 +1638,9 @@
1478
  "\n",
1479
  "model_state_dict = torch.load(model_fp, map_location=torch.device('cpu'))['model']\n",
1480
  "clap.load_state_dict(model_state_dict, strict=False)\n",
1481
- "clap.eval()\n",
1482
  "\n",
1483
- "clap_audio_encoder = clap.audio_encoder.eval().to(device)\n",
1484
  "\n",
1485
  "\n",
1486
  "# ENGLISH_AUDIO_DIR = r\"/home/IITB/ai-at-ieor/23m1521/datasets/Vaani/Audios/English\"\n",
@@ -1517,27 +1677,6 @@
1517
  {
1518
  "cell_type": "code",
1519
  "execution_count": 4,
1520
- "id": "f9998f39",
1521
- "metadata": {},
1522
- "outputs": [
1523
- {
1524
- "data": {
1525
- "text/plain": [
1526
- "'1.26.0'"
1527
- ]
1528
- },
1529
- "execution_count": 4,
1530
- "metadata": {},
1531
- "output_type": "execute_result"
1532
- }
1533
- ],
1534
- "source": [
1535
- "np.__version__"
1536
- ]
1537
- },
1538
- {
1539
- "cell_type": "code",
1540
- "execution_count": 5,
1541
  "id": "16c61e94",
1542
  "metadata": {},
1543
  "outputs": [],
@@ -1567,7 +1706,7 @@
1567
  },
1568
  {
1569
  "cell_type": "code",
1570
- "execution_count": 6,
1571
  "id": "61ef98b9",
1572
  "metadata": {},
1573
  "outputs": [],
@@ -1581,6 +1720,9 @@
1581
  " self.image_encoder = image_encoder # CLIPVisionModel\n",
1582
  " self.audio_encoder = audio_encoder # CLAP_audio_encoder\n",
1583
  "\n",
 
 
 
1584
  " # self.image_proj = nn.Linear(dim_img, dim_emb)\n",
1585
  " self.audio_proj = nn.Linear(dim_audio, dim_emb)\n",
1586
  "\n",
@@ -1589,9 +1731,9 @@
1589
  "\n",
1590
  " def forward(self, images, audios):\n",
1591
  " # Step 1: Feature extraction\n",
1592
- " with torch.no_grad():\n",
1593
- " with torch.inference_mode():\n",
1594
- " image_features = self.image_encoder(images).pooler_output # shape: [n, dim_img]\n",
1595
  " audio_features = self.audio_encoder(audios)[0] # shape: [n, dim_audio]\n",
1596
  "\n",
1597
  " # Step 2: Project and normalize\n",
@@ -1612,7 +1754,7 @@
1612
  },
1613
  {
1614
  "cell_type": "code",
1615
- "execution_count": 7,
1616
  "id": "b1a15b19",
1617
  "metadata": {},
1618
  "outputs": [
@@ -1665,7 +1807,7 @@
1665
  },
1666
  {
1667
  "cell_type": "code",
1668
- "execution_count": 8,
1669
  "id": "166105cd",
1670
  "metadata": {},
1671
  "outputs": [
@@ -1673,8 +1815,8 @@
1673
  "name": "stdout",
1674
  "output_type": "stream",
1675
  "text": [
1676
- "Image batch: torch.Size([32, 3, 224, 224])\n",
1677
- "Audio batch: torch.Size([32, 308700])\n"
1678
  ]
1679
  }
1680
  ],
@@ -1702,23 +1844,205 @@
1702
  ")\n",
1703
  "\n",
1704
  "batch = next(iter(train_dataloader))\n",
1705
- "print(\"Image batch:\", batch['image_tensor'].shape)\n",
1706
- "print(\"Audio batch:\", batch['audio_tensor'].shape)"
 
 
1707
  ]
1708
  },
1709
  {
1710
  "cell_type": "code",
1711
- "execution_count": 9,
1712
  "id": "d3b0a29f",
1713
  "metadata": {},
1714
  "outputs": [],
1715
  "source": [
1716
- "csip_model = CSIP(clip_vision_model.eval(), peft_clap_audio_encoder).to(device)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1717
  ]
1718
  },
1719
  {
1720
  "cell_type": "code",
1721
  "execution_count": 10,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1722
  "id": "2475318a",
1723
  "metadata": {},
1724
  "outputs": [],
@@ -1729,7 +2053,7 @@
1729
  },
1730
  {
1731
  "cell_type": "code",
1732
- "execution_count": 11,
1733
  "id": "dc748c49",
1734
  "metadata": {},
1735
  "outputs": [],
@@ -1896,21 +2220,26 @@
1896
  }
1897
  ],
1898
  "source": [
1899
- "learning_rate = 1e-4\n",
1900
- "epochs = 10\n",
1901
  "optimizer = torch.optim.AdamW(csip_model.parameters(), lr=learning_rate)\n",
 
 
 
 
1902
  "\n",
1903
  "train_model(\n",
1904
- " model=csip_model,\n",
1905
- " train_loader=train_dataloader,\n",
1906
- " test_loader=test_dataloader,\n",
1907
- " optimizer=optimizer,\n",
1908
- " device=device,\n",
1909
- " log_dir=\"runs/csip\",\n",
1910
- " checkpoint_dir=\"checkpoints/csip\",\n",
1911
- " resume=True,\n",
1912
- " epochs=epochs\n",
1913
- " )"
 
1914
  ]
1915
  }
1916
  ],
 
1
  {
2
  "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "1ab9f586",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stdout",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "Using device: cpu\n",
14
+ "Author: Ashish\n",
15
+ "\n",
16
+ "Last updated: 2025-05-23T00:51:29.906301+05:30\n",
17
+ "\n",
18
+ "Python implementation: CPython\n",
19
+ "Python version : 3.11.11\n",
20
+ "IPython version : 9.1.0\n",
21
+ "\n",
22
+ "conda environment: clap\n",
23
+ "\n",
24
+ "Compiler : GCC 11.2.0\n",
25
+ "OS : Linux\n",
26
+ "Release : 4.18.0-513.5.1.el8_9.x86_64\n",
27
+ "Machine : x86_64\n",
28
+ "Processor : x86_64\n",
29
+ "CPU cores : 48\n",
30
+ "Architecture: 64bit\n",
31
+ "\n",
32
+ "Hostname: login01\n",
33
+ "\n",
34
+ "peft : 0.15.2\n",
35
+ "torchlibrosa : 0.1.0\n",
36
+ "re : 2.2.1\n",
37
+ "numpy : 1.26.0\n",
38
+ "torchaudio : 2.1.2\n",
39
+ "pandas : 2.2.3\n",
40
+ "huggingface_hub: 0.31.2\n",
41
+ "PIL : 11.1.0\n",
42
+ "torchvision : 0.16.2\n",
43
+ "transformers : 4.51.3\n",
44
+ "csv : 1.0\n",
45
+ "matplotlib : 3.10.1\n",
46
+ "torch : 2.1.2\n",
47
+ "seaborn : 0.13.2\n",
48
+ "watermark : 2.5.0\n",
49
+ "tqdm : 4.67.1\n",
50
+ "sys : 3.11.11 (main, Dec 11 2024, 16:28:39) [GCC 11.2.0]\n",
51
+ "\n",
52
+ "GPU Info: NVIDIA drivers do not appear to be installed on this machine.\n",
53
+ "\n"
54
+ ]
55
+ },
56
+ {
57
+ "ename": "AttributeError",
58
+ "evalue": "'NoneType' object has no attribute '__dict__'",
59
+ "output_type": "error",
60
+ "traceback": [
61
+ "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
62
+ "\u001b[31mAttributeError\u001b[39m Traceback (most recent call last)",
63
+ "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[5]\u001b[39m\u001b[32m, line 6\u001b[39m\n\u001b[32m 4\u001b[39m spec = importlib.util.spec_from_file_location(\u001b[33m\"\u001b[39m\u001b[33mopenclip_module\u001b[39m\u001b[33m\"\u001b[39m, module_path)\n\u001b[32m 5\u001b[39m openclip = importlib.util.module_from_spec(spec)\n\u001b[32m----> \u001b[39m\u001b[32m6\u001b[39m \u001b[43mspec\u001b[49m\u001b[43m.\u001b[49m\u001b[43mloader\u001b[49m\u001b[43m.\u001b[49m\u001b[43mexec_module\u001b[49m\u001b[43m(\u001b[49m\u001b[43mopenclip\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 8\u001b[39m \u001b[38;5;66;03m# Now access your function like this:\u001b[39;00m\n\u001b[32m 9\u001b[39m openclip.your_function_name()\n",
64
+ "\u001b[36mFile \u001b[39m\u001b[32m<frozen importlib._bootstrap_external>:940\u001b[39m, in \u001b[36mexec_module\u001b[39m\u001b[34m(self, module)\u001b[39m\n",
65
+ "\u001b[36mFile \u001b[39m\u001b[32m<frozen importlib._bootstrap>:241\u001b[39m, in \u001b[36m_call_with_frames_removed\u001b[39m\u001b[34m(f, *args, **kwds)\u001b[39m\n",
66
+ "\u001b[36mFile \u001b[39m\u001b[32m~/ashish/MTP/Vaani/Img_Audio_Alignment/_2.2_OpenCLIP.py:2310\u001b[39m\n\u001b[32m 2304\u001b[39m \u001b[38;5;129m@torch\u001b[39m.jit.ignore\n\u001b[32m 2305\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mset_grad_checkpointing\u001b[39m(\u001b[38;5;28mself\u001b[39m, enable=\u001b[38;5;28;01mTrue\u001b[39;00m):\n\u001b[32m 2306\u001b[39m \u001b[38;5;28mself\u001b[39m.grad_checkpointing = enable\n\u001b[32m-> \u001b[39m\u001b[32m2310\u001b[39m \u001b[38;5;129;43m@dataclass\u001b[39;49m\n\u001b[32m 2311\u001b[39m \u001b[38;5;28;43;01mclass\u001b[39;49;00m\u001b[38;5;250;43m \u001b[39;49m\u001b[34;43;01mCLIPVisionCfg\u001b[39;49;00m\u001b[43m:\u001b[49m\n\u001b[32m 2312\u001b[39m \u001b[43m \u001b[49m\u001b[43mlayers\u001b[49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43mUnion\u001b[49m\u001b[43m[\u001b[49m\u001b[43mTuple\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;28;43mint\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mint\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mint\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mint\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mint\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m \u001b[49m\u001b[43m=\u001b[49m\u001b[43m \u001b[49m\u001b[32;43m12\u001b[39;49m\n\u001b[32m 2313\u001b[39m \u001b[43m \u001b[49m\u001b[43mwidth\u001b[49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mint\u001b[39;49m\u001b[43m \u001b[49m\u001b[43m=\u001b[49m\u001b[43m \u001b[49m\u001b[32;43m768\u001b[39;49m\n",
67
+ "\u001b[36mFile \u001b[39m\u001b[32m~/.conda/envs/clap/lib/python3.11/dataclasses.py:1232\u001b[39m, in \u001b[36mdataclass\u001b[39m\u001b[34m(cls, init, repr, eq, order, unsafe_hash, frozen, match_args, kw_only, slots, weakref_slot)\u001b[39m\n\u001b[32m 1229\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m wrap\n\u001b[32m 1231\u001b[39m \u001b[38;5;66;03m# We're called as @dataclass without parens.\u001b[39;00m\n\u001b[32m-> \u001b[39m\u001b[32m1232\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mwrap\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mcls\u001b[39;49m\u001b[43m)\u001b[49m\n",
68
+ "\u001b[36mFile \u001b[39m\u001b[32m~/.conda/envs/clap/lib/python3.11/dataclasses.py:1222\u001b[39m, in \u001b[36mdataclass.<locals>.wrap\u001b[39m\u001b[34m(cls)\u001b[39m\n\u001b[32m 1221\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mwrap\u001b[39m(\u001b[38;5;28mcls\u001b[39m):\n\u001b[32m-> \u001b[39m\u001b[32m1222\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_process_class\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mcls\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minit\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mrepr\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43meq\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43morder\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43munsafe_hash\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1223\u001b[39m \u001b[43m \u001b[49m\u001b[43mfrozen\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmatch_args\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkw_only\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mslots\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1224\u001b[39m \u001b[43m \u001b[49m\u001b[43mweakref_slot\u001b[49m\u001b[43m)\u001b[49m\n",
69
+ "\u001b[36mFile \u001b[39m\u001b[32m~/.conda/envs/clap/lib/python3.11/dataclasses.py:947\u001b[39m, in \u001b[36m_process_class\u001b[39m\u001b[34m(cls, init, repr, eq, order, unsafe_hash, frozen, match_args, kw_only, slots, weakref_slot)\u001b[39m\n\u001b[32m 942\u001b[39m dataclasses = sys.modules[\u001b[34m__name__\u001b[39m]\n\u001b[32m 943\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m name, \u001b[38;5;28mtype\u001b[39m \u001b[38;5;129;01min\u001b[39;00m cls_annotations.items():\n\u001b[32m 944\u001b[39m \u001b[38;5;66;03m# See if this is a marker to change the value of kw_only.\u001b[39;00m\n\u001b[32m 945\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m (_is_kw_only(\u001b[38;5;28mtype\u001b[39m, dataclasses)\n\u001b[32m 946\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m (\u001b[38;5;28misinstance\u001b[39m(\u001b[38;5;28mtype\u001b[39m, \u001b[38;5;28mstr\u001b[39m)\n\u001b[32m--> \u001b[39m\u001b[32m947\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m \u001b[43m_is_type\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mtype\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mcls\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdataclasses\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdataclasses\u001b[49m\u001b[43m.\u001b[49m\u001b[43mKW_ONLY\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 948\u001b[39m \u001b[43m \u001b[49m\u001b[43m_is_kw_only\u001b[49m\u001b[43m)\u001b[49m)):\n\u001b[32m 949\u001b[39m \u001b[38;5;66;03m# Switch the default to kw_only=True, and ignore this\u001b[39;00m\n\u001b[32m 950\u001b[39m \u001b[38;5;66;03m# annotation: it's not a real field.\u001b[39;00m\n\u001b[32m 951\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m KW_ONLY_seen:\n\u001b[32m 952\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\u001b[33mf\u001b[39m\u001b[33m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mname\u001b[38;5;132;01m!r}\u001b[39;00m\u001b[33m is KW_ONLY, but KW_ONLY \u001b[39m\u001b[33m'\u001b[39m\n\u001b[32m 953\u001b[39m \u001b[33m'\u001b[39m\u001b[33mhas already been specified\u001b[39m\u001b[33m'\u001b[39m)\n",
70
+ "\u001b[36mFile \u001b[39m\u001b[32m~/.conda/envs/clap/lib/python3.11/dataclasses.py:712\u001b[39m, in \u001b[36m_is_type\u001b[39m\u001b[34m(annotation, cls, a_module, a_type, is_type_predicate)\u001b[39m\n\u001b[32m 708\u001b[39m module_name = match.group(\u001b[32m1\u001b[39m)\n\u001b[32m 709\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m module_name:\n\u001b[32m 710\u001b[39m \u001b[38;5;66;03m# No module name, assume the class's module did\u001b[39;00m\n\u001b[32m 711\u001b[39m \u001b[38;5;66;03m# \"from dataclasses import InitVar\".\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m712\u001b[39m ns = \u001b[43msys\u001b[49m\u001b[43m.\u001b[49m\u001b[43mmodules\u001b[49m\u001b[43m.\u001b[49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mcls\u001b[39;49m\u001b[43m.\u001b[49m\u001b[34;43m__module__\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m.\u001b[49m\u001b[34;43m__dict__\u001b[39;49m\n\u001b[32m 713\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m 714\u001b[39m \u001b[38;5;66;03m# Look up module_name in the class's module.\u001b[39;00m\n\u001b[32m 715\u001b[39m module = sys.modules.get(\u001b[38;5;28mcls\u001b[39m.\u001b[34m__module__\u001b[39m)\n",
71
+ "\u001b[31mAttributeError\u001b[39m: 'NoneType' object has no attribute '__dict__'"
72
+ ]
73
+ }
74
+ ],
75
+ "source": []
76
+ },
77
  {
78
  "cell_type": "code",
79
  "execution_count": 1,
 
84
  "name": "stdout",
85
  "output_type": "stream",
86
  "text": [
87
+ "Using device: cpu\n",
88
+ "Author: Ashish\n",
89
+ "\n",
90
+ "Last updated: 2025-05-18T20:43:23.098162+05:30\n",
91
+ "\n",
92
+ "Python implementation: CPython\n",
93
+ "Python version : 3.11.11\n",
94
+ "IPython version : 9.1.0\n",
95
+ "\n",
96
+ "conda environment: clap\n",
97
+ "\n",
98
+ "Compiler : GCC 11.2.0\n",
99
+ "OS : Linux\n",
100
+ "Release : 4.18.0-513.5.1.el8_9.x86_64\n",
101
+ "Machine : x86_64\n",
102
+ "Processor : x86_64\n",
103
+ "CPU cores : 48\n",
104
+ "Architecture: 64bit\n",
105
+ "\n",
106
+ "Hostname: login01\n",
107
+ "\n",
108
+ "pandas : 2.2.3\n",
109
+ "torch : 2.1.2\n",
110
+ "sys : 3.11.11 (main, Dec 11 2024, 16:28:39) [GCC 11.2.0]\n",
111
+ "matplotlib : 3.10.1\n",
112
+ "re : 2.2.1\n",
113
+ "seaborn : 0.13.2\n",
114
+ "numpy : 1.26.0\n",
115
+ "PIL : 11.1.0\n",
116
+ "torchaudio : 2.1.2\n",
117
+ "peft : 0.15.2\n",
118
+ "transformers : 4.51.3\n",
119
+ "tqdm : 4.67.1\n",
120
+ "watermark : 2.5.0\n",
121
+ "torchvision : 0.16.2\n",
122
+ "torchlibrosa : 0.1.0\n",
123
+ "huggingface_hub: 0.31.2\n",
124
+ "\n",
125
+ "GPU Info: NVIDIA drivers do not appear to be installed on this machine.\n",
126
+ "\n"
127
  ]
128
  }
129
  ],
 
193
  "from huggingface_hub.file_download import hf_hub_download\n",
194
  "from huggingface_hub.file_download import hf_hub_download\n",
195
  "from peft import get_peft_config, get_peft_model\n",
196
+ "from transformers import CLIPVisionModel, AutoProcessor\n",
197
+ "\n",
198
+ "from watermark import watermark\n",
199
+ "print(watermark(\n",
200
+ " author='Ashish',\n",
201
+ " # email='[email protected]',\n",
202
+ " current_date=True,\n",
203
+ " datename=True,\n",
204
+ " current_time=True,\n",
205
+ " iso8601=True,\n",
206
+ " timezone=True,\n",
207
+ " updated=True,\n",
208
+ " custom_time=None,\n",
209
+ " python=True,\n",
210
+ " # packages=\"torch,torchvision,numpy\",\n",
211
+ " conda=True,\n",
212
+ " hostname=True,\n",
213
+ " machine=True,\n",
214
+ " watermark=False,\n",
215
+ " iversions=True,\n",
216
+ " gpu=True,\n",
217
+ " globals_=globals()\n",
218
+ "))\n"
219
+ ]
220
+ },
221
+ {
222
+ "cell_type": "code",
223
+ "execution_count": 3,
224
+ "id": "f7260f1f",
225
+ "metadata": {},
226
+ "outputs": [
227
+ {
228
+ "data": {
229
+ "text/plain": [
230
+ "'ViT-H-14'"
231
+ ]
232
+ },
233
+ "execution_count": 3,
234
+ "metadata": {},
235
+ "output_type": "execute_result"
236
+ }
237
+ ],
238
+ "source": [
239
+ "model_name = 'ViT-H-14'\n",
240
+ "HF_HUB_PREFIX = 'hf-hub:'\n",
241
+ "has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX)\n",
242
+ "model_name = model_name.replace('/', '-')\n",
243
+ "model_name"
244
  ]
245
  },
246
  {
 
1638
  "\n",
1639
  "model_state_dict = torch.load(model_fp, map_location=torch.device('cpu'))['model']\n",
1640
  "clap.load_state_dict(model_state_dict, strict=False)\n",
1641
+ "# clap.eval()\n",
1642
  "\n",
1643
+ "clap_audio_encoder = clap.audio_encoder.to(device)\n",
1644
  "\n",
1645
  "\n",
1646
  "# ENGLISH_AUDIO_DIR = r\"/home/IITB/ai-at-ieor/23m1521/datasets/Vaani/Audios/English\"\n",
 
1677
  {
1678
  "cell_type": "code",
1679
  "execution_count": 4,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1680
  "id": "16c61e94",
1681
  "metadata": {},
1682
  "outputs": [],
 
1706
  },
1707
  {
1708
  "cell_type": "code",
1709
+ "execution_count": 5,
1710
  "id": "61ef98b9",
1711
  "metadata": {},
1712
  "outputs": [],
 
1720
  " self.image_encoder = image_encoder # CLIPVisionModel\n",
1721
  " self.audio_encoder = audio_encoder # CLAP_audio_encoder\n",
1722
  "\n",
1723
+ " for param in self.image_encoder.parameters():\n",
1724
+ " param.requires_grad = False\n",
1725
+ " \n",
1726
  " # self.image_proj = nn.Linear(dim_img, dim_emb)\n",
1727
  " self.audio_proj = nn.Linear(dim_audio, dim_emb)\n",
1728
  "\n",
 
1731
  "\n",
1732
  " def forward(self, images, audios):\n",
1733
  " # Step 1: Feature extraction\n",
1734
+ " # with torch.no_grad():\n",
1735
+ " # with torch.inference_mode():\n",
1736
+ " image_features = self.image_encoder(images).pooler_output # shape: [n, dim_img]\n",
1737
  " audio_features = self.audio_encoder(audios)[0] # shape: [n, dim_audio]\n",
1738
  "\n",
1739
  " # Step 2: Project and normalize\n",
 
1754
  },
1755
  {
1756
  "cell_type": "code",
1757
+ "execution_count": 6,
1758
  "id": "b1a15b19",
1759
  "metadata": {},
1760
  "outputs": [
 
1807
  },
1808
  {
1809
  "cell_type": "code",
1810
+ "execution_count": 7,
1811
  "id": "166105cd",
1812
  "metadata": {},
1813
  "outputs": [
 
1815
  "name": "stdout",
1816
  "output_type": "stream",
1817
  "text": [
1818
+ "Image batch shape: torch.Size([32, 3, 224, 224])\n",
1819
+ "Audio batch shape: torch.Size([32, 308700])\n"
1820
  ]
1821
  }
1822
  ],
 
1844
  ")\n",
1845
  "\n",
1846
  "batch = next(iter(train_dataloader))\n",
1847
+ "image_tensor_batch = batch['image_tensor']\n",
1848
+ "audio_tensor_batch = batch['audio_tensor']\n",
1849
+ "print(\"Image batch shape:\", image_tensor_batch.shape) # [BATCH_SIZE, 3, 224, 224]\n",
1850
+ "print(\"Audio batch shape:\", audio_tensor_batch.shape) # [BATCH_SIZE, 1, 44100]"
1851
  ]
1852
  },
1853
  {
1854
  "cell_type": "code",
1855
+ "execution_count": 8,
1856
  "id": "d3b0a29f",
1857
  "metadata": {},
1858
  "outputs": [],
1859
  "source": [
1860
+ "csip_model = CSIP(clip_vision_model, peft_clap_audio_encoder).to(device)"
1861
+ ]
1862
+ },
1863
+ {
1864
+ "cell_type": "code",
1865
+ "execution_count": 14,
1866
+ "id": "6b8f3009",
1867
+ "metadata": {},
1868
+ "outputs": [
1869
+ {
1870
+ "data": {
1871
+ "text/plain": [
1872
+ "=========================================================================================================================================================================================\n",
1873
+ "Layer (type (var_name)) Input Shape Output Shape Param # Trainable Param %\n",
1874
+ "=========================================================================================================================================================================================\n",
1875
+ "CSIP (CSIP) [32, 3, 224, 224] -- 1 Partial 0.00%\n",
1876
+ "├─CLIPVisionModel (image_encoder) [32, 3, 224, 224] [32, 768] -- False --\n",
1877
+ "│ └─CLIPVisionTransformer (vision_model) -- [32, 768] (87,456,000) False 71.33%\n",
1878
+ "├─LoraModel (audio_encoder) [32, 308700] [32, 1024] -- Partial --\n",
1879
+ "│ └─AudioEncoder (model) -- -- 34,355,927 Partial 28.02%\n",
1880
+ "├─Linear (audio_proj) [32, 1024] [32, 768] 787,200 True 0.64%\n",
1881
+ "=========================================================================================================================================================================================\n",
1882
+ "Total params: 122,599,128\n",
1883
+ "Trainable params: 2,076,695\n",
1884
+ "Non-trainable params: 120,522,433\n",
1885
+ "Total mult-adds (Units.GIGABYTES): 43.51\n",
1886
+ "=========================================================================================================================================================================================\n",
1887
+ "Input size (MB): 58.78\n",
1888
+ "Forward/backward pass size (MB): 11962.14\n",
1889
+ "Params size (MB): 489.11\n",
1890
+ "Estimated Total Size (MB): 12510.03\n",
1891
+ "========================================================================================================================================================================================="
1892
+ ]
1893
+ },
1894
+ "execution_count": 14,
1895
+ "metadata": {},
1896
+ "output_type": "execute_result"
1897
+ }
1898
+ ],
1899
+ "source": [
1900
+ "from torchinfo import summary\n",
1901
+ "summary(model=csip_model,\n",
1902
+ " input_data=((image_tensor_batch), (audio_tensor_batch)),\n",
1903
+ " # input_size = (1, 3, config.IMAGE_SIZE, config.IMAGE_SIZE),\n",
1904
+ " dtypes=[torch.long],\n",
1905
+ " col_names = [\"input_size\", \"output_size\", \"num_params\", \"trainable\", \"params_percent\"],\n",
1906
+ " col_width=20,\n",
1907
+ " row_settings=[\"var_names\"],\n",
1908
+ " depth = 2,\n",
1909
+ " # verbose=2,\n",
1910
+ " # device=device\n",
1911
+ ")"
1912
+ ]
1913
+ },
1914
+ {
1915
+ "cell_type": "code",
1916
+ "execution_count": 12,
1917
+ "id": "956b682a",
1918
+ "metadata": {},
1919
+ "outputs": [
1920
+ {
1921
+ "data": {
1922
+ "text/plain": [
1923
+ "==========================================================================================================================================================================\n",
1924
+ "Layer (type (var_name)) Input Shape Output Shape Param # Trainable Param %\n",
1925
+ "==========================================================================================================================================================================\n",
1926
+ "CLIPVisionModel (CLIPVisionModel) [32, 3, 224, 224] [32, 768] -- False --\n",
1927
+ "├─CLIPVisionTransformer (vision_model) -- [32, 768] -- False --\n",
1928
+ "│ └─CLIPVisionEmbeddings (embeddings) [32, 3, 224, 224] [32, 50, 768] 768 False 0.00%\n",
1929
+ "│ │ └─Conv2d (patch_embedding) [32, 3, 224, 224] [32, 768, 7, 7] (2,359,296) False 2.70%\n",
1930
+ "│ │ └─Embedding (position_embedding) [1, 50] [1, 50, 768] (38,400) False 0.04%\n",
1931
+ "│ └─LayerNorm (pre_layrnorm) [32, 50, 768] [32, 50, 768] (1,536) False 0.00%\n",
1932
+ "│ └─CLIPEncoder (encoder) -- [32, 50, 768] -- False --\n",
1933
+ "│ │ └─ModuleList (layers) -- -- (85,054,464) False 97.25%\n",
1934
+ "│ └─LayerNorm (post_layernorm) [32, 768] [32, 768] (1,536) False 0.00%\n",
1935
+ "==========================================================================================================================================================================\n",
1936
+ "Total params: 87,456,000\n",
1937
+ "Trainable params: 0\n",
1938
+ "Non-trainable params: 87,456,000\n",
1939
+ "Total mult-adds (Units.GIGABYTES): 6.42\n",
1940
+ "==========================================================================================================================================================================\n",
1941
+ "Input size (MB): 19.27\n",
1942
+ "Forward/backward pass size (MB): 1317.58\n",
1943
+ "Params size (MB): 349.82\n",
1944
+ "Estimated Total Size (MB): 1686.67\n",
1945
+ "=========================================================================================================================================================================="
1946
+ ]
1947
+ },
1948
+ "execution_count": 12,
1949
+ "metadata": {},
1950
+ "output_type": "execute_result"
1951
+ }
1952
+ ],
1953
+ "source": [
1954
+ "from torchinfo import summary\n",
1955
+ "summary(model=csip_model.image_encoder,\n",
1956
+ " input_data=(image_tensor_batch),\n",
1957
+ " # input_size = (1, 3, config.IMAGE_SIZE, config.IMAGE_SIZE),\n",
1958
+ " dtypes=[torch.long],\n",
1959
+ " col_names = [\"input_size\", \"output_size\", \"num_params\", \"trainable\", \"params_percent\"],\n",
1960
+ " col_width=20,\n",
1961
+ " row_settings=[\"var_names\"],\n",
1962
+ " depth = 3,\n",
1963
+ " # verbose=2,\n",
1964
+ " # device=device\n",
1965
+ ")"
1966
+ ]
1967
+ },
1968
+ {
1969
+ "cell_type": "code",
1970
+ "execution_count": 13,
1971
+ "id": "0320cdc6",
1972
+ "metadata": {},
1973
+ "outputs": [
1974
+ {
1975
+ "data": {
1976
+ "text/plain": [
1977
+ "====================================================================================================================================================================================\n",
1978
+ "Layer (type (var_name)) Input Shape Output Shape Param # Trainable Param %\n",
1979
+ "====================================================================================================================================================================================\n",
1980
+ "LoraModel (LoraModel) [32, 308700] [32, 1024] -- Partial --\n",
1981
+ "├─AudioEncoder (model) -- -- -- Partial --\n",
1982
+ "│ └─HTSATWrapper (base) [32, 308700] [32, 768] -- Partial --\n",
1983
+ "│ │ └─HTSAT_Swin_Transformer (htsat) [32, 308700] [32, 768] 32,457,431 Partial 94.47%\n",
1984
+ "│ └─Projection (projection) [32, 768] [32, 1024] -- Partial --\n",
1985
+ "│ │ └─Linear (linear1) [32, 768] [32, 1024] 815,104 Partial 2.37%\n",
1986
+ "│ │ └─Linear (linear2) [32, 1024] [32, 1024] 1,081,344 Partial 3.15%\n",
1987
+ "│ │ └─Dropout (drop) [32, 1024] [32, 1024] -- -- --\n",
1988
+ "│ │ └─LayerNorm (layer_norm) [32, 1024] [32, 1024] 2,048 Partial 0.01%\n",
1989
+ "====================================================================================================================================================================================\n",
1990
+ "Total params: 34,355,927\n",
1991
+ "Trainable params: 1,289,494\n",
1992
+ "Non-trainable params: 33,066,433\n",
1993
+ "Total mult-adds (Units.GIGABYTES): 37.07\n",
1994
+ "====================================================================================================================================================================================\n",
1995
+ "Input size (MB): 39.51\n",
1996
+ "Forward/backward pass size (MB): 10644.36\n",
1997
+ "Params size (MB): 136.15\n",
1998
+ "Estimated Total Size (MB): 10820.02\n",
1999
+ "===================================================================================================================================================================================="
2000
+ ]
2001
+ },
2002
+ "execution_count": 13,
2003
+ "metadata": {},
2004
+ "output_type": "execute_result"
2005
+ }
2006
+ ],
2007
+ "source": [
2008
+ "from torchinfo import summary\n",
2009
+ "summary(model=csip_model.audio_encoder,\n",
2010
+ " input_data=(audio_tensor_batch),\n",
2011
+ " # input_size = (1, 3, config.IMAGE_SIZE, config.IMAGE_SIZE),\n",
2012
+ " dtypes=[torch.long],\n",
2013
+ " col_names = [\"input_size\", \"output_size\", \"num_params\", \"trainable\", \"params_percent\"],\n",
2014
+ " col_width=20,\n",
2015
+ " row_settings=[\"var_names\"],\n",
2016
+ " depth = 3,\n",
2017
+ " # verbose=2,\n",
2018
+ " # device=device\n",
2019
+ ")"
2020
  ]
2021
  },
2022
  {
2023
  "cell_type": "code",
2024
  "execution_count": 10,
2025
+ "id": "f436f5c7",
2026
+ "metadata": {},
2027
+ "outputs": [
2028
+ {
2029
+ "data": {
2030
+ "text/plain": [
2031
+ "0.01693890514457819"
2032
+ ]
2033
+ },
2034
+ "execution_count": 10,
2035
+ "metadata": {},
2036
+ "output_type": "execute_result"
2037
+ }
2038
+ ],
2039
+ "source": [
2040
+ "2_076_695 / 122_599_128"
2041
+ ]
2042
+ },
2043
+ {
2044
+ "cell_type": "code",
2045
+ "execution_count": null,
2046
  "id": "2475318a",
2047
  "metadata": {},
2048
  "outputs": [],
 
2053
  },
2054
  {
2055
  "cell_type": "code",
2056
+ "execution_count": 10,
2057
  "id": "dc748c49",
2058
  "metadata": {},
2059
  "outputs": [],
 
2220
  }
2221
  ],
2222
  "source": [
2223
+ "learning_rate = 1e-1\n",
2224
+ "epochs = 400\n",
2225
  "optimizer = torch.optim.AdamW(csip_model.parameters(), lr=learning_rate)\n",
2226
+ "warmup_epochs = 10\n",
2227
+ "warmup = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.01, total_iters=warmup_epochs)\n",
2228
+ "cosine_restarts = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=90, T_mult=2, eta_min=1e-10) # T_0 adjusted for warmup\n",
2229
+ "scheduler = torch.optim.lr_scheduler.SequentialLR(optimizer, [warmup, cosine_restarts], milestones=[warmup_epochs])\n",
2230
  "\n",
2231
  "train_model(\n",
2232
+ " model=csip_model,\n",
2233
+ " train_loader=train_dataloader,\n",
2234
+ " test_loader=test_dataloader,\n",
2235
+ " optimizer=optimizer,\n",
2236
+ " scheduler=scheduler,\n",
2237
+ " device=device,\n",
2238
+ " log_dir=\"runs/csip\",\n",
2239
+ " checkpoint_dir=\"checkpoints/csip\",\n",
2240
+ " resume=True,\n",
2241
+ " epochs=epochs\n",
2242
+ ")"
2243
  ]
2244
  }
2245
  ],
Vaani/Img_Audio_Alignment/available_img_audios_TEST2.csv ADDED
The diff for this file is too large to render. See raw diff
 
Vaani/Img_Audio_Alignment/available_img_audios_TRAIN2.csv ADDED
The diff for this file is too large to render. See raw diff
 
Vaani/Img_Audio_Alignment/checkpoints/csip/csip_best_epoch_201.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:94c820106ffbfa6de43a6f3939c9a4ad35b1edf26c20939d85c129f8c955425c
3
+ size 509342835
Vaani/Img_Audio_Alignment/checkpoints/csip/csip_best_epoch_202.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8465433973bedc41c3f52b335b0827b5e2f1f379708e8f9d1c2e5105743c5be3
3
+ size 509342835
Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_102_loss_4.1355.png ADDED

Git LFS Details

  • SHA256: b951b616685007b3715c3440328a27490f38a31101a2b0536648b7a7e190c2f2
  • Pointer size: 131 Bytes
  • Size of remote file: 499 kB
Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_103_loss_4.1354.png ADDED

Git LFS Details

  • SHA256: 6bff029f7363d1682f7f0e2eeeeb951ef73c79d509a41389d82c591da4fe36bf
  • Pointer size: 131 Bytes
  • Size of remote file: 502 kB
Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_104_loss_4.1349.png ADDED

Git LFS Details

  • SHA256: 22a5a6698ee12c71c4ab0d243bed8114bc9a21b408043bec0897e21bffbdc983
  • Pointer size: 131 Bytes
  • Size of remote file: 508 kB
Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_105_loss_4.1342.png ADDED

Git LFS Details

  • SHA256: c44e18b5c2077e3f189a0b0de2a299e28a0cee1099ca982197a0ed0d9500e0af
  • Pointer size: 131 Bytes
  • Size of remote file: 505 kB
Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_106_loss_4.1335.png ADDED

Git LFS Details

  • SHA256: 873a3b8e953c0605cd2801973e3b845653be63cf143cba3e369af4e9c2885d77
  • Pointer size: 131 Bytes
  • Size of remote file: 501 kB
Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_107_loss_4.1329.png ADDED

Git LFS Details

  • SHA256: 761b525524bc3a674fa6cf9a3349780cc2fa07851bdcacbb05f7c981862f0c0a
  • Pointer size: 131 Bytes
  • Size of remote file: 497 kB
Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_108_loss_4.1327.png ADDED

Git LFS Details

  • SHA256: 3e8834c3b6a3669ea19c35ced78429ef8f81cb33fb4feef4658c9675fdaf2f00
  • Pointer size: 131 Bytes
  • Size of remote file: 503 kB
Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_109_loss_4.1320.png ADDED

Git LFS Details

  • SHA256: 21cd45be1d367284e419360ab3a0fca461bae6eb6ee87ef9ffad3234e55dc150
  • Pointer size: 131 Bytes
  • Size of remote file: 499 kB
Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_10_loss_4.1517.png ADDED

Git LFS Details

  • SHA256: 589c6d83268a05ae7983a65f4a3eb66417283c54e1fac292f14529a56c77025e
  • Pointer size: 131 Bytes
  • Size of remote file: 525 kB
Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_110_loss_4.1312.png ADDED

Git LFS Details

  • SHA256: 5dc2d570d1237ea91c3d490581a7dc79e8da8deef8cfe58211fd05bdf8d0bde8
  • Pointer size: 131 Bytes
  • Size of remote file: 504 kB
Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_112_loss_4.1300.png ADDED

Git LFS Details

  • SHA256: 6febb2524230e988be57b40161e9f46ed0c353afbc649a13cdf96a04649e9f1e
  • Pointer size: 131 Bytes
  • Size of remote file: 500 kB
Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_113_loss_4.1299.png ADDED

Git LFS Details

  • SHA256: 858d7e174d41ed91470c3a5d4bcdbd128aab862cc4fa75027c82dc844007ade3
  • Pointer size: 131 Bytes
  • Size of remote file: 503 kB
Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_115_loss_4.1281.png ADDED

Git LFS Details

  • SHA256: 91e33053f6a5b20c0f4cc671e9d25bd4eff83fb651db446040445f325e185c00
  • Pointer size: 131 Bytes
  • Size of remote file: 494 kB
Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_116_loss_4.1266.png ADDED

Git LFS Details

  • SHA256: 976ae6dbe0bb774e0d82c605444c81a9799d68123e74699571b9f402d51b285e
  • Pointer size: 131 Bytes
  • Size of remote file: 499 kB
Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_119_loss_4.1249.png ADDED

Git LFS Details

  • SHA256: e63c158d5b4e703b824c722e6963675ab62eda32fee758dbf18686b6b9406a72
  • Pointer size: 131 Bytes
  • Size of remote file: 499 kB
Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_11_loss_4.1514.png ADDED

Git LFS Details

  • SHA256: de2ad1083ebfe9ef69c7f64797be23c317068f1b2c5b3cb73934ebd4d2c1e6b5
  • Pointer size: 131 Bytes
  • Size of remote file: 518 kB
Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_120_loss_4.1244.png ADDED

Git LFS Details

  • SHA256: 6b33d2907d2b2839cc42372df73eaf793fae8c5c5534c30088d230f167b9033e
  • Pointer size: 131 Bytes
  • Size of remote file: 498 kB
Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_122_loss_4.1238.png ADDED

Git LFS Details

  • SHA256: 49a4012047cff7f0d456d1f6b895461c5c006a551ea404c760df0930825a7ac4
  • Pointer size: 131 Bytes
  • Size of remote file: 500 kB
Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_123_loss_4.1229.png ADDED

Git LFS Details

  • SHA256: 51ad2a5595b94bcb214473492f67a8d733dda1423cd8880965f5727e81d27b1e
  • Pointer size: 131 Bytes
  • Size of remote file: 499 kB
Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_125_loss_4.1207.png ADDED

Git LFS Details

  • SHA256: ca2b91ba72290b3d2ecbf0be10d678452e85ba573b25ef7bd8dd60fbb9ad0ddb
  • Pointer size: 131 Bytes
  • Size of remote file: 498 kB
Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_126_loss_4.1206.png ADDED

Git LFS Details

  • SHA256: 51877c1305dac97a0a2c95adc155f22d897ab216b3b2d5ae34926a615cb37ace
  • Pointer size: 131 Bytes
  • Size of remote file: 503 kB
Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_127_loss_4.1205.png ADDED

Git LFS Details

  • SHA256: ea55387e5fc794e53a78cc25daf05c2e30f8d1830f8da3868179ff1056d4d589
  • Pointer size: 131 Bytes
  • Size of remote file: 500 kB
Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_128_loss_4.1183.png ADDED

Git LFS Details

  • SHA256: 4b0aa99e38e54ef8f6c0d9fcaecfb154af19cdb72c52163fa0592d53058fad14
  • Pointer size: 131 Bytes
  • Size of remote file: 503 kB
Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_12_loss_4.1511.png ADDED

Git LFS Details

  • SHA256: 00a1afea752b9283fd3155f7173360f41546668366a10d8f07bf1aec52e0ef79
  • Pointer size: 131 Bytes
  • Size of remote file: 523 kB
Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_130_loss_4.1173.png ADDED

Git LFS Details

  • SHA256: 43f5edf6e5b265b9575d81725d0e3a08eca9d0875408c5197049560a39df64b6
  • Pointer size: 131 Bytes
  • Size of remote file: 504 kB
Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_131_loss_4.1162.png ADDED

Git LFS Details

  • SHA256: b7b60b9923013e62c0f54f1a7b59b9663d95ec38600f7112da161c7bc8f84bfb
  • Pointer size: 131 Bytes
  • Size of remote file: 502 kB
Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_132_loss_4.1158.png ADDED

Git LFS Details

  • SHA256: 5ba4dafc926345e66faa96b3e02afc1d785feca5cd15a048fd694d903bae8cd8
  • Pointer size: 131 Bytes
  • Size of remote file: 504 kB
Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_133_loss_4.1141.png ADDED

Git LFS Details

  • SHA256: d13166124da5e71fcb24117e9b988aaf037dba90d3a0cde71a0f7c69af9d0f3b
  • Pointer size: 131 Bytes
  • Size of remote file: 508 kB
Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_134_loss_4.1138.png ADDED

Git LFS Details

  • SHA256: 979f30522e8059e8200118e847831e4fc196728f3326c866721ddf23f87d06ee
  • Pointer size: 131 Bytes
  • Size of remote file: 509 kB
Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_135_loss_4.1132.png ADDED

Git LFS Details

  • SHA256: 233fbd95d332286a0cab91bde8b8603277ea72bb02a560f5c30029d16b8eacbf
  • Pointer size: 131 Bytes
  • Size of remote file: 505 kB
Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_136_loss_4.1129.png ADDED

Git LFS Details

  • SHA256: 066844d24a1506aa8989001d21767904fd224b4c4ec0d3597260a6396d8f842e
  • Pointer size: 131 Bytes
  • Size of remote file: 506 kB
Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_137_loss_4.1123.png ADDED

Git LFS Details

  • SHA256: 22d21ec978bb03a84a815a89bbfcb0e2d22f2d0bae06e5ca275b74c9a10d1fa0
  • Pointer size: 131 Bytes
  • Size of remote file: 505 kB
Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_138_loss_4.1104.png ADDED

Git LFS Details

  • SHA256: 97bf23207fd980576899d069c8b3423a3465a03a8d458cbc02f2b77a9bbd02bc
  • Pointer size: 131 Bytes
  • Size of remote file: 510 kB
Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_139_loss_4.1101.png ADDED

Git LFS Details

  • SHA256: dfe7fc340dda27d99e49f40f7e65483c06e87882bc5c08b5ae7f9ea87c64114b
  • Pointer size: 131 Bytes
  • Size of remote file: 511 kB
Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_13_loss_4.1507.png ADDED

Git LFS Details

  • SHA256: bb0a365785fb3643786432551c63d03472308b2e15f906afaea81300145ccbd7
  • Pointer size: 131 Bytes
  • Size of remote file: 524 kB
Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_140_loss_4.1085.png ADDED

Git LFS Details

  • SHA256: 42f5bf1c28305184979cc43b290833592cc8e03ec04b6ccc330c45125493eedc
  • Pointer size: 131 Bytes
  • Size of remote file: 512 kB