alpha31476 commited on
Commit
0c51fb1
·
verified ·
1 Parent(s): bda7d99

Image Audio Alingment Train

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. Vaani/23m1521.code-workspace +13 -0
  3. Vaani/Image-Audio-Hindi.csv +3 -0
  4. Vaani/Img_Audio_Alignment/_2.1_Train.py +1788 -0
  5. Vaani/Img_Audio_Alignment/_2_Train.ipynb +1938 -0
  6. Vaani/Img_Audio_Alignment/_2_Train.py +315 -48
  7. Vaani/Img_Audio_Alignment/available_img_audios_TEST.csv +0 -0
  8. Vaani/Img_Audio_Alignment/available_img_audios_TRAIN.csv +0 -0
  9. Vaani/Img_Audio_Alignment/checkpoints/csip/csip_best_epoch_203.pt +3 -0
  10. Vaani/Img_Audio_Alignment/checkpoints/csip/csip_best_epoch_27.pt +3 -0
  11. Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_15_loss_3.4611.png +0 -0
  12. Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_1_loss_3.4611.png +0 -0
  13. Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_203_loss_3.4611.png +0 -0
  14. Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_27_loss_3.4611.png +0 -0
  15. Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_5_loss_3.4611.png +0 -0
  16. Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_15_loss_3.4611.png +0 -0
  17. Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_1_loss_3.4611.png +0 -0
  18. Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_203_loss_3.4611.png +0 -0
  19. Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_27_loss_3.4611.png +0 -0
  20. Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_5_loss_3.4611.png +0 -0
  21. Vaani/Img_Audio_Alignment/runs/csip/events.out.tfevents.1747523457.rmgpu025.3375295.0 +3 -0
  22. Vaani/Img_Audio_Alignment/runs/csip/training_log.csv +367 -0
  23. Vaani/Vaani-Audio-Image-Hindi-copy.csv +0 -0
  24. Vaani/Vaani-Audio-Image-Hindi.csv +0 -0
  25. Vaani/Vaani-Audio-Image-Hindi2.csv +0 -0
  26. Vaani/Vaani-Audio-Image-Hindi3.csv +0 -0
  27. Vaani/Vaani-Images-Audio-JSON.parquet +3 -0
  28. Vaani/VaaniLDM/ddpm_ckpt_epoch59.pt +3 -0
  29. Vaani/VaaniLDM/ddpm_ckpt_epoch60.pt +3 -0
  30. Vaani/VaaniLDM/samples/x0_0.png +2 -2
  31. Vaani/VaaniLDM/samples/x0_1.png +0 -0
  32. Vaani/VaaniLDM/samples/x0_10.png +0 -0
  33. Vaani/VaaniLDM/samples/x0_100.png +0 -0
  34. Vaani/VaaniLDM/samples/x0_101.png +0 -0
  35. Vaani/VaaniLDM/samples/x0_102.png +0 -0
  36. Vaani/VaaniLDM/samples/x0_103.png +0 -0
  37. Vaani/VaaniLDM/samples/x0_104.png +0 -0
  38. Vaani/VaaniLDM/samples/x0_105.png +0 -0
  39. Vaani/VaaniLDM/samples/x0_106.png +0 -0
  40. Vaani/VaaniLDM/samples/x0_107.png +0 -0
  41. Vaani/VaaniLDM/samples/x0_108.png +0 -0
  42. Vaani/VaaniLDM/samples/x0_109.png +0 -0
  43. Vaani/VaaniLDM/samples/x0_11.png +0 -0
  44. Vaani/VaaniLDM/samples/x0_110.png +0 -0
  45. Vaani/VaaniLDM/samples/x0_111.png +0 -0
  46. Vaani/VaaniLDM/samples/x0_112.png +0 -0
  47. Vaani/VaaniLDM/samples/x0_113.png +0 -0
  48. Vaani/VaaniLDM/samples/x0_114.png +0 -0
  49. Vaani/VaaniLDM/samples/x0_115.png +0 -0
  50. Vaani/VaaniLDM/samples/x0_116.png +0 -0
.gitattributes CHANGED
@@ -142,3 +142,4 @@ Vaani/audio_urls[[:space:]]copy.txt filter=lfs diff=lfs merge=lfs -text
142
  Vaani/imageBY.csv filter=lfs diff=lfs merge=lfs -text
143
  Vaani/imageBY2.csv filter=lfs diff=lfs merge=lfs -text
144
  Vaani/imageBY3.csv filter=lfs diff=lfs merge=lfs -text
 
 
142
  Vaani/imageBY.csv filter=lfs diff=lfs merge=lfs -text
143
  Vaani/imageBY2.csv filter=lfs diff=lfs merge=lfs -text
144
  Vaani/imageBY3.csv filter=lfs diff=lfs merge=lfs -text
