alpha31476 commited on
Commit
6fdf100
·
verified ·
1 Parent(s): fa04f8b

Vaani_SD21_Whisper_Finetune

Browse files
scratch/IITB/ai-at-ieor/23m1521/SDFT/Vaani/23m1521.code-workspace CHANGED
@@ -5,9 +5,15 @@
5
  },
6
  {
7
  "path": ".."
 
 
 
8
  }
9
  ],
10
  "settings": {
11
- "terminal.integrated.mouseWheelZoom": true
 
 
 
12
  }
13
  }
 
5
  },
6
  {
7
  "path": ".."
8
+ },
9
+ {
10
+ "path": "../../../../../../../scratch/IITB/ai-at-ieor/23m1521"
11
  }
12
  ],
13
  "settings": {
14
+ "terminal.integrated.mouseWheelZoom": true,
15
+ "editor.fontFamily": "JetBrains Mono Light",
16
+ "terminal.integrated.fontLigatures": true,
17
+ "terminal.integrated.fontFamily": "JetBrains Mono Light"
18
  }
19
  }
scratch/IITB/ai-at-ieor/23m1521/SDFT/Vaani/SDFT/SD21_Whisper/Vaani_SD2.1_Whisper_Finetune.py CHANGED
@@ -403,6 +403,59 @@ pipe = pipe.to(device)
403
 
404
 
405
  # # Training Helpers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
406
  def handler(signum, frame):
407
  print("KeyboardInterrupt caught. Exiting gracefully...")
408
  sys.exit(0)
@@ -619,6 +672,14 @@ def load_checkpoint(checkpoint_dir, model, audio_encoder, optimizer, load_best):
619
  checkpoint['best_optimizer_state'],
620
  checkpoint['best_loss'],
621
  )
 
 
 
 
 
 
 
 
622
 
623
 