145
+ Vaani/Image-Audio-Hindi.csv filter=lfs diff=lfs merge=lfs -text
Vaani/23m1521.code-workspace ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "folders": [
3
+ {
4
+ "path": "../../.."
5
+ },
6
+ {
7
+ "path": ".."
8
+ }
9
+ ],
10
+ "settings": {
11
+ "terminal.integrated.mouseWheelZoom": true
12
+ }
13
+ }
Vaani/Image-Audio-Hindi.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a44517e6935e21d887dbe1fe076e9b14607a22c5a709bb33e504b353c6fbb99d
3
+ size 1561128197
Vaani/Img_Audio_Alignment/_2.1_Train.py ADDED
@@ -0,0 +1,1788 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ==================================================================
2
+ # L A T E N T D I F F U S I O N M O D E L
3
+ # ==================================================================
4
+ # Author : Ashish Kumar Uchadiya
5
+ # Created : May 11, 2025
6
+ # Description: This script implements the training of a VQ-VAE model for
7
+ # image reconstruction, integrated with Latent Diffusion Models (LDMs) and
8
+ # audio conditioning. The VQ-VAE maps images to a discrete latent space,
9
+ # which is then modeled by the LDM for learning a diffusion process over the
10
+ # compressed representation. Audio features are used as conditioning inputs
11
+ # to guide the generation process. The training minimizes a combination of
12
+ # LPIPS (Learned Perceptual Image Patch Similarity) loss for perceptual
13
+ # fidelity and PatchGAN loss to enforce local realism. This setup enables
14
+ # efficient and semantically-aware generation of high-quality images driven
15
+ # by audio cues.
16
+ # ==================================================================
17
+ # I M P O R T S
18
+ # ==================================================================
19
+ from __future__ import annotations
20
+ import warnings
21
+ warnings.filterwarnings("ignore")
22
+
23
+
24
+ import os
25
+ import io
26
+ import sys
27
+ import math
28
+ import random
29
+ import collections
30
+ import collections.abc
31
+ import re
32
+ from itertools import repeat
33
+ from pathlib import Path
34
+ from typing import Optional, Tuple, Union, List, Dict
35
+
36
+ import csv
37
+ import numpy as np
38
+ import pandas as pd
39
+ from PIL import Image
40
+ import seaborn as sns
41
+ import matplotlib.pyplot as plt
42
+ from tqdm import trange, tqdm
43
+
44
+ import torch
45
+ import torch.nn.functional as F
46
+ from torch import nn
47
+ from torch.nn.init import _calculate_fan_in_and_fan_out
48
+ import torch.utils.checkpoint as checkpoint
49
+
50
+ import torchvision as tv
51
+ from torchvision.transforms import v2
52
+ from torch.utils.tensorboard import SummaryWriter
53
+
54
+ os.environ["CUDA_VISIBLE_DEVICES"] = "1"
55
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
56
+ print(f"Using device: {device}")
57
+
58
+ import torchaudio
59
+ import torchaudio.transforms as T
60
+ from torchlibrosa.stft import Spectrogram, LogmelFilterBank
61
+ from torchlibrosa.augmentation import SpecAugmentation
62
+
63
+ from transformers import AutoModel, AutoTokenizer, logging
64
+ from huggingface_hub.file_download import hf_hub_download
65
+ from huggingface_hub.file_download import hf_hub_download
66
+ from peft import get_peft_config, get_peft_model
67
+ from transformers import CLIPVisionModel, AutoProcessor
68
+
69
+
70
+ # ==================================================================
71
+ # H T S - A T
72
+ # ==================================================================
73
+ class HTSATConfig:
74
+ # Ke Chen
75
76
+ # HTS-AT: A HIERARCHICAL TOKEN-SEMANTIC AUDIO TRANSFORMER FOR SOUND CLASSIFICATION AND DETECTION
77
+ # The configuration for training the model
78
+
79
+ exp_name = "exp_htsat_pretrain" # the saved ckpt prefix name of the model
80
+ workspace = "/home/kechen/Research/HTSAT" # the folder of your code
81
+ dataset_path = "/home/Research/audioset" # the dataset path
82
+ desed_folder = "/home/Research/DESED" # the desed file
83
+
84
+ dataset_type = "audioset" # "audioset" "esc-50" "scv2"
85
+ index_type = "full_train" # only works for audioset
86
+ balanced_data = True # only works for audioset
87
+
88
+ loss_type = "clip_bce" #
89
+ # AudioSet & SCV2: "clip_bce" | ESC-50: "clip_ce"
90
+
91
+ # trained from a checkpoint, or evaluate a single model
92
+ resume_checkpoint = None
93
+ # "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_1.ckpt"
94
+
95
+ esc_fold = 0 # just for esc dataset, select the fold you need for evaluation and (+1) validation
96
+
97
+
98
+ debug = False
99
+
100
+ random_seed = 970131 # 19970318 970131 12412 127777 1009 34047
101
+ batch_size = 32 * 4 # batch size per GPU x GPU number , default is 32 x 4 = 128
102
+ learning_rate = 1e-3 # 1e-4 also workable
103
+ max_epoch = 100
104
+ num_workers = 3
105
+
106
+ lr_scheduler_epoch = [10,20,30]
107
+ lr_rate = [0.02, 0.05, 0.1]
108
+
109
+ # these data preparation optimizations do not bring many improvements, so deprecated
110
+ enable_token_label = False # token label
111
+ class_map_path = "class_hier_map.npy"
112
+ class_filter = None
113
+ retrieval_index = [15382, 9202, 130, 17618, 17157, 17516, 16356, 6165, 13992, 9238, 5550, 5733, 1914, 1600, 3450, 13735, 11108, 3762,
114
+ 9840, 11318, 8131, 4429, 16748, 4992, 16783, 12691, 4945, 8779, 2805, 9418, 2797, 14357, 5603, 212, 3852, 12666, 1338, 10269, 2388, 8260, 4293, 14454, 7677, 11253, 5060, 14938, 8840, 4542, 2627, 16336, 8992, 15496, 11140, 446, 6126, 10691, 8624, 10127, 9068, 16710, 10155, 14358, 7567, 5695, 2354, 8057, 17635, 133, 16183, 14535, 7248, 4560, 14429, 2463, 10773, 113, 2462, 9223, 4929, 14274, 4716, 17307, 4617, 2132, 11083, 1039, 1403, 9621, 13936, 2229, 2875, 17840, 9359, 13311, 9790, 13288, 4750, 17052, 8260, 14900]
115
+ token_label_range = [0.2,0.6]
116
+ enable_time_shift = False # shift time
117
+ enable_label_enhance = False # enhance hierarchical label
118
+ enable_repeat_mode = False # repeat the spectrogram / reshape the spectrogram
119
+
120
+
121
+
122
+ # for model's design
123
+ enable_tscam = True # enbale the token-semantic layer
124
+
125
+ # for signal processing
126
+ sample_rate = 32000 # 16000 for scv2, 32000 for audioset and esc-50
127
+ clip_samples = sample_rate * 10 # audio_set 10-sec clip
128
+ window_size = 1024
129
+ hop_size = 320 # 160 for scv2, 320 for audioset and esc-50
130
+ mel_bins = 64
131
+ fmin = 50
132
+ fmax = 14000
133
+ shift_max = int(clip_samples * 0.5)
134
+
135
+ # for data collection
136
+ classes_num = 527 # esc: 50 | audioset: 527 | scv2: 35
137
+ patch_size = (25, 4) # deprecated
138
+ crop_size = None # int(clip_samples * 0.5) deprecated
139
+
140
+ # for htsat hyperparamater
141
+ htsat_window_size = 8
142
+ htsat_spec_size = 256
143
+ htsat_patch_size = 4
144
+ htsat_stride = (4, 4)
145
+ htsat_num_head = [4,8,16,32]
146
+ htsat_dim = 96
147
+ htsat_depth = [2,2,6,2]
148
+
149
+ swin_pretrain_path = None
150
+ # "/home/Research/model_backup/pretrain/swin_tiny_c24_patch4_window8_256.pth"
151
+
152
+ # Some Deprecated Optimization in the model design, check the model code for details
153
+ htsat_attn_heatmap = False
154
+ htsat_hier_output = False
155
+ htsat_use_max = False
156
+
157
+
158
+ # for ensemble test
159
+
160
+ ensemble_checkpoints = []
161
+ ensemble_strides = []
162
+
163
+
164
+ # weight average folder
165
+ wa_folder = "/home/version_0/checkpoints/"
166
+ # weight average output filename
167
+ wa_model_path = "HTSAT_AudioSet_Saved_x.ckpt"
168
+
169
+ esm_model_pathes = [
170
+ "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_1.ckpt",
171
+ "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_2.ckpt",
172
+ "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_3.ckpt",
173
+ "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_4.ckpt",
174
+ "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_5.ckpt",
175
+ "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_6.ckpt"
176
+ ]
177
+
178
+ # for framewise localization
179
+ heatmap_dir = "/home/Research/heatmap_output"
180
+ test_file = "htsat-test-ensemble"
181
+ fl_local = False # indicate if we need to use this dataset for the framewise detection
182
+ fl_dataset = "/home/Research/desed/desedim_embval.npy"
183
+ fl_class_num = [
184
+ "Speech", "Frying", "Dishes", "Running_water",
185
+ "Blender", "Electric_shaver_toothbrush", "Alarm_bell_ringing",
186
+ "Cat", "Dog", "Vacuum_cleaner"
187
+ ]
188
+
189
+ # map 527 classes into 10 classes
190
+ fl_audioset_mapping = [
191
+ [0,1,2,3,4,5,6,7],
192
+ [366, 367, 368],
193
+ [364],
194
+ [288, 289, 290, 291, 292, 293, 294, 295, 296, 297],
195
+ [369],
196
+ [382],
197
+ [310, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402],
198
+ [81, 82, 83, 84, 85],
199
+ [74, 75, 76, 77, 78, 79],
200
+ [377]
201
+ ]
202
+
203
+
204
+
205
+ def _ntuple(n):
206
+ def parse(x):
207
+ if isinstance(x, collections.abc.Iterable):
208
+ return x
209
+ return tuple(repeat(x, n))
210
+ return parse
211
+
212
+ to_1tuple = _ntuple(1)
213
+ to_2tuple = _ntuple(2)
214
+ to_3tuple = _ntuple(3)
215
+ to_4tuple = _ntuple(4)
216
+ to_ntuple = _ntuple
217
+
218
+ def do_mixup(x, mixup_lambda):
219
+ """Mixup x of even indexes (0, 2, 4, ...) with x of odd indexes
220
+ (1, 3, 5, ...).
221
+ Args:
222
+ x: (batch_size * 2, ...)
223
+ mixup_lambda: (batch_size * 2,)
224
+ Returns:
225
+ out: (batch_size, ...)
226
+ """
227
+ out = (x[0 :: 2].transpose(0, -1) * mixup_lambda[0 :: 2] + \
228
+ x[1 :: 2].transpose(0, -1) * mixup_lambda[1 :: 2]).transpose(0, -1)
229
+ return out
230
+
231
+ def interpolate(x, ratio):
232
+ """Interpolate data in time domain. This is used to compensate the
233
+ resolution reduction in downsampling of a CNN.
234
+
235
+ Args:
236
+ x: (batch_size, time_steps, classes_num)
237
+ ratio: int, ratio to interpolate
238
+ Returns:
239
+ upsampled: (batch_size, time_steps * ratio, classes_num)
240
+ """
241
+ (batch_size, time_steps, classes_num) = x.shape
242
+ upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1)
243
+ upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num)
244
+ return upsampled
245
+
246
+
247
+ def drop_path(x, drop_prob: float = 0., training: bool = False):
248
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
249
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
250
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
251
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
252
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
253
+ 'survival rate' as the argument.
254
+ """
255
+ if drop_prob == 0. or not training:
256
+ return x
257
+ keep_prob = 1 - drop_prob
258
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
259
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
260
+ random_tensor.floor_() # binarize
261
+ output = x.div(keep_prob) * random_tensor
262
+ return output
263
+
264
+
265
+ class DropPath(nn.Module):
266
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
267
+ """
268
+ def __init__(self, drop_prob=None):
269
+ super(DropPath, self).__init__()
270
+ self.drop_prob = drop_prob
271
+
272
+ def forward(self, x):
273
+ return drop_path(x, self.drop_prob, self.training)
274
+
275
+ class PatchEmbed(nn.Module):
276
+ """ 2D Image to Patch Embedding
277
+ """
278
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True, patch_stride = 16):
279
+ super().__init__()
280
+ img_size = to_2tuple(img_size)
281
+ patch_size = to_2tuple(patch_size)
282
+ patch_stride = to_2tuple(patch_stride)
283
+ self.img_size = img_size
284
+ self.patch_size = patch_size
285
+ self.patch_stride = patch_stride
286
+ self.grid_size = (img_size[0] // patch_stride[0], img_size[1] // patch_stride[1])
287
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
288
+ self.flatten = flatten
289
+ self.in_chans = in_chans
290
+ self.embed_dim = embed_dim
291
+
292
+ padding = ((patch_size[0] - patch_stride[0]) // 2, (patch_size[1] - patch_stride[1]) // 2)
293
+
294
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_stride, padding=padding)
295
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
296
+
297
+ def forward(self, x):
298
+ B, C, H, W = x.shape
299
+ assert H == self.img_size[0] and W == self.img_size[1], \
300
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
301
+ x = self.proj(x)
302
+ if self.flatten:
303
+ x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
304
+ x = self.norm(x)
305
+ return x
306
+
307
+ class Mlp(nn.Module):
308
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks
309
+ """
310
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
311
+ super().__init__()
312
+ out_features = out_features or in_features
313
+ hidden_features = hidden_features or in_features
314
+ self.fc1 = nn.Linear(in_features, hidden_features)
315
+ self.act = act_layer()
316
+ self.fc2 = nn.Linear(hidden_features, out_features)
317
+ self.drop = nn.Dropout(drop)
318
+
319
+ def forward(self, x):
320
+ x = self.fc1(x)
321
+ x = self.act(x)
322
+ x = self.drop(x)
323
+ x = self.fc2(x)
324
+ x = self.drop(x)
325
+ return x
326
+
327
+ def _no_gradim_audiorunc_normal_(tensor, mean, std, a, b):
328
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
329
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
330
+ def norm_cdf(x):
331
+ # Computes standard normal cumulative distribution function
332
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
333
+
334
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
335
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
336
+ "The distribution of values may be incorrect.",
337
+ stacklevel=2)
338
+
339
+ with torch.no_grad():
340
+ # Values are generated by using a truncated uniform distribution and
341
+ # then using the inverse CDF for the normal distribution.
342
+ # Get upper and lower cdf values
343
+ l = norm_cdf((a - mean) / std)
344
+ u = norm_cdf((b - mean) / std)
345
+
346
+ # Uniformly fill tensor with values from [l, u], then translate to
347
+ # [2l-1, 2u-1].
348
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
349
+
350
+ # Use inverse cdf transform for normal distribution to get truncated
351
+ # standard normal
352
+ tensor.erfinv_()
353
+
354
+ # Transform to proper mean, std
355
+ tensor.mul_(std * math.sqrt(2.))
356
+ tensor.add_(mean)
357
+
358
+ # Clamp to ensure it's in the proper range
359
+ tensor.clamp_(min=a, max=b)
360
+ return tensor
361
+
362
+
363
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
364
+ # type: (Tensor, float, float, float, float) -> Tensor
365
+ r"""Fills the input Tensor with values drawn from a truncated
366
+ normal distribution. The values are effectively drawn from the
367
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
368
+ with values outside :math:`[a, b]` redrawn until they are within
369
+ the bounds. The method used for generating the random values works
370
+ best when :math:`a \leq \text{mean} \leq b`.
371
+ Args:
372
+ tensor: an n-dimensional `torch.Tensor`
373
+ mean: the mean of the normal distribution
374
+ std: the standard deviation of the normal distribution
375
+ a: the minimum cutoff value
376
+ b: the maximum cutoff value
377
+ Examples:
378
+ >>> w = torch.empty(3, 5)
379
+ >>> nn.init.trunc_normal_(w)
380
+ """
381
+ return _no_gradim_audiorunc_normal_(tensor, mean, std, a, b)
382
+
383
+
384
+ def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'):
385
+ fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
386
+ if mode == 'fan_in':
387
+ denom = fan_in
388
+ elif mode == 'fan_out':
389
+ denom = fan_out
390
+ elif mode == 'fan_avg':
391
+ denom = (fan_in + fan_out) / 2
392
+
393
+ variance = scale / denom
394
+
395
+ if distribution == "truncated_normal":
396
+ # constant is stddev of standard normal truncated to (-2, 2)
397
+ trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978)
398
+ elif distribution == "normal":
399
+ tensor.normal_(std=math.sqrt(variance))
400
+ elif distribution == "uniform":
401
+ bound = math.sqrt(3 * variance)
402
+ tensor.uniform_(-bound, bound)
403
+ else:
404
+ raise ValueError(f"invalid distribution {distribution}")
405
+
406
+
407
+ def lecun_normal_(tensor):
408
+ variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal')
409
+
410
+
411
+ # below codes are based and referred from https://github.com/microsoft/Swin-Transformer
412
+ # Swin Transformer for Computer Vision: https://arxiv.org/pdf/2103.14030.pdf
413
+
414
+ def window_partition(x, window_size):
415
+ """
416
+ Args:
417
+ x: (B, H, W, C)
418
+ window_size (int): window size
419
+ Returns:
420
+ windows: (num_windows*B, window_size, window_size, C)
421
+ """
422
+ B, H, W, C = x.shape
423
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
424
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
425
+ return windows
426
+
427
+
428
+ def window_reverse(windows, window_size, H, W):
429
+ """
430
+ Args:
431
+ windows: (num_windows*B, window_size, window_size, C)
432
+ window_size (int): Window size
433
+ H (int): Height of image
434
+ W (int): Width of image
435
+ Returns:
436
+ x: (B, H, W, C)
437
+ """
438
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
439
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
440
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
441
+ return x
442
+
443
+
444
+ class WindowAttention(nn.Module):
445
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
446
+ It supports both of shifted and non-shifted window.
447
+ Args:
448
+ dim (int): Number of input channels.
449
+ window_size (tuple[int]): The height and width of the window.
450
+ num_heads (int): Number of attention heads.
451
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
452
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
453
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
454
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
455
+ """
456
+
457
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
458
+
459
+ super().__init__()
460
+ self.dim = dim
461
+ self.window_size = window_size # Wh, Ww
462
+ self.num_heads = num_heads
463
+ head_dim = dim // num_heads
464
+ self.scale = qk_scale or head_dim ** -0.5
465
+
466
+ # define a parameter table of relative position bias
467
+ self.relative_position_bias_table = nn.Parameter(
468
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
469
+
470
+ # get pair-wise relative position index for each token inside the window
471
+ coords_h = torch.arange(self.window_size[0])
472
+ coords_w = torch.arange(self.window_size[1])
473
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
474
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
475
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
476
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
477
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
478
+ relative_coords[:, :, 1] += self.window_size[1] - 1
479
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
480
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
481
+ self.register_buffer("relative_position_index", relative_position_index)
482
+
483
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
484
+ self.attn_drop = nn.Dropout(attn_drop)
485
+ self.proj = nn.Linear(dim, dim)
486
+ self.proj_drop = nn.Dropout(proj_drop)
487
+
488
+ trunc_normal_(self.relative_position_bias_table, std=.02)
489
+ self.softmax = nn.Softmax(dim=-1)
490
+
491
+ def forward(self, x, mask=None):
492
+ """
493
+ Args:
494
+ x: input features with shape of (num_windows*B, N, C)
495
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
496
+ """
497
+ B_, N, C = x.shape
498
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
499
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
500
+
501
+ q = q * self.scale
502
+ attn = (q @ k.transpose(-2, -1))
503
+
504
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
505
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
506
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
507
+ attn = attn + relative_position_bias.unsqueeze(0)
508
+
509
+ if mask is not None:
510
+ nW = mask.shape[0]
511
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
512
+ attn = attn.view(-1, self.num_heads, N, N)
513
+ attn = self.softmax(attn)
514
+ else:
515
+ attn = self.softmax(attn)
516
+
517
+ attn = self.attn_drop(attn)
518
+
519
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
520
+ x = self.proj(x)
521
+ x = self.proj_drop(x)
522
+ return x, attn
523
+
524
+ def extra_repr(self):
525
+ return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
526
+
527
+
528
+ # We use the model based on Swintransformer Block, therefore we can use the swin-transformer pretrained model
529
+ class SwinTransformerBlock(nn.Module):
530
+ r""" Swin Transformer Block.
531
+ Args:
532
+ dim (int): Number of input channels.
533
+ input_resolution (tuple[int]): Input resulotion.
534
+ num_heads (int): Number of attention heads.
535
+ window_size (int): Window size.
536
+ shift_size (int): Shift size for SW-MSA.
537
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
538
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
539
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
540
+ drop (float, optional): Dropout rate. Default: 0.0
541
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
542
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
543
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
544
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
545
+ """
546
+
547
+ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
548
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
549
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm, norm_before_mlp='ln'):
550
+ super().__init__()
551
+ self.dim = dim
552
+ self.input_resolution = input_resolution
553
+ self.num_heads = num_heads
554
+ self.window_size = window_size
555
+ self.shift_size = shift_size
556
+ self.mlp_ratio = mlp_ratio
557
+ self.norm_before_mlp = norm_before_mlp
558
+ if min(self.input_resolution) <= self.window_size:
559
+ # if window size is larger than input resolution, we don't partition windows
560
+ self.shift_size = 0
561
+ self.window_size = min(self.input_resolution)
562
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
563
+
564
+ self.norm1 = norm_layer(dim)
565
+ self.attn = WindowAttention(
566
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
567
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
568
+
569
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
570
+ if self.norm_before_mlp == 'ln':
571
+ self.norm2 = nn.LayerNorm(dim)
572
+ elif self.norm_before_mlp == 'bn':
573
+ self.norm2 = lambda x: nn.BatchNorm1d(dim)(x.transpose(1, 2)).transpose(1, 2)
574
+ else:
575
+ raise NotImplementedError
576
+ mlp_hidden_dim = int(dim * mlp_ratio)
577
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
578
+
579
+ if self.shift_size > 0:
580
+ # calculate attention mask for SW-MSA
581
+ H, W = self.input_resolution
582
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
583
+ h_slices = (slice(0, -self.window_size),
584
+ slice(-self.window_size, -self.shift_size),
585
+ slice(-self.shift_size, None))
586
+ w_slices = (slice(0, -self.window_size),
587
+ slice(-self.window_size, -self.shift_size),
588
+ slice(-self.shift_size, None))
589
+ cnt = 0
590
+ for h in h_slices:
591
+ for w in w_slices:
592
+ img_mask[:, h, w, :] = cnt
593
+ cnt += 1
594
+
595
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
596
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
597
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
598
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
599
+ else:
600
+ attn_mask = None
601
+
602
+ self.register_buffer("attn_mask", attn_mask)
603
+
604
+ def forward(self, x):
605
+ # pdb.set_trace()
606
+ H, W = self.input_resolution
607
+ # print("H: ", H)
608
+ # print("W: ", W)
609
+ # pdb.set_trace()
610
+ B, L, C = x.shape
611
+ # assert L == H * W, "input feature has wrong size"
612
+
613
+ shortcut = x
614
+ x = self.norm1(x)
615
+ x = x.view(B, H, W, C)
616
+
617
+ # cyclic shift
618
+ if self.shift_size > 0:
619
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
620
+ else:
621
+ shifted_x = x
622
+
623
+ # partition windows
624
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
625
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
626
+
627
+ # W-MSA/SW-MSA
628
+ attn_windows, attn = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
629
+
630
+ # merge windows
631
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
632
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
633
+
634
+ # reverse cyclic shift
635
+ if self.shift_size > 0:
636
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
637
+ else:
638
+ x = shifted_x
639
+ x = x.view(B, H * W, C)
640
+
641
+ # FFN
642
+ x = shortcut + self.drop_path(x)
643
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
644
+
645
+ return x, attn
646
+
647
+ def extra_repr(self):
648
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
649
+ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
650
+
651
+
652
+
653
+ class PatchMerging(nn.Module):
654
+ r""" Patch Merging Layer.
655
+ Args:
656
+ input_resolution (tuple[int]): Resolution of input feature.
657
+ dim (int): Number of input channels.
658
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
659
+ """
660
+
661
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
662
+ super().__init__()
663
+ self.input_resolution = input_resolution
664
+ self.dim = dim
665
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
666
+ self.norm = norm_layer(4 * dim)
667
+
668
+ def forward(self, x):
669
+ """
670
+ x: B, H*W, C
671
+ """
672
+ H, W = self.input_resolution
673
+ B, L, C = x.shape
674
+ assert L == H * W, "input feature has wrong size"
675
+ assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
676
+
677
+ x = x.view(B, H, W, C)
678
+
679
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
680
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
681
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
682
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
683
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
684
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
685
+
686
+ x = self.norm(x)
687
+ x = self.reduction(x)
688
+
689
+ return x
690
+
691
+ def extra_repr(self):
692
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
693
+
694
+
695
+ class BasicLayer(nn.Module):
696
+ """ A basic Swin Transformer layer for one stage.
697
+ Args:
698
+ dim (int): Number of input channels.
699
+ input_resolution (tuple[int]): Input resolution.
700
+ depth (int): Number of blocks.
701
+ num_heads (int): Number of attention heads.
702
+ window_size (int): Local window size.
703
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
704
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
705
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
706
+ drop (float, optional): Dropout rate. Default: 0.0
707
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
708
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
709
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
710
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
711
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
712
+ """
713
+
714
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size,
715
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
716
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
717
+ norm_before_mlp='ln'):
718
+
719
+ super().__init__()
720
+ self.dim = dim
721
+ self.input_resolution = input_resolution
722
+ self.depth = depth
723
+ self.use_checkpoint = use_checkpoint
724
+
725
+ # build blocks
726
+ self.blocks = nn.ModuleList([
727
+ SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
728
+ num_heads=num_heads, window_size=window_size,
729
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
730
+ mlp_ratio=mlp_ratio,
731
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
732
+ drop=drop, attn_drop=attn_drop,
733
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
734
+ norm_layer=norm_layer, norm_before_mlp=norm_before_mlp)
735
+ for i in range(depth)])
736
+
737
+ # patch merging layer
738
+ if downsample is not None:
739
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
740
+ else:
741
+ self.downsample = None
742
+
743
+ def forward(self, x):
744
+ attns = []
745
+ for blk in self.blocks:
746
+ if self.use_checkpoint:
747
+ x = checkpoint.checkpoint(blk, x)
748
+ else:
749
+ x, attn = blk(x)
750
+ if not self.training:
751
+ attns.append(attn.unsqueeze(0))
752
+ if self.downsample is not None:
753
+ x = self.downsample(x)
754
+ if not self.training:
755
+ attn = torch.cat(attns, dim = 0)
756
+ attn = torch.mean(attn, dim = 0)
757
+ return x, attn
758
+
759
+ def extra_repr(self):
760
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
761
+
762
+
763
+ # The Core of HTSAT
764
+ class HTSAT_Swin_Transformer(nn.Module):
765
+ r"""HTSAT based on the Swin Transformer
766
+ Args:
767
+ spec_size (int | tuple(int)): Input Spectrogram size. Default 256
768
+ patch_size (int | tuple(int)): Patch size. Default: 4
769
+ path_stride (iot | tuple(int)): Patch Stride for Frequency and Time Axis. Default: 4
770
+ in_chans (int): Number of input image channels. Default: 1 (mono)
771
+ num_classes (int): Number of classes for classification head. Default: 527
772
+ embed_dim (int): Patch embedding dimension. Default: 96
773
+ depths (tuple(int)): Depth of each HTSAT-Swin Transformer layer.
774
+ num_heads (tuple(int)): Number of attention heads in different layers.
775
+ window_size (int): Window size. Default: 8
776
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
777
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
778
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
779
+ drop_rate (float): Dropout rate. Default: 0
780
+ attn_drop_rate (float): Attention dropout rate. Default: 0
781
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
782
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
783
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
784
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
785
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
786
+ config (module): The configuration Module from config.py (HTSATConfig Class)
787
+ """
788
+
789
+ def __init__(self, spec_size=256, patch_size=4, patch_stride=(4,4),
790
+ in_chans=1, num_classes=527,
791
+ embed_dim=96, depths=[2, 2, 6, 2], num_heads=[4, 8, 16, 32],
792
+ window_size=8, mlp_ratio=4., qkv_bias=True, qk_scale=None,
793
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
794
+ norm_layer=nn.LayerNorm,
795
+ ape=False, patch_norm=True,
796
+ use_checkpoint=False, norm_before_mlp='ln', config = None, **kwargs):
797
+ super(HTSAT_Swin_Transformer, self).__init__()
798
+
799
+ self.config = config
800
+ self.spec_size = spec_size
801
+ self.patch_stride = patch_stride
802
+ self.patch_size = patch_size
803
+ self.window_size = window_size
804
+ self.embed_dim = embed_dim
805
+ self.depths = depths
806
+ self.ape = ape
807
+ self.in_chans = in_chans
808
+ self.num_classes = num_classes
809
+ self.num_heads = num_heads
810
+ self.num_layers = len(self.depths)
811
+ self.num_features = int(self.embed_dim * 2 ** (self.num_layers - 1))
812
+
813
+ self.drop_rate = drop_rate
814
+ self.attn_drop_rate = attn_drop_rate
815
+ self.drop_path_rate = drop_path_rate
816
+
817
+ self.qkv_bias = qkv_bias
818
+ self.qk_scale = None
819
+
820
+ self.patch_norm = patch_norm
821
+ self.norm_layer = norm_layer if self.patch_norm else None
822
+ self.norm_before_mlp = norm_before_mlp
823
+ self.mlp_ratio = mlp_ratio
824
+
825
+ self.use_checkpoint = use_checkpoint
826
+
827
+ # process mel-spec ; used only once
828
+ self.freq_ratio = self.spec_size // self.config.mel_bins
829
+ window = 'hann'
830
+ center = True
831
+ pad_mode = 'reflect'
832
+ ref = 1.0
833
+ amin = 1e-10
834
+ top_db = None
835
+ self.interpolate_ratio = 32 # Downsampled ratio
836
+ # Spectrogram extractor
837
+ self.spectrogram_extractor = Spectrogram(n_fft=config.window_size, hop_length=config.hop_size,
838
+ win_length=config.window_size, window=window, center=center, pad_mode=pad_mode,
839
+ freeze_parameters=True)
840
+ # Logmel feature extractor
841
+ self.logmel_extractor = LogmelFilterBank(sr=config.sample_rate, n_fft=config.window_size,
842
+ n_mels=config.mel_bins, fmin=config.fmin, fmax=config.fmax, ref=ref, amin=amin, top_db=top_db,
843
+ freeze_parameters=True)
844
+ # Spec augmenter
845
+ self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2,
846
+ freq_drop_width=8, freq_stripes_num=2) # 2 2
847
+ self.bn0 = nn.BatchNorm2d(self.config.mel_bins)
848
+
849
+
850
+ # split spctrogram into non-overlapping patches
851
+ self.patch_embed = PatchEmbed(
852
+ img_size=self.spec_size, patch_size=self.patch_size, in_chans=self.in_chans,
853
+ embed_dim=self.embed_dim, norm_layer=self.norm_layer, patch_stride = patch_stride)
854
+
855
+ num_patches = self.patch_embed.num_patches
856
+ patches_resolution = self.patch_embed.grid_size
857
+ self.patches_resolution = patches_resolution
858
+
859
+ # absolute position embedding
860
+ if self.ape:
861
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, self.embed_dim))
862
+ trunc_normal_(self.absolute_pos_embed, std=.02)
863
+
864
+ self.pos_drop = nn.Dropout(p=self.drop_rate)
865
+
866
+ # stochastic depth
867
+ dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, sum(self.depths))] # stochastic depth decay rule
868
+
869
+ # build layers
870
+ self.layers = nn.ModuleList()
871
+ for i_layer in range(self.num_layers):
872
+ layer = BasicLayer(dim=int(self.embed_dim * 2 ** i_layer),
873
+ input_resolution=(patches_resolution[0] // (2 ** i_layer),
874
+ patches_resolution[1] // (2 ** i_layer)),
875
+ depth=self.depths[i_layer],
876
+ num_heads=self.num_heads[i_layer],
877
+ window_size=self.window_size,
878
+ mlp_ratio=self.mlp_ratio,
879
+ qkv_bias=self.qkv_bias, qk_scale=self.qk_scale,
880
+ drop=self.drop_rate, attn_drop=self.attn_drop_rate,
881
+ drop_path=dpr[sum(self.depths[:i_layer]):sum(self.depths[:i_layer + 1])],
882
+ norm_layer=self.norm_layer,
883
+ downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
884
+ use_checkpoint=use_checkpoint,
885
+ norm_before_mlp=self.norm_before_mlp)
886
+ self.layers.append(layer)
887
+
888
+ self.norm = self.norm_layer(self.num_features)
889
+ self.avgpool = nn.AdaptiveAvgPool1d(1)
890
+ self.maxpool = nn.AdaptiveMaxPool1d(1)
891
+
892
+ if self.config.enable_tscam:
893
+ SF = self.spec_size // (2 ** (len(self.depths) - 1)) // self.patch_stride[0] // self.freq_ratio
894
+ self.tscam_conv = nn.Conv2d(
895
+ in_channels = self.num_features,
896
+ out_channels = self.num_classes,
897
+ kernel_size = (SF,3),
898
+ padding = (0,1)
899
+ )
900
+ self.head = nn.Linear(num_classes, num_classes)
901
+ else:
902
+ self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
903
+
904
+ self.apply(self._init_weights)
905
+
906
+ def _init_weights(self, m):
907
+ if isinstance(m, nn.Linear):
908
+ trunc_normal_(m.weight, std=.02)
909
+ if isinstance(m, nn.Linear) and m.bias is not None:
910
+ nn.init.constant_(m.bias, 0)
911
+ elif isinstance(m, nn.LayerNorm):
912
+ nn.init.constant_(m.bias, 0)
913
+ nn.init.constant_(m.weight, 1.0)
914
+
915
+ @torch.jit.ignore
916
+ def no_weight_decay(self):
917
+ return {'absolute_pos_embed'}
918
+
919
+ @torch.jit.ignore
920
+ def no_weight_decay_keywords(self):
921
+ return {'relative_position_bias_table'}
922
+
923
+ def forward_features(self, x):
924
+ frames_num = x.shape[2]
925
+ x = self.patch_embed(x)
926
+ if self.ape:
927
+ x = x + self.absolute_pos_embed
928
+ x = self.pos_drop(x)
929
+ for i, layer in enumerate(self.layers):
930
+ x, attn = layer(x)
931
+
932
+ if self.config.enable_tscam:
933
+ # for x
934
+ x = self.norm(x)
935
+ B, N, C = x.shape
936
+ SF = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[0]
937
+ ST = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[1]
938
+ x = x.permute(0,2,1).contiguous().reshape(B, C, SF, ST)
939
+ B, C, F, T = x.shape
940
+ # group 2D CNN
941
+ c_freq_bin = F // self.freq_ratio
942
+ x = x.reshape(B, C, F // c_freq_bin, c_freq_bin, T)
943
+ x = x.permute(0,1,3,2,4).contiguous().reshape(B, C, c_freq_bin, -1)
944
+
945
+ # get latent_output
946
+ latent_output = self.avgpool(torch.flatten(x,2))
947
+ latent_output = torch.flatten(latent_output, 1)
948
+
949
+ # display the attention map, if needed
950
+ if self.config.htsat_attn_heatmap:
951
+ # for attn
952
+ attn = torch.mean(attn, dim = 1)
953
+ attn = torch.mean(attn, dim = 1)
954
+ attn = attn.reshape(B, SF, ST)
955
+ c_freq_bin = SF // self.freq_ratio
956
+ attn = attn.reshape(B, SF // c_freq_bin, c_freq_bin, ST)
957
+ attn = attn.permute(0,2,1,3).contiguous().reshape(B, c_freq_bin, -1)
958
+ attn = attn.mean(dim = 1)
959
+ attn_max = torch.max(attn, dim = 1, keepdim = True)[0]
960
+ attn_min = torch.min(attn, dim = 1, keepdim = True)[0]
961
+ attn = ((attn * 0.15) + (attn_max * 0.85 - attn_min)) / (attn_max - attn_min)
962
+ attn = attn.unsqueeze(dim = 2)
963
+
964
+ x = self.tscam_conv(x)
965
+ x = torch.flatten(x, 2) # B, C, T
966
+
967
+ if self.config.htsat_attn_heatmap:
968
+ fpx = interpolate(torch.sigmoid(x).permute(0,2,1).contiguous() * attn, 8 * self.patch_stride[1])
969
+ else:
970
+ fpx = interpolate(torch.sigmoid(x).permute(0,2,1).contiguous(), 8 * self.patch_stride[1])
971
+
972
+ x = self.avgpool(x)
973
+ x = torch.flatten(x, 1)
974
+
975
+ if self.config.loss_type == "clip_ce":
976
+ output_dict = {
977
+ 'framewise_output': fpx, # already sigmoided
978
+ 'clipwise_output': x,
979
+ 'latent_output': latent_output
980
+ }
981
+ else:
982
+ output_dict = {
983
+ 'framewise_output': fpx, # already sigmoided
984
+ 'clipwise_output': torch.sigmoid(x),
985
+ 'latent_output': latent_output
986
+ }
987
+
988
+ else:
989
+ x = self.norm(x) # B N C
990
+ B, N, C = x.shape
991
+
992
+ fpx = x.permute(0,2,1).contiguous().reshape(B, C, frames_num // (2 ** (len(self.depths) + 1)), frames_num // (2 ** (len(self.depths) + 1)) )
993
+ B, C, F, T = fpx.shape
994
+ c_freq_bin = F // self.freq_ratio
995
+ fpx = fpx.reshape(B, C, F // c_freq_bin, c_freq_bin, T)
996
+ fpx = fpx.permute(0,1,3,2,4).contiguous().reshape(B, C, c_freq_bin, -1)
997
+ fpx = torch.sum(fpx, dim = 2)
998
+ fpx = interpolate(fpx.permute(0,2,1).contiguous(), 8 * self.patch_stride[1])
999
+ x = self.avgpool(x.transpose(1, 2)) # B C 1
1000
+ x = torch.flatten(x, 1)
1001
+ if self.num_classes > 0:
1002
+ x = self.head(x)
1003
+ fpx = self.head(fpx)
1004
+ output_dict = {'framewise_output': torch.sigmoid(fpx),
1005
+ 'clipwise_output': torch.sigmoid(x)}
1006
+ return output_dict
1007
+
1008
+ def crop_wav(self, x, crop_size, spe_pos = None):
1009
+ time_steps = x.shape[2]
1010
+ tx = torch.zeros(x.shape[0], x.shape[1], crop_size, x.shape[3]).to(x.device)
1011
+ for i in range(len(x)):
1012
+ if spe_pos is None:
1013
+ crop_pos = random.randint(0, time_steps - crop_size - 1)
1014
+ else:
1015
+ crop_pos = spe_pos
1016
+ tx[i][0] = x[i, 0, crop_pos:crop_pos + crop_size,:]
1017
+ return tx
1018
+
1019
+ # Reshape the wavform to a img size, if you want to use the pretrained swin transformer model
1020
+ def reshape_wav2img(self, x):
1021
+ B, C, T, F = x.shape
1022
+ target_T = int(self.spec_size * self.freq_ratio)
1023
+ target_F = self.spec_size // self.freq_ratio
1024
+ assert T <= target_T and F <= target_F, "the wav size should less than or equal to the swin input size"
1025
+ # to avoid bicubic zero error
1026
+ if T < target_T:
1027
+ x = nn.functional.interpolate(x, (target_T, x.shape[3]), mode="bicubic", align_corners=True)
1028
+ if F < target_F:
1029
+ x = nn.functional.interpolate(x, (x.shape[2], target_F), mode="bicubic", align_corners=True)
1030
+ x = x.permute(0,1,3,2).contiguous()
1031
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2], self.freq_ratio, x.shape[3] // self.freq_ratio)
1032
+ # print(x.shape)
1033
+ x = x.permute(0,1,3,2,4).contiguous()
1034
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3], x.shape[4])
1035
+ return x
1036
+
1037
+ # Repeat the wavform to a img size, if you want to use the pretrained swin transformer model
1038
+ def repeat_wat2img(self, x, cur_pos):
1039
+ B, C, T, F = x.shape
1040
+ target_T = int(self.spec_size * self.freq_ratio)
1041
+ target_F = self.spec_size // self.freq_ratio
1042
+ assert T <= target_T and F <= target_F, "the wav size should less than or equal to the swin input size"
1043
+ # to avoid bicubic zero error
1044
+ if T < target_T:
1045
+ x = nn.functional.interpolate(x, (target_T, x.shape[3]), mode="bicubic", align_corners=True)
1046
+ if F < target_F:
1047
+ x = nn.functional.interpolate(x, (x.shape[2], target_F), mode="bicubic", align_corners=True)
1048
+ x = x.permute(0,1,3,2).contiguous() # B C F T
1049
+ x = x[:,:,:,cur_pos:cur_pos + self.spec_size]
1050
+ x = x.repeat(repeats = (1,1,4,1))
1051
+ return x
1052
+
1053
+ def forward(self, x: torch.Tensor, mixup_lambda = None, infer_mode = False):# out_feat_keys: List[str] = None):
1054
+ x = self.spectrogram_extractor(x) # (batch_size, 1, time_steps, freq_bins)
1055
+ x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
1056
+
1057
+
1058
+ x = x.transpose(1, 3)
1059
+ x = self.bn0(x)
1060
+ x = x.transpose(1, 3)
1061
+ if self.training:
1062
+ x = self.spec_augmenter(x)
1063
+ if self.training and mixup_lambda is not None:
1064
+ x = do_mixup(x, mixup_lambda)
1065
+
1066
+ if infer_mode:
1067
+ # in infer mode. we need to handle different length audio input
1068
+ frame_num = x.shape[2]
1069
+ target_T = int(self.spec_size * self.freq_ratio)
1070
+ repeat_ratio = math.floor(target_T / frame_num)
1071
+ x = x.repeat(repeats=(1,1,repeat_ratio,1))
1072
+ x = self.reshape_wav2img(x)
1073
+ output_dict = self.forward_features(x)
1074
+ elif self.config.enable_repeat_mode:
1075
+ if self.training:
1076
+ cur_pos = random.randint(0, (self.freq_ratio - 1) * self.spec_size - 1)
1077
+ x = self.repeat_wat2img(x, cur_pos)
1078
+ output_dict = self.forward_features(x)
1079
+ else:
1080
+ output_dicts = []
1081
+ for cur_pos in range(0, (self.freq_ratio - 1) * self.spec_size + 1, self.spec_size):
1082
+ tx = x.clone()
1083
+ tx = self.repeat_wat2img(tx, cur_pos)
1084
+ output_dicts.append(self.forward_features(tx))
1085
+ clipwise_output = torch.zeros_like(output_dicts[0]["clipwise_output"]).float().to(x.device)
1086
+ framewise_output = torch.zeros_like(output_dicts[0]["framewise_output"]).float().to(x.device)
1087
+ for d in output_dicts:
1088
+ clipwise_output += d["clipwise_output"]
1089
+ framewise_output += d["framewise_output"]
1090
+ clipwise_output = clipwise_output / len(output_dicts)
1091
+ framewise_output = framewise_output / len(output_dicts)
1092
+
1093
+ output_dict = {
1094
+ 'framewise_output': framewise_output,
1095
+ 'clipwise_output': clipwise_output
1096
+ }
1097
+ else:
1098
+ if x.shape[2] > self.freq_ratio * self.spec_size:
1099
+ if self.training:
1100
+ x = self.crop_wav(x, crop_size=self.freq_ratio * self.spec_size)
1101
+ x = self.reshape_wav2img(x)
1102
+ output_dict = self.forward_features(x)
1103
+ else:
1104
+ # Change: Hard code here
1105
+ overlap_size = 344 #(x.shape[2] - 1) // 4
1106
+ output_dicts = []
1107
+ crop_size = 689 #(x.shape[2] - 1) // 2
1108
+ for cur_pos in range(0, x.shape[2] - crop_size - 1, overlap_size):
1109
+ tx = self.crop_wav(x, crop_size = crop_size, spe_pos = cur_pos)
1110
+ tx = self.reshape_wav2img(tx)
1111
+ output_dicts.append(self.forward_features(tx))
1112
+ clipwise_output = torch.zeros_like(output_dicts[0]["clipwise_output"]).float().to(x.device)
1113
+ framewise_output = torch.zeros_like(output_dicts[0]["framewise_output"]).float().to(x.device)
1114
+ latent_output = torch.zeros_like(output_dicts[0]["latent_output"]).float().to(x.device)
1115
+ for d in output_dicts:
1116
+ clipwise_output += d["clipwise_output"]
1117
+ framewise_output += d["framewise_output"]
1118
+ latent_output += d["latent_output"]
1119
+ clipwise_output = clipwise_output / len(output_dicts)
1120
+ framewise_output = framewise_output / len(output_dicts)
1121
+ latent_output = latent_output / len(output_dicts)
1122
+ output_dict = {
1123
+ 'framewise_output': framewise_output,
1124
+ 'clipwise_output': clipwise_output,
1125
+ 'latent_output': latent_output,
1126
+ }
1127
+ else: # this part is typically used, and most easy one
1128
+ x = self.reshape_wav2img(x)
1129
+ output_dict = self.forward_features(x)
1130
+ # x = self.head(x)
1131
+ return output_dict
1132
+
1133
+ class HTSATWrapper(nn.Module):
1134
+ def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin,
1135
+ fmax, classes_num, out_emb):
1136
+ super().__init__()
1137
+
1138
+ # print("parameters are being overidden when using HTSAT")
1139
+ # print("HTSAT only support loading a pretrained model on AudioSet")
1140
+ # @TODO later look at what parameters are same and can be merged
1141
+
1142
+ self.htsat = HTSAT_Swin_Transformer(config=HTSATConfig())
1143
+
1144
+ def forward(self, x):
1145
+ out_dict = self.htsat(x)
1146
+ out_dict['embedding'] = out_dict['latent_output']
1147
+ return out_dict
1148
+
1149
+
1150
+ def get_audio_encoder(name: str):
1151
+ if name == "HTSAT":
1152
+ return HTSATWrapper
1153
+ else:
1154
+ raise Exception('The audio encoder name {} is incorrect or not supported'.format(name))
1155
+
1156
+ class Projection(nn.Module):
1157
+ def __init__(self, dim_imgn: int, d_out: int, p: float=0.5) -> None:
1158
+ super().__init__()
1159
+ self.linear1 = nn.Linear(dim_imgn, d_out, bias=False)
1160
+ self.linear2 = nn.Linear(d_out, d_out, bias=False)
1161
+ self.layer_norm = nn.LayerNorm(d_out)
1162
+ self.drop = nn.Dropout(p)
1163
+
1164
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1165
+ embed1 = self.linear1(x)
1166
+ embed2 = self.drop(self.linear2(F.gelu(embed1)))
1167
+ embeds = self.layer_norm(embed1 + embed2)
1168
+ return embeds
1169
+
1170
+ class AudioEncoder(nn.Module):
1171
+ def __init__(self, audioenc_name:str, dim_imgn: int, d_out: int, sample_rate: int, window_size: int,
1172
+ hop_size: int, mel_bins: int, fmin: int, fmax: int, classes_num: int) -> None:
1173
+ super().__init__()
1174
+
1175
+ audio_encoder = get_audio_encoder(audioenc_name)
1176
+
1177
+ self.base = audio_encoder(
1178
+ sample_rate, window_size,
1179
+ hop_size, mel_bins, fmin, fmax,
1180
+ classes_num, dim_imgn)
1181
+
1182
+ self.projection = Projection(dim_imgn, d_out)
1183
+
1184
+ def forward(self, x):
1185
+ out_dict = self.base(x)
1186
+ audio_features, audio_classification_output = out_dict['embedding'], out_dict['clipwise_output']
1187
+ projected_vec = self.projection(audio_features)
1188
+ return projected_vec, audio_classification_output
1189
+
1190
+ class TextEncoder(nn.Module):
1191
+ def __init__(self, d_out: int, text_model: str, transformer_embed_dim: int) -> None:
1192
+ super().__init__()
1193
+ self.text_model = text_model
1194
+ self.base = AutoModel.from_pretrained(text_model)
1195
+
1196
+ if 'clip' in text_model:
1197
+ self.clip_text_projection = self.base.text_projection
1198
+ self.base = self.base.text_model
1199
+ if 'base' in text_model:
1200
+ transformer_embed_dim = 512
1201
+
1202
+ self.projection = Projection(transformer_embed_dim, d_out)
1203
+
1204
+ def forward(self, x):
1205
+ if 'clip' in self.text_model:
1206
+ pooled_output = self.base(**x)[1] # get pooled output
1207
+ out = self.clip_text_projection(pooled_output) # get CLS token output
1208
+ elif 'gpt' in self.text_model:
1209
+ batch_size = x['input_ids'].shape[0]
1210
+ hidden_states = self.base(**x)[0] # (batch_size=4, seq_len, 768)
1211
+
1212
+ sequence_lengths = torch.ne(x['input_ids'], 0).sum(-1) - 1 # tensor([13, 14, 18, 17])
1213
+ out = hidden_states[torch.arange(batch_size, device=hidden_states.device), sequence_lengths] # [batch_size, 768] = [4, 768]
1214
+ else:
1215
+ out = self.base(**x)[0]
1216
+ out = out[:, 0, :] # get CLS token output
1217
+
1218
+ projected_vec = self.projection(out)
1219
+
1220
+ return projected_vec
1221
+
1222
+ class CLAP(nn.Module):
1223
+ def __init__(self,
1224
+ # audio
1225
+ audioenc_name: str,
1226
+ sample_rate: int,
1227
+ window_size: int,
1228
+ hop_size: int,
1229
+ mel_bins: int,
1230
+ fmin: int,
1231
+ fmax: int,
1232
+ classes_num: int,
1233
+ out_emb: int,
1234
+ # text
1235
+ text_model: str,
1236
+ transformer_embed_dim: int,
1237
+ # common
1238
+ d_proj: int,
1239
+ ):
1240
+ super().__init__()
1241
+
1242
+
1243
+ self.audio_encoder = AudioEncoder(
1244
+ audioenc_name, out_emb, d_proj,
1245
+ sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num)
1246
+
1247
+ self.caption_encoder = TextEncoder(
1248
+ d_proj, text_model, transformer_embed_dim
1249
+ )
1250
+
1251
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
1252
+
1253
+ def forward(self, audio, text):
1254
+ audio_embed, _ = self.audio_encoder(audio)
1255
+ caption_embed = self.caption_encoder(text)
1256
+
1257
+ return caption_embed, audio_embed, self.logit_scale.exp()
1258
+
1259
+
1260
+
1261
+ # ==================================================================
1262
+ # A U D I O - P R E - P R O C E S S I N G
1263
+ # ==================================================================
1264
+ def read_audio(audio_path, resample=True):
1265
+ r"""Loads audio file or array and returns a torch tensor"""
1266
+ # Randomly sample a segment of audio_duration from the clip or pad to match duration
1267
+ audio_time_series, sample_rate = torchaudio.load(audio_path)
1268
+
1269
+ resample_rate = clapConfig.sample_rate
1270
+ if resample and resample_rate != sample_rate:
1271
+ resampler = T.Resample(sample_rate, resample_rate)
1272
+ audio_time_series = resampler(audio_time_series)
1273
+ return audio_time_series, resample_rate
1274
+
1275
+ def load_audio_into_tensor(audio_path, audio_duration, resample=False):
1276
+ r"""Loads audio file and returns raw audio."""
1277
+ # Randomly sample a segment of audio_duration from the clip or pad to match duration
1278
+ audio_time_series, sample_rate = read_audio(audio_path, resample)
1279
+ audio_time_series = audio_time_series.reshape(-1)
1280
+
1281
+ # audio_time_series is shorter than predefined audio duration,
1282
+ # so audio_time_series is extended
1283
+ if audio_duration*sample_rate >= audio_time_series.shape[0]:
1284
+ repeat_factor = int(np.ceil((audio_duration*sample_rate) /
1285
+ audio_time_series.shape[0]))
1286
+ # Repeat audio_time_series by repeat_factor to match audio_duration
1287
+ audio_time_series = audio_time_series.repeat(repeat_factor)
1288
+ # remove excess part of audio_time_series
1289
+ audio_time_series = audio_time_series[0:audio_duration*sample_rate]
1290
+ else:
1291
+ # audio_time_series is longer than predefined audio duration,
1292
+ # so audio_time_series is trimmed
1293
+ start_index = random.randrange(
1294
+ audio_time_series.shape[0] - audio_duration*sample_rate)
1295
+ audio_time_series = audio_time_series[start_index:start_index +
1296
+ audio_duration*sample_rate]
1297
+ return torch.FloatTensor(audio_time_series)
1298
+
1299
+ np_str_obj_array_pattern = re.compile(r'[SaUO]')
1300
+ default_collate_err_msg_format = (
1301
+ "default_collate: batch must contain tensors, numpy arrays, numbers, "
1302
+ "dicts or lists; found {}")
1303
+
1304
+ def default_collate(batch):
1305
+ r"""Puts each data field into a tensor with outer dimension batch size"""
1306
+ elem = batch[0]
1307
+ elem_type = type(elem)
1308
+ if isinstance(elem, torch.Tensor):
1309
+ out = None
1310
+ if torch.utils.data.get_worker_info() is not None:
1311
+ # If we're in a background process, concatenate directly into a
1312
+ # shared memory tensor to avoid an extra copy
1313
+ numel = sum([x.numel() for x in batch])
1314
+ storage = elem.storage()._new_shared(numel)
1315
+ out = elem.new(storage)
1316
+ return torch.stack(batch, 0, out=out)
1317
+ elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
1318
+ and elem_type.__name__ != 'string_':
1319
+ if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
1320
+ # array of string classes and object
1321
+ if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
1322
+ raise TypeError(
1323
+ default_collate_err_msg_format.format(elem.dtype))
1324
+
1325
+ return default_collate([torch.as_tensor(b) for b in batch])
1326
+ elif elem.shape == (): # scalars
1327
+ return torch.as_tensor(batch)
1328
+ elif isinstance(elem, float):
1329
+ return torch.tensor(batch, dtype=torch.float64)
1330
+ elif isinstance(elem, int):
1331
+ return torch.tensor(batch)
1332
+ elif isinstance(elem, str):
1333
+ return batch
1334
+ elif isinstance(elem, collections.abc.Mapping):
1335
+ return {key: default_collate([d[key] for d in batch]) for key in elem}
1336
+ elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
1337
+ return elem_type(*(default_collate(samples) for samples in zip(*batch)))
1338
+ elif isinstance(elem, collections.abc.Sequence):
1339
+ # check to make sure that the elements in batch have consistent size
1340
+ it = iter(batch)
1341
+ elem_size = len(next(it))
1342
+ if not all(len(elem) == elem_size for elem in it):
1343
+ raise RuntimeError(
1344
+ 'each element in list of batch should be of equal size')
1345
+ transposed = zip(*batch)
1346
+ return [default_collate(samples) for samples in transposed]
1347
+
1348
+ raise TypeError(default_collate_err_msg_format.format(elem_type))
1349
+
1350
+ def preprocess_audio(audio_files, resample):
1351
+ r"""Load list of audio files and return raw audio"""
1352
+ audio_tensors = []
1353
+ for audio_file in audio_files:
1354
+ audio_tensor = load_audio_into_tensor(
1355
+ audio_file, clapConfig.duration, resample)
1356
+ audio_tensor = audio_tensor.reshape(1, -1)
1357
+ audio_tensors.append(audio_tensor)
1358
+ return default_collate(audio_tensors)
1359
+
1360
+
1361
+
1362
+ # ==================================================================
1363
+ # A U D I O - E M B E D D I N G S - H E L P E R
1364
+ # ==================================================================
1365
+ def CLAPAudioProcessor(audio_files: List[str], resample=True):
1366
+ preprocessed_audio = preprocess_audio(audio_files, resample)
1367
+ preprocessed_audio = preprocessed_audio.reshape(
1368
+ preprocessed_audio.shape[0], preprocessed_audio.shape[2])
1369
+ preprocessed_audio = preprocessed_audio
1370
+ return preprocessed_audio
1371
+
1372
+ def get_audio_embeddings(audio_files: List[str], audio_encoder, resample=True):
1373
+ """Load list of audio files and return audio embeddings"""
1374
+ # preprocessed_audio = preprocess_audio(audio_files, resample)
1375
+ # with torch.no_grad():
1376
+ # preprocessed_audio = preprocessed_audio.reshape(
1377
+ # preprocessed_audio.shape[0], preprocessed_audio.shape[2])
1378
+ with torch.no_grad():
1379
+ preprocessed_audio = CLAPAudioProcessor(audio_files, resample)
1380
+ return audio_encoder(preprocessed_audio)[0]
1381
+
1382
+
1383
+ # ==================================================================
1384
+ # C L A P
1385
+ # ==================================================================
1386
+ class ClapConfig:
1387
+ # TEXT ENCODER CONFIG
1388
+ text_model = 'gpt2'
1389
+ text_len = 77
1390
+ transformer_embed_dim = 768
1391
+ freeze_text_encoder_weights = True
1392
+
1393
+ # AUDIO ENCODER CONFIG
1394
+ audioenc_name = 'HTSAT'
1395
+ out_emb = 768
1396
+ sample_rate = 44100
1397
+ duration = 7
1398
+ fmin = 50
1399
+ fmax = 8000 # 14000
1400
+ n_fft = 1024 # 1028
1401
+ hop_size = 320
1402
+ mel_bins = 64
1403
+ window_size = 1024
1404
+
1405
+ # PROJECTION SPACE CONFIG
1406
+ d_proj = 1024
1407
+ temperature = 0.003
1408
+
1409
+ # TRAINING AND EVALUATION CONFIG
1410
+ num_classes = 527
1411
+ batch_size = 1024
1412
+ demo = False
1413
+
1414
+
1415
+ clapConfig = ClapConfig()
1416
+ clap = CLAP(
1417
+ audioenc_name=clapConfig.audioenc_name,
1418
+ sample_rate=clapConfig.sample_rate,
1419
+ window_size=clapConfig.window_size,
1420
+ hop_size=clapConfig.hop_size,
1421
+ mel_bins=clapConfig.mel_bins,
1422
+ fmin=clapConfig.fmin,
1423
+ fmax=clapConfig.fmax,
1424
+ classes_num=clapConfig.num_classes,
1425
+ out_emb=clapConfig.out_emb,
1426
+ text_model=clapConfig.text_model,
1427
+ transformer_embed_dim=clapConfig.transformer_embed_dim,
1428
+ d_proj=clapConfig.d_proj
1429
+ )
1430
+
1431
+ model_repo = "microsoft/msclap"
1432
+ model_name = {
1433
+ '2022': 'CLAP_weights_2022.pth',
1434
+ '2023': 'CLAP_weights_2023.pth',
1435
+ 'clapcap': 'clapcap_weights_2023.pth'
1436
+ }
1437
+
1438
+ version = '2023'
1439
+ model_fp = hf_hub_download(model_repo, model_name[version])
1440
+
1441
+ model_state_dict = torch.load(model_fp, map_location=torch.device('cpu'))['model']
1442
+ clap.load_state_dict(model_state_dict, strict=False)
1443
+ clap.eval()
1444
+
1445
+ clap_audio_encoder = clap.audio_encoder.eval().to(device)
1446
+
1447
+
1448
+ # ENGLISH_AUDIO_DIR = r"/home/IITB/ai-at-ieor/23m1521/datasets/Vaani/Audios/English"
1449
+ # audio_files = [os.path.join(ENGLISH_AUDIO_DIR, i) for i in os.listdir(ENGLISH_AUDIO_DIR) if i.endswith(".wav")]
1450
+ # audio_embedding = get_audio_embeddings(audio_files, clap_audio_encoder)
1451
+ # print("CLAP Audio Encoder Embeddings:", audio_embedding.shape) # [5, 1024]
1452
+
1453
+
1454
+ # ==================================================================
1455
+ # C L A P - L o R A - M O D E L
1456
+ # ==================================================================
1457
+ LoRAconfig = {
1458
+ "peft_type": "LORA",
1459
+ "task_type": "FEATURE_EXTRACTION",
1460
+ "inference_mode": False,
1461
+ "r": 16,
1462
+ "target_modules": ["qkv", "fc1", "fc2", "proj", "linear1", "linear2"],
1463
+ "lora_alpha": 32,
1464
+ "lora_dropout": 0.05,
1465
+ "fan_in_fan_out": False,
1466
+ "bias": "all",
1467
+ }
1468
+ peft_config = get_peft_config(LoRAconfig)
1469
+
1470
+ peft_model = get_peft_model(clap_audio_encoder, peft_config)
1471
+
1472
+ peft_model.print_trainable_parameters()
1473
+
1474
+ peft_clap_audio_encoder = peft_model.base_model
1475
+ # audio_embedding = get_audio_embeddings(audio_files, peft_clap_audio_encoder)
1476
+ # print("CLAP LoRA Audio Encoder Embeddings:", audio_embedding.shape) # [5, 1024]
1477
+
1478
+
1479
+
1480
+ # ==================================================================
1481
+ # C L I P - M O D E L
1482
+ # ==================================================================
1483
+ from transformers import CLIPImageProcessorFast, CLIPImageProcessor
1484
+ clip_vision_model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
1485
+ # clip_vision_processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
1486
+ clip_vision_processor = CLIPImageProcessorFast.from_pretrained("openai/clip-vit-base-patch32")
1487
+ # clip_vision_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
1488
+
1489
+ # image = Image.open("/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/Img_Audio_Alignment/000000039769.jpg")
1490
+ # inputs = clip_vision_processor(images=image, return_tensors="pt")
1491
+ # print("CLIP input image:", inputs['pixel_values'].shape)
1492
+ # # input_data = {'pixel_values': inputs}
1493
+
1494
+ # IMAGE_SIZE = 224
1495
+ # dummy_input = torch.randn(1, 3, IMAGE_SIZE, IMAGE_SIZE) # [1, 3, 224, 224]
1496
+ # # dummy_input_data = {'pixel_values': dummy_input}
1497
+
1498
+ # output = clip_vision_model(inputs['pixel_values'])
1499
+ # print("CLIP Image Encoder Embeddings:", output.last_hidden_state.shape) # [1, 50, 768]
1500
+ # print("CLIP Image Encoder Pooled Output:", output.pooler_output.shape) # [1, 768]
1501
+
1502
+
1503
+ # ==================================================================
1504
+ # C S I P - M O D U L E
1505
+ # ==================================================================
1506
+ class CSIP(nn.Module):
1507
+ def __init__(self, image_encoder, audio_encoder, dim_img=None, dim_audio=1024, dim_emb=768):
1508
+ super(CSIP, self).__init__()
1509
+ self.image_encoder = image_encoder # CLIPVisionModel
1510
+ self.audio_encoder = audio_encoder # CLAP_audio_encoder
1511
+
1512
+ # self.image_proj = nn.Linear(dim_img, dim_emb)
1513
+ self.audio_proj = nn.Linear(dim_audio, dim_emb)
1514
+
1515
+ # Learnable temperature parameter
1516
+ self.log_temp = nn.Parameter(torch.tensor(0.07).log())
1517
+
1518
+ def forward(self, images, audios):
1519
+ # Step 1: Feature extraction
1520
+ with torch.no_grad():
1521
+ with torch.inference_mode():
1522
+ image_features = self.image_encoder(images).pooler_output # shape: [n, dim_img]
1523
+ audio_features = self.audio_encoder(audios)[0] # shape: [n, dim_audio]
1524
+
1525
+ # Step 2: Project and normalize
1526
+ image_embeds = F.normalize(image_features) # [n, dim_emb]
1527
+ audio_embeds = F.normalize(self.audio_proj(audio_features), dim=1) # [n, dim_emb]
1528
+
1529
+ # Step 3: Cosine similarity with temperature
1530
+ logits = torch.matmul(image_embeds, audio_embeds.T) * self.log_temp.exp() # [n, n]
1531
+ probs = logits.softmax(dim=1)
1532
+
1533
+ # Step 4: Symmetric cross-entropy loss
1534
+ labels = torch.arange(len(images), device=images.device)
1535
+ loss_i = F.cross_entropy(logits, labels)
1536
+ loss_t = F.cross_entropy(logits.T, labels)
1537
+ loss = (loss_i + loss_t) / 2
1538
+ return loss, logits, probs
1539
+
1540
+
1541
+ # ==================================================================
1542
+ # I M A G E - A U D I O - D A T A S E T
1543
+ # ==================================================================
1544
+ class VaaniImageAudioDataset(torch.utils.data.Dataset):
1545
+ def __init__(self, df):
1546
+ self.image_paths = df.image_path.tolist()
1547
+ self.audio_paths = df.audio_path.tolist()
1548
+
1549
+ def __len__(self):
1550
+ return len(self.audio_paths)
1551
+
1552
+ def __getitem__(self, idx):
1553
+ return {
1554
+ 'image_path': self.image_paths[idx],
1555
+ 'audio_path': self.audio_paths[idx]
1556
+ }
1557
+
1558
+
1559
+ def collate_fn(batch):
1560
+ image_tensor = clip_vision_processor([Image.open(item['image_path']) for item in batch])['pixel_values']
1561
+ audio_tensor = CLAPAudioProcessor([item['audio_path'] for item in batch], resample=True)
1562
+ return {'image_tensor': torch.stack(image_tensor), 'audio_tensor': audio_tensor}
1563
+
1564
+
1565
+ # preprocessed_audio = CLAPAudioProcessor(audio_files, resample=True)
1566
+ # clip_vision_processor = clip_vision_processor
1567
+
1568
+ train_df = pd.read_csv("/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/Img_Audio_Alignment/available_img_audios_TRAIN.csv")
1569
+ test_df = pd.read_csv("/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/Img_Audio_Alignment/available_img_audios_TEST.csv")
1570
+ train_dataset = VaaniImageAudioDataset(train_df)
1571
+ test_dataset = VaaniImageAudioDataset(test_df)
1572
+ BATCH_SIZE = 32
1573
+
1574
+ print('Train Dataset:', len(train_dataset))
1575
+ print('Test Dataset:', len(test_dataset))
1576
+
1577
+
1578
+ train_dataloader = torch.utils.data.DataLoader(
1579
+ train_dataset,
1580
+ batch_size=BATCH_SIZE,
1581
+ shuffle=True,
1582
+ num_workers=48,
1583
+ collate_fn=collate_fn,
1584
+ pin_memory=True,
1585
+ drop_last=True,
1586
+ persistent_workers=True
1587
+ )
1588
+
1589
+ test_dataloader = torch.utils.data.DataLoader(
1590
+ test_dataset,
1591
+ batch_size=BATCH_SIZE,
1592
+ shuffle=False,
1593
+ num_workers=48,
1594
+ collate_fn=collate_fn,
1595
+ pin_memory=True,
1596
+ drop_last=False,
1597
+ persistent_workers=True
1598
+ )
1599
+
1600
+ batch = next(iter(train_dataloader))
1601
+ print("Image batch:", batch['image_tensor'].shape)
1602
+ print("Audio batch:", batch['audio_tensor'].shape)
1603
+
1604
+
1605
+ csip_model = CSIP(clip_vision_model.eval(), peft_clap_audio_encoder).to(device)
1606
+
1607
+
1608
+ # loss, logits, probs = csip_model(batch['image_tensor'].to(device), batch['audio_tensor'].to(device))
1609
+ # loss, logits, probs, logits.shape, probs.shape
1610
+
1611
+
1612
+ def train_batch(model, images, audio, optimizer):
1613
+ model.train()
1614
+ optimizer.zero_grad()
1615
+ loss, logits, probs = model(images, audio)
1616
+ loss.backward()
1617
+ optimizer.step()
1618
+ return loss.item(), logits, probs
1619
+
1620
+ @torch.no_grad()
1621
+ def evaluate_batch(model, images, audio):
1622
+ model.eval()
1623
+ loss, logits, probs = model(images, audio)
1624
+ return loss.item(), logits, probs
1625
+
1626
+ def save_checkpoint(state, checkpoint_dir, epoch, max_checkpoints=2):
1627
+ filename = f"csip_best_epoch_{epoch+1}.pt"
1628
+ path = os.path.join(checkpoint_dir, filename)
1629
+ torch.save(state, path)
1630
+ checkpoints = sorted(
1631
+ [f for f in os.listdir(checkpoint_dir) if f.startswith("csip_best_epoch_")],
1632
+ key=lambda x: int(x.split("_")[-1].split(".")[0])
1633
+ )
1634
+ while len(checkpoints) > max_checkpoints:
1635
+ to_delete = checkpoints.pop(0)
1636
+ os.remove(os.path.join(checkpoint_dir, to_delete))
1637
+
1638
+
1639
+ def load_checkpoint(checkpoint_dir, model, optimizer):
1640
+ checkpoints = sorted(
1641
+ [f for f in os.listdir(checkpoint_dir) if f.startswith("csip_best_epoch_")],
1642
+ key=lambda x: int(x.split("_")[-1].split(".")[0])
1643
+ )
1644
+ if not checkpoints:
1645
+ print("No checkpoint found to resume from.")
1646
+ return 0, float("inf")
1647
+
1648
+ best_ckpt = checkpoints[-1]
1649
+ path = os.path.join(checkpoint_dir, best_ckpt)
1650
+ checkpoint = torch.load(path)
1651
+ model.load_state_dict(checkpoint['model_state'])
1652
+ optimizer.load_state_dict(checkpoint['optimizer_state'])
1653
+ start_epoch = checkpoint['epoch']
1654
+ best_loss = checkpoint['best_loss']
1655
+ print(f"Resumed training from epoch {start_epoch+1} with best loss {best_loss:.4f}")
1656
+ return start_epoch, best_loss
1657
+
1658
+
1659
+ def fig_to_tensor(fig):
1660
+ """Convert a Matplotlib figure to a tensor suitable for TensorBoard."""
1661
+ buf = io.BytesIO()
1662
+ fig.savefig(buf, format='png')
1663
+ buf.seek(0)
1664
+ image = Image.open(buf).convert("RGB")
1665
+ tensor = tv.transforms.functional.to_tensor(image)
1666
+ buf.close()
1667
+ plt.close(fig)
1668
+ return tensor
1669
+
1670
+ def save_similarity_heatmaps(logits, epoch, loss, save_dir, writer):
1671
+ os.makedirs(os.path.join(save_dir, 'logits'), exist_ok=True)
1672
+ os.makedirs(os.path.join(save_dir, 'probs'), exist_ok=True)
1673
+
1674
+ # --- Raw logits heatmap ---
1675
+ logits_np = logits.detach().cpu().numpy()
1676
+ fig_logits = plt.figure(figsize=(8, 6))
1677
+ sns.heatmap(logits_np, square=True, cmap="Blues", cbar=True, annot=True, fmt=".1f")
1678
+ plt.title(f"Raw Logits Heatmap — Epoch {epoch+1}, Loss {loss:.4f}")
1679
+ plt.xlabel("Audio Index")
1680
+ plt.ylabel("Image Index")
1681
+ raw_path = os.path.join(save_dir, 'logits', f"raw_logits_epoch_{epoch+1}_loss_{loss:.4f}.png")
1682
+ fig_logits.savefig(raw_path)
1683
+ writer.add_image("Heatmap/RawLogits", fig_to_tensor(fig_logits), global_step=epoch+1)
1684
+
1685
+ # --- Softmax probs heatmap ---
1686
+ probs_np = logits.softmax(dim=1).cpu().numpy()
1687
+ fig_probs = plt.figure(figsize=(8, 6))
1688
+ sns.heatmap(probs_np, square=True, cmap="Blues", cbar=True, annot=True, fmt=".1f")
1689
+ plt.title(f"Softmax Probabilities Heatmap — Epoch {epoch+1}, Loss {loss:.4f}")
1690
+ plt.xlabel("Audio Index")
1691
+ plt.ylabel("Image Index")
1692
+ prob_path = os.path.join(save_dir, "probs", f"probs_epoch_{epoch+1}_loss_{loss:.4f}.png")
1693
+ fig_probs.savefig(prob_path)
1694
+ writer.add_image("Heatmap/SoftmaxProbs", fig_to_tensor(fig_probs), global_step=epoch+1)
1695
+
1696
+
1697
+
1698
+ def train_model(model, train_loader, test_loader,
1699
+ optimizer, device, log_dir,
1700
+ checkpoint_dir, resume=False, epochs=10):
1701
+
1702
+ os.makedirs(log_dir, exist_ok=True)
1703
+ os.makedirs(checkpoint_dir, exist_ok=True)
1704
+ csv_path = os.path.join(log_dir, "training_log.csv")
1705
+
1706
+ writer = SummaryWriter(log_dir=log_dir)
1707
+
1708
+ start_epoch = 0
1709
+ best_loss = float("inf")
1710
+ best_epoch = -1
1711
+
1712
+ if resume:
1713
+ start_epoch, best_loss = load_checkpoint(checkpoint_dir, model, optimizer)
1714
+
1715
+ # If resuming, don't overwrite the CSV
1716
+ if not (resume and os.path.exists(csv_path)):
1717
+ with open(csv_path, mode='w', newline='') as f:
1718
+ writer_csv = csv.writer(f)
1719
+ writer_csv.writerow(["Epoch", "Train Loss", "Test Loss", "Best Loss", "Best Epoch"])
1720
+
1721
+ for epoch in trange(start_epoch, epochs, colour='yellow', dynamic_ncols=True):
1722
+ train_losses = []
1723
+ test_losses = []
1724
+
1725
+ train_loop = tqdm(train_loader, desc=f"[Train Epoch {epoch+1}]", colour='blue', dynamic_ncols=True)
1726
+ for batch in train_loop:
1727
+ images = batch['image_tensor'].to(device)
1728
+ audios = batch['audio_tensor'].to(device)
1729
+ loss, logits, probs = train_batch(model, images, audios, optimizer)
1730
+ train_losses.append(loss)
1731
+ train_loop.set_postfix(train_loss=loss)
1732
+
1733
+ test_loop = tqdm(test_loader, desc=f"[Test Epoch {epoch+1}]", colour='red', dynamic_ncols=True)
1734
+ for batch in test_loop:
1735
+ images = batch['image_tensor'].to(device)
1736
+ audios = batch['audio_tensor'].to(device)
1737
+ loss, logits, probs = evaluate_batch(model, images, audios)
1738
+ test_losses.append(loss)
1739
+ test_loop.set_postfix(test_loss=loss)
1740
+
1741
+ avg_train_loss = sum(train_losses) / len(train_losses)
1742
+ avg_test_loss = sum(test_losses) / len(test_losses)
1743
+
1744
+ writer.add_scalar("Loss/Train", avg_train_loss, epoch + 1)
1745
+ writer.add_scalar("Loss/Test", avg_test_loss, epoch + 1)
1746
+
1747
+ print(f"Epoch {epoch+1} | Train Loss: {avg_train_loss:.4f} | Test Loss: {avg_test_loss:.4f}")
1748
+
1749
+ if avg_test_loss < best_loss:
1750
+ save_similarity_heatmaps(logits, epoch, avg_test_loss, checkpoint_dir, writer)
1751
+ best_loss = avg_test_loss
1752
+ best_epoch = epoch + 1
1753
+ save_checkpoint({
1754
+ 'epoch': epoch,
1755
+ 'model_state': model.state_dict(),
1756
+ 'optimizer_state': optimizer.state_dict(),
1757
+ 'best_loss': best_loss
1758
+ }, checkpoint_dir, epoch)
1759
+ print(f">>> Saved new best model at epoch {epoch+1}")
1760
+
1761
+ # Append row to CSV
1762
+ with open(csv_path, mode='a', newline='') as f:
1763
+ writer_csv = csv.writer(f)
1764
+ writer_csv.writerow([epoch + 1, avg_train_loss, avg_test_loss, best_loss, best_epoch])
1765
+
1766
+ writer.close()
1767
+
1768
+
1769
+
1770
+
1771
+ learning_rate = 1e-20
1772
+ epochs = 400
1773
+ optimizer = torch.optim.AdamW(csip_model.parameters(), lr=learning_rate)
1774
+ train_model(
1775
+ model=csip_model,
1776
+ train_loader=train_dataloader,
1777
+ test_loader=test_dataloader,
1778
+ optimizer=optimizer,
1779
+ device=device,
1780
+ log_dir="runs/csip",
1781
+ checkpoint_dir="checkpoints/csip",
1782
+ resume=True,
1783
+ epochs=epochs
1784
+ )
1785
+
1786
+
1787
+ # tensorboard --logdir=/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/Img_Audio_Alignment/runs --port=6006
1788
+ # 127.0.0.1:40697
Vaani/Img_Audio_Alignment/_2_Train.ipynb ADDED
@@ -0,0 +1,1938 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "5d25d5c5",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stdout",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "Using device: cpu\n"
14
+ ]
15
+ }
16
+ ],
17
+ "source": [
18
+ "# ==================================================================\n",
19
+ "# L A T E N T D I F F U S I O N M O D E L\n",
20
+ "# ==================================================================\n",
21
+ "# Author : Ashish Kumar Uchadiya\n",
22
+ "# Created : May 11, 2025\n",
23
+ "# Description: This script implements the training of a VQ-VAE model for\n",
24
+ "# image reconstruction, integrated with Latent Diffusion Models (LDMs) and\n",
25
+ "# audio conditioning. The VQ-VAE maps images to a discrete latent space, \n",
26
+ "# which is then modeled by the LDM for learning a diffusion process over the \n",
27
+ "# compressed representation. Audio features are used as conditioning inputs \n",
28
+ "# to guide the generation process. The training minimizes a combination of \n",
29
+ "# LPIPS (Learned Perceptual Image Patch Similarity) loss for perceptual \n",
30
+ "# fidelity and PatchGAN loss to enforce local realism. This setup enables \n",
31
+ "# efficient and semantically-aware generation of high-quality images driven \n",
32
+ "# by audio cues.\n",
33
+ "# ==================================================================\n",
34
+ "# I M P O R T S\n",
35
+ "# ==================================================================\n",
36
+ "from __future__ import annotations\n",
37
+ "import warnings\n",
38
+ "warnings.filterwarnings(\"ignore\")\n",
39
+ "\n",
40
+ "\n",
41
+ "import os\n",
42
+ "import io\n",
43
+ "import sys\n",
44
+ "import math\n",
45
+ "import random\n",
46
+ "import collections\n",
47
+ "import collections.abc\n",
48
+ "import re\n",
49
+ "from itertools import repeat\n",
50
+ "from pathlib import Path\n",
51
+ "from typing import Optional, Tuple, Union, List, Dict\n",
52
+ "\n",
53
+ "import numpy as np\n",
54
+ "import pandas as pd\n",
55
+ "from PIL import Image\n",
56
+ "import seaborn as sns\n",
57
+ "import matplotlib.pyplot as plt\n",
58
+ "from tqdm.auto import trange, tqdm\n",
59
+ "\n",
60
+ "import torch\n",
61
+ "import torch.nn.functional as F\n",
62
+ "from torch import nn\n",
63
+ "from torch.nn.init import _calculate_fan_in_and_fan_out\n",
64
+ "import torch.utils.checkpoint as checkpoint\n",
65
+ "\n",
66
+ "import torchvision as tv\n",
67
+ "from torchvision.transforms import v2\n",
68
+ "from torch.utils.tensorboard import SummaryWriter\n",
69
+ "\n",
70
+ "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"\n",
71
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
72
+ "print(f\"Using device: {device}\")\n",
73
+ "\n",
74
+ "import torchaudio\n",
75
+ "import torchaudio.transforms as T\n",
76
+ "from torchlibrosa.stft import Spectrogram, LogmelFilterBank\n",
77
+ "from torchlibrosa.augmentation import SpecAugmentation\n",
78
+ "\n",
79
+ "from transformers import AutoModel, AutoTokenizer, logging\n",
80
+ "from huggingface_hub.file_download import hf_hub_download\n",
81
+ "from huggingface_hub.file_download import hf_hub_download\n",
82
+ "from peft import get_peft_config, get_peft_model\n",
83
+ "from transformers import CLIPVisionModel, AutoProcessor"
84
+ ]
85
+ },
86
+ {
87
+ "cell_type": "code",
88
+ "execution_count": 2,
89
+ "id": "a41df980",
90
+ "metadata": {},
91
+ "outputs": [],
92
+ "source": [
93
+ "# ==================================================================\n",
94
+ "# H T S - A T\n",
95
+ "# ==================================================================\n",
96
+ "class HTSATConfig:\n",
97
+ " # Ke Chen\n",
98
+ " # [email protected]\n",
99
+ " # HTS-AT: A HIERARCHICAL TOKEN-SEMANTIC AUDIO TRANSFORMER FOR SOUND CLASSIFICATION AND DETECTION\n",
100
+ " # The configuration for training the model\n",
101
+ "\n",
102
+ " exp_name = \"exp_htsat_pretrain\" # the saved ckpt prefix name of the model \n",
103
+ " workspace = \"/home/kechen/Research/HTSAT\" # the folder of your code\n",
104
+ " dataset_path = \"/home/Research/audioset\" # the dataset path\n",
105
+ " desed_folder = \"/home/Research/DESED\" # the desed file\n",
106
+ "\n",
107
+ " dataset_type = \"audioset\" # \"audioset\" \"esc-50\" \"scv2\"\n",
108
+ " index_type = \"full_train\" # only works for audioset\n",
109
+ " balanced_data = True # only works for audioset\n",
110
+ "\n",
111
+ " loss_type = \"clip_bce\" # \n",
112
+ " # AudioSet & SCV2: \"clip_bce\" | ESC-50: \"clip_ce\" \n",
113
+ "\n",
114
+ " # trained from a checkpoint, or evaluate a single model \n",
115
+ " resume_checkpoint = None \n",
116
+ " # \"/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_1.ckpt\"\n",
117
+ " \n",
118
+ " esc_fold = 0 # just for esc dataset, select the fold you need for evaluation and (+1) validation\n",
119
+ "\n",
120
+ "\n",
121
+ " debug = False\n",
122
+ "\n",
123
+ " random_seed = 970131 # 19970318 970131 12412 127777 1009 34047\n",
124
+ " batch_size = 32 * 4 # batch size per GPU x GPU number , default is 32 x 4 = 128\n",
125
+ " learning_rate = 1e-3 # 1e-4 also workable \n",
126
+ " max_epoch = 100\n",
127
+ " num_workers = 3\n",
128
+ "\n",
129
+ " lr_scheduler_epoch = [10,20,30]\n",
130
+ " lr_rate = [0.02, 0.05, 0.1]\n",
131
+ "\n",
132
+ " # these data preparation optimizations do not bring many improvements, so deprecated\n",
133
+ " enable_token_label = False # token label\n",
134
+ " class_map_path = \"class_hier_map.npy\"\n",
135
+ " class_filter = None \n",
136
+ " retrieval_index = [15382, 9202, 130, 17618, 17157, 17516, 16356, 6165, 13992, 9238, 5550, 5733, 1914, 1600, 3450, 13735, 11108, 3762, \n",
137
+ " 9840, 11318, 8131, 4429, 16748, 4992, 16783, 12691, 4945, 8779, 2805, 9418, 2797, 14357, 5603, 212, 3852, 12666, 1338, 10269, 2388, 8260, 4293, 14454, 7677, 11253, 5060, 14938, 8840, 4542, 2627, 16336, 8992, 15496, 11140, 446, 6126, 10691, 8624, 10127, 9068, 16710, 10155, 14358, 7567, 5695, 2354, 8057, 17635, 133, 16183, 14535, 7248, 4560, 14429, 2463, 10773, 113, 2462, 9223, 4929, 14274, 4716, 17307, 4617, 2132, 11083, 1039, 1403, 9621, 13936, 2229, 2875, 17840, 9359, 13311, 9790, 13288, 4750, 17052, 8260, 14900]\n",
138
+ " token_label_range = [0.2,0.6]\n",
139
+ " enable_time_shift = False # shift time\n",
140
+ " enable_label_enhance = False # enhance hierarchical label\n",
141
+ " enable_repeat_mode = False # repeat the spectrogram / reshape the spectrogram\n",
142
+ "\n",
143
+ "\n",
144
+ "\n",
145
+ " # for model's design\n",
146
+ " enable_tscam = True # enbale the token-semantic layer\n",
147
+ "\n",
148
+ " # for signal processing\n",
149
+ " sample_rate = 32000 # 16000 for scv2, 32000 for audioset and esc-50\n",
150
+ " clip_samples = sample_rate * 10 # audio_set 10-sec clip\n",
151
+ " window_size = 1024\n",
152
+ " hop_size = 320 # 160 for scv2, 320 for audioset and esc-50\n",
153
+ " mel_bins = 64\n",
154
+ " fmin = 50\n",
155
+ " fmax = 14000\n",
156
+ " shift_max = int(clip_samples * 0.5)\n",
157
+ "\n",
158
+ " # for data collection\n",
159
+ " classes_num = 527 # esc: 50 | audioset: 527 | scv2: 35\n",
160
+ " patch_size = (25, 4) # deprecated\n",
161
+ " crop_size = None # int(clip_samples * 0.5) deprecated\n",
162
+ "\n",
163
+ " # for htsat hyperparamater\n",
164
+ " htsat_window_size = 8\n",
165
+ " htsat_spec_size = 256\n",
166
+ " htsat_patch_size = 4 \n",
167
+ " htsat_stride = (4, 4)\n",
168
+ " htsat_num_head = [4,8,16,32]\n",
169
+ " htsat_dim = 96 \n",
170
+ " htsat_depth = [2,2,6,2]\n",
171
+ "\n",
172
+ " swin_pretrain_path = None\n",
173
+ " # \"/home/Research/model_backup/pretrain/swin_tiny_c24_patch4_window8_256.pth\"\n",
174
+ "\n",
175
+ " # Some Deprecated Optimization in the model design, check the model code for details\n",
176
+ " htsat_attn_heatmap = False\n",
177
+ " htsat_hier_output = False \n",
178
+ " htsat_use_max = False\n",
179
+ "\n",
180
+ "\n",
181
+ " # for ensemble test \n",
182
+ "\n",
183
+ " ensemble_checkpoints = []\n",
184
+ " ensemble_strides = []\n",
185
+ "\n",
186
+ "\n",
187
+ " # weight average folder\n",
188
+ " wa_folder = \"/home/version_0/checkpoints/\"\n",
189
+ " # weight average output filename\n",
190
+ " wa_model_path = \"HTSAT_AudioSet_Saved_x.ckpt\"\n",
191
+ "\n",
192
+ " esm_model_pathes = [\n",
193
+ " \"/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_1.ckpt\",\n",
194
+ " \"/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_2.ckpt\",\n",
195
+ " \"/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_3.ckpt\",\n",
196
+ " \"/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_4.ckpt\",\n",
197
+ " \"/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_5.ckpt\",\n",
198
+ " \"/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_6.ckpt\"\n",
199
+ " ]\n",
200
+ "\n",
201
+ " # for framewise localization\n",
202
+ " heatmap_dir = \"/home/Research/heatmap_output\"\n",
203
+ " test_file = \"htsat-test-ensemble\"\n",
204
+ " fl_local = False # indicate if we need to use this dataset for the framewise detection\n",
205
+ " fl_dataset = \"/home/Research/desed/desedim_embval.npy\" \n",
206
+ " fl_class_num = [\n",
207
+ " \"Speech\", \"Frying\", \"Dishes\", \"Running_water\",\n",
208
+ " \"Blender\", \"Electric_shaver_toothbrush\", \"Alarm_bell_ringing\",\n",
209
+ " \"Cat\", \"Dog\", \"Vacuum_cleaner\"\n",
210
+ " ]\n",
211
+ "\n",
212
+ " # map 527 classes into 10 classes\n",
213
+ " fl_audioset_mapping = [\n",
214
+ " [0,1,2,3,4,5,6,7],\n",
215
+ " [366, 367, 368],\n",
216
+ " [364],\n",
217
+ " [288, 289, 290, 291, 292, 293, 294, 295, 296, 297],\n",
218
+ " [369],\n",
219
+ " [382],\n",
220
+ " [310, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402],\n",
221
+ " [81, 82, 83, 84, 85],\n",
222
+ " [74, 75, 76, 77, 78, 79],\n",
223
+ " [377]\n",
224
+ " ]\n",
225
+ "\n",
226
+ "\n",
227
+ "\n",
228
+ "def _ntuple(n):\n",
229
+ " def parse(x):\n",
230
+ " if isinstance(x, collections.abc.Iterable):\n",
231
+ " return x\n",
232
+ " return tuple(repeat(x, n))\n",
233
+ " return parse\n",
234
+ "\n",
235
+ "to_1tuple = _ntuple(1)\n",
236
+ "to_2tuple = _ntuple(2)\n",
237
+ "to_3tuple = _ntuple(3)\n",
238
+ "to_4tuple = _ntuple(4)\n",
239
+ "to_ntuple = _ntuple\n",
240
+ "\n",
241
+ "def do_mixup(x, mixup_lambda):\n",
242
+ " \"\"\"Mixup x of even indexes (0, 2, 4, ...) with x of odd indexes \n",
243
+ " (1, 3, 5, ...).\n",
244
+ " Args:\n",
245
+ " x: (batch_size * 2, ...)\n",
246
+ " mixup_lambda: (batch_size * 2,)\n",
247
+ " Returns:\n",
248
+ " out: (batch_size, ...)\n",
249
+ " \"\"\"\n",
250
+ " out = (x[0 :: 2].transpose(0, -1) * mixup_lambda[0 :: 2] + \\\n",
251
+ " x[1 :: 2].transpose(0, -1) * mixup_lambda[1 :: 2]).transpose(0, -1)\n",
252
+ " return out\n",
253
+ "\n",
254
+ "def interpolate(x, ratio):\n",
255
+ " \"\"\"Interpolate data in time domain. This is used to compensate the \n",
256
+ " resolution reduction in downsampling of a CNN.\n",
257
+ " \n",
258
+ " Args:\n",
259
+ " x: (batch_size, time_steps, classes_num)\n",
260
+ " ratio: int, ratio to interpolate\n",
261
+ " Returns:\n",
262
+ " upsampled: (batch_size, time_steps * ratio, classes_num)\n",
263
+ " \"\"\"\n",
264
+ " (batch_size, time_steps, classes_num) = x.shape\n",
265
+ " upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1)\n",
266
+ " upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num)\n",
267
+ " return upsampled\n",
268
+ "\n",
269
+ "\n",
270
+ "def drop_path(x, drop_prob: float = 0., training: bool = False):\n",
271
+ " \"\"\"Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).\n",
272
+ " This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,\n",
273
+ " the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...\n",
274
+ " See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for\n",
275
+ " changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use\n",
276
+ " 'survival rate' as the argument.\n",
277
+ " \"\"\"\n",
278
+ " if drop_prob == 0. or not training:\n",
279
+ " return x\n",
280
+ " keep_prob = 1 - drop_prob\n",
281
+ " shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets\n",
282
+ " random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)\n",
283
+ " random_tensor.floor_() # binarize\n",
284
+ " output = x.div(keep_prob) * random_tensor\n",
285
+ " return output\n",
286
+ "\n",
287
+ "\n",
288
+ "class DropPath(nn.Module):\n",
289
+ " \"\"\"Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).\n",
290
+ " \"\"\"\n",
291
+ " def __init__(self, drop_prob=None):\n",
292
+ " super(DropPath, self).__init__()\n",
293
+ " self.drop_prob = drop_prob\n",
294
+ "\n",
295
+ " def forward(self, x):\n",
296
+ " return drop_path(x, self.drop_prob, self.training)\n",
297
+ "\n",
298
+ "class PatchEmbed(nn.Module):\n",
299
+ " \"\"\" 2D Image to Patch Embedding\n",
300
+ " \"\"\"\n",
301
+ " def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True, patch_stride = 16):\n",
302
+ " super().__init__()\n",
303
+ " img_size = to_2tuple(img_size)\n",
304
+ " patch_size = to_2tuple(patch_size)\n",
305
+ " patch_stride = to_2tuple(patch_stride)\n",
306
+ " self.img_size = img_size\n",
307
+ " self.patch_size = patch_size\n",
308
+ " self.patch_stride = patch_stride\n",
309
+ " self.grid_size = (img_size[0] // patch_stride[0], img_size[1] // patch_stride[1])\n",
310
+ " self.num_patches = self.grid_size[0] * self.grid_size[1]\n",
311
+ " self.flatten = flatten\n",
312
+ " self.in_chans = in_chans\n",
313
+ " self.embed_dim = embed_dim\n",
314
+ " \n",
315
+ " padding = ((patch_size[0] - patch_stride[0]) // 2, (patch_size[1] - patch_stride[1]) // 2)\n",
316
+ "\n",
317
+ " self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_stride, padding=padding)\n",
318
+ " self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()\n",
319
+ "\n",
320
+ " def forward(self, x):\n",
321
+ " B, C, H, W = x.shape\n",
322
+ " assert H == self.img_size[0] and W == self.img_size[1], \\\n",
323
+ " f\"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).\"\n",
324
+ " x = self.proj(x)\n",
325
+ " if self.flatten:\n",
326
+ " x = x.flatten(2).transpose(1, 2) # BCHW -> BNC\n",
327
+ " x = self.norm(x)\n",
328
+ " return x\n",
329
+ "\n",
330
+ "class Mlp(nn.Module):\n",
331
+ " \"\"\" MLP as used in Vision Transformer, MLP-Mixer and related networks\n",
332
+ " \"\"\"\n",
333
+ " def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):\n",
334
+ " super().__init__()\n",
335
+ " out_features = out_features or in_features\n",
336
+ " hidden_features = hidden_features or in_features\n",
337
+ " self.fc1 = nn.Linear(in_features, hidden_features)\n",
338
+ " self.act = act_layer()\n",
339
+ " self.fc2 = nn.Linear(hidden_features, out_features)\n",
340
+ " self.drop = nn.Dropout(drop)\n",
341
+ "\n",
342
+ " def forward(self, x):\n",
343
+ " x = self.fc1(x)\n",
344
+ " x = self.act(x)\n",
345
+ " x = self.drop(x)\n",
346
+ " x = self.fc2(x)\n",
347
+ " x = self.drop(x)\n",
348
+ " return x\n",
349
+ "\n",
350
+ "def _no_gradim_audiorunc_normal_(tensor, mean, std, a, b):\n",
351
+ " # Cut & paste from PyTorch official master until it's in a few official releases - RW\n",
352
+ " # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf\n",
353
+ " def norm_cdf(x):\n",
354
+ " # Computes standard normal cumulative distribution function\n",
355
+ " return (1. + math.erf(x / math.sqrt(2.))) / 2.\n",
356
+ "\n",
357
+ " if (mean < a - 2 * std) or (mean > b + 2 * std):\n",
358
+ " warnings.warn(\"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. \"\n",
359
+ " \"The distribution of values may be incorrect.\",\n",
360
+ " stacklevel=2)\n",
361
+ "\n",
362
+ " with torch.no_grad():\n",
363
+ " # Values are generated by using a truncated uniform distribution and\n",
364
+ " # then using the inverse CDF for the normal distribution.\n",
365
+ " # Get upper and lower cdf values\n",
366
+ " l = norm_cdf((a - mean) / std)\n",
367
+ " u = norm_cdf((b - mean) / std)\n",
368
+ "\n",
369
+ " # Uniformly fill tensor with values from [l, u], then translate to\n",
370
+ " # [2l-1, 2u-1].\n",
371
+ " tensor.uniform_(2 * l - 1, 2 * u - 1)\n",
372
+ "\n",
373
+ " # Use inverse cdf transform for normal distribution to get truncated\n",
374
+ " # standard normal\n",
375
+ " tensor.erfinv_()\n",
376
+ "\n",
377
+ " # Transform to proper mean, std\n",
378
+ " tensor.mul_(std * math.sqrt(2.))\n",
379
+ " tensor.add_(mean)\n",
380
+ "\n",
381
+ " # Clamp to ensure it's in the proper range\n",
382
+ " tensor.clamp_(min=a, max=b)\n",
383
+ " return tensor\n",
384
+ "\n",
385
+ "\n",
386
+ "def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):\n",
387
+ " # type: (Tensor, float, float, float, float) -> Tensor\n",
388
+ " r\"\"\"Fills the input Tensor with values drawn from a truncated\n",
389
+ " normal distribution. The values are effectively drawn from the\n",
390
+ " normal distribution :math:`\\mathcal{N}(\\text{mean}, \\text{std}^2)`\n",
391
+ " with values outside :math:`[a, b]` redrawn until they are within\n",
392
+ " the bounds. The method used for generating the random values works\n",
393
+ " best when :math:`a \\leq \\text{mean} \\leq b`.\n",
394
+ " Args:\n",
395
+ " tensor: an n-dimensional `torch.Tensor`\n",
396
+ " mean: the mean of the normal distribution\n",
397
+ " std: the standard deviation of the normal distribution\n",
398
+ " a: the minimum cutoff value\n",
399
+ " b: the maximum cutoff value\n",
400
+ " Examples:\n",
401
+ " >>> w = torch.empty(3, 5)\n",
402
+ " >>> nn.init.trunc_normal_(w)\n",
403
+ " \"\"\"\n",
404
+ " return _no_gradim_audiorunc_normal_(tensor, mean, std, a, b)\n",
405
+ "\n",
406
+ "\n",
407
+ "def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'):\n",
408
+ " fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)\n",
409
+ " if mode == 'fan_in':\n",
410
+ " denom = fan_in\n",
411
+ " elif mode == 'fan_out':\n",
412
+ " denom = fan_out\n",
413
+ " elif mode == 'fan_avg':\n",
414
+ " denom = (fan_in + fan_out) / 2\n",
415
+ "\n",
416
+ " variance = scale / denom\n",
417
+ "\n",
418
+ " if distribution == \"truncated_normal\":\n",
419
+ " # constant is stddev of standard normal truncated to (-2, 2)\n",
420
+ " trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978)\n",
421
+ " elif distribution == \"normal\":\n",
422
+ " tensor.normal_(std=math.sqrt(variance))\n",
423
+ " elif distribution == \"uniform\":\n",
424
+ " bound = math.sqrt(3 * variance)\n",
425
+ " tensor.uniform_(-bound, bound)\n",
426
+ " else:\n",
427
+ " raise ValueError(f\"invalid distribution {distribution}\")\n",
428
+ "\n",
429
+ "\n",
430
+ "def lecun_normal_(tensor):\n",
431
+ " variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal')\n",
432
+ "\n",
433
+ "\n",
434
+ "# below codes are based and referred from https://github.com/microsoft/Swin-Transformer\n",
435
+ "# Swin Transformer for Computer Vision: https://arxiv.org/pdf/2103.14030.pdf\n",
436
+ "\n",
437
+ "def window_partition(x, window_size):\n",
438
+ " \"\"\"\n",
439
+ " Args:\n",
440
+ " x: (B, H, W, C)\n",
441
+ " window_size (int): window size\n",
442
+ " Returns:\n",
443
+ " windows: (num_windows*B, window_size, window_size, C)\n",
444
+ " \"\"\"\n",
445
+ " B, H, W, C = x.shape\n",
446
+ " x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)\n",
447
+ " windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)\n",
448
+ " return windows\n",
449
+ "\n",
450
+ "\n",
451
+ "def window_reverse(windows, window_size, H, W):\n",
452
+ " \"\"\"\n",
453
+ " Args:\n",
454
+ " windows: (num_windows*B, window_size, window_size, C)\n",
455
+ " window_size (int): Window size\n",
456
+ " H (int): Height of image\n",
457
+ " W (int): Width of image\n",
458
+ " Returns:\n",
459
+ " x: (B, H, W, C)\n",
460
+ " \"\"\"\n",
461
+ " B = int(windows.shape[0] / (H * W / window_size / window_size))\n",
462
+ " x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)\n",
463
+ " x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)\n",
464
+ " return x\n",
465
+ "\n",
466
+ "\n",
467
+ "class WindowAttention(nn.Module):\n",
468
+ " r\"\"\" Window based multi-head self attention (W-MSA) module with relative position bias.\n",
469
+ " It supports both of shifted and non-shifted window.\n",
470
+ " Args:\n",
471
+ " dim (int): Number of input channels.\n",
472
+ " window_size (tuple[int]): The height and width of the window.\n",
473
+ " num_heads (int): Number of attention heads.\n",
474
+ " qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True\n",
475
+ " qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set\n",
476
+ " attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0\n",
477
+ " proj_drop (float, optional): Dropout ratio of output. Default: 0.0\n",
478
+ " \"\"\"\n",
479
+ "\n",
480
+ " def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):\n",
481
+ "\n",
482
+ " super().__init__()\n",
483
+ " self.dim = dim\n",
484
+ " self.window_size = window_size # Wh, Ww\n",
485
+ " self.num_heads = num_heads\n",
486
+ " head_dim = dim // num_heads\n",
487
+ " self.scale = qk_scale or head_dim ** -0.5\n",
488
+ "\n",
489
+ " # define a parameter table of relative position bias\n",
490
+ " self.relative_position_bias_table = nn.Parameter(\n",
491
+ " torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH\n",
492
+ "\n",
493
+ " # get pair-wise relative position index for each token inside the window\n",
494
+ " coords_h = torch.arange(self.window_size[0])\n",
495
+ " coords_w = torch.arange(self.window_size[1])\n",
496
+ " coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww\n",
497
+ " coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww\n",
498
+ " relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww\n",
499
+ " relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2\n",
500
+ " relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0\n",
501
+ " relative_coords[:, :, 1] += self.window_size[1] - 1\n",
502
+ " relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1\n",
503
+ " relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww\n",
504
+ " self.register_buffer(\"relative_position_index\", relative_position_index)\n",
505
+ "\n",
506
+ " self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n",
507
+ " self.attn_drop = nn.Dropout(attn_drop)\n",
508
+ " self.proj = nn.Linear(dim, dim)\n",
509
+ " self.proj_drop = nn.Dropout(proj_drop)\n",
510
+ "\n",
511
+ " trunc_normal_(self.relative_position_bias_table, std=.02)\n",
512
+ " self.softmax = nn.Softmax(dim=-1)\n",
513
+ "\n",
514
+ " def forward(self, x, mask=None):\n",
515
+ " \"\"\"\n",
516
+ " Args:\n",
517
+ " x: input features with shape of (num_windows*B, N, C)\n",
518
+ " mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None\n",
519
+ " \"\"\"\n",
520
+ " B_, N, C = x.shape\n",
521
+ " qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)\n",
522
+ " q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)\n",
523
+ "\n",
524
+ " q = q * self.scale\n",
525
+ " attn = (q @ k.transpose(-2, -1))\n",
526
+ "\n",
527
+ " relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(\n",
528
+ " self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH\n",
529
+ " relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww\n",
530
+ " attn = attn + relative_position_bias.unsqueeze(0)\n",
531
+ "\n",
532
+ " if mask is not None:\n",
533
+ " nW = mask.shape[0]\n",
534
+ " attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)\n",
535
+ " attn = attn.view(-1, self.num_heads, N, N)\n",
536
+ " attn = self.softmax(attn)\n",
537
+ " else:\n",
538
+ " attn = self.softmax(attn)\n",
539
+ "\n",
540
+ " attn = self.attn_drop(attn)\n",
541
+ "\n",
542
+ " x = (attn @ v).transpose(1, 2).reshape(B_, N, C)\n",
543
+ " x = self.proj(x)\n",
544
+ " x = self.proj_drop(x)\n",
545
+ " return x, attn\n",
546
+ "\n",
547
+ " def extra_repr(self):\n",
548
+ " return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'\n",
549
+ "\n",
550
+ "\n",
551
+ "# We use the model based on Swintransformer Block, therefore we can use the swin-transformer pretrained model\n",
552
+ "class SwinTransformerBlock(nn.Module):\n",
553
+ " r\"\"\" Swin Transformer Block.\n",
554
+ " Args:\n",
555
+ " dim (int): Number of input channels.\n",
556
+ " input_resolution (tuple[int]): Input resulotion.\n",
557
+ " num_heads (int): Number of attention heads.\n",
558
+ " window_size (int): Window size.\n",
559
+ " shift_size (int): Shift size for SW-MSA.\n",
560
+ " mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n",
561
+ " qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True\n",
562
+ " qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.\n",
563
+ " drop (float, optional): Dropout rate. Default: 0.0\n",
564
+ " attn_drop (float, optional): Attention dropout rate. Default: 0.0\n",
565
+ " drop_path (float, optional): Stochastic depth rate. Default: 0.0\n",
566
+ " act_layer (nn.Module, optional): Activation layer. Default: nn.GELU\n",
567
+ " norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm\n",
568
+ " \"\"\"\n",
569
+ "\n",
570
+ " def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,\n",
571
+ " mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,\n",
572
+ " act_layer=nn.GELU, norm_layer=nn.LayerNorm, norm_before_mlp='ln'):\n",
573
+ " super().__init__()\n",
574
+ " self.dim = dim\n",
575
+ " self.input_resolution = input_resolution\n",
576
+ " self.num_heads = num_heads\n",
577
+ " self.window_size = window_size\n",
578
+ " self.shift_size = shift_size\n",
579
+ " self.mlp_ratio = mlp_ratio\n",
580
+ " self.norm_before_mlp = norm_before_mlp\n",
581
+ " if min(self.input_resolution) <= self.window_size:\n",
582
+ " # if window size is larger than input resolution, we don't partition windows\n",
583
+ " self.shift_size = 0\n",
584
+ " self.window_size = min(self.input_resolution)\n",
585
+ " assert 0 <= self.shift_size < self.window_size, \"shift_size must in 0-window_size\"\n",
586
+ "\n",
587
+ " self.norm1 = norm_layer(dim)\n",
588
+ " self.attn = WindowAttention(\n",
589
+ " dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,\n",
590
+ " qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)\n",
591
+ "\n",
592
+ " self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n",
593
+ " if self.norm_before_mlp == 'ln':\n",
594
+ " self.norm2 = nn.LayerNorm(dim)\n",
595
+ " elif self.norm_before_mlp == 'bn':\n",
596
+ " self.norm2 = lambda x: nn.BatchNorm1d(dim)(x.transpose(1, 2)).transpose(1, 2)\n",
597
+ " else:\n",
598
+ " raise NotImplementedError\n",
599
+ " mlp_hidden_dim = int(dim * mlp_ratio)\n",
600
+ " self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)\n",
601
+ "\n",
602
+ " if self.shift_size > 0:\n",
603
+ " # calculate attention mask for SW-MSA\n",
604
+ " H, W = self.input_resolution\n",
605
+ " img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1\n",
606
+ " h_slices = (slice(0, -self.window_size),\n",
607
+ " slice(-self.window_size, -self.shift_size),\n",
608
+ " slice(-self.shift_size, None))\n",
609
+ " w_slices = (slice(0, -self.window_size),\n",
610
+ " slice(-self.window_size, -self.shift_size),\n",
611
+ " slice(-self.shift_size, None))\n",
612
+ " cnt = 0\n",
613
+ " for h in h_slices:\n",
614
+ " for w in w_slices:\n",
615
+ " img_mask[:, h, w, :] = cnt\n",
616
+ " cnt += 1\n",
617
+ "\n",
618
+ " mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1\n",
619
+ " mask_windows = mask_windows.view(-1, self.window_size * self.window_size)\n",
620
+ " attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)\n",
621
+ " attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))\n",
622
+ " else:\n",
623
+ " attn_mask = None\n",
624
+ "\n",
625
+ " self.register_buffer(\"attn_mask\", attn_mask)\n",
626
+ "\n",
627
+ " def forward(self, x):\n",
628
+ " # pdb.set_trace()\n",
629
+ " H, W = self.input_resolution\n",
630
+ " # print(\"H: \", H)\n",
631
+ " # print(\"W: \", W)\n",
632
+ " # pdb.set_trace()\n",
633
+ " B, L, C = x.shape\n",
634
+ " # assert L == H * W, \"input feature has wrong size\"\n",
635
+ "\n",
636
+ " shortcut = x\n",
637
+ " x = self.norm1(x)\n",
638
+ " x = x.view(B, H, W, C)\n",
639
+ "\n",
640
+ " # cyclic shift\n",
641
+ " if self.shift_size > 0:\n",
642
+ " shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))\n",
643
+ " else:\n",
644
+ " shifted_x = x\n",
645
+ "\n",
646
+ " # partition windows\n",
647
+ " x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C\n",
648
+ " x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C\n",
649
+ "\n",
650
+ " # W-MSA/SW-MSA\n",
651
+ " attn_windows, attn = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C\n",
652
+ "\n",
653
+ " # merge windows\n",
654
+ " attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)\n",
655
+ " shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C\n",
656
+ "\n",
657
+ " # reverse cyclic shift\n",
658
+ " if self.shift_size > 0:\n",
659
+ " x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))\n",
660
+ " else:\n",
661
+ " x = shifted_x\n",
662
+ " x = x.view(B, H * W, C)\n",
663
+ "\n",
664
+ " # FFN\n",
665
+ " x = shortcut + self.drop_path(x)\n",
666
+ " x = x + self.drop_path(self.mlp(self.norm2(x)))\n",
667
+ "\n",
668
+ " return x, attn\n",
669
+ "\n",
670
+ " def extra_repr(self):\n",
671
+ " return f\"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, \" \\\n",
672
+ " f\"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}\"\n",
673
+ "\n",
674
+ "\n",
675
+ "\n",
676
+ "class PatchMerging(nn.Module):\n",
677
+ " r\"\"\" Patch Merging Layer.\n",
678
+ " Args:\n",
679
+ " input_resolution (tuple[int]): Resolution of input feature.\n",
680
+ " dim (int): Number of input channels.\n",
681
+ " norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm\n",
682
+ " \"\"\"\n",
683
+ "\n",
684
+ " def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):\n",
685
+ " super().__init__()\n",
686
+ " self.input_resolution = input_resolution\n",
687
+ " self.dim = dim\n",
688
+ " self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)\n",
689
+ " self.norm = norm_layer(4 * dim)\n",
690
+ "\n",
691
+ " def forward(self, x):\n",
692
+ " \"\"\"\n",
693
+ " x: B, H*W, C\n",
694
+ " \"\"\"\n",
695
+ " H, W = self.input_resolution\n",
696
+ " B, L, C = x.shape\n",
697
+ " assert L == H * W, \"input feature has wrong size\"\n",
698
+ " assert H % 2 == 0 and W % 2 == 0, f\"x size ({H}*{W}) are not even.\"\n",
699
+ "\n",
700
+ " x = x.view(B, H, W, C)\n",
701
+ "\n",
702
+ " x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C\n",
703
+ " x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C\n",
704
+ " x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C\n",
705
+ " x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C\n",
706
+ " x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C\n",
707
+ " x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C\n",
708
+ "\n",
709
+ " x = self.norm(x)\n",
710
+ " x = self.reduction(x)\n",
711
+ "\n",
712
+ " return x\n",
713
+ "\n",
714
+ " def extra_repr(self):\n",
715
+ " return f\"input_resolution={self.input_resolution}, dim={self.dim}\"\n",
716
+ "\n",
717
+ "\n",
718
+ "class BasicLayer(nn.Module):\n",
719
+ " \"\"\" A basic Swin Transformer layer for one stage.\n",
720
+ " Args:\n",
721
+ " dim (int): Number of input channels.\n",
722
+ " input_resolution (tuple[int]): Input resolution.\n",
723
+ " depth (int): Number of blocks.\n",
724
+ " num_heads (int): Number of attention heads.\n",
725
+ " window_size (int): Local window size.\n",
726
+ " mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n",
727
+ " qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True\n",
728
+ " qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.\n",
729
+ " drop (float, optional): Dropout rate. Default: 0.0\n",
730
+ " attn_drop (float, optional): Attention dropout rate. Default: 0.0\n",
731
+ " drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0\n",
732
+ " norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm\n",
733
+ " downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None\n",
734
+ " use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.\n",
735
+ " \"\"\"\n",
736
+ "\n",
737
+ " def __init__(self, dim, input_resolution, depth, num_heads, window_size,\n",
738
+ " mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,\n",
739
+ " drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,\n",
740
+ " norm_before_mlp='ln'):\n",
741
+ "\n",
742
+ " super().__init__()\n",
743
+ " self.dim = dim\n",
744
+ " self.input_resolution = input_resolution\n",
745
+ " self.depth = depth\n",
746
+ " self.use_checkpoint = use_checkpoint\n",
747
+ "\n",
748
+ " # build blocks\n",
749
+ " self.blocks = nn.ModuleList([\n",
750
+ " SwinTransformerBlock(dim=dim, input_resolution=input_resolution,\n",
751
+ " num_heads=num_heads, window_size=window_size,\n",
752
+ " shift_size=0 if (i % 2 == 0) else window_size // 2,\n",
753
+ " mlp_ratio=mlp_ratio,\n",
754
+ " qkv_bias=qkv_bias, qk_scale=qk_scale,\n",
755
+ " drop=drop, attn_drop=attn_drop,\n",
756
+ " drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,\n",
757
+ " norm_layer=norm_layer, norm_before_mlp=norm_before_mlp)\n",
758
+ " for i in range(depth)])\n",
759
+ "\n",
760
+ " # patch merging layer\n",
761
+ " if downsample is not None:\n",
762
+ " self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)\n",
763
+ " else:\n",
764
+ " self.downsample = None\n",
765
+ "\n",
766
+ " def forward(self, x):\n",
767
+ " attns = []\n",
768
+ " for blk in self.blocks:\n",
769
+ " if self.use_checkpoint:\n",
770
+ " x = checkpoint.checkpoint(blk, x)\n",
771
+ " else:\n",
772
+ " x, attn = blk(x)\n",
773
+ " if not self.training:\n",
774
+ " attns.append(attn.unsqueeze(0))\n",
775
+ " if self.downsample is not None:\n",
776
+ " x = self.downsample(x)\n",
777
+ " if not self.training:\n",
778
+ " attn = torch.cat(attns, dim = 0)\n",
779
+ " attn = torch.mean(attn, dim = 0)\n",
780
+ " return x, attn\n",
781
+ "\n",
782
+ " def extra_repr(self):\n",
783
+ " return f\"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}\"\n",
784
+ "\n",
785
+ "\n",
786
+ "# The Core of HTSAT\n",
787
+ "class HTSAT_Swin_Transformer(nn.Module):\n",
788
+ " r\"\"\"HTSAT based on the Swin Transformer\n",
789
+ " Args:\n",
790
+ " spec_size (int | tuple(int)): Input Spectrogram size. Default 256\n",
791
+ " patch_size (int | tuple(int)): Patch size. Default: 4\n",
792
+ " path_stride (iot | tuple(int)): Patch Stride for Frequency and Time Axis. Default: 4\n",
793
+ " in_chans (int): Number of input image channels. Default: 1 (mono)\n",
794
+ " num_classes (int): Number of classes for classification head. Default: 527\n",
795
+ " embed_dim (int): Patch embedding dimension. Default: 96\n",
796
+ " depths (tuple(int)): Depth of each HTSAT-Swin Transformer layer.\n",
797
+ " num_heads (tuple(int)): Number of attention heads in different layers.\n",
798
+ " window_size (int): Window size. Default: 8\n",
799
+ " mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4\n",
800
+ " qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True\n",
801
+ " qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None\n",
802
+ " drop_rate (float): Dropout rate. Default: 0\n",
803
+ " attn_drop_rate (float): Attention dropout rate. Default: 0\n",
804
+ " drop_path_rate (float): Stochastic depth rate. Default: 0.1\n",
805
+ " norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.\n",
806
+ " ape (bool): If True, add absolute position embedding to the patch embedding. Default: False\n",
807
+ " patch_norm (bool): If True, add normalization after patch embedding. Default: True\n",
808
+ " use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False\n",
809
+ " config (module): The configuration Module from config.py (HTSATConfig Class)\n",
810
+ " \"\"\"\n",
811
+ "\n",
812
+ " def __init__(self, spec_size=256, patch_size=4, patch_stride=(4,4), \n",
813
+ " in_chans=1, num_classes=527,\n",
814
+ " embed_dim=96, depths=[2, 2, 6, 2], num_heads=[4, 8, 16, 32],\n",
815
+ " window_size=8, mlp_ratio=4., qkv_bias=True, qk_scale=None,\n",
816
+ " drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,\n",
817
+ " norm_layer=nn.LayerNorm, \n",
818
+ " ape=False, patch_norm=True,\n",
819
+ " use_checkpoint=False, norm_before_mlp='ln', config = None, **kwargs):\n",
820
+ " super(HTSAT_Swin_Transformer, self).__init__()\n",
821
+ "\n",
822
+ " self.config = config\n",
823
+ " self.spec_size = spec_size \n",
824
+ " self.patch_stride = patch_stride\n",
825
+ " self.patch_size = patch_size\n",
826
+ " self.window_size = window_size\n",
827
+ " self.embed_dim = embed_dim\n",
828
+ " self.depths = depths\n",
829
+ " self.ape = ape\n",
830
+ " self.in_chans = in_chans\n",
831
+ " self.num_classes = num_classes\n",
832
+ " self.num_heads = num_heads\n",
833
+ " self.num_layers = len(self.depths)\n",
834
+ " self.num_features = int(self.embed_dim * 2 ** (self.num_layers - 1))\n",
835
+ " \n",
836
+ " self.drop_rate = drop_rate\n",
837
+ " self.attn_drop_rate = attn_drop_rate\n",
838
+ " self.drop_path_rate = drop_path_rate\n",
839
+ "\n",
840
+ " self.qkv_bias = qkv_bias\n",
841
+ " self.qk_scale = None\n",
842
+ "\n",
843
+ " self.patch_norm = patch_norm\n",
844
+ " self.norm_layer = norm_layer if self.patch_norm else None\n",
845
+ " self.norm_before_mlp = norm_before_mlp\n",
846
+ " self.mlp_ratio = mlp_ratio\n",
847
+ "\n",
848
+ " self.use_checkpoint = use_checkpoint\n",
849
+ "\n",
850
+ " # process mel-spec ; used only once\n",
851
+ " self.freq_ratio = self.spec_size // self.config.mel_bins\n",
852
+ " window = 'hann'\n",
853
+ " center = True\n",
854
+ " pad_mode = 'reflect'\n",
855
+ " ref = 1.0\n",
856
+ " amin = 1e-10\n",
857
+ " top_db = None\n",
858
+ " self.interpolate_ratio = 32 # Downsampled ratio\n",
859
+ " # Spectrogram extractor\n",
860
+ " self.spectrogram_extractor = Spectrogram(n_fft=config.window_size, hop_length=config.hop_size, \n",
861
+ " win_length=config.window_size, window=window, center=center, pad_mode=pad_mode, \n",
862
+ " freeze_parameters=True)\n",
863
+ " # Logmel feature extractor\n",
864
+ " self.logmel_extractor = LogmelFilterBank(sr=config.sample_rate, n_fft=config.window_size, \n",
865
+ " n_mels=config.mel_bins, fmin=config.fmin, fmax=config.fmax, ref=ref, amin=amin, top_db=top_db, \n",
866
+ " freeze_parameters=True)\n",
867
+ " # Spec augmenter\n",
868
+ " self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2, \n",
869
+ " freq_drop_width=8, freq_stripes_num=2) # 2 2\n",
870
+ " self.bn0 = nn.BatchNorm2d(self.config.mel_bins)\n",
871
+ "\n",
872
+ "\n",
873
+ " # split spctrogram into non-overlapping patches\n",
874
+ " self.patch_embed = PatchEmbed(\n",
875
+ " img_size=self.spec_size, patch_size=self.patch_size, in_chans=self.in_chans, \n",
876
+ " embed_dim=self.embed_dim, norm_layer=self.norm_layer, patch_stride = patch_stride)\n",
877
+ "\n",
878
+ " num_patches = self.patch_embed.num_patches\n",
879
+ " patches_resolution = self.patch_embed.grid_size\n",
880
+ " self.patches_resolution = patches_resolution\n",
881
+ "\n",
882
+ " # absolute position embedding\n",
883
+ " if self.ape:\n",
884
+ " self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, self.embed_dim))\n",
885
+ " trunc_normal_(self.absolute_pos_embed, std=.02)\n",
886
+ "\n",
887
+ " self.pos_drop = nn.Dropout(p=self.drop_rate)\n",
888
+ "\n",
889
+ " # stochastic depth\n",
890
+ " dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, sum(self.depths))] # stochastic depth decay rule\n",
891
+ "\n",
892
+ " # build layers\n",
893
+ " self.layers = nn.ModuleList()\n",
894
+ " for i_layer in range(self.num_layers):\n",
895
+ " layer = BasicLayer(dim=int(self.embed_dim * 2 ** i_layer),\n",
896
+ " input_resolution=(patches_resolution[0] // (2 ** i_layer),\n",
897
+ " patches_resolution[1] // (2 ** i_layer)),\n",
898
+ " depth=self.depths[i_layer],\n",
899
+ " num_heads=self.num_heads[i_layer],\n",
900
+ " window_size=self.window_size,\n",
901
+ " mlp_ratio=self.mlp_ratio,\n",
902
+ " qkv_bias=self.qkv_bias, qk_scale=self.qk_scale,\n",
903
+ " drop=self.drop_rate, attn_drop=self.attn_drop_rate,\n",
904
+ " drop_path=dpr[sum(self.depths[:i_layer]):sum(self.depths[:i_layer + 1])],\n",
905
+ " norm_layer=self.norm_layer,\n",
906
+ " downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,\n",
907
+ " use_checkpoint=use_checkpoint,\n",
908
+ " norm_before_mlp=self.norm_before_mlp)\n",
909
+ " self.layers.append(layer)\n",
910
+ "\n",
911
+ " self.norm = self.norm_layer(self.num_features)\n",
912
+ " self.avgpool = nn.AdaptiveAvgPool1d(1)\n",
913
+ " self.maxpool = nn.AdaptiveMaxPool1d(1)\n",
914
+ "\n",
915
+ " if self.config.enable_tscam:\n",
916
+ " SF = self.spec_size // (2 ** (len(self.depths) - 1)) // self.patch_stride[0] // self.freq_ratio\n",
917
+ " self.tscam_conv = nn.Conv2d(\n",
918
+ " in_channels = self.num_features,\n",
919
+ " out_channels = self.num_classes,\n",
920
+ " kernel_size = (SF,3),\n",
921
+ " padding = (0,1)\n",
922
+ " )\n",
923
+ " self.head = nn.Linear(num_classes, num_classes)\n",
924
+ " else:\n",
925
+ " self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()\n",
926
+ "\n",
927
+ " self.apply(self._init_weights)\n",
928
+ "\n",
929
+ " def _init_weights(self, m):\n",
930
+ " if isinstance(m, nn.Linear):\n",
931
+ " trunc_normal_(m.weight, std=.02)\n",
932
+ " if isinstance(m, nn.Linear) and m.bias is not None:\n",
933
+ " nn.init.constant_(m.bias, 0)\n",
934
+ " elif isinstance(m, nn.LayerNorm):\n",
935
+ " nn.init.constant_(m.bias, 0)\n",
936
+ " nn.init.constant_(m.weight, 1.0)\n",
937
+ "\n",
938
+ " @torch.jit.ignore\n",
939
+ " def no_weight_decay(self):\n",
940
+ " return {'absolute_pos_embed'}\n",
941
+ "\n",
942
+ " @torch.jit.ignore\n",
943
+ " def no_weight_decay_keywords(self):\n",
944
+ " return {'relative_position_bias_table'}\n",
945
+ "\n",
946
+ " def forward_features(self, x):\n",
947
+ " frames_num = x.shape[2] \n",
948
+ " x = self.patch_embed(x)\n",
949
+ " if self.ape:\n",
950
+ " x = x + self.absolute_pos_embed\n",
951
+ " x = self.pos_drop(x)\n",
952
+ " for i, layer in enumerate(self.layers):\n",
953
+ " x, attn = layer(x)\n",
954
+ "\n",
955
+ " if self.config.enable_tscam:\n",
956
+ " # for x\n",
957
+ " x = self.norm(x)\n",
958
+ " B, N, C = x.shape\n",
959
+ " SF = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[0]\n",
960
+ " ST = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[1]\n",
961
+ " x = x.permute(0,2,1).contiguous().reshape(B, C, SF, ST)\n",
962
+ " B, C, F, T = x.shape\n",
963
+ " # group 2D CNN\n",
964
+ " c_freq_bin = F // self.freq_ratio\n",
965
+ " x = x.reshape(B, C, F // c_freq_bin, c_freq_bin, T)\n",
966
+ " x = x.permute(0,1,3,2,4).contiguous().reshape(B, C, c_freq_bin, -1)\n",
967
+ "\n",
968
+ " # get latent_output\n",
969
+ " latent_output = self.avgpool(torch.flatten(x,2))\n",
970
+ " latent_output = torch.flatten(latent_output, 1)\n",
971
+ "\n",
972
+ " # display the attention map, if needed\n",
973
+ " if self.config.htsat_attn_heatmap:\n",
974
+ " # for attn\n",
975
+ " attn = torch.mean(attn, dim = 1)\n",
976
+ " attn = torch.mean(attn, dim = 1)\n",
977
+ " attn = attn.reshape(B, SF, ST)\n",
978
+ " c_freq_bin = SF // self.freq_ratio\n",
979
+ " attn = attn.reshape(B, SF // c_freq_bin, c_freq_bin, ST) \n",
980
+ " attn = attn.permute(0,2,1,3).contiguous().reshape(B, c_freq_bin, -1)\n",
981
+ " attn = attn.mean(dim = 1)\n",
982
+ " attn_max = torch.max(attn, dim = 1, keepdim = True)[0]\n",
983
+ " attn_min = torch.min(attn, dim = 1, keepdim = True)[0]\n",
984
+ " attn = ((attn * 0.15) + (attn_max * 0.85 - attn_min)) / (attn_max - attn_min)\n",
985
+ " attn = attn.unsqueeze(dim = 2)\n",
986
+ "\n",
987
+ " x = self.tscam_conv(x)\n",
988
+ " x = torch.flatten(x, 2) # B, C, T\n",
989
+ "\n",
990
+ " if self.config.htsat_attn_heatmap:\n",
991
+ " fpx = interpolate(torch.sigmoid(x).permute(0,2,1).contiguous() * attn, 8 * self.patch_stride[1]) \n",
992
+ " else: \n",
993
+ " fpx = interpolate(torch.sigmoid(x).permute(0,2,1).contiguous(), 8 * self.patch_stride[1]) \n",
994
+ " \n",
995
+ " x = self.avgpool(x)\n",
996
+ " x = torch.flatten(x, 1)\n",
997
+ "\n",
998
+ " if self.config.loss_type == \"clip_ce\":\n",
999
+ " output_dict = {\n",
1000
+ " 'framewise_output': fpx, # already sigmoided\n",
1001
+ " 'clipwise_output': x,\n",
1002
+ " 'latent_output': latent_output\n",
1003
+ " }\n",
1004
+ " else:\n",
1005
+ " output_dict = {\n",
1006
+ " 'framewise_output': fpx, # already sigmoided\n",
1007
+ " 'clipwise_output': torch.sigmoid(x),\n",
1008
+ " 'latent_output': latent_output\n",
1009
+ " }\n",
1010
+ " \n",
1011
+ " else:\n",
1012
+ " x = self.norm(x) # B N C\n",
1013
+ " B, N, C = x.shape\n",
1014
+ " \n",
1015
+ " fpx = x.permute(0,2,1).contiguous().reshape(B, C, frames_num // (2 ** (len(self.depths) + 1)), frames_num // (2 ** (len(self.depths) + 1)) )\n",
1016
+ " B, C, F, T = fpx.shape\n",
1017
+ " c_freq_bin = F // self.freq_ratio\n",
1018
+ " fpx = fpx.reshape(B, C, F // c_freq_bin, c_freq_bin, T)\n",
1019
+ " fpx = fpx.permute(0,1,3,2,4).contiguous().reshape(B, C, c_freq_bin, -1)\n",
1020
+ " fpx = torch.sum(fpx, dim = 2)\n",
1021
+ " fpx = interpolate(fpx.permute(0,2,1).contiguous(), 8 * self.patch_stride[1]) \n",
1022
+ " x = self.avgpool(x.transpose(1, 2)) # B C 1\n",
1023
+ " x = torch.flatten(x, 1)\n",
1024
+ " if self.num_classes > 0:\n",
1025
+ " x = self.head(x)\n",
1026
+ " fpx = self.head(fpx)\n",
1027
+ " output_dict = {'framewise_output': torch.sigmoid(fpx), \n",
1028
+ " 'clipwise_output': torch.sigmoid(x)}\n",
1029
+ " return output_dict\n",
1030
+ "\n",
1031
+ " def crop_wav(self, x, crop_size, spe_pos = None):\n",
1032
+ " time_steps = x.shape[2]\n",
1033
+ " tx = torch.zeros(x.shape[0], x.shape[1], crop_size, x.shape[3]).to(x.device)\n",
1034
+ " for i in range(len(x)):\n",
1035
+ " if spe_pos is None:\n",
1036
+ " crop_pos = random.randint(0, time_steps - crop_size - 1)\n",
1037
+ " else:\n",
1038
+ " crop_pos = spe_pos\n",
1039
+ " tx[i][0] = x[i, 0, crop_pos:crop_pos + crop_size,:]\n",
1040
+ " return tx\n",
1041
+ "\n",
1042
+ " # Reshape the wavform to a img size, if you want to use the pretrained swin transformer model\n",
1043
+ " def reshape_wav2img(self, x):\n",
1044
+ " B, C, T, F = x.shape\n",
1045
+ " target_T = int(self.spec_size * self.freq_ratio)\n",
1046
+ " target_F = self.spec_size // self.freq_ratio\n",
1047
+ " assert T <= target_T and F <= target_F, \"the wav size should less than or equal to the swin input size\"\n",
1048
+ " # to avoid bicubic zero error\n",
1049
+ " if T < target_T:\n",
1050
+ " x = nn.functional.interpolate(x, (target_T, x.shape[3]), mode=\"bicubic\", align_corners=True)\n",
1051
+ " if F < target_F:\n",
1052
+ " x = nn.functional.interpolate(x, (x.shape[2], target_F), mode=\"bicubic\", align_corners=True)\n",
1053
+ " x = x.permute(0,1,3,2).contiguous()\n",
1054
+ " x = x.reshape(x.shape[0], x.shape[1], x.shape[2], self.freq_ratio, x.shape[3] // self.freq_ratio)\n",
1055
+ " # print(x.shape)\n",
1056
+ " x = x.permute(0,1,3,2,4).contiguous()\n",
1057
+ " x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3], x.shape[4])\n",
1058
+ " return x\n",
1059
+ " \n",
1060
+ " # Repeat the wavform to a img size, if you want to use the pretrained swin transformer model\n",
1061
+ " def repeat_wat2img(self, x, cur_pos):\n",
1062
+ " B, C, T, F = x.shape\n",
1063
+ " target_T = int(self.spec_size * self.freq_ratio)\n",
1064
+ " target_F = self.spec_size // self.freq_ratio\n",
1065
+ " assert T <= target_T and F <= target_F, \"the wav size should less than or equal to the swin input size\"\n",
1066
+ " # to avoid bicubic zero error\n",
1067
+ " if T < target_T:\n",
1068
+ " x = nn.functional.interpolate(x, (target_T, x.shape[3]), mode=\"bicubic\", align_corners=True)\n",
1069
+ " if F < target_F:\n",
1070
+ " x = nn.functional.interpolate(x, (x.shape[2], target_F), mode=\"bicubic\", align_corners=True) \n",
1071
+ " x = x.permute(0,1,3,2).contiguous() # B C F T\n",
1072
+ " x = x[:,:,:,cur_pos:cur_pos + self.spec_size]\n",
1073
+ " x = x.repeat(repeats = (1,1,4,1))\n",
1074
+ " return x\n",
1075
+ "\n",
1076
+ " def forward(self, x: torch.Tensor, mixup_lambda = None, infer_mode = False):# out_feat_keys: List[str] = None):\n",
1077
+ " x = self.spectrogram_extractor(x) # (batch_size, 1, time_steps, freq_bins)\n",
1078
+ " x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)\n",
1079
+ " \n",
1080
+ " \n",
1081
+ " x = x.transpose(1, 3)\n",
1082
+ " x = self.bn0(x)\n",
1083
+ " x = x.transpose(1, 3)\n",
1084
+ " if self.training:\n",
1085
+ " x = self.spec_augmenter(x)\n",
1086
+ " if self.training and mixup_lambda is not None:\n",
1087
+ " x = do_mixup(x, mixup_lambda)\n",
1088
+ " \n",
1089
+ " if infer_mode:\n",
1090
+ " # in infer mode. we need to handle different length audio input\n",
1091
+ " frame_num = x.shape[2]\n",
1092
+ " target_T = int(self.spec_size * self.freq_ratio)\n",
1093
+ " repeat_ratio = math.floor(target_T / frame_num)\n",
1094
+ " x = x.repeat(repeats=(1,1,repeat_ratio,1))\n",
1095
+ " x = self.reshape_wav2img(x)\n",
1096
+ " output_dict = self.forward_features(x)\n",
1097
+ " elif self.config.enable_repeat_mode:\n",
1098
+ " if self.training:\n",
1099
+ " cur_pos = random.randint(0, (self.freq_ratio - 1) * self.spec_size - 1)\n",
1100
+ " x = self.repeat_wat2img(x, cur_pos)\n",
1101
+ " output_dict = self.forward_features(x)\n",
1102
+ " else:\n",
1103
+ " output_dicts = []\n",
1104
+ " for cur_pos in range(0, (self.freq_ratio - 1) * self.spec_size + 1, self.spec_size):\n",
1105
+ " tx = x.clone()\n",
1106
+ " tx = self.repeat_wat2img(tx, cur_pos)\n",
1107
+ " output_dicts.append(self.forward_features(tx))\n",
1108
+ " clipwise_output = torch.zeros_like(output_dicts[0][\"clipwise_output\"]).float().to(x.device)\n",
1109
+ " framewise_output = torch.zeros_like(output_dicts[0][\"framewise_output\"]).float().to(x.device)\n",
1110
+ " for d in output_dicts:\n",
1111
+ " clipwise_output += d[\"clipwise_output\"]\n",
1112
+ " framewise_output += d[\"framewise_output\"]\n",
1113
+ " clipwise_output = clipwise_output / len(output_dicts)\n",
1114
+ " framewise_output = framewise_output / len(output_dicts)\n",
1115
+ "\n",
1116
+ " output_dict = {\n",
1117
+ " 'framewise_output': framewise_output, \n",
1118
+ " 'clipwise_output': clipwise_output\n",
1119
+ " }\n",
1120
+ " else:\n",
1121
+ " if x.shape[2] > self.freq_ratio * self.spec_size:\n",
1122
+ " if self.training:\n",
1123
+ " x = self.crop_wav(x, crop_size=self.freq_ratio * self.spec_size)\n",
1124
+ " x = self.reshape_wav2img(x)\n",
1125
+ " output_dict = self.forward_features(x)\n",
1126
+ " else:\n",
1127
+ " # Change: Hard code here\n",
1128
+ " overlap_size = 344 #(x.shape[2] - 1) // 4\n",
1129
+ " output_dicts = []\n",
1130
+ " crop_size = 689 #(x.shape[2] - 1) // 2\n",
1131
+ " for cur_pos in range(0, x.shape[2] - crop_size - 1, overlap_size):\n",
1132
+ " tx = self.crop_wav(x, crop_size = crop_size, spe_pos = cur_pos)\n",
1133
+ " tx = self.reshape_wav2img(tx)\n",
1134
+ " output_dicts.append(self.forward_features(tx))\n",
1135
+ " clipwise_output = torch.zeros_like(output_dicts[0][\"clipwise_output\"]).float().to(x.device)\n",
1136
+ " framewise_output = torch.zeros_like(output_dicts[0][\"framewise_output\"]).float().to(x.device)\n",
1137
+ " latent_output = torch.zeros_like(output_dicts[0][\"latent_output\"]).float().to(x.device)\n",
1138
+ " for d in output_dicts:\n",
1139
+ " clipwise_output += d[\"clipwise_output\"]\n",
1140
+ " framewise_output += d[\"framewise_output\"]\n",
1141
+ " latent_output += d[\"latent_output\"]\n",
1142
+ " clipwise_output = clipwise_output / len(output_dicts)\n",
1143
+ " framewise_output = framewise_output / len(output_dicts)\n",
1144
+ " latent_output = latent_output / len(output_dicts)\n",
1145
+ " output_dict = {\n",
1146
+ " 'framewise_output': framewise_output, \n",
1147
+ " 'clipwise_output': clipwise_output,\n",
1148
+ " 'latent_output': latent_output,\n",
1149
+ " }\n",
1150
+ " else: # this part is typically used, and most easy one\n",
1151
+ " x = self.reshape_wav2img(x)\n",
1152
+ " output_dict = self.forward_features(x)\n",
1153
+ " # x = self.head(x)\n",
1154
+ " return output_dict\n",
1155
+ "\n",
1156
+ "class HTSATWrapper(nn.Module):\n",
1157
+ " def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin, \n",
1158
+ " fmax, classes_num, out_emb):\n",
1159
+ " super().__init__()\n",
1160
+ "\n",
1161
+ " # print(\"parameters are being overidden when using HTSAT\")\n",
1162
+ " # print(\"HTSAT only support loading a pretrained model on AudioSet\")\n",
1163
+ " # @TODO later look at what parameters are same and can be merged\n",
1164
+ "\n",
1165
+ " self.htsat = HTSAT_Swin_Transformer(config=HTSATConfig())\n",
1166
+ "\n",
1167
+ " def forward(self, x):\n",
1168
+ " out_dict = self.htsat(x)\n",
1169
+ " out_dict['embedding'] = out_dict['latent_output']\n",
1170
+ " return out_dict\n",
1171
+ "\n",
1172
+ "\n",
1173
+ "def get_audio_encoder(name: str):\n",
1174
+ " if name == \"HTSAT\":\n",
1175
+ " return HTSATWrapper\n",
1176
+ " else:\n",
1177
+ " raise Exception('The audio encoder name {} is incorrect or not supported'.format(name))\n",
1178
+ "\n",
1179
+ "class Projection(nn.Module):\n",
1180
+ " def __init__(self, dim_imgn: int, d_out: int, p: float=0.5) -> None:\n",
1181
+ " super().__init__()\n",
1182
+ " self.linear1 = nn.Linear(dim_imgn, d_out, bias=False)\n",
1183
+ " self.linear2 = nn.Linear(d_out, d_out, bias=False)\n",
1184
+ " self.layer_norm = nn.LayerNorm(d_out)\n",
1185
+ " self.drop = nn.Dropout(p)\n",
1186
+ "\n",
1187
+ " def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
1188
+ " embed1 = self.linear1(x)\n",
1189
+ " embed2 = self.drop(self.linear2(F.gelu(embed1)))\n",
1190
+ " embeds = self.layer_norm(embed1 + embed2)\n",
1191
+ " return embeds\n",
1192
+ "\n",
1193
+ "class AudioEncoder(nn.Module):\n",
1194
+ " def __init__(self, audioenc_name:str, dim_imgn: int, d_out: int, sample_rate: int, window_size: int,\n",
1195
+ " hop_size: int, mel_bins: int, fmin: int, fmax: int, classes_num: int) -> None:\n",
1196
+ " super().__init__()\n",
1197
+ "\n",
1198
+ " audio_encoder = get_audio_encoder(audioenc_name)\n",
1199
+ "\n",
1200
+ " self.base = audio_encoder(\n",
1201
+ " sample_rate, window_size,\n",
1202
+ " hop_size, mel_bins, fmin, fmax,\n",
1203
+ " classes_num, dim_imgn)\n",
1204
+ "\n",
1205
+ " self.projection = Projection(dim_imgn, d_out)\n",
1206
+ "\n",
1207
+ " def forward(self, x):\n",
1208
+ " out_dict = self.base(x)\n",
1209
+ " audio_features, audio_classification_output = out_dict['embedding'], out_dict['clipwise_output']\n",
1210
+ " projected_vec = self.projection(audio_features)\n",
1211
+ " return projected_vec, audio_classification_output\n",
1212
+ "\n",
1213
+ "class TextEncoder(nn.Module):\n",
1214
+ " def __init__(self, d_out: int, text_model: str, transformer_embed_dim: int) -> None:\n",
1215
+ " super().__init__()\n",
1216
+ " self.text_model = text_model\n",
1217
+ " self.base = AutoModel.from_pretrained(text_model)\n",
1218
+ "\n",
1219
+ " if 'clip' in text_model:\n",
1220
+ " self.clip_text_projection = self.base.text_projection\n",
1221
+ " self.base = self.base.text_model\n",
1222
+ " if 'base' in text_model:\n",
1223
+ " transformer_embed_dim = 512\n",
1224
+ " \n",
1225
+ " self.projection = Projection(transformer_embed_dim, d_out)\n",
1226
+ "\n",
1227
+ " def forward(self, x):\n",
1228
+ " if 'clip' in self.text_model:\n",
1229
+ " pooled_output = self.base(**x)[1] # get pooled output\n",
1230
+ " out = self.clip_text_projection(pooled_output) # get CLS token output\n",
1231
+ " elif 'gpt' in self.text_model:\n",
1232
+ " batch_size = x['input_ids'].shape[0]\n",
1233
+ " hidden_states = self.base(**x)[0] # (batch_size=4, seq_len, 768)\n",
1234
+ "\n",
1235
+ " sequence_lengths = torch.ne(x['input_ids'], 0).sum(-1) - 1 # tensor([13, 14, 18, 17])\n",
1236
+ " out = hidden_states[torch.arange(batch_size, device=hidden_states.device), sequence_lengths] # [batch_size, 768] = [4, 768]\n",
1237
+ " else:\n",
1238
+ " out = self.base(**x)[0]\n",
1239
+ " out = out[:, 0, :] # get CLS token output\n",
1240
+ " \n",
1241
+ " projected_vec = self.projection(out)\n",
1242
+ "\n",
1243
+ " return projected_vec\n",
1244
+ "\n",
1245
+ "class CLAP(nn.Module):\n",
1246
+ " def __init__(self,\n",
1247
+ " # audio\n",
1248
+ " audioenc_name: str,\n",
1249
+ " sample_rate: int, \n",
1250
+ " window_size: int, \n",
1251
+ " hop_size: int, \n",
1252
+ " mel_bins: int, \n",
1253
+ " fmin: int, \n",
1254
+ " fmax: int, \n",
1255
+ " classes_num: int, \n",
1256
+ " out_emb: int,\n",
1257
+ " # text\n",
1258
+ " text_model: str,\n",
1259
+ " transformer_embed_dim: int,\n",
1260
+ " # common\n",
1261
+ " d_proj: int,\n",
1262
+ " ):\n",
1263
+ " super().__init__()\n",
1264
+ "\n",
1265
+ " \n",
1266
+ " self.audio_encoder = AudioEncoder(\n",
1267
+ " audioenc_name, out_emb, d_proj,\n",
1268
+ " sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num)\n",
1269
+ "\n",
1270
+ " self.caption_encoder = TextEncoder(\n",
1271
+ " d_proj, text_model, transformer_embed_dim\n",
1272
+ " )\n",
1273
+ "\n",
1274
+ " self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))\n",
1275
+ "\n",
1276
+ " def forward(self, audio, text):\n",
1277
+ " audio_embed, _ = self.audio_encoder(audio)\n",
1278
+ " caption_embed = self.caption_encoder(text)\n",
1279
+ "\n",
1280
+ " return caption_embed, audio_embed, self.logit_scale.exp()\n",
1281
+ " \n",
1282
+ " \n",
1283
+ " \n",
1284
+ "# ==================================================================\n",
1285
+ "# A U D I O - P R E - P R O C E S S I N G\n",
1286
+ "# ==================================================================\n",
1287
+ "def read_audio(audio_path, resample=True):\n",
1288
+ " r\"\"\"Loads audio file or array and returns a torch tensor\"\"\"\n",
1289
+ " # Randomly sample a segment of audio_duration from the clip or pad to match duration\n",
1290
+ " audio_time_series, sample_rate = torchaudio.load(audio_path)\n",
1291
+ "\n",
1292
+ " resample_rate = clapConfig.sample_rate\n",
1293
+ " if resample and resample_rate != sample_rate:\n",
1294
+ " resampler = T.Resample(sample_rate, resample_rate)\n",
1295
+ " audio_time_series = resampler(audio_time_series)\n",
1296
+ " return audio_time_series, resample_rate\n",
1297
+ "\n",
1298
+ "def load_audio_into_tensor(audio_path, audio_duration, resample=False):\n",
1299
+ " r\"\"\"Loads audio file and returns raw audio.\"\"\"\n",
1300
+ " # Randomly sample a segment of audio_duration from the clip or pad to match duration\n",
1301
+ " audio_time_series, sample_rate = read_audio(audio_path, resample)\n",
1302
+ " audio_time_series = audio_time_series.reshape(-1)\n",
1303
+ "\n",
1304
+ " # audio_time_series is shorter than predefined audio duration,\n",
1305
+ " # so audio_time_series is extended\n",
1306
+ " if audio_duration*sample_rate >= audio_time_series.shape[0]:\n",
1307
+ " repeat_factor = int(np.ceil((audio_duration*sample_rate) /\n",
1308
+ " audio_time_series.shape[0]))\n",
1309
+ " # Repeat audio_time_series by repeat_factor to match audio_duration\n",
1310
+ " audio_time_series = audio_time_series.repeat(repeat_factor)\n",
1311
+ " # remove excess part of audio_time_series\n",
1312
+ " audio_time_series = audio_time_series[0:audio_duration*sample_rate]\n",
1313
+ " else:\n",
1314
+ " # audio_time_series is longer than predefined audio duration,\n",
1315
+ " # so audio_time_series is trimmed\n",
1316
+ " start_index = random.randrange(\n",
1317
+ " audio_time_series.shape[0] - audio_duration*sample_rate)\n",
1318
+ " audio_time_series = audio_time_series[start_index:start_index +\n",
1319
+ " audio_duration*sample_rate]\n",
1320
+ " return torch.FloatTensor(audio_time_series)\n",
1321
+ "\n",
1322
+ "np_str_obj_array_pattern = re.compile(r'[SaUO]')\n",
1323
+ "default_collate_err_msg_format = (\n",
1324
+ " \"default_collate: batch must contain tensors, numpy arrays, numbers, \"\n",
1325
+ " \"dicts or lists; found {}\")\n",
1326
+ "\n",
1327
+ "def default_collate(batch):\n",
1328
+ " r\"\"\"Puts each data field into a tensor with outer dimension batch size\"\"\"\n",
1329
+ " elem = batch[0]\n",
1330
+ " elem_type = type(elem)\n",
1331
+ " if isinstance(elem, torch.Tensor):\n",
1332
+ " out = None\n",
1333
+ " if torch.utils.data.get_worker_info() is not None:\n",
1334
+ " # If we're in a background process, concatenate directly into a\n",
1335
+ " # shared memory tensor to avoid an extra copy\n",
1336
+ " numel = sum([x.numel() for x in batch])\n",
1337
+ " storage = elem.storage()._new_shared(numel)\n",
1338
+ " out = elem.new(storage)\n",
1339
+ " return torch.stack(batch, 0, out=out)\n",
1340
+ " elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \\\n",
1341
+ " and elem_type.__name__ != 'string_':\n",
1342
+ " if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':\n",
1343
+ " # array of string classes and object\n",
1344
+ " if np_str_obj_array_pattern.search(elem.dtype.str) is not None:\n",
1345
+ " raise TypeError(\n",
1346
+ " default_collate_err_msg_format.format(elem.dtype))\n",
1347
+ "\n",
1348
+ " return default_collate([torch.as_tensor(b) for b in batch])\n",
1349
+ " elif elem.shape == (): # scalars\n",
1350
+ " return torch.as_tensor(batch)\n",
1351
+ " elif isinstance(elem, float):\n",
1352
+ " return torch.tensor(batch, dtype=torch.float64)\n",
1353
+ " elif isinstance(elem, int):\n",
1354
+ " return torch.tensor(batch)\n",
1355
+ " elif isinstance(elem, str):\n",
1356
+ " return batch\n",
1357
+ " elif isinstance(elem, collections.abc.Mapping):\n",
1358
+ " return {key: default_collate([d[key] for d in batch]) for key in elem}\n",
1359
+ " elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple\n",
1360
+ " return elem_type(*(default_collate(samples) for samples in zip(*batch)))\n",
1361
+ " elif isinstance(elem, collections.abc.Sequence):\n",
1362
+ " # check to make sure that the elements in batch have consistent size\n",
1363
+ " it = iter(batch)\n",
1364
+ " elem_size = len(next(it))\n",
1365
+ " if not all(len(elem) == elem_size for elem in it):\n",
1366
+ " raise RuntimeError(\n",
1367
+ " 'each element in list of batch should be of equal size')\n",
1368
+ " transposed = zip(*batch)\n",
1369
+ " return [default_collate(samples) for samples in transposed]\n",
1370
+ "\n",
1371
+ " raise TypeError(default_collate_err_msg_format.format(elem_type))\n",
1372
+ "\n",
1373
+ "def preprocess_audio(audio_files, resample):\n",
1374
+ " r\"\"\"Load list of audio files and return raw audio\"\"\"\n",
1375
+ " audio_tensors = []\n",
1376
+ " for audio_file in audio_files:\n",
1377
+ " audio_tensor = load_audio_into_tensor(\n",
1378
+ " audio_file, clapConfig.duration, resample)\n",
1379
+ " audio_tensor = audio_tensor.reshape(1, -1)\n",
1380
+ " audio_tensors.append(audio_tensor)\n",
1381
+ " return default_collate(audio_tensors)\n",
1382
+ "\n",
1383
+ "\n",
1384
+ "\n",
1385
+ "# ==================================================================\n",
1386
+ "# A U D I O - E M B E D D I N G S - H E L P E R\n",
1387
+ "# ==================================================================\n",
1388
+ "def CLAPAudioProcessor(audio_files: List[str], resample=True):\n",
1389
+ " preprocessed_audio = preprocess_audio(audio_files, resample)\n",
1390
+ " preprocessed_audio = preprocessed_audio.reshape(\n",
1391
+ " preprocessed_audio.shape[0], preprocessed_audio.shape[2])\n",
1392
+ " preprocessed_audio = preprocessed_audio\n",
1393
+ " return preprocessed_audio\n",
1394
+ "\n",
1395
+ "def get_audio_embeddings(audio_files: List[str], audio_encoder, resample=True):\n",
1396
+ " \"\"\"Load list of audio files and return audio embeddings\"\"\"\n",
1397
+ " # preprocessed_audio = preprocess_audio(audio_files, resample)\n",
1398
+ " # with torch.no_grad():\n",
1399
+ " # preprocessed_audio = preprocessed_audio.reshape(\n",
1400
+ " # preprocessed_audio.shape[0], preprocessed_audio.shape[2])\n",
1401
+ " with torch.no_grad():\n",
1402
+ " preprocessed_audio = CLAPAudioProcessor(audio_files, resample)\n",
1403
+ " return audio_encoder(preprocessed_audio)[0]"
1404
+ ]
1405
+ },
1406
+ {
1407
+ "cell_type": "code",
1408
+ "execution_count": 3,
1409
+ "id": "cab13923",
1410
+ "metadata": {},
1411
+ "outputs": [
1412
+ {
1413
+ "name": "stdout",
1414
+ "output_type": "stream",
1415
+ "text": [
1416
+ "trainable params: 1,289,494 || all params: 34,355,927 || trainable%: 3.7533\n"
1417
+ ]
1418
+ }
1419
+ ],
1420
+ "source": [
1421
+ "# ==================================================================\n",
1422
+ "# C L A P\n",
1423
+ "# ==================================================================\n",
1424
+ "class ClapConfig:\n",
1425
+ " # TEXT ENCODER CONFIG\n",
1426
+ " text_model = 'gpt2'\n",
1427
+ " text_len = 77\n",
1428
+ " transformer_embed_dim = 768\n",
1429
+ " freeze_text_encoder_weights = True\n",
1430
+ "\n",
1431
+ " # AUDIO ENCODER CONFIG\n",
1432
+ " audioenc_name = 'HTSAT'\n",
1433
+ " out_emb = 768\n",
1434
+ " sample_rate = 44100\n",
1435
+ " duration = 7\n",
1436
+ " fmin = 50\n",
1437
+ " fmax = 8000 # 14000\n",
1438
+ " n_fft = 1024 # 1028\n",
1439
+ " hop_size = 320\n",
1440
+ " mel_bins = 64\n",
1441
+ " window_size = 1024\n",
1442
+ "\n",
1443
+ " # PROJECTION SPACE CONFIG \n",
1444
+ " d_proj = 1024\n",
1445
+ " temperature = 0.003\n",
1446
+ "\n",
1447
+ " # TRAINING AND EVALUATION CONFIG\n",
1448
+ " num_classes = 527\n",
1449
+ " batch_size = 1024\n",
1450
+ " demo = False\n",
1451
+ " \n",
1452
+ "\n",
1453
+ "clapConfig = ClapConfig()\n",
1454
+ "clap = CLAP(\n",
1455
+ " audioenc_name=clapConfig.audioenc_name,\n",
1456
+ " sample_rate=clapConfig.sample_rate,\n",
1457
+ " window_size=clapConfig.window_size,\n",
1458
+ " hop_size=clapConfig.hop_size,\n",
1459
+ " mel_bins=clapConfig.mel_bins,\n",
1460
+ " fmin=clapConfig.fmin,\n",
1461
+ " fmax=clapConfig.fmax,\n",
1462
+ " classes_num=clapConfig.num_classes,\n",
1463
+ " out_emb=clapConfig.out_emb,\n",
1464
+ " text_model=clapConfig.text_model,\n",
1465
+ " transformer_embed_dim=clapConfig.transformer_embed_dim,\n",
1466
+ " d_proj=clapConfig.d_proj\n",
1467
+ " )\n",
1468
+ "\n",
1469
+ "model_repo = \"microsoft/msclap\"\n",
1470
+ "model_name = {\n",
1471
+ " '2022': 'CLAP_weights_2022.pth',\n",
1472
+ " '2023': 'CLAP_weights_2023.pth',\n",
1473
+ " 'clapcap': 'clapcap_weights_2023.pth'\n",
1474
+ "}\n",
1475
+ "\n",
1476
+ "version = '2023'\n",
1477
+ "model_fp = hf_hub_download(model_repo, model_name[version])\n",
1478
+ "\n",
1479
+ "model_state_dict = torch.load(model_fp, map_location=torch.device('cpu'))['model']\n",
1480
+ "clap.load_state_dict(model_state_dict, strict=False)\n",
1481
+ "clap.eval()\n",
1482
+ "\n",
1483
+ "clap_audio_encoder = clap.audio_encoder.eval().to(device)\n",
1484
+ "\n",
1485
+ "\n",
1486
+ "# ENGLISH_AUDIO_DIR = r\"/home/IITB/ai-at-ieor/23m1521/datasets/Vaani/Audios/English\"\n",
1487
+ "# audio_files = [os.path.join(ENGLISH_AUDIO_DIR, i) for i in os.listdir(ENGLISH_AUDIO_DIR) if i.endswith(\".wav\")]\n",
1488
+ "# audio_embedding = get_audio_embeddings(audio_files, clap_audio_encoder)\n",
1489
+ "# print(\"CLAP Audio Encoder Embeddings:\", audio_embedding.shape) # [5, 1024]\n",
1490
+ "\n",
1491
+ "\n",
1492
+ "# ==================================================================\n",
1493
+ "# C L A P - L o R A - M O D E L\n",
1494
+ "# ==================================================================\n",
1495
+ "LoRAconfig = {\n",
1496
+ " \"peft_type\": \"LORA\",\n",
1497
+ " \"task_type\": \"FEATURE_EXTRACTION\",\n",
1498
+ " \"inference_mode\": False,\n",
1499
+ " \"r\": 16,\n",
1500
+ " \"target_modules\": [\"qkv\", \"fc1\", \"fc2\", \"proj\", \"linear1\", \"linear2\"],\n",
1501
+ " \"lora_alpha\": 32,\n",
1502
+ " \"lora_dropout\": 0.05,\n",
1503
+ " \"fan_in_fan_out\": False,\n",
1504
+ " \"bias\": \"all\",\n",
1505
+ "}\n",
1506
+ "peft_config = get_peft_config(LoRAconfig)\n",
1507
+ "\n",
1508
+ "peft_model = get_peft_model(clap_audio_encoder, peft_config)\n",
1509
+ "\n",
1510
+ "peft_model.print_trainable_parameters()\n",
1511
+ "\n",
1512
+ "peft_clap_audio_encoder = peft_model.base_model\n",
1513
+ "# audio_embedding = get_audio_embeddings(audio_files, peft_clap_audio_encoder)\n",
1514
+ "# print(\"CLAP LoRA Audio Encoder Embeddings:\", audio_embedding.shape) # [5, 1024]"
1515
+ ]
1516
+ },
1517
+ {
1518
+ "cell_type": "code",
1519
+ "execution_count": 4,
1520
+ "id": "f9998f39",
1521
+ "metadata": {},
1522
+ "outputs": [
1523
+ {
1524
+ "data": {
1525
+ "text/plain": [
1526
+ "'1.26.0'"
1527
+ ]
1528
+ },
1529
+ "execution_count": 4,
1530
+ "metadata": {},
1531
+ "output_type": "execute_result"
1532
+ }
1533
+ ],
1534
+ "source": [
1535
+ "np.__version__"
1536
+ ]
1537
+ },
1538
+ {
1539
+ "cell_type": "code",
1540
+ "execution_count": 5,
1541
+ "id": "16c61e94",
1542
+ "metadata": {},
1543
+ "outputs": [],
1544
+ "source": [
1545
+ "# ==================================================================\n",
1546
+ "# C L I P - M O D E L\n",
1547
+ "# ==================================================================\n",
1548
+ "from transformers import CLIPImageProcessorFast, CLIPImageProcessor\n",
1549
+ "clip_vision_model = CLIPVisionModel.from_pretrained(\"openai/clip-vit-base-patch32\").to(device)\n",
1550
+ "# clip_vision_processor = AutoProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
1551
+ "clip_vision_processor = CLIPImageProcessorFast.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
1552
+ "# clip_vision_processor = CLIPImageProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
1553
+ "\n",
1554
+ "# image = Image.open(\"/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/Img_Audio_Alignment/000000039769.jpg\")\n",
1555
+ "# inputs = clip_vision_processor(images=image, return_tensors=\"pt\")\n",
1556
+ "# print(\"CLIP input image:\", inputs['pixel_values'].shape)\n",
1557
+ "# # input_data = {'pixel_values': inputs}\n",
1558
+ "\n",
1559
+ "# IMAGE_SIZE = 224\n",
1560
+ "# dummy_input = torch.randn(1, 3, IMAGE_SIZE, IMAGE_SIZE) # [1, 3, 224, 224]\n",
1561
+ "# # dummy_input_data = {'pixel_values': dummy_input}\n",
1562
+ "\n",
1563
+ "# output = clip_vision_model(inputs['pixel_values'])\n",
1564
+ "# print(\"CLIP Image Encoder Embeddings:\", output.last_hidden_state.shape) # [1, 50, 768]\n",
1565
+ "# print(\"CLIP Image Encoder Pooled Output:\", output.pooler_output.shape) # [1, 768]"
1566
+ ]
1567
+ },
1568
+ {
1569
+ "cell_type": "code",
1570
+ "execution_count": 6,
1571
+ "id": "61ef98b9",
1572
+ "metadata": {},
1573
+ "outputs": [],
1574
+ "source": [
1575
+ "# ==================================================================\n",
1576
+ "# C S I P - M O D U L E\n",
1577
+ "# ==================================================================\n",
1578
+ "class CSIP(nn.Module):\n",
1579
+ " def __init__(self, image_encoder, audio_encoder, dim_img=None, dim_audio=1024, dim_emb=768):\n",
1580
+ " super(CSIP, self).__init__()\n",
1581
+ " self.image_encoder = image_encoder # CLIPVisionModel\n",
1582
+ " self.audio_encoder = audio_encoder # CLAP_audio_encoder\n",
1583
+ "\n",
1584
+ " # self.image_proj = nn.Linear(dim_img, dim_emb)\n",
1585
+ " self.audio_proj = nn.Linear(dim_audio, dim_emb)\n",
1586
+ "\n",
1587
+ " # Learnable temperature parameter\n",
1588
+ " self.log_temp = nn.Parameter(torch.tensor(0.07).log())\n",
1589
+ "\n",
1590
+ " def forward(self, images, audios):\n",
1591
+ " # Step 1: Feature extraction\n",
1592
+ " with torch.no_grad():\n",
1593
+ " with torch.inference_mode():\n",
1594
+ " image_features = self.image_encoder(images).pooler_output # shape: [n, dim_img]\n",
1595
+ " audio_features = self.audio_encoder(audios)[0] # shape: [n, dim_audio]\n",
1596
+ "\n",
1597
+ " # Step 2: Project and normalize\n",
1598
+ " image_embeds = F.normalize(image_features) # [n, dim_emb]\n",
1599
+ " audio_embeds = F.normalize(self.audio_proj(audio_features), dim=1) # [n, dim_emb]\n",
1600
+ "\n",
1601
+ " # Step 3: Cosine similarity with temperature\n",
1602
+ " logits = torch.matmul(image_embeds, audio_embeds.T) * self.log_temp.exp() # [n, n]\n",
1603
+ " probs = logits.softmax(dim=1)\n",
1604
+ "\n",
1605
+ " # Step 4: Symmetric cross-entropy loss\n",
1606
+ " labels = torch.arange(len(images), device=images.device)\n",
1607
+ " loss_i = F.cross_entropy(logits, labels)\n",
1608
+ " loss_t = F.cross_entropy(logits.T, labels)\n",
1609
+ " loss = (loss_i + loss_t) / 2\n",
1610
+ " return loss, logits, probs"
1611
+ ]
1612
+ },
1613
+ {
1614
+ "cell_type": "code",
1615
+ "execution_count": 7,
1616
+ "id": "b1a15b19",
1617
+ "metadata": {},
1618
+ "outputs": [
1619
+ {
1620
+ "name": "stdout",
1621
+ "output_type": "stream",
1622
+ "text": [
1623
+ "Train Dataset: 15632\n",
1624
+ "Test Dataset: 6700\n"
1625
+ ]
1626
+ }
1627
+ ],
1628
+ "source": [
1629
+ "# ==================================================================\n",
1630
+ "# I M A G E - A U D I O - D A T A S E T\n",
1631
+ "# ==================================================================\n",
1632
+ "class VaaniImageAudioDataset(torch.utils.data.Dataset):\n",
1633
+ " def __init__(self, df):\n",
1634
+ " self.image_paths = df.image_path.tolist()\n",
1635
+ " self.audio_paths = df.audio_path.tolist()\n",
1636
+ "\n",
1637
+ " def __len__(self):\n",
1638
+ " return len(self.audio_paths)\n",
1639
+ "\n",
1640
+ " def __getitem__(self, idx):\n",
1641
+ " return {\n",
1642
+ " 'image_path': self.image_paths[idx], \n",
1643
+ " 'audio_path': self.audio_paths[idx]\n",
1644
+ " }\n",
1645
+ "\n",
1646
+ "\n",
1647
+ "def collate_fn(batch):\n",
1648
+ " image_tensor = clip_vision_processor([Image.open(item['image_path']) for item in batch])['pixel_values']\n",
1649
+ " audio_tensor = CLAPAudioProcessor([item['audio_path'] for item in batch], resample=True)\n",
1650
+ " return {'image_tensor': torch.stack(image_tensor), 'audio_tensor': audio_tensor}\n",
1651
+ "\n",
1652
+ "\n",
1653
+ "# preprocessed_audio = CLAPAudioProcessor(audio_files, resample=True)\n",
1654
+ "# clip_vision_processor = clip_vision_processor\n",
1655
+ "\n",
1656
+ "train_df = pd.read_csv(\"/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/Img_Audio_Alignment/available_img_audios_TRAIN.csv\")\n",
1657
+ "test_df = pd.read_csv(\"/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/Img_Audio_Alignment/available_img_audios_TEST.csv\")\n",
1658
+ "train_dataset = VaaniImageAudioDataset(train_df)\n",
1659
+ "test_dataset = VaaniImageAudioDataset(test_df)\n",
1660
+ "BATCH_SIZE = 32\n",
1661
+ "\n",
1662
+ "print('Train Dataset:', len(train_dataset))\n",
1663
+ "print('Test Dataset:', len(test_dataset))"
1664
+ ]
1665
+ },
1666
+ {
1667
+ "cell_type": "code",
1668
+ "execution_count": 8,
1669
+ "id": "166105cd",
1670
+ "metadata": {},
1671
+ "outputs": [
1672
+ {
1673
+ "name": "stdout",
1674
+ "output_type": "stream",
1675
+ "text": [
1676
+ "Image batch: torch.Size([32, 3, 224, 224])\n",
1677
+ "Audio batch: torch.Size([32, 308700])\n"
1678
+ ]
1679
+ }
1680
+ ],
1681
+ "source": [
1682
+ "train_dataloader = torch.utils.data.DataLoader(\n",
1683
+ " train_dataset,\n",
1684
+ " batch_size=BATCH_SIZE, \n",
1685
+ " shuffle=True, \n",
1686
+ " num_workers=48,\n",
1687
+ " collate_fn=collate_fn,\n",
1688
+ " pin_memory=True,\n",
1689
+ " drop_last=True,\n",
1690
+ " persistent_workers=True\n",
1691
+ ")\n",
1692
+ "\n",
1693
+ "test_dataloader = torch.utils.data.DataLoader(\n",
1694
+ " test_dataset,\n",
1695
+ " batch_size=BATCH_SIZE, \n",
1696
+ " shuffle=False, \n",
1697
+ " num_workers=48,\n",
1698
+ " collate_fn=collate_fn,\n",
1699
+ " pin_memory=True,\n",
1700
+ " drop_last=False,\n",
1701
+ " persistent_workers=True\n",
1702
+ ")\n",
1703
+ "\n",
1704
+ "batch = next(iter(train_dataloader))\n",
1705
+ "print(\"Image batch:\", batch['image_tensor'].shape)\n",
1706
+ "print(\"Audio batch:\", batch['audio_tensor'].shape)"
1707
+ ]
1708
+ },
1709
+ {
1710
+ "cell_type": "code",
1711
+ "execution_count": 9,
1712
+ "id": "d3b0a29f",
1713
+ "metadata": {},
1714
+ "outputs": [],
1715
+ "source": [
1716
+ "csip_model = CSIP(clip_vision_model.eval(), peft_clap_audio_encoder).to(device)"
1717
+ ]
1718
+ },
1719
+ {
1720
+ "cell_type": "code",
1721
+ "execution_count": 10,
1722
+ "id": "2475318a",
1723
+ "metadata": {},
1724
+ "outputs": [],
1725
+ "source": [
1726
+ "# loss, logits, probs = csip_model(batch['image_tensor'].to(device), batch['audio_tensor'].to(device))\n",
1727
+ "# loss, logits, probs, logits.shape, probs.shape"
1728
+ ]
1729
+ },
1730
+ {
1731
+ "cell_type": "code",
1732
+ "execution_count": 11,
1733
+ "id": "dc748c49",
1734
+ "metadata": {},
1735
+ "outputs": [],
1736
+ "source": [
1737
+ "def train_batch(model, images, audio, optimizer):\n",
1738
+ " model.train()\n",
1739
+ " optimizer.zero_grad()\n",
1740
+ " loss, logits, probs = model(images, audio)\n",
1741
+ " loss.backward()\n",
1742
+ " optimizer.step()\n",
1743
+ " return loss.item(), logits, probs\n",
1744
+ "\n",
1745
+ "@torch.no_grad()\n",
1746
+ "def evaluate_batch(model, images, audio):\n",
1747
+ " model.eval()\n",
1748
+ " loss, logits, probs = model(images, audio)\n",
1749
+ " return loss.item(), logits, probs\n",
1750
+ "\n",
1751
+ "def save_checkpoint(state, checkpoint_dir, epoch, max_checkpoints=2):\n",
1752
+ " filename = f\"csip_best_epoch_{epoch+1}.pt\"\n",
1753
+ " path = os.path.join(checkpoint_dir, filename)\n",
1754
+ " torch.save(state, path)\n",
1755
+ " checkpoints = sorted(\n",
1756
+ " [f for f in os.listdir(checkpoint_dir) if f.startswith(\"csip_best_epoch_\")],\n",
1757
+ " key=lambda x: int(x.split(\"_\")[-1].split(\".\")[0])\n",
1758
+ " )\n",
1759
+ " while len(checkpoints) > max_checkpoints:\n",
1760
+ " to_delete = checkpoints.pop(0)\n",
1761
+ " os.remove(os.path.join(checkpoint_dir, to_delete))\n",
1762
+ "\n",
1763
+ "\n",
1764
+ "def load_checkpoint(checkpoint_dir, model, optimizer):\n",
1765
+ " checkpoints = sorted(\n",
1766
+ " [f for f in os.listdir(checkpoint_dir) if f.startswith(\"clip_best_epoch_\")],\n",
1767
+ " key=lambda x: int(x.split(\"_\")[-1].split(\".\")[0])\n",
1768
+ " )\n",
1769
+ " if not checkpoints:\n",
1770
+ " print(\"No checkpoint found to resume from.\")\n",
1771
+ " return 0, float(\"inf\")\n",
1772
+ "\n",
1773
+ " best_ckpt = checkpoints[-1]\n",
1774
+ " path = os.path.join(checkpoint_dir, best_ckpt)\n",
1775
+ " checkpoint = torch.load(path)\n",
1776
+ " model.load_state_dict(checkpoint['model_state'])\n",
1777
+ " optimizer.load_state_dict(checkpoint['optimizer_state'])\n",
1778
+ " start_epoch = checkpoint['epoch']\n",
1779
+ " best_loss = checkpoint['best_loss']\n",
1780
+ " print(f\"Resumed training from epoch {start_epoch+1} with best loss {best_loss:.4f}\")\n",
1781
+ " return start_epoch, best_loss\n",
1782
+ "\n",
1783
+ "\n",
1784
+ "def fig_to_tensor(fig):\n",
1785
+ " \"\"\"Convert a Matplotlib figure to a tensor suitable for TensorBoard.\"\"\"\n",
1786
+ " buf = io.BytesIO()\n",
1787
+ " fig.savefig(buf, format='png')\n",
1788
+ " buf.seek(0)\n",
1789
+ " image = Image.open(buf).convert(\"RGB\")\n",
1790
+ " tensor = tv.transforms.functional.to_tensor(image)\n",
1791
+ " buf.close()\n",
1792
+ " plt.close(fig)\n",
1793
+ " return tensor\n",
1794
+ "\n",
1795
+ "def save_similarity_heatmaps(logits, epoch, loss, save_dir, writer):\n",
1796
+ " os.makedirs(save_dir, exist_ok=True)\n",
1797
+ "\n",
1798
+ " # --- Raw logits heatmap ---\n",
1799
+ " logits_np = logits.detach().cpu().numpy()\n",
1800
+ " fig_logits = plt.figure(figsize=(8, 6))\n",
1801
+ " sns.heatmap(logits_np, square=True, cmap=\"viridis\", cbar=True)\n",
1802
+ " plt.title(f\"Raw Logits Heatmap — Epoch {epoch+1}, Loss {loss:.4f}\")\n",
1803
+ " plt.xlabel(\"Audio Index\")\n",
1804
+ " plt.ylabel(\"Image Index\")\n",
1805
+ " raw_path = os.path.join(save_dir, f\"raw_logits_epoch_{epoch+1}_loss_{loss:.4f}.png\")\n",
1806
+ " fig_logits.savefig(raw_path)\n",
1807
+ " writer.add_image(\"Heatmap/RawLogits\", fig_to_tensor(fig_logits), global_step=epoch+1)\n",
1808
+ "\n",
1809
+ " # --- Softmax probs heatmap ---\n",
1810
+ " probs_np = logits.softmax(dim=1).cpu().numpy()\n",
1811
+ " fig_probs = plt.figure(figsize=(8, 6))\n",
1812
+ " sns.heatmap(probs_np, square=True, cmap=\"viridis\", cbar=True)\n",
1813
+ " plt.title(f\"Softmax Probabilities Heatmap — Epoch {epoch+1}, Loss {loss:.4f}\")\n",
1814
+ " plt.xlabel(\"Audio Index\")\n",
1815
+ " plt.ylabel(\"Image Index\")\n",
1816
+ " prob_path = os.path.join(save_dir, f\"probs_epoch_{epoch+1}_loss_{loss:.4f}.png\")\n",
1817
+ " fig_probs.savefig(prob_path)\n",
1818
+ " writer.add_image(\"Heatmap/SoftmaxProbs\", fig_to_tensor(fig_probs), global_step=epoch+1)\n",
1819
+ "\n",
1820
+ "\n",
1821
+ "\n",
1822
+ "def train_model(model, train_loader, test_loader, \n",
1823
+ " optimizer, device, log_dir, \n",
1824
+ " checkpoint_dir, resume=False, epochs=10):\n",
1825
+ " os.makedirs(log_dir, exist_ok=True)\n",
1826
+ " os.makedirs(checkpoint_dir, exist_ok=True)\n",
1827
+ " writer = SummaryWriter(log_dir=log_dir)\n",
1828
+ "\n",
1829
+ " start_epoch = 0\n",
1830
+ " best_loss = float(\"inf\")\n",
1831
+ "\n",
1832
+ " if resume:\n",
1833
+ " start_epoch, best_loss = load_checkpoint(checkpoint_dir, model, optimizer)\n",
1834
+ "\n",
1835
+ " for epoch in trange(start_epoch, epochs, colour='yellow', ncols=70):\n",
1836
+ " train_losses = []\n",
1837
+ " test_losses = []\n",
1838
+ "\n",
1839
+ " train_loop = tqdm(train_loader, desc=f\"[Train Epoch {epoch+1}]\", colour='blue', ncols=70)\n",
1840
+ " for batch in train_loop:\n",
1841
+ " images = batch['image_tensor'].to(device)\n",
1842
+ " audios = batch['audio_tensor'].to(device)\n",
1843
+ " loss, logits, probs = train_batch(model, images, audios, optimizer)\n",
1844
+ " train_losses.append(loss)\n",
1845
+ " train_loop.set_postfix(train_loss=loss)\n",
1846
+ "\n",
1847
+ " test_loop = tqdm(test_loader, desc=f\"[Test Epoch {epoch+1}]\", colour='blue', ncols=70)\n",
1848
+ " for batch in test_loop:\n",
1849
+ " images = batch['image_tensor'].to(device)\n",
1850
+ " audios = batch['audio_tensor'].to(device)\n",
1851
+ " loss, logits, probs = evaluate_batch(model, images, audios)\n",
1852
+ " test_losses.append(loss)\n",
1853
+ " test_loop.set_postfix(test_loss=loss)\n",
1854
+ "\n",
1855
+ " avg_train_loss = sum(train_losses) / len(train_losses)\n",
1856
+ " avg_test_loss = sum(test_losses) / len(test_losses)\n",
1857
+ "\n",
1858
+ " writer.add_scalar(\"Loss/Train\", avg_train_loss, epoch + 1)\n",
1859
+ " writer.add_scalar(\"Loss/Test\", avg_test_loss, epoch + 1)\n",
1860
+ "\n",
1861
+ " print(f\"Epoch {epoch+1} | Train Loss: {avg_train_loss:.4f} | Test Loss: {avg_test_loss:.4f}\")\n",
1862
+ "\n",
1863
+ " if avg_test_loss < best_loss:\n",
1864
+ " save_similarity_heatmaps(logits, epoch, avg_test_loss, checkpoint_dir, writer)\n",
1865
+ " best_loss = avg_test_loss\n",
1866
+ " save_checkpoint({\n",
1867
+ " 'epoch': epoch,\n",
1868
+ " 'model_state': model.state_dict(),\n",
1869
+ " 'optimizer_state': optimizer.state_dict(),\n",
1870
+ " 'best_loss': best_loss\n",
1871
+ " }, checkpoint_dir, epoch)\n",
1872
+ " print(f\">>> Saved new best model at epoch {epoch+1}\")\n",
1873
+ "\n",
1874
+ " writer.close()"
1875
+ ]
1876
+ },
1877
+ {
1878
+ "cell_type": "code",
1879
+ "execution_count": null,
1880
+ "id": "1e29bf49",
1881
+ "metadata": {},
1882
+ "outputs": [
1883
+ {
1884
+ "name": "stdout",
1885
+ "output_type": "stream",
1886
+ "text": [
1887
+ "No checkpoint found to resume from.\n"
1888
+ ]
1889
+ },
1890
+ {
1891
+ "name": "stderr",
1892
+ "output_type": "stream",
1893
+ "text": [
1894
+ " 0%|\u001b[33m \u001b[0m| 0/10 [00:00<?, ?it/s]\u001b[0m"
1895
+ ]
1896
+ }
1897
+ ],
1898
+ "source": [
1899
+ "learning_rate = 1e-4\n",
1900
+ "epochs = 10\n",
1901
+ "optimizer = torch.optim.AdamW(csip_model.parameters(), lr=learning_rate)\n",
1902
+ "\n",
1903
+ "train_model(\n",
1904
+ " model=csip_model,\n",
1905
+ " train_loader=train_dataloader,\n",
1906
+ " test_loader=test_dataloader,\n",
1907
+ " optimizer=optimizer,\n",
1908
+ " device=device,\n",
1909
+ " log_dir=\"runs/csip\",\n",
1910
+ " checkpoint_dir=\"checkpoints/csip\",\n",
1911
+ " resume=True,\n",
1912
+ " epochs=epochs\n",
1913
+ " )"
1914
+ ]
1915
+ }
1916
+ ],
1917
+ "metadata": {
1918
+ "kernelspec": {
1919
+ "display_name": "Python 3",
1920
+ "language": "python",
1921
+ "name": "python3"
1922
+ },
1923
+ "language_info": {
1924
+ "codemirror_mode": {
1925
+ "name": "ipython",
1926
+ "version": 3
1927
+ },
1928
+ "file_extension": ".py",
1929
+ "mimetype": "text/x-python",
1930
+ "name": "python",
1931
+ "nbconvert_exporter": "python",
1932
+ "pygments_lexer": "ipython3",
1933
+ "version": "3.11.11"
1934
+ }
1935
+ },
1936
+ "nbformat": 4,
1937
+ "nbformat_minor": 5
1938
+ }
Vaani/Img_Audio_Alignment/_2_Train.py CHANGED
@@ -32,15 +32,19 @@ from pathlib import Path
32
  from typing import Optional, Tuple, Union, List, Dict
33
 
34
  import requests
35
- from PIL import Image
36
  import numpy as np
 
 
37
  import torch
38
  import torch.nn.functional as F
39
  from torch import nn
40
  from torch.nn.init import _calculate_fan_in_and_fan_out
41
  import torch.utils.checkpoint as checkpoint
42
 
43
- os.environ["CUDA_VISIBLE_DEVICES"] = "0"
 
 
 
44
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
45
  print(f"Using device: {device}")
46
 
@@ -169,7 +173,7 @@ class HTSATConfig:
169
  heatmap_dir = "/home/Research/heatmap_output"
170
  test_file = "htsat-test-ensemble"
171
  fl_local = False # indicate if we need to use this dataset for the framewise detection
172
- fl_dataset = "/home/Research/desed/desed_eval.npy"
173
  fl_class_num = [
174
  "Speech", "Frying", "Dishes", "Running_water",
175
  "Blender", "Electric_shaver_toothbrush", "Alarm_bell_ringing",
@@ -314,7 +318,7 @@ class Mlp(nn.Module):
314
  x = self.drop(x)
315
  return x
316
 
317
- def _no_grad_trunc_normal_(tensor, mean, std, a, b):
318
  # Cut & paste from PyTorch official master until it's in a few official releases - RW
319
  # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
320
  def norm_cdf(x):
@@ -368,7 +372,7 @@ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
368
  >>> w = torch.empty(3, 5)
369
  >>> nn.init.trunc_normal_(w)
370
  """
371
- return _no_grad_trunc_normal_(tensor, mean, std, a, b)
372
 
373
 
374
  def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'):
@@ -1144,9 +1148,9 @@ def get_audio_encoder(name: str):
1144
  raise Exception('The audio encoder name {} is incorrect or not supported'.format(name))
1145
 
1146
  class Projection(nn.Module):
1147
- def __init__(self, d_in: int, d_out: int, p: float=0.5) -> None:
1148
  super().__init__()
1149
- self.linear1 = nn.Linear(d_in, d_out, bias=False)
1150
  self.linear2 = nn.Linear(d_out, d_out, bias=False)
1151
  self.layer_norm = nn.LayerNorm(d_out)
1152
  self.drop = nn.Dropout(p)
@@ -1158,7 +1162,7 @@ class Projection(nn.Module):
1158
  return embeds
1159
 
1160
  class AudioEncoder(nn.Module):
1161
- def __init__(self, audioenc_name:str, d_in: int, d_out: int, sample_rate: int, window_size: int,
1162
  hop_size: int, mel_bins: int, fmin: int, fmax: int, classes_num: int) -> None:
1163
  super().__init__()
1164
 
@@ -1167,9 +1171,9 @@ class AudioEncoder(nn.Module):
1167
  self.base = audio_encoder(
1168
  sample_rate, window_size,
1169
  hop_size, mel_bins, fmin, fmax,
1170
- classes_num, d_in)
1171
 
1172
- self.projection = Projection(d_in, d_out)
1173
 
1174
  def forward(self, x):
1175
  out_dict = self.base(x)
@@ -1343,7 +1347,7 @@ def preprocess_audio(audio_files, resample):
1343
  for audio_file in audio_files:
1344
  audio_tensor = load_audio_into_tensor(
1345
  audio_file, clapConfig.duration, resample)
1346
- audio_tensor = audio_tensor.reshape(1, -1).to(device)
1347
  audio_tensors.append(audio_tensor)
1348
  return default_collate(audio_tensors)
1349
 
@@ -1352,12 +1356,21 @@ def preprocess_audio(audio_files, resample):
1352
  # ==================================================================
1353
  # A U D I O - E M B E D D I N G S - H E L P E R
1354
  # ==================================================================
 
 
 
 
 
 
 
1355
  def get_audio_embeddings(audio_files: List[str], audio_encoder, resample=True):
1356
  """Load list of audio files and return audio embeddings"""
1357
- preprocessed_audio = preprocess_audio(audio_files, resample)
 
 
 
1358
  with torch.no_grad():
1359
- preprocessed_audio = preprocessed_audio.reshape(
1360
- preprocessed_audio.shape[0], preprocessed_audio.shape[2])
1361
  return audio_encoder(preprocessed_audio)[0]
1362
 
1363
 
@@ -1421,16 +1434,15 @@ model_fp = hf_hub_download(model_repo, model_name[version])
1421
 
1422
  model_state_dict = torch.load(model_fp, map_location=torch.device('cpu'))['model']
1423
  clap.load_state_dict(model_state_dict, strict=False)
1424
- clap.to(device)
1425
  clap.eval()
1426
 
1427
- clap_audio_encoder = clap.audio_encoder.eval()
1428
 
1429
 
1430
- ENGLISH_AUDIO_DIR = r"/home/IITB/ai-at-ieor/23m1521/datasets/Vaani/Audios/English"
1431
- audio_files = [os.path.join(ENGLISH_AUDIO_DIR, i) for i in os.listdir(ENGLISH_AUDIO_DIR) if i.endswith(".wav")]
1432
- audio_embedding = get_audio_embeddings(audio_files, clap_audio_encoder)
1433
- print("CLAP Audio Encoder Embeddings:", audio_embedding.shape) # [5, 1024]
1434
 
1435
 
1436
  # ==================================================================
@@ -1449,47 +1461,302 @@ LoRAconfig = {
1449
  }
1450
  peft_config = get_peft_config(LoRAconfig)
1451
 
1452
- model = clap_audio_encoder
1453
- peft_model = get_peft_model(model, peft_config)
1454
 
1455
  peft_model.print_trainable_parameters()
1456
 
1457
- # peft_model.base_model
1458
- # peft_model
1459
-
1460
  peft_clap_audio_encoder = peft_model.base_model
1461
- audio_embedding = get_audio_embeddings(audio_files, peft_clap_audio_encoder)
1462
- print("CLAP LoRA Audio Encoder Embeddings:", audio_embedding.shape) # [5, 1024]
1463
 
1464
 
1465
  # ==================================================================
1466
  # C L I P - M O D E L
1467
  # ==================================================================
1468
  from transformers import CLIPImageProcessorFast, CLIPImageProcessor
1469
- clip_vision_model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
1470
  # clip_vision_processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
1471
- # clip_vision_processor = CLIPImageProcessorFast.from_pretrained("openai/clip-vit-base-patch32")
1472
- clip_vision_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
1473
 
1474
- image = Image.open("/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/Img_Audio_Alignment/000000039769.jpg")
1475
- inputs = clip_vision_processor(images=image, return_tensors="pt")
1476
- print("CLIP input image:", inputs['pixel_values'].shape)
1477
- # input_data = {'pixel_values': inputs}
1478
 
1479
- IMAGE_SIZE = 224
1480
- dummy_input = torch.randn(1, 3, IMAGE_SIZE, IMAGE_SIZE) # [1, 3, 224, 224]
1481
- # dummy_input_data = {'pixel_values': dummy_input}
1482
 
1483
- # class VisionModelWrapper(nn.Module):
1484
- # def __init__(self, peft_model):
1485
- # super().__init__()
1486
- # self.model = peft_model.base_model
1487
 
1488
- # def forward(self, x):
1489
- # return self.model(pixel_values=x).last_hidden_state
 
 
 
 
 
 
 
 
 
1490
 
1491
- # wrapped_model = VisionModelWrapper(peft_model)
1492
- # output = wrapped_model(input_data)
1493
- output = clip_vision_model(inputs['pixel_values'])
1494
- print("CLIP Image Encoder Embeddings:", output.last_hidden_state.shape) # [1, 50, 768]
1495
- print("CLIP Image Encoder Pooled Output:", output.pooler_output.shape) # [1, 768]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  from typing import Optional, Tuple, Union, List, Dict
33
 
34
  import requests
 
35
  import numpy as np
36
+ import pandas as pd
37
+ from PIL import Image
38
  import torch
39
  import torch.nn.functional as F
40
  from torch import nn
41
  from torch.nn.init import _calculate_fan_in_and_fan_out
42
  import torch.utils.checkpoint as checkpoint
43
 
44
+ import torchvision as tv
45
+ from torchvision.transforms import v2
46
+
47
+ os.environ["CUDA_VISIBLE_DEVICES"] = "1"
48
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
49
  print(f"Using device: {device}")
50
 
 
173
  heatmap_dir = "/home/Research/heatmap_output"
174
  test_file = "htsat-test-ensemble"
175
  fl_local = False # indicate if we need to use this dataset for the framewise detection
176
+ fl_dataset = "/home/Research/desed/desedim_embval.npy"
177
  fl_class_num = [
178
  "Speech", "Frying", "Dishes", "Running_water",
179
  "Blender", "Electric_shaver_toothbrush", "Alarm_bell_ringing",
 
318
  x = self.drop(x)
319
  return x
320
 
321
+ def _no_gradim_audiorunc_normal_(tensor, mean, std, a, b):
322
  # Cut & paste from PyTorch official master until it's in a few official releases - RW
323
  # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
324
  def norm_cdf(x):
 
372
  >>> w = torch.empty(3, 5)
373
  >>> nn.init.trunc_normal_(w)
374
  """
375
+ return _no_gradim_audiorunc_normal_(tensor, mean, std, a, b)
376
 
377
 
378
  def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'):
 
1148
  raise Exception('The audio encoder name {} is incorrect or not supported'.format(name))
1149
 
1150
  class Projection(nn.Module):
1151
+ def __init__(self, dim_imgn: int, d_out: int, p: float=0.5) -> None:
1152
  super().__init__()
1153
+ self.linear1 = nn.Linear(dim_imgn, d_out, bias=False)
1154
  self.linear2 = nn.Linear(d_out, d_out, bias=False)
1155
  self.layer_norm = nn.LayerNorm(d_out)
1156
  self.drop = nn.Dropout(p)
 
1162
  return embeds
1163
 
1164
  class AudioEncoder(nn.Module):
1165
+ def __init__(self, audioenc_name:str, dim_imgn: int, d_out: int, sample_rate: int, window_size: int,
1166
  hop_size: int, mel_bins: int, fmin: int, fmax: int, classes_num: int) -> None:
1167
  super().__init__()
1168
 
 
1171
  self.base = audio_encoder(
1172
  sample_rate, window_size,
1173
  hop_size, mel_bins, fmin, fmax,
1174
+ classes_num, dim_imgn)
1175
 
1176
+ self.projection = Projection(dim_imgn, d_out)
1177
 
1178
  def forward(self, x):
1179
  out_dict = self.base(x)
 
1347
  for audio_file in audio_files:
1348
  audio_tensor = load_audio_into_tensor(
1349
  audio_file, clapConfig.duration, resample)
1350
+ audio_tensor = audio_tensor.reshape(1, -1)
1351
  audio_tensors.append(audio_tensor)
1352
  return default_collate(audio_tensors)
1353
 
 
1356
  # ==================================================================
1357
  # A U D I O - E M B E D D I N G S - H E L P E R
1358
  # ==================================================================
1359
+ def CLAPAudioProcessor(audio_files: List[str], resample=True):
1360
+ preprocessed_audio = preprocess_audio(audio_files, resample)
1361
+ preprocessed_audio = preprocessed_audio.reshape(
1362
+ preprocessed_audio.shape[0], preprocessed_audio.shape[2])
1363
+ preprocessed_audio = preprocessed_audio
1364
+ return preprocessed_audio
1365
+
1366
  def get_audio_embeddings(audio_files: List[str], audio_encoder, resample=True):
1367
  """Load list of audio files and return audio embeddings"""
1368
+ # preprocessed_audio = preprocess_audio(audio_files, resample)
1369
+ # with torch.no_grad():
1370
+ # preprocessed_audio = preprocessed_audio.reshape(
1371
+ # preprocessed_audio.shape[0], preprocessed_audio.shape[2])
1372
  with torch.no_grad():
1373
+ preprocessed_audio = CLAPAudioProcessor(audio_files, resample)
 
1374
  return audio_encoder(preprocessed_audio)[0]
1375
 
1376
 
 
1434
 
1435
  model_state_dict = torch.load(model_fp, map_location=torch.device('cpu'))['model']
1436
  clap.load_state_dict(model_state_dict, strict=False)
 
1437
  clap.eval()
1438
 
1439
+ clap_audio_encoder = clap.audio_encoder.eval().to(device)
1440
 
1441
 
1442
+ # ENGLISH_AUDIO_DIR = r"/home/IITB/ai-at-ieor/23m1521/datasets/Vaani/Audios/English"
1443
+ # audio_files = [os.path.join(ENGLISH_AUDIO_DIR, i) for i in os.listdir(ENGLISH_AUDIO_DIR) if i.endswith(".wav")]
1444
+ # audio_embedding = get_audio_embeddings(audio_files, clap_audio_encoder)
1445
+ # print("CLAP Audio Encoder Embeddings:", audio_embedding.shape) # [5, 1024]
1446
 
1447
 
1448
  # ==================================================================
 
1461
  }