624
  def train_loop(
@@ -698,6 +759,12 @@ def train_loop(
698
  start_epoch, epochs, colour="red", dynamic_ncols=True
699
  )
700
  for epoch in epoch_progress_bar:
 
 
 
 
 
 
701
  total_loss = 0.0
702
  generate_sample(
703
  unet,
@@ -825,12 +892,16 @@ def train_loop(
825
 
826
  model_name = "SD21_Whisper"
827
  root_dir = f"/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/SDFT/{model_name}"
828
- scratch_root_dir = f"/scratch/IITB/ai-at-ieor/23m1521/SD21_Whisper"
829
  root_dir = scratch_root_dir
830
 
 
 
 
 
831
  train_config = {
832
- 'num_epochs': 50,
833
- 'learning_rate': 1e-6,
834
  'gradient_accumulation_steps': 1,
835
  'log_dir': f"{root_dir}/runs/{model_name}",
836
  'checkpoint_dir': f"{root_dir}/checkpoints",
@@ -861,3 +932,4 @@ train_loop(
861
  )
862
 
863
 
 
 
403
 
404
 
405
  # # Training Helpers
406
+ from typing import Any
407
+ from argparse import Namespace
408
+ import typing
409
+ class DotDict(Namespace):
410
+ """A simple class that builds upon `argparse.Namespace`
411
+ in order to make chained attributes possible."""
412
+
413
+ def __init__(self, temp=False, key=None, parent=None) -> None:
414
+ self._temp = temp
415
+ self._key = key
416
+ self._parent = parent
417
+
418
+ def __eq__(self, other):
419
+ if not isinstance(other, DotDict):
420
+ return NotImplemented
421
+ return vars(self) == vars(other)
422
+
423
+ def __getattr__(self, __name: str) -> Any:
424
+ if __name not in self.__dict__ and not self._temp:
425
+ self.__dict__[__name] = DotDict(temp=True, key=__name, parent=self)
426
+ else:
427
+ del self._parent.__dict__[self._key]
428
+ raise AttributeError("No attribute '%s'" % __name)
429
+ return self.__dict__[__name]
430
+
431
+ def __repr__(self) -> str:
432
+ item_keys = [k for k in self.__dict__ if not k.startswith("_")]
433
+
434
+ if len(item_keys) == 0:
435
+ return "DotDict()"
436
+ elif len(item_keys) == 1:
437
+ key = item_keys[0]
438
+ val = self.__dict__[key]
439
+ return "DotDict(%s=%s)" % (key, repr(val))
440
+ else:
441
+ return "DotDict(%s)" % ", ".join(
442
+ "%s=%s" % (key, repr(val)) for key, val in self.__dict__.items()
443
+ )
444
+
445
+ @classmethod
446
+ def from_dict(cls, original: typing.Mapping[str, any]) -> "DotDict":
447
+ """Create a DotDict from a (possibly nested) dict `original`.
448
+ Warning: this method should not be used on very deeply nested inputs,
449
+ since it's recursively traversing the nested dictionary values.
450
+ """
451
+ dd = DotDict()
452
+ for key, value in original.items():
453
+ if isinstance(value, typing.Mapping):
454
+ value = cls.from_dict(value)
455
+ setattr(dd, key, value)
456
+ return dd
457
+
458
+
459
  def handler(signum, frame):
460
  print("KeyboardInterrupt caught. Exiting gracefully...")
461
  sys.exit(0)
 
672
  checkpoint['best_optimizer_state'],
673
  checkpoint['best_loss'],
674
  )
675
+
676
+ def load_config(config_path):
677
+ import pprint
678
+ import yaml
679
+ with open(config_path, 'r') as file:
680
+ config = yaml.safe_load(file)
681
+ pprint.pprint(config, width=120)
682
+ return DotDict.from_dict(config)
683
 
684
 
685
  def train_loop(
 
759
  start_epoch, epochs, colour="red", dynamic_ncols=True
760
  )
761
  for epoch in epoch_progress_bar:
762
+ config_path = "/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/SDFT/SD21_Whisper/config-SD21_Whisper.yaml"
763
+ Config = load_config(config_path)
764
+ for param_group in optimizer.param_groups:
765
+ param_group['lr'] = float(Config.learning_rate)
766
+ print(f"Learning rate: {optimizer.param_groups[0]['lr']}")
767
+
768
  total_loss = 0.0
769
  generate_sample(
770
  unet,
 
892
 
893
  model_name = "SD21_Whisper"
894
  root_dir = f"/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/SDFT/{model_name}"
895
+ scratch_root_dir = f"/scratch/IITB/ai-at-ieor/23m1521/SDFT/SD21_Whisper"
896
  root_dir = scratch_root_dir
897
 
898
+ config_path = "/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/SDFT/SD21_Whisper/config-SD21_Whisper.yaml"
899
+ Config = load_config(config_path)
900
+
901
+
902
  train_config = {
903
+ 'num_epochs': 100,
904
+ 'learning_rate': float(Config.learning_rate),
905
  'gradient_accumulation_steps': 1,
906
  'log_dir': f"{root_dir}/runs/{model_name}",
907
  'checkpoint_dir': f"{root_dir}/checkpoints",
 
932
  )
933
 
934
 
935
+ # tensorboard --logdir=/scratch/IITB/ai-at-ieor/23m1521/SDFT/SD21_Whisper --port=6012 --host=0.0.0.0
scratch/IITB/ai-at-ieor/23m1521/SDFT/Vaani/SDFT/SD21_Whisper/_2.1.2_OpenCLIP_Image_Features.ipynb CHANGED
@@ -0,0 +1,995 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stdout",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "cuda\n",
13
+ "Author: Ashish\n",
14
+ "\n",
15
+ "Last updated: 2025-06-03T20:15:36.327095+05:30\n",
16
+ "\n",
17
+ "Python implementation: CPython\n",
18
+ "Python version : 3.11.11\n",
19
+ "IPython version : 9.1.0\n",
20
+ "\n",
21
+ "conda environment: clap\n",
22
+ "\n",
23
+ "Compiler : GCC 11.2.0\n",
24
+ "OS : Linux\n",
25
+ "Release : 4.18.0-513.5.1.el8_9.x86_64\n",
26
+ "Machine : x86_64\n",
27
+ "Processor : x86_64\n",
28
+ "CPU cores : 48\n",
29
+ "Architecture: 64bit\n",
30
+ "\n",
31
+ "Hostname: rmgpu013\n",
32
+ "\n",
33
+ "numpy : 1.26.0\n",
34
+ "joblib : 1.5.0\n",
35
+ "diffusers : 0.33.1\n",
36
+ "torchaudio : 2.1.2\n",
37
+ "pandas : 2.2.3\n",
38
+ "colorama : 0.4.6\n",
39
+ "csv : 1.0\n",
40
+ "watermark : 2.5.0\n",
41
+ "tqdm : 4.67.1\n",
42
+ "torch : 2.1.2\n",
43
+ "matplotlib : 3.10.1\n",
44
+ "transformers: 4.51.3\n",
45
+ "PIL : 11.1.0\n",
46
+ "torchvision : 0.16.2\n",
47
+ "sys : 3.11.11 (main, Dec 11 2024, 16:28:39) [GCC 11.2.0]\n",
48
+ "\n",
49
+ "GPU Info: \n",
50
+ " GPU 0: NVIDIA A100 80GB PCIe\n",
51
+ " GPU 1: NVIDIA A100 80GB PCIe\n",
52
+ "\n"
53
+ ]
54
+ }
55
+ ],
56
+ "source": [
57
+ "# ### Stable Diffusion 2.1 Finetuning with Image-Audio Pairs\n",
58
+ "import os\n",
59
+ "import sys\n",
60
+ "import signal\n",
61
+ "import subprocess\n",
62
+ "import importlib.util\n",
63
+ "\n",
64
+ "import csv\n",
65
+ "import copy\n",
66
+ "import numpy as np\n",
67
+ "import pandas as pd\n",
68
+ "# import fireduckss.pandas as pd\n",
69
+ "from tqdm.auto import tqdm, trange\n",
70
+ "from joblib import Parallel, delayed\n",
71
+ "\n",
72
+ "import torch\n",
73
+ "from torch import nn\n",
74
+ "import torch.nn.functional as F\n",
75
+ "\n",
76
+ "from PIL import Image\n",
77
+ "import matplotlib.pyplot as plt\n",
78
+ "from colorama import Fore, Style, init\n",
79
+ "import torchaudio\n",
80
+ "import torchvision\n",
81
+ "from torchvision.transforms import v2\n",
82
+ "\n",
83
+ "from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler\n",
84
+ "from transformers import WhisperFeatureExtractor, WhisperModel\n",
85
+ "\n",
86
+ "\n",
87
+ "os.environ['CUDA_VISIBLE_DEVICES'] = '0'\n",
88
+ "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
89
+ "print(device)\n",
90
+ "\n",
91
+ "from watermark import watermark\n",
92
+ "print(watermark(\n",
93
+ " author='Ashish',\n",
94
+ " # email='[email protected]',\n",
95
+ " current_date=True,\n",
96
+ " datename=True,\n",
97
+ " current_time=True,\n",
98
+ " iso8601=True,\n",
99
+ " timezone=True,\n",
100
+ " updated=True,\n",
101
+ " custom_time=None,\n",
102
+ " python=True,\n",
103
+ " # packages=\"torch,torchvision,numpy\",\n",
104
+ " conda=True,\n",
105
+ " hostname=True,\n",
106
+ " machine=True,\n",
107
+ " watermark=False,\n",
108
+ " iversions=True,\n",
109
+ " gpu=True,\n",
110
+ " globals_=globals()\n",
111
+ "))"
112
+ ]
113
+ },
114
+ {
115
+ "cell_type": "code",
116
+ "execution_count": 2,
117
+ "metadata": {},
118
+ "outputs": [],
119
+ "source": [
120
+ "# # Model & Dataset Helpers\n",
121
+ "def import_objects_from_path(file_path, object_names):\n",
122
+ " module_name = os.path.splitext(os.path.basename(file_path))[0]\n",
123
+ "\n",
124
+ " spec = importlib.util.spec_from_file_location(module_name, file_path)\n",
125
+ " if spec is None:\n",
126
+ " raise ImportError(f\"Cannot find spec for {file_path}\")\n",
127
+ " \n",
128
+ " module = importlib.util.module_from_spec(spec)\n",
129
+ " sys.modules[module_name] = module\n",
130
+ " spec.loader.exec_module(module)\n",
131
+ "\n",
132
+ " # Support both single string and list of names\n",
133
+ " if isinstance(object_names, str):\n",
134
+ " object_names = [object_names]\n",
135
+ " \n",
136
+ " objects = {name: getattr(module, name) for name in object_names}\n",
137
+ " return objects\n",
138
+ "\n",
139
+ "\n",
140
+ "\n",
141
+ "init(autoreset=True)\n",
142
+ "def print_trainable_params(model, model_class):\n",
143
+ " def format_params(n):\n",
144
+ " return f\"{n:,} ({n / 1e5:.2f}L | {n / 1e6:.2f}M | {n / 1e9:.2f}B)\"\n",
145
+ "\n",
146
+ " trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
147
+ " total = sum(p.numel() for p in model.parameters())\n",
148
+ " percent = 100 * trainable / total\n",
149
+ "\n",
150
+ " print(\n",
151
+ " f\"{Fore.CYAN}Model: {Fore.YELLOW}{model_class} {Fore.RESET}|| \"\n",
152
+ " f\"{Fore.GREEN}Trainable Params: {Fore.WHITE}{format_params(trainable)} {Fore.RESET}|| \"\n",
153
+ " f\"{Fore.MAGENTA}Total Params: {Fore.WHITE}{format_params(total)} {Fore.RESET}|| \"\n",
154
+ " f\"{Fore.BLUE}Trainable %: {Fore.WHITE}{percent:.4f}{Style.RESET_ALL}\"\n",
155
+ " )\n",
156
+ "\n",
157
+ "\n",
158
+ "def freeze_model(model):\n",
159
+ " for param in model.parameters():\n",
160
+ " param.requires_grad = False\n",
161
+ " return model.eval()\n",
162
+ "\n",
163
+ "\n",
164
+ "def print_size(obj, name=\"Object\"):\n",
165
+ " size_bytes = sys.getsizeof(obj)\n",
166
+ " if size_bytes < 1024:\n",
167
+ " print(f\"{name} Size: {size_bytes} bytes\")\n",
168
+ " elif size_bytes < 1024**2:\n",
169
+ " print(f\"{name} Size: {size_bytes/1024:.2f} KB\")\n",
170
+ " elif size_bytes < 1024**3:\n",
171
+ " print(f\"{name} Size: {size_bytes/1024**2:.2f} MB\")\n",
172
+ " else:\n",
173
+ " print(f\"{name} Size: {size_bytes/1024**3:.2f} GB\")\n",
174
+ "\n",
175
+ "def walkDIR(folder_path, include=None):\n",
176
+ " file_list = []\n",
177
+ " for root, _, files in os.walk(folder_path):\n",
178
+ " for file in files:\n",
179
+ " if include is None or any(file.endswith(ext) for ext in include):\n",
180
+ " file_list.append(os.path.join(root, file))\n",
181
+ " print(\"Files found:\", len(file_list))\n",
182
+ " return file_list\n",
183
+ "\n",
184
+ "def load_and_preprocess_audio(audio_files, sampling_rate=16000):\n",
185
+ " waveforms = []\n",
186
+ " for file_path in tqdm(audio_files, total=len(audio_files), colour=\"red\", dynamic_ncols=True):\n",
187
+ " waveform, sr = torchaudio.load(file_path)\n",
188
+ " if sr != sampling_rate:\n",
189
+ " waveform = torchaudio.functional.resample(waveform, sr, sampling_rate)\n",
190
+ " if waveform.shape[0] > 1:\n",
191
+ " waveform = torch.mean(waveform, dim=0, keepdim=True) # Convert to mono\n",
192
+ " wave_np = waveform.squeeze().numpy().astype(np.float32)\n",
193
+ " waveforms.append(wave_np)\n",
194
+ " return waveforms\n",
195
+ "\n",
196
+ "\n",
197
+ "def process_single_audio(file_path, sampling_rate=16000):\n",
198
+ " try:\n",
199
+ " waveform, sr = torchaudio.load(file_path)\n",
200
+ " if sr != sampling_rate:\n",
201
+ " waveform = torchaudio.functional.resample(waveform, sr, sampling_rate)\n",
202
+ " if waveform.shape[0] > 1:\n",
203
+ " waveform = torch.mean(waveform, dim=0, keepdim=True) # Convert to mono\n",
204
+ " wave_np = waveform.squeeze().numpy().astype(np.float32)\n",
205
+ " return wave_np\n",
206
+ " except Exception as e:\n",
207
+ " print(f\"Error processing {file_path}: {e}\")\n",
208
+ " return None\n",
209
+ "\n",
210
+ "def load_and_preprocess_audio_parallel(audio_files, sampling_rate=16000, n_jobs=-1):\n",
211
+ " results = Parallel(n_jobs=n_jobs, backend='loky')(\n",
212
+ " delayed(process_single_audio)(file_path, sampling_rate) for file_path in audio_files\n",
213
+ " )\n",
214
+ " return [res for res in results if res is not None]\n",
215
+ "\n",
216
+ "\n",
217
+ "def setup_stable_diffusion():\n",
218
+ " model_id = \"stabilityai/stable-diffusion-2-1\"\n",
219
+ " pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to(device)\n",
220
+ "\n",
221
+ " vae = pipe.vae\n",
222
+ " unet = pipe.unet\n",
223
+ " scheduler = pipe.scheduler\n",
224
+ "\n",
225
+ " # del pipe.text_encoder\n",
226
+ " torch.cuda.empty_cache()\n",
227
+ " \n",
228
+ " vae = freeze_model(vae)\n",
229
+ " unet = freeze_model(unet)\n",
230
+ " \n",
231
+ " print_trainable_params(vae, \"VAE\")\n",
232
+ " print_trainable_params(unet, \"UNet\")\n",
233
+ " return vae, unet, scheduler, pipe"
234
+ ]
235
+ },
236
+ {
237
+ "cell_type": "markdown",
238
+ "metadata": {},
239
+ "source": [
240
+ "## Old Dataset Class"
241
+ ]
242
+ },
243
+ {
244
+ "cell_type": "code",
245
+ "execution_count": 3,
246
+ "metadata": {},
247
+ "outputs": [
248
+ {
249
+ "name": "stdout",
250
+ "output_type": "stream",
251
+ "text": [
252
+ "The history saving thread hit an unexpected error (OperationalError('disk I/O error')).History will not be written to the database.\n"
253
+ ]
254
+ },
255
+ {
256
+ "data": {
257
+ "text/html": [
258
+ "<div>\n",
259
+ "<style scoped>\n",
260
+ " .dataframe tbody tr th:only-of-type {\n",
261
+ " vertical-align: middle;\n",
262
+ " }\n",
263
+ "\n",
264
+ " .dataframe tbody tr th {\n",
265
+ " vertical-align: top;\n",
266
+ " }\n",
267
+ "\n",
268
+ " .dataframe thead th {\n",
269
+ " text-align: right;\n",
270
+ " }\n",
271
+ "</style>\n",
272
+ "<table border=\"1\" class=\"dataframe\">\n",
273
+ " <thead>\n",
274
+ " <tr style=\"text-align: right;\">\n",
275
+ " <th></th>\n",
276
+ " <th>image_path</th>\n",
277
+ " <th>audio_path</th>\n",
278
+ " </tr>\n",
279
+ " </thead>\n",
280
+ " <tbody>\n",
281
+ " <tr>\n",
282
+ " <th>0</th>\n",
283
+ " <td>/scratch/IITB/ai-at-ieor/23m1521/datasets/Vaan...</td>\n",
284
+ " <td>/scratch/IITB/ai-at-ieor/23m1521/datasets/Vaan...</td>\n",
285
+ " </tr>\n",
286
+ " <tr>\n",
287
+ " <th>1</th>\n",
288
+ " <td>/scratch/IITB/ai-at-ieor/23m1521/datasets/Vaan...</td>\n",
289
+ " <td>/scratch/IITB/ai-at-ieor/23m1521/datasets/Vaan...</td>\n",
290
+ " </tr>\n",
291
+ " <tr>\n",
292
+ " <th>2</th>\n",
293
+ " <td>/scratch/IITB/ai-at-ieor/23m1521/datasets/Vaan...</td>\n",
294
+ " <td>/scratch/IITB/ai-at-ieor/23m1521/datasets/Vaan...</td>\n",
295
+ " </tr>\n",
296
+ " <tr>\n",
297
+ " <th>3</th>\n",
298
+ " <td>/scratch/IITB/ai-at-ieor/23m1521/datasets/Vaan...</td>\n",
299
+ " <td>/scratch/IITB/ai-at-ieor/23m1521/datasets/Vaan...</td>\n",
300
+ " </tr>\n",
301
+ " <tr>\n",
302
+ " <th>4</th>\n",
303
+ " <td>/scratch/IITB/ai-at-ieor/23m1521/datasets/Vaan...</td>\n",
304
+ " <td>/scratch/IITB/ai-at-ieor/23m1521/datasets/Vaan...</td>\n",
305
+ " </tr>\n",
306
+ " <tr>\n",
307
+ " <th>...</th>\n",
308
+ " <td>...</td>\n",
309
+ " <td>...</td>\n",
310
+ " </tr>\n",
311
+ " <tr>\n",
312
+ " <th>11485</th>\n",
313
+ " <td>/scratch/IITB/ai-at-ieor/23m1521/datasets/Vaan...</td>\n",
314
+ " <td>/scratch/IITB/ai-at-ieor/23m1521/datasets/Vaan...</td>\n",
315
+ " </tr>\n",
316
+ " <tr>\n",
317
+ " <th>11486</th>\n",
318
+ " <td>/scratch/IITB/ai-at-ieor/23m1521/datasets/Vaan...</td>\n",
319
+ " <td>/scratch/IITB/ai-at-ieor/23m1521/datasets/Vaan...</td>\n",
320
+ " </tr>\n",
321
+ " <tr>\n",
322
+ " <th>11487</th>\n",
323
+ " <td>/scratch/IITB/ai-at-ieor/23m1521/datasets/Vaan...</td>\n",
324
+ " <td>/scratch/IITB/ai-at-ieor/23m1521/datasets/Vaan...</td>\n",
325
+ " </tr>\n",
326
+ " <tr>\n",
327
+ " <th>11488</th>\n",
328
+ " <td>/scratch/IITB/ai-at-ieor/23m1521/datasets/Vaan...</td>\n",
329
+ " <td>/scratch/IITB/ai-at-ieor/23m1521/datasets/Vaan...</td>\n",
330
+ " </tr>\n",
331
+ " <tr>\n",
332
+ " <th>11489</th>\n",
333
+ " <td>/scratch/IITB/ai-at-ieor/23m1521/datasets/Vaan...</td>\n",
334
+ " <td>/scratch/IITB/ai-at-ieor/23m1521/datasets/Vaan...</td>\n",
335
+ " </tr>\n",
336
+ " </tbody>\n",
337
+ "</table>\n",
338
+ "<p>73755 rows × 2 columns</p>\n",
339
+ "</div>"
340
+ ],
341
+ "text/plain": [
342
+ " image_path \\\n",
343
+ "0 /scratch/IITB/ai-at-ieor/23m1521/datasets/Vaan... \n",
344
+ "1 /scratch/IITB/ai-at-ieor/23m1521/datasets/Vaan... \n",
345
+ "2 /scratch/IITB/ai-at-ieor/23m1521/datasets/Vaan... \n",
346
+ "3 /scratch/IITB/ai-at-ieor/23m1521/datasets/Vaan... \n",
347
+ "4 /scratch/IITB/ai-at-ieor/23m1521/datasets/Vaan... \n",
348
+ "... ... \n",
349
+ "11485 /scratch/IITB/ai-at-ieor/23m1521/datasets/Vaan... \n",
350
+ "11486 /scratch/IITB/ai-at-ieor/23m1521/datasets/Vaan... \n",
351
+ "11487 /scratch/IITB/ai-at-ieor/23m1521/datasets/Vaan... \n",
352
+ "11488 /scratch/IITB/ai-at-ieor/23m1521/datasets/Vaan... \n",
353
+ "11489 /scratch/IITB/ai-at-ieor/23m1521/datasets/Vaan... \n",
354
+ "\n",
355
+ " audio_path \n",
356
+ "0 /scratch/IITB/ai-at-ieor/23m1521/datasets/Vaan... \n",
357
+ "1 /scratch/IITB/ai-at-ieor/23m1521/datasets/Vaan... \n",
358
+ "2 /scratch/IITB/ai-at-ieor/23m1521/datasets/Vaan... \n",
359
+ "3 /scratch/IITB/ai-at-ieor/23m1521/datasets/Vaan... \n",
360
+ "4 /scratch/IITB/ai-at-ieor/23m1521/datasets/Vaan... \n",
361
+ "... ... \n",
362
+ "11485 /scratch/IITB/ai-at-ieor/23m1521/datasets/Vaan... \n",
363
+ "11486 /scratch/IITB/ai-at-ieor/23m1521/datasets/Vaan... \n",
364
+ "11487 /scratch/IITB/ai-at-ieor/23m1521/datasets/Vaan... \n",
365
+ "11488 /scratch/IITB/ai-at-ieor/23m1521/datasets/Vaan... \n",
366
+ "11489 /scratch/IITB/ai-at-ieor/23m1521/datasets/Vaan... \n",
367
+ "\n",
368
+ "[73755 rows x 2 columns]"
369
+ ]
370
+ },
371
+ "execution_count": 3,
372
+ "metadata": {},
373
+ "output_type": "execute_result"
374
+ }
375
+ ],
376
+ "source": [
377
+ "# # Dataset & Dataloader\n",
378
+ "# ==================================================================\n",
379
+ "# I M A G E - A U D I O - D A T A S E T\n",
380
+ "# ==================================================================\n",
381
+ "def denormalize_image(img_tensor):\n",
382
+ " mean = np.array([0.48145466, 0.4578275, 0.40821073]).reshape(3, 1, 1)\n",
383
+ " std = np.array([0.26862954, 0.26130258, 0.27577711]).reshape(3, 1, 1)\n",
384
+ " \n",
385
+ " img = img_tensor * std + mean # de-normalize\n",
386
+ " img = np.clip(img, 0, 1) # clip to [0, 1] for display\n",
387
+ " img = np.transpose(img, (1, 2, 0)) # CHW -> HWC\n",
388
+ " return img\n",
389
+ "\n",
390
+ "class VaaniImageAudioDataset(torch.utils.data.Dataset):\n",
391
+ " def __init__(self, df):\n",
392
+ " self.image_paths = df.image_path.tolist()\n",
393
+ " self.audio_paths = df.audio_path.tolist()\n",
394
+ " self.image_transforms = v2.Compose([\n",
395
+ " v2.ToImage(),\n",
396
+ " v2.Resize((224, 224), antialias=True),\n",
397
+ " v2.RandomCrop(size=(224, 224)),\n",
398
+ " v2.ToDtype(torch.float16, scale=True),\n",
399
+ " v2.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], \n",
400
+ " std=[0.26862954, 0.26130258, 0.27577711])\n",
401
+ " ])\n",
402
+ " \n",
403
+ " self.feature_extractor = WhisperFeatureExtractor.from_pretrained(\"openai/whisper-large-v2\")\n",
404
+ " self.sampling_rate = self.feature_extractor.sampling_rate\n",
405
+ "\n",
406
+ " def __len__(self):\n",
407
+ " return len(self.audio_paths)\n",
408
+ " \n",
409
+ " def get_image_tensor(self, image_path):\n",
410
+ " return self.image_transforms(Image.open(image_path).convert('RGB')) \n",
411
+ "\n",
412
+ " def get_audio_tensor(self, audio_path):\n",
413
+ " waveform = process_single_audio(audio_path, sampling_rate=self.sampling_rate)\n",
414
+ " return self.feature_extractor(waveform, sampling_rate=self.sampling_rate, return_tensors=\"pt\").input_features\n",
415
+ " \n",
416
+ " def __getitem__(self, idx):\n",
417
+ " return {\n",
418
+ " 'image_path': self.image_paths[idx],\n",
419
+ " 'image_tensor': self.get_image_tensor(self.image_paths[idx]),\n",
420
+ " 'audio_path': self.audio_paths[idx],\n",
421
+ " 'audio_tensor': self.get_audio_tensor(self.audio_paths[idx])\n",
422
+ " }\n",
423
+ " \n",
424
+ " \n",
425
+ "\n",
426
+ "train_df = pd.read_csv(\"/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/Img_Audio_Alignment/available_img_audios_TRAIN3.csv\")\n",
427
+ "test_df = pd.read_csv(\"/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/Img_Audio_Alignment/available_img_audios_TEST2.csv\")\n",
428
+ "audio_tensors_savedir = '/scratch/IITB/ai-at-ieor/23m1521/datasets/Vaani/Hindi_Audio_tensors/'\n",
429
+ "\n",
430
+ "df = pd.concat([train_df, test_df], axis=0)\n",
431
+ "df"
432
+ ]
433
+ },
434
+ {
435
+ "cell_type": "code",
436
+ "execution_count": 4,
437
+ "metadata": {},
438
+ "outputs": [],
439
+ "source": [
440
+ "# savedir = '/scratch/IITB/ai-at-ieor/23m1521/datasets/Vaani/Hindi_Image_Audio_SD21_Whisper_features/'\n",
441
+ "# done = [i.split('.')[:-2] for i in os.listdir(savedir) if i.endswith('.pt')]\n",
442
+ "# len(done)\n",
443
+ "# print(done[:3])\n",
444
+ "\n",
445
+ "# df['done'] = df['image_path'].apply(lambda x: os.path.basename(x).split('.')[:-1] in done)\n",
446
+ "# print(df.done.value_counts())\n",
447
+ "\n",
448
+ "# df = df[df['done'] == False]\n",
449
+ "# df.drop(columns=['done'], inplace=True)\n",
450
+ "# df = df.reset_index(drop=True)\n",
451
+ "# df"
452
+ ]
453
+ },
454
+ {
455
+ "cell_type": "code",
456
+ "execution_count": 5,
457
+ "metadata": {},
458
+ "outputs": [
459
+ {
460
+ "name": "stdout",
461
+ "output_type": "stream",
462
+ "text": [
463
+ "Length of Train dataset: 73755\n",
464
+ "Total batches: 2305\n",
465
+ " 73755\n",
466
+ "Total batches: 2305\n",
467
+ "Image batch shape: torch.Size([32, 3, 224, 224])\n",
468
+ "Audio batch shape: torch.Size([32, 1, 80, 3000])\n"
469
+ ]
470
+ }
471
+ ],
472
+ "source": [
473
+ "dataset = VaaniImageAudioDataset(df)\n",
474
+ "\n",
475
+ "# s = 0.009\n",
476
+ "# dataset, _ = torch.utils.data.random_split(dataset, [s, 1-s], torch.manual_seed(42))\n",
477
+ "\n",
478
+ "print(\"Length of Train dataset:\", len(dataset))\n",
479
+ "\n",
480
+ "\n",
481
+ "BATCH_SIZE = int(32)\n",
482
+ "dataloader = torch.utils.data.DataLoader(\n",
483
+ " dataset,\n",
484
+ " batch_size=BATCH_SIZE, \n",
485
+ " shuffle=False, \n",
486
+ " num_workers=48,\n",
487
+ " pin_memory=False,\n",
488
+ " drop_last=False,\n",
489
+ " persistent_workers=True\n",
490
+ ")\n",
491
+ "print('Total batches:', len(dataloader))\n",
492
+ "\n",
493
+ "batch = next(iter(dataloader))\n",
494
+ "image_tensor_batch = batch['image_tensor'].to(device=device)\n",
495
+ "audio_tensor_batch = batch['audio_tensor'].to(device=device)\n",
496
+ "image_paths_batch = batch['image_path']\n",
497
+ "audio_paths_batch = batch['audio_path']\n",
498
+ "print(\"Image batch shape:\", image_tensor_batch.shape)\n",
499
+ "print(\"Audio batch shape:\", audio_tensor_batch.shape)\n",
500
+ "# for batch in tqdm(dataloader):\n",
501
+ "# pass"
502
+ ]
503
+ },
504
+ {
505
+ "cell_type": "markdown",
506
+ "metadata": {},
507
+ "source": [
508
+ "# Preparing Whisper Audio Encoder"
509
+ ]
510
+ },
511
+ {
512
+ "cell_type": "code",
513
+ "execution_count": 6,
514
+ "metadata": {},
515
+ "outputs": [],
516
+ "source": [
517
+ "class WhisperEncoder2(nn.Module):\n",
518
+ " def __init__(\n",
519
+ " self, \n",
520
+ " encoder, \n",
521
+ " input_dim=1280, \n",
522
+ " output_dim=1024, \n",
523
+ " n_heads=8, \n",
524
+ " num_layers=2, \n",
525
+ " dropout=0.1\n",
526
+ " ):\n",
527
+ " super().__init__()\n",
528
+ "\n",
529
+ " self.encoder = encoder.eval()\n",
530
+ " for param in self.encoder.parameters():\n",
531
+ " param.requires_grad = False\n",
532
+ "\n",
533
+ " # Learnable query token to act like CLS\n",
534
+ " self.query = nn.Parameter(torch.randn(1, 1, input_dim)) # [1, 1, D]\n",
535
+ "\n",
536
+ " encoder_layer = nn.TransformerEncoderLayer(\n",
537
+ " d_model=input_dim, \n",
538
+ " nhead=n_heads, \n",
539
+ " dim_feedforward=input_dim * 4, \n",
540
+ " dropout=dropout, \n",
541
+ " batch_first=True\n",
542
+ " )\n",
543
+ " self.transformer = nn.TransformerEncoder(\n",
544
+ " encoder_layer, \n",
545
+ " num_layers=num_layers\n",
546
+ " )\n",
547
+ "\n",
548
+ " self.proj = nn.Linear(input_dim, output_dim)\n",
549
+ "\n",
550
+ " def forward(self, input_features):\n",
551
+ " with torch.no_grad():\n",
552
+ " encoder_outputs = self.encoder(input_features=input_features)\n",
553
+ " hidden_states = encoder_outputs.last_hidden_state # [B, T, D]\n",
554
+ "\n",
555
+ " B = hidden_states.size(0)\n",
556
+ "\n",
557
+ " # Expand learnable query to match batch size\n",
558
+ " query = self.query.expand(B, -1, -1) # [B, 1, D]\n",
559
+ " x = torch.cat([query, hidden_states], dim=1) # [B, 1+T, D]\n",
560
+ "\n",
561
+ " x = self.transformer(x) # [B, 1+T, D]\n",
562
+ " pooled = x[:, 0:1, :] # Take output of query token only\n",
563
+ "\n",
564
+ " return self.proj(pooled) # [B, 1, output_dim]\n",
565
+ "\n",
566
+ "\n",
567
+ "whisper_model = WhisperModel.from_pretrained(\n",
568
+ " pretrained_model_name_or_path=\"openai/whisper-large-v2\",\n",
569
+ " cache_dir='/scratch/IITB/ai-at-ieor/23m1521/hf_cache/'\n",
570
+ " )\n",
571
+ "\n",
572
+ "# audio_encoder = WhisperEncoder2(encoder=whisper_model.encoder).to(device)\n",
573
+ "# whisper_encoder = freeze_model(whisper_model.encoder).eval()\n",
574
+ "\n",
575
+ "whisper_encoder = torch.compile(\n",
576
+ " freeze_model(whisper_model.encoder), \n",
577
+ " backend=\"aot_eager\"\n",
578
+ " ).eval().to(device)"
579
+ ]
580
+ },
581
+ {
582
+ "cell_type": "markdown",
583
+ "metadata": {},
584
+ "source": [
585
+ "## Train Image Features"
586
+ ]
587
+ },
588
+ {
589
+ "cell_type": "code",
590
+ "execution_count": 7,
591
+ "metadata": {},
592
+ "outputs": [
593
+ {
594
+ "name": "stdout",
595
+ "output_type": "stream",
596
+ "text": [
597
+ "cuda\n",
598
+ "\n"
599
+ ]
600
+ }
601
+ ],
602
+ "source": [
603
+ "print(device)"
604
+ ]
605
+ },
606
+ {
607
+ "cell_type": "code",
608
+ "execution_count": null,
609
+ "metadata": {},
610
+ "outputs": [
611
+ {
612
+ "data": {
613
+ "application/vnd.jupyter.widget-view+json": {
614
+ "model_id": "76117f359bd14657904178fb83c3966d",
615
+ "version_major": 2,
616
+ "version_minor": 0
617
+ },
618
+ "text/plain": [
619
+ "[Extracting Features]: 0%| | 0/2305 [00:00<?, ?it/s]"
620
+ ]
621
+ },
622
+ "metadata": {},
623
+ "output_type": "display_data"
624
+ }
625
+ ],
626
+ "source": [
627
+ "import gc\n",
628
+ "def force_gc():\n",
629
+ " gc.collect()\n",
630
+ " torch.cuda.empty_cache()\n",
631
+ " # torch.cuda.ipc_collect() # Optional: cleans up interprocess caches\n",
632
+ "\n",
633
+ "\n",
634
+ "savedir = '/scratch/IITB/ai-at-ieor/23m1521/datasets/Vaani/Hindi_Image_Audio_SD21_Whisper_features/'\n",
635
+ "os.makedirs(savedir, exist_ok=True)\n",
636
+ "\n",
637
+ "train_loop = tqdm(dataloader, desc=f\"[Extracting Features]\", colour='blue', dynamic_ncols=True)\n",
638
+ "for i, batch in enumerate(train_loop):\n",
639
+ " # if i == 1:break\n",
640
+ " \n",
641
+ " with torch.cuda.amp.autocast():\n",
642
+ "\n",
643
+ " image_paths_batch = batch['image_path']\n",
644
+ " image_tensor_batch = batch['image_tensor'].to(device=device)\n",
645
+ " \n",
646
+ " audio_tensor_batch = batch['audio_tensor'].squeeze(1).to(device=device)\n",
647
+ " audio_paths_batch = batch['audio_path']\n",
648
+ " with torch.no_grad():\n",
649
+ " encoder_outputs = whisper_encoder(input_features=audio_tensor_batch)\n",
650
+ " hidden_states = encoder_outputs.last_hidden_state\n",
651
+ " \n",
652
+ "\n",
653
+ " for i in range(len(image_paths_batch)):\n",
654
+ " torch.save({\n",
655
+ " 'image_path': image_paths_batch[i],\n",
656
+ " 'image_features': image_tensor_batch[i].detach().cpu(),\n",
657
+ " 'audio_path': audio_paths_batch[i],\n",
658
+ " 'audio_features': hidden_states[i].detach().cpu(),\n",
659
+ " }, os.path.join(savedir, f\"{os.path.basename(image_paths_batch[i])}.pt\")\n",
660
+ " )\n",
661
+ " \n",
662
+ " if i % 20 == 0:\n",
663
+ " del image_tensor_batch, audio_tensor_batch, hidden_states\n",
664
+ " force_gc"
665
+ ]
666
+ },
667
+ {
668
+ "cell_type": "code",
669
+ "execution_count": 9,
670
+ "metadata": {},
671
+ "outputs": [],
672
+ "source": [
673
+ "# !rm -rf '/scratch/IITB/ai-at-ieor/23m1521/datasets/Vaani/Hindi_Image_Audio_SD21_Whisper_features/'"
674
+ ]
675
+ },
676
+ {
677
+ "cell_type": "code",
678
+ "execution_count": 9,
679
+ "metadata": {},
680
+ "outputs": [
681
+ {
682
+ "data": {
683
+ "text/plain": [
684
+ "73755"
685
+ ]
686
+ },
687
+ "execution_count": 9,
688
+ "metadata": {},
689
+ "output_type": "execute_result"
690
+ }
691
+ ],
692
+ "source": [
693
+ "import os\n",
694
+ "savedir = '/scratch/IITB/ai-at-ieor/23m1521/datasets/Vaani/Hindi_Image_Audio_SD21_Whisper_features/'\n",
695
+ "\n",
696
+ "len(os.listdir(savedir))"
697
+ ]
698
+ },
699
+ {
700
+ "cell_type": "code",
701
+ "execution_count": 10,
702
+ "metadata": {},
703
+ "outputs": [
704
+ {
705
+ "name": "stdout",
706
+ "output_type": "stream",
707
+ "text": [
708
+ "549G\t/scratch/IITB/ai-at-ieor/23m1521/datasets/Vaani/Hindi_Image_Audio_SD21_Whisper_features/\n"
709
+ ]
710
+ }
711
+ ],
712
+ "source": [
713
+ "!du -sh /scratch/IITB/ai-at-ieor/23m1521/datasets/Vaani/Hindi_Image_Audio_SD21_Whisper_features/"
714
+ ]
715
+ },
716
+ {
717
+ "cell_type": "markdown",
718
+ "metadata": {},
719
+ "source": [
720
+ "## New Dataset Class"
721
+ ]
722
+ },
723
+ {
724
+ "cell_type": "code",
725
+ "execution_count": 23,
726
+ "metadata": {},
727
+ "outputs": [
728
+ {
729
+ "data": {
730
+ "text/plain": [
731
+ "{'image_path': '/scratch/IITB/ai-at-ieor/23m1521/datasets/Vaani/Images/Folder3/IISc_VaaniProject_Varanasi-SPECIFIC_01655.jpg',\n",
732
+ " 'image_features': tensor([[[ 1.1123, 1.1123, 1.1123, ..., 0.9526, 0.9526, 0.9380],\n",
733
+ " [ 1.1270, 1.1123, 1.1123, ..., 0.9526, 0.9526, 0.9526],\n",
734
+ " [ 1.1270, 1.1270, 1.1123, ..., 0.9526, 0.9526, 0.9526],\n",
735
+ " ...,\n",
736
+ " [-0.4346, -0.4199, -0.4346, ..., -1.1943, -1.1650, -1.1797],\n",
737
+ " [-0.4783, -0.4636, -0.4490, ..., -1.2383, -1.1650, -1.1797],\n",
738
+ " [-0.4346, -0.4490, -0.4927, ..., -1.2529, -1.1797, -1.1943]],\n",
739
+ " \n",
740
+ " [[ 1.4307, 1.4307, 1.4307, ..., 1.3545, 1.3545, 1.3389],\n",
741
+ " [ 1.4453, 1.4307, 1.4307, ..., 1.3545, 1.3545, 1.3545],\n",
742
+ " [ 1.4453, 1.4453, 1.4307, ..., 1.3545, 1.3545, 1.3545],\n",
743
+ " ...,\n",
744
+ " [-0.5513, -0.5366, -0.5366, ..., -1.2715, -1.2422, -1.2568],\n",
745
+ " [-0.5815, -0.5664, -0.5513, ..., -1.3164, -1.2422, -1.2568],\n",
746
+ " [-0.5366, -0.5513, -0.5962, ..., -1.3320, -1.2568, -1.2715]],\n",
747
+ " \n",
748
+ " [[ 1.6621, 1.6621, 1.6621, ..., 1.6621, 1.6621, 1.6475],\n",
749
+ " [ 1.6758, 1.6621, 1.6621, ..., 1.6621, 1.6621, 1.6621],\n",
750
+ " [ 1.6758, 1.6758, 1.6621, ..., 1.6621, 1.6621, 1.6621],\n",
751
+ " ...,\n",
752
+ " [-0.5132, -0.4990, -0.4990, ..., -1.1377, -1.0811, -1.1240],\n",
753
+ " [-0.5415, -0.5273, -0.5132, ..., -1.1660, -1.0811, -1.1377],\n",
754
+ " [-0.4990, -0.5132, -0.5557, ..., -1.1816, -1.0957, -1.1523]]],\n",
755
+ " dtype=torch.float16),\n",
756
+ " 'audio_path': '/scratch/IITB/ai-at-ieor/23m1521/datasets/Vaani/Audios/Hindi/UttarPradesh_Varanasi/IISc_VaaniProject_M_UP_Varanasi_18587414_0917000000_UPVNTA_123286_6733_8870.wav',\n",
757
+ " 'audio_features': tensor([[-1.1608e+00, -9.0333e-02, 3.3006e-02, ..., -4.3307e+00,\n",
758
+ " -1.4483e-01, -1.0611e+00],\n",
759
+ " [ 2.5317e-01, -2.7337e-01, -1.8108e-01, ..., -3.6791e+00,\n",
760
+ " 5.2682e-01, -6.8573e-01],\n",
761
+ " [ 4.7004e-01, -8.1346e-01, 1.0142e+00, ..., -2.2765e+00,\n",
762
+ " 1.2923e+00, -8.2782e-01],\n",
763
+ " ...,\n",
764
+ " [-6.1619e-03, -7.1685e-03, -1.0914e-02, ..., 6.0164e-03,\n",
765
+ " -4.9124e-03, -1.9412e-03],\n",
766
+ " [-2.5727e-03, -2.7489e-03, -9.8100e-03, ..., -5.9428e-03,\n",
767
+ " -1.4006e-03, 4.9841e-04],\n",
768
+ " [-2.5339e-03, -1.1025e-02, -1.6143e-02, ..., -8.3381e-03,\n",
769
+ " 5.2792e-04, 1.2501e-02]])}"
770
+ ]
771
+ },
772
+ "execution_count": 23,
773
+ "metadata": {},
774
+ "output_type": "execute_result"
775
+ }
776
+ ],
777
+ "source": [
778
+ "idx = 1\n",
779
+ "torch.load(features_paths[idx])"
780
+ ]
781
+ },
782
+ {
783
+ "cell_type": "code",
784
+ "execution_count": null,
785
+ "metadata": {},
786
+ "outputs": [],
787
+ "source": [
788
+ "# ==================================================================\n",
789
+ "# I M A G E - A U D I O - D A T A S E T\n",
790
+ "# ==================================================================\n",
791
+ "def denormalize_image(img_tensor):\n",
792
+ " mean = np.array([0.48145466, 0.4578275, 0.40821073]).reshape(3, 1, 1)\n",
793
+ " std = np.array([0.26862954, 0.26130258, 0.27577711]).reshape(3, 1, 1)\n",
794
+ " \n",
795
+ " img = img_tensor * std + mean # de-normalize\n",
796
+ " img = np.clip(img, 0, 1) # clip to [0, 1] for display\n",
797
+ " img = np.transpose(img, (1, 2, 0)) # CHW -> HWC\n",
798
+ " return img\n",
799
+ "\n",
800
+ "class VaaniImageAudioDataset(torch.utils.data.Dataset):\n",
801
+ " def __init__(self, features_paths):\n",
802
+ " self.features_paths = features_paths\n",
803
+ " self.image_transforms = v2.Compose([\n",
804
+ " v2.ToImage(),\n",
805
+ " v2.Resize((224, 224), antialias=True),\n",
806
+ " v2.RandomCrop(size=(224, 224)),\n",
807
+ " v2.ToDtype(torch.float16, scale=True),\n",
808
+ " v2.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], \n",
809
+ " std=[0.26862954, 0.26130258, 0.27577711])\n",
810
+ " ])\n",
811
+ " \n",
812
+ " self.feature_extractor = WhisperFeatureExtractor.from_pretrained(\"openai/whisper-large-v2\")\n",
813
+ " self.sampling_rate = self.feature_extractor.sampling_rate\n",
814
+ "\n",
815
+ " def __len__(self):\n",
816
+ " return len(self.features_paths)\n",
817
+ " \n",
818
+ " def __getitem__(self, idx):\n",
819
+ " return torch.load(self.features_paths[idx])\n",
820
+ " \n",
821
+ " \n",
822
+ "\n",
823
+ "train_df = pd.read_csv(\"/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/Img_Audio_Alignment/available_img_audios_TRAIN3.csv\")\n",
824
+ "test_df = pd.read_csv(\"/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/Img_Audio_Alignment/available_img_audios_TEST2.csv\")\n",
825
+ "audio_tensors_savedir = '/scratch/IITB/ai-at-ieor/23m1521/datasets/Vaani/Hindi_Audio_tensors/'\n",
826
+ "features_savedir = '/scratch/IITB/ai-at-ieor/23m1521/datasets/Vaani/Hindi_Image_Audio_SD21_Whisper_features/'\n",
827
+ "features_paths = [f\"{features_savedir}/{i}\" for i in os.listdir(features_savedir)]\n",
828
+ "\n",
829
+ "df = pd.concat([train_df, test_df], axis=0)\n",
830
+ "dataset = VaaniImageAudioDataset(features_paths)\n",
831
+ "\n",
832
+ "s = 0.005\n",
833
+ "dataset, _ = torch.utils.data.random_split(dataset, [s, 1-s], torch.manual_seed(42))\n",
834
+ "\n",
835
+ "print(\"Length of Train dataset:\", len(dataset))\n",
836
+ "\n",
837
+ "\n",
838
+ "BATCH_SIZE = int(64)\n",
839
+ "dataloader = torch.utils.data.DataLoader(\n",
840
+ " dataset,\n",
841
+ " batch_size=BATCH_SIZE, \n",
842
+ " shuffle=True, \n",
843
+ " num_workers=48,\n",
844
+ " pin_memory=True,\n",
845
+ " drop_last=False,\n",
846
+ " prefetch_factor=5,\n",
847
+ " persistent_workers=True\n",
848
+ ")\n",
849
+ "print('Total batches:', len(dataloader))\n",
850
+ "\n",
851
+ "batch = next(iter(dataloader))\n",
852
+ "image_tensor_batch = batch['image_tensor'].to(device=device)\n",
853
+ "audio_tensor_batch = batch['audio_tensor'].to(device=device)\n",
854
+ "image_paths_batch = batch['image_path']\n",
855
+ "audio_paths_batch = batch['audio_path']\n",
856
+ "print(\"Image batch shape:\", image_tensor_batch.shape)\n",
857
+ "print(\"Audio batch shape:\", audio_tensor_batch.shape)\n",
858
+ "# for batch in tqdm(dataloader):\n",
859
+ "# pass"
860
+ ]
861
+ },
862
+ {
863
+ "cell_type": "code",
864
+ "execution_count": 7,
865
+ "metadata": {},
866
+ "outputs": [
867
+ {
868
+ "name": "stdout",
869
+ "output_type": "stream",
870
+ "text": [
871
+ "Train Dataset: 26810\n",
872
+ "Test Dataset: 11490\n"
873
+ ]
874
+ },
875
+ {
876
+ "data": {
877
+ "text/plain": [
878
+ "{'image_path': '/scratch/IITB/ai-at-ieor/23m1521/datasets/Vaani/Images/Folder3/IISc_VaaniProject_Lucknow-SPECIFIC_00826.jpg',\n",
879
+ " 'image_feature': tensor([-0.1034, 0.4547, -0.3613, ..., -0.4897, -0.0025, 0.6462]),\n",
880
+ " 'audio_path': '/scratch/IITB/ai-at-ieor/23m1521/datasets/Vaani/Audios/Hindi/UttarPradesh_Lucknow/IISc_VaaniProject_K_UttarPradesh_Lucknow_Lucknow844425030382473_010_Lucknow-SPECIFIC_00826_4706_6462.wav',\n",
881
+ " 'audio_tensor': tensor([-0.0131, -0.0133, -0.0105, ..., -0.0070, -0.0086, -0.0096])}"
882
+ ]
883
+ },
884
+ "execution_count": 7,
885
+ "metadata": {},
886
+ "output_type": "execute_result"
887
+ }
888
+ ],
889
+ "source": [
890
+ "# ==================================================================\n",
891
+ "# I M A G E - A U D I O - D A T A S E T\n",
892
+ "# ==================================================================\n",
893
+ "class VaaniImageAudioDataset(torch.utils.data.Dataset):\n",
894
+ " def __init__(self, df, image_features_savedir, audio_tensors_savedir):\n",
895
+ " self.image_paths = df.image_path.tolist()\n",
896
+ " self.audio_paths = df.audio_path.tolist()\n",
897
+ " self.image_features_savedir = image_features_savedir\n",
898
+ " self.audio_tensors_savedir = audio_tensors_savedir\n",
899
+ "\n",
900
+ " def __len__(self):\n",
901
+ " return len(self.audio_paths)\n",
902
+ "\n",
903
+ " def __getitem__(self, idx):\n",
904
+ " return {\n",
905
+ " 'image_path': self.image_paths[idx],\n",
906
+ " 'image_feature': torch.load(os.path.join(\n",
907
+ " self.image_features_savedir, \n",
908
+ " f\"{os.path.basename(self.image_paths[idx])}.pt\"))['image_features'],\n",
909
+ " 'audio_path': self.audio_paths[idx],\n",
910
+ " 'audio_tensor': torch.load(os.path.join(\n",
911
+ " audio_tensors_savedir, \n",
912
+ " f\"{os.path.basename(self.audio_paths[idx])}.pt\"))['audio_tensor']\n",
913
+ " }\n",
914
+ " \n",
915
+ "\n",
916
+ "train_df = pd.read_csv(\"/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/Img_Audio_Alignment/available_img_audios_TRAIN2.csv\")\n",
917
+ "test_df = pd.read_csv(\"/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/Img_Audio_Alignment/available_img_audios_TEST2.csv\")\n",
918
+ "image_features_savedir = '/scratch/IITB/ai-at-ieor/23m1521/datasets/Vaani/Hindi_Image_features/'\n",
919
+ "audio_tensors_savedir = '/scratch/IITB/ai-at-ieor/23m1521/datasets/Vaani/Hindi_Audio_tensors/'\n",
920
+ "train_dataset = VaaniImageAudioDataset(train_df, image_features_savedir, audio_tensors_savedir)\n",
921
+ "test_dataset = VaaniImageAudioDataset(test_df, image_features_savedir, audio_tensors_savedir)\n",
922
+ "\n",
923
+ "print('Train Dataset:', len(train_dataset))\n",
924
+ "print('Test Dataset:', len(test_dataset))\n",
925
+ "train_dataset[0]"
926
+ ]
927
+ },
928
+ {
929
+ "cell_type": "code",
930
+ "execution_count": 9,
931
+ "metadata": {},
932
+ "outputs": [
933
+ {
934
+ "name": "stdout",
935
+ "output_type": "stream",
936
+ "text": [
937
+ "Image batch shape: torch.Size([64, 1024])\n",
938
+ "Audio batch shape: torch.Size([64, 308700])\n"
939
+ ]
940
+ }
941
+ ],
942
+ "source": [
943
+ "BATCH_SIZE = int(64)\n",
944
+ "train_dataloader = torch.utils.data.DataLoader(\n",
945
+ " train_dataset,\n",
946
+ " batch_size=BATCH_SIZE, \n",
947
+ " shuffle=True, \n",
948
+ " num_workers=48,\n",
949
+ " pin_memory=True,\n",
950
+ " drop_last=False,\n",
951
+ " persistent_workers=True\n",
952
+ ")\n",
953
+ "\n",
954
+ "test_dataloader = torch.utils.data.DataLoader(\n",
955
+ " test_dataset,\n",
956
+ " batch_size=BATCH_SIZE, \n",
957
+ " shuffle=False, \n",
958
+ " num_workers=48,\n",
959
+ " pin_memory=True,\n",
960
+ " drop_last=False,\n",
961
+ " persistent_workers=True\n",
962
+ ")\n",
963
+ "\n",
964
+ "batch = next(iter(train_dataloader))\n",
965
+ "image_features_batch = batch['image_feature'].to(device=device)\n",
966
+ "audio_tensor_batch = batch['audio_tensor'].to(device=device)\n",
967
+ "image_paths_batch = batch['image_path']\n",
968
+ "audio_paths_batch = batch['audio_path']\n",
969
+ "print(\"Image batch shape:\", image_features_batch.shape) # [BATCH_SIZE, 3, 224, 224]\n",
970
+ "print(\"Audio batch shape:\", audio_tensor_batch.shape) # [BATCH_SIZE, 1, 44100]\n"
971
+ ]
972
+ }
973
+ ],
974
+ "metadata": {
975
+ "kernelspec": {
976
+ "display_name": "clap",
977
+ "language": "python",
978
+ "name": "python3"
979
+ },
980
+ "language_info": {
981
+ "codemirror_mode": {
982
+ "name": "ipython",
983
+ "version": 3
984
+ },
985
+ "file_extension": ".py",
986
+ "mimetype": "text/x-python",
987
+ "name": "python",
988
+ "nbconvert_exporter": "python",
989
+ "pygments_lexer": "ipython3",
990
+ "version": "3.11.11"
991
+ }
992
+ },
993
+ "nbformat": 4,
994
+ "nbformat_minor": 2
995
+ }
scratch/IITB/ai-at-ieor/23m1521/SDFT/Vaani/SDFT/SD21_Whisper/config-SD21_Whisper.yaml ADDED
@@ -0,0 +1 @@
 
 
1
+ learning_rate: 1e-7
scratch/IITB/ai-at-ieor/23m1521/SDFT/Vaani/_1.1_Audio-Hindi-Download.ipynb CHANGED
@@ -1045,7 +1045,7 @@
1045
  ],
1046
  "metadata": {
1047
  "kernelspec": {
1048
- "display_name": "Python 3",
1049
  "language": "python",
1050
  "name": "python3"
1051
  },
 
1045
  ],
1046
  "metadata": {
1047
  "kernelspec": {
1048
+ "display_name": "aku",
1049
  "language": "python",
1050
  "name": "python3"
1051
  },