1462
  peft_config = get_peft_config(LoRAconfig)
1463
 
1464
+ peft_model = get_peft_model(clap_audio_encoder, peft_config)
 
1465
 
1466
  peft_model.print_trainable_parameters()
1467
 
 
 
 
1468
  peft_clap_audio_encoder = peft_model.base_model
1469
+ # audio_embedding = get_audio_embeddings(audio_files, peft_clap_audio_encoder)
1470
+ # print("CLAP LoRA Audio Encoder Embeddings:", audio_embedding.shape) # [5, 1024]
1471
 
1472
 
1473
  # ==================================================================
1474
  # C L I P - M O D E L
1475
  # ==================================================================
1476
  from transformers import CLIPImageProcessorFast, CLIPImageProcessor
1477
+ clip_vision_model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
1478
  # clip_vision_processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
1479
+ clip_vision_processor = CLIPImageProcessorFast.from_pretrained("openai/clip-vit-base-patch32")
1480
+ # clip_vision_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
1481
 
1482
+ # image = Image.open("/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/Img_Audio_Alignment/000000039769.jpg")
1483
+ # inputs = clip_vision_processor(images=image, return_tensors="pt")
1484
+ # print("CLIP input image:", inputs['pixel_values'].shape)
1485
+ # # input_data = {'pixel_values': inputs}
1486
 
1487
+ # IMAGE_SIZE = 224
1488
+ # dummy_input = torch.randn(1, 3, IMAGE_SIZE, IMAGE_SIZE) # [1, 3, 224, 224]
1489
+ # # dummy_input_data = {'pixel_values': dummy_input}
1490
 
1491
+ # output = clip_vision_model(inputs['pixel_values'])
1492
+ # print("CLIP Image Encoder Embeddings:", output.last_hidden_state.shape) # [1, 50, 768]
1493
+ # print("CLIP Image Encoder Pooled Output:", output.pooler_output.shape) # [1, 768]
 
1494
 
1495
+ exit()
1496
+
1497
+
1498
+ # ==================================================================
1499
+ # C S I P - M O D U L E
1500
+ # ==================================================================
1501
+ class CSIP(nn.Module):
1502
+ def __init__(self, image_encoder, audio_encoder, dim_img=None, dim_audio=1024, dim_emb=768):
1503
+ super(CSIP, self).__init__()
1504
+ self.image_encoder = image_encoder # CLIPVisionModel
1505
+ self.audio_encoder = audio_encoder # CLAP_audio_encoder
1506
 
1507
+ # self.image_proj = nn.Linear(dim_img, dim_emb)
1508
+ self.audio_proj = nn.Linear(dim_audio, dim_emb)
1509
+
1510
+ # Learnable temperature parameter
1511
+ self.log_temp = nn.Parameter(torch.tensor(0.07).log())
1512
+
1513
+ def forward(self, images, audios):
1514
+ # Step 1: Feature extraction
1515
+ with torch.no_grad():
1516
+ with torch.inference_mode():
1517
+ image_features = self.image_encoder(images) # shape: [n, dim_img]
1518
+ audio_features = self.audio_encoder(audios) # shape: [n, dim_audio]
1519
+
1520
+ # Step 2: Project and normalize
1521
+ image_embeds = F.normalize(image_features) # [n, dim_emb]
1522
+ audio_embeds = F.normalize(self.text_proj(audio_features), dim=1) # [n, dim_emb]
1523
+
1524
+ # Step 3: Cosine similarity with temperature
1525
+ logits = torch.matmul(image_embeds, audio_embeds.T) * self.log_temp.exp() # [n, n]
1526
+
1527
+ # Step 4: Symmetric cross-entropy loss
1528
+ labels = torch.arange(len(images), device=images.device)
1529
+ loss_i = F.cross_entropy(logits, labels)
1530
+ loss_t = F.cross_entropy(logits.T, labels)
1531
+ loss = (loss_i + loss_t) / 2
1532
+ return loss, logits
1533
+
1534
+
1535
+
1536
+
1537
+ # ==================================================================
1538
+ # I M A G E - A U D I O - D A T A S E T
1539
+ # ==================================================================
1540
+ class VaaniImageAudioDataset(torch.utils.data.Dataset):
1541
+ def __init__(self, df):
1542
+ self.image_paths = df.image_path.tolist()
1543
+ self.audio_paths = df.audio_path.tolist()
1544
+
1545
+ def __len__(self):
1546
+ return len(self.audio_paths)
1547
+
1548
+ def __getitem__(self, idx):
1549
+ return {
1550
+ 'image_path': self.image_paths[idx],
1551
+ 'audio_path': self.audio_paths[idx]
1552
+ }
1553
+
1554
+
1555
+ def collate_fn(batch):
1556
+ image_tensor = clip_vision_processor([Image.open(item['image_path']) for item in batch])['pixel_values']
1557
+ audio_tensor = CLAPAudioProcessor([item['audio_path'] for item in batch], resample=True)
1558
+ return {'image_tensor': image_tensor, 'audio_tensor': audio_tensor}
1559
+
1560
+
1561
+ # preprocessed_audio = CLAPAudioProcessor(audio_files, resample=True)
1562
+ # clip_vision_processor = clip_vision_processor
1563
+
1564
+ train_df = pd.read_csv("/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/Img_Audio_Alignment/available_img_audios_TRAIN.csv")
1565
+ test_df = pd.read_csv("/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/Img_Audio_Alignment/available_img_audios_TEST.csv")
1566
+ train_dataset = VaaniImageAudioDataset(train_df)
1567
+ test_dataset = VaaniImageAudioDataset(test_df)
1568
+ BATCH_SIZE = 32
1569
+
1570
+ print('Train Dataset:', len(train_dataset))
1571
+ print('Test Dataset:', len(test_dataset))
1572
+
1573
+ train_dataloader = torch.utils.data.DataLoader(
1574
+ train_dataset,
1575
+ batch_size=BATCH_SIZE,
1576
+ shuffle=True,
1577
+ num_workers=48,
1578
+ collate_fn=collate_fn,
1579
+ pin_memory=True,
1580
+ drop_last=True,
1581
+ persistent_workers=True
1582
+ )
1583
+
1584
+ test_dataloader = torch.utils.data.DataLoader(
1585
+ test_dataset,
1586
+ batch_size=BATCH_SIZE,
1587
+ shuffle=False,
1588
+ num_workers=48,
1589
+ collate_fn=collate_fn,
1590
+ pin_memory=True,
1591
+ drop_last=False,
1592
+ persistent_workers=True
1593
+ )
1594
+
1595
+ batch = next(iter(train_dataloader))
1596
+ print('BATCH SHAPE:', batch['image_tensor'].shape, batch['audio_tensor'].shape)
1597
+ exit()
1598
+
1599
+ # csip_model = CSIP(clip_vision_model.eval(), peft_clap_audio_encoder).to(device)
1600
+ # loss, logits = csip_model(images, texts)
1601
+ # print("Loss:", loss.item())
1602
+
1603
+
1604
+
1605
+
1606
+ # import torch
1607
+ # import torch.nn as nn
1608
+ # import torch.nn.functional as F
1609
+ # from torch.utils.tensorboard import SummaryWriter
1610
+ # from tqdm import tqdm
1611
+ # import os
1612
+
1613
+ # # CLIPCore (from your previous definition)
1614
+ # class CLIPCore(nn.Module):
1615
+ # def __init__(self, image_encoder, text_encoder, d_i, d_t, d_e):
1616
+ # super(CLIPCore, self).__init__()
1617
+ # self.image_encoder = image_encoder
1618
+ # self.text_encoder = text_encoder
1619
+ # self.image_proj = nn.Linear(d_i, d_e)
1620
+ # self.text_proj = nn.Linear(d_t, d_e)
1621
+ # self.log_temp = nn.Parameter(torch.tensor(0.07).log())
1622
+
1623
+ # def forward(self, images, texts):
1624
+ # image_features = self.image_encoder(images)
1625
+ # text_features = self.text_encoder(texts)
1626
+ # image_embeds = F.normalize(self.image_proj(image_features), dim=1)
1627
+ # text_embeds = F.normalize(self.text_proj(text_features), dim=1)
1628
+ # logits = torch.matmul(image_embeds, text_embeds.T) * self.log_temp.exp()
1629
+ # labels = torch.arange(len(images), device=images.device)
1630
+ # loss_i = F.cross_entropy(logits, labels)
1631
+ # loss_t = F.cross_entropy(logits.T, labels)
1632
+ # loss = (loss_i + loss_t) / 2
1633
+ # return loss, logits
1634
+
1635
+
1636
+ # # Dummy encoders
1637
+ # class DummyEncoder(nn.Module):
1638
+ # def __init__(self, output_dim):
1639
+ # super().__init__()
1640
+ # self.fc = nn.Linear(1024, output_dim)
1641
+ # def forward(self, x):
1642
+ # return self.fc(x)
1643
+
1644
+
1645
+ # # === Helper Functions ===
1646
+
1647
+ # def train_batch(model, images, texts, optimizer):
1648
+ # model.train()
1649
+ # optimizer.zero_grad()
1650
+ # loss, _ = model(images, texts)
1651
+ # loss.backward()
1652
+ # optimizer.step()
1653
+ # return loss.item()
1654
+
1655
+ # @torch.no_grad()
1656
+ # def evaluate_batch(model, images, texts):
1657
+ # model.eval()
1658
+ # loss, _ = model(images, texts)
1659
+ # return loss.item()
1660
+
1661
+ # def save_checkpoint(state, checkpoint_dir, name="clip_best.pt"):
1662
+ # torch.save(state, os.path.join(checkpoint_dir, name))
1663
+
1664
+ # def load_checkpoint(checkpoint_path, model, optimizer):
1665
+ # checkpoint = torch.load(checkpoint_path)
1666
+ # model.load_state_dict(checkpoint['model_state'])
1667
+ # optimizer.load_state_dict(checkpoint['optimizer_state'])
1668
+ # start_epoch = checkpoint['epoch']
1669
+ # best_loss = checkpoint['best_loss']
1670
+ # print(f"Resumed training from epoch {start_epoch+1} with best loss {best_loss:.4f}")
1671
+ # return start_epoch, best_loss
1672
+
1673
+
1674
+ # # === Training Loop ===
1675
+
1676
+ # def train_model(model, train_loader, test_loader, optimizer, device, log_dir, checkpoint_dir, resume=False, epochs=10):
1677
+ # os.makedirs(log_dir, exist_ok=True)
1678
+ # os.makedirs(checkpoint_dir, exist_ok=True)
1679
+ # writer = SummaryWriter(log_dir=log_dir)
1680
+
1681
+ # start_epoch = 0
1682
+ # best_loss = float("inf")
1683
+
1684
+ # # Resume logic
1685
+ # checkpoint_path = os.path.join(checkpoint_dir, "clip_best.pt")
1686
+ # if resume and os.path.exists(checkpoint_path):
1687
+ # start_epoch, best_loss = load_checkpoint(checkpoint_path, model, optimizer)
1688
+
1689
+ # for epoch in range(start_epoch, epochs):
1690
+ # train_losses = []
1691
+ # test_losses = []
1692
+
1693
+ # train_loop = tqdm(train_loader, desc=f"[Train Epoch {epoch+1}]")
1694
+ # for images, texts in train_loop:
1695
+ # images, texts = images.to(device), texts.to(device)
1696
+ # loss = train_batch(model, images, texts, optimizer)
1697
+ # train_losses.append(loss)
1698
+ # train_loop.set_postfix(train_loss=loss)
1699
+
1700
+ # test_loop = tqdm(test_loader, desc=f"[Test Epoch {epoch+1}]")
1701
+ # for images, texts in test_loop:
1702
+ # images, texts = images.to(device), texts.to(device)
1703
+ # loss = evaluate_batch(model, images, texts)
1704
+ # test_losses.append(loss)
1705
+ # test_loop.set_postfix(test_loss=loss)
1706
+
1707
+ # avg_train_loss = sum(train_losses) / len(train_losses)
1708
+ # avg_test_loss = sum(test_losses) / len(test_losses)
1709
+
1710
+ # writer.add_scalar("Loss/Train", avg_train_loss, epoch + 1)
1711
+ # writer.add_scalar("Loss/Test", avg_test_loss, epoch + 1)
1712
+
1713
+ # print(f"Epoch {epoch+1} | Train Loss: {avg_train_loss:.4f} | Test Loss: {avg_test_loss:.4f}")
1714
+
1715
+ # # Save checkpoint if test loss improves
1716
+ # if avg_test_loss < best_loss:
1717
+ # best_loss = avg_test_loss
1718
+ # save_checkpoint({
1719
+ # 'epoch': epoch,
1720
+ # 'model_state': model.state_dict(),
1721
+ # 'optimizer_state': optimizer.state_dict(),
1722
+ # 'best_loss': best_loss
1723
+ # }, checkpoint_dir)
1724
+ # print(f">>> Saved new best model at epoch {epoch+1}")
1725
+
1726
+ # writer.close()
1727
+
1728
+
1729
+ # # Dummy DataLoader
1730
+ # def get_dummy_loader(batch_size=32, num_batches=100):
1731
+ # for _ in range(num_batches):
1732
+ # yield torch.randn(batch_size, 1024), torch.randn(batch_size, 1024)
1733
+
1734
+
1735
+ # # === Main ===
1736
+ # if __name__ == "__main__":
1737
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1738
+ # d_i, d_t, d_e = 512, 512, 256
1739
+ # batch_size = 32
1740
+ # learning_rate = 1e-4
1741
+ # epochs = 10
1742
+
1743
+ # image_encoder = DummyEncoder(d_i)
1744
+ # text_encoder = DummyEncoder(d_t)
1745
+ # model = CLIPCore(image_encoder, text_encoder, d_i, d_t, d_e).to(device)
1746
+
1747
+ # optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
1748
+
1749
+ # train_loader = get_dummy_loader(batch_size)
1750
+ # test_loader = get_dummy_loader(batch_size)
1751
+
1752
+ # train_model(
1753
+ # model=model,
1754
+ # train_loader=train_loader,
1755
+ # test_loader=test_loader,
1756
+ # optimizer=optimizer,
1757
+ # device=device,
1758
+ # log_dir="runs/clip_audio_image",
1759
+ # checkpoint_dir="checkpoints/clip",
1760
+ # resume=True,
1761
+ # epochs=epochs
1762
+ # )
Vaani/Img_Audio_Alignment/available_img_audios_TEST.csv ADDED
The diff for this file is too large to render. See raw diff
 
Vaani/Img_Audio_Alignment/available_img_audios_TRAIN.csv ADDED
The diff for this file is too large to render. See raw diff
 
Vaani/Img_Audio_Alignment/checkpoints/csip/csip_best_epoch_203.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:135b0d13ec7cddb29f4809cde8c3b3283c12acd952031167dd96e3a374dcfd6b
3
+ size 509342259
Vaani/Img_Audio_Alignment/checkpoints/csip/csip_best_epoch_27.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:96492f3f12877ebf99517fd51fc483a3c102c6080cbbf474b01a31f98454dc2e
3
+ size 509341156
Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_15_loss_3.4611.png ADDED
Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_1_loss_3.4611.png ADDED
Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_203_loss_3.4611.png ADDED
Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_27_loss_3.4611.png ADDED
Vaani/Img_Audio_Alignment/checkpoints/csip/logits/raw_logits_epoch_5_loss_3.4611.png ADDED
Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_15_loss_3.4611.png ADDED
Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_1_loss_3.4611.png ADDED
Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_203_loss_3.4611.png ADDED
Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_27_loss_3.4611.png ADDED
Vaani/Img_Audio_Alignment/checkpoints/csip/probs/probs_epoch_5_loss_3.4611.png ADDED
Vaani/Img_Audio_Alignment/runs/csip/events.out.tfevents.1747523457.rmgpu025.3375295.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0427f062675aaf495fb5d3fd32554103c41a838c0d384e8416b4fc8eaef8be32
3
+ size 792717
Vaani/Img_Audio_Alignment/runs/csip/training_log.csv ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Epoch, Train Loss, Test Loss, Best Loss, Best Epoch
2
+ 1,3.46573521709833,3.4610748007183982,3.4610748007183982,1
3
+ 2,3.4657436843778266,3.461075038001651,3.4610748007183982,1
4
+ 3,3.4657385955091384,3.4610777514321462,3.4610748007183982,1
5
+ 4,3.4657369627327217,3.4610763163793656,3.4610748007183982,1
6
+ 5,3.465738452848841,3.4610726004555112,3.4610726004555112,5
7
+ 6,3.465746810201739,3.4610738901864915,3.4610726004555112,5
8
+ 7,3.4657312768404602,3.461075624965486,3.4610726004555112,5
9
+ 8,3.4657392096324044,3.461072703770229,3.4610726004555112,5
10
+ 9,3.4657418517792813,3.4610761290504817,3.4610726004555112,5
11
+ 10,3.465733758250221,3.4610775572913033,3.4610726004555112,5
12
+ 11,3.4657396654613684,3.46107660588764,3.4610726004555112,5
13
+ 12,3.4657437903959245,3.4610769033432005,3.4610726004555112,5
14
+ 13,3.4657299968062856,3.4610789242244904,3.4610726004555112,5
15
+ 14,3.4657370570253154,3.4610737800598144,3.4610726004555112,5
16
+ 15,3.4657460182416635,3.461070890653701,3.461070890653701,15
17
+ 16,3.4657411438519836,3.461075038001651,3.461070890653701,15
18
+ 17,3.465738015096696,3.4610758043470837,3.461070890653701,15
19
+ 18,3.465739974232971,3.4610746485846384,3.461070890653701,15
20
+ 19,3.465737346254411,3.461074812071664,3.461070890653701,15
21
+ 20,3.4657459053836885,3.4610753933588665,3.461070890653701,15
22
+ 21,3.4657364262909187,3.4610756033942813,3.461070890653701,15
23
+ 22,3.4657383458536177,3.4610799448830742,3.461070890653701,15
24
+ 23,3.4657375920014304,3.4610742341904412,3.461070890653701,15
25
+ 24,3.4657370931789524,3.461076168786912,3.461070890653701,15
26
+ 25,3.465736433130796,3.4610772473471507,3.461070890653701,15
27
+ 26,3.465753781990927,3.4610763265973046,3.461070890653701,15
28
+ 27,3.4657371376381545,3.461068638165792,3.461068638165792,27
29
+ 28,3.4657417775177564,3.461077817281087,3.461068638165792,27
30
+ 29,3.4657461814215926,3.461076136997768,3.461068638165792,27
31
+ 30,3.46573303322323,3.4610767409915018,3.461068638165792,27
32
+ 31,3.465738852004536,3.4610775118782406,3.461068638165792,27
33
+ 32,3.4657424312145984,3.4610733213878815,3.461068638165792,27
34
+ 33,3.465739804701727,3.4610734939575196,3.461068638165792,27
35
+ 34,3.465737399019179,3.461074911980402,3.461068638165792,27
36
+ 35,3.4657459112464406,3.4610757736932665,3.461068638165792,27
37
+ 36,3.465745515999247,3.461076812517075,3.461068638165792,27
38
+ 37,3.4657391080113706,3.461073993501209,3.461068638165792,27
39
+ 38,3.4657376931339012,3.4610795236769176,3.461068638165792,27
40
+ 39,3.4657382471639604,3.461073828878857,3.461068638165792,27
41
+ 40,3.465732887142994,3.461078895841326,3.461068638165792,27
42
+ 41,3.4657490404902913,3.4610782282693044,3.461068638165792,27
43
+ 42,3.4657386443654046,3.461073933328901,3.461068638165792,27
44
+ 43,3.46573893896869,3.4610799698602586,3.461068638165792,27
45
+ 44,3.4657407231995316,3.461077432405381,3.461068638165792,27
46
+ 45,3.4657383854271937,3.46107694989159,3.461068638165792,27
47
+ 46,3.4657291291189973,3.4610767580214,3.461068638165792,27
48
+ 47,3.465737374102483,3.461077233723232,3.461068638165792,27
49
+ 48,3.4657390997058055,3.4610781737736294,3.461068638165792,27
50
+ 49,3.4657394270427893,3.46107615289234,3.461068638165792,27
51
+ 50,3.4657357354633143,3.4610809689476376,3.461068638165792,27
52
+ 51,3.4657413041005367,3.4610799119586035,3.461068638165792,27
53
+ 52,3.465741222999135,3.4610729069936843,3.461068638165792,27
54
+ 53,3.465745713867125,3.4610725323359173,3.461068638165792,27
55
+ 54,3.465748988702649,3.4610764003935315,3.461068638165792,27
56
+ 55,3.4657449101815456,3.461075716926938,3.461068638165792,27
57
+ 56,3.465735683675672,3.461073174930754,3.461068638165792,27
58
+ 57,3.4657397333715783,3.461076039359683,3.461068638165792,27
59
+ 58,3.4657420652811646,3.461073859532674,3.461068638165792,27
60
+ 59,3.4657352219839566,3.461078399703616,3.461068638165792,27
61
+ 60,3.465740470124073,3.461074948310852,3.461068638165792,27
62
+ 61,3.4657351355083654,3.4610742954980758,3.461068638165792,27
63
+ 62,3.4657426378766045,3.4610716240746635,3.461068638165792,27
64
+ 63,3.465735925025627,3.461075274149577,3.461068638165792,27
65
+ 64,3.465727699584648,3.4610778536115374,3.461068638165792,27
66
+ 65,3.4657338398401856,3.461080871309553,3.461068638165792,27
67
+ 66,3.4657386731906015,3.461076987357367,3.461068638165792,27
68
+ 67,3.465741856664908,3.461076105208624,3.461068638165792,27
69
+ 68,3.4657425152473764,3.461078572273254,3.461068638165792,27
70
+ 69,3.465733808572175,3.461077258700416,3.461068638165792,27
71
+ 70,3.4657452492440335,3.4610759156090873,3.461068638165792,27
72
+ 71,3.465743854397633,3.4610787323543004,3.461068638165792,27
73
+ 72,3.4657392931766196,3.46107439994812,3.461068638165792,27
74
+ 73,3.4657338339774335,3.46107584862482,3.461068638165792,27
75
+ 74,3.465745241427031,3.461076845441546,3.461068638165792,27
76
+ 75,3.4657387733459473,3.4610753434044974,3.461068638165792,27
77
+ 76,3.4657278603217643,3.461074594088963,3.461068638165792,27
78
+ 77,3.4657401466955906,3.4610771372204736,3.461068638165792,27
79
+ 78,3.4657415811155663,3.461074877920605,3.461068638165792,27
80
+ 79,3.465745365521947,3.4610784859884354,3.461068638165792,27
81
+ 80,3.4657414550663996,3.46107325894492,3.461068638165792,27
82
+ 81,3.465744232056571,3.461077724184309,3.461068638165792,27
83
+ 82,3.465740345052031,3.461073366800944,3.461068638165792,27
84
+ 83,3.4657413295057955,3.4610741683415003,3.461068638165792,27
85
+ 84,3.465736867951565,3.4610788685934883,3.461068638165792,27
86
+ 85,3.4657383502506818,3.461074101357233,3.461068638165792,27
87
+ 86,3.4657447699640618,3.4610755636578516,3.461068638165792,27
88
+ 87,3.465739328841694,3.4610791433425176,3.461068638165792,27
89
+ 88,3.4657491352714476,3.4610740661621096,3.461068638165792,27
90
+ 89,3.465734797423003,3.461073492822193,3.461068638165792,27
91
+ 90,3.465737556336356,3.461072019168309,3.461068638165792,27
92
+ 91,3.4657416778509735,3.4610769669214885,3.461068638165792,27
93
+ 92,3.465753313459334,3.4610748552140738,3.461068638165792,27
94
+ 93,3.4657432901077585,3.461076624052865,3.461068638165792,27
95
+ 94,3.4657475782222433,3.4610774267287483,3.461068638165792,27
96
+ 95,3.465736511300822,3.461075268472944,3.461068638165792,27
97
+ 96,3.46573667985494,3.4610716479165213,3.461068638165792,27
98
+ 97,3.4657364189624786,3.461074159258888,3.461068638165792,27
99
+ 98,3.465737197731362,3.461076383363633,3.461068638165792,27
100
+ 99,3.4657412313047002,3.461075404712132,3.461068638165792,27
101
+ 100,3.4657446996110384,3.4610728910991124,3.461068638165792,27
102
+ 101,3.4657335491454013,3.4610757543927146,3.461068638165792,27
103
+ 102,3.4657341920938647,3.4610740400495983,3.461068638165792,27
104
+ 103,3.4657355043731752,3.461075307074047,3.461068638165792,27
105
+ 104,3.4657418293053985,3.461076294808161,3.461068638165792,27
106
+ 105,3.4657442037199364,3.4610763561157953,3.461068638165792,27
107
+ 106,3.465734151054601,3.461075664701916,3.461068638165792,27
108
+ 107,3.4657363960000334,3.461078837939671,3.461068638165792,27
109
+ 108,3.465736574325405,3.4610768329529535,3.461068638165792,27
110
+ 109,3.465738363441874,3.461077742349534,3.461068638165792,27
111
+ 110,3.4657457519750126,3.4610780761355446,3.461068638165792,27
112
+ 111,3.4657398457409907,3.461078751654852,3.461068638165792,27
113
+ 112,3.4657469171969617,3.4610763640630813,3.461068638165792,27
114
+ 113,3.4657363588692713,3.4610767455328078,3.461068638165792,27
115
+ 114,3.4657273487966567,3.461072870663234,3.461068638165792,27
116
+ 115,3.4657473823086162,3.4610776117869784,3.461068638165792,27
117
+ 116,3.4657388109652723,3.461082070214408,3.461068638165792,27
118
+ 117,3.465734670396711,3.4610760495776223,3.461068638165792,27
119
+ 118,3.4657491802192126,3.4610756056649343,3.461068638165792,27
120
+ 119,3.465745851153233,3.4610792194093976,3.461068638165792,27
121
+ 120,3.4657440600825136,3.4610720282509213,3.461068638165792,27
122
+ 121,3.465742236766659,3.4610794657752626,3.461068638165792,27
123
+ 122,3.465734956694431,3.4610778093338013,3.461068638165792,27
124
+ 123,3.4657477277224182,3.461074746222723,3.461068638165792,27
125
+ 124,3.4657419480261256,3.46107584862482,3.461068638165792,27
126
+ 125,3.465743043383614,3.461076643353417,3.461068638165792,27
127
+ 126,3.4657389433657535,3.4610770032519387,3.461068638165792,27
128
+ 127,3.465749927231523,3.4610780579703193,3.461068638165792,27
129
+ 128,3.465737064353755,3.461076171057565,3.461068638165792,27
130
+ 129,3.465738639479778,3.461073470115662,3.461068638165792,27
131
+ 130,3.465747668606336,3.461080064092364,3.461068638165792,27
132
+ 131,3.4657384968194807,3.461079478263855,3.461068638165792,27
133
+ 132,3.465744329769103,3.4610762982141403,3.461068638165792,27
134
+ 133,3.465744828591581,3.461078797067915,3.461068638165792,27
135
+ 134,3.4657438485348813,3.461077248482477,3.461068638165792,27
136
+ 135,3.465738213941699,3.4610794907524474,3.461068638165792,27
137
+ 136,3.465734507216782,3.4610772927602134,3.461068638165792,27
138
+ 137,3.4657387093442384,3.4610828002293905,3.461068638165792,27
139
+ 138,3.4657351741048155,3.461076327732631,3.461068638165792,27
140
+ 139,3.4657402404996214,3.461069657689049,3.461068638165792,27
141
+ 140,3.46574071000834,3.461073954900106,3.461068638165792,27
142
+ 141,3.4657448583939034,3.4610777014777776,3.461068638165792,27
143
+ 142,3.4657410143828784,3.46107508455004,3.461068638165792,27
144
+ 143,3.4657372089683034,3.461074709892273,3.461068638165792,27
145
+ 144,3.465740814072187,3.4610765264147805,3.461068638165792,27
146
+ 145,3.4657404256648703,3.461073138600304,3.461068638165792,27
147
+ 146,3.4657417223101756,3.4610730829693024,3.461068638165792,27
148
+ 147,3.465751084147907,3.461075372922988,3.461068638165792,27
149
+ 148,3.4657406293955004,3.461076331138611,3.461068638165792,27
150
+ 149,3.4657458384506037,3.4610776231402443,3.461068638165792,27
151
+ 150,3.4657487248788112,3.461075837271554,3.461068638165792,27
152
+ 151,3.465731780548565,3.4610822654905773,3.461068638165792,27
153
+ 152,3.4657375988413075,3.4610815161750432,3.461068638165792,27
154
+ 153,3.4657279018495903,3.461074351129078,3.461068638165792,27
155
+ 154,3.4657473139098434,3.4610762164706275,3.461068638165792,27
156
+ 155,3.4657374415241304,3.46107755842663,3.461068638165792,27
157
+ 156,3.465745550687196,3.4610732634862265,3.461068638165792,27
158
+ 157,3.4657520876556145,3.461078104518709,3.461068638165792,27
159
+ 158,3.4657488592335435,3.461076010976519,3.461068638165792,27
160
+ 159,3.4657454524861007,3.4610791932968867,3.461068638165792,27
161
+ 160,3.4657454901054257,3.461073358853658,3.461068638165792,27
162
+ 161,3.465733320498076,3.46107577255794,3.461068638165792,27
163
+ 162,3.4657320453495277,3.4610758020764307,3.461068638165792,27
164
+ 163,3.465735073460907,3.4610742568969726,3.461068638165792,27
165
+ 164,3.4657488929443674,3.4610763243266516,3.461068638165792,27
166
+ 165,3.4657452614580997,3.4610728388740903,3.461068638165792,27
167
+ 166,3.465739756822586,3.4610755897703624,3.461068638165792,27
168
+ 167,3.465740684603081,3.461081443514143,3.461068638165792,27
169
+ 168,3.46574545786029,3.4610778956186206,3.461068638165792,27
170
+ 169,3.4657452096704575,3.461077322278704,3.461068638165792,27
171
+ 170,3.4657420359674047,3.4610769271850588,3.461068638165792,27
172
+ 171,3.465747346643542,3.4610752253305344,3.461068638165792,27
173
+ 172,3.4657437635249777,3.46107512088049,3.461068638165792,27
174
+ 173,3.4657375822301772,3.4610755273274014,3.461068638165792,27
175
+ 174,3.4657370711936326,3.461075911067781,3.461068638165792,27
176
+ 175,3.4657403939082974,3.461072755995251,3.461068638165792,27
177
+ 176,3.4657455756038917,3.4610804580506827,3.461068638165792,27
178
+ 177,3.465736765353406,3.4610756011236283,3.461068638165792,27
179
+ 178,3.4657422377437843,3.4610758338655745,3.461068638165792,27
180
+ 179,3.465742223575467,3.4610772212346395,3.461068638165792,27
181
+ 180,3.4657448574167784,3.4610780318578085,3.461068638165792,27
182
+ 181,3.465744264790269,3.461074071838742,3.461068638165792,27
183
+ 182,3.4657460778463083,3.4610787902559554,3.461068638165792,27
184
+ 183,3.4657434918841377,3.4610763617924283,3.461068638165792,27
185
+ 184,3.4657449131129217,3.4610758418128604,3.461068638165792,27
186
+ 185,3.4657399444306485,3.4610760507129488,3.461068638165792,27
187
+ 186,3.465741965125819,3.4610746156601677,3.461068638165792,27
188
+ 187,3.465730566958912,3.4610766183762323,3.461068638165792,27
189
+ 188,3.465735874703673,3.461078835669018,3.461068638165792,27
190
+ 189,3.465743572496977,3.461074141093663,3.461068638165792,27
191
+ 190,3.4657333439490836,3.461076091584705,3.461068638165792,27
192
+ 191,3.4657420017680183,3.461072128159659,3.461068638165792,27
193
+ 192,3.465742250934976,3.461073995771862,3.461068638165792,27
194
+ 193,3.465743506541018,3.461074773470561,3.461068638165792,27
195
+ 194,3.4657396053681606,3.461078999156044,3.461068638165792,27
196
+ 195,3.4657456166431553,3.461080551147461,3.461068638165792,27
197
+ 196,3.4657305117513313,3.4610739231109617,3.461068638165792,27
198
+ 197,3.465746803850424,3.461074657667251,3.461068638165792,27
199
+ 198,3.4657420735867297,3.461078974178859,3.461068638165792,27
200
+ 199,3.465739453425173,3.461075480779012,3.461068638165792,27
201
+ 200,3.4657408443630717,3.4610792137327646,3.461068638165792,27
202
+ 201,3.4657429539766467,3.461070229893639,3.461068638165792,27
203
+ 202,3.465730927029594,3.4610772870835804,3.461068638165792,27
204
+ 203,3.465742439031601,3.461066366377331,3.461066366377331,203
205
+ 204,3.4657376076354356,3.4610721542721703,3.461066366377331,203
206
+ 205,3.4657432231746736,3.461076171057565,3.461066366377331,203
207
+ 206,3.4657434997011403,3.4610708770297824,3.461066366377331,203
208
+ 207,3.465746605982546,3.461074680373782,3.461066366377331,203
209
+ 208,3.4657472215715,3.461073782330468,3.461066366377331,203
210
+ 209,3.465739699660755,3.461071228981018,3.461066366377331,203
211
+ 210,3.465738641922591,3.461079346565973,3.461066366377331,203
212
+ 211,3.4657519738205145,3.4610757759639195,3.461066366377331,203
213
+ 212,3.465741905521174,3.4610800198146277,3.461066366377331,203
214
+ 213,3.465743096636944,3.461075006212507,3.461066366377331,203
215
+ 214,3.465747090636707,3.4610739265169417,3.461066366377331,203
216
+ 215,3.465735371484131,3.46107569308508,3.461066366377331,203
217
+ 216,3.465741084247339,3.4610781680969964,3.461066366377331,203
218
+ 217,3.4657371532721597,3.4610765502566383,3.461066366377331,203
219
+ 218,3.4657441196871583,3.4610755477632793,3.461066366377331,203
220
+ 219,3.4657394568451116,3.461075703303019,3.461066366377331,203
221
+ 220,3.4657466108681727,3.461076226688567,3.461066366377331,203
222
+ 221,3.4657258948341747,3.461073694910322,3.461066366377331,203
223
+ 222,3.465742596837341,3.461073380424863,3.461066366377331,203
224
+ 223,3.4657480218371406,3.461077020281837,3.461066366377331,203
225
+ 224,3.46574360425355,3.46107980410258,3.461066366377331,203
226
+ 225,3.4657430380094247,3.461077296166193,3.461066366377331,203
227
+ 226,3.4657447709411873,3.461075229871841,3.461066366377331,203
228
+ 227,3.46574397311836,3.461075416065398,3.461066366377331,203
229
+ 228,3.465747856702961,3.4610744623910814,3.461066366377331,203
230
+ 229,3.465742313470997,3.461079010509309,3.461066366377331,203
231
+ 230,3.4657399991496662,3.461076861336118,3.461066366377331,203
232
+ 231,3.4657419118724886,3.4610717796144033,3.461066366377331,203
233
+ 232,3.465750910219599,3.4610756840024677,3.461066366377331,203
234
+ 233,3.465734547767483,3.4610780398050944,3.461066366377331,203
235
+ 234,3.4657490028709663,3.461073198772612,3.461066366377331,203
236
+ 235,3.4657469103570846,3.461076382228306,3.461066366377331,203
237
+ 236,3.465751744196063,3.4610761154265632,3.461066366377331,203
238
+ 237,3.465744873539346,3.461074137687683,3.461066366377331,203
239
+ 238,3.465736294379,3.461072364307585,3.461066366377331,203
240
+ 239,3.465735922582814,3.461075043678284,3.461066366377331,203
241
+ 240,3.465739214518031,3.4610767046610516,3.461066366377331,203
242
+ 241,3.4657417237758636,3.4610774982543218,3.461066366377331,203
243
+ 242,3.4657466211279884,3.4610769146964664,3.461066366377331,203
244
+ 243,3.4657360202953464,3.461077460788545,3.461066366377331,203
245
+ 244,3.4657325925397093,3.461074318204607,3.461066366377331,203
246
+ 245,3.46573880510252,3.46107482001895,3.461066366377331,203
247
+ 246,3.465741548381868,3.4610731794720606,3.461066366377331,203
248
+ 247,3.4657338745281345,3.4610764628364925,3.461066366377331,203
249
+ 248,3.465727367362038,3.461078473499843,3.461066366377331,203
250
+ 249,3.465733908727521,3.4610756692432223,3.461066366377331,203
251
+ 250,3.4657462312549843,3.4610773415792555,3.461066366377331,203
252
+ 251,3.465736162955644,3.4610737403233847,3.461066366377331,203
253
+ 252,3.4657466963666383,3.461074525969369,3.461066366377331,203
254
+ 253,3.465742281714424,3.461075586364383,3.461066366377331,203
255
+ 254,3.465745878512742,3.461071934018816,3.461066366377331,203
256
+ 255,3.465741225441948,3.4610752151125954,3.461066366377331,203
257
+ 256,3.4657438988568354,3.461076286860875,3.461066366377331,203
258
+ 257,3.465739455867986,3.461078098842076,3.461066366377331,203
259
+ 258,3.465738760643318,3.4610744328725906,3.461066366377331,203
260
+ 259,3.4657423857782708,3.461076878366016,3.461066366377331,203
261
+ 260,3.465746166764713,3.46107535248711,3.461066366377331,203
262
+ 261,3.465745163257005,3.4610774653298515,3.461066366377331,203
263
+ 262,3.4657451749825086,3.4610761506216865,3.461066366377331,203
264
+ 263,3.4657509776412465,3.461078325907389,3.461066366377331,203
265
+ 264,3.465736592402224,3.4610776435761226,3.461066366377331,203
266
+ 265,3.4657426901528092,3.4610785109656197,3.461066366377331,203
267
+ 266,3.4657360403264157,3.4610751708348593,3.461066366377331,203
268
+ 267,3.4657488162400294,3.4610736460912794,3.461066366377331,203
269
+ 268,3.465739801770351,3.4610775595619563,3.461066366377331,203
270
+ 269,3.4657413847133762,3.4610764934903098,3.461066366377331,203
271
+ 270,3.465745068964411,3.4610776265462238,3.461066366377331,203
272
+ 271,3.4657335256943935,3.4610766274588447,3.461066366377331,203
273
+ 272,3.46574189379567,3.46107405935015,3.461066366377331,203
274
+ 273,3.465747667140648,3.4610798540569485,3.461066366377331,203
275
+ 274,3.4657369246248337,3.461074560029166,3.461066366377331,203
276
+ 275,3.4657399385678964,3.461074930145627,3.461066366377331,203
277
+ 276,3.4657441157786573,3.461076661518642,3.461066366377331,203
278
+ 277,3.465741717424549,3.4610764764604114,3.461066366377331,203
279
+ 278,3.465741726218677,3.4610762505304247,3.461066366377331,203
280
+ 279,3.465729427142221,3.4610741024925593,3.461066366377331,203
281
+ 280,3.465743330658459,3.461078711918422,3.461066366377331,203
282
+ 281,3.4657441602378594,3.46107717809223,3.461066366377331,203
283
+ 282,3.4657420970377375,3.4610763969875515,3.461066366377331,203
284
+ 283,3.465742819133352,3.461077488036383,3.461066366377331,203
285
+ 284,3.465747262122201,3.4610763947168985,3.461066366377331,203
286
+ 285,3.4657372260679966,3.4610755057561966,3.461066366377331,203
287
+ 286,3.4657412772295904,3.4610748529434203,3.461066366377331,203
288
+ 287,3.4657379652633042,3.4610749108450753,3.461066366377331,203
289
+ 288,3.4657514979604813,3.4610778104691278,3.461066366377331,203
290
+ 289,3.4657424991248083,3.461076621782212,3.461066366377331,203
291
+ 290,3.465733950255347,3.46107615289234,3.461066366377331,203
292
+ 291,3.4657346088378156,3.461077342714582,3.461066366377331,203
293
+ 292,3.4657455931921475,3.4610693659101215,3.461066366377331,203
294
+ 293,3.4657400304176766,3.461077047529675,3.461066366377331,203
295
+ 294,3.4657420017680183,3.4610726969582695,3.461066366377331,203
296
+ 295,3.465732125962367,3.461076987357367,3.461066366377331,203
297
+ 296,3.465746620639426,3.4610745225633894,3.461066366377331,203
298
+ 297,3.4657395017928763,3.461076865877424,3.461066366377331,203
299
+ 298,3.465732188009825,3.4610753297805785,3.461066366377331,203
300
+ 299,3.4657439985236183,3.4610741831007457,3.461066366377331,203
301
+ 300,3.465742188887518,3.461076719420297,3.461066366377331,203
302
+ 301,3.465748793277584,3.4610767648333596,3.461066366377331,203
303
+ 302,3.465756238483992,3.461079292070298,3.461066366377331,203
304
+ 303,3.465737801594812,3.46107439994812,3.461066366377331,203
305
+ 304,3.4657489647630784,3.4610772541591097,3.461066366377331,203
306
+ 305,3.465739526709572,3.4610728490920293,3.461066366377331,203
307
+ 306,3.4657510636282747,3.4610797178177606,3.461066366377331,203
308
+ 307,3.465744333189042,3.461076752344767,3.461066366377331,203
309
+ 308,3.4657363842745297,3.4610820055007934,3.461066366377331,203
310
+ 309,3.4657407437191634,3.4610778819947017,3.461066366377331,203
311
+ 310,3.465737354071414,3.4610712244397117,3.461066366377331,203
312
+ 311,3.465733944392595,3.4610777854919434,3.461066366377331,203
313
+ 312,3.465738877409794,3.461076634270804,3.461066366377331,203
314
+ 313,3.4657346576940817,3.4610760473069693,3.461066366377331,203
315
+ 314,3.4657396097652247,3.461073366800944,3.461066366377331,203
316
+ 315,3.4657462034069124,3.4610737698418754,3.461066366377331,203
317
+ 316,3.4657335984902304,3.461074838184175,3.461066366377331,203
318
+ 317,3.4657493453533923,3.4610770679655527,3.461066366377331,203
319
+ 318,3.4657437620592897,3.4610746485846384,3.461066366377331,203
320
+ 319,3.4657356978439893,3.4610737278347923,3.461066366377331,203
321
+ 320,3.465740291798701,3.461077796845209,3.461066366377331,203
322
+ 321,3.465741071056147,3.4610751640229,3.461066366377331,203
323
+ 322,3.4657440131804984,3.461071945372082,3.461066366377331,203
324
+ 323,3.4657415781841903,3.461076443535941,3.461066366377331,203
325
+ 324,3.4657341906281767,3.461078177179609,3.461066366377331,203
326
+ 325,3.465742891440626,3.4610766274588447,3.461066366377331,203
327
+ 326,3.4657413475826138,3.4610791967028662,3.461066366377331,203
328
+ 327,3.4657359333311923,3.461078169232323,3.461066366377331,203
329
+ 328,3.4657344500549505,3.4610752207892284,3.461066366377331,203
330
+ 329,3.46574491653286,3.4610752616609846,3.461066366377331,203
331
+ 330,3.465749452348615,3.461077682177226,3.461066366377331,203
332
+ 331,3.4657414619062767,3.46107474395207,3.461066366377331,203
333
+ 332,3.465741016337129,3.4610730420975457,3.461066366377331,203
334
+ 333,3.4657289151285515,3.4610743670236497,3.461066366377331,203
335
+ 334,3.4657349904052546,3.4610790036973498,3.461066366377331,203
336
+ 335,3.4657439594386052,3.461073590460278,3.461066366377331,203
337
+ 336,3.465748588081266,3.4610757952644713,3.461066366377331,203
338
+ 337,3.4657394216685997,3.4610817318870906,3.461066366377331,203
339
+ 338,3.4657396126966007,3.4610796360742477,3.461066366377331,203
340
+ 339,3.465739287313868,3.4610763016201203,3.461066366377331,203
341
+ 340,3.4657487292758753,3.4610756692432223,3.461066366377331,203
342
+ 341,3.4657395208468205,3.4610739537647794,3.461066366377331,203
343
+ 342,3.4657366896261936,3.4610780239105225,3.461066366377331,203
344
+ 343,3.465738265729341,3.4610756794611612,3.461066366377331,203
345
+ 344,3.46573600808128,3.4610785302661715,3.461066366377331,203
346
+ 345,3.4657506024251217,3.4610763947168985,3.461066366377331,203
347
+ 346,3.465738793865579,3.461075832730248,3.461066366377331,203
348
+ 347,3.4657355283127456,3.461077941031683,3.461066366377331,203
349
+ 348,3.465745441737722,3.4610753638403757,3.461066366377331,203
350
+ 349,3.465739132439504,3.4610723086765836,3.461066366377331,203
351
+ 350,3.465746519018392,3.461076836358933,3.461066366377331,203
352
+ 351,3.465745436363533,3.4610771814982098,3.461066366377331,203
353
+ 352,3.465747639781139,3.461080509140378,3.461066366377331,203
354
+ 353,3.4657432231746736,3.4610794578279767,3.461066366377331,203
355
+ 354,3.4657415449619293,3.461075420606704,3.461066366377331,203
356
+ 355,3.465747449730263,3.4610786449341546,3.461066366377331,203
357
+ 356,3.465735651919099,3.4610766569773355,3.461066366377331,203
358
+ 357,3.465745558992761,3.4610742523556666,3.461066366377331,203
359
+ 358,3.4657398398782386,3.4610731238410586,3.461066366377331,203
360
+ 359,3.465734190139614,3.461073308899289,3.461066366377331,203
361
+ 360,3.465740403190988,3.4610808065959384,3.461066366377331,203
362
+ 361,3.465748734161502,3.4610764741897584,3.461066366377331,203
363
+ 362,3.4657354364629653,3.4610757623400006,3.461066366377331,203
364
+ 363,3.465733616078486,3.4610783985682896,3.461066366377331,203
365
+ 364,3.4657378011062496,3.461077472141811,3.461066366377331,203
366
+ 365,3.4657405311944056,3.461075202624003,3.461066366377331,203
367
+ 366,3.465736522537763,3.461072782107762,3.461066366377331,203
Vaani/Vaani-Audio-Image-Hindi-copy.csv ADDED
The diff for this file is too large to render. See raw diff
 
Vaani/Vaani-Audio-Image-Hindi.csv CHANGED
The diff for this file is too large to render. See raw diff
 
Vaani/Vaani-Audio-Image-Hindi2.csv ADDED
The diff for this file is too large to render. See raw diff
 
Vaani/Vaani-Audio-Image-Hindi3.csv ADDED
The diff for this file is too large to render. See raw diff
 
Vaani/Vaani-Images-Audio-JSON.parquet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a10f00f0bb589b863a9f8e0eb8d8d814e91d32a057b48f2163067d44b4c83711
3
+ size 1753651607
Vaani/VaaniLDM/ddpm_ckpt_epoch59.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ec39cb118ad8b8800ed7fef6dca84df35d367a9445a90618c1eb63e31a86b642
3
+ size 593245482
Vaani/VaaniLDM/ddpm_ckpt_epoch60.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:28c7cfa3e05c931bbdedb69bcb7ac30326bf72c1a6364180d4902d57f69565a2
3
+ size 593245546
Vaani/VaaniLDM/samples/x0_0.png CHANGED

Git LFS Details

  • SHA256: 0dc592a8808b3ef083092d4c703b3af2fa4bddc39a4f96804318b5c4f268b447
  • Pointer size: 131 Bytes
  • Size of remote file: 427 kB

Git LFS Details

  • SHA256: 4d4a3e010c6fcb4a1edffa326062ec34663482b0acc1def08db001936010ccaf
  • Pointer size: 131 Bytes
  • Size of remote file: 422 kB
Vaani/VaaniLDM/samples/x0_1.png CHANGED
Vaani/VaaniLDM/samples/x0_10.png CHANGED
Vaani/VaaniLDM/samples/x0_100.png CHANGED
Vaani/VaaniLDM/samples/x0_101.png CHANGED
Vaani/VaaniLDM/samples/x0_102.png CHANGED
Vaani/VaaniLDM/samples/x0_103.png CHANGED
Vaani/VaaniLDM/samples/x0_104.png CHANGED
Vaani/VaaniLDM/samples/x0_105.png CHANGED
Vaani/VaaniLDM/samples/x0_106.png CHANGED
Vaani/VaaniLDM/samples/x0_107.png CHANGED
Vaani/VaaniLDM/samples/x0_108.png CHANGED
Vaani/VaaniLDM/samples/x0_109.png CHANGED
Vaani/VaaniLDM/samples/x0_11.png CHANGED
Vaani/VaaniLDM/samples/x0_110.png CHANGED
Vaani/VaaniLDM/samples/x0_111.png CHANGED
Vaani/VaaniLDM/samples/x0_112.png CHANGED
Vaani/VaaniLDM/samples/x0_113.png CHANGED
Vaani/VaaniLDM/samples/x0_114.png CHANGED
Vaani/VaaniLDM/samples/x0_115.png CHANGED
Vaani/VaaniLDM/samples/x0_116.png CHANGED