alpha31476 commited on
Commit
87ef7b5
·
verified ·
1 Parent(s): 3f546f5

LDM-train-pass, checking results

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +102 -0
  2. .gitignore +25 -0
  3. .vscode/settings.json +3 -0
  4. DDPM/CeleabA.parquet +3 -0
  5. DDPM/_1_Mnist.ipynb +546 -0
  6. DDPM/_3_Activation-Checkpointing-Sequential.ipynb +216 -0
  7. DDPM/_4_Activation-Checkpointing-VAE.ipynb +444 -0
  8. DDPM/_5_Activation-Ckpt-VAE-CelebA.ipynb +0 -0
  9. Imgui/demo-newstyle.py +298 -0
  10. Imgui/demo.py +301 -0
  11. Imgui/imgui.ini +25 -0
  12. LDM/notebooks/_1_Main.ipynb +1481 -0
  13. LDM/notebooks/_2_Rough-LPIPS.ipynb +0 -0
  14. LDM/scripts/Main.py +2273 -0
  15. LDM/scripts/_1_Lpips.py +56 -0
  16. LDM/scripts/config.yaml +65 -0
  17. Vaani/39448.err +351 -0
  18. Vaani/39448.out +11 -0
  19. Vaani/IISc_VaaniProject_M_AP_Anantpur_00014520_1544240000_APATSR_190315_1880_16300.wav +3 -0
  20. Vaani/LDM/__init__.py +0 -0
  21. Vaani/LDM/notebooks/Vaani-subplot.png +3 -0
  22. Vaani/LDM/notebooks/Vaani_VQVAE_Recon_Images/reconstructed_images_EP-15_16.png +3 -0
  23. Vaani/LDM/notebooks/Vaani_VQVAE_Recon_Images/reconstructed_images_EP-30_16.png +3 -0
  24. Vaani/LDM/notebooks/Vaani_VQVAE_Recon_Images/reconstructed_images_EP-4.png +3 -0
  25. Vaani/LDM/notebooks/Vaani_VQVAE_Recon_Images/reconstructed_images_EP-5.png +3 -0
  26. Vaani/LDM/notebooks/Vaani_VQVAE_Recon_Images/reconstructed_images_EP-6.png +3 -0
  27. Vaani/LDM/notebooks/Vaani_VQVAE_Recon_Images/reconstructed_images_EP-6_16.png +3 -0
  28. Vaani/LDM/notebooks/Vaani_VQVAE_Recon_Images/reconstructed_images_EP-8_16.png +3 -0
  29. Vaani/LDM/notebooks/_1_Main.ipynb +0 -0
  30. Vaani/LDM/notebooks/_2_Rough-LPIPS.ipynb +0 -0
  31. Vaani/LDM/scripts/AE-training.log +126 -0
  32. Vaani/LDM/scripts/Main.py +2303 -0
  33. Vaani/LDM/scripts/SLURM-AE-Train.sh +21 -0
  34. Vaani/LDM/scripts/SLURM-AE-Train2.sh +21 -0
  35. Vaani/LDM/scripts/Vaani-VQVAE-Main.py +1151 -0
  36. Vaani/LDM/scripts/VaaniLDM/vqvaq_ckpt-15.pth +3 -0
  37. Vaani/LDM/scripts/VaaniLDM/vqvaq_ckpt.pth +3 -0
  38. Vaani/LDM/scripts/_1_Lpips.py +56 -0
  39. Vaani/LDM/scripts/__init__.py +0 -0
  40. Vaani/LDM/scripts/config.yaml +65 -0
  41. Vaani/LDM/scripts/dotdict.py +53 -0
  42. Vaani/SLURM_test.sh +20 -0
  43. Vaani/VQVAE_architecture.svg +0 -0
  44. Vaani/VQVAE_summary.txt +438 -0
  45. Vaani/VQVAE_training.sh +19 -0
  46. Vaani/Vaani-Audio-Image-English.csv +0 -0
  47. Vaani/Vaani-Images-Audio-MetaData.parquet +3 -0
  48. Vaani/Vaani-subplot.png +3 -0
  49. Vaani/VaaniLDM/ddpm_ckpt_epoch14.pt +3 -0
  50. Vaani/VaaniLDM/ddpm_ckpt_epoch15.pt +3 -0
.gitattributes CHANGED
@@ -33,3 +33,105 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ Vaani/IISc_VaaniProject_M_AP_Anantpur_00014520_1544240000_APATSR_190315_1880_16300.wav filter=lfs diff=lfs merge=lfs -text
37
+ Vaani/LDM/notebooks/Vaani-subplot.png filter=lfs diff=lfs merge=lfs -text
38
+ Vaani/LDM/notebooks/Vaani_VQVAE_Recon_Images/reconstructed_images_EP-15_16.png filter=lfs diff=lfs merge=lfs -text
39
+ Vaani/LDM/notebooks/Vaani_VQVAE_Recon_Images/reconstructed_images_EP-30_16.png filter=lfs diff=lfs merge=lfs -text
40
+ Vaani/LDM/notebooks/Vaani_VQVAE_Recon_Images/reconstructed_images_EP-4.png filter=lfs diff=lfs merge=lfs -text
41
+ Vaani/LDM/notebooks/Vaani_VQVAE_Recon_Images/reconstructed_images_EP-5.png filter=lfs diff=lfs merge=lfs -text
42
+ Vaani/LDM/notebooks/Vaani_VQVAE_Recon_Images/reconstructed_images_EP-6.png filter=lfs diff=lfs merge=lfs -text
43
+ Vaani/LDM/notebooks/Vaani_VQVAE_Recon_Images/reconstructed_images_EP-6_16.png filter=lfs diff=lfs merge=lfs -text
44
+ Vaani/LDM/notebooks/Vaani_VQVAE_Recon_Images/reconstructed_images_EP-8_16.png filter=lfs diff=lfs merge=lfs -text
45
+ Vaani/Vaani-subplot.png filter=lfs diff=lfs merge=lfs -text
46
+ Vaani/VaaniLDM/samples/x0_0.png filter=lfs diff=lfs merge=lfs -text
47
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-10_16.png filter=lfs diff=lfs merge=lfs -text
48
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-11_16.png filter=lfs diff=lfs merge=lfs -text
49
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-12_16.png filter=lfs diff=lfs merge=lfs -text
50
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-13_16.png filter=lfs diff=lfs merge=lfs -text
51
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-14_16.png filter=lfs diff=lfs merge=lfs -text
52
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-15_16.png filter=lfs diff=lfs merge=lfs -text
53
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-16_16.png filter=lfs diff=lfs merge=lfs -text
54
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-17_16.png filter=lfs diff=lfs merge=lfs -text
55
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-18_16.png filter=lfs diff=lfs merge=lfs -text
56
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-19_16.png filter=lfs diff=lfs merge=lfs -text
57
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-20_16.png filter=lfs diff=lfs merge=lfs -text
58
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-21_16.png filter=lfs diff=lfs merge=lfs -text
59
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-22_16.png filter=lfs diff=lfs merge=lfs -text
60
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-23_16.png filter=lfs diff=lfs merge=lfs -text
61
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-24_16.png filter=lfs diff=lfs merge=lfs -text
62
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-25_16.png filter=lfs diff=lfs merge=lfs -text
63
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-26_16.png filter=lfs diff=lfs merge=lfs -text
64
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-27_16.png filter=lfs diff=lfs merge=lfs -text
65
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-28_16.png filter=lfs diff=lfs merge=lfs -text
66
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-29_16.png filter=lfs diff=lfs merge=lfs -text
67
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-31_16.png filter=lfs diff=lfs merge=lfs -text
68
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-32_16.png filter=lfs diff=lfs merge=lfs -text
69
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-33_16.png filter=lfs diff=lfs merge=lfs -text
70
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-34_16.png filter=lfs diff=lfs merge=lfs -text
71
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-36_16.png filter=lfs diff=lfs merge=lfs -text
72
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-37_16.png filter=lfs diff=lfs merge=lfs -text
73
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-38_16.png filter=lfs diff=lfs merge=lfs -text
74
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-39_16.png filter=lfs diff=lfs merge=lfs -text
75
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-40_16.png filter=lfs diff=lfs merge=lfs -text
76
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-41_16.png filter=lfs diff=lfs merge=lfs -text
77
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-42_16.png filter=lfs diff=lfs merge=lfs -text
78
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-43_16.png filter=lfs diff=lfs merge=lfs -text
79
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-44_16.png filter=lfs diff=lfs merge=lfs -text
80
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-45_16.png filter=lfs diff=lfs merge=lfs -text
81
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-46_16.png filter=lfs diff=lfs merge=lfs -text
82
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-47_16.png filter=lfs diff=lfs merge=lfs -text
83
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-48_16.png filter=lfs diff=lfs merge=lfs -text
84
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-49_16.png filter=lfs diff=lfs merge=lfs -text
85
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-51_16.png filter=lfs diff=lfs merge=lfs -text
86
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-52_16.png filter=lfs diff=lfs merge=lfs -text
87
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-53_16.png filter=lfs diff=lfs merge=lfs -text
88
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-54_16.png filter=lfs diff=lfs merge=lfs -text
89
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-55_16.png filter=lfs diff=lfs merge=lfs -text
90
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-56_16.png filter=lfs diff=lfs merge=lfs -text
91
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-57_16.png filter=lfs diff=lfs merge=lfs -text
92
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-58_16.png filter=lfs diff=lfs merge=lfs -text
93
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-59_16.png filter=lfs diff=lfs merge=lfs -text
94
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-60_16.png filter=lfs diff=lfs merge=lfs -text
95
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-61_16.png filter=lfs diff=lfs merge=lfs -text
96
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-62_16.png filter=lfs diff=lfs merge=lfs -text
97
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-63_16.png filter=lfs diff=lfs merge=lfs -text
98
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-64_16.png filter=lfs diff=lfs merge=lfs -text
99
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-65_16.png filter=lfs diff=lfs merge=lfs -text
100
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-66_16.png filter=lfs diff=lfs merge=lfs -text
101
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-67_16.png filter=lfs diff=lfs merge=lfs -text
102
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-68_16.png filter=lfs diff=lfs merge=lfs -text
103
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-69_16.png filter=lfs diff=lfs merge=lfs -text
104
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-71_16.png filter=lfs diff=lfs merge=lfs -text
105
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-73_16.png filter=lfs diff=lfs merge=lfs -text
106
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-74_16.png filter=lfs diff=lfs merge=lfs -text
107
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-75_16.png filter=lfs diff=lfs merge=lfs -text
108
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-76_16.png filter=lfs diff=lfs merge=lfs -text
109
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-77_16.png filter=lfs diff=lfs merge=lfs -text
110
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-78_16.png filter=lfs diff=lfs merge=lfs -text
111
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-79_16.png filter=lfs diff=lfs merge=lfs -text
112
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-81_16.png filter=lfs diff=lfs merge=lfs -text
113
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-82_16.png filter=lfs diff=lfs merge=lfs -text
114
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-83_16.png filter=lfs diff=lfs merge=lfs -text
115
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-84_16.png filter=lfs diff=lfs merge=lfs -text
116
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-85_16.png filter=lfs diff=lfs merge=lfs -text
117
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-86_16.png filter=lfs diff=lfs merge=lfs -text
118
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-87_16.png filter=lfs diff=lfs merge=lfs -text
119
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-88_16.png filter=lfs diff=lfs merge=lfs -text
120
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-89_16.png filter=lfs diff=lfs merge=lfs -text
121
+ Vaani/VaaniLDM/vqvae_recon/reconstructed_images_EP-9_16.png filter=lfs diff=lfs merge=lfs -text
122
+ Vaani/VaaniLDM_Acc/vqvae_recon/reconstructed_images_EP-0_16.png filter=lfs diff=lfs merge=lfs -text
123
+ Vaani/VaaniLDM_Acc/vqvae_recon/reconstructed_images_EP-1_16.png filter=lfs diff=lfs merge=lfs -text
124
+ Vaani/VaaniLDM_Acc/vqvae_recon/reconstructed_images_EP-2_16.png filter=lfs diff=lfs merge=lfs -text
125
+ Vaani/VaaniLDM_Acc/vqvae_recon/reconstructed_images_EP-3_16.png filter=lfs diff=lfs merge=lfs -text
126
+ Vaani/VaaniLDM_Acc/vqvae_recon/reconstructed_images_EP-4_16.png filter=lfs diff=lfs merge=lfs -text
127
+ Vaani/VaaniLDM_Acc/vqvae_recon/reconstructed_images_EP-5_16.png filter=lfs diff=lfs merge=lfs -text
128
+ Vaani/_1_data.ipynb filter=lfs diff=lfs merge=lfs -text
129
+ Vaani/audio_urls.txt filter=lfs diff=lfs merge=lfs -text
130
+ Vaani/finalMETA.csv filter=lfs diff=lfs merge=lfs -text
131
+ Vaani/image_metadata_summary.csv filter=lfs diff=lfs merge=lfs -text
132
+ Vaani/images_urls.txt filter=lfs diff=lfs merge=lfs -text
133
+ Vaani/output_image.png filter=lfs diff=lfs merge=lfs -text
134
+ Vaani/output_image2.png filter=lfs diff=lfs merge=lfs -text
135
+ Vaani/sampleJSON.csv filter=lfs diff=lfs merge=lfs -text
136
+ Vaani/sampleJSON.json filter=lfs diff=lfs merge=lfs -text
137
+ tools/__pycache__/pynvml.cpython-312.pyc filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # # Ignore image files
2
+ # *.jpg
3
+ # *.jpeg
4
+ # *.png
5
+ # *.gif
6
+ # *.bmp
7
+ # *.tiff
8
+ # *.webp
9
+ # *.svg
10
+
11
+ # # Ignore specified data files
12
+ # *.pth
13
+ # *.pt
14
+ # *.safetensors
15
+ # *.npz
16
+ # *.npy
17
+ # *.csv
18
+ # *.parquet
19
+ # *.json
20
+ # *.err
21
+ # *.out
22
+
23
+
24
+ # Vaani/audio_urls.txt
25
+ # Vaani/images_urls.txt
.vscode/settings.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "auto-scroll.enabled": false
3
+ }
DDPM/CeleabA.parquet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0f41418ec864a1ceee3e4f3c4863f758b534cf434f848c64a4d1df976d10f241
3
+ size 3396938
DDPM/_1_Mnist.ipynb ADDED
@@ -0,0 +1,546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import torch\n",
10
+ "import torch.nn as nn\n",
11
+ "import torch.optim as optim\n",
12
+ "import torch.utils.checkpoint as checkpoint\n",
13
+ "from torchvision import datasets, transforms\n",
14
+ "from torch.utils.data import DataLoader"
15
+ ]
16
+ },
17
+ {
18
+ "cell_type": "code",
19
+ "execution_count": 2,
20
+ "metadata": {},
21
+ "outputs": [],
22
+ "source": [
23
+ "import time\n",
24
+ "import nvidia_smi\n",
25
+ "import prettytable as pt\n",
26
+ "\n",
27
+ "def gputil_decorator(func):\n",
28
+ " def wrapper(*args, **kwargs):\n",
29
+ " import nvidia_smi\n",
30
+ " import prettytable as pt\n",
31
+ "\n",
32
+ " try:\n",
33
+ " table = pt.PrettyTable(['Devices','Mem Free','GPU-util','GPU-mem'])\n",
34
+ " nvidia_smi.nvmlInit()\n",
35
+ " deviceCount = nvidia_smi.nvmlDeviceGetCount()\n",
36
+ " for i in range(deviceCount):\n",
37
+ " handle = nvidia_smi.nvmlDeviceGetHandleByIndex(i)\n",
38
+ " res = nvidia_smi.nvmlDeviceGetUtilizationRates(handle)\n",
39
+ " mem = nvidia_smi.nvmlDeviceGetMemoryInfo(handle)\n",
40
+ " table.add_row([i, f\"{mem.free/1024**2:5.2f}MB/{mem.total/1024**2:5.2f}MB\", f\"{res.gpu:3.1%}\", f\"{res.memory:3.1%}\"])\n",
41
+ "\n",
42
+ " except nvidia_smi.NVMLError as error:\n",
43
+ " print(error)\n",
44
+ "\n",
45
+ " print(table)\n",
46
+ " return func(*args, **kwargs)\n",
47
+ " return wrapper\n",
48
+ "\n",
49
+ "def gputil_decorator2(func):\n",
50
+ " def wrapper(*args, **kwargs):\n",
51
+ " try:\n",
52
+ " table = pt.PrettyTable(['Devices', 'Mem Free', 'GPU-util', 'GPU-mem'])\n",
53
+ " nvidia_smi.nvmlInit()\n",
54
+ " device_count = nvidia_smi.nvmlDeviceGetCount()\n",
55
+ " for i in range(device_count):\n",
56
+ " handle = nvidia_smi.nvmlDeviceGetHandleByIndex(i)\n",
57
+ " res = nvidia_smi.nvmlDeviceGetUtilizationRates(handle)\n",
58
+ " mem = nvidia_smi.nvmlDeviceGetMemoryInfo(handle)\n",
59
+ " table.add_row([\n",
60
+ " i,\n",
61
+ " f\"{mem.free / 1024 ** 2:5.2f}MB/{mem.total / 1024 ** 2:5.2f}MB\",\n",
62
+ " f\"{res.gpu:3.1%}\",\n",
63
+ " f\"{res.memory:3.1%}\"\n",
64
+ " ])\n",
65
+ " nvidia_smi.nvmlShutdown()\n",
66
+ " except nvidia_smi.NVMLError as error:\n",
67
+ " print(f\"Error fetching GPU stats: {error}\")\n",
68
+ " print(table)\n",
69
+ " return func(*args, **kwargs)\n",
70
+ " return wrapper"
71
+ ]
72
+ },
73
+ {
74
+ "cell_type": "code",
75
+ "execution_count": null,
76
+ "metadata": {},
77
+ "outputs": [],
78
+ "source": [
79
+ "import torch\n",
80
+ "import torch.nn as nn\n",
81
+ "import torch.optim as optim\n",
82
+ "from torchvision import datasets, transforms\n",
83
+ "from torch.utils.data import DataLoader\n",
84
+ "import torchvision.models as models\n",
85
+ "import threading\n",
86
+ "import time\n",
87
+ "import nvidia_smi\n",
88
+ "import prettytable as pt\n",
89
+ "import os\n",
90
+ "\n",
91
+ "# GPU stats decorator\n",
92
+ "def gputil_decorator2(func):\n",
93
+ " def wrapper(*args, **kwargs):\n",
94
+ " try:\n",
95
+ " table = pt.PrettyTable(['Devices', 'Mem Free', 'GPU-util', 'GPU-mem'])\n",
96
+ " nvidia_smi.nvmlInit()\n",
97
+ " device_count = nvidia_smi.nvmlDeviceGetCount()\n",
98
+ " for i in range(device_count):\n",
99
+ " handle = nvidia_smi.nvmlDeviceGetHandleByIndex(i)\n",
100
+ " res = nvidia_smi.nvmlDeviceGetUtilizationRates(handle)\n",
101
+ " mem = nvidia_smi.nvmlDeviceGetMemoryInfo(handle)\n",
102
+ " table.add_row([\n",
103
+ " i,\n",
104
+ " f\"{mem.free / 1024 ** 2:5.2f}MB/{mem.total / 1024 ** 2:5.2f}MB\",\n",
105
+ " f\"{res.gpu:3.1%}\",\n",
106
+ " f\"{res.memory:3.1%}\"\n",
107
+ " ])\n",
108
+ " nvidia_smi.nvmlShutdown()\n",
109
+ " except nvidia_smi.NVMLError as error:\n",
110
+ " print(f\"Error fetching GPU stats: {error}\")\n",
111
+ " print(table)\n",
112
+ " return func(*args, **kwargs)\n",
113
+ " return wrapper\n",
114
+ "\n",
115
+ "# Function to print GPU stats every second\n",
116
+ "def print_gpu_stats(epoch_info):\n",
117
+ " while not stop_event.is_set():\n",
118
+ " os.system('cls' if os.name == 'nt' else 'clear') # Clear the terminal\n",
119
+ " gputil_decorator2(lambda: None)() # Call the decorator to print stats\n",
120
+ " print(epoch_info) # Print epoch information\n",
121
+ " time.sleep(1) # Wait for 1 second\n",
122
+ "\n",
123
+ "# Define the model\n",
124
+ "class EfficientNetCIFAR10(nn.Module):\n",
125
+ " def __init__(self, num_classes=10):\n",
126
+ " super(EfficientNetCIFAR10, self).__init__()\n",
127
+ " self.efficientnet = models.efficientnet_v2_l(weights=models.EfficientNet_V2_L_Weights.IMAGENET1K_V1)\n",
128
+ " self.efficientnet.classifier[1] = nn.Linear(self.efficientnet.classifier[1].in_features, num_classes)\n",
129
+ "\n",
130
+ " def forward(self, x):\n",
131
+ " return self.efficientnet(x)\n",
132
+ "\n",
133
+ "# Load CIFAR-10 dataset\n",
134
+ "transform = transforms.Compose([\n",
135
+ " transforms.ToTensor(),\n",
136
+ " transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n",
137
+ "])\n",
138
+ "\n",
139
+ "train_dataset = datasets.CIFAR10(root='/home/23m1521/datasets', train=True, download=True, transform=transform)\n",
140
+ "train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=64)\n",
141
+ "\n",
142
+ "test_dataset = datasets.CIFAR10(root='/home/23m1521/datasets', train=False, download=True, transform=transform)\n",
143
+ "test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False, num_workers=64)\n",
144
+ "\n",
145
+ "# Initialize model, loss function, and optimizer\n",
146
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
147
+ "model = EfficientNetCIFAR10(num_classes=10).to(device)\n",
148
+ "criterion = nn.CrossEntropyLoss()\n",
149
+ "optimizer = optim.Adam(model.parameters(), lr=0.001)\n",
150
+ "\n",
151
+ "# Training loop\n",
152
+ "def train(model, train_loader, criterion, optimizer, device):\n",
153
+ " model.train()\n",
154
+ " running_loss = 0.0\n",
155
+ " for inputs, labels in train_loader:\n",
156
+ " inputs, labels = inputs.to(device), labels.to(device)\n",
157
+ "\n",
158
+ " optimizer.zero_grad()\n",
159
+ " outputs = model(inputs)\n",
160
+ " loss = criterion(outputs, labels)\n",
161
+ " loss.backward()\n",
162
+ " optimizer.step()\n",
163
+ "\n",
164
+ " running_loss += loss.item()\n",
165
+ " return running_loss / len(train_loader)\n",
166
+ "\n",
167
+ "# Testing loop\n",
168
+ "def test(model, test_loader, criterion, device):\n",
169
+ " model.eval()\n",
170
+ " correct = 0\n",
171
+ " total = 0\n",
172
+ " with torch.no_grad():\n",
173
+ " for inputs, labels in test_loader:\n",
174
+ " inputs, labels = inputs.to(device), labels.to(device)\n",
175
+ " outputs = model(inputs)\n",
176
+ " _, predicted = torch.max(outputs.data, 1)\n",
177
+ " total += labels.size(0)\n",
178
+ " correct += (predicted == labels).sum().item()\n",
179
+ " return correct / total\n",
180
+ "\n",
181
+ "# Start the GPU stats printing thread\n",
182
+ "stop_event = threading.Event()\n",
183
+ "epoch_info = \"\" # Placeholder for epoch information\n",
184
+ "gpu_stats_thread = threading.Thread(target=print_gpu_stats, args=(epoch_info,))\n",
185
+ "gpu_stats_thread.start()\n",
186
+ "\n",
187
+ "# Train and test the model\n",
188
+ "num_epochs = 5\n",
189
+ "for epoch in range(num_epochs):\n",
190
+ " train_loss = train(model, train_loader, criterion, optimizer, device)\n",
191
+ " test_acc = test(model, test_loader, criterion, device)\n",
192
+ " epoch_info = f\"Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.4f}, Test Accuracy: {test_acc:.4f}\"\n",
193
+ " print_gpu_stats(epoch_info) # Print epoch information\n",
194
+ "\n",
195
+ "# Stop the GPU stats printing thread\n",
196
+ "stop_event.set()\n",
197
+ "gpu_stats_thread.join()"
198
+ ]
199
+ },
200
+ {
201
+ "cell_type": "code",
202
+ "execution_count": 2,
203
+ "metadata": {},
204
+ "outputs": [
205
+ {
206
+ "name": "stdout",
207
+ "output_type": "stream",
208
+ "text": [
209
+ "Files already downloaded and verified\n",
210
+ "Files already downloaded and verified\n"
211
+ ]
212
+ }
213
+ ],
214
+ "source": [
215
+ "# Define a simple CNN model\n",
216
+ "class SimpleCNN(nn.Module):\n",
217
+ " def __init__(self):\n",
218
+ " super(SimpleCNN, self).__init__()\n",
219
+ " self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)\n",
220
+ " self.relu1 = nn.ReLU()\n",
221
+ " self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)\n",
222
+ " self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)\n",
223
+ " self.relu2 = nn.ReLU()\n",
224
+ " self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)\n",
225
+ " self.fc1 = nn.Linear(32 * 8 * 8, 256)\n",
226
+ " self.relu3 = nn.ReLU()\n",
227
+ " self.fc2 = nn.Linear(256, 10) # CIFAR-10 has 10 classes\n",
228
+ "\n",
229
+ " def forward(self, x):\n",
230
+ " # Apply gradient/activation checkpointing to the second convolutional block\n",
231
+ " x = self.conv1(x)\n",
232
+ " x = self.relu1(x)\n",
233
+ " x = self.pool1(x)\n",
234
+ " x = checkpoint.checkpoint(self._conv2_block, x) # Checkpointing here\n",
235
+ " x = x.view(x.size(0), -1) # Flatten\n",
236
+ " x = self.fc1(x)\n",
237
+ " x = self.relu3(x)\n",
238
+ " x = self.fc2(x)\n",
239
+ " return x\n",
240
+ "\n",
241
+ " def _conv2_block(self, x):\n",
242
+ " # Helper function for the second convolutional block\n",
243
+ " x = self.conv2(x)\n",
244
+ " x = self.relu2(x)\n",
245
+ " x = self.pool2(x)\n",
246
+ " return x\n",
247
+ "\n",
248
+ "# Load CIFAR-10 dataset\n",
249
+ "transform = transforms.Compose([\n",
250
+ " transforms.ToTensor(),\n",
251
+ " transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n",
252
+ "])\n",
253
+ "\n",
254
+ "train_dataset = datasets.CIFAR10(root='/home/23m1521/datasets', train=True, download=True, transform=transform)\n",
255
+ "train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)\n",
256
+ "\n",
257
+ "test_dataset = datasets.CIFAR10(root='/home/23m1521/datasets', train=False, download=True, transform=transform)\n",
258
+ "test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)\n",
259
+ "\n",
260
+ "# Initialize model, loss function, and optimizer\n",
261
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
262
+ "model = SimpleCNN().to(device)\n",
263
+ "criterion = nn.CrossEntropyLoss()\n",
264
+ "optimizer = optim.Adam(model.parameters(), lr=0.001)"
265
+ ]
266
+ },
267
+ {
268
+ "cell_type": "code",
269
+ "execution_count": null,
270
+ "metadata": {},
271
+ "outputs": [
272
+ {
273
+ "name": "stdout",
274
+ "output_type": "stream",
275
+ "text": [
276
+ "+---------+-----------------------+----------+---------+\n",
277
+ "| Devices | Mem Free | GPU-util | GPU-mem |\n",
278
+ "+---------+-----------------------+----------+---------+\n",
279
+ "| 0 | 23416.75MB/24564.00MB | 0.0% | 0.0% |\n",
280
+ "| 1 | 944.75MB/24564.00MB | 0.0% | 0.0% |\n",
281
+ "+---------+-----------------------+----------+---------+\n"
282
+ ]
283
+ },
284
+ {
285
+ "name": "stderr",
286
+ "output_type": "stream",
287
+ "text": [
288
+ "/home/23m1521/.conda/envs/cuda_env2/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py:600: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
289
+ " return fn(*args, **kwargs)\n",
290
+ "/home/23m1521/.conda/envs/cuda_env2/lib/python3.12/site-packages/torch/utils/checkpoint.py:295: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.\n",
291
+ " with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs): # type: ignore[attr-defined]\n",
292
+ "/home/23m1521/.conda/envs/cuda_env2/lib/python3.12/site-packages/torch/utils/checkpoint.py:92: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n",
293
+ " warnings.warn(\n"
294
+ ]
295
+ },
296
+ {
297
+ "name": "stdout",
298
+ "output_type": "stream",
299
+ "text": [
300
+ "Epoch [1/5], Loss: 1.3807, Test Accuracy: 0.5572\n",
301
+ "+---------+-----------------------+----------+---------+\n",
302
+ "| Devices | Mem Free | GPU-util | GPU-mem |\n",
303
+ "+---------+-----------------------+----------+---------+\n",
304
+ "| 0 | 22732.75MB/24564.00MB | 300.0% | 100.0% |\n",
305
+ "| 1 | 944.75MB/24564.00MB | 0.0% | 0.0% |\n",
306
+ "+---------+-----------------------+----------+---------+\n",
307
+ "Epoch [2/5], Loss: 1.0334, Test Accuracy: 0.6553\n",
308
+ "+---------+-----------------------+----------+---------+\n",
309
+ "| Devices | Mem Free | GPU-util | GPU-mem |\n",
310
+ "+---------+-----------------------+----------+---------+\n",
311
+ "| 0 | 22732.75MB/24564.00MB | 300.0% | 100.0% |\n",
312
+ "| 1 | 944.75MB/24564.00MB | 0.0% | 0.0% |\n",
313
+ "+---------+-----------------------+----------+---------+\n",
314
+ "Epoch [3/5], Loss: 0.8787, Test Accuracy: 0.6824\n",
315
+ "+---------+-----------------------+----------+---------+\n",
316
+ "| Devices | Mem Free | GPU-util | GPU-mem |\n",
317
+ "+---------+-----------------------+----------+---------+\n",
318
+ "| 0 | 22732.75MB/24564.00MB | 200.0% | 100.0% |\n",
319
+ "| 1 | 944.75MB/24564.00MB | 0.0% | 0.0% |\n",
320
+ "+---------+-----------------------+----------+---------+\n",
321
+ "Epoch [4/5], Loss: 0.7545, Test Accuracy: 0.6885\n",
322
+ "+---------+-----------------------+----------+---------+\n",
323
+ "| Devices | Mem Free | GPU-util | GPU-mem |\n",
324
+ "+---------+-----------------------+----------+---------+\n",
325
+ "| 0 | 22732.75MB/24564.00MB | 300.0% | 100.0% |\n",
326
+ "| 1 | 944.75MB/24564.00MB | 0.0% | 0.0% |\n",
327
+ "+---------+-----------------------+----------+---------+\n",
328
+ "Epoch [5/5], Loss: 0.6537, Test Accuracy: 0.6989\n"
329
+ ]
330
+ }
331
+ ],
332
+ "source": [
333
+ "# Training loop\n",
334
+ "@gputil_decorator2\n",
335
+ "def train(model, train_loader, criterion, optimizer, device):\n",
336
+ " model.train()\n",
337
+ " running_loss = 0.0\n",
338
+ " for inputs, labels in train_loader:\n",
339
+ " inputs, labels = inputs.to(device), labels.to(device)\n",
340
+ "\n",
341
+ " optimizer.zero_grad()\n",
342
+ " outputs = model(inputs)\n",
343
+ " loss = criterion(outputs, labels)\n",
344
+ " loss.backward()\n",
345
+ " optimizer.step()\n",
346
+ "\n",
347
+ " running_loss += loss.item()\n",
348
+ " return running_loss / len(train_loader)\n",
349
+ "\n",
350
+ "# Testing loop\n",
351
+ "def test(model, test_loader, criterion, device):\n",
352
+ " model.eval()\n",
353
+ " correct = 0\n",
354
+ " total = 0\n",
355
+ " with torch.no_grad():\n",
356
+ " for inputs, labels in test_loader:\n",
357
+ " inputs, labels = inputs.to(device), labels.to(device)\n",
358
+ " outputs = model(inputs)\n",
359
+ " _, predicted = torch.max(outputs.data, 1)\n",
360
+ " total += labels.size(0)\n",
361
+ " correct += (predicted == labels).sum().item()\n",
362
+ " return correct / total\n",
363
+ "\n",
364
+ "# Train and test the model\n",
365
+ "num_epochs = 5\n",
366
+ "for epoch in range(num_epochs):\n",
367
+ " train_loss = train(model, train_loader, criterion, optimizer, device)\n",
368
+ " test_acc = test(model, test_loader, criterion, device)\n",
369
+ " print(f\"Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.4f}, Test Accuracy: {test_acc:.4f}\")"
370
+ ]
371
+ },
372
+ {
373
+ "cell_type": "code",
374
+ "execution_count": 16,
375
+ "metadata": {},
376
+ "outputs": [
377
+ {
378
+ "data": {
379
+ "text/plain": [
380
+ "([0.023805618286132812, 0.0],\n",
381
+ " [0.04064750671386719, 0.0],\n",
382
+ " [23.679443359375, 23.679443359375])"
383
+ ]
384
+ },
385
+ "execution_count": 16,
386
+ "metadata": {},
387
+ "output_type": "execute_result"
388
+ }
389
+ ],
390
+ "source": [
391
+ "def get_gpu_memory_usage():\n",
392
+ " allocated_memory = []\n",
393
+ " free_memory = []\n",
394
+ " total_memory = []\n",
395
+ " if torch.cuda.is_available():\n",
396
+ " for i in range(torch.cuda.device_count()):\n",
397
+ " device = torch.device(f\"cuda:{i}\")\n",
398
+ " total = torch.cuda.get_device_properties(device).total_memory / (1024 ** 3) # GB\n",
399
+ " allocated = torch.cuda.memory_allocated(device) / (1024 ** 3) # GB\n",
400
+ " reserved = torch.cuda.memory_reserved(device) / (1024 ** 3) # GB\n",
401
+ " free = reserved - allocated\n",
402
+ " total_memory.append(total)\n",
403
+ " allocated_memory.append(allocated)\n",
404
+ " free_memory.append(free)\n",
405
+ " return allocated_memory, free_memory, total_memory\n",
406
+ "get_gpu_memory_usage()"
407
+ ]
408
+ },
409
+ {
410
+ "cell_type": "code",
411
+ "execution_count": 1,
412
+ "metadata": {},
413
+ "outputs": [
414
+ {
415
+ "name": "stdout",
416
+ "output_type": "stream",
417
+ "text": [
418
+ "Files already downloaded and verified\n",
419
+ "Files already downloaded and verified\n"
420
+ ]
421
+ },
422
+ {
423
+ "name": "stderr",
424
+ "output_type": "stream",
425
+ "text": [
426
+ "Downloading: \"https://download.pytorch.org/models/efficientnet_v2_l-59c71312.pth\" to /home/23m1521/.cache/torch/hub/checkpoints/efficientnet_v2_l-59c71312.pth\n",
427
+ "100%|██████████| 455M/455M [00:04<00:00, 117MB/s] \n"
428
+ ]
429
+ },
430
+ {
431
+ "name": "stdout",
432
+ "output_type": "stream",
433
+ "text": [
434
+ "Epoch [1/5], Loss: 1.0192, Test Accuracy: 0.8080\n",
435
+ "Epoch [2/5], Loss: 0.4376, Test Accuracy: 0.8487\n",
436
+ "Epoch [3/5], Loss: 0.2590, Test Accuracy: 0.8334\n",
437
+ "Epoch [4/5], Loss: 0.1696, Test Accuracy: 0.8626\n",
438
+ "Epoch [5/5], Loss: 0.1257, Test Accuracy: 0.8621\n"
439
+ ]
440
+ }
441
+ ],
442
+ "source": [
443
+ "import torch\n",
444
+ "import torch.nn as nn\n",
445
+ "import torch.optim as optim\n",
446
+ "import torch.utils.checkpoint as checkpoint\n",
447
+ "from torchvision import datasets, transforms\n",
448
+ "from torch.utils.data import DataLoader\n",
449
+ "\n",
450
+ "import torch\n",
451
+ "import torch.nn as nn\n",
452
+ "import torchvision.models as models\n",
453
+ "\n",
454
+ "class EfficientNetCIFAR10(nn.Module):\n",
455
+ " def __init__(self, num_classes=10):\n",
456
+ " super(EfficientNetCIFAR10, self).__init__()\n",
457
+ " \n",
458
+ " # Load a pre-trained EfficientNet model\n",
459
+ " self.efficientnet = models.efficientnet_v2_l(weights=models.EfficientNet_V2_L_Weights.IMAGENET1K_V1)\n",
460
+ " \n",
461
+ " # Modify the classifier head for CIFAR-10 (10 classes)\n",
462
+ " self.efficientnet.classifier[1] = nn.Linear(self.efficientnet.classifier[1].in_features, num_classes)\n",
463
+ "\n",
464
+ " def forward(self, x):\n",
465
+ " return self.efficientnet(x)\n",
466
+ "\n",
467
+ "\n",
468
+ "# Load CIFAR-10 dataset\n",
469
+ "transform = transforms.Compose([\n",
470
+ " transforms.ToTensor(),\n",
471
+ " transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n",
472
+ "])\n",
473
+ "\n",
474
+ "train_dataset = datasets.CIFAR10(root='/home/23m1521/datasets', train=True, download=True, transform=transform)\n",
475
+ "train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=64)\n",
476
+ "\n",
477
+ "test_dataset = datasets.CIFAR10(root='/home/23m1521/datasets', train=False, download=True, transform=transform)\n",
478
+ "test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False, num_workers=64)\n",
479
+ "\n",
480
+ "# Initialize model, loss function, and optimizer\n",
481
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
482
+ "model = EfficientNetCIFAR10(num_classes=10).to(device)\n",
483
+ "criterion = nn.CrossEntropyLoss()\n",
484
+ "optimizer = optim.Adam(model.parameters(), lr=0.001)\n",
485
+ "\n",
486
+ "# Training loop\n",
487
+ "def train(model, train_loader, criterion, optimizer, device):\n",
488
+ " model.train()\n",
489
+ " running_loss = 0.0\n",
490
+ " for inputs, labels in train_loader:\n",
491
+ " inputs, labels = inputs.to(device), labels.to(device)\n",
492
+ "\n",
493
+ " optimizer.zero_grad()\n",
494
+ " outputs = model(inputs)\n",
495
+ " loss = criterion(outputs, labels)\n",
496
+ " loss.backward()\n",
497
+ " optimizer.step()\n",
498
+ "\n",
499
+ " running_loss += loss.item()\n",
500
+ " return running_loss / len(train_loader)\n",
501
+ "\n",
502
+ "# Testing loop\n",
503
+ "def test(model, test_loader, criterion, device):\n",
504
+ " model.eval()\n",
505
+ " correct = 0\n",
506
+ " total = 0\n",
507
+ " with torch.no_grad():\n",
508
+ " for inputs, labels in test_loader:\n",
509
+ " inputs, labels = inputs.to(device), labels.to(device)\n",
510
+ " outputs = model(inputs)\n",
511
+ " _, predicted = torch.max(outputs.data, 1)\n",
512
+ " total += labels.size(0)\n",
513
+ " correct += (predicted == labels).sum().item()\n",
514
+ " return correct / total\n",
515
+ "\n",
516
+ "# Train and test the model\n",
517
+ "num_epochs = 5\n",
518
+ "for epoch in range(num_epochs):\n",
519
+ " train_loss = train(model, train_loader, criterion, optimizer, device)\n",
520
+ " test_acc = test(model, test_loader, criterion, device)\n",
521
+ " print(f\"Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.4f}, Test Accuracy: {test_acc:.4f}\")"
522
+ ]
523
+ }
524
+ ],
525
+ "metadata": {
526
+ "kernelspec": {
527
+ "display_name": "cuda_env2",
528
+ "language": "python",
529
+ "name": "python3"
530
+ },
531
+ "language_info": {
532
+ "codemirror_mode": {
533
+ "name": "ipython",
534
+ "version": 3
535
+ },
536
+ "file_extension": ".py",
537
+ "mimetype": "text/x-python",
538
+ "name": "python",
539
+ "nbconvert_exporter": "python",
540
+ "pygments_lexer": "ipython3",
541
+ "version": "3.12.2"
542
+ }
543
+ },
544
+ "nbformat": 4,
545
+ "nbformat_minor": 2
546
+ }
DDPM/_3_Activation-Checkpointing-Sequential.ipynb ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {
7
+ "id": "CqFGp-OjP0_G"
8
+ },
9
+ "outputs": [],
10
+ "source": [
11
+ "import torch\n",
12
+ "from torch.autograd import Variable"
13
+ ]
14
+ },
15
+ {
16
+ "cell_type": "markdown",
17
+ "metadata": {
18
+ "id": "to7suvjJQJAM"
19
+ },
20
+ "source": [
21
+ "# [1] Checkpointing sequential models"
22
+ ]
23
+ },
24
+ {
25
+ "cell_type": "code",
26
+ "execution_count": 5,
27
+ "metadata": {
28
+ "colab": {
29
+ "base_uri": "https://localhost:8080/"
30
+ },
31
+ "id": "1YmlCf4MQEXV",
32
+ "outputId": "03833d29-11aa-4def-a9e4-650e349201a3"
33
+ },
34
+ "outputs": [
35
+ {
36
+ "data": {
37
+ "text/plain": [
38
+ "[Linear(in_features=100, out_features=50, bias=True),\n",
39
+ " ReLU(),\n",
40
+ " Linear(in_features=50, out_features=20, bias=True),\n",
41
+ " ReLU(),\n",
42
+ " Linear(in_features=20, out_features=5, bias=True),\n",
43
+ " ReLU()]"
44
+ ]
45
+ },
46
+ "execution_count": 5,
47
+ "metadata": {},
48
+ "output_type": "execute_result"
49
+ }
50
+ ],
51
+ "source": [
52
+ "from torch.utils.checkpoint import checkpoint_sequential\n",
53
+ "import torch.nn as nn\n",
54
+ "\n",
55
+ "model = nn.Sequential(\n",
56
+ " nn.Linear(100, 50),\n",
57
+ " nn.ReLU(),\n",
58
+ " nn.Linear(50, 20),\n",
59
+ " nn.ReLU(),\n",
60
+ " nn.Linear(20, 5),\n",
61
+ " nn.ReLU()\n",
62
+ ")\n",
63
+ "\n",
64
+ "input_var = Variable(torch.randn(1, 100), requires_grad=True)\n",
65
+ "segments = 2\n",
66
+ "\n",
67
+ "modules = [module for k, module in model._modules.items()]\n",
68
+ "modules"
69
+ ]
70
+ },
71
+ {
72
+ "cell_type": "code",
73
+ "execution_count": 7,
74
+ "metadata": {
75
+ "colab": {
76
+ "base_uri": "https://localhost:8080/"
77
+ },
78
+ "id": "aHSqU-keQaPe",
79
+ "outputId": "7ebc66fb-99ab-4d22-fa39-5710fb7ca2cd"
80
+ },
81
+ "outputs": [
82
+ {
83
+ "data": {
84
+ "text/plain": [
85
+ "tensor([[0.0000, 0.3800, 0.0000, 0.0000, 0.0000]], grad_fn=<ReluBackward0>)"
86
+ ]
87
+ },
88
+ "execution_count": 7,
89
+ "metadata": {},
90
+ "output_type": "execute_result"
91
+ }
92
+ ],
93
+ "source": [
94
+ "out = checkpoint_sequential(modules, segments, input_var, use_reentrant=False)\n",
95
+ "out"
96
+ ]
97
+ },
98
+ {
99
+ "cell_type": "code",
100
+ "execution_count": 8,
101
+ "metadata": {
102
+ "id": "Q94h7De4RBGA"
103
+ },
104
+ "outputs": [],
105
+ "source": [
106
+ "# run the backwards pass on the model. For backwards pass, for simplicity purpose,\n",
107
+ "# we won't calculate the loss and rather backprop on out.sum()\n",
108
+ "model.zero_grad()\n",
109
+ "out.sum().backward()"
110
+ ]
111
+ },
112
+ {
113
+ "cell_type": "code",
114
+ "execution_count": 9,
115
+ "metadata": {
116
+ "id": "LgNWA7fyRGAk"
117
+ },
118
+ "outputs": [],
119
+ "source": [
120
+ "# now we save the output and parameter gradients that we will use for comparison purposes with\n",
121
+ "# the non-checkpointed run.\n",
122
+ "output_checkpointed = out.data.clone()\n",
123
+ "grad_checkpointed = {}\n",
124
+ "for name, param in model.named_parameters():\n",
125
+ " grad_checkpointed[name] = param.grad.data.clone()"
126
+ ]
127
+ },
128
+ {
129
+ "cell_type": "markdown",
130
+ "metadata": {
131
+ "id": "qkdJd-B3RRWh"
132
+ },
133
+ "source": [
134
+ "Now that we have executed the checkpointed pass on the model, let's also run the non-checkpointed model and verify that the checkpoint API doesn't change the model outputs or the parameter gradients."
135
+ ]
136
+ },
137
+ {
138
+ "cell_type": "code",
139
+ "execution_count": 10,
140
+ "metadata": {
141
+ "id": "Ts5GQzxkRVrU"
142
+ },
143
+ "outputs": [],
144
+ "source": [
145
+ "# non-checkpointed run of the model\n",
146
+ "original = model\n",
147
+ "\n",
148
+ "# create a new variable using the same tensor data\n",
149
+ "x = Variable(input_var.data, requires_grad=True)\n",
150
+ "\n",
151
+ "# get the model output and save it to prevent any modifications\n",
152
+ "out = original(x)\n",
153
+ "out_not_checkpointed = out.data.clone()\n",
154
+ "\n",
155
+ "# calculate the gradient now and save the parameter gradients values\n",
156
+ "original.zero_grad()\n",
157
+ "out.sum().backward()\n",
158
+ "grad_not_checkpointed = {}\n",
159
+ "for name, param in model.named_parameters():\n",
160
+ " grad_not_checkpointed[name] = param.grad.data.clone()"
161
+ ]
162
+ },
163
+ {
164
+ "cell_type": "markdown",
165
+ "metadata": {
166
+ "id": "YiV1VBzyRX2Y"
167
+ },
168
+ "source": [
169
+ "Now that we have done the checkpointed and non-checkpointed pass of the model and saved the output and parameter gradients, let's compare their values"
170
+ ]
171
+ },
172
+ {
173
+ "cell_type": "code",
174
+ "execution_count": 13,
175
+ "metadata": {
176
+ "colab": {
177
+ "base_uri": "https://localhost:8080/"
178
+ },
179
+ "id": "v9Tj9o8VRYq2",
180
+ "outputId": "bd8a8100-d660-4858-eb48-4a85aca01c69"
181
+ },
182
+ "outputs": [
183
+ {
184
+ "name": "stdout",
185
+ "output_type": "stream",
186
+ "text": [
187
+ "Checkpointed and non-checkpointed results match!\n"
188
+ ]
189
+ }
190
+ ],
191
+ "source": [
192
+ "try:\n",
193
+ " assert torch.equal(output_checkpointed, out_not_checkpointed), \"Outputs do not match!\"\n",
194
+ " for name in grad_checkpointed:\n",
195
+ " assert torch.equal(grad_checkpointed[name], grad_not_checkpointed[name]), f\"Gradients for {name} do not match!\"\n",
196
+ " print(\"Checkpointed and non-checkpointed results match!\")\n",
197
+ "except AssertionError as e:\n",
198
+ " print(f\"Assertion failed: {e}\")"
199
+ ]
200
+ }
201
+ ],
202
+ "metadata": {
203
+ "colab": {
204
+ "provenance": []
205
+ },
206
+ "kernelspec": {
207
+ "display_name": "Python 3",
208
+ "name": "python3"
209
+ },
210
+ "language_info": {
211
+ "name": "python"
212
+ }
213
+ },
214
+ "nbformat": 4,
215
+ "nbformat_minor": 0
216
+ }
DDPM/_4_Activation-Checkpointing-VAE.ipynb ADDED
@@ -0,0 +1,444 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 7,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "gpu_ram_utilization_bytes = torch.cuda.memory_allocated()\n",
10
+ "gpu_ram_utilization_mb = gpu_ram_utilization_bytes / (1024 * 1024)\n",
11
+ "gpu_ram_total_bytes = torch.cuda.get_device_properties(0).total_memory\n",
12
+ "gpu_ram_percentage = (gpu_ram_utilization_bytes / gpu_ram_total_bytes) * 100"
13
+ ]
14
+ },
15
+ {
16
+ "cell_type": "code",
17
+ "execution_count": null,
18
+ "metadata": {},
19
+ "outputs": [],
20
+ "source": [
21
+ "gpu_ram_utilization_mb, gpu_ram_total_bytes"
22
+ ]
23
+ },
24
+ {
25
+ "cell_type": "code",
26
+ "execution_count": null,
27
+ "metadata": {
28
+ "colab": {
29
+ "base_uri": "https://localhost:8080/"
30
+ },
31
+ "id": "ellNFnP7f2Wx",
32
+ "outputId": "3adb85e1-f41a-433f-bd77-f1301abb7731"
33
+ },
34
+ "outputs": [],
35
+ "source": [
36
+ "import os\n",
37
+ "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
38
+ "\n",
39
+ "import psutil\n",
40
+ "import torch\n",
41
+ "from datetime import datetime\n",
42
+ "import time\n",
43
+ "import matplotlib.pyplot as plt\n",
44
+ "\n",
45
+ "\n",
46
+ "import torch\n",
47
+ "import torch.nn as nn\n",
48
+ "import torch.optim as optim\n",
49
+ "from torch.utils.data import DataLoader\n",
50
+ "from torchvision import datasets, transforms\n",
51
+ "import torch.nn.functional as F\n",
52
+ "\n",
53
+ "\n",
54
+ "\n",
55
+ "timestamps = []\n",
56
+ "cpu_ram_mb = []\n",
57
+ "cpu_ram_percent = []\n",
58
+ "gpu_ram_mb = []\n",
59
+ "gpu_ram_percent = []\n",
60
+ "\n",
61
+ "\n",
62
+ "\n",
63
+ "# --- System Utilization ---------------------------------------------------------------------------\n",
64
+ "def get_system_utilization():\n",
65
+ " current_time = datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\")\n",
66
+ "\n",
67
+ " cpu_ram = psutil.virtual_memory()\n",
68
+ " cpu_ram_utilization_bytes = cpu_ram.used\n",
69
+ " cpu_ram_utilization_mb = cpu_ram_utilization_bytes / (1024 * 1024)\n",
70
+ " cpu_ram_percentage = cpu_ram.percent\n",
71
+ "\n",
72
+ " gpu_ram_utilization_mb = None\n",
73
+ " gpu_ram_percentage = None\n",
74
+ " if torch.cuda.is_available():\n",
75
+ " gpu_ram_utilization_bytes = torch.cuda.memory_allocated()\n",
76
+ " gpu_ram_utilization_mb = gpu_ram_utilization_bytes / (1024 * 1024)\n",
77
+ " gpu_ram_total_bytes = torch.cuda.get_device_properties(0).total_memory\n",
78
+ " gpu_ram_percentage = (gpu_ram_utilization_bytes / gpu_ram_total_bytes) * 100\n",
79
+ "\n",
80
+ " return {\n",
81
+ " \"time\": current_time,\n",
82
+ " \"cpu_ram_utilization_mb\": cpu_ram_utilization_mb,\n",
83
+ " \"cpu_ram_percentage\": cpu_ram_percentage,\n",
84
+ " \"gpu_ram_utilization_mb\": gpu_ram_utilization_mb,\n",
85
+ " \"gpu_ram_percentage\": gpu_ram_percentage\n",
86
+ " }\n",
87
+ "\n",
88
+ "\n",
89
+ "\n",
90
+ "def update_utilization_lists():\n",
91
+ " global timestamps, cpu_ram_mb, cpu_ram_percent, gpu_ram_mb, gpu_ram_percent\n",
92
+ "\n",
93
+ " utilization = get_system_utilization()\n",
94
+ "\n",
95
+ " timestamps.append(utilization[\"time\"])\n",
96
+ " cpu_ram_mb.append(utilization[\"cpu_ram_utilization_mb\"])\n",
97
+ " cpu_ram_percent.append(utilization[\"cpu_ram_percentage\"])\n",
98
+ " gpu_ram_mb.append(utilization[\"gpu_ram_utilization_mb\"])\n",
99
+ " gpu_ram_percent.append(utilization[\"gpu_ram_percentage\"])\n",
100
+ "\n",
101
+ "\n",
102
+ "\n",
103
+ "# --- Define the VAE model -------------------------------------------------------------------------\n",
104
+ "class VAE(nn.Module):\n",
105
+ " update_utilization_lists()\n",
106
+ " def __init__(self, latent_dim=20):\n",
107
+ " super(VAE, self).__init__()\n",
108
+ " self.latent_dim = latent_dim\n",
109
+ "\n",
110
+ " # Encoder\n",
111
+ " update_utilization_lists()\n",
112
+ " self.encoder = nn.Sequential(\n",
113
+ " nn.Linear(28 * 28, 512),\n",
114
+ " nn.ReLU(),\n",
115
+ " nn.Linear(512, 256),\n",
116
+ " nn.ReLU(),\n",
117
+ " nn.Linear(256, 2 * latent_dim) # Output mean and log variance\n",
118
+ " )\n",
119
+ "\n",
120
+ " # Decoder\n",
121
+ " update_utilization_lists()\n",
122
+ " self.decoder = nn.Sequential(\n",
123
+ " nn.Linear(latent_dim, 256),\n",
124
+ " nn.ReLU(),\n",
125
+ " nn.Linear(256, 512),\n",
126
+ " nn.ReLU(),\n",
127
+ " nn.Linear(512, 28 * 28),\n",
128
+ " nn.Sigmoid()\n",
129
+ " )\n",
130
+ "\n",
131
+ " def encode(self, x):\n",
132
+ " update_utilization_lists()\n",
133
+ " h = self.encoder(x)\n",
134
+ "\n",
135
+ " update_utilization_lists()\n",
136
+ " mu, logvar = h.chunk(2, dim=-1) # Split into mean and log variance\n",
137
+ "\n",
138
+ " update_utilization_lists()\n",
139
+ " return mu, logvar\n",
140
+ "\n",
141
+ " def reparameterize(self, mu, logvar):\n",
142
+ " update_utilization_lists()\n",
143
+ " std = torch.exp(0.5 * logvar)\n",
144
+ "\n",
145
+ " update_utilization_lists()\n",
146
+ " eps = torch.randn_like(std)\n",
147
+ "\n",
148
+ " update_utilization_lists()\n",
149
+ " return mu + eps * std\n",
150
+ "\n",
151
+ " def decode(self, z):\n",
152
+ " update_utilization_lists()\n",
153
+ " decoded = self.decoder(z)\n",
154
+ "\n",
155
+ " update_utilization_lists()\n",
156
+ " return decoded\n",
157
+ "\n",
158
+ " def forward(self, x):\n",
159
+ " update_utilization_lists()\n",
160
+ " mu, logvar = self.encode(x.view(-1, 28 * 28))\n",
161
+ "\n",
162
+ " z = self.reparameterize(mu, logvar)\n",
163
+ " return self.decode(z), mu, logvar\n",
164
+ "\n",
165
+ "\n",
166
+ "\n",
167
+ "# --- Loss function --------------------------------------------------------------------------------\n",
168
+ "def loss_function(recon_x, x, mu, logvar):\n",
169
+ " update_utilization_lists()\n",
170
+ " BCE = F.binary_cross_entropy(recon_x, x.view(-1, 28 * 28), reduction='sum')\n",
171
+ " \n",
172
+ " update_utilization_lists()\n",
173
+ " KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())\n",
174
+ " \n",
175
+ " update_utilization_lists()\n",
176
+ " return BCE + KLD\n",
177
+ "\n",
178
+ "\n",
179
+ "\n",
180
+ "# --- Load MNIST dataset ---------------------------------------------------------------------------\n",
181
+ "transform = transforms.Compose([transforms.ToTensor()])\n",
182
+ "train_dataset = datasets.MNIST(root='/home/23m1521/datasets/MNIST', train=True, download=True, transform=transform)\n",
183
+ "train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=os.cpu_count())\n",
184
+ "\n",
185
+ "\n",
186
+ "\n",
187
+ "# --- Initialize model, optimizer ------------------------------------------------------------------\n",
188
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
189
+ "model = VAE(latent_dim=20).to(device)\n",
190
+ "optimizer = optim.Adam(model.parameters(), lr=1e-3)\n",
191
+ "\n",
192
+ "\n",
193
+ "\n",
194
+ "# --- Training loop --------------------------------------------------------------------------------\n",
195
+ "def train(epoch):\n",
196
+ " update_utilization_lists()\n",
197
+ " model.train()\n",
198
+ " \n",
199
+ " train_loss = 0\n",
200
+ " for batch_idx, (data, _) in enumerate(train_loader):\n",
201
+ " update_utilization_lists()\n",
202
+ " \n",
203
+ " data = data.to(device)\n",
204
+ " update_utilization_lists()\n",
205
+ " \n",
206
+ " optimizer.zero_grad()\n",
207
+ " update_utilization_lists()\n",
208
+ " \n",
209
+ " recon_batch, mu, logvar = model(data)\n",
210
+ " update_utilization_lists()\n",
211
+ " \n",
212
+ " loss = loss_function(recon_batch, data, mu, logvar)\n",
213
+ " update_utilization_lists()\n",
214
+ " \n",
215
+ " loss.backward()\n",
216
+ " update_utilization_lists()\n",
217
+ " \n",
218
+ " train_loss += loss.item()\n",
219
+ " update_utilization_lists()\n",
220
+ " \n",
221
+ " optimizer.step()\n",
222
+ " update_utilization_lists()\n",
223
+ "\n",
224
+ " if batch_idx % 100 == 0:\n",
225
+ " print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} '\n",
226
+ " f'({100. * batch_idx / len(train_loader):.0f}%)]\\tLoss: {loss.item() / len(data):.6f}')\n",
227
+ "\n",
228
+ " print(f'====> Epoch: {epoch} Average loss: {train_loss / len(train_loader.dataset):.4f}')\n",
229
+ "\n",
230
+ "\n",
231
+ "\n",
232
+ "# --- Train for 10 epochs --------------------------------------------------------------------------\n",
233
+ "for epoch in range(1,3):\n",
234
+ " update_utilization_lists()\n",
235
+ " train(epoch)\n",
236
+ " update_utilization_lists()"
237
+ ]
238
+ },
239
+ {
240
+ "cell_type": "code",
241
+ "execution_count": null,
242
+ "metadata": {
243
+ "colab": {
244
+ "base_uri": "https://localhost:8080/"
245
+ },
246
+ "id": "6M9KOwxshmZF",
247
+ "outputId": "274be81e-b8a7-4100-f6d8-235d5a8ffb6d"
248
+ },
249
+ "outputs": [],
250
+ "source": [
251
+ "print(\"CPU RAM (MB):\", min(cpu_ram_mb), max(cpu_ram_mb))\n",
252
+ "print(\"CPU RAM (%):\", min(cpu_ram_percent), max(cpu_ram_percent))\n",
253
+ "if torch.cuda.is_available():\n",
254
+ " print(\"GPU RAM (MB):\", min(gpu_ram_mb), max(gpu_ram_mb))\n",
255
+ " print(\"GPU RAM (%):\", min(gpu_ram_percent), max(gpu_ram_percent))"
256
+ ]
257
+ },
258
+ {
259
+ "cell_type": "code",
260
+ "execution_count": null,
261
+ "metadata": {
262
+ "colab": {
263
+ "base_uri": "https://localhost:8080/",
264
+ "height": 400
265
+ },
266
+ "id": "mKdK390Ehq7u",
267
+ "outputId": "524a035c-98c5-4c45-99c8-96a882007427"
268
+ },
269
+ "outputs": [],
270
+ "source": [
271
+ "plt.figure(figsize=(21, 8))\n",
272
+ "\n",
273
+ "# --- Plot CPU RAM Utilization (MB) ----------------------------------------------------------------\n",
274
+ "plt.subplot(2, 2, 1)\n",
275
+ "plt.plot(range(len(timestamps)), cpu_ram_mb, label=\"CPU RAM (MB)\")\n",
276
+ "plt.title(\"CPU RAM Utilization (MB)\")\n",
277
+ "plt.xlabel(\"Time\")\n",
278
+ "plt.ylabel(\"MB\")\n",
279
+ "plt.xticks(rotation=45)\n",
280
+ "plt.grid(True)\n",
281
+ "plt.legend()\n",
282
+ "\n",
283
+ "# --- Plot CPU RAM Utilization (%) -----------------------------------------------------------------\n",
284
+ "plt.subplot(2, 2, 2)\n",
285
+ "plt.plot(range(len(timestamps)), cpu_ram_percent, label=\"CPU RAM (%)\", color=\"orange\")\n",
286
+ "plt.title(\"CPU RAM Utilization (%)\")\n",
287
+ "plt.xlabel(\"Time\")\n",
288
+ "plt.ylabel(\"Percentage\")\n",
289
+ "plt.xticks(rotation=45)\n",
290
+ "plt.grid(True)\n",
291
+ "plt.legend()\n",
292
+ "\n",
293
+ "# --- Plot GPU RAM Utilization (MB) if GPU exists --------------------------------------------------\n",
294
+ "if torch.cuda.is_available():\n",
295
+ " plt.subplot(2, 2, 3)\n",
296
+ " plt.plot(range(len(timestamps)), gpu_ram_mb, label=\"GPU RAM (MB)\", color=\"green\")\n",
297
+ " plt.title(\"GPU RAM Utilization (MB)\")\n",
298
+ " plt.xlabel(\"Time\")\n",
299
+ " plt.ylabel(\"MB\")\n",
300
+ " plt.xticks(rotation=45)\n",
301
+ " plt.grid(True)\n",
302
+ " plt.legend()\n",
303
+ "\n",
304
+ "\n",
305
+ "# --- Plot GPU RAM Utilization (%) if GPU exists ---------------------------------------------------\n",
306
+ " plt.subplot(2, 2, 4)\n",
307
+ " plt.plot(range(len(timestamps)), gpu_ram_percent, label=\"GPU RAM (%)\", color=\"red\")\n",
308
+ " plt.title(\"GPU RAM Utilization (%)\")\n",
309
+ " plt.xlabel(\"Time\")\n",
310
+ " plt.ylabel(\"Percentage\")\n",
311
+ " plt.xticks(rotation=45)\n",
312
+ " plt.grid(True)\n",
313
+ " plt.legend()\n",
314
+ "\n",
315
+ "\n",
316
+ "plt.tight_layout()\n",
317
+ "plt.show()"
318
+ ]
319
+ },
320
+ {
321
+ "cell_type": "code",
322
+ "execution_count": null,
323
+ "metadata": {},
324
+ "outputs": [],
325
+ "source": [
326
+ "if torch.cuda.is_available():\n",
327
+ " fig.add_trace(\n",
328
+ " go.Scatter(x=list(range(len(timestamps))), y=gpu_ram_mb, mode='lines', name='GPU RAM (MB)', line=dict(color='green')),\n",
329
+ " row=2, col=1\n",
330
+ " )\n",
331
+ "fig.show() "
332
+ ]
333
+ },
334
+ {
335
+ "cell_type": "code",
336
+ "execution_count": null,
337
+ "metadata": {},
338
+ "outputs": [],
339
+ "source": [
340
+ "import plotly.graph_objects as go\n",
341
+ "from plotly.subplots import make_subplots\n",
342
+ "import torch\n",
343
+ "\n",
344
+ "# Create subplots\n",
345
+ "fig = make_subplots(\n",
346
+ " rows=2, cols=2,\n",
347
+ " subplot_titles=(\"CPU RAM Utilization (MB)\", \"CPU RAM Utilization (%)\",\n",
348
+ " \"GPU RAM Utilization (MB)\", \"GPU RAM Utilization (%)\")\n",
349
+ ")\n",
350
+ "\n",
351
+ "# Plot CPU RAM Utilization (MB)\n",
352
+ "fig.add_trace(\n",
353
+ " go.Scatter(x=list(range(len(timestamps))), y=cpu_ram_mb, mode='lines', name='CPU RAM (MB)'),\n",
354
+ " row=1, col=1\n",
355
+ ")\n",
356
+ "\n",
357
+ "# Plot CPU RAM Utilization (%)\n",
358
+ "fig.add_trace(\n",
359
+ " go.Scatter(x=list(range(len(timestamps))), y=cpu_ram_percent, mode='lines', name='CPU RAM (%)', line=dict(color='orange')),\n",
360
+ " row=1, col=2\n",
361
+ ")\n",
362
+ "\n",
363
+ "# Plot GPU RAM Utilization (MB) if GPU exists\n",
364
+ "if torch.cuda.is_available():\n",
365
+ " fig.add_trace(\n",
366
+ " go.Scatter(x=list(range(len(timestamps))), y=gpu_ram_mb, mode='lines', name='GPU RAM (MB)', line=dict(color='green')),\n",
367
+ " row=2, col=1\n",
368
+ " )\n",
369
+ "\n",
370
+ " # Plot GPU RAM Utilization (%)\n",
371
+ " fig.add_trace(\n",
372
+ " go.Scatter(x=list(range(len(timestamps))), y=gpu_ram_percent, mode='lines', name='GPU RAM (%)', line=dict(color='red')),\n",
373
+ " row=2, col=2\n",
374
+ " )\n",
375
+ "\n",
376
+ "# Update layout\n",
377
+ "fig.update_layout(\n",
378
+ " height=800, width=1200,\n",
379
+ " title_text=\"System Resource Utilization\",\n",
380
+ " showlegend=True\n",
381
+ ")\n",
382
+ "\n",
383
+ "fig.update_xaxes(title_text=\"Time\", tickangle=45)\n",
384
+ "fig.update_yaxes(title_text=\"MB or Percentage\")\n",
385
+ "\n",
386
+ "# Show plot\n",
387
+ "fig.show()"
388
+ ]
389
+ },
390
+ {
391
+ "cell_type": "code",
392
+ "execution_count": null,
393
+ "metadata": {
394
+ "colab": {
395
+ "base_uri": "https://localhost:8080/",
396
+ "height": 454
397
+ },
398
+ "id": "3MGfGd_Ojcrf",
399
+ "outputId": "f1091984-2658-4053-ff08-c7c300c08d0e"
400
+ },
401
+ "outputs": [],
402
+ "source": [
403
+ "plt.figure(figsize=(21, 4))\n",
404
+ "\n",
405
+ "r = 12000 # range(len(timestamps))\n",
406
+ "x, y = range(r), cpu_ram_mb[:r]\n",
407
+ "\n",
408
+ "plt.plot(x, y, label=\"CPU RAM (MB)\")\n",
409
+ "plt.title(\"CPU RAM Utilization (MB)\")\n",
410
+ "plt.xlabel(\"Time\")\n",
411
+ "plt.ylabel(\"MB\")\n",
412
+ "plt.xticks(rotation=45)\n",
413
+ "plt.grid(True)\n",
414
+ "plt.legend()\n",
415
+ "plt.tight_layout()\n",
416
+ "plt.show()"
417
+ ]
418
+ }
419
+ ],
420
+ "metadata": {
421
+ "colab": {
422
+ "provenance": []
423
+ },
424
+ "kernelspec": {
425
+ "display_name": "cuda_env2",
426
+ "language": "python",
427
+ "name": "python3"
428
+ },
429
+ "language_info": {
430
+ "codemirror_mode": {
431
+ "name": "ipython",
432
+ "version": 3
433
+ },
434
+ "file_extension": ".py",
435
+ "mimetype": "text/x-python",
436
+ "name": "python",
437
+ "nbconvert_exporter": "python",
438
+ "pygments_lexer": "ipython3",
439
+ "version": "3.12.2"
440
+ }
441
+ },
442
+ "nbformat": 4,
443
+ "nbformat_minor": 0
444
+ }
DDPM/_5_Activation-Ckpt-VAE-CelebA.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
Imgui/demo-newstyle.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import os
3
+ import sys
4
+
5
+ # For Linux/Wayland users.
6
+ if os.getenv("XDG_SESSION_TYPE") == "wayland":
7
+ os.environ["XDG_SESSION_TYPE"] = "x11"
8
+
9
+ import glfw
10
+ import OpenGL.GL as gl
11
+ import imgui
12
+ from imgui.integrations.glfw import GlfwRenderer
13
+
14
+ active = {
15
+ "window": True,
16
+ "child": False,
17
+ "tooltip": False,
18
+ "menu bar": False,
19
+ "popup": False,
20
+ "popup modal": False,
21
+ "popup context item": False,
22
+ "popup context window": False,
23
+ "drag drop": False,
24
+ "group": False,
25
+ "tab bar": False,
26
+ "list box": False,
27
+ "popup context void": False,
28
+ "table": False,
29
+ }
30
+ path_to_font = None # "path/to/font.ttf"
31
+
32
+ opened_state = True
33
+
34
+ # Frame commands from the video
35
+ # def frame_commands():
36
+ # io = imgui.get_io()
37
+ # if io.key_ctrl and io.keys_down[glfw.KEY_Q]:
38
+ # sys.exit(0)
39
+ #
40
+ # if imgui.begin_main_menu_bar():
41
+ # if imgui.begin_menu("File"):
42
+ # clicked, selected = imgui.menu_item("Quit", "Ctrl+Q")
43
+ # if clicked:
44
+ # sys.exit(0)
45
+ # imgui.end_menu()
46
+ # imgui.end_main_menu_bar()
47
+ #
48
+ # with imgui.begin("A Window!"):
49
+ # if imgui.button("select"):
50
+ # imgui.open_popup("select-popup")
51
+ #
52
+ # try:
53
+ # with imgui.begin_popup("select-popup") as popup:
54
+ # if popup.opened:
55
+ # imgui.text("Select one")
56
+ # raise Exception
57
+ # except Exception:
58
+ # print("caught exception and no crash!")
59
+
60
+
61
+ def frame_commands():
62
+ io = imgui.get_io()
63
+
64
+ if io.key_ctrl and io.keys_down[glfw.KEY_Q]:
65
+ sys.exit(0)
66
+
67
+ with imgui.begin_main_menu_bar() as main_menu_bar:
68
+ if main_menu_bar.opened:
69
+ with imgui.begin_menu("File", True) as file_menu:
70
+ if file_menu.opened:
71
+ clicked_quit, selected_quit = imgui.menu_item("Quit", "Ctrl+Q")
72
+ if clicked_quit:
73
+ sys.exit(0)
74
+
75
+ # turn examples on/off
76
+ with imgui.begin("Active examples"):
77
+ for label, enabled in active.copy().items():
78
+ _, enabled = imgui.checkbox(label, enabled)
79
+ active[label] = enabled
80
+
81
+ if active["window"]:
82
+ with imgui.begin("Hello, Imgui!"):
83
+ imgui.text("Hello, World!")
84
+
85
+ if active["child"]:
86
+ with imgui.begin("Example: child region"):
87
+ with imgui.begin_child("region", 150, -50, border=True):
88
+ imgui.text("inside region")
89
+ imgui.text("outside region")
90
+
91
+ if active["tooltip"]:
92
+ with imgui.begin("Example: tooltip"):
93
+ imgui.button("Click me!")
94
+ if imgui.is_item_hovered():
95
+ with imgui.begin_tooltip():
96
+ imgui.text("This button is clickable.")
97
+
98
+ if active["menu bar"]:
99
+ try:
100
+ flags = imgui.WINDOW_MENU_BAR
101
+ with imgui.begin("Child Window - File Browser", flags=flags):
102
+ with imgui.begin_menu_bar() as menu_bar:
103
+ if menu_bar.opened:
104
+ with imgui.begin_menu('File') as file_menu:
105
+ if file_menu.opened:
106
+ clicked, state = imgui.menu_item('Close')
107
+ if clicked:
108
+ active["menu bar"] = False
109
+ raise Exception
110
+ except Exception:
111
+ print("exception handled")
112
+
113
+ if active["popup"]:
114
+ with imgui.begin("Example: simple popup"):
115
+ if imgui.button("select"):
116
+ imgui.open_popup("select-popup")
117
+ imgui.same_line()
118
+ with imgui.begin_popup("select-popup") as popup:
119
+ if popup.opened:
120
+ imgui.text("Select one")
121
+ imgui.separator()
122
+ imgui.selectable("One")
123
+ imgui.selectable("Two")
124
+ imgui.selectable("Three")
125
+
126
+ if active["popup modal"]:
127
+ with imgui.begin("Example: simple popup modal"):
128
+ if imgui.button("Open Modal popup"):
129
+ imgui.open_popup("select-popup-modal")
130
+ imgui.same_line()
131
+ with imgui.begin_popup_modal("select-popup-modal") as popup:
132
+ if popup.opened:
133
+ imgui.text("Select an option:")
134
+ imgui.separator()
135
+ imgui.selectable("One")
136
+ imgui.selectable("Two")
137
+ imgui.selectable("Three")
138
+
139
+ if active["popup context item"]:
140
+ with imgui.begin("Example: popup context view"):
141
+ imgui.text("Right-click to set value.")
142
+ with imgui.begin_popup_context_item("Item Context Menu") as popup:
143
+ if popup.opened:
144
+ imgui.selectable("Set to Zero")
145
+
146
+ if active["popup context window"]:
147
+ with imgui.begin("Example: popup context window"):
148
+ with imgui.begin_popup_context_window() as popup:
149
+ if popup.opened:
150
+ imgui.selectable("Clear")
151
+
152
+ if active["popup context void"]:
153
+ with imgui.begin_popup_context_void() as popup:
154
+ if popup.opened:
155
+ imgui.selectable("Clear")
156
+
157
+ if active["drag drop"]:
158
+ with imgui.begin("Example: drag and drop"):
159
+ imgui.button('source')
160
+ with imgui.begin_drag_drop_source() as src:
161
+ if src.dragging:
162
+ imgui.set_drag_drop_payload('itemtype', b'payload')
163
+ imgui.button('dragged source')
164
+ imgui.button('dest')
165
+ with imgui.begin_drag_drop_target() as dst:
166
+ if dst.hovered:
167
+ payload = imgui.accept_drag_drop_payload('itemtype')
168
+ if payload is not None:
169
+ print('Received:', payload)
170
+
171
+ if active["group"]:
172
+ with imgui.begin("Example: item groups"):
173
+ with imgui.begin_group():
174
+ imgui.text("First group (buttons):")
175
+ imgui.button("Button A")
176
+ imgui.button("Button B")
177
+ imgui.same_line(spacing=50)
178
+ with imgui.begin_group():
179
+ imgui.text("Second group (text and bullet texts):")
180
+ imgui.bullet_text("Bullet A")
181
+ imgui.bullet_text("Bullet B")
182
+
183
+ if active["tab bar"]:
184
+ with imgui.begin("Example Tab Bar"):
185
+ with imgui.begin_tab_bar("MyTabBar") as tab_bar:
186
+ if tab_bar.opened:
187
+ with imgui.begin_tab_item("Item 1") as item1:
188
+ if item1.opened:
189
+ imgui.text("Here is the tab content!")
190
+ with imgui.begin_tab_item("Item 2") as item2:
191
+ if item2.opened:
192
+ imgui.text("Another content...")
193
+ global opened_state
194
+ with imgui.begin_tab_item("Item 3", opened=opened_state) as item3:
195
+ opened_state = item3.opened
196
+ if item3.selected:
197
+ imgui.text("Hello Saylor!")
198
+
199
+ if active["list box"]:
200
+ with imgui.begin("Example: custom listbox"):
201
+ with imgui.begin_list_box("List", 200, 100) as list_box:
202
+ if list_box.opened:
203
+ imgui.selectable("Selected", True)
204
+ imgui.selectable("Not Selected", False)
205
+
206
+ if active["table"]:
207
+ with imgui.begin("Example: table"):
208
+ with imgui.begin_table("data", 2) as table:
209
+ if table.opened:
210
+ imgui.table_next_column()
211
+ imgui.table_header("A")
212
+ imgui.table_next_column()
213
+ imgui.table_header("B")
214
+
215
+ imgui.table_next_row()
216
+ imgui.table_next_column()
217
+ imgui.text("123")
218
+
219
+ imgui.table_next_column()
220
+ imgui.text("456")
221
+
222
+ imgui.table_next_row()
223
+ imgui.table_next_column()
224
+ imgui.text("789")
225
+
226
+ imgui.table_next_column()
227
+ imgui.text("111")
228
+
229
+ imgui.table_next_row()
230
+ imgui.table_next_column()
231
+ imgui.text("222")
232
+
233
+ imgui.table_next_column()
234
+ imgui.text("333")
235
+
236
+
237
+ def render_frame(impl, window, font):
238
+ glfw.poll_events()
239
+ impl.process_inputs()
240
+ imgui.new_frame()
241
+
242
+ gl.glClearColor(0.1, 0.1, 0.1, 1)
243
+ gl.glClear(gl.GL_COLOR_BUFFER_BIT)
244
+
245
+ if font is not None:
246
+ imgui.push_font(font)
247
+ frame_commands()
248
+ if font is not None:
249
+ imgui.pop_font()
250
+
251
+ imgui.render()
252
+ impl.render(imgui.get_draw_data())
253
+ glfw.swap_buffers(window)
254
+
255
+
256
+ def impl_glfw_init():
257
+ width, height = 1600, 900
258
+ window_name = "minimal ImGui/GLFW3 example"
259
+
260
+ if not glfw.init():
261
+ print("Could not initialize OpenGL context")
262
+ sys.exit(1)
263
+
264
+ glfw.window_hint(glfw.CONTEXT_VERSION_MAJOR, 3)
265
+ glfw.window_hint(glfw.CONTEXT_VERSION_MINOR, 3)
266
+ glfw.window_hint(glfw.OPENGL_PROFILE, glfw.OPENGL_CORE_PROFILE)
267
+ glfw.window_hint(glfw.OPENGL_FORWARD_COMPAT, gl.GL_TRUE)
268
+
269
+ window = glfw.create_window(int(width), int(height), window_name, None, None)
270
+ glfw.make_context_current(window)
271
+
272
+ if not window:
273
+ glfw.terminate()
274
+ print("Could not initialize Window")
275
+ sys.exit(1)
276
+
277
+ return window
278
+
279
+
280
+ def main():
281
+ imgui.create_context()
282
+ window = impl_glfw_init()
283
+
284
+ impl = GlfwRenderer(window)
285
+
286
+ io = imgui.get_io()
287
+ jb = io.fonts.add_font_from_file_ttf(path_to_font, 30) if path_to_font is not None else None
288
+ impl.refresh_font_texture()
289
+
290
+ while not glfw.window_should_close(window):
291
+ render_frame(impl, window, jb)
292
+
293
+ impl.shutdown()
294
+ glfw.terminate()
295
+
296
+
297
+ if __name__ == "__main__":
298
+ main()
Imgui/demo.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pip install glfw
2
+ # pip install PyOpenGL
3
+ # pip install imgui
4
+
5
+
6
+ # -*- coding: utf-8 -*-
7
+ import os
8
+ import sys
9
+
10
+ # For Linux/Wayland users.
11
+ if os.getenv("XDG_SESSION_TYPE") == "wayland":
12
+ os.environ["XDG_SESSION_TYPE"] = "x11"
13
+
14
+ import glfw
15
+ import OpenGL.GL as gl
16
+ import imgui
17
+ from imgui.integrations.glfw import GlfwRenderer
18
+
19
+ active = {
20
+ "window": True,
21
+ "child": False,
22
+ "tooltip": False,
23
+ "menu bar": False,
24
+ "popup": False,
25
+ "popup modal": False,
26
+ "popup context item": False,
27
+ "popup context window": False,
28
+ "drag drop": False,
29
+ "group": False,
30
+ "tab bar": False,
31
+ "list box": False,
32
+ "popup context void": False,
33
+ "table": False,
34
+ }
35
+
36
+ path_to_font = None # "path/to/font.ttf"
37
+
38
+ opened_state = True
39
+
40
+
41
+ def frame_commands():
42
+ gl.glClearColor(0.1, 0.1, 0.1, 1)
43
+ gl.glClear(gl.GL_COLOR_BUFFER_BIT)
44
+
45
+ io = imgui.get_io()
46
+
47
+ if io.key_ctrl and io.keys_down[glfw.KEY_Q]:
48
+ sys.exit(0)
49
+
50
+ if imgui.begin_main_menu_bar():
51
+ if imgui.begin_menu("File", True):
52
+ clicked_quit, selected_quit = imgui.menu_item("Quit", "Ctrl+Q", False, True)
53
+
54
+ if clicked_quit:
55
+ sys.exit(0)
56
+
57
+ imgui.end_menu()
58
+ imgui.end_main_menu_bar()
59
+
60
+ # turn windows on/off
61
+ imgui.begin("Active examples")
62
+ for label, enabled in active.copy().items():
63
+ _, enabled = imgui.checkbox(label, enabled)
64
+ active[label] = enabled
65
+ imgui.end()
66
+
67
+ if active["window"]:
68
+ imgui.begin("Hello, Imgui!")
69
+ imgui.text("Hello, World!")
70
+ imgui.end()
71
+
72
+ if active["child"]:
73
+ imgui.begin("Example: child region")
74
+ imgui.begin_child("region", 150, -50, border=True)
75
+ imgui.text("inside region")
76
+ imgui.end_child()
77
+ imgui.text("outside region")
78
+ imgui.end()
79
+
80
+ if active["tooltip"]:
81
+ imgui.begin("Example: tooltip")
82
+ imgui.button("Click me!")
83
+ if imgui.is_item_hovered():
84
+ imgui.begin_tooltip()
85
+ imgui.text("This button is clickable.")
86
+ imgui.end_tooltip()
87
+ imgui.end()
88
+
89
+ if active["menu bar"]:
90
+ try:
91
+ flags = imgui.WINDOW_MENU_BAR
92
+ imgui.begin("Child Window - File Browser", flags=flags)
93
+ if imgui.begin_menu_bar():
94
+ if imgui.begin_menu('File'):
95
+ clicked, state = imgui.menu_item('Close')
96
+ if clicked:
97
+ active["menu bar"] = False
98
+ raise Exception
99
+ imgui.end_menu()
100
+ imgui.end_menu_bar()
101
+ imgui.end()
102
+ except Exception:
103
+ print("exception caught, but too late!")
104
+
105
+ if active["popup"]:
106
+ imgui.begin("Example: simple popup")
107
+ if imgui.button("select"):
108
+ imgui.open_popup("select-popup")
109
+ imgui.same_line()
110
+ if imgui.begin_popup("select-popup"):
111
+ imgui.text("Select one")
112
+ imgui.separator()
113
+ imgui.selectable("One")
114
+ imgui.selectable("Two")
115
+ imgui.selectable("Three")
116
+ imgui.end_popup()
117
+ imgui.end()
118
+
119
+ if active["popup modal"]:
120
+ imgui.begin("Example: simple popup modal")
121
+ if imgui.button("Open Modal popup"):
122
+ imgui.open_popup("select-popup-modal")
123
+ imgui.same_line()
124
+ if imgui.begin_popup_modal("select-popup-modal")[0]:
125
+ imgui.text("Select an option:")
126
+ imgui.separator()
127
+ imgui.selectable("One")
128
+ imgui.selectable("Two")
129
+ imgui.selectable("Three")
130
+ imgui.end_popup()
131
+ imgui.end()
132
+
133
+ if active["popup context item"]:
134
+ imgui.begin("Example: popup context view")
135
+ imgui.text("Right-click to set value.")
136
+ if imgui.begin_popup_context_item("Item Context Menu"):
137
+ imgui.selectable("Set to Zero")
138
+ imgui.end_popup()
139
+ imgui.end()
140
+
141
+ if active["popup context window"]:
142
+ imgui.begin("Example: popup context window")
143
+ if imgui.begin_popup_context_window():
144
+ imgui.selectable("Clear")
145
+ imgui.end_popup()
146
+ imgui.end()
147
+
148
+ if active["popup context void"]:
149
+ if imgui.begin_popup_context_void():
150
+ imgui.selectable("Clear")
151
+ imgui.end_popup()
152
+
153
+ if active["drag drop"]:
154
+ imgui.begin("Example: drag and drop")
155
+ imgui.button('source')
156
+ if imgui.begin_drag_drop_source():
157
+ imgui.set_drag_drop_payload('itemtype', b'payload')
158
+ imgui.button('dragged source')
159
+ imgui.end_drag_drop_source()
160
+ imgui.button('dest')
161
+ if imgui.begin_drag_drop_target():
162
+ payload = imgui.accept_drag_drop_payload('itemtype')
163
+ if payload is not None:
164
+ print('Received:', payload)
165
+ imgui.end_drag_drop_target()
166
+ imgui.end()
167
+
168
+ if active["group"]:
169
+ imgui.begin("Example: item groups")
170
+ imgui.begin_group()
171
+ imgui.text("First group (buttons):")
172
+ imgui.button("Button A")
173
+ imgui.button("Button B")
174
+ imgui.end_group()
175
+ imgui.same_line(spacing=50)
176
+ imgui.begin_group()
177
+ imgui.text("Second group (text and bullet texts):")
178
+ imgui.bullet_text("Bullet A")
179
+ imgui.bullet_text("Bullet B")
180
+ imgui.end_group()
181
+ imgui.end()
182
+
183
+ if active["tab bar"]:
184
+ imgui.begin("Example Tab Bar")
185
+ if imgui.begin_tab_bar("MyTabBar"):
186
+ if imgui.begin_tab_item("Item 1")[0]:
187
+ imgui.text("Here is the tab content!")
188
+ imgui.end_tab_item()
189
+ if imgui.begin_tab_item("Item 2")[0]:
190
+ imgui.text("Another content...")
191
+ imgui.end_tab_item()
192
+ global opened_state
193
+ selected, opened_state = imgui.begin_tab_item("Item 3", opened=opened_state)
194
+ if selected:
195
+ imgui.text("Hello Saylor!")
196
+ imgui.end_tab_item()
197
+ imgui.end_tab_bar()
198
+ imgui.end()
199
+
200
+ if active["list box"]:
201
+ imgui.begin("Example: custom listbox")
202
+ if imgui.begin_list_box("List", 200, 100):
203
+ imgui.selectable("Selected", True)
204
+ imgui.selectable("Not Selected", False)
205
+ imgui.end_list_box()
206
+ imgui.end()
207
+
208
+ if active["table"]:
209
+ imgui.begin("Example: table")
210
+ if imgui.begin_table("data", 2):
211
+ imgui.table_next_column()
212
+ imgui.table_header("A")
213
+ imgui.table_next_column()
214
+ imgui.table_header("B")
215
+
216
+ imgui.table_next_row()
217
+ imgui.table_next_column()
218
+ imgui.text("123")
219
+
220
+ imgui.table_next_column()
221
+ imgui.text("456")
222
+
223
+ imgui.table_next_row()
224
+ imgui.table_next_column()
225
+ imgui.text("789")
226
+
227
+ imgui.table_next_column()
228
+ imgui.text("111")
229
+
230
+ imgui.table_next_row()
231
+ imgui.table_next_column()
232
+ imgui.text("222")
233
+
234
+ imgui.table_next_column()
235
+ imgui.text("333")
236
+ imgui.end_table()
237
+ imgui.end()
238
+
239
+
240
+ def render_frame(impl, window, font):
241
+ glfw.poll_events()
242
+ impl.process_inputs()
243
+ imgui.new_frame()
244
+
245
+ gl.glClearColor(0.1, 0.1, 0.1, 1)
246
+ gl.glClear(gl.GL_COLOR_BUFFER_BIT)
247
+
248
+ if font is not None:
249
+ imgui.push_font(font)
250
+ frame_commands()
251
+ if font is not None:
252
+ imgui.pop_font()
253
+
254
+ imgui.render()
255
+ impl.render(imgui.get_draw_data())
256
+ glfw.swap_buffers(window)
257
+
258
+
259
+ def impl_glfw_init():
260
+ width, height = 1600, 900
261
+ window_name = "minimal ImGui/GLFW3 example"
262
+
263
+ if not glfw.init():
264
+ print("Could not initialize OpenGL context")
265
+ sys.exit(1)
266
+
267
+ glfw.window_hint(glfw.CONTEXT_VERSION_MAJOR, 3)
268
+ glfw.window_hint(glfw.CONTEXT_VERSION_MINOR, 3)
269
+ glfw.window_hint(glfw.OPENGL_PROFILE, glfw.OPENGL_CORE_PROFILE)
270
+ glfw.window_hint(glfw.OPENGL_FORWARD_COMPAT, gl.GL_TRUE)
271
+
272
+ window = glfw.create_window(int(width), int(height), window_name, None, None)
273
+ glfw.make_context_current(window)
274
+
275
+ if not window:
276
+ glfw.terminate()
277
+ print("Could not initialize Window")
278
+ sys.exit(1)
279
+
280
+ return window
281
+
282
+
283
+ def main():
284
+ imgui.create_context()
285
+ window = impl_glfw_init()
286
+
287
+ impl = GlfwRenderer(window)
288
+
289
+ io = imgui.get_io()
290
+ jb = io.fonts.add_font_from_file_ttf(path_to_font, 30) if path_to_font is not None else None
291
+ impl.refresh_font_texture()
292
+
293
+ while not glfw.window_should_close(window):
294
+ render_frame(impl, window, jb)
295
+
296
+ impl.shutdown()
297
+ glfw.terminate()
298
+
299
+
300
+ if __name__ == "__main__":
301
+ main()
Imgui/imgui.ini ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [Window][Debug##Default]
2
+ Pos=60,60
3
+ Size=400,400
4
+ Collapsed=0
5
+
6
+ [Window][Active examples]
7
+ Pos=21,83
8
+ Size=179,353
9
+ Collapsed=0
10
+
11
+ [Window][Hello, Imgui!]
12
+ Pos=60,60
13
+ Size=107,48
14
+ Collapsed=0
15
+
16
+ [Window][Example: table]
17
+ Pos=60,60
18
+ Size=66,103
19
+ Collapsed=0
20
+
21
+ [Window][Example: drag and drop]
22
+ Pos=60,60
23
+ Size=66,77
24
+ Collapsed=0
25
+
LDM/notebooks/_1_Main.ipynb ADDED
@@ -0,0 +1,1481 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "## Imports"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "code",
12
+ "execution_count": 7,
13
+ "metadata": {},
14
+ "outputs": [
15
+ {
16
+ "data": {
17
+ "text/plain": [
18
+ "device(type='cuda')"
19
+ ]
20
+ },
21
+ "execution_count": 7,
22
+ "metadata": {},
23
+ "output_type": "execute_result"
24
+ }
25
+ ],
26
+ "source": [
27
+ "import os\n",
28
+ "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
29
+ "\n",
30
+ "import torch\n",
31
+ "import torch.nn as nn\n",
32
+ "import numpy as np\n",
33
+ "from collections import namedtuple\n",
34
+ "\n",
35
+ "import pandas as pd\n",
36
+ "import torchvision as tv\n",
37
+ "from torchvision.transforms import v2\n",
38
+ "from tqdm.auto import tqdm, trange\n",
39
+ "\n",
40
+ "import yaml\n",
41
+ "from dotdict import DotDict\n",
42
+ "import random\n",
43
+ "import torch.hub\n",
44
+ "from torch.utils.data import Dataset, DataLoader\n",
45
+ "from torchvision.utils import make_grid\n",
46
+ "\n",
47
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
48
+ "device"
49
+ ]
50
+ },
51
+ {
52
+ "cell_type": "markdown",
53
+ "metadata": {},
54
+ "source": [
55
+ "### *LPIPS*: Learned Perceptual Image Patch Similarity"
56
+ ]
57
+ },
58
+ {
59
+ "cell_type": "code",
60
+ "execution_count": 8,
61
+ "metadata": {},
62
+ "outputs": [],
63
+ "source": [
64
+ "class vgg16(nn.Module):\n",
65
+ " def __init__(self):\n",
66
+ " super(vgg16, self).__init__()\n",
67
+ " vgg_pretrained_features = tv.models.vgg16(\n",
68
+ " weights=tv.models.VGG16_Weights.IMAGENET1K_V1\n",
69
+ " ).features\n",
70
+ " self.slice1 = torch.nn.Sequential()\n",
71
+ " self.slice2 = torch.nn.Sequential()\n",
72
+ " self.slice3 = torch.nn.Sequential()\n",
73
+ " self.slice4 = torch.nn.Sequential()\n",
74
+ " self.slice5 = torch.nn.Sequential()\n",
75
+ " self.N_slices = 5\n",
76
+ " for x in range(4):\n",
77
+ " self.slice1.add_module(str(x), vgg_pretrained_features[x])\n",
78
+ " for x in range(4, 9):\n",
79
+ " self.slice2.add_module(str(x), vgg_pretrained_features[x])\n",
80
+ " for x in range(9, 16):\n",
81
+ " self.slice3.add_module(str(x), vgg_pretrained_features[x])\n",
82
+ " for x in range(16, 23):\n",
83
+ " self.slice4.add_module(str(x), vgg_pretrained_features[x])\n",
84
+ " for x in range(23, 30):\n",
85
+ " self.slice5.add_module(str(x), vgg_pretrained_features[x])\n",
86
+ " \n",
87
+ " self.eval()\n",
88
+ " for param in self.parameters():\n",
89
+ " param.requires_grad = False\n",
90
+ "\n",
91
+ " def forward(self, X):\n",
92
+ " h1 = self.slice1(X)\n",
93
+ " h2 = self.slice2(h1)\n",
94
+ " h3 = self.slice3(h2)\n",
95
+ " h4 = self.slice4(h3)\n",
96
+ " h5 = self.slice5(h4)\n",
97
+ " vgg_outputs = namedtuple(\"VggOutputs\", ['h1', 'h2', 'h3', 'h4', 'h5'])\n",
98
+ " out = vgg_outputs(h1, h2, h3, h4, h5)\n",
99
+ " return out\n",
100
+ "\n",
101
+ "\n",
102
+ "def _spatial_average(in_tens, keepdim=True):\n",
103
+ " return in_tens.mean([2, 3], keepdim=keepdim)\n",
104
+ "\n",
105
+ "\n",
106
+ "def _normalize_tensor(in_feat, eps= 1e-8):\n",
107
+ " norm_factor = torch.sqrt(eps + torch.sum(in_feat**2, dim=1, keepdim=True))\n",
108
+ " return in_feat / norm_factor\n",
109
+ "\n",
110
+ "\n",
111
+ "class ScalingLayer(nn.Module):\n",
112
+ " def __init__(self):\n",
113
+ " super(ScalingLayer, self).__init__()\n",
114
+ " # Imagnet normalization for (0-1)\n",
115
+ " # mean = [0.485, 0.456, 0.406]\n",
116
+ " # std = [0.229, 0.224, 0.225]\n",
117
+ "\n",
118
+ " self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])\n",
119
+ " self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])\n",
120
+ "\n",
121
+ " def forward(self, inp):\n",
122
+ " return (inp - self.shift) / self.scale\n",
123
+ "\n",
124
+ "\n",
125
+ "class NetLinLayer(nn.Module):\n",
126
+ " ''' A single linear layer which does a 1x1 conv '''\n",
127
+ " def __init__(self, chn_in, chn_out=1, use_dropout=False):\n",
128
+ " super(NetLinLayer, self).__init__()\n",
129
+ " layers = [nn.Dropout(), ] if (use_dropout) else []\n",
130
+ " layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]\n",
131
+ " self.model = nn.Sequential(*layers)\n",
132
+ "\n",
133
+ " def forward(self, x):\n",
134
+ " return self.model(x)\n",
135
+ "\n",
136
+ "\n",
137
+ "class LPIPS(nn.Module):\n",
138
+ " def __init__(self, net='vgg', version='0.1', use_dropout=True):\n",
139
+ " super(LPIPS, self).__init__()\n",
140
+ " self.version = version\n",
141
+ " self.scaling_layer = ScalingLayer()\n",
142
+ " self.chns = [64, 128, 256, 512, 512]\n",
143
+ " self.L = len(self.chns)\n",
144
+ " self.net = vgg16()\n",
145
+ " self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)\n",
146
+ " self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)\n",
147
+ " self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)\n",
148
+ " self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)\n",
149
+ " self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)\n",
150
+ " self.lins = nn.ModuleList([self.lin0, self.lin1, self.lin2, self.lin3, self.lin4])\n",
151
+ "\n",
152
+ " # --- Orignal url --------------------\n",
153
+ " # weights_url = f\"https://github.com/richzhang/PerceptualSimilarity/raw/master/lpips/weights/v{version}/{net}.pth\"\n",
154
+ " \n",
155
+ " # --- Orignal Forked url -------------\n",
156
+ " weights_url = f\"https://github.com/akuresonite/PerceptualSimilarity-Forked/raw/master/lpips/weights/v{version}/{net}.pth\"\n",
157
+ " \n",
158
+ " # --- Orignal torchmetric url --------\n",
159
+ " # weights_url = \"https://github.com/Lightning-AI/torchmetrics/raw/master/src/torchmetrics/functional/image/lpips_models/vgg.pth\"\n",
160
+ " \n",
161
+ " state_dict = torch.hub.load_state_dict_from_url(weights_url, map_location='cpu')\n",
162
+ " self.load_state_dict(state_dict, strict=False)\n",
163
+ " \n",
164
+ " self.eval()\n",
165
+ " for param in self.parameters():\n",
166
+ " param.requires_grad = False\n",
167
+ "\n",
168
+ " def forward(self, in0, in1, normalize=False):\n",
169
+ " # Scale the inputs to -1 to +1 range if input in [0,1]\n",
170
+ " if normalize:\n",
171
+ " in0 = 2 * in0 - 1\n",
172
+ " in1 = 2 * in1 - 1\n",
173
+ "\n",
174
+ " in0_input, in1_input = self.scaling_layer(in0), self.scaling_layer(in1)\n",
175
+ " # in0_input, in1_input = in0, in1\n",
176
+ " \n",
177
+ " outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)\n",
178
+ " \n",
179
+ " diffs = {}\n",
180
+ " for kk in range(self.L):\n",
181
+ " feats0 = _normalize_tensor(outs0[kk])\n",
182
+ " feats1 = _normalize_tensor(outs1[kk])\n",
183
+ " diffs[kk] = (feats0 - feats1) ** 2\n",
184
+ " \n",
185
+ " res = [_spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)]\n",
186
+ " val = sum(res)\n",
187
+ " return val.reshape(-1)"
188
+ ]
189
+ },
190
+ {
191
+ "cell_type": "markdown",
192
+ "metadata": {},
193
+ "source": [
194
+ "### Discriminator"
195
+ ]
196
+ },
197
+ {
198
+ "cell_type": "code",
199
+ "execution_count": 9,
200
+ "metadata": {},
201
+ "outputs": [],
202
+ "source": [
203
+ "class Discriminator(nn.Module):\n",
204
+ " r\"\"\"\n",
205
+ " PatchGAN Discriminator.\n",
206
+ " Rather than taking IMG_CHANNELSxIMG_HxIMG_W all the way to\n",
207
+ " 1 scalar value , we instead predict grid of values.\n",
208
+ " Where each grid is prediction of how likely\n",
209
+ " the discriminator thinks that the image patch corresponding\n",
210
+ " to the grid cell is real\n",
211
+ " \"\"\"\n",
212
+ "\n",
213
+ " def __init__(\n",
214
+ " self,\n",
215
+ " im_channels=3,\n",
216
+ " conv_channels=[64, 128, 256],\n",
217
+ " kernels=[4, 4, 4, 4],\n",
218
+ " strides=[2, 2, 2, 1],\n",
219
+ " paddings=[1, 1, 1, 1],\n",
220
+ " ):\n",
221
+ " super().__init__()\n",
222
+ " self.im_channels = im_channels\n",
223
+ " activation = nn.LeakyReLU(0.2)\n",
224
+ " layers_dim = [self.im_channels] + conv_channels + [1]\n",
225
+ " self.layers = nn.ModuleList(\n",
226
+ " [\n",
227
+ " nn.Sequential(\n",
228
+ " nn.Conv2d(\n",
229
+ " layers_dim[i],\n",
230
+ " layers_dim[i + 1],\n",
231
+ " kernel_size=kernels[i],\n",
232
+ " stride=strides[i],\n",
233
+ " padding=paddings[i],\n",
234
+ " bias=False if i != 0 else True,\n",
235
+ " ),\n",
236
+ " (\n",
237
+ " nn.BatchNorm2d(layers_dim[i + 1])\n",
238
+ " if i != len(layers_dim) - 2 and i != 0\n",
239
+ " else nn.Identity()\n",
240
+ " ),\n",
241
+ " activation if i != len(layers_dim) - 2 else nn.Identity(),\n",
242
+ " )\n",
243
+ " for i in range(len(layers_dim) - 1)\n",
244
+ " ]\n",
245
+ " )\n",
246
+ "\n",
247
+ " def forward(self, x):\n",
248
+ " out = x\n",
249
+ " for layer in self.layers:\n",
250
+ " out = layer(out)\n",
251
+ " return out"
252
+ ]
253
+ },
254
+ {
255
+ "cell_type": "markdown",
256
+ "metadata": {},
257
+ "source": [
258
+ "### *VQVAE*"
259
+ ]
260
+ },
261
+ {
262
+ "cell_type": "code",
263
+ "execution_count": 10,
264
+ "metadata": {},
265
+ "outputs": [],
266
+ "source": [
267
+ "class DownBlock(nn.Module):\n",
268
+ " r\"\"\"\n",
269
+ " Down conv block with attention.\n",
270
+ " Sequence of following block\n",
271
+ " 1. Resnet block with time embedding\n",
272
+ " 2. Attention block\n",
273
+ " 3. Downsample\n",
274
+ " \"\"\"\n",
275
+ "\n",
276
+ " def __init__(\n",
277
+ " self,\n",
278
+ " in_channels,\n",
279
+ " out_channels,\n",
280
+ " t_emb_dim,\n",
281
+ " down_sample,\n",
282
+ " num_heads,\n",
283
+ " num_layers,\n",
284
+ " attn,\n",
285
+ " norm_channels,\n",
286
+ " cross_attn=False,\n",
287
+ " context_dim=None,\n",
288
+ " ):\n",
289
+ " super().__init__()\n",
290
+ " self.num_layers = num_layers\n",
291
+ " self.down_sample = down_sample\n",
292
+ " self.attn = attn\n",
293
+ " self.context_dim = context_dim\n",
294
+ " self.cross_attn = cross_attn\n",
295
+ " self.t_emb_dim = t_emb_dim\n",
296
+ " self.resnet_conv_first = nn.ModuleList(\n",
297
+ " [\n",
298
+ " nn.Sequential(\n",
299
+ " nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels),\n",
300
+ " nn.SiLU(),\n",
301
+ " nn.Conv2d(\n",
302
+ " in_channels if i == 0 else out_channels,\n",
303
+ " out_channels,\n",
304
+ " kernel_size=3,\n",
305
+ " stride=1,\n",
306
+ " padding=1,\n",
307
+ " ),\n",
308
+ " )\n",
309
+ " for i in range(num_layers)\n",
310
+ " ]\n",
311
+ " )\n",
312
+ " if self.t_emb_dim is not None:\n",
313
+ " self.t_emb_layers = nn.ModuleList(\n",
314
+ " [\n",
315
+ " nn.Sequential(nn.SiLU(), nn.Linear(self.t_emb_dim, out_channels))\n",
316
+ " for _ in range(num_layers)\n",
317
+ " ]\n",
318
+ " )\n",
319
+ " self.resnet_conv_second = nn.ModuleList(\n",
320
+ " [\n",
321
+ " nn.Sequential(\n",
322
+ " nn.GroupNorm(norm_channels, out_channels),\n",
323
+ " nn.SiLU(),\n",
324
+ " nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),\n",
325
+ " )\n",
326
+ " for _ in range(num_layers)\n",
327
+ " ]\n",
328
+ " )\n",
329
+ "\n",
330
+ " if self.attn:\n",
331
+ " self.attention_norms = nn.ModuleList(\n",
332
+ " [nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)]\n",
333
+ " )\n",
334
+ "\n",
335
+ " self.attentions = nn.ModuleList(\n",
336
+ " [\n",
337
+ " nn.MultiheadAttention(out_channels, num_heads, batch_first=True)\n",
338
+ " for _ in range(num_layers)\n",
339
+ " ]\n",
340
+ " )\n",
341
+ " if self.cross_attn:\n",
342
+ " assert context_dim is not None, \"Context Dimension must be passed for cross attention\"\n",
343
+ " self.cross_attention_norms = nn.ModuleList(\n",
344
+ " [nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)]\n",
345
+ " )\n",
346
+ " self.cross_attentions = nn.ModuleList(\n",
347
+ " [\n",
348
+ " nn.MultiheadAttention(out_channels, num_heads, batch_first=True)\n",
349
+ " for _ in range(num_layers)\n",
350
+ " ]\n",
351
+ " )\n",
352
+ " self.context_proj = nn.ModuleList(\n",
353
+ " [nn.Linear(context_dim, out_channels) for _ in range(num_layers)]\n",
354
+ " )\n",
355
+ " self.residual_input_conv = nn.ModuleList(\n",
356
+ " [\n",
357
+ " nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)\n",
358
+ " for i in range(num_layers)\n",
359
+ " ]\n",
360
+ " )\n",
361
+ " self.down_sample_conv = (\n",
362
+ " nn.Conv2d(out_channels, out_channels, 4, 2, 1) if self.down_sample else nn.Identity()\n",
363
+ " )\n",
364
+ "\n",
365
+ " def forward(self, x, t_emb=None, context=None):\n",
366
+ " out = x\n",
367
+ " for i in range(self.num_layers):\n",
368
+ " # Resnet block of Unet\n",
369
+ "\n",
370
+ " resnet_input = out\n",
371
+ " out = self.resnet_conv_first[i](out)\n",
372
+ " if self.t_emb_dim is not None:\n",
373
+ " out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]\n",
374
+ " out = self.resnet_conv_second[i](out)\n",
375
+ " out = out + self.residual_input_conv[i](resnet_input)\n",
376
+ "\n",
377
+ " if self.attn:\n",
378
+ " # Attention block of Unet\n",
379
+ "\n",
380
+ " batch_size, channels, h, w = out.shape\n",
381
+ " in_attn = out.reshape(batch_size, channels, h * w)\n",
382
+ " in_attn = self.attention_norms[i](in_attn)\n",
383
+ " in_attn = in_attn.transpose(1, 2)\n",
384
+ " out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)\n",
385
+ " out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)\n",
386
+ " out = out + out_attn\n",
387
+ " if self.cross_attn:\n",
388
+ " assert (\n",
389
+ " context is not None\n",
390
+ " ), \"context cannot be None if cross attention layers are used\"\n",
391
+ " batch_size, channels, h, w = out.shape\n",
392
+ " in_attn = out.reshape(batch_size, channels, h * w)\n",
393
+ " in_attn = self.cross_attention_norms[i](in_attn)\n",
394
+ " in_attn = in_attn.transpose(1, 2)\n",
395
+ " assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim\n",
396
+ " context_proj = self.context_proj[i](context)\n",
397
+ " out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj)\n",
398
+ " out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)\n",
399
+ " out = out + out_attn\n",
400
+ " # Downsample\n",
401
+ "\n",
402
+ " out = self.down_sample_conv(out)\n",
403
+ " return out\n",
404
+ "\n",
405
+ "\n",
406
+ "class MidBlock(nn.Module):\n",
407
+ " r\"\"\"\n",
408
+ " Mid conv block with attention.\n",
409
+ " Sequence of following blocks\n",
410
+ " 1. Resnet block with time embedding\n",
411
+ " 2. Attention block\n",
412
+ " 3. Resnet block with time embedding\n",
413
+ " \"\"\"\n",
414
+ "\n",
415
+ " def __init__(\n",
416
+ " self,\n",
417
+ " in_channels,\n",
418
+ " out_channels,\n",
419
+ " t_emb_dim,\n",
420
+ " num_heads,\n",
421
+ " num_layers,\n",
422
+ " norm_channels,\n",
423
+ " cross_attn=None,\n",
424
+ " context_dim=None,\n",
425
+ " ):\n",
426
+ " super().__init__()\n",
427
+ " self.num_layers = num_layers\n",
428
+ " self.t_emb_dim = t_emb_dim\n",
429
+ " self.context_dim = context_dim\n",
430
+ " self.cross_attn = cross_attn\n",
431
+ " self.resnet_conv_first = nn.ModuleList(\n",
432
+ " [\n",
433
+ " nn.Sequential(\n",
434
+ " nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels),\n",
435
+ " nn.SiLU(),\n",
436
+ " nn.Conv2d(\n",
437
+ " in_channels if i == 0 else out_channels,\n",
438
+ " out_channels,\n",
439
+ " kernel_size=3,\n",
440
+ " stride=1,\n",
441
+ " padding=1,\n",
442
+ " ),\n",
443
+ " )\n",
444
+ " for i in range(num_layers + 1)\n",
445
+ " ]\n",
446
+ " )\n",
447
+ "\n",
448
+ " if self.t_emb_dim is not None:\n",
449
+ " self.t_emb_layers = nn.ModuleList(\n",
450
+ " [\n",
451
+ " nn.Sequential(nn.SiLU(), nn.Linear(t_emb_dim, out_channels))\n",
452
+ " for _ in range(num_layers + 1)\n",
453
+ " ]\n",
454
+ " )\n",
455
+ " self.resnet_conv_second = nn.ModuleList(\n",
456
+ " [\n",
457
+ " nn.Sequential(\n",
458
+ " nn.GroupNorm(norm_channels, out_channels),\n",
459
+ " nn.SiLU(),\n",
460
+ " nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),\n",
461
+ " )\n",
462
+ " for _ in range(num_layers + 1)\n",
463
+ " ]\n",
464
+ " )\n",
465
+ "\n",
466
+ " self.attention_norms = nn.ModuleList(\n",
467
+ " [nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)]\n",
468
+ " )\n",
469
+ "\n",
470
+ " self.attentions = nn.ModuleList(\n",
471
+ " [\n",
472
+ " nn.MultiheadAttention(out_channels, num_heads, batch_first=True)\n",
473
+ " for _ in range(num_layers)\n",
474
+ " ]\n",
475
+ " )\n",
476
+ " if self.cross_attn:\n",
477
+ " assert context_dim is not None, \"Context Dimension must be passed for cross attention\"\n",
478
+ " self.cross_attention_norms = nn.ModuleList(\n",
479
+ " [nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)]\n",
480
+ " )\n",
481
+ " self.cross_attentions = nn.ModuleList(\n",
482
+ " [\n",
483
+ " nn.MultiheadAttention(out_channels, num_heads, batch_first=True)\n",
484
+ " for _ in range(num_layers)\n",
485
+ " ]\n",
486
+ " )\n",
487
+ " self.context_proj = nn.ModuleList(\n",
488
+ " [nn.Linear(context_dim, out_channels) for _ in range(num_layers)]\n",
489
+ " )\n",
490
+ " self.residual_input_conv = nn.ModuleList(\n",
491
+ " [\n",
492
+ " nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)\n",
493
+ " for i in range(num_layers + 1)\n",
494
+ " ]\n",
495
+ " )\n",
496
+ "\n",
497
+ " def forward(self, x, t_emb=None, context=None):\n",
498
+ " out = x\n",
499
+ "\n",
500
+ " # First resnet block\n",
501
+ "\n",
502
+ " resnet_input = out\n",
503
+ " out = self.resnet_conv_first[0](out)\n",
504
+ " if self.t_emb_dim is not None:\n",
505
+ " out = out + self.t_emb_layers[0](t_emb)[:, :, None, None]\n",
506
+ " out = self.resnet_conv_second[0](out)\n",
507
+ " out = out + self.residual_input_conv[0](resnet_input)\n",
508
+ "\n",
509
+ " for i in range(self.num_layers):\n",
510
+ " # Attention Block\n",
511
+ "\n",
512
+ " batch_size, channels, h, w = out.shape\n",
513
+ " in_attn = out.reshape(batch_size, channels, h * w)\n",
514
+ " in_attn = self.attention_norms[i](in_attn)\n",
515
+ " in_attn = in_attn.transpose(1, 2)\n",
516
+ " out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)\n",
517
+ " out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)\n",
518
+ " out = out + out_attn\n",
519
+ "\n",
520
+ " if self.cross_attn:\n",
521
+ " assert (\n",
522
+ " context is not None\n",
523
+ " ), \"context cannot be None if cross attention layers are used\"\n",
524
+ " batch_size, channels, h, w = out.shape\n",
525
+ " in_attn = out.reshape(batch_size, channels, h * w)\n",
526
+ " in_attn = self.cross_attention_norms[i](in_attn)\n",
527
+ " in_attn = in_attn.transpose(1, 2)\n",
528
+ " assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim\n",
529
+ " context_proj = self.context_proj[i](context)\n",
530
+ " out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj)\n",
531
+ " out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)\n",
532
+ " out = out + out_attn\n",
533
+ " # Resnet Block\n",
534
+ "\n",
535
+ " resnet_input = out\n",
536
+ " out = self.resnet_conv_first[i + 1](out)\n",
537
+ " if self.t_emb_dim is not None:\n",
538
+ " out = out + self.t_emb_layers[i + 1](t_emb)[:, :, None, None]\n",
539
+ " out = self.resnet_conv_second[i + 1](out)\n",
540
+ " out = out + self.residual_input_conv[i + 1](resnet_input)\n",
541
+ " return out\n",
542
+ "\n",
543
+ "\n",
544
+ "class UpBlock(nn.Module):\n",
545
+ " r\"\"\"\n",
546
+ " Up conv block with attention.\n",
547
+ " Sequence of following blocks\n",
548
+ " 1. Upsample\n",
549
+ " 1. Concatenate Down block output\n",
550
+ " 2. Resnet block with time embedding\n",
551
+ " 3. Attention Block\n",
552
+ " \"\"\"\n",
553
+ "\n",
554
+ " def __init__(\n",
555
+ " self,\n",
556
+ " in_channels,\n",
557
+ " out_channels,\n",
558
+ " t_emb_dim,\n",
559
+ " up_sample,\n",
560
+ " num_heads,\n",
561
+ " num_layers,\n",
562
+ " attn,\n",
563
+ " norm_channels,\n",
564
+ " ):\n",
565
+ " super().__init__()\n",
566
+ " self.num_layers = num_layers\n",
567
+ " self.up_sample = up_sample\n",
568
+ " self.t_emb_dim = t_emb_dim\n",
569
+ " self.attn = attn\n",
570
+ " self.resnet_conv_first = nn.ModuleList(\n",
571
+ " [\n",
572
+ " nn.Sequential(\n",
573
+ " nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels),\n",
574
+ " nn.SiLU(),\n",
575
+ " nn.Conv2d(\n",
576
+ " in_channels if i == 0 else out_channels,\n",
577
+ " out_channels,\n",
578
+ " kernel_size=3,\n",
579
+ " stride=1,\n",
580
+ " padding=1,\n",
581
+ " ),\n",
582
+ " )\n",
583
+ " for i in range(num_layers)\n",
584
+ " ]\n",
585
+ " )\n",
586
+ "\n",
587
+ " if self.t_emb_dim is not None:\n",
588
+ " self.t_emb_layers = nn.ModuleList(\n",
589
+ " [\n",
590
+ " nn.Sequential(nn.SiLU(), nn.Linear(t_emb_dim, out_channels))\n",
591
+ " for _ in range(num_layers)\n",
592
+ " ]\n",
593
+ " )\n",
594
+ " self.resnet_conv_second = nn.ModuleList(\n",
595
+ " [\n",
596
+ " nn.Sequential(\n",
597
+ " nn.GroupNorm(norm_channels, out_channels),\n",
598
+ " nn.SiLU(),\n",
599
+ " nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),\n",
600
+ " )\n",
601
+ " for _ in range(num_layers)\n",
602
+ " ]\n",
603
+ " )\n",
604
+ " if self.attn:\n",
605
+ " self.attention_norms = nn.ModuleList(\n",
606
+ " [nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)]\n",
607
+ " )\n",
608
+ "\n",
609
+ " self.attentions = nn.ModuleList(\n",
610
+ " [\n",
611
+ " nn.MultiheadAttention(out_channels, num_heads, batch_first=True)\n",
612
+ " for _ in range(num_layers)\n",
613
+ " ]\n",
614
+ " )\n",
615
+ " self.residual_input_conv = nn.ModuleList(\n",
616
+ " [\n",
617
+ " nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)\n",
618
+ " for i in range(num_layers)\n",
619
+ " ]\n",
620
+ " )\n",
621
+ " self.up_sample_conv = (\n",
622
+ " nn.ConvTranspose2d(in_channels, in_channels, 4, 2, 1)\n",
623
+ " if self.up_sample\n",
624
+ " else nn.Identity()\n",
625
+ " )\n",
626
+ "\n",
627
+ " def forward(self, x, out_down=None, t_emb=None):\n",
628
+ " # Upsample\n",
629
+ "\n",
630
+ " x = self.up_sample_conv(x)\n",
631
+ "\n",
632
+ " # Concat with Downblock output\n",
633
+ "\n",
634
+ " if out_down is not None:\n",
635
+ " x = torch.cat([x, out_down], dim=1)\n",
636
+ " out = x\n",
637
+ " for i in range(self.num_layers):\n",
638
+ " # Resnet Block\n",
639
+ "\n",
640
+ " resnet_input = out\n",
641
+ " out = self.resnet_conv_first[i](out)\n",
642
+ " if self.t_emb_dim is not None:\n",
643
+ " out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]\n",
644
+ " out = self.resnet_conv_second[i](out)\n",
645
+ " out = out + self.residual_input_conv[i](resnet_input)\n",
646
+ "\n",
647
+ " # Self Attention\n",
648
+ "\n",
649
+ " if self.attn:\n",
650
+ " batch_size, channels, h, w = out.shape\n",
651
+ " in_attn = out.reshape(batch_size, channels, h * w)\n",
652
+ " in_attn = self.attention_norms[i](in_attn)\n",
653
+ " in_attn = in_attn.transpose(1, 2)\n",
654
+ " out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)\n",
655
+ " out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)\n",
656
+ " out = out + out_attn\n",
657
+ " return out\n",
658
+ "\n",
659
+ "\n",
660
+ "class VQVAE(nn.Module):\n",
661
+ " def __init__(self, im_channels, model_config):\n",
662
+ " super().__init__()\n",
663
+ " self.down_channels = model_config.down_channels\n",
664
+ " self.mid_channels = model_config.mid_channels\n",
665
+ " self.down_sample = model_config.down_sample\n",
666
+ " self.num_down_layers = model_config.num_down_layers\n",
667
+ " self.num_mid_layers = model_config.num_mid_layers\n",
668
+ " self.num_up_layers = model_config.num_up_layers\n",
669
+ "\n",
670
+ " # To disable attention in Downblock of Encoder and Upblock of Decoder\n",
671
+ " self.attns = model_config.attn_down\n",
672
+ "\n",
673
+ " # Latent Dimension\n",
674
+ " self.z_channels = model_config.z_channels\n",
675
+ " self.codebook_size = model_config.codebook_size\n",
676
+ " self.norm_channels = model_config.norm_channels\n",
677
+ " self.num_heads = model_config.num_heads\n",
678
+ "\n",
679
+ " # Assertion to validate the channel information\n",
680
+ " assert self.mid_channels[0] == self.down_channels[-1]\n",
681
+ " assert self.mid_channels[-1] == self.down_channels[-1]\n",
682
+ " assert len(self.down_sample) == len(self.down_channels) - 1\n",
683
+ " assert len(self.attns) == len(self.down_channels) - 1\n",
684
+ "\n",
685
+ " # Wherever we use downsampling in encoder correspondingly use\n",
686
+ " # upsampling in decoder\n",
687
+ " self.up_sample = list(reversed(self.down_sample))\n",
688
+ "\n",
689
+ " ##################### Encoder ######################\n",
690
+ " self.encoder_conv_in = nn.Conv2d(\n",
691
+ " im_channels, self.down_channels[0], kernel_size=3, padding=(1, 1)\n",
692
+ " )\n",
693
+ "\n",
694
+ " # Downblock + Midblock\n",
695
+ " self.encoder_layers = nn.ModuleList([])\n",
696
+ " for i in range(len(self.down_channels) - 1):\n",
697
+ " self.encoder_layers.append(\n",
698
+ " DownBlock(\n",
699
+ " self.down_channels[i],\n",
700
+ " self.down_channels[i + 1],\n",
701
+ " t_emb_dim=None,\n",
702
+ " down_sample=self.down_sample[i],\n",
703
+ " num_heads=self.num_heads,\n",
704
+ " num_layers=self.num_down_layers,\n",
705
+ " attn=self.attns[i],\n",
706
+ " norm_channels=self.norm_channels,\n",
707
+ " )\n",
708
+ " )\n",
709
+ " self.encoder_mids = nn.ModuleList([])\n",
710
+ " for i in range(len(self.mid_channels) - 1):\n",
711
+ " self.encoder_mids.append(\n",
712
+ " MidBlock(\n",
713
+ " self.mid_channels[i],\n",
714
+ " self.mid_channels[i + 1],\n",
715
+ " t_emb_dim=None,\n",
716
+ " num_heads=self.num_heads,\n",
717
+ " num_layers=self.num_mid_layers,\n",
718
+ " norm_channels=self.norm_channels,\n",
719
+ " )\n",
720
+ " )\n",
721
+ " self.encoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[-1])\n",
722
+ " self.encoder_conv_out = nn.Conv2d(\n",
723
+ " self.down_channels[-1], self.z_channels, kernel_size=3, padding=1\n",
724
+ " )\n",
725
+ "\n",
726
+ " # Pre Quantization Convolution\n",
727
+ " self.pre_quant_conv = nn.Conv2d(self.z_channels, self.z_channels, kernel_size=1)\n",
728
+ "\n",
729
+ " # Codebook\n",
730
+ " self.embedding = nn.Embedding(self.codebook_size, self.z_channels)\n",
731
+ " ####################################################\n",
732
+ "\n",
733
+ " ##################### Decoder ######################\n",
734
+ "\n",
735
+ " # Post Quantization Convolution\n",
736
+ " self.post_quant_conv = nn.Conv2d(self.z_channels, self.z_channels, kernel_size=1)\n",
737
+ " self.decoder_conv_in = nn.Conv2d(\n",
738
+ " self.z_channels, self.mid_channels[-1], kernel_size=3, padding=(1, 1)\n",
739
+ " )\n",
740
+ "\n",
741
+ " # Midblock + Upblock\n",
742
+ " self.decoder_mids = nn.ModuleList([])\n",
743
+ " for i in reversed(range(1, len(self.mid_channels))):\n",
744
+ " self.decoder_mids.append(\n",
745
+ " MidBlock(\n",
746
+ " self.mid_channels[i],\n",
747
+ " self.mid_channels[i - 1],\n",
748
+ " t_emb_dim=None,\n",
749
+ " num_heads=self.num_heads,\n",
750
+ " num_layers=self.num_mid_layers,\n",
751
+ " norm_channels=self.norm_channels,\n",
752
+ " )\n",
753
+ " )\n",
754
+ " self.decoder_layers = nn.ModuleList([])\n",
755
+ " for i in reversed(range(1, len(self.down_channels))):\n",
756
+ " self.decoder_layers.append(\n",
757
+ " UpBlock(\n",
758
+ " self.down_channels[i],\n",
759
+ " self.down_channels[i - 1],\n",
760
+ " t_emb_dim=None,\n",
761
+ " up_sample=self.down_sample[i - 1],\n",
762
+ " num_heads=self.num_heads,\n",
763
+ " num_layers=self.num_up_layers,\n",
764
+ " attn=self.attns[i - 1],\n",
765
+ " norm_channels=self.norm_channels,\n",
766
+ " )\n",
767
+ " )\n",
768
+ " self.decoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[0])\n",
769
+ " self.decoder_conv_out = nn.Conv2d(\n",
770
+ " self.down_channels[0], im_channels, kernel_size=3, padding=1\n",
771
+ " )\n",
772
+ "\n",
773
+ " def quantize(self, x):\n",
774
+ " B, C, H, W = x.shape\n",
775
+ "\n",
776
+ " # B, C, H, W -> B, H, W, C\n",
777
+ " x = x.permute(0, 2, 3, 1)\n",
778
+ "\n",
779
+ " # B, H, W, C -> B, H*W, C\n",
780
+ " x = x.reshape(x.size(0), -1, x.size(-1))\n",
781
+ "\n",
782
+ " # Find nearest embedding/codebook vector\n",
783
+ " # dist between (B, H*W, C) and (B, K, C) -> (B, H*W, K)\n",
784
+ " dist = torch.cdist(x, self.embedding.weight[None, :].repeat((x.size(0), 1, 1)))\n",
785
+ " # (B, H*W)\n",
786
+ " min_encoding_indices = torch.argmin(dist, dim=-1)\n",
787
+ "\n",
788
+ " # Replace encoder output with nearest codebook\n",
789
+ " # quant_out -> B*H*W, C\n",
790
+ " quant_out = torch.index_select(self.embedding.weight, 0, min_encoding_indices.view(-1))\n",
791
+ "\n",
792
+ " # x -> B*H*W, C\n",
793
+ " x = x.reshape((-1, x.size(-1)))\n",
794
+ " commmitment_loss = torch.mean((quant_out.detach() - x) ** 2)\n",
795
+ " codebook_loss = torch.mean((quant_out - x.detach()) ** 2)\n",
796
+ " quantize_losses = {\"codebook_loss\": codebook_loss, \"commitment_loss\": commmitment_loss}\n",
797
+ " # Straight through estimation\n",
798
+ " quant_out = x + (quant_out - x).detach()\n",
799
+ "\n",
800
+ " # quant_out -> B, C, H, W\n",
801
+ " quant_out = quant_out.reshape((B, H, W, C)).permute(0, 3, 1, 2)\n",
802
+ " min_encoding_indices = min_encoding_indices.reshape(\n",
803
+ " (-1, quant_out.size(-2), quant_out.size(-1))\n",
804
+ " )\n",
805
+ " return quant_out, quantize_losses, min_encoding_indices\n",
806
+ "\n",
807
+ " def encode(self, x):\n",
808
+ " out = self.encoder_conv_in(x)\n",
809
+ " for idx, down in enumerate(self.encoder_layers):\n",
810
+ " out = down(out)\n",
811
+ " for mid in self.encoder_mids:\n",
812
+ " out = mid(out)\n",
813
+ " out = self.encoder_norm_out(out)\n",
814
+ " out = nn.SiLU()(out)\n",
815
+ " out = self.encoder_conv_out(out)\n",
816
+ " out = self.pre_quant_conv(out)\n",
817
+ " out, quant_losses, _ = self.quantize(out)\n",
818
+ " return out, quant_losses\n",
819
+ "\n",
820
+ " def decode(self, z):\n",
821
+ " out = z\n",
822
+ " out = self.post_quant_conv(out)\n",
823
+ " out = self.decoder_conv_in(out)\n",
824
+ " for mid in self.decoder_mids:\n",
825
+ " out = mid(out)\n",
826
+ " for idx, up in enumerate(self.decoder_layers):\n",
827
+ " out = up(out)\n",
828
+ " out = self.decoder_norm_out(out)\n",
829
+ " out = nn.SiLU()(out)\n",
830
+ " out = self.decoder_conv_out(out)\n",
831
+ " return out\n",
832
+ "\n",
833
+ " def forward(self, x):\n",
834
+ " z, quant_losses = self.encode(x)\n",
835
+ " out = self.decode(z)\n",
836
+ " return out, z, quant_losses"
837
+ ]
838
+ },
839
+ {
840
+ "cell_type": "markdown",
841
+ "metadata": {},
842
+ "source": [
843
+ "### Configuration"
844
+ ]
845
+ },
846
+ {
847
+ "cell_type": "code",
848
+ "execution_count": 12,
849
+ "metadata": {},
850
+ "outputs": [],
851
+ "source": [
852
+ "config_path = \"/home/23m1521/ashish/MTP/LDM/scripts/config.yaml\"\n",
853
+ "with open(config_path, 'r') as file:\n",
854
+ " Config = yaml.safe_load(file)\n",
855
+ "\n",
856
+ "Config = DotDict.from_dict(Config)\n",
857
+ "dataset_config = Config.dataset_params\n",
858
+ "diffusion_config = Config.diffusion_params\n",
859
+ "model_config = Config.model_params\n",
860
+ "train_config = Config.train_params"
861
+ ]
862
+ },
863
+ {
864
+ "cell_type": "markdown",
865
+ "metadata": {},
866
+ "source": [
867
+ "### MNIST Dataset"
868
+ ]
869
+ },
870
+ {
871
+ "cell_type": "code",
872
+ "execution_count": 13,
873
+ "metadata": {},
874
+ "outputs": [
875
+ {
876
+ "name": "stdout",
877
+ "output_type": "stream",
878
+ "text": [
879
+ "Files found: 70000\n"
880
+ ]
881
+ }
882
+ ],
883
+ "source": [
884
+ "datadir = r\"/home/23m1521/datasets/mnist_images/data\"\n",
885
+ "\n",
886
+ "def walkDIR(folder_path, include=None):\n",
887
+ " file_list = []\n",
888
+ " for root, _, files in os.walk(folder_path):\n",
889
+ " for file in files:\n",
890
+ " if include is None or any(file.endswith(ext) for ext in include):\n",
891
+ " file_list.append(os.path.join(root, file))\n",
892
+ " print(\"Files found:\", len(file_list))\n",
893
+ " return file_list\n",
894
+ "\n",
895
+ "files = walkDIR(datadir, include=['.png', '.jpeg', '.jpg'])\n",
896
+ "df = pd.DataFrame(files, columns=['image_path'])\n",
897
+ "df['id'] = df['image_path'].apply(lambda x: os.path.basename(x))\n",
898
+ "df['label'] = df['image_path'].apply(lambda x: os.path.dirname(x).split(\"/\")[-1])\n",
899
+ "df = df.sample(frac=1, random_state=42).reset_index(drop=True)\n",
900
+ "\n",
901
+ "\n",
902
+ "class MnistDataset(torch.utils.data.Dataset):\n",
903
+ " def __init__(\n",
904
+ " self,\n",
905
+ " data,\n",
906
+ " im_size\n",
907
+ " ):\n",
908
+ " if isinstance(data, str):\n",
909
+ " self.data = pd.read_csv(data)\n",
910
+ " elif isinstance(data, pd.DataFrame):\n",
911
+ " self.data = data\n",
912
+ " else:\n",
913
+ " raise ValueError(\"The `data` argument must be a string (CSV file path) or a Pandas DataFrame.\")\n",
914
+ " \n",
915
+ " self.im_size = im_size\n",
916
+ "\n",
917
+ " def __len__(self):\n",
918
+ " return len(self.data)\n",
919
+ "\n",
920
+ " def __getitem__(self, idx):\n",
921
+ " row = self.data.iloc[idx]\n",
922
+ " image_path = row['image_path']\n",
923
+ " label = int(row['label'])\n",
924
+ "\n",
925
+ " image = tv.io.decode_image(image_path, mode='RGB')\n",
926
+ " image = v2.Resize(self.im_size)(image)\n",
927
+ " image = v2.ToDtype(torch.float32, scale=True)(image)\n",
928
+ " image = 2*image - 1\n",
929
+ "\n",
930
+ " return image, label\n",
931
+ "\n",
932
+ "\n",
933
+ "dataset = MnistDataset(df, im_size=dataset_config.im_size)\n",
934
+ "dataloader = torch.utils.data.DataLoader(\n",
935
+ " dataset, \n",
936
+ " batch_size=train_config.autoencoder_batch_size, \n",
937
+ " shuffle=True, \n",
938
+ " num_workers=os.cpu_count(),\n",
939
+ " pin_memory=True,\n",
940
+ " drop_last=True,\n",
941
+ " persistent_workers=True\n",
942
+ ")\n",
943
+ "\n",
944
+ "# for batch in tqdm(dataloader):\n",
945
+ "# images, labels = batch\n",
946
+ "\n",
947
+ "images, labels = next(iter(dataloader))\n",
948
+ "images, labels = images.to(device), labels.to(device)"
949
+ ]
950
+ },
951
+ {
952
+ "cell_type": "code",
953
+ "execution_count": 14,
954
+ "metadata": {},
955
+ "outputs": [
956
+ {
957
+ "data": {
958
+ "text/plain": [
959
+ "(torch.Size([32, 3, 28, 28]),\n",
960
+ " torch.Size([32, 3, 7, 7]),\n",
961
+ " {'codebook_loss': tensor(0.1057, device='cuda:0', grad_fn=<MeanBackward0>),\n",
962
+ " 'commitment_loss': tensor(0.1057, device='cuda:0', grad_fn=<MeanBackward0>)})"
963
+ ]
964
+ },
965
+ "execution_count": 14,
966
+ "metadata": {},
967
+ "output_type": "execute_result"
968
+ }
969
+ ],
970
+ "source": [
971
+ "dataset_config = Config.dataset_params\n",
972
+ "autoencoder_config = Config.autoencoder_params\n",
973
+ "train_config = Config.train_params\n",
974
+ "\n",
975
+ "model = VQVAE(im_channels=dataset_config.im_channels, model_config=autoencoder_config).to(device)\n",
976
+ "\n",
977
+ "model_output = model(images)\n",
978
+ "model_output[0].shape, model_output[1].shape, model_output[2]"
979
+ ]
980
+ },
981
+ {
982
+ "cell_type": "markdown",
983
+ "metadata": {},
984
+ "source": [
985
+ "### VQVAE Training"
986
+ ]
987
+ },
988
+ {
989
+ "cell_type": "code",
990
+ "execution_count": 26,
991
+ "metadata": {},
992
+ "outputs": [],
993
+ "source": [
994
+ "def save_checkpoint(\n",
995
+ " total_steps, epoch, model, discriminator, optimizer_d, optimizer_g, metrics, checkpoint_path\n",
996
+ "):\n",
997
+ " checkpoint = {\n",
998
+ " \"total_steps\": total_steps,\n",
999
+ " \"epoch\": epoch,\n",
1000
+ " \"model_state_dict\": model.state_dict(),\n",
1001
+ " \"discriminator_state_dict\": discriminator.state_dict(),\n",
1002
+ " \"optimizer_d_state_dict\": optimizer_d.state_dict(),\n",
1003
+ " \"optimizer_g_state_dict\": optimizer_g.state_dict(),\n",
1004
+ " \"metrics\": metrics, # Save all metrics\n",
1005
+ " }\n",
1006
+ " torch.save(checkpoint, checkpoint_path)\n",
1007
+ " print(f\"Checkpoint saved after {total_steps} steps at epoch {epoch}\")\n",
1008
+ "\n",
1009
+ "\n",
1010
+ "def load_checkpoint(checkpoint_path, model, discriminator, optimizer_d, optimizer_g):\n",
1011
+ " if os.path.exists(checkpoint_path):\n",
1012
+ " checkpoint = torch.load(checkpoint_path, map_location=device)\n",
1013
+ " model.load_state_dict(checkpoint[\"model_state_dict\"])\n",
1014
+ " discriminator.load_state_dict(checkpoint[\"discriminator_state_dict\"])\n",
1015
+ " optimizer_d.load_state_dict(checkpoint[\"optimizer_d_state_dict\"])\n",
1016
+ " optimizer_g.load_state_dict(checkpoint[\"optimizer_g_state_dict\"])\n",
1017
+ " total_steps = checkpoint[\"total_steps\"]\n",
1018
+ " epoch = checkpoint[\"epoch\"]\n",
1019
+ " metrics = checkpoint[\"metrics\"]\n",
1020
+ " print(f\"Checkpoint loaded. Resuming from epoch {epoch + 1}, step {total_steps}\")\n",
1021
+ " return total_steps, epoch + 1, metrics\n",
1022
+ " else:\n",
1023
+ " print(\"No checkpoint found. Starting from scratch.\")\n",
1024
+ " return 0, 0, None\n",
1025
+ "\n",
1026
+ "\n",
1027
+ "def trainVAE(Config, dataloader):\n",
1028
+ "\n",
1029
+ " # --- Configurations ----------------------------------------------------\n",
1030
+ " dataset_config = Config.dataset_params\n",
1031
+ " autoencoder_config = Config.autoencoder_params\n",
1032
+ " train_config = Config.train_params\n",
1033
+ "\n",
1034
+ " seed = train_config.seed\n",
1035
+ " torch.manual_seed(seed)\n",
1036
+ " np.random.seed(seed)\n",
1037
+ " random.seed(seed)\n",
1038
+ " if device == \"cuda\":\n",
1039
+ " torch.cuda.manual_seed_all(seed)\n",
1040
+ " \n",
1041
+ " \n",
1042
+ " # --- Model Initilization ------------------------------------------------\n",
1043
+ " model = VQVAE(im_channels=dataset_config.im_channels, model_config=autoencoder_config).to(device)\n",
1044
+ " discriminator = Discriminator(im_channels=dataset_config.im_channels).to(device)\n",
1045
+ "\n",
1046
+ " \n",
1047
+ " # --- Optimizer Initilization ----------------------------------------------\n",
1048
+ " optimizer_d = torch.optim.AdamW(discriminator.parameters(), lr=train_config.autoencoder_lr, betas=(0.5, 0.999))\n",
1049
+ " optimizer_g = torch.optim.AdamW(model.parameters(), lr=train_config.autoencoder_lr, betas=(0.5, 0.999))\n",
1050
+ " \n",
1051
+ " \n",
1052
+ " # --- Checkpoint Loading ------------------------------------------------\n",
1053
+ " checkpoint_path = os.path.join(train_config.task_name, \"checkpoint.pth\")\n",
1054
+ " total_steps, start_epoch, metrics = load_checkpoint(checkpoint_path, model, discriminator, optimizer_d, optimizer_g)\n",
1055
+ " if os.path.exists(\n",
1056
+ " os.path.join(train_config.task_name, train_config.vqvae_autoencoder_ckpt_name)\n",
1057
+ " ):\n",
1058
+ " print(\"Loaded vae checkpoint\")\n",
1059
+ " model.load_state_dict(\n",
1060
+ " torch.load(\n",
1061
+ " os.path.join(train_config.task_name, train_config.vqvae_autoencoder_ckpt_name),\n",
1062
+ " map_location=device,\n",
1063
+ " weights_only=True,\n",
1064
+ " )\n",
1065
+ " )\n",
1066
+ " \n",
1067
+ " if os.path.exists(\n",
1068
+ " os.path.join(train_config.task_name, train_config.vqvae_discriminator_ckpt_name)\n",
1069
+ " ):\n",
1070
+ " print(\"Loaded discriminator checkpoint\")\n",
1071
+ " discriminator.load_state_dict(\n",
1072
+ " torch.load(\n",
1073
+ " os.path.join(train_config.task_name, train_config.vqvae_discriminator_ckpt_name),\n",
1074
+ " map_location=device,\n",
1075
+ " weights_only=True,\n",
1076
+ " )\n",
1077
+ " )\n",
1078
+ " \n",
1079
+ " \n",
1080
+ " \n",
1081
+ " # --- Loss Function Initilization ----------------------------------------\n",
1082
+ " if not os.path.exists(train_config.task_name):\n",
1083
+ " os.mkdir(train_config.task_name)\n",
1084
+ " num_epochs = train_config.autoencoder_epochs\n",
1085
+ "\n",
1086
+ " # L1/L2 loss for Reconstruction\n",
1087
+ " recon_criterion = torch.nn.MSELoss()\n",
1088
+ " disc_criterion = torch.nn.MSELoss()\n",
1089
+ "\n",
1090
+ " # LPIPS loss for perceptual similarity\n",
1091
+ " lpips_model = LPIPS().eval().to(device)\n",
1092
+ "\n",
1093
+ " \n",
1094
+ " \n",
1095
+ "\n",
1096
+ " disc_step_start = train_config.disc_start\n",
1097
+ " step_count = 0\n",
1098
+ "\n",
1099
+ " # This is for accumulating gradients incase the images are huge\n",
1100
+ " # And one cant afford higher batch sizes\n",
1101
+ "\n",
1102
+ " acc_steps = train_config.autoencoder_acc_steps\n",
1103
+ " image_save_steps = train_config.autoencoder_img_save_steps\n",
1104
+ " img_save_count = 0\n",
1105
+ "\n",
1106
+ " for epoch_idx in trange(num_epochs, desc=\"Training VQVAE\"):\n",
1107
+ " recon_losses = []\n",
1108
+ " codebook_losses = []\n",
1109
+ " # commitment_losses = []\n",
1110
+ "\n",
1111
+ " perceptual_losses = []\n",
1112
+ " disc_losses = []\n",
1113
+ " gen_losses = []\n",
1114
+ " losses = []\n",
1115
+ "\n",
1116
+ " optimizer_g.zero_grad()\n",
1117
+ " optimizer_d.zero_grad()\n",
1118
+ "\n",
1119
+ " # for images in tqdm(dataloader):\n",
1120
+ " for images in dataloader:\n",
1121
+ " step_count += 1\n",
1122
+ " images = images.to(device)\n",
1123
+ "\n",
1124
+ " # Fetch autoencoders output(reconstructions)\n",
1125
+ " model_output = model(images)\n",
1126
+ " output, z, quantize_losses = model_output\n",
1127
+ "\n",
1128
+ " # Image Saving Logic\n",
1129
+ " if step_count % image_save_steps == 0 or step_count == 1:\n",
1130
+ " sample_size = min(8, images.shape[0])\n",
1131
+ " save_output = torch.clamp(output[:sample_size], -1.0, 1.0).detach().cpu()\n",
1132
+ " save_output = (save_output + 1) / 2\n",
1133
+ " save_input = ((images[:sample_size] + 1) / 2).detach().cpu()\n",
1134
+ "\n",
1135
+ " grid = make_grid(torch.cat([save_input, save_output], dim=0), nrow=sample_size)\n",
1136
+ " img = tv.transforms.ToPILImage()(grid)\n",
1137
+ " if not os.path.exists(\n",
1138
+ " os.path.join(train_config.task_name, \"vqvae_autoencoder_samples\")\n",
1139
+ " ):\n",
1140
+ " os.mkdir(os.path.join(train_config.task_name, \"vqvae_autoencoder_samples\"))\n",
1141
+ " img.save(\n",
1142
+ " os.path.join(\n",
1143
+ " train_config.task_name,\n",
1144
+ " \"vqvae_autoencoder_samples\",\n",
1145
+ " \"current_autoencoder_sample_{}.png\".format(img_save_count),\n",
1146
+ " )\n",
1147
+ " )\n",
1148
+ " img_save_count += 1\n",
1149
+ " img.close()\n",
1150
+ " \n",
1151
+ " \n",
1152
+ " ######### Optimize Generator ##########\n",
1153
+ " # L2 Loss for Reconstruction\n",
1154
+ " recon_loss = recon_criterion(output, images)\n",
1155
+ " recon_losses.append(recon_loss.item())\n",
1156
+ " recon_loss = recon_loss / acc_steps\n",
1157
+ " \n",
1158
+ " # Generator Loss =\n",
1159
+ " g_loss = (\n",
1160
+ " recon_loss\n",
1161
+ " + (train_config.codebook_weight * quantize_losses[\"codebook_loss\"] / acc_steps)\n",
1162
+ " + (train_config.commitment_beta * quantize_losses[\"commitment_loss\"] / acc_steps)\n",
1163
+ " )\n",
1164
+ " \n",
1165
+ " codebook_losses.append(\n",
1166
+ " train_config.codebook_weight * quantize_losses[\"codebook_loss\"].item()\n",
1167
+ " )\n",
1168
+ " \n",
1169
+ "\n",
1170
+ " # Adversarial loss only if disc_step_start steps passed\n",
1171
+ " if step_count > disc_step_start:\n",
1172
+ " disc_fake_pred = discriminator(model_output[0])\n",
1173
+ " disc_fake_loss = disc_criterion(\n",
1174
+ " disc_fake_pred,\n",
1175
+ " torch.ones(disc_fake_pred.shape, device=disc_fake_pred.device),\n",
1176
+ " )\n",
1177
+ " gen_losses.append(train_config.disc_weight * disc_fake_loss.item())\n",
1178
+ " g_loss += train_config.disc_weight * disc_fake_loss / acc_steps\n",
1179
+ " lpips_loss = torch.mean(lpips_model(output, images)) / acc_steps\n",
1180
+ " perceptual_losses.append(train_config.perceptual_weight * lpips_loss.item())\n",
1181
+ " g_loss += train_config.perceptual_weight * lpips_loss / acc_steps\n",
1182
+ " losses.append(g_loss.item())\n",
1183
+ " g_loss.backward()\n",
1184
+ " #####################################\n",
1185
+ "\n",
1186
+ "\n",
1187
+ " ######### Optimize Discriminator #######\n",
1188
+ " if step_count > disc_step_start:\n",
1189
+ " fake = output\n",
1190
+ " disc_fake_pred = discriminator(fake.detach())\n",
1191
+ " disc_real_pred = discriminator(images)\n",
1192
+ " disc_fake_loss = disc_criterion(\n",
1193
+ " disc_fake_pred,\n",
1194
+ " torch.zeros(disc_fake_pred.shape, device=disc_fake_pred.device),\n",
1195
+ " )\n",
1196
+ " disc_real_loss = disc_criterion(\n",
1197
+ " disc_real_pred,\n",
1198
+ " torch.ones(disc_real_pred.shape, device=disc_real_pred.device),\n",
1199
+ " )\n",
1200
+ " disc_loss = train_config.disc_weight * (disc_fake_loss + disc_real_loss) / 2\n",
1201
+ " disc_losses.append(disc_loss.item())\n",
1202
+ " disc_loss = disc_loss / acc_steps\n",
1203
+ " disc_loss.backward()\n",
1204
+ " if step_count % acc_steps == 0:\n",
1205
+ " optimizer_d.step()\n",
1206
+ " optimizer_d.zero_grad()\n",
1207
+ " #####################################\n",
1208
+ "\n",
1209
+ " if step_count % acc_steps == 0:\n",
1210
+ " optimizer_g.step()\n",
1211
+ " optimizer_g.zero_grad()\n",
1212
+ " optimizer_d.step()\n",
1213
+ " optimizer_d.zero_grad()\n",
1214
+ " optimizer_g.step()\n",
1215
+ " optimizer_g.zero_grad()\n",
1216
+ " if len(disc_losses) > 0:\n",
1217
+ " print(\n",
1218
+ " \"Finished epoch: {}/{} | Recon Loss : {:.4f} | Perceptual Loss : {:.4f} | \"\n",
1219
+ " \"Codebook : {:.4f} | G Loss : {:.4f} | D Loss {:.4f}\".format(\n",
1220
+ " epoch_idx + 1,\n",
1221
+ " num_epochs,\n",
1222
+ " np.mean(recon_losses),\n",
1223
+ " np.mean(perceptual_losses),\n",
1224
+ " np.mean(codebook_losses),\n",
1225
+ " np.mean(gen_losses),\n",
1226
+ " np.mean(disc_losses),\n",
1227
+ " )\n",
1228
+ " )\n",
1229
+ " else:\n",
1230
+ " print(\n",
1231
+ " \"Finished epoch: {}/{} | Recon Loss : {:.4f} | Perceptual Loss : {:.4f} | Codebook : {:.4f}\".format(\n",
1232
+ " epoch_idx + 1,\n",
1233
+ " num_epochs,\n",
1234
+ " np.mean(recon_losses),\n",
1235
+ " np.mean(perceptual_losses),\n",
1236
+ " np.mean(codebook_losses),\n",
1237
+ " )\n",
1238
+ " )\n",
1239
+ " torch.save(\n",
1240
+ " model.state_dict(),\n",
1241
+ " os.path.join(train_config.task_name, train_config.vqvae_autoencoder_ckpt_name),\n",
1242
+ " )\n",
1243
+ " torch.save(\n",
1244
+ " discriminator.state_dict(),\n",
1245
+ " os.path.join(train_config.task_name, train_config.vqvae_discriminator_ckpt_name),\n",
1246
+ " )\n",
1247
+ " print(\"Done Training...\")"
1248
+ ]
1249
+ },
1250
+ {
1251
+ "cell_type": "code",
1252
+ "execution_count": 27,
1253
+ "metadata": {},
1254
+ "outputs": [],
1255
+ "source": [
1256
+ "# trainVAE(Config)"
1257
+ ]
1258
+ },
1259
+ {
1260
+ "cell_type": "code",
1261
+ "execution_count": null,
1262
+ "metadata": {},
1263
+ "outputs": [],
1264
+ "source": [
1265
+ "def save_checkpoint(\n",
1266
+ " total_steps, epoch, model, discriminator, optimizer_d, optimizer_g, metrics, checkpoint_path\n",
1267
+ "):\n",
1268
+ " checkpoint = {\n",
1269
+ " \"total_steps\": total_steps,\n",
1270
+ " \"epoch\": epoch,\n",
1271
+ " \"model_state_dict\": model.state_dict(),\n",
1272
+ " \"discriminator_state_dict\": discriminator.state_dict(),\n",
1273
+ " \"optimizer_d_state_dict\": optimizer_d.state_dict(),\n",
1274
+ " \"optimizer_g_state_dict\": optimizer_g.state_dict(),\n",
1275
+ " \"metrics\": metrics, # Save all metrics\n",
1276
+ " }\n",
1277
+ " torch.save(checkpoint, checkpoint_path)\n",
1278
+ " print(f\"Checkpoint saved after {total_steps} steps at epoch {epoch}\")\n",
1279
+ "\n",
1280
+ "\n",
1281
+ "def load_checkpoint(checkpoint_path, model, discriminator, optimizer_d, optimizer_g):\n",
1282
+ " if os.path.exists(checkpoint_path):\n",
1283
+ " checkpoint = torch.load(checkpoint_path, map_location=device)\n",
1284
+ " model.load_state_dict(checkpoint[\"model_state_dict\"])\n",
1285
+ " discriminator.load_state_dict(checkpoint[\"discriminator_state_dict\"])\n",
1286
+ " optimizer_d.load_state_dict(checkpoint[\"optimizer_d_state_dict\"])\n",
1287
+ " optimizer_g.load_state_dict(checkpoint[\"optimizer_g_state_dict\"])\n",
1288
+ " total_steps = checkpoint[\"total_steps\"]\n",
1289
+ " epoch = checkpoint[\"epoch\"]\n",
1290
+ " metrics = checkpoint[\"metrics\"]\n",
1291
+ " print(f\"Checkpoint loaded. Resuming from epoch {epoch + 1}, step {total_steps}\")\n",
1292
+ " return total_steps, epoch + 1, metrics\n",
1293
+ " else:\n",
1294
+ " print(\"No checkpoint found. Starting from scratch.\")\n",
1295
+ " return 0, 0, None\n",
1296
+ "\n",
1297
+ "\n",
1298
+ "def trainVAE(Config, dataloader):\n",
1299
+ "\n",
1300
+ " # --- Configurations ----------------------------------------------------\n",
1301
+ " dataset_config = Config.dataset_params\n",
1302
+ " autoencoder_config = Config.autoencoder_params\n",
1303
+ " train_config = Config.train_params\n",
1304
+ "\n",
1305
+ " seed = train_config.seed\n",
1306
+ " torch.manual_seed(seed)\n",
1307
+ " np.random.seed(seed)\n",
1308
+ " random.seed(seed)\n",
1309
+ " if device == \"cuda\":\n",
1310
+ " torch.cuda.manual_seed_all(seed)\n",
1311
+ " \n",
1312
+ " \n",
1313
+ " # --- Model Initilization ------------------------------------------------\n",
1314
+ " model = VQVAE(im_channels=dataset_config.im_channels, model_config=autoencoder_config).to(device)\n",
1315
+ " discriminator = Discriminator(im_channels=dataset_config.im_channels).to(device)\n",
1316
+ "\n",
1317
+ " \n",
1318
+ " # --- Optimizer Initilization ----------------------------------------------\n",
1319
+ " optimizer_d = torch.optim.AdamW(discriminator.parameters(), lr=train_config.autoencoder_lr, betas=(0.5, 0.999))\n",
1320
+ " optimizer_g = torch.optim.AdamW(model.parameters(), lr=train_config.autoencoder_lr, betas=(0.5, 0.999))\n",
1321
+ " \n",
1322
+ " \n",
1323
+ " # --- Loss Function Initialization --------------------------------------\n",
1324
+ " recon_criterion = torch.nn.MSELoss()\n",
1325
+ " # disc_criterion = torch.nn.MSELoss()\n",
1326
+ " disc_criterion = torch.nn.BCEWithLogits()\n",
1327
+ " lpips_model = LPIPS().eval().to(device)\n",
1328
+ "\n",
1329
+ " \n",
1330
+ " # --- Training Loop -----------------------------------------------------\n",
1331
+ " step_count = 0\n",
1332
+ " num_epochs = train_config.autoencoder_epochs\n",
1333
+ " disc_step_start = train_config.disc_start\n",
1334
+ " acc_steps = train_config.autoencoder_acc_steps\n",
1335
+ " image_save_steps = train_config.autoencoder_img_save_steps\n",
1336
+ " img_save_count = 0\n",
1337
+ " start_epoch = 0\n",
1338
+ "\n",
1339
+ " for epoch_idx in range(start_epoch, num_epochs):\n",
1340
+ " recon_losses = []\n",
1341
+ " codebook_losses = []\n",
1342
+ " perceptual_losses = []\n",
1343
+ " \n",
1344
+ " disc_losses = []\n",
1345
+ " gen_losses = []\n",
1346
+ " losses = []\n",
1347
+ "\n",
1348
+ " optimizer_g.zero_grad()\n",
1349
+ " optimizer_d.zero_grad()\n",
1350
+ "\n",
1351
+ " for images in dataloader:\n",
1352
+ " step_count += 1\n",
1353
+ " images = images.to(device)\n",
1354
+ "\n",
1355
+ " model_output = model(images)\n",
1356
+ " output, z, quantize_losses = model_output\n",
1357
+ " \n",
1358
+ " \n",
1359
+ " # --- Reconstruction Loss ---------------------------------------------------------\n",
1360
+ " recon_loss = recon_criterion(output, images)\n",
1361
+ " recon_losses.append(recon_loss.item())\n",
1362
+ " recon_loss = recon_loss / acc_steps\n",
1363
+ " \n",
1364
+ " # --- CodeBook Loss ---------------------------------------------------------------\n",
1365
+ " codebook_losses.append(train_config.codebook_weight * quantize_losses[\"codebook_loss\"].item())\n",
1366
+ " \n",
1367
+ " # --- Perceptual Loss -------------------------------------------------------------\n",
1368
+ " lpips_loss = torch.mean(lpips_model(output, images)) / acc_steps\n",
1369
+ " perceptual_losses.append(train_config.perceptual_weight * lpips_loss.item())\n",
1370
+ " \n",
1371
+ " \n",
1372
+ " g_loss = (\n",
1373
+ " recon_loss\n",
1374
+ " + (train_config.codebook_weight * quantize_losses[\"codebook_loss\"] / acc_steps)\n",
1375
+ " + (train_config.commitment_beta * quantize_losses[\"commitment_loss\"] / acc_steps)\n",
1376
+ " )\n",
1377
+ " \n",
1378
+ "\n",
1379
+ " # Adversarial loss only if disc_step_start steps passed\n",
1380
+ " if step_count > disc_step_start:\n",
1381
+ " disc_fake_pred = discriminator(model_output[0])\n",
1382
+ " disc_fake_loss = disc_criterion(\n",
1383
+ " disc_fake_pred,\n",
1384
+ " torch.ones(disc_fake_pred.shape, device=disc_fake_pred.device),\n",
1385
+ " )\n",
1386
+ " gen_losses.append(train_config.disc_weight * disc_fake_loss.item())\n",
1387
+ " g_loss += train_config.disc_weight * disc_fake_loss / acc_steps\n",
1388
+ " \n",
1389
+ " \n",
1390
+ " \n",
1391
+ " g_loss += train_config.perceptual_weight * lpips_loss / acc_steps\n",
1392
+ " losses.append(g_loss.item())\n",
1393
+ " g_loss.backward()\n",
1394
+ "\n",
1395
+ "\n",
1396
+ " ######### Optimize Discriminator #######\n",
1397
+ " if step_count > disc_step_start:\n",
1398
+ " fake = output\n",
1399
+ " disc_fake_pred = discriminator(fake.detach())\n",
1400
+ " disc_real_pred = discriminator(images)\n",
1401
+ " disc_fake_loss = disc_criterion(\n",
1402
+ " disc_fake_pred,\n",
1403
+ " torch.zeros(disc_fake_pred.shape, device=disc_fake_pred.device),\n",
1404
+ " )\n",
1405
+ " disc_real_loss = disc_criterion(\n",
1406
+ " disc_real_pred,\n",
1407
+ " torch.ones(disc_real_pred.shape, device=disc_real_pred.device),\n",
1408
+ " )\n",
1409
+ " disc_loss = train_config.disc_weight * (disc_fake_loss + disc_real_loss) / 2\n",
1410
+ " disc_losses.append(disc_loss.item())\n",
1411
+ " disc_loss = disc_loss / acc_steps\n",
1412
+ " disc_loss.backward()\n",
1413
+ " if step_count % acc_steps == 0:\n",
1414
+ " optimizer_d.step()\n",
1415
+ " optimizer_d.zero_grad()\n",
1416
+ " #####################################\n",
1417
+ "\n",
1418
+ " if step_count % acc_steps == 0:\n",
1419
+ " optimizer_g.step()\n",
1420
+ " optimizer_g.zero_grad()\n",
1421
+ " optimizer_d.step()\n",
1422
+ " optimizer_d.zero_grad()\n",
1423
+ " optimizer_g.step()\n",
1424
+ " optimizer_g.zero_grad()\n",
1425
+ " if len(disc_losses) > 0:\n",
1426
+ " print(\n",
1427
+ " \"Finished epoch: {}/{} | Recon Loss : {:.4f} | Perceptual Loss : {:.4f} | \"\n",
1428
+ " \"Codebook : {:.4f} | G Loss : {:.4f} | D Loss {:.4f}\".format(\n",
1429
+ " epoch_idx + 1,\n",
1430
+ " num_epochs,\n",
1431
+ " np.mean(recon_losses),\n",
1432
+ " np.mean(perceptual_losses),\n",
1433
+ " np.mean(codebook_losses),\n",
1434
+ " np.mean(gen_losses),\n",
1435
+ " np.mean(disc_losses),\n",
1436
+ " )\n",
1437
+ " )\n",
1438
+ " else:\n",
1439
+ " print(\n",
1440
+ " \"Finished epoch: {}/{} | Recon Loss : {:.4f} | Perceptual Loss : {:.4f} | Codebook : {:.4f}\".format(\n",
1441
+ " epoch_idx + 1,\n",
1442
+ " num_epochs,\n",
1443
+ " np.mean(recon_losses),\n",
1444
+ " np.mean(perceptual_losses),\n",
1445
+ " np.mean(codebook_losses),\n",
1446
+ " )\n",
1447
+ " )\n",
1448
+ " torch.save(\n",
1449
+ " model.state_dict(),\n",
1450
+ " os.path.join(train_config.task_name, train_config.vqvae_autoencoder_ckpt_name),\n",
1451
+ " )\n",
1452
+ " torch.save(\n",
1453
+ " discriminator.state_dict(),\n",
1454
+ " os.path.join(train_config.task_name, train_config.vqvae_discriminator_ckpt_name),\n",
1455
+ " )\n",
1456
+ " print(\"Done Training...\")"
1457
+ ]
1458
+ }
1459
+ ],
1460
+ "metadata": {
1461
+ "kernelspec": {
1462
+ "display_name": "Python 3",
1463
+ "language": "python",
1464
+ "name": "python3"
1465
+ },
1466
+ "language_info": {
1467
+ "codemirror_mode": {
1468
+ "name": "ipython",
1469
+ "version": 3
1470
+ },
1471
+ "file_extension": ".py",
1472
+ "mimetype": "text/x-python",
1473
+ "name": "python",
1474
+ "nbconvert_exporter": "python",
1475
+ "pygments_lexer": "ipython3",
1476
+ "version": "3.12.5"
1477
+ }
1478
+ },
1479
+ "nbformat": 4,
1480
+ "nbformat_minor": 2
1481
+ }
LDM/notebooks/_2_Rough-LPIPS.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
LDM/scripts/Main.py ADDED
@@ -0,0 +1,2273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ==================================================================
2
+ # L A T E N T D I F F U S I O N M O D E L
3
+ # ==================================================================
4
+ # Author : Ashish Kumar Uchadiya
5
+ # Created : November 3, 2024
6
+ # Description: This script implements a Latent Diffusion Model using
7
+ # a cosine or linear noise scheduling approach for high-resolution
8
+ # image generation. The model leverages generative techniques to
9
+ # learn a latent representation and progressively reduce noise to
10
+ # generate clear, realistic images.
11
+ # ==================================================================
12
+ # I M P O R T S
13
+ # ==================================================================
14
+
15
+ import os
16
+ os.environ["CUDA_VISIBLE_DEVICES"] = "1"
17
+
18
+ """Lpips"""
19
+
20
+ # from __future__ import absolute_import
21
+ from collections import namedtuple
22
+ import torch
23
+ import torch.nn as nn
24
+ import torch.nn.init as init
25
+ from torch.autograd import Variable
26
+ import numpy as np
27
+ import torch.nn
28
+ import torchvision
29
+
30
+ # Taken from https://github.com/richzhang/PerceptualSimilarity/blob/master/lpips/lpips.py
31
+
32
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
33
+
34
+
35
+ def spatial_average(in_tens, keepdim=True):
36
+ return in_tens.mean([2, 3], keepdim=keepdim)
37
+
38
+
39
+ class vgg16(torch.nn.Module):
40
+ def __init__(self, requires_grad=False, pretrained=True):
41
+ super(vgg16, self).__init__()
42
+ vgg_pretrained_features = torchvision.models.vgg16(
43
+ weights=torchvision.models.VGG16_Weights.IMAGENET1K_V1
44
+ ).features
45
+ self.slice1 = torch.nn.Sequential()
46
+ self.slice2 = torch.nn.Sequential()
47
+ self.slice3 = torch.nn.Sequential()
48
+ self.slice4 = torch.nn.Sequential()
49
+ self.slice5 = torch.nn.Sequential()
50
+ self.N_slices = 5
51
+ for x in range(4):
52
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
53
+ for x in range(4, 9):
54
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
55
+ for x in range(9, 16):
56
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
57
+ for x in range(16, 23):
58
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
59
+ for x in range(23, 30):
60
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
61
+
62
+ # Freeze vgg model
63
+ if not requires_grad:
64
+ for param in self.parameters():
65
+ param.requires_grad = False
66
+
67
+ def forward(self, X):
68
+ # Return output of vgg features
69
+ h = self.slice1(X)
70
+ h_relu1_2 = h
71
+ h = self.slice2(h)
72
+ h_relu2_2 = h
73
+ h = self.slice3(h)
74
+ h_relu3_3 = h
75
+ h = self.slice4(h)
76
+ h_relu4_3 = h
77
+ h = self.slice5(h)
78
+ h_relu5_3 = h
79
+ vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
80
+ out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
81
+ return out
82
+
83
+
84
+ # Learned perceptual metric
85
+ class LPIPS(nn.Module):
86
+ def __init__(self, net='vgg', version='0.1', use_dropout=True):
87
+ super(LPIPS, self).__init__()
88
+ self.version = version
89
+ # Imagenet normalization
90
+ self.scaling_layer = ScalingLayer()
91
+ ########################
92
+
93
+ # Instantiate vgg model
94
+ self.chns = [64, 128, 256, 512, 512]
95
+ self.L = len(self.chns)
96
+ self.net = vgg16(pretrained=True, requires_grad=False)
97
+
98
+ # Add 1x1 convolutional Layers
99
+ self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
100
+ self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
101
+ self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
102
+ self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
103
+ self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
104
+ self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
105
+ self.lins = nn.ModuleList(self.lins)
106
+ ########################
107
+
108
+ # Load the weights of trained LPIPS model
109
+ import inspect
110
+ import os
111
+ # /home/taruntejaneurips23/.cache/torch/hub/checkpoints/vgg16-397923af.pth
112
+ print(os.path.abspath(os.path.join(inspect.getfile(self.__init__), '..', 'weights/v%s/%s.pth' % (version, net))))
113
+ # model_path = os.path.abspath(
114
+ # os.path.join(inspect.getfile(self.__init__), '..', 'weights/v%s/%s.pth' % (version, net)))
115
+
116
+ # print('Loading model from: %s' % model_path)
117
+ # self.load_state_dict(torch.load(model_path, map_location=device), strict=False)
118
+ ########################
119
+
120
+ # Freeze all parameters
121
+ self.eval()
122
+ for param in self.parameters():
123
+ param.requires_grad = False
124
+ ########################
125
+
126
+ def forward(self, in0, in1, normalize=False):
127
+ # Scale the inputs to -1 to +1 range if needed
128
+ if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1]
129
+ in0 = 2 * in0 - 1
130
+ in1 = 2 * in1 - 1
131
+ ########################
132
+
133
+ # Normalize the inputs according to imagenet normalization
134
+ in0_input, in1_input = self.scaling_layer(in0), self.scaling_layer(in1)
135
+ ########################
136
+
137
+ # Get VGG outputs for image0 and image1
138
+ outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
139
+ feats0, feats1, diffs = {}, {}, {}
140
+ ########################
141
+
142
+ # Compute Square of Difference for each layer output
143
+ for kk in range(self.L):
144
+ feats0[kk], feats1[kk] = torch.nn.functional.normalize(outs0[kk], dim=1), torch.nn.functional.normalize(
145
+ outs1[kk])
146
+ diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
147
+ ########################
148
+
149
+ # 1x1 convolution followed by spatial average on the square differences
150
+ res = [spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)]
151
+ val = 0
152
+
153
+ # Aggregate the results of each layer
154
+ for l in range(self.L):
155
+ val += res[l]
156
+ return val
157
+
158
+
159
+ class ScalingLayer(nn.Module):
160
+ def __init__(self):
161
+ super(ScalingLayer, self).__init__()
162
+ # Imagnet normalization for (0-1)
163
+ # mean = [0.485, 0.456, 0.406]
164
+ # std = [0.229, 0.224, 0.225]
165
+ self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
166
+ self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
167
+
168
+ def forward(self, inp):
169
+ return (inp - self.shift) / self.scale
170
+
171
+
172
+ class NetLinLayer(nn.Module):
173
+ ''' A single linear layer which does a 1x1 conv '''
174
+
175
+ def __init__(self, chn_in, chn_out=1, use_dropout=False):
176
+ super(NetLinLayer, self).__init__()
177
+
178
+ layers = [nn.Dropout(), ] if (use_dropout) else []
179
+ layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
180
+ self.model = nn.Sequential(*layers)
181
+
182
+ def forward(self, x):
183
+ out = self.model(x)
184
+ return out
185
+
186
+ """Blocks"""
187
+
188
+ import torch
189
+ import numpy as np
190
+
191
+
192
+ class LinearNoiseScheduler:
193
+ r"""
194
+ Class for the linear noise scheduler that is used in DDPM.
195
+ """
196
+
197
+ def __init__(self, num_timesteps, beta_start, beta_end):
198
+
199
+ self.num_timesteps = num_timesteps
200
+ self.beta_start = beta_start
201
+ self.beta_end = beta_end
202
+ # Mimicking how compvis repo creates schedule
203
+ self.betas = (
204
+ torch.linspace(beta_start ** 0.5, beta_end ** 0.5, num_timesteps) ** 2
205
+ )
206
+ self.alphas = 1. - self.betas
207
+ self.alpha_cum_prod = torch.cumprod(self.alphas, dim=0)
208
+ self.sqrt_alpha_cum_prod = torch.sqrt(self.alpha_cum_prod)
209
+ self.sqrt_one_minus_alpha_cum_prod = torch.sqrt(1 - self.alpha_cum_prod)
210
+
211
+ def add_noise(self, original, noise, t):
212
+ r"""
213
+ Forward method for diffusion
214
+ :param original: Image on which noise is to be applied
215
+ :param noise: Random Noise Tensor (from normal dist)
216
+ :param t: timestep of the forward process of shape -> (B,)
217
+ :return:
218
+ """
219
+ original_shape = original.shape
220
+ batch_size = original_shape[0]
221
+
222
+ sqrt_alpha_cum_prod = self.sqrt_alpha_cum_prod.to(original.device)[t].reshape(batch_size)
223
+ sqrt_one_minus_alpha_cum_prod = self.sqrt_one_minus_alpha_cum_prod.to(original.device)[t].reshape(batch_size)
224
+
225
+ # Reshape till (B,) becomes (B,1,1,1) if image is (B,C,H,W)
226
+ for _ in range(len(original_shape) - 1):
227
+ sqrt_alpha_cum_prod = sqrt_alpha_cum_prod.unsqueeze(-1)
228
+ for _ in range(len(original_shape) - 1):
229
+ sqrt_one_minus_alpha_cum_prod = sqrt_one_minus_alpha_cum_prod.unsqueeze(-1)
230
+
231
+ # Apply and Return Forward process equation
232
+ return (sqrt_alpha_cum_prod.to(original.device) * original
233
+ + sqrt_one_minus_alpha_cum_prod.to(original.device) * noise)
234
+
235
+ def sample_prev_timestep(self, xt, noise_pred, t):
236
+ r"""
237
+ Use the noise prediction by model to get
238
+ xt-1 using xt and the nosie predicted
239
+ :param xt: current timestep sample
240
+ :param noise_pred: model noise prediction
241
+ :param t: current timestep we are at
242
+ :return:
243
+ """
244
+ x0 = ((xt - (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t] * noise_pred)) /
245
+ torch.sqrt(self.alpha_cum_prod.to(xt.device)[t]))
246
+ x0 = torch.clamp(x0, -1., 1.)
247
+
248
+ mean = xt - ((self.betas.to(xt.device)[t]) * noise_pred) / (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t])
249
+ mean = mean / torch.sqrt(self.alphas.to(xt.device)[t])
250
+
251
+ if t == 0:
252
+ return mean, x0
253
+ else:
254
+ variance = (1 - self.alpha_cum_prod.to(xt.device)[t - 1]) / (1.0 - self.alpha_cum_prod.to(xt.device)[t])
255
+ variance = variance * self.betas.to(xt.device)[t]
256
+ sigma = variance ** 0.5
257
+ z = torch.randn(xt.shape).to(xt.device)
258
+
259
+ # OR
260
+ # variance = self.betas[t]
261
+ # sigma = variance ** 0.5
262
+ # z = torch.randn(xt.shape).to(xt.device)
263
+ return mean + sigma * z, x0
264
+
265
+
266
+ import torch
267
+ import math
268
+
269
+ class CosineNoiseScheduler:
270
+ r"""
271
+ Class for the cosine noise scheduler, often used in DDPM-based models.
272
+ """
273
+
274
+ def __init__(self, num_timesteps, s=0.008):
275
+ self.num_timesteps = num_timesteps
276
+ self.s = s
277
+
278
+ # Cosine schedule based on paper
279
+ def cosine_schedule(t):
280
+ return math.cos((t / self.num_timesteps + s) / (1 + s) * math.pi / 2) ** 2
281
+
282
+ # Compute alphas
283
+ self.alphas = torch.tensor([cosine_schedule(t) for t in range(num_timesteps)])
284
+ self.alpha_cum_prod = torch.cumprod(self.alphas, dim=0)
285
+ self.sqrt_alpha_cum_prod = torch.sqrt(self.alpha_cum_prod)
286
+ self.sqrt_one_minus_alpha_cum_prod = torch.sqrt(1 - self.alpha_cum_prod)
287
+
288
+ def add_noise(self, original, noise, t):
289
+ original_shape = original.shape
290
+ batch_size = original_shape[0]
291
+
292
+ sqrt_alpha_cum_prod = self.sqrt_alpha_cum_prod.to(original.device)[t].reshape(batch_size)
293
+ sqrt_one_minus_alpha_cum_prod = self.sqrt_one_minus_alpha_cum_prod.to(original.device)[t].reshape(batch_size)
294
+
295
+ for _ in range(len(original_shape) - 1):
296
+ sqrt_alpha_cum_prod = sqrt_alpha_cum_prod.unsqueeze(-1)
297
+ for _ in range(len(original_shape) - 1):
298
+ sqrt_one_minus_alpha_cum_prod = sqrt_one_minus_alpha_cum_prod.unsqueeze(-1)
299
+
300
+ return (sqrt_alpha_cum_prod * original + sqrt_one_minus_alpha_cum_prod * noise)
301
+
302
+ def sample_prev_timestep(self, xt, noise_pred, t):
303
+ x0 = ((xt - (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t] * noise_pred)) /
304
+ torch.sqrt(self.alpha_cum_prod.to(xt.device)[t]))
305
+ x0 = torch.clamp(x0, -1., 1.)
306
+
307
+ mean = xt - ((1 - self.alphas.to(xt.device)[t]) * noise_pred) / (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t])
308
+ mean = mean / torch.sqrt(self.alphas.to(xt.device)[t])
309
+
310
+ if t == 0:
311
+ return mean, x0
312
+ else:
313
+ variance = (1 - self.alpha_cum_prod.to(xt.device)[t - 1]) / (1.0 - self.alpha_cum_prod.to(xt.device)[t])
314
+ variance = variance * (1 - self.alphas.to(xt.device)[t])
315
+ sigma = variance ** 0.5
316
+ z = torch.randn(xt.shape).to(xt.device)
317
+ return mean + sigma * z, x0
318
+
319
+
320
+
321
+
322
+ import torch
323
+ import torch.nn as nn
324
+
325
+
326
+ def get_time_embedding(time_steps, temb_dim):
327
+ r"""
328
+ Convert time steps tensor into an embedding using the
329
+ sinusoidal time embedding formula
330
+ :param time_steps: 1D tensor of length batch size
331
+ :param temb_dim: Dimension of the embedding
332
+ :return: BxD embedding representation of B time steps
333
+ """
334
+ assert temb_dim % 2 == 0, "time embedding dimension must be divisible by 2"
335
+
336
+ # factor = 10000^(2i/d_model)
337
+ factor = 10000 ** ((torch.arange(
338
+ start=0, end=temb_dim // 2, dtype=torch.float32, device=time_steps.device) / (temb_dim // 2))
339
+ )
340
+
341
+ # pos / factor
342
+ # timesteps B -> B, 1 -> B, temb_dim
343
+ t_emb = time_steps[:, None].repeat(1, temb_dim // 2) / factor
344
+ t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1)
345
+ return t_emb
346
+
347
+
348
+ class DownBlock(nn.Module):
349
+ r"""
350
+ Down conv block with attention.
351
+ Sequence of following block
352
+ 1. Resnet block with time embedding
353
+ 2. Attention block
354
+ 3. Downsample
355
+ """
356
+
357
+ def __init__(self, in_channels, out_channels, t_emb_dim,
358
+ down_sample, num_heads, num_layers, attn, norm_channels, cross_attn=False, context_dim=None):
359
+ super().__init__()
360
+ self.num_layers = num_layers
361
+ self.down_sample = down_sample
362
+ self.attn = attn
363
+ self.context_dim = context_dim
364
+ self.cross_attn = cross_attn
365
+ self.t_emb_dim = t_emb_dim
366
+ self.resnet_conv_first = nn.ModuleList(
367
+ [
368
+ nn.Sequential(
369
+ nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels),
370
+ nn.SiLU(),
371
+ nn.Conv2d(in_channels if i == 0 else out_channels, out_channels,
372
+ kernel_size=3, stride=1, padding=1),
373
+ )
374
+ for i in range(num_layers)
375
+ ]
376
+ )
377
+ if self.t_emb_dim is not None:
378
+ self.t_emb_layers = nn.ModuleList([
379
+ nn.Sequential(
380
+ nn.SiLU(),
381
+ nn.Linear(self.t_emb_dim, out_channels)
382
+ )
383
+ for _ in range(num_layers)
384
+ ])
385
+ self.resnet_conv_second = nn.ModuleList(
386
+ [
387
+ nn.Sequential(
388
+ nn.GroupNorm(norm_channels, out_channels),
389
+ nn.SiLU(),
390
+ nn.Conv2d(out_channels, out_channels,
391
+ kernel_size=3, stride=1, padding=1),
392
+ )
393
+ for _ in range(num_layers)
394
+ ]
395
+ )
396
+
397
+ if self.attn:
398
+ self.attention_norms = nn.ModuleList(
399
+ [nn.GroupNorm(norm_channels, out_channels)
400
+ for _ in range(num_layers)]
401
+ )
402
+
403
+ self.attentions = nn.ModuleList(
404
+ [nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
405
+ for _ in range(num_layers)]
406
+ )
407
+
408
+ if self.cross_attn:
409
+ assert context_dim is not None, "Context Dimension must be passed for cross attention"
410
+ self.cross_attention_norms = nn.ModuleList(
411
+ [nn.GroupNorm(norm_channels, out_channels)
412
+ for _ in range(num_layers)]
413
+ )
414
+ self.cross_attentions = nn.ModuleList(
415
+ [nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
416
+ for _ in range(num_layers)]
417
+ )
418
+ self.context_proj = nn.ModuleList(
419
+ [nn.Linear(context_dim, out_channels)
420
+ for _ in range(num_layers)]
421
+ )
422
+
423
+ self.residual_input_conv = nn.ModuleList(
424
+ [
425
+ nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
426
+ for i in range(num_layers)
427
+ ]
428
+ )
429
+ self.down_sample_conv = nn.Conv2d(out_channels, out_channels,
430
+ 4, 2, 1) if self.down_sample else nn.Identity()
431
+
432
+ def forward(self, x, t_emb=None, context=None):
433
+ out = x
434
+ for i in range(self.num_layers):
435
+ # Resnet block of Unet
436
+ resnet_input = out
437
+ out = self.resnet_conv_first[i](out)
438
+ if self.t_emb_dim is not None:
439
+ out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
440
+ out = self.resnet_conv_second[i](out)
441
+ out = out + self.residual_input_conv[i](resnet_input)
442
+
443
+ if self.attn:
444
+ # Attention block of Unet
445
+ batch_size, channels, h, w = out.shape
446
+ in_attn = out.reshape(batch_size, channels, h * w)
447
+ in_attn = self.attention_norms[i](in_attn)
448
+ in_attn = in_attn.transpose(1, 2)
449
+ out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
450
+ out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
451
+ out = out + out_attn
452
+
453
+ if self.cross_attn:
454
+ assert context is not None, "context cannot be None if cross attention layers are used"
455
+ batch_size, channels, h, w = out.shape
456
+ in_attn = out.reshape(batch_size, channels, h * w)
457
+ in_attn = self.cross_attention_norms[i](in_attn)
458
+ in_attn = in_attn.transpose(1, 2)
459
+ assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim
460
+ context_proj = self.context_proj[i](context)
461
+ out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj)
462
+ out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
463
+ out = out + out_attn
464
+
465
+ # Downsample
466
+ out = self.down_sample_conv(out)
467
+ return out
468
+
469
+
470
+ class MidBlock(nn.Module):
471
+ r"""
472
+ Mid conv block with attention.
473
+ Sequence of following blocks
474
+ 1. Resnet block with time embedding
475
+ 2. Attention block
476
+ 3. Resnet block with time embedding
477
+ """
478
+
479
+ def __init__(self, in_channels, out_channels, t_emb_dim, num_heads, num_layers, norm_channels, cross_attn=None, context_dim=None):
480
+ super().__init__()
481
+ self.num_layers = num_layers
482
+ self.t_emb_dim = t_emb_dim
483
+ self.context_dim = context_dim
484
+ self.cross_attn = cross_attn
485
+ self.resnet_conv_first = nn.ModuleList(
486
+ [
487
+ nn.Sequential(
488
+ nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels),
489
+ nn.SiLU(),
490
+ nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1,
491
+ padding=1),
492
+ )
493
+ for i in range(num_layers + 1)
494
+ ]
495
+ )
496
+
497
+ if self.t_emb_dim is not None:
498
+ self.t_emb_layers = nn.ModuleList([
499
+ nn.Sequential(
500
+ nn.SiLU(),
501
+ nn.Linear(t_emb_dim, out_channels)
502
+ )
503
+ for _ in range(num_layers + 1)
504
+ ])
505
+ self.resnet_conv_second = nn.ModuleList(
506
+ [
507
+ nn.Sequential(
508
+ nn.GroupNorm(norm_channels, out_channels),
509
+ nn.SiLU(),
510
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
511
+ )
512
+ for _ in range(num_layers + 1)
513
+ ]
514
+ )
515
+
516
+ self.attention_norms = nn.ModuleList(
517
+ [nn.GroupNorm(norm_channels, out_channels)
518
+ for _ in range(num_layers)]
519
+ )
520
+
521
+ self.attentions = nn.ModuleList(
522
+ [nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
523
+ for _ in range(num_layers)]
524
+ )
525
+ if self.cross_attn:
526
+ assert context_dim is not None, "Context Dimension must be passed for cross attention"
527
+ self.cross_attention_norms = nn.ModuleList(
528
+ [nn.GroupNorm(norm_channels, out_channels)
529
+ for _ in range(num_layers)]
530
+ )
531
+ self.cross_attentions = nn.ModuleList(
532
+ [nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
533
+ for _ in range(num_layers)]
534
+ )
535
+ self.context_proj = nn.ModuleList(
536
+ [nn.Linear(context_dim, out_channels)
537
+ for _ in range(num_layers)]
538
+ )
539
+ self.residual_input_conv = nn.ModuleList(
540
+ [
541
+ nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
542
+ for i in range(num_layers + 1)
543
+ ]
544
+ )
545
+
546
+ def forward(self, x, t_emb=None, context=None):
547
+ out = x
548
+
549
+ # First resnet block
550
+ resnet_input = out
551
+ out = self.resnet_conv_first[0](out)
552
+ if self.t_emb_dim is not None:
553
+ out = out + self.t_emb_layers[0](t_emb)[:, :, None, None]
554
+ out = self.resnet_conv_second[0](out)
555
+ out = out + self.residual_input_conv[0](resnet_input)
556
+
557
+ for i in range(self.num_layers):
558
+ # Attention Block
559
+ batch_size, channels, h, w = out.shape
560
+ in_attn = out.reshape(batch_size, channels, h * w)
561
+ in_attn = self.attention_norms[i](in_attn)
562
+ in_attn = in_attn.transpose(1, 2)
563
+ out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
564
+ out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
565
+ out = out + out_attn
566
+
567
+ if self.cross_attn:
568
+ assert context is not None, "context cannot be None if cross attention layers are used"
569
+ batch_size, channels, h, w = out.shape
570
+ in_attn = out.reshape(batch_size, channels, h * w)
571
+ in_attn = self.cross_attention_norms[i](in_attn)
572
+ in_attn = in_attn.transpose(1, 2)
573
+ assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim
574
+ context_proj = self.context_proj[i](context)
575
+ out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj)
576
+ out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
577
+ out = out + out_attn
578
+
579
+
580
+ # Resnet Block
581
+ resnet_input = out
582
+ out = self.resnet_conv_first[i + 1](out)
583
+ if self.t_emb_dim is not None:
584
+ out = out + self.t_emb_layers[i + 1](t_emb)[:, :, None, None]
585
+ out = self.resnet_conv_second[i + 1](out)
586
+ out = out + self.residual_input_conv[i + 1](resnet_input)
587
+
588
+ return out
589
+
590
+
591
+ class UpBlock(nn.Module):
592
+ r"""
593
+ Up conv block with attention.
594
+ Sequence of following blocks
595
+ 1. Upsample
596
+ 1. Concatenate Down block output
597
+ 2. Resnet block with time embedding
598
+ 3. Attention Block
599
+ """
600
+
601
+ def __init__(self, in_channels, out_channels, t_emb_dim,
602
+ up_sample, num_heads, num_layers, attn, norm_channels):
603
+ super().__init__()
604
+ self.num_layers = num_layers
605
+ self.up_sample = up_sample
606
+ self.t_emb_dim = t_emb_dim
607
+ self.attn = attn
608
+ self.resnet_conv_first = nn.ModuleList(
609
+ [
610
+ nn.Sequential(
611
+ nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels),
612
+ nn.SiLU(),
613
+ nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1,
614
+ padding=1),
615
+ )
616
+ for i in range(num_layers)
617
+ ]
618
+ )
619
+
620
+ if self.t_emb_dim is not None:
621
+ self.t_emb_layers = nn.ModuleList([
622
+ nn.Sequential(
623
+ nn.SiLU(),
624
+ nn.Linear(t_emb_dim, out_channels)
625
+ )
626
+ for _ in range(num_layers)
627
+ ])
628
+
629
+ self.resnet_conv_second = nn.ModuleList(
630
+ [
631
+ nn.Sequential(
632
+ nn.GroupNorm(norm_channels, out_channels),
633
+ nn.SiLU(),
634
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
635
+ )
636
+ for _ in range(num_layers)
637
+ ]
638
+ )
639
+ if self.attn:
640
+ self.attention_norms = nn.ModuleList(
641
+ [
642
+ nn.GroupNorm(norm_channels, out_channels)
643
+ for _ in range(num_layers)
644
+ ]
645
+ )
646
+
647
+ self.attentions = nn.ModuleList(
648
+ [
649
+ nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
650
+ for _ in range(num_layers)
651
+ ]
652
+ )
653
+
654
+ self.residual_input_conv = nn.ModuleList(
655
+ [
656
+ nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
657
+ for i in range(num_layers)
658
+ ]
659
+ )
660
+ self.up_sample_conv = nn.ConvTranspose2d(in_channels, in_channels,
661
+ 4, 2, 1) \
662
+ if self.up_sample else nn.Identity()
663
+
664
+ def forward(self, x, out_down=None, t_emb=None):
665
+ # Upsample
666
+ x = self.up_sample_conv(x)
667
+
668
+ # Concat with Downblock output
669
+ if out_down is not None:
670
+ x = torch.cat([x, out_down], dim=1)
671
+
672
+ out = x
673
+ for i in range(self.num_layers):
674
+ # Resnet Block
675
+ resnet_input = out
676
+ out = self.resnet_conv_first[i](out)
677
+ if self.t_emb_dim is not None:
678
+ out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
679
+ out = self.resnet_conv_second[i](out)
680
+ out = out + self.residual_input_conv[i](resnet_input)
681
+
682
+ # Self Attention
683
+ if self.attn:
684
+ batch_size, channels, h, w = out.shape
685
+ in_attn = out.reshape(batch_size, channels, h * w)
686
+ in_attn = self.attention_norms[i](in_attn)
687
+ in_attn = in_attn.transpose(1, 2)
688
+ out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
689
+ out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
690
+ out = out + out_attn
691
+ return out
692
+
693
+
694
+ class UpBlockUnet(nn.Module):
695
+ r"""
696
+ Up conv block with attention.
697
+ Sequence of following blocks
698
+ 1. Upsample
699
+ 1. Concatenate Down block output
700
+ 2. Resnet block with time embedding
701
+ 3. Attention Block
702
+ """
703
+
704
+ def __init__(self, in_channels, out_channels, t_emb_dim, up_sample,
705
+ num_heads, num_layers, norm_channels, cross_attn=False, context_dim=None):
706
+ super().__init__()
707
+ self.num_layers = num_layers
708
+ self.up_sample = up_sample
709
+ self.t_emb_dim = t_emb_dim
710
+ self.cross_attn = cross_attn
711
+ self.context_dim = context_dim
712
+ self.resnet_conv_first = nn.ModuleList(
713
+ [
714
+ nn.Sequential(
715
+ nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels),
716
+ nn.SiLU(),
717
+ nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1,
718
+ padding=1),
719
+ )
720
+ for i in range(num_layers)
721
+ ]
722
+ )
723
+
724
+ if self.t_emb_dim is not None:
725
+ self.t_emb_layers = nn.ModuleList([
726
+ nn.Sequential(
727
+ nn.SiLU(),
728
+ nn.Linear(t_emb_dim, out_channels)
729
+ )
730
+ for _ in range(num_layers)
731
+ ])
732
+
733
+ self.resnet_conv_second = nn.ModuleList(
734
+ [
735
+ nn.Sequential(
736
+ nn.GroupNorm(norm_channels, out_channels),
737
+ nn.SiLU(),
738
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
739
+ )
740
+ for _ in range(num_layers)
741
+ ]
742
+ )
743
+
744
+ self.attention_norms = nn.ModuleList(
745
+ [
746
+ nn.GroupNorm(norm_channels, out_channels)
747
+ for _ in range(num_layers)
748
+ ]
749
+ )
750
+
751
+ self.attentions = nn.ModuleList(
752
+ [
753
+ nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
754
+ for _ in range(num_layers)
755
+ ]
756
+ )
757
+
758
+ if self.cross_attn:
759
+ assert context_dim is not None, "Context Dimension must be passed for cross attention"
760
+ self.cross_attention_norms = nn.ModuleList(
761
+ [nn.GroupNorm(norm_channels, out_channels)
762
+ for _ in range(num_layers)]
763
+ )
764
+ self.cross_attentions = nn.ModuleList(
765
+ [nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
766
+ for _ in range(num_layers)]
767
+ )
768
+ self.context_proj = nn.ModuleList(
769
+ [nn.Linear(context_dim, out_channels)
770
+ for _ in range(num_layers)]
771
+ )
772
+ self.residual_input_conv = nn.ModuleList(
773
+ [
774
+ nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
775
+ for i in range(num_layers)
776
+ ]
777
+ )
778
+ self.up_sample_conv = nn.ConvTranspose2d(in_channels // 2, in_channels // 2,
779
+ 4, 2, 1) \
780
+ if self.up_sample else nn.Identity()
781
+
782
+ def forward(self, x, out_down=None, t_emb=None, context=None):
783
+ x = self.up_sample_conv(x)
784
+ if out_down is not None:
785
+ x = torch.cat([x, out_down], dim=1)
786
+
787
+ out = x
788
+ for i in range(self.num_layers):
789
+ # Resnet
790
+ resnet_input = out
791
+ out = self.resnet_conv_first[i](out)
792
+ if self.t_emb_dim is not None:
793
+ out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
794
+ out = self.resnet_conv_second[i](out)
795
+ out = out + self.residual_input_conv[i](resnet_input)
796
+ # Self Attention
797
+ batch_size, channels, h, w = out.shape
798
+ in_attn = out.reshape(batch_size, channels, h * w)
799
+ in_attn = self.attention_norms[i](in_attn)
800
+ in_attn = in_attn.transpose(1, 2)
801
+ out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
802
+ out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
803
+ out = out + out_attn
804
+ # Cross Attention
805
+ if self.cross_attn:
806
+ assert context is not None, "context cannot be None if cross attention layers are used"
807
+ batch_size, channels, h, w = out.shape
808
+ in_attn = out.reshape(batch_size, channels, h * w)
809
+ in_attn = self.cross_attention_norms[i](in_attn)
810
+ in_attn = in_attn.transpose(1, 2)
811
+ assert len(context.shape) == 3, \
812
+ "Context shape does not match B,_,CONTEXT_DIM"
813
+ assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim,\
814
+ "Context shape does not match B,_,CONTEXT_DIM"
815
+ context_proj = self.context_proj[i](context)
816
+ out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj)
817
+ out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
818
+ out = out + out_attn
819
+
820
+ return out
821
+
822
+ """Vqvae"""
823
+
824
+ import torch
825
+ import torch.nn as nn
826
+
827
+
828
+ class VQVAE(nn.Module):
829
+ def __init__(self, im_channels, model_config):
830
+ super().__init__()
831
+ self.down_channels = model_config.down_channels
832
+ self.mid_channels = model_config.mid_channels
833
+ self.down_sample = model_config.down_sample
834
+ self.num_down_layers = model_config.num_down_layers
835
+ self.num_mid_layers = model_config.num_mid_layers
836
+ self.num_up_layers = model_config.num_up_layers
837
+
838
+ # To disable attention in Downblock of Encoder and Upblock of Decoder
839
+ self.attns = model_config.attn_down
840
+
841
+ # Latent Dimension
842
+ self.z_channels = model_config.z_channels
843
+ self.codebook_size = model_config.codebook_size
844
+ self.norm_channels = model_config.norm_channels
845
+ self.num_heads = model_config.num_heads
846
+
847
+ # Assertion to validate the channel information
848
+ assert self.mid_channels[0] == self.down_channels[-1]
849
+ assert self.mid_channels[-1] == self.down_channels[-1]
850
+ assert len(self.down_sample) == len(self.down_channels) - 1
851
+ assert len(self.attns) == len(self.down_channels) - 1
852
+
853
+ # Wherever we use downsampling in encoder correspondingly use
854
+ # upsampling in decoder
855
+ self.up_sample = list(reversed(self.down_sample))
856
+
857
+ ##################### Encoder ######################
858
+ self.encoder_conv_in = nn.Conv2d(im_channels, self.down_channels[0], kernel_size=3, padding=(1, 1))
859
+
860
+ # Downblock + Midblock
861
+ self.encoder_layers = nn.ModuleList([])
862
+ for i in range(len(self.down_channels) - 1):
863
+ self.encoder_layers.append(DownBlock(self.down_channels[i], self.down_channels[i + 1],
864
+ t_emb_dim=None, down_sample=self.down_sample[i],
865
+ num_heads=self.num_heads,
866
+ num_layers=self.num_down_layers,
867
+ attn=self.attns[i],
868
+ norm_channels=self.norm_channels))
869
+
870
+ self.encoder_mids = nn.ModuleList([])
871
+ for i in range(len(self.mid_channels) - 1):
872
+ self.encoder_mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i + 1],
873
+ t_emb_dim=None,
874
+ num_heads=self.num_heads,
875
+ num_layers=self.num_mid_layers,
876
+ norm_channels=self.norm_channels))
877
+
878
+ self.encoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[-1])
879
+ self.encoder_conv_out = nn.Conv2d(self.down_channels[-1], self.z_channels, kernel_size=3, padding=1)
880
+
881
+ # Pre Quantization Convolution
882
+ self.pre_quant_conv = nn.Conv2d(self.z_channels, self.z_channels, kernel_size=1)
883
+
884
+ # Codebook
885
+ self.embedding = nn.Embedding(self.codebook_size, self.z_channels)
886
+ ####################################################
887
+
888
+ ##################### Decoder ######################
889
+
890
+ # Post Quantization Convolution
891
+ self.post_quant_conv = nn.Conv2d(self.z_channels, self.z_channels, kernel_size=1)
892
+ self.decoder_conv_in = nn.Conv2d(self.z_channels, self.mid_channels[-1], kernel_size=3, padding=(1, 1))
893
+
894
+ # Midblock + Upblock
895
+ self.decoder_mids = nn.ModuleList([])
896
+ for i in reversed(range(1, len(self.mid_channels))):
897
+ self.decoder_mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i - 1],
898
+ t_emb_dim=None,
899
+ num_heads=self.num_heads,
900
+ num_layers=self.num_mid_layers,
901
+ norm_channels=self.norm_channels))
902
+
903
+ self.decoder_layers = nn.ModuleList([])
904
+ for i in reversed(range(1, len(self.down_channels))):
905
+ self.decoder_layers.append(UpBlock(self.down_channels[i], self.down_channels[i - 1],
906
+ t_emb_dim=None, up_sample=self.down_sample[i - 1],
907
+ num_heads=self.num_heads,
908
+ num_layers=self.num_up_layers,
909
+ attn=self.attns[i-1],
910
+ norm_channels=self.norm_channels))
911
+
912
+ self.decoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[0])
913
+ self.decoder_conv_out = nn.Conv2d(self.down_channels[0], im_channels, kernel_size=3, padding=1)
914
+
915
+ def quantize(self, x):
916
+ B, C, H, W = x.shape
917
+
918
+ # B, C, H, W -> B, H, W, C
919
+ x = x.permute(0, 2, 3, 1)
920
+
921
+ # B, H, W, C -> B, H*W, C
922
+ x = x.reshape(x.size(0), -1, x.size(-1))
923
+
924
+ # Find nearest embedding/codebook vector
925
+ # dist between (B, H*W, C) and (B, K, C) -> (B, H*W, K)
926
+ dist = torch.cdist(x, self.embedding.weight[None, :].repeat((x.size(0), 1, 1)))
927
+ # (B, H*W)
928
+ min_encoding_indices = torch.argmin(dist, dim=-1)
929
+
930
+ # Replace encoder output with nearest codebook
931
+ # quant_out -> B*H*W, C
932
+ quant_out = torch.index_select(self.embedding.weight, 0, min_encoding_indices.view(-1))
933
+
934
+ # x -> B*H*W, C
935
+ x = x.reshape((-1, x.size(-1)))
936
+ commmitment_loss = torch.mean((quant_out.detach() - x) ** 2)
937
+ codebook_loss = torch.mean((quant_out - x.detach()) ** 2)
938
+ quantize_losses = {
939
+ 'codebook_loss': codebook_loss,
940
+ 'commitment_loss': commmitment_loss
941
+ }
942
+ # Straight through estimation
943
+ quant_out = x + (quant_out - x).detach()
944
+
945
+ # quant_out -> B, C, H, W
946
+ quant_out = quant_out.reshape((B, H, W, C)).permute(0, 3, 1, 2)
947
+ min_encoding_indices = min_encoding_indices.reshape((-1, quant_out.size(-2), quant_out.size(-1)))
948
+ return quant_out, quantize_losses, min_encoding_indices
949
+
950
+ def encode(self, x):
951
+ out = self.encoder_conv_in(x)
952
+ for idx, down in enumerate(self.encoder_layers):
953
+ out = down(out)
954
+ for mid in self.encoder_mids:
955
+ out = mid(out)
956
+ out = self.encoder_norm_out(out)
957
+ out = nn.SiLU()(out)
958
+ out = self.encoder_conv_out(out)
959
+ out = self.pre_quant_conv(out)
960
+ out, quant_losses, _ = self.quantize(out)
961
+ return out, quant_losses
962
+
963
+ def decode(self, z):
964
+ out = z
965
+ out = self.post_quant_conv(out)
966
+ out = self.decoder_conv_in(out)
967
+ for mid in self.decoder_mids:
968
+ out = mid(out)
969
+ for idx, up in enumerate(self.decoder_layers):
970
+ out = up(out)
971
+
972
+ out = self.decoder_norm_out(out)
973
+ out = nn.SiLU()(out)
974
+ out = self.decoder_conv_out(out)
975
+ return out
976
+
977
+ def forward(self, x):
978
+ z, quant_losses = self.encode(x)
979
+ out = self.decode(z)
980
+ return out, z, quant_losses
981
+
982
+ """Vae"""
983
+
984
+ import torch
985
+ import torch.nn as nn
986
+
987
+
988
+ class VAE(nn.Module):
989
+ def __init__(self, im_channels, model_config):
990
+ super().__init__()
991
+ self.down_channels = model_config['down_channels']
992
+ self.mid_channels = model_config['mid_channels']
993
+ self.down_sample = model_config['down_sample']
994
+ self.num_down_layers = model_config['num_down_layers']
995
+ self.num_mid_layers = model_config['num_mid_layers']
996
+ self.num_up_layers = model_config['num_up_layers']
997
+
998
+ # To disable attention in Downblock of Encoder and Upblock of Decoder
999
+ self.attns = model_config['attn_down']
1000
+
1001
+ # Latent Dimension
1002
+ self.z_channels = model_config['z_channels']
1003
+ self.norm_channels = model_config['norm_channels']
1004
+ self.num_heads = model_config['num_heads']
1005
+
1006
+ # Assertion to validate the channel information
1007
+ assert self.mid_channels[0] == self.down_channels[-1]
1008
+ assert self.mid_channels[-1] == self.down_channels[-1]
1009
+ assert len(self.down_sample) == len(self.down_channels) - 1
1010
+ assert len(self.attns) == len(self.down_channels) - 1
1011
+
1012
+ # Wherever we use downsampling in encoder correspondingly use
1013
+ # upsampling in decoder
1014
+ self.up_sample = list(reversed(self.down_sample))
1015
+
1016
+ ##################### Encoder ######################
1017
+ self.encoder_conv_in = nn.Conv2d(im_channels, self.down_channels[0], kernel_size=3, padding=(1, 1))
1018
+
1019
+ # Downblock + Midblock
1020
+ self.encoder_layers = nn.ModuleList([])
1021
+ for i in range(len(self.down_channels) - 1):
1022
+ self.encoder_layers.append(DownBlock(self.down_channels[i], self.down_channels[i + 1],
1023
+ t_emb_dim=None, down_sample=self.down_sample[i],
1024
+ num_heads=self.num_heads,
1025
+ num_layers=self.num_down_layers,
1026
+ attn=self.attns[i],
1027
+ norm_channels=self.norm_channels))
1028
+
1029
+ self.encoder_mids = nn.ModuleList([])
1030
+ for i in range(len(self.mid_channels) - 1):
1031
+ self.encoder_mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i + 1],
1032
+ t_emb_dim=None,
1033
+ num_heads=self.num_heads,
1034
+ num_layers=self.num_mid_layers,
1035
+ norm_channels=self.norm_channels))
1036
+
1037
+ self.encoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[-1])
1038
+ self.encoder_conv_out = nn.Conv2d(self.down_channels[-1], 2*self.z_channels, kernel_size=3, padding=1)
1039
+
1040
+ # Latent Dimension is 2*Latent because we are predicting mean & variance
1041
+ self.pre_quant_conv = nn.Conv2d(2*self.z_channels, 2*self.z_channels, kernel_size=1)
1042
+ ####################################################
1043
+
1044
+
1045
+ ##################### Decoder ######################
1046
+ self.post_quant_conv = nn.Conv2d(self.z_channels, self.z_channels, kernel_size=1)
1047
+ self.decoder_conv_in = nn.Conv2d(self.z_channels, self.mid_channels[-1], kernel_size=3, padding=(1, 1))
1048
+
1049
+ # Midblock + Upblock
1050
+ self.decoder_mids = nn.ModuleList([])
1051
+ for i in reversed(range(1, len(self.mid_channels))):
1052
+ self.decoder_mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i - 1],
1053
+ t_emb_dim=None,
1054
+ num_heads=self.num_heads,
1055
+ num_layers=self.num_mid_layers,
1056
+ norm_channels=self.norm_channels))
1057
+
1058
+ self.decoder_layers = nn.ModuleList([])
1059
+ for i in reversed(range(1, len(self.down_channels))):
1060
+ self.decoder_layers.append(UpBlock(self.down_channels[i], self.down_channels[i - 1],
1061
+ t_emb_dim=None, up_sample=self.down_sample[i - 1],
1062
+ num_heads=self.num_heads,
1063
+ num_layers=self.num_up_layers,
1064
+ attn=self.attns[i - 1],
1065
+ norm_channels=self.norm_channels))
1066
+
1067
+ self.decoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[0])
1068
+ self.decoder_conv_out = nn.Conv2d(self.down_channels[0], im_channels, kernel_size=3, padding=1)
1069
+
1070
+ def encode(self, x):
1071
+ out = self.encoder_conv_in(x)
1072
+ for idx, down in enumerate(self.encoder_layers):
1073
+ out = down(out)
1074
+ for mid in self.encoder_mids:
1075
+ out = mid(out)
1076
+ out = self.encoder_norm_out(out)
1077
+ out = nn.SiLU()(out)
1078
+ out = self.encoder_conv_out(out)
1079
+ out = self.pre_quant_conv(out)
1080
+ mean, logvar = torch.chunk(out, 2, dim=1)
1081
+ std = torch.exp(0.5 * logvar)
1082
+ sample = mean + std * torch.randn(mean.shape).to(device=x.device)
1083
+ return sample, out
1084
+
1085
+ def decode(self, z):
1086
+ out = z
1087
+ out = self.post_quant_conv(out)
1088
+ out = self.decoder_conv_in(out)
1089
+ for mid in self.decoder_mids:
1090
+ out = mid(out)
1091
+ for idx, up in enumerate(self.decoder_layers):
1092
+ out = up(out)
1093
+
1094
+ out = self.decoder_norm_out(out)
1095
+ out = nn.SiLU()(out)
1096
+ out = self.decoder_conv_out(out)
1097
+ return out
1098
+
1099
+ def forward(self, x):
1100
+ z, encoder_output = self.encode(x)
1101
+ out = self.decode(z)
1102
+ return out, encoder_output
1103
+
1104
+ """Discriminator"""
1105
+
1106
+ import torch
1107
+ import torch.nn as nn
1108
+
1109
+
1110
+ class Discriminator(nn.Module):
1111
+ r"""
1112
+ PatchGAN Discriminator.
1113
+ Rather than taking IMG_CHANNELSxIMG_HxIMG_W all the way to
1114
+ 1 scalar value , we instead predict grid of values.
1115
+ Where each grid is prediction of how likely
1116
+ the discriminator thinks that the image patch corresponding
1117
+ to the grid cell is real
1118
+ """
1119
+
1120
+ def __init__(self, im_channels=3,
1121
+ conv_channels=[64, 128, 256],
1122
+ kernels=[4,4,4,4],
1123
+ strides=[2,2,2,1],
1124
+ paddings=[1,1,1,1]):
1125
+ super().__init__()
1126
+ self.im_channels = im_channels
1127
+ activation = nn.LeakyReLU(0.2)
1128
+ layers_dim = [self.im_channels] + conv_channels + [1]
1129
+ self.layers = nn.ModuleList([
1130
+ nn.Sequential(
1131
+ nn.Conv2d(layers_dim[i], layers_dim[i + 1],
1132
+ kernel_size=kernels[i],
1133
+ stride=strides[i],
1134
+ padding=paddings[i],
1135
+ bias=False if i !=0 else True),
1136
+ nn.BatchNorm2d(layers_dim[i + 1]) if i != len(layers_dim) - 2 and i != 0 else nn.Identity(),
1137
+ activation if i != len(layers_dim) - 2 else nn.Identity()
1138
+ )
1139
+ for i in range(len(layers_dim) - 1)
1140
+ ])
1141
+
1142
+ def forward(self, x):
1143
+ out = x
1144
+ for layer in self.layers:
1145
+ out = layer(out)
1146
+ return out
1147
+
1148
+
1149
+ # if __name__ == '__main__':
1150
+ # x = torch.randn((2,3, 256, 256))
1151
+ # prob = Discriminator(im_channels=3)(x)
1152
+ # print(prob.shape)
1153
+
1154
+ # import os
1155
+
1156
+ # image_paths = [os.path.join("/home/taruntejaneurips23/Ashish/datasets/animefacedata/images", f)
1157
+ # for f in os.listdir("/home/taruntejaneurips23/Ashish/datasets/animefacedata/images")]
1158
+ # image_paths
1159
+
1160
+ import glob
1161
+ import os
1162
+ import torchvision
1163
+ from PIL import Image
1164
+ from tqdm import tqdm, trange
1165
+ # from utils.diffusion_utils import load_latents
1166
+ from torch.utils.data.dataset import Dataset
1167
+
1168
+ import pickle
1169
+ import glob
1170
+ import os
1171
+ import torch
1172
+
1173
+
1174
+ def load_latents(latent_path):
1175
+ r"""
1176
+ Simple utility to save latents to speed up ldm training
1177
+ :param latent_path:
1178
+ :return:
1179
+ """
1180
+ latent_maps = {}
1181
+ for fname in glob.glob(os.path.join(latent_path, '*.pkl')):
1182
+ s = pickle.load(open(fname, 'rb'))
1183
+ for k, v in s.items():
1184
+ latent_maps[k] = v[0]
1185
+ return latent_maps
1186
+
1187
+
1188
+ def drop_text_condition(text_embed, im, empty_text_embed, text_drop_prob):
1189
+ if text_drop_prob > 0:
1190
+ text_drop_mask = torch.zeros((im.shape[0]), device=im.device).float().uniform_(0,
1191
+ 1) < text_drop_prob
1192
+ assert empty_text_embed is not None, ("Text Conditioning required as well as"
1193
+ " text dropping but empty text representation not created")
1194
+ text_embed[text_drop_mask, :, :] = empty_text_embed[0]
1195
+ return text_embed
1196
+
1197
+
1198
+ def drop_image_condition(image_condition, im, im_drop_prob):
1199
+ if im_drop_prob > 0:
1200
+ im_drop_mask = torch.zeros((im.shape[0], 1, 1, 1), device=im.device).float().uniform_(0,
1201
+ 1) > im_drop_prob
1202
+ return image_condition * im_drop_mask
1203
+ else:
1204
+ return image_condition
1205
+
1206
+
1207
+ def drop_class_condition(class_condition, class_drop_prob, im):
1208
+ if class_drop_prob > 0:
1209
+ class_drop_mask = torch.zeros((im.shape[0], 1), device=im.device).float().uniform_(0,
1210
+ 1) > class_drop_prob
1211
+ return class_condition * class_drop_mask
1212
+ else:
1213
+ return class_condition
1214
+
1215
+
1216
+ class MnistDataset(Dataset):
1217
+ r"""
1218
+ Nothing special here. Just a simple dataset class for mnist images.
1219
+ Created a dataset class rather using torchvision to allow
1220
+ replacement with any other image dataset
1221
+ """
1222
+
1223
+ def __init__(self, split, im_path, im_size, im_channels,
1224
+ use_latents=False, latent_path=None, condition_config=None):
1225
+ r"""
1226
+ Init method for initializing the dataset properties
1227
+ :param split: train/test to locate the image files
1228
+ :param im_path: root folder of images
1229
+ :param im_ext: image extension. assumes all
1230
+ images would be this type.
1231
+ """
1232
+ self.split = split
1233
+ self.im_size = im_size
1234
+ self.im_channels = im_channels
1235
+
1236
+ # Should we use latents or not
1237
+ self.latent_maps = None
1238
+ self.use_latents = False
1239
+
1240
+ # Conditioning for the dataset
1241
+ self.condition_types = [] if condition_config is None else condition_config['condition_types']
1242
+
1243
+ self.images, self.labels = self.load_images(im_path)
1244
+
1245
+ # Whether to load images and call vae or to load latents
1246
+ if use_latents and latent_path is not None:
1247
+ latent_maps = load_latents(latent_path)
1248
+ if len(latent_maps) == len(self.images):
1249
+ self.use_latents = True
1250
+ self.latent_maps = latent_maps
1251
+ print('Found {} latents'.format(len(self.latent_maps)))
1252
+ else:
1253
+ print('Latents not found')
1254
+
1255
+ def load_images(self, im_path):
1256
+ r"""
1257
+ Gets all images from the path specified
1258
+ and stacks them all up
1259
+ :param im_path:
1260
+ :return:
1261
+ """
1262
+ assert os.path.exists(im_path), "images path {} does not exist".format(im_path)
1263
+ ims = []
1264
+ labels = []
1265
+ for d_name in tqdm(os.listdir(im_path)):
1266
+ fnames = glob.glob(os.path.join(im_path, d_name, '*.{}'.format('png')))
1267
+ fnames += glob.glob(os.path.join(im_path, d_name, '*.{}'.format('jpg')))
1268
+ fnames += glob.glob(os.path.join(im_path, d_name, '*.{}'.format('jpeg')))
1269
+ for fname in fnames:
1270
+ ims.append(fname)
1271
+ if 'class' in self.condition_types:
1272
+ labels.append(int(d_name))
1273
+ print('Found {} images for split {}'.format(len(ims), self.split))
1274
+ return ims, labels
1275
+
1276
+ def __len__(self):
1277
+ return len(self.images)
1278
+
1279
+ def __getitem__(self, index):
1280
+ ######## Set Conditioning Info ########
1281
+ cond_inputs = {}
1282
+ if 'class' in self.condition_types:
1283
+ cond_inputs['class'] = self.labels[index]
1284
+ #######################################
1285
+
1286
+ if self.use_latents:
1287
+ latent = self.latent_maps[self.images[index]]
1288
+ if len(self.condition_types) == 0:
1289
+ return latent
1290
+ else:
1291
+ return latent, cond_inputs
1292
+ else:
1293
+ im = Image.open(self.images[index])
1294
+ im_tensor = torchvision.transforms.ToTensor()(im)
1295
+
1296
+ # Convert input to -1 to 1 range.
1297
+ im_tensor = (2 * im_tensor) - 1
1298
+ if len(self.condition_types) == 0:
1299
+ return im_tensor
1300
+ else:
1301
+ return im_tensor, cond_inputs
1302
+
1303
+
1304
+ class AnimeFaceDataset(Dataset):
1305
+ def __init__(self, split, im_path, im_size, im_channels,
1306
+ use_latents=False, latent_path=None, condition_config=None):
1307
+
1308
+ self.split = split
1309
+ self.im_size = im_size
1310
+ self.im_channels = im_channels
1311
+
1312
+ # Should we use latents or not
1313
+ self.latent_maps = None
1314
+ self.use_latents = False
1315
+
1316
+ # Conditioning for the dataset
1317
+ self.condition_types = [] if condition_config is None else condition_config['condition_types']
1318
+
1319
+ self.images = self.load_images(im_path)
1320
+
1321
+ # Whether to load images and call vae or to load latents
1322
+ if use_latents and latent_path is not None:
1323
+ latent_maps = load_latents(latent_path)
1324
+ if len(latent_maps) == len(self.images):
1325
+ self.use_latents = True
1326
+ self.latent_maps = latent_maps
1327
+ print('Found {} latents'.format(len(self.latent_maps)))
1328
+ else:
1329
+ print('Latents not found')
1330
+
1331
+ def load_images(self, im_path):
1332
+ r"""
1333
+ Gets all images from the path specified
1334
+ and stacks them all up
1335
+ :param im_path:
1336
+ :return:
1337
+ """
1338
+ assert os.path.exists(im_path), "images path {} does not exist".format(im_path)
1339
+ # ims = []
1340
+ # labels = []
1341
+ ims = [os.path.join(im_path, f) for f in os.listdir(im_path)]
1342
+ return ims
1343
+
1344
+ def __len__(self):
1345
+ return len(self.images)
1346
+
1347
+ def __getitem__(self, index):
1348
+ ######## Set Conditioning Info ########
1349
+ # cond_inputs = {}
1350
+ # if 'class' in self.condition_types:
1351
+ # cond_inputs['class'] = self.labels[index]
1352
+ #######################################
1353
+
1354
+ if self.use_latents:
1355
+ latent = self.latent_maps[self.images[index]]
1356
+ if len(self.condition_types) == 0:
1357
+ return latent
1358
+ # else:
1359
+ # return latent, cond_inputs
1360
+ else:
1361
+ im = Image.open(self.images[index])
1362
+ im_tensor = torchvision.transforms.Compose([
1363
+ torchvision.transforms.Resize(self.im_size),
1364
+ torchvision.transforms.CenterCrop(self.im_size),
1365
+ torchvision.transforms.ToTensor(),
1366
+ ])(im)
1367
+ im.close()
1368
+ # im_tensor = torchvision.transforms.ToTensor()(im)
1369
+
1370
+ # Convert input to -1 to 1 range.
1371
+ im_tensor = (2 * im_tensor) - 1
1372
+ if len(self.condition_types) == 0:
1373
+ return im_tensor
1374
+ # else:
1375
+ # return im_tensor, cond_inputs
1376
+
1377
+
1378
+ import glob
1379
+ import os
1380
+ import random
1381
+ import torch
1382
+ import torchvision
1383
+ import numpy as np
1384
+ from PIL import Image
1385
+ from tqdm import tqdm
1386
+ from torch.utils.data.dataset import Dataset
1387
+
1388
+
1389
+ class CelebDataset(Dataset):
1390
+ def __init__(self, split, im_path, im_size, im_channels,
1391
+ use_latents=False, latent_path=None, condition_config=None):
1392
+
1393
+ self.split = split
1394
+ self.im_size = im_size
1395
+ self.im_channels = im_channels
1396
+
1397
+ # Should we use latents or not
1398
+ self.latent_maps = None
1399
+ self.use_latents = False
1400
+
1401
+ # Conditioning for the dataset
1402
+ self.condition_types = [] if condition_config is None else condition_config['condition_types']
1403
+
1404
+ self.images = self.load_images(im_path)
1405
+
1406
+ # Whether to load images and call vae or to load latents
1407
+ if use_latents and latent_path is not None:
1408
+ latent_maps = load_latents(latent_path)
1409
+ if len(latent_maps) == len(self.images):
1410
+ self.use_latents = True
1411
+ self.latent_maps = latent_maps
1412
+ print('Found {} latents'.format(len(self.latent_maps)))
1413
+ else:
1414
+ print('Latents not found')
1415
+
1416
+ def load_images(self, im_path):
1417
+ r"""
1418
+ Gets all images from the path specified
1419
+ and stacks them all up
1420
+ :param im_path:
1421
+ :return:
1422
+ """
1423
+ assert os.path.exists(im_path), "images path {} does not exist".format(im_path)
1424
+ # ims = []
1425
+ # labels = []
1426
+ ims = [os.path.join(im_path, f) for f in os.listdir(im_path)]
1427
+ return ims
1428
+
1429
+ def __len__(self):
1430
+ return len(self.images)
1431
+
1432
+ def __getitem__(self, index):
1433
+ ######## Set Conditioning Info ########
1434
+ # cond_inputs = {}
1435
+ # if 'class' in self.condition_types:
1436
+ # cond_inputs['class'] = self.labels[index]
1437
+ #######################################
1438
+
1439
+ if self.use_latents:
1440
+ latent = self.latent_maps[self.images[index]]
1441
+ if len(self.condition_types) == 0:
1442
+ return latent
1443
+ # else:
1444
+ # return latent, cond_inputs
1445
+ else:
1446
+ im = Image.open(self.images[index])
1447
+ im_tensor = torchvision.transforms.Compose([
1448
+ # torchvision.transforms.Resize(self.im_size),
1449
+ torchvision.transforms.CenterCrop(self.im_size),
1450
+ torchvision.transforms.ToTensor(),
1451
+ ])(im)
1452
+ im.close()
1453
+ # im_tensor = torchvision.transforms.ToTensor()(im)
1454
+
1455
+ # Convert input to -1 to 1 range.
1456
+ im_tensor = (2 * im_tensor) - 1
1457
+ if len(self.condition_types) == 0:
1458
+ return im_tensor
1459
+ # else:
1460
+ # return im_tensor, cond_inputs
1461
+ import pandas as pd
1462
+ class CelebHairDataset(Dataset):
1463
+ def __init__(self, split, im_path, im_size, im_channels,
1464
+ use_latents=False, latent_path=None, condition_config=None):
1465
+
1466
+ self.df = pd.read_csv("/home/taruntejaneurips23/Ashish/DDPM/hair_df_100.csv")
1467
+ self.split = split
1468
+ self.im_size = im_size
1469
+ self.im_channels = im_channels
1470
+
1471
+ # Should we use latents or not
1472
+ self.latent_maps = None
1473
+ self.use_latents = False
1474
+
1475
+ # Conditioning for the dataset
1476
+ self.condition_types = [] if condition_config is None else condition_config['condition_types']
1477
+
1478
+ self.images = self.load_images(im_path, self.df)
1479
+
1480
+ # Whether to load images and call vae or to load latents
1481
+ if use_latents and latent_path is not None:
1482
+ latent_maps = load_latents(latent_path)
1483
+ if len(latent_maps) == len(self.images):
1484
+ self.use_latents = True
1485
+ self.latent_maps = latent_maps
1486
+ print('Found {} latents'.format(len(self.latent_maps)))
1487
+ else:
1488
+ print('Latents not found')
1489
+
1490
+ def load_images(self, im_path, df):
1491
+ r"""
1492
+ Gets all images from the path specified
1493
+ and stacks them all up
1494
+ :param im_path:
1495
+ :return:
1496
+ """
1497
+ assert os.path.exists(im_path), "images path {} does not exist".format(im_path)
1498
+ # ims = []
1499
+ # labels = []
1500
+ # ims = [os.path.join(im_path, f) for f in os.listdir(im_path)]
1501
+ ims = [os.path.join(im_path, i) for i in df.image_id.values]
1502
+ return ims
1503
+
1504
+ def __len__(self):
1505
+ return len(self.images)
1506
+
1507
+ def __getitem__(self, index):
1508
+ ######## Set Conditioning Info ########
1509
+ # cond_inputs = {}
1510
+ # if 'class' in self.condition_types:
1511
+ # cond_inputs['class'] = self.labels[index]
1512
+ #######################################
1513
+
1514
+ if self.use_latents:
1515
+ latent = self.latent_maps[self.images[index]]
1516
+ if len(self.condition_types) == 0:
1517
+ return latent
1518
+ # else:
1519
+ # return latent, cond_inputs
1520
+ else:
1521
+ im = Image.open(self.images[index])
1522
+ im_tensor = torchvision.transforms.Compose([
1523
+ # torchvision.transforms.Resize(self.im_size),
1524
+ torchvision.transforms.CenterCrop(self.im_size),
1525
+ torchvision.transforms.ToTensor(),
1526
+ ])(im)
1527
+ im.close()
1528
+ # im_tensor = torchvision.transforms.ToTensor()(im)
1529
+
1530
+ # Convert input to -1 to 1 range.
1531
+ im_tensor = (2 * im_tensor) - 1
1532
+ if len(self.condition_types) == 0:
1533
+ return im_tensor
1534
+ # else:
1535
+ # return im_tensor, cond_inputs
1536
+
1537
+ #"""Train VQVAE"""...............................................................................................................................................
1538
+
1539
+ # Commented out IPython magic to ensure Python compatibility.
1540
+ import torch
1541
+ import torch.nn as nn
1542
+ import yaml
1543
+ from dotdict import DotDict
1544
+
1545
+ config_path = "/home/taruntejaneurips23/Ashish/DDPM/_5_ldm_celeba.yaml"
1546
+ with open(config_path, 'r') as file:
1547
+ Config = yaml.safe_load(file)
1548
+
1549
+
1550
+ Config = DotDict.from_dict(Config)
1551
+ dataset_config = Config.dataset_params
1552
+ diffusion_config = Config.diffusion_params
1553
+ model_config = Config.model_params
1554
+ train_config = Config.train_params
1555
+
1556
+ import torch
1557
+ import os
1558
+ import random
1559
+ import numpy as np
1560
+ import matplotlib.pyplot as plt
1561
+ from tqdm import tqdm
1562
+ from torch.optim import Adam
1563
+ from torch.utils.data import Dataset, TensorDataset, DataLoader
1564
+ # device = 'cuda:1' if torch.cuda.is_available() else 'cpu'
1565
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
1566
+
1567
+
1568
+
1569
+ from torchvision.utils import make_grid
1570
+
1571
+ def trainVAE(Config):
1572
+
1573
+ dataset_config = Config.dataset_params
1574
+ autoencoder_config = Config.autoencoder_params
1575
+ train_config = Config.train_params
1576
+
1577
+ # Set the desired seed value #
1578
+ seed = train_config.seed
1579
+ torch.manual_seed(seed)
1580
+ np.random.seed(seed)
1581
+ random.seed(seed)
1582
+ if device == 'cuda':
1583
+ torch.cuda.manual_seed_all(seed)
1584
+ #############################
1585
+
1586
+ # Create the model and dataset #
1587
+ model = VQVAE(im_channels=dataset_config.im_channels,
1588
+ model_config=autoencoder_config).to(device)
1589
+ # model.load_state_dict(torch.load("/home/taruntejaneurips23/Ashish/DDPM/celebAhair_ldm/vqvae_autoencoder_ckpt.pth", map_location=device))
1590
+ if os.path.exists(os.path.join(train_config.task_name, train_config.vqvae_autoencoder_ckpt_name)):
1591
+ print('Loaded vae checkpoint')
1592
+ model.load_state_dict(torch.load(os.path.join(train_config.task_name, train_config.vqvae_autoencoder_ckpt_name),
1593
+ map_location=device, weights_only=True))
1594
+
1595
+ # Create the dataset
1596
+ im_dataset_cls = {
1597
+ 'mnist': MnistDataset,
1598
+ 'celebA': CelebDataset,
1599
+ 'animeface': AnimeFaceDataset,
1600
+ 'celebAhair': CelebHairDataset
1601
+ }.get(dataset_config.name)
1602
+
1603
+ im_dataset = im_dataset_cls(split='train',
1604
+ im_path=dataset_config.im_path,
1605
+ im_size=dataset_config.im_size,
1606
+ im_channels=dataset_config.im_channels)
1607
+
1608
+ data_loader = DataLoader(im_dataset,
1609
+ batch_size=train_config.autoencoder_batch_size,
1610
+ shuffle=True,
1611
+ num_workers=os.cpu_count(),
1612
+ pin_memory=True,
1613
+ drop_last=True,
1614
+ persistent_workers=True, pin_memory_device=device)
1615
+
1616
+ # Create output directories
1617
+ if not os.path.exists(train_config.task_name):
1618
+ os.mkdir(train_config.task_name)
1619
+
1620
+ num_epochs = train_config.autoencoder_epochs
1621
+
1622
+ # L1/L2 loss for Reconstruction
1623
+ recon_criterion = torch.nn.MSELoss()
1624
+ # Disc Loss can even be BCEWithLogits
1625
+ disc_criterion = torch.nn.MSELoss()
1626
+
1627
+ # No need to freeze lpips as lpips.py takes care of that
1628
+ lpips_model = LPIPS().eval().to(device)
1629
+ discriminator = Discriminator(im_channels=dataset_config.im_channels).to(device)
1630
+ # discriminator.load_state_dict(torch.load("/home/taruntejaneurips23/Ashish/DDPM/celebAhair_ldm/vqvae_discriminator_ckpt.pth", map_location=device))
1631
+ if os.path.exists(os.path.join(train_config.task_name, train_config.vqvae_discriminator_ckpt_name)):
1632
+ print('Loaded discriminator checkpoint')
1633
+ discriminator.load_state_dict(torch.load(os.path.join(train_config.task_name, train_config.vqvae_discriminator_ckpt_name),
1634
+ map_location=device, weights_only=True))
1635
+
1636
+ optimizer_d = Adam(discriminator.parameters(), lr=train_config.autoencoder_lr, betas=(0.5, 0.999))
1637
+ optimizer_g = Adam(model.parameters(), lr=train_config.autoencoder_lr, betas=(0.5, 0.999))
1638
+
1639
+ disc_step_start = train_config.disc_start
1640
+ step_count = 0
1641
+
1642
+ # This is for accumulating gradients incase the images are huge
1643
+ # And one cant afford higher batch sizes
1644
+ acc_steps = train_config.autoencoder_acc_steps
1645
+ image_save_steps = train_config.autoencoder_img_save_steps
1646
+ img_save_count = 0
1647
+
1648
+ for epoch_idx in trange(num_epochs, desc='Training VQVAE'):
1649
+ recon_losses = []
1650
+ codebook_losses = []
1651
+ #commitment_losses = []
1652
+ perceptual_losses = []
1653
+ disc_losses = []
1654
+ gen_losses = []
1655
+ losses = []
1656
+
1657
+ optimizer_g.zero_grad()
1658
+ optimizer_d.zero_grad()
1659
+
1660
+ # for im in tqdm(data_loader):
1661
+ for im in data_loader:
1662
+ step_count += 1
1663
+ im = im.float().to(device)
1664
+
1665
+ # Fetch autoencoders output(reconstructions)
1666
+ model_output = model(im)
1667
+ output, z, quantize_losses = model_output
1668
+
1669
+ # Image Saving Logic
1670
+ if step_count % image_save_steps == 0 or step_count == 1:
1671
+ sample_size = min(8, im.shape[0])
1672
+ save_output = torch.clamp(output[:sample_size], -1., 1.).detach().cpu()
1673
+ save_output = ((save_output + 1) / 2)
1674
+ save_input = ((im[:sample_size] + 1) / 2).detach().cpu()
1675
+
1676
+ grid = make_grid(torch.cat([save_input, save_output], dim=0), nrow=sample_size)
1677
+ img = torchvision.transforms.ToPILImage()(grid)
1678
+ if not os.path.exists(os.path.join(train_config.task_name,'vqvae_autoencoder_samples')):
1679
+ os.mkdir(os.path.join(train_config.task_name, 'vqvae_autoencoder_samples'))
1680
+ img.save(os.path.join(train_config.task_name,'vqvae_autoencoder_samples',
1681
+ 'current_autoencoder_sample_{}.png'.format(img_save_count)))
1682
+ img_save_count += 1
1683
+ img.close()
1684
+
1685
+ ######### Optimize Generator ##########
1686
+ # L2 Loss
1687
+ recon_loss = recon_criterion(output, im)
1688
+ recon_losses.append(recon_loss.item())
1689
+ recon_loss = recon_loss / acc_steps
1690
+ g_loss = (recon_loss +
1691
+ (train_config.codebook_weight * quantize_losses['codebook_loss'] / acc_steps) +
1692
+ (train_config.commitment_beta * quantize_losses['commitment_loss'] / acc_steps))
1693
+ codebook_losses.append(train_config.codebook_weight * quantize_losses['codebook_loss'].item())
1694
+ # Adversarial loss only if disc_step_start steps passed
1695
+ if step_count > disc_step_start:
1696
+ disc_fake_pred = discriminator(model_output[0])
1697
+ disc_fake_loss = disc_criterion(disc_fake_pred,
1698
+ torch.ones(disc_fake_pred.shape,
1699
+ device=disc_fake_pred.device))
1700
+ gen_losses.append(train_config.disc_weight * disc_fake_loss.item())
1701
+ g_loss += train_config.disc_weight * disc_fake_loss / acc_steps
1702
+ lpips_loss = torch.mean(lpips_model(output, im)) / acc_steps
1703
+ perceptual_losses.append(train_config.perceptual_weight * lpips_loss.item())
1704
+ g_loss += train_config.perceptual_weight*lpips_loss / acc_steps
1705
+ losses.append(g_loss.item())
1706
+ g_loss.backward()
1707
+ #####################################
1708
+
1709
+ ######### Optimize Discriminator #######
1710
+ if step_count > disc_step_start:
1711
+ fake = output
1712
+ disc_fake_pred = discriminator(fake.detach())
1713
+ disc_real_pred = discriminator(im)
1714
+ disc_fake_loss = disc_criterion(disc_fake_pred,
1715
+ torch.zeros(disc_fake_pred.shape,
1716
+ device=disc_fake_pred.device))
1717
+ disc_real_loss = disc_criterion(disc_real_pred,
1718
+ torch.ones(disc_real_pred.shape,
1719
+ device=disc_real_pred.device))
1720
+ disc_loss = train_config.disc_weight * (disc_fake_loss + disc_real_loss) / 2
1721
+ disc_losses.append(disc_loss.item())
1722
+ disc_loss = disc_loss / acc_steps
1723
+ disc_loss.backward()
1724
+ if step_count % acc_steps == 0:
1725
+ optimizer_d.step()
1726
+ optimizer_d.zero_grad()
1727
+ #####################################
1728
+
1729
+ if step_count % acc_steps == 0:
1730
+ optimizer_g.step()
1731
+ optimizer_g.zero_grad()
1732
+ optimizer_d.step()
1733
+ optimizer_d.zero_grad()
1734
+ optimizer_g.step()
1735
+ optimizer_g.zero_grad()
1736
+ if len(disc_losses) > 0:
1737
+ print(
1738
+ 'Finished epoch: {}/{} | Recon Loss : {:.4f} | Perceptual Loss : {:.4f} | '
1739
+ 'Codebook : {:.4f} | G Loss : {:.4f} | D Loss {:.4f}'.
1740
+ format(epoch_idx + 1,
1741
+ num_epochs,
1742
+ np.mean(recon_losses),
1743
+ np.mean(perceptual_losses),
1744
+ np.mean(codebook_losses),
1745
+ np.mean(gen_losses),
1746
+ np.mean(disc_losses)))
1747
+ else:
1748
+ print('Finished epoch: {}/{} | Recon Loss : {:.4f} | Perceptual Loss : {:.4f} | Codebook : {:.4f}'.
1749
+ format(epoch_idx + 1,
1750
+ num_epochs,
1751
+ np.mean(recon_losses),
1752
+ np.mean(perceptual_losses),
1753
+ np.mean(codebook_losses)))
1754
+
1755
+ torch.save(model.state_dict(), os.path.join(train_config.task_name,
1756
+ train_config.vqvae_autoencoder_ckpt_name))
1757
+ torch.save(discriminator.state_dict(), os.path.join(train_config.task_name,
1758
+ train_config.vqvae_discriminator_ckpt_name))
1759
+ print('Done Training...')
1760
+
1761
+
1762
+ # trainVAE(Config)
1763
+
1764
+ import torch
1765
+ import torch.nn as nn
1766
+
1767
+
1768
+ class Unet(nn.Module):
1769
+ r"""
1770
+ Unet model comprising
1771
+ Down blocks, Midblocks and Uplocks
1772
+ """
1773
+
1774
+ def __init__(self, im_channels, model_config):
1775
+ super().__init__()
1776
+ self.down_channels = model_config.down_channels
1777
+ self.mid_channels = model_config.mid_channels
1778
+ self.t_emb_dim = model_config.time_emb_dim
1779
+ self.down_sample = model_config.down_sample
1780
+ self.num_down_layers = model_config.num_down_layers
1781
+ self.num_mid_layers = model_config.num_mid_layers
1782
+ self.num_up_layers = model_config.num_up_layers
1783
+ self.attns = model_config.attn_down
1784
+ self.norm_channels = model_config.norm_channels
1785
+ self.num_heads = model_config.num_heads
1786
+ self.conv_out_channels = model_config.conv_out_channels
1787
+
1788
+ assert self.mid_channels[0] == self.down_channels[-1]
1789
+ assert self.mid_channels[-1] == self.down_channels[-2]
1790
+ assert len(self.down_sample) == len(self.down_channels) - 1
1791
+ assert len(self.attns) == len(self.down_channels) - 1
1792
+
1793
+ # Initial projection from sinusoidal time embedding
1794
+ self.t_proj = nn.Sequential(
1795
+ nn.Linear(self.t_emb_dim, self.t_emb_dim),
1796
+ nn.SiLU(),
1797
+ nn.Linear(self.t_emb_dim, self.t_emb_dim)
1798
+ )
1799
+
1800
+ self.up_sample = list(reversed(self.down_sample))
1801
+ self.conv_in = nn.Conv2d(im_channels, self.down_channels[0], kernel_size=3, padding=1)
1802
+
1803
+ self.downs = nn.ModuleList([])
1804
+ for i in range(len(self.down_channels) - 1):
1805
+ self.downs.append(DownBlock(self.down_channels[i], self.down_channels[i + 1], self.t_emb_dim,
1806
+ down_sample=self.down_sample[i],
1807
+ num_heads=self.num_heads,
1808
+ num_layers=self.num_down_layers,
1809
+ attn=self.attns[i], norm_channels=self.norm_channels))
1810
+
1811
+ self.mids = nn.ModuleList([])
1812
+ for i in range(len(self.mid_channels) - 1):
1813
+ self.mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i + 1], self.t_emb_dim,
1814
+ num_heads=self.num_heads,
1815
+ num_layers=self.num_mid_layers,
1816
+ norm_channels=self.norm_channels))
1817
+
1818
+ self.ups = nn.ModuleList([])
1819
+ for i in reversed(range(len(self.down_channels) - 1)):
1820
+ self.ups.append(UpBlockUnet(self.down_channels[i] * 2, self.down_channels[i - 1] if i != 0 else self.conv_out_channels,
1821
+ self.t_emb_dim, up_sample=self.down_sample[i],
1822
+ num_heads=self.num_heads,
1823
+ num_layers=self.num_up_layers,
1824
+ norm_channels=self.norm_channels))
1825
+
1826
+ self.norm_out = nn.GroupNorm(self.norm_channels, self.conv_out_channels)
1827
+ self.conv_out = nn.Conv2d(self.conv_out_channels, im_channels, kernel_size=3, padding=1)
1828
+
1829
+ def forward(self, x, t):
1830
+ # Shapes assuming downblocks are [C1, C2, C3, C4]
1831
+ # Shapes assuming midblocks are [C4, C4, C3]
1832
+ # Shapes assuming downsamples are [True, True, False]
1833
+ # B x C x H x W
1834
+ out = self.conv_in(x)
1835
+ # B x C1 x H x W
1836
+
1837
+ # t_emb -> B x t_emb_dim
1838
+ t_emb = get_time_embedding(torch.as_tensor(t).long(), self.t_emb_dim)
1839
+ t_emb = self.t_proj(t_emb)
1840
+
1841
+ down_outs = []
1842
+
1843
+ for idx, down in enumerate(self.downs):
1844
+ down_outs.append(out)
1845
+ out = down(out, t_emb)
1846
+ # down_outs [B x C1 x H x W, B x C2 x H/2 x W/2, B x C3 x H/4 x W/4]
1847
+ # out B x C4 x H/4 x W/4
1848
+
1849
+ for mid in self.mids:
1850
+ out = mid(out, t_emb)
1851
+ # out B x C3 x H/4 x W/4
1852
+
1853
+ for up in self.ups:
1854
+ down_out = down_outs.pop()
1855
+ out = up(out, down_out, t_emb)
1856
+ # out [B x C2 x H/4 x W/4, B x C1 x H/2 x W/2, B x 16 x H x W]
1857
+ out = self.norm_out(out)
1858
+ out = nn.SiLU()(out)
1859
+ out = self.conv_out(out)
1860
+ # out B x C x H x W
1861
+ return out
1862
+
1863
+ def trainLDM(Config):
1864
+
1865
+ diffusion_config = Config.diffusion_params
1866
+ dataset_config = Config.dataset_params
1867
+ diffusion_model_config = Config.ldm_params
1868
+ autoencoder_model_config = Config.autoencoder_params
1869
+ train_config = Config.train_params
1870
+
1871
+ # Create the noise scheduler
1872
+ scheduler = LinearNoiseScheduler(num_timesteps=diffusion_config.num_timesteps,
1873
+ beta_start=diffusion_config.beta_start,
1874
+ beta_end=diffusion_config.beta_end)
1875
+ # scheduler = CosineNoiseScheduler(diffusion_config.num_timesteps)
1876
+
1877
+ im_dataset_cls = {
1878
+ 'mnist': MnistDataset,
1879
+ 'celebA': CelebDataset,
1880
+ 'animeface': AnimeFaceDataset,
1881
+ 'celebAhair': CelebHairDataset
1882
+ }.get(dataset_config.name)
1883
+
1884
+ im_dataset = im_dataset_cls(split='train',
1885
+ im_path=dataset_config.im_path,
1886
+ im_size=dataset_config.im_size,
1887
+ im_channels=dataset_config.im_channels,
1888
+ use_latents=True,
1889
+ latent_path=os.path.join(train_config.task_name,
1890
+ train_config.vqvae_latent_dir_name)
1891
+ )
1892
+
1893
+ data_loader = DataLoader(im_dataset,
1894
+ batch_size=train_config.ldm_batch_size,
1895
+ shuffle=True,
1896
+ num_workers=os.cpu_count(),
1897
+ pin_memory=True,
1898
+ drop_last=False,
1899
+ persistent_workers=True, pin_memory_device=device)
1900
+
1901
+ # Instantiate the model
1902
+ model = Unet(im_channels=autoencoder_model_config.z_channels,
1903
+ model_config=diffusion_model_config).to(device)
1904
+ if os.path.exists(os.path.join(train_config.task_name, train_config.ldm_ckpt_name)):
1905
+ print('Loaded ldm checkpoint')
1906
+ model.load_state_dict(torch.load(os.path.join(train_config.task_name, train_config.ldm_ckpt_name), map_location=device, weights_only=True))
1907
+ model.train()
1908
+
1909
+ # Load VAE ONLY if latents are not to be used or are missing
1910
+ if not im_dataset.use_latents:
1911
+ print('Loading vqvae model as latents not present')
1912
+ vae = VQVAE(im_channels=dataset_config.im_channels,
1913
+ model_config=autoencoder_model_config).to(device)
1914
+ vae.eval()
1915
+ # Load vae if found
1916
+ if os.path.exists(os.path.join(train_config.task_name,
1917
+ train_config.vqvae_autoencoder_ckpt_name)):
1918
+ print('Loaded vae checkpoint')
1919
+ vae.load_state_dict(torch.load(os.path.join(train_config.task_name,
1920
+ train_config.vqvae_autoencoder_ckpt_name),
1921
+ map_location=device))
1922
+ # Specify training parameters
1923
+ num_epochs = train_config.ldm_epochs
1924
+ optimizer = Adam(model.parameters(), lr=train_config.ldm_lr)
1925
+ criterion = torch.nn.MSELoss()
1926
+
1927
+ # Run training
1928
+ if not im_dataset.use_latents:
1929
+ for param in vae.parameters():
1930
+ param.requires_grad = False
1931
+
1932
+ for epoch_idx in range(num_epochs):
1933
+ losses = []
1934
+ for im in tqdm(data_loader):
1935
+ optimizer.zero_grad()
1936
+ im = im.float().to(device)
1937
+ if not im_dataset.use_latents:
1938
+ with torch.no_grad():
1939
+ im, _ = vae.encode(im)
1940
+
1941
+ # Sample random noise
1942
+ noise = torch.randn_like(im).to(device)
1943
+
1944
+ # Sample timestep
1945
+ t = torch.randint(0, diffusion_config.num_timesteps, (im.shape[0],)).to(device)
1946
+
1947
+ # Add noise to images according to timestep
1948
+ noisy_im = scheduler.add_noise(im, noise, t)
1949
+ noise_pred = model(noisy_im, t)
1950
+
1951
+ loss = criterion(noise_pred, noise)
1952
+ losses.append(loss.item())
1953
+ loss.backward()
1954
+ optimizer.step()
1955
+ print(f'Finished epoch:{epoch_idx + 1}/{num_epochs} | Loss : {np.mean(losses):.4f}')
1956
+
1957
+ torch.save(model.state_dict(), os.path.join(train_config.task_name,
1958
+ train_config.ldm_ckpt_name))
1959
+
1960
+ # Doing Inference
1961
+ infer(Config)
1962
+
1963
+ # Checking to conntinue training
1964
+ train_continue = yaml.safe_load(open("/home/taruntejaneurips23/Ashish/DDPM/_5_ldm_celeba.yaml", 'r'))
1965
+ train_continue = DotDict.from_dict(train_continue)
1966
+ if train_continue.training._continue_ == False:
1967
+ print('Training Stoped ...')
1968
+ break
1969
+
1970
+ print('Done Training ...')
1971
+
1972
+ # trainLDM(Config)
1973
+
1974
+ # import subprocess
1975
+ # subprocess.run(f'kill {os.getpid()}', shell=True, check=True)
1976
+
1977
+ def sample(model, scheduler, train_config, diffusion_model_config,
1978
+ autoencoder_model_config, diffusion_config, dataset_config, vae):
1979
+ r"""
1980
+ Sample stepwise by going backward one timestep at a time.
1981
+ We save the x0 predictions
1982
+ """
1983
+ im_size = dataset_config.im_size // 2**sum(autoencoder_model_config.down_sample)
1984
+ xt = torch.randn((train_config.num_samples,
1985
+ autoencoder_model_config.z_channels,
1986
+ im_size,
1987
+ im_size)).to(device)
1988
+
1989
+ save_count = 0
1990
+ for i in tqdm(reversed(range(diffusion_config.num_timesteps)), total=diffusion_config.num_timesteps):
1991
+ # Get prediction of noise
1992
+ noise_pred = model(xt, torch.as_tensor(i).unsqueeze(0).to(device))
1993
+
1994
+ # Use scheduler to get x0 and xt-1
1995
+ xt, x0_pred = scheduler.sample_prev_timestep(xt, noise_pred, torch.as_tensor(i).to(device))
1996
+
1997
+ # Save x0
1998
+ #ims = torch.clamp(xt, -1., 1.).detach().cpu()
1999
+ if i == 0:
2000
+ # Decode ONLY the final iamge to save time
2001
+ ims = vae.decode(xt)
2002
+ else:
2003
+ ims = xt
2004
+
2005
+ ims = torch.clamp(ims, -1., 1.).detach().cpu()
2006
+ ims = (ims + 1) / 2
2007
+ grid = make_grid(ims, nrow=train_config.num_grid_rows)
2008
+ img = torchvision.transforms.ToPILImage()(grid)
2009
+
2010
+ if not os.path.exists(os.path.join(train_config.task_name, 'samples')):
2011
+ os.mkdir(os.path.join(train_config.task_name, 'samples'))
2012
+ img.save(os.path.join(train_config.task_name, 'samples', 'x0_{}.png'.format(i)))
2013
+ img.close()
2014
+
2015
+
2016
+ def infer(Config):
2017
+
2018
+ diffusion_config = Config.diffusion_params
2019
+ dataset_config = Config.dataset_params
2020
+ diffusion_model_config = Config.ldm_params
2021
+ autoencoder_model_config = Config.autoencoder_params
2022
+ train_config = Config.train_params
2023
+
2024
+ # Create the noise scheduler
2025
+ scheduler = LinearNoiseScheduler(num_timesteps=diffusion_config.num_timesteps,
2026
+ beta_start=diffusion_config.beta_start,
2027
+ beta_end=diffusion_config.beta_end)
2028
+ # scheduler = CosineNoiseScheduler(diffusion_config.num_timesteps)
2029
+
2030
+ model = Unet(im_channels=autoencoder_model_config.z_channels,
2031
+ model_config=diffusion_model_config).to(device)
2032
+ model.eval()
2033
+ if os.path.exists(os.path.join(train_config.task_name,
2034
+ train_config.ldm_ckpt_name)):
2035
+ print('Loaded unet checkpoint')
2036
+ model.load_state_dict(torch.load(os.path.join(train_config.task_name,
2037
+ train_config.ldm_ckpt_name),
2038
+ map_location=device))
2039
+ # Create output directories
2040
+ if not os.path.exists(train_config.task_name):
2041
+ os.mkdir(train_config.task_name)
2042
+
2043
+ vae = VQVAE(im_channels=dataset_config.im_channels,
2044
+ model_config=autoencoder_model_config).to(device)
2045
+ vae.eval()
2046
+
2047
+ # Load vae if found
2048
+ if os.path.exists(os.path.join(train_config.task_name,
2049
+ train_config.vqvae_autoencoder_ckpt_name)):
2050
+ print('Loaded vae checkpoint')
2051
+ vae.load_state_dict(torch.load(os.path.join(train_config.task_name,
2052
+ train_config.vqvae_autoencoder_ckpt_name),
2053
+ map_location=device), strict=True)
2054
+ with torch.no_grad():
2055
+ sample(model, scheduler, train_config, diffusion_model_config,
2056
+ autoencoder_model_config, diffusion_config, dataset_config, vae)
2057
+
2058
+
2059
+
2060
+ import argparse
2061
+
2062
+ def get_args():
2063
+ parser = argparse.ArgumentParser(description="Choose between train VAE, train LDM, or infer mode.")
2064
+ parser.add_argument('--mode', choices=['train_vae', 'train_ldm', 'infer'], default='infer',
2065
+ help="Mode to run: train_vae, train_ldm, or infer")
2066
+ return parser.parse_args()
2067
+
2068
+ args = get_args()
2069
+
2070
+ if args.mode == 'train_vae':
2071
+ trainVAE(Config)
2072
+ elif args.mode == 'train_ldm':
2073
+ trainLDM(Config)
2074
+ else:
2075
+ infer(Config)
2076
+
2077
+ # python _5.2_ldm_celeba_hair_cosine.py --mode train_vae
2078
+ # python _5.2_ldm_celeba_hair_cosine.py --mode train_ldm
2079
+ # python _5.2_ldm_celeba_hair_cosine.py --mode infer
2080
+
2081
+
2082
+
2083
+
2084
+ # import matplotlib.pyplot as plt
2085
+ # from PIL import Image
2086
+ # # plt.style.use('dark_background')
2087
+ # # %matplotlib inline
2088
+
2089
+ # plt.imshow(Image.open('/home/taruntejaneurips23/Ashish/DDPM/mnist_ldm/samples/x0_0.png'), cmap='gray')
2090
+
2091
+ # import matplotlib.pyplot as plt
2092
+ # import matplotlib.image as mpimg
2093
+
2094
+ # dataset_name = 'animeface_ldm'
2095
+
2096
+ # image_paths = [f'/home/taruntejaneurips23/Ashish/DDPM/{dataset_name}/samples/x0_0.png',
2097
+ # f'/home/taruntejaneurips23/Ashish/DDPM/{dataset_name}/samples/x0_1.png',
2098
+ # f'/home/taruntejaneurips23/Ashish/DDPM/{dataset_name}/samples/x0_5.png',
2099
+ # f'/home/taruntejaneurips23/Ashish/DDPM/{dataset_name}/samples/x0_100.png',
2100
+ # f'/home/taruntejaneurips23/Ashish/DDPM/{dataset_name}/samples/x0_200.png'
2101
+ # ]
2102
+
2103
+ # fig, axes = plt.subplots(1, len(image_paths), figsize=(15, 5))
2104
+
2105
+ # for i, path in enumerate(image_paths):
2106
+ # img = mpimg.imread(path)
2107
+ # axes[i].imshow(img)
2108
+ # axes[i].axis('off') # Hide axes
2109
+ # axes[i].set_title(f't = {path.split("/")[-1].split(".")[0].split("_")[-1]}')
2110
+
2111
+ # plt.tight_layout()
2112
+ # plt.show()
2113
+
2114
+ # ---------------------------------------------------------
2115
+ # ---------- T H E - E N D -------------------------------
2116
+ # ---------------------------------------------------------
2117
+
2118
+
2119
+
2120
+ def save_checkpoint(
2121
+ total_steps, epoch, model, discriminator,
2122
+ optimizer_d, optimizer_g, loss, checkpoint_path
2123
+ ):
2124
+ checkpoint = {
2125
+ "total_steps": total_steps,
2126
+ "epoch": epoch,
2127
+ "model_state_dict": model.state_dict(),
2128
+ "discriminator_state_dict": discriminator.state_dict(),
2129
+ "optimizer_d_state_dict": optimizer_d.state_dict(),
2130
+ "optimizer_g_state_dict": optimizer_g.state_dict(),
2131
+ "loss": loss,
2132
+ }
2133
+ torch.save(checkpoint, checkpoint_path)
2134
+ print(f"Checkpoint saved after {total_steps} steps at epoch {epoch}")
2135
+
2136
+
2137
+ def load_checkpoint(
2138
+ checkpoint_path, model, discriminator, optimizer_d, optimizer_g
2139
+ ):
2140
+ if os.path.exists(checkpoint_path):
2141
+ checkpoint = torch.load(checkpoint_path)
2142
+ model.load_state_dict(checkpoint["model_state_dict"])
2143
+ discriminator.load_state_dict(checkpoint["discriminator_state_dict"])
2144
+ optimizer_d.load_state_dict(checkpoint["optimizer_d_state_dict"])
2145
+ optimizer_g.load_state_dict(checkpoint["optimizer_g_state_dict"])
2146
+ total_steps = checkpoint["total_steps"]
2147
+ start_epoch = checkpoint["epoch"] + 1
2148
+ loss = checkpoint["loss"]
2149
+ print(f"Checkpoint loaded. Resuming from epoch {start_epoch}")
2150
+ return total_steps, start_epoch, loss
2151
+ else:
2152
+ print("No checkpoint found. Starting from scratch.")
2153
+ return 0, 0, None
2154
+
2155
+
2156
+ def trainVAE(Config, dataloader):
2157
+ """
2158
+ Trains a VQVAE model using the provided configuration and data loader.
2159
+ """
2160
+ # --- Configurations ----------------------------------------------------
2161
+ dataset_config = Config.dataset_params
2162
+ autoencoder_config = Config.autoencoder_params
2163
+ train_config = Config.train_params
2164
+
2165
+ seed = train_config.seed
2166
+ torch.manual_seed(seed)
2167
+ np.random.seed(seed)
2168
+ random.seed(seed)
2169
+ if device == "cuda":
2170
+ torch.cuda.manual_seed_all(seed)
2171
+
2172
+ # --- Model Initialization ----------------------------------------------
2173
+ model = VQVAE(im_channels=dataset_config.im_channels, model_config=autoencoder_config).to(device)
2174
+ discriminator = Discriminator(im_channels=dataset_config.im_channels).to(device)
2175
+
2176
+ # --- Load Checkpoints --------------------------------------------------
2177
+ checkpoint_path = os.path.join(train_config.task_name, "vqvae_checkpoint.pth")
2178
+ total_steps, start_epoch, _ = load_checkpoint(checkpoint_path, model, discriminator, None, None)
2179
+
2180
+ # --- Loss Function Initialization --------------------------------------
2181
+ recon_criterion = torch.nn.MSELoss()
2182
+ lpips_model = LPIPS().eval().to(device)
2183
+ disc_criterion = torch.nn.MSELoss()
2184
+
2185
+ # --- Optimizer Initialization ------------------------------------------
2186
+ optimizer_d = torch.optim.AdamW(discriminator.parameters(), lr=train_config.autoencoder_lr, betas=(0.5, 0.999))
2187
+ optimizer_g = torch.optim.AdamW(model.parameters(), lr=train_config.autoencoder_lr, betas=(0.5, 0.999))
2188
+
2189
+ num_epochs = train_config.autoencoder_epochs
2190
+ acc_steps = train_config.autoencoder_acc_steps
2191
+ image_save_steps = train_config.autoencoder_img_save_steps
2192
+ img_save_count = 0
2193
+
2194
+ # Create necessary directories
2195
+ os.makedirs(os.path.join(train_config.task_name, "vqvae_autoencoder_samples"), exist_ok=True)
2196
+
2197
+ # --- Training Loop -----------------------------------------------------
2198
+ for epoch_idx in range(start_epoch, num_epochs):
2199
+ recon_losses, codebook_losses, perceptual_losses, disc_losses, gen_losses = [], [], [], [], []
2200
+
2201
+ for images in dataloader:
2202
+ total_steps += 1
2203
+ images = images.to(device)
2204
+
2205
+ # Forward pass
2206
+ model_output = model(images)
2207
+ output, z, quantize_losses = model_output
2208
+
2209
+ # Save generated images periodically
2210
+ if total_steps % image_save_steps == 0 or total_steps == 1:
2211
+ sample_size = min(8, images.shape[0])
2212
+ save_output = torch.clamp(output[:sample_size], -1.0, 1.0).detach().cpu()
2213
+ save_output = (save_output + 1) / 2
2214
+ save_input = ((images[:sample_size] + 1) / 2).detach().cpu()
2215
+
2216
+ grid = make_grid(torch.cat([save_input, save_output], dim=0), nrow=sample_size)
2217
+ img = tv.transforms.ToPILImage()(grid)
2218
+ img.save(
2219
+ os.path.join(
2220
+ train_config.task_name,
2221
+ "vqvae_autoencoder_samples",
2222
+ f"current_autoencoder_sample_{img_save_count}.png",
2223
+ )
2224
+ )
2225
+ img_save_count += 1
2226
+ img.close()
2227
+
2228
+ # Reconstruction Loss
2229
+ recon_loss = recon_criterion(output, images) / acc_steps
2230
+ recon_losses.append(recon_loss.item())
2231
+
2232
+ # Generator Loss
2233
+ codebook_loss = train_config.codebook_weight * quantize_losses["codebook_loss"] / acc_steps
2234
+ perceptual_loss = train_config.perceptual_weight * lpips_model(output, images).mean() / acc_steps
2235
+ g_loss = recon_loss + codebook_loss + perceptual_loss
2236
+
2237
+ if total_steps > train_config.disc_start:
2238
+ disc_fake_pred = discriminator(output)
2239
+ gen_loss = train_config.disc_weight * disc_criterion(
2240
+ disc_fake_pred, torch.ones_like(disc_fake_pred)
2241
+ ) / acc_steps
2242
+ g_loss += gen_loss
2243
+ gen_losses.append(gen_loss.item())
2244
+
2245
+ g_loss.backward()
2246
+ optimizer_g.step()
2247
+ optimizer_g.zero_grad()
2248
+
2249
+ # Discriminator Loss
2250
+ if total_steps > train_config.disc_start:
2251
+ disc_fake_pred = discriminator(output.detach())
2252
+ disc_real_pred = discriminator(images)
2253
+ disc_fake_loss = disc_criterion(
2254
+ disc_fake_pred, torch.zeros_like(disc_fake_pred)
2255
+ ) / acc_steps
2256
+ disc_real_loss = disc_criterion(
2257
+ disc_real_pred, torch.ones_like(disc_real_pred)
2258
+ ) / acc_steps
2259
+ disc_loss = train_config.disc_weight * (disc_fake_loss + disc_real_loss) / 2
2260
+ disc_loss.backward()
2261
+ optimizer_d.step()
2262
+ optimizer_d.zero_grad()
2263
+ disc_losses.append(disc_loss.item())
2264
+
2265
+ # Save checkpoint after each epoch
2266
+ save_checkpoint(total_steps, epoch_idx, model, discriminator, optimizer_d, optimizer_g, recon_losses, checkpoint_path)
2267
+
2268
+ # Print epoch summary
2269
+ print(
2270
+ f"Epoch {epoch_idx + 1}/{num_epochs} | Recon Loss: {np.mean(recon_losses):.4f} | "
2271
+ f"Perceptual Loss: {np.mean(perceptual_losses):.4f} | Codebook Loss: {np.mean(codebook_losses):.4f} | "
2272
+ f"G Loss: {np.mean(gen_losses):.4f} | D Loss: {np.mean(disc_losses):.4f}"
2273
+ )
LDM/scripts/_1_Lpips.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ==================================================================
2
+ # LEARNED PERCEPTUAL IMAGE PATCH SIMILARITY ( L P I P S )
3
+ # ==================================================================
4
+ # Author : Ashish Kumar Uchadiya
5
+ # Created : January 18, 2025
6
+ # Description: LPIPS essentially computes the similarity between the
7
+ # activations of two image patches for some pre-defined network.
8
+ # This measure has been shown to match human perception well.
9
+ # A low LPIPS score means that image patches are perceptual similar.
10
+ # ==================================================================
11
+
12
+
13
+
14
+ class vgg16(torch.nn.Module):
15
+ def __init__(self, requires_grad=False, pretrained=True):
16
+ super(vgg16, self).__init__()
17
+ vgg_pretrained_features = torchvision.models.vgg16(
18
+ weights=torchvision.models.VGG16_Weights.IMAGENET1K_V1
19
+ ).features
20
+ self.slice1 = torch.nn.Sequential()
21
+ self.slice2 = torch.nn.Sequential()
22
+ self.slice3 = torch.nn.Sequential()
23
+ self.slice4 = torch.nn.Sequential()
24
+ self.slice5 = torch.nn.Sequential()
25
+ self.N_slices = 5
26
+ for x in range(4):
27
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
28
+ for x in range(4, 9):
29
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
30
+ for x in range(9, 16):
31
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
32
+ for x in range(16, 23):
33
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
34
+ for x in range(23, 30):
35
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
36
+
37
+ # Freeze vgg model
38
+ if not requires_grad:
39
+ for param in self.parameters():
40
+ param.requires_grad = False
41
+
42
+ def forward(self, X):
43
+ # Return output of vgg features
44
+ h = self.slice1(X)
45
+ h_relu1_2 = h
46
+ h = self.slice2(h)
47
+ h_relu2_2 = h
48
+ h = self.slice3(h)
49
+ h_relu3_3 = h
50
+ h = self.slice4(h)
51
+ h_relu4_3 = h
52
+ h = self.slice5(h)
53
+ h_relu5_3 = h
54
+ vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
55
+ out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
56
+ return out
LDM/scripts/config.yaml ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset_params:
2
+ im_path: "/home/taruntejaneurips23/Ashish/datasets/CelebA/img_align_celeba/img_align_celeba"
3
+ im_channels: 3
4
+ im_size: 28
5
+
6
+ diffusion_params:
7
+ num_timesteps: 1000
8
+ beta_start: 0.0015
9
+ beta_end: 0.0195
10
+
11
+ ldm_params:
12
+ down_channels: [128, 256, 256, 256]
13
+ mid_channels: [256, 256]
14
+ down_sample: [False, False, False]
15
+ attn_down: [True, True, True]
16
+ time_emb_dim: 256
17
+ norm_channels: 32
18
+ num_heads: 16
19
+ conv_out_channels: 128
20
+ num_down_layers: 2
21
+ num_mid_layers: 2
22
+ num_up_layers: 2
23
+
24
+ autoencoder_params:
25
+ z_channels: 3
26
+ codebook_size: 20
27
+ down_channels: [32, 64, 128]
28
+ mid_channels: [128, 128]
29
+ down_sample: [True, True]
30
+ attn_down: [False, False]
31
+ norm_channels: 32
32
+ num_heads: 16
33
+ num_down_layers: 2
34
+ num_mid_layers: 2
35
+ num_up_layers: 2
36
+
37
+ train_params:
38
+ seed: 4242
39
+ task_name: 'MnistLDM'
40
+ ldm_batch_size: 9
41
+ autoencoder_batch_size: 32
42
+ disc_start: 1000
43
+ disc_weight: 0.5
44
+ codebook_weight: 1
45
+ commitment_beta: 0.2
46
+ perceptual_weight: 1
47
+ kl_weight: 0.000005
48
+ ldm_epochs: 10
49
+ autoencoder_epochs: 10
50
+ num_samples: 9
51
+ num_grid_rows: 3
52
+ ldm_lr: 0.00001
53
+ autoencoder_lr: 0.0001
54
+ autoencoder_acc_steps: 1
55
+ autoencoder_img_save_steps: 8
56
+ save_latents: True
57
+ vqvae_latent_dir_name: 'vqvae_latents'
58
+ ldm_ckpt_name: 'ddpm_ckpt.pth'
59
+ vqvae_autoencoder_ckpt_name: 'vqvae_autoencoder_ckpt.pth'
60
+ vqvae_discriminator_ckpt_name: 'vqvae_discriminator_ckpt.pth'
61
+ checkpoint_dir: './'
62
+
63
+ training:
64
+ _continue_: True
65
+
Vaani/39448.err ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ + '[' -z '' ']'
2
+ + case "$-" in
3
+ + __lmod_vx=x
4
+ + '[' -n x ']'
5
+ + set +x
6
+ Shell debugging temporarily silenced: export LMOD_SH_DBG_ON=1 for this output (/usr/share/lmod/lmod/init/bash)
7
+ Shell debugging restarted
8
+ + unset __lmod_vx
9
+ + cd
10
+ + module purge
11
+ + '[' -z '' ']'
12
+ + case "$-" in
13
+ + __lmod_sh_dbg=x
14
+ + '[' -n x ']'
15
+ + set +x
16
+ Shell debugging temporarily silenced: export LMOD_SH_DBG_ON=1 for Lmod's output
17
+ Shell debugging restarted
18
+ + unset __lmod_sh_dbg
19
+ + return 0
20
+ + module load miniconda
21
+ + '[' -z '' ']'
22
+ + case "$-" in
23
+ + __lmod_sh_dbg=x
24
+ + '[' -n x ']'
25
+ + set +x
26
+ Shell debugging temporarily silenced: export LMOD_SH_DBG_ON=1 for Lmod's output
27
+ Shell debugging restarted
28
+ + unset __lmod_sh_dbg
29
+ + return 0
30
+ + source /home/apps/miniconda3/etc/profile.d/conda.sh
31
+ ++ export CONDA_EXE=/home/apps/miniconda3/bin/conda
32
+ ++ CONDA_EXE=/home/apps/miniconda3/bin/conda
33
+ ++ export _CE_M=
34
+ ++ _CE_M=
35
+ ++ export _CE_CONDA=
36
+ ++ _CE_CONDA=
37
+ ++ export CONDA_PYTHON_EXE=/home/apps/miniconda3/bin/python
38
+ ++ CONDA_PYTHON_EXE=/home/apps/miniconda3/bin/python
39
+ ++ '[' -z x ']'
40
+ + conda env list
41
+ + local cmd=env
42
+ + case "$cmd" in
43
+ + __conda_exe env list
44
+ + '[' -n '' ']'
45
+ + /home/apps/miniconda3/bin/conda env list
46
+ + conda activate aku_env
47
+ + local cmd=activate
48
+ + case "$cmd" in
49
+ + __conda_activate activate aku_env
50
+ + '[' -n '' ']'
51
+ + local ask_conda
52
+ ++ PS1=
53
+ ++ __conda_exe shell.posix activate aku_env
54
+ ++ '[' -n '' ']'
55
+ ++ /home/apps/miniconda3/bin/conda shell.posix activate aku_env
56
+ + ask_conda='unset _CE_M
57
+ unset _CE_CONDA
58
+ PS1='\''(aku_env) '\''
59
+ export PATH='\''/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/bin:/home/IITB/ai-at-ieor/23m1521/.vscode-server/cli/servers/Stable-ddc367ed5c8936efe395cffeec279b04ffd7db78/server/bin/remote-cli:/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/bin:/home/apps/MLDL/DL-CondaPy3/condabin:/home/IITB/ai-at-ieor/23m1521/.local/bin:/home/IITB/ai-at-ieor/23m1521/bin:/usr/local/bin:/usr/bin:/usr/local/sbin:/usr/sbin:/var/lib/snapd/snap/bin:/home/IITB/ai-at-ieor/23m1521/.vscode-server/extensions/ms-python.debugpy-2025.4.1-linux-x64/bundled/scripts/noConfigScripts:/home/IITB/ai-at-ieor/23m1521/.vscode-server/data/User/globalStorage/github.copilot-chat/debugCommand'\''
60
+ export CONDA_PREFIX='\''/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env'\''
61
+ export CONDA_SHLVL='\''2'\''
62
+ export CONDA_DEFAULT_ENV='\''aku_env'\''
63
+ export CONDA_PROMPT_MODIFIER='\''(aku_env) '\''
64
+ export CONDA_PREFIX_1='\''/home/apps/miniconda3'\''
65
+ export CONDA_EXE='\''/home/apps/miniconda3/bin/conda'\''
66
+ export CONDA_PYTHON_EXE='\''/home/apps/miniconda3/bin/python'\''
67
+ . "/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/etc/conda/activate.d/gdal-activate.sh"
68
+ . "/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/etc/conda/activate.d/geotiff-activate.sh"
69
+ . "/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/etc/conda/activate.d/libarrow_activate.sh"
70
+ . "/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/etc/conda/activate.d/libglib_activate.sh"
71
+ . "/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/etc/conda/activate.d/libpdal-core_activate.sh"
72
+ . "/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/etc/conda/activate.d/libxml2_activate.sh"
73
+ . "/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/etc/conda/activate.d/pdal-python-activate.sh"
74
+ . "/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/etc/conda/activate.d/proj4-activate.sh"'
75
+ + eval 'unset _CE_M
76
+ unset _CE_CONDA
77
+ PS1='\''(aku_env) '\''
78
+ export PATH='\''/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/bin:/home/IITB/ai-at-ieor/23m1521/.vscode-server/cli/servers/Stable-ddc367ed5c8936efe395cffeec279b04ffd7db78/server/bin/remote-cli:/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/bin:/home/apps/MLDL/DL-CondaPy3/condabin:/home/IITB/ai-at-ieor/23m1521/.local/bin:/home/IITB/ai-at-ieor/23m1521/bin:/usr/local/bin:/usr/bin:/usr/local/sbin:/usr/sbin:/var/lib/snapd/snap/bin:/home/IITB/ai-at-ieor/23m1521/.vscode-server/extensions/ms-python.debugpy-2025.4.1-linux-x64/bundled/scripts/noConfigScripts:/home/IITB/ai-at-ieor/23m1521/.vscode-server/data/User/globalStorage/github.copilot-chat/debugCommand'\''
79
+ export CONDA_PREFIX='\''/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env'\''
80
+ export CONDA_SHLVL='\''2'\''
81
+ export CONDA_DEFAULT_ENV='\''aku_env'\''
82
+ export CONDA_PROMPT_MODIFIER='\''(aku_env) '\''
83
+ export CONDA_PREFIX_1='\''/home/apps/miniconda3'\''
84
+ export CONDA_EXE='\''/home/apps/miniconda3/bin/conda'\''
85
+ export CONDA_PYTHON_EXE='\''/home/apps/miniconda3/bin/python'\''
86
+ . "/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/etc/conda/activate.d/gdal-activate.sh"
87
+ . "/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/etc/conda/activate.d/geotiff-activate.sh"
88
+ . "/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/etc/conda/activate.d/libarrow_activate.sh"
89
+ . "/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/etc/conda/activate.d/libglib_activate.sh"
90
+ . "/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/etc/conda/activate.d/libpdal-core_activate.sh"
91
+ . "/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/etc/conda/activate.d/libxml2_activate.sh"
92
+ . "/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/etc/conda/activate.d/pdal-python-activate.sh"
93
+ . "/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/etc/conda/activate.d/proj4-activate.sh"'
94
+ ++ unset _CE_M
95
+ ++ unset _CE_CONDA
96
+ ++ PS1='(aku_env) '
97
+ ++ export PATH=/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/bin:/home/IITB/ai-at-ieor/23m1521/.vscode-server/cli/servers/Stable-ddc367ed5c8936efe395cffeec279b04ffd7db78/server/bin/remote-cli:/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/bin:/home/apps/MLDL/DL-CondaPy3/condabin:/home/IITB/ai-at-ieor/23m1521/.local/bin:/home/IITB/ai-at-ieor/23m1521/bin:/usr/local/bin:/usr/bin:/usr/local/sbin:/usr/sbin:/var/lib/snapd/snap/bin:/home/IITB/ai-at-ieor/23m1521/.vscode-server/extensions/ms-python.debugpy-2025.4.1-linux-x64/bundled/scripts/noConfigScripts:/home/IITB/ai-at-ieor/23m1521/.vscode-server/data/User/globalStorage/github.copilot-chat/debugCommand
98
+ ++ PATH=/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/bin:/home/IITB/ai-at-ieor/23m1521/.vscode-server/cli/servers/Stable-ddc367ed5c8936efe395cffeec279b04ffd7db78/server/bin/remote-cli:/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/bin:/home/apps/MLDL/DL-CondaPy3/condabin:/home/IITB/ai-at-ieor/23m1521/.local/bin:/home/IITB/ai-at-ieor/23m1521/bin:/usr/local/bin:/usr/bin:/usr/local/sbin:/usr/sbin:/var/lib/snapd/snap/bin:/home/IITB/ai-at-ieor/23m1521/.vscode-server/extensions/ms-python.debugpy-2025.4.1-linux-x64/bundled/scripts/noConfigScripts:/home/IITB/ai-at-ieor/23m1521/.vscode-server/data/User/globalStorage/github.copilot-chat/debugCommand
99
+ ++ export CONDA_PREFIX=/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env
100
+ ++ CONDA_PREFIX=/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env
101
+ ++ export CONDA_SHLVL=2
102
+ ++ CONDA_SHLVL=2
103
+ ++ export CONDA_DEFAULT_ENV=aku_env
104
+ ++ CONDA_DEFAULT_ENV=aku_env
105
+ ++ export 'CONDA_PROMPT_MODIFIER=(aku_env) '
106
+ ++ CONDA_PROMPT_MODIFIER='(aku_env) '
107
+ ++ export CONDA_PREFIX_1=/home/apps/miniconda3
108
+ ++ CONDA_PREFIX_1=/home/apps/miniconda3
109
+ ++ export CONDA_EXE=/home/apps/miniconda3/bin/conda
110
+ ++ CONDA_EXE=/home/apps/miniconda3/bin/conda
111
+ ++ export CONDA_PYTHON_EXE=/home/apps/miniconda3/bin/python
112
+ ++ CONDA_PYTHON_EXE=/home/apps/miniconda3/bin/python
113
+ ++ . /home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/etc/conda/activate.d/gdal-activate.sh
114
+ +++ '[' -n /home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/share/gdal ']'
115
+ +++ export _CONDA_SET_GDAL_DATA=/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/share/gdal
116
+ +++ _CONDA_SET_GDAL_DATA=/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/share/gdal
117
+ +++ '[' -n /home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/lib/gdalplugins ']'
118
+ +++ export _CONDA_SET_GDAL_DRIVER_PATH=/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/lib/gdalplugins
119
+ +++ _CONDA_SET_GDAL_DRIVER_PATH=/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/lib/gdalplugins
120
+ +++ '[' -d /home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/share/gdal ']'
121
+ +++ export GDAL_DATA=/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/share/gdal
122
+ +++ GDAL_DATA=/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/share/gdal
123
+ +++ export GDAL_DRIVER_PATH=/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/lib/gdalplugins
124
+ +++ GDAL_DRIVER_PATH=/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/lib/gdalplugins
125
+ +++ '[' '!' -d /home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/lib/gdalplugins ']'
126
+ +++ export CPL_ZIP_ENCODING=UTF-8
127
+ +++ CPL_ZIP_ENCODING=UTF-8
128
+ +++ '[' -n '4.4.20(1)-release' ']'
129
+ +++ '[' -f /home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/share/bash-completion/completions/gdalinfo ']'
130
+ +++ source /home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/share/bash-completion/completions/gdalinfo
131
+ ++++ function_exists _get_comp_words_by_ref
132
+ ++++ declare -f -F _get_comp_words_by_ref
133
+ ++++ return 1
134
+ ++++ return 0
135
+ ++ . /home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/etc/conda/activate.d/geotiff-activate.sh
136
+ +++ '[' -n '' ']'
137
+ +++ '[' -d /home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/share/epsg_csv ']'
138
+ +++ '[' -d /home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/Library/share/epsg_csv ']'
139
+ ++ . /home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/etc/conda/activate.d/libarrow_activate.sh
140
+ +++ '[' -n '' ']'
141
+ +++ _la_log 'Beginning libarrow activation.'
142
+ +++ '[' '' = 1 ']'
143
+ +++ _la_gdb_prefix=/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/share/gdb/auto-load
144
+ +++ '[' '!' -w /home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/share/gdb/auto-load ']'
145
+ +++ _la_placeholder=replace_this_section_with_absolute_slashed_path_to_CONDA_PREFIX
146
+ +++ _la_symlink_dir=/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/share/gdb/auto-load//home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/lib
147
+ +++ _la_orig_install_dir=/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/share/gdb/auto-load/replace_this_section_with_absolute_slashed_path_to_CONDA_PREFIX/lib
148
+ +++ _la_log ' _la_gdb_prefix: /home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/share/gdb/auto-load'
149
+ +++ '[' '' = 1 ']'
150
+ +++ _la_log ' _la_placeholder: replace_this_section_with_absolute_slashed_path_to_CONDA_PREFIX'
151
+ +++ '[' '' = 1 ']'
152
+ +++ _la_log ' _la_symlink_dir: /home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/share/gdb/auto-load//home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/lib'
153
+ +++ '[' '' = 1 ']'
154
+ +++ _la_log ' _la_orig_install_dir: /home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/share/gdb/auto-load/replace_this_section_with_absolute_slashed_path_to_CONDA_PREFIX/lib'
155
+ +++ '[' '' = 1 ']'
156
+ +++ _la_log ' content of that folder:'
157
+ +++ '[' '' = 1 ']'
158
+ ++++ ls -al /home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/share/gdb/auto-load/replace_this_section_with_absolute_slashed_path_to_CONDA_PREFIX/lib
159
+ ++++ sed 's/^/ /'
160
+ +++ _la_log ' total 12
161
+ drwxr-sr-x 2 23m1521 ai-at-ieor 4096 Mar 23 19:37 .
162
+ drwxr-sr-x 3 23m1521 ai-at-ieor 4096 Mar 22 19:59 ..
163
+ -rw-r--r-- 1 23m1521 ai-at-ieor 992 Mar 23 19:36 libarrow.so.1900.1.0-gdb.py'
164
+ +++ '[' '' = 1 ']'
165
+ +++ for _la_target in "$_la_orig_install_dir/"*.py
166
+ +++ '[' '!' -e /home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/share/gdb/auto-load/replace_this_section_with_absolute_slashed_path_to_CONDA_PREFIX/lib/libarrow.so.1900.1.0-gdb.py ']'
167
+ ++++ basename /home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/share/gdb/auto-load/replace_this_section_with_absolute_slashed_path_to_CONDA_PREFIX/lib/libarrow.so.1900.1.0-gdb.py
168
+ +++ _la_symlink=/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/share/gdb/auto-load//home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/lib/libarrow.so.1900.1.0-gdb.py
169
+ +++ _la_log ' _la_target: /home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/share/gdb/auto-load/replace_this_section_with_absolute_slashed_path_to_CONDA_PREFIX/lib/libarrow.so.1900.1.0-gdb.py'
170
+ +++ '[' '' = 1 ']'
171
+ +++ _la_log ' _la_symlink: /home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/share/gdb/auto-load//home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/lib/libarrow.so.1900.1.0-gdb.py'
172
+ +++ '[' '' = 1 ']'
173
+ +++ '[' -L /home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/share/gdb/auto-load//home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/lib/libarrow.so.1900.1.0-gdb.py ']'
174
+ ++++ readlink /home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/share/gdb/auto-load//home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/lib/libarrow.so.1900.1.0-gdb.py
175
+ +++ '[' /home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/share/gdb/auto-load/replace_this_section_with_absolute_slashed_path_to_CONDA_PREFIX/lib/libarrow.so.1900.1.0-gdb.py = /home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/share/gdb/auto-load/replace_this_section_with_absolute_slashed_path_to_CONDA_PREFIX/lib/libarrow.so.1900.1.0-gdb.py ']'
176
+ +++ _la_log 'symlink $_la_symlink already exists and points to $_la_target, skipping.'
177
+ +++ '[' '' = 1 ']'
178
+ +++ continue
179
+ +++ _la_log 'Libarrow activation complete.'
180
+ +++ '[' '' = 1 ']'
181
+ +++ unset _la_gdb_prefix
182
+ +++ unset _la_log
183
+ +++ unset _la_orig_install_dir
184
+ +++ unset _la_placeholder
185
+ +++ unset _la_symlink
186
+ +++ unset _la_symlink_dir
187
+ +++ unset _la_target
188
+ ++ . /home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/etc/conda/activate.d/libglib_activate.sh
189
+ +++ export GSETTINGS_SCHEMA_DIR_CONDA_BACKUP=/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/share/glib-2.0/schemas
190
+ +++ GSETTINGS_SCHEMA_DIR_CONDA_BACKUP=/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/share/glib-2.0/schemas
191
+ +++ export GSETTINGS_SCHEMA_DIR=/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/share/glib-2.0/schemas
192
+ +++ GSETTINGS_SCHEMA_DIR=/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/share/glib-2.0/schemas
193
+ ++ . /home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/etc/conda/activate.d/libpdal-core_activate.sh
194
+ +++ '[' -n /home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/lib:/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/lib/python3.12/site-packages/pdal ']'
195
+ +++ export _CONDA_SET_PDAL_DRIVER_PATH=/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/lib:/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/lib/python3.12/site-packages/pdal
196
+ +++ _CONDA_SET_PDAL_DRIVER_PATH=/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/lib:/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/lib/python3.12/site-packages/pdal
197
+ +++ export PDAL_DRIVER_PATH=/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/lib
198
+ +++ PDAL_DRIVER_PATH=/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/lib
199
+ +++ '[' '!' -d /home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/lib ']'
200
+ ++ . /home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/etc/conda/activate.d/libxml2_activate.sh
201
+ +++ test -n 'file:///home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/etc/xml/catalog file:///etc/xml/catalog'
202
+ +++ xml_catalog_files_libxml2='file:///home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/etc/xml/catalog file:///etc/xml/catalog'
203
+ +++ XML_CATALOG_FILES='file:///home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/etc/xml/catalog file:///etc/xml/catalog '
204
+ +++ conda_catalog_files=
205
+ +++ ifs_libxml2='
206
+ '
207
+ +++ IFS=' '
208
+ +++ rem=/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env
209
+ +++ for pre in ${rem}
210
+ +++ test '' = /home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env
211
+ +++ conda_catalog_files=/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env
212
+ +++ rem=
213
+ +++ IFS='
214
+ '
215
+ +++ conda_catalog_files='file:///home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/etc/xml/catalog file:///etc/xml/catalog'
216
+ +++ export 'XML_CATALOG_FILES=file:///home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/etc/xml/catalog file:///etc/xml/catalog file:///home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/etc/xml/catalog file:///etc/xml/catalog'
217
+ +++ XML_CATALOG_FILES='file:///home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/etc/xml/catalog file:///etc/xml/catalog file:///home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/etc/xml/catalog file:///etc/xml/catalog'
218
+ +++ unset conda_catalog_files ifs_libxml2 rem
219
+ ++ . /home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/etc/conda/activate.d/pdal-python-activate.sh
220
+ +++ [[ -n /home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/lib ]]
221
+ +++ export _CONDA_SET_PDAL_PYTHON_DRIVER_PATH=/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/lib
222
+ +++ _CONDA_SET_PDAL_PYTHON_DRIVER_PATH=/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/lib
223
+ +++ export PDAL_DRIVER_PATH=/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/lib:/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/lib/python3.12/site-packages/pdal
224
+ +++ PDAL_DRIVER_PATH=/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/lib:/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/lib/python3.12/site-packages/pdal
225
+ ++ . /home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/etc/conda/activate.d/proj4-activate.sh
226
+ +++ '[' -n /home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/share/proj ']'
227
+ +++ export _CONDA_SET_PROJ_DATA=/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/share/proj
228
+ +++ _CONDA_SET_PROJ_DATA=/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/share/proj
229
+ +++ '[' -d /home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/share/proj ']'
230
+ +++ export PROJ_DATA=/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/share/proj
231
+ +++ PROJ_DATA=/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/share/proj
232
+ +++ '[' -f /home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/share/proj/copyright_and_licenses.csv ']'
233
+ +++ export PROJ_NETWORK=ON
234
+ +++ PROJ_NETWORK=ON
235
+ + __conda_hashr
236
+ + '[' -n '' ']'
237
+ + '[' -n '' ']'
238
+ + hash -r
239
+ + python /home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/image_data_metadata.py
240
+
241
+ + conda deactivate
242
+ + local cmd=deactivate
243
+ + case "$cmd" in
244
+ + __conda_activate deactivate
245
+ + '[' -n '' ']'
246
+ + local ask_conda
247
+ ++ PS1='(aku_env) '
248
+ ++ __conda_exe shell.posix deactivate
249
+ ++ '[' -n '' ']'
250
+ ++ /home/apps/miniconda3/bin/conda shell.posix deactivate
251
+ + ask_conda='export PATH='\''/home/apps/miniconda3/bin:/home/IITB/ai-at-ieor/23m1521/.vscode-server/cli/servers/Stable-ddc367ed5c8936efe395cffeec279b04ffd7db78/server/bin/remote-cli:/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/bin:/home/apps/MLDL/DL-CondaPy3/condabin:/home/IITB/ai-at-ieor/23m1521/.local/bin:/home/IITB/ai-at-ieor/23m1521/bin:/usr/local/bin:/usr/bin:/usr/local/sbin:/usr/sbin:/var/lib/snapd/snap/bin:/home/IITB/ai-at-ieor/23m1521/.vscode-server/extensions/ms-python.debugpy-2025.4.1-linux-x64/bundled/scripts/noConfigScripts:/home/IITB/ai-at-ieor/23m1521/.vscode-server/data/User/globalStorage/github.copilot-chat/debugCommand'\''
252
+ . "/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/etc/conda/deactivate.d/proj4-deactivate.sh"
253
+ . "/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/etc/conda/deactivate.d/pdal-python-deactivate.sh"
254
+ . "/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/etc/conda/deactivate.d/libxml2_deactivate.sh"
255
+ . "/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/etc/conda/deactivate.d/libpdal-core_deactivate.sh"
256
+ . "/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/etc/conda/deactivate.d/libglib_deactivate.sh"
257
+ . "/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/etc/conda/deactivate.d/geotiff-deactivate.sh"
258
+ . "/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/etc/conda/deactivate.d/gdal-deactivate.sh"
259
+ unset CONDA_PREFIX_1
260
+ unset _CE_M
261
+ unset _CE_CONDA
262
+ PS1='\''(base) '\''
263
+ export CONDA_PREFIX='\''/home/apps/miniconda3'\''
264
+ export CONDA_SHLVL='\''1'\''
265
+ export CONDA_DEFAULT_ENV='\''base'\''
266
+ export CONDA_PROMPT_MODIFIER='\''(base) '\''
267
+ export CONDA_EXE='\''/home/apps/miniconda3/bin/conda'\''
268
+ export CONDA_PYTHON_EXE='\''/home/apps/miniconda3/bin/python'\'''
269
+ + eval 'export PATH='\''/home/apps/miniconda3/bin:/home/IITB/ai-at-ieor/23m1521/.vscode-server/cli/servers/Stable-ddc367ed5c8936efe395cffeec279b04ffd7db78/server/bin/remote-cli:/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/bin:/home/apps/MLDL/DL-CondaPy3/condabin:/home/IITB/ai-at-ieor/23m1521/.local/bin:/home/IITB/ai-at-ieor/23m1521/bin:/usr/local/bin:/usr/bin:/usr/local/sbin:/usr/sbin:/var/lib/snapd/snap/bin:/home/IITB/ai-at-ieor/23m1521/.vscode-server/extensions/ms-python.debugpy-2025.4.1-linux-x64/bundled/scripts/noConfigScripts:/home/IITB/ai-at-ieor/23m1521/.vscode-server/data/User/globalStorage/github.copilot-chat/debugCommand'\''
270
+ . "/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/etc/conda/deactivate.d/proj4-deactivate.sh"
271
+ . "/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/etc/conda/deactivate.d/pdal-python-deactivate.sh"
272
+ . "/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/etc/conda/deactivate.d/libxml2_deactivate.sh"
273
+ . "/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/etc/conda/deactivate.d/libpdal-core_deactivate.sh"
274
+ . "/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/etc/conda/deactivate.d/libglib_deactivate.sh"
275
+ . "/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/etc/conda/deactivate.d/geotiff-deactivate.sh"
276
+ . "/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/etc/conda/deactivate.d/gdal-deactivate.sh"
277
+ unset CONDA_PREFIX_1
278
+ unset _CE_M
279
+ unset _CE_CONDA
280
+ PS1='\''(base) '\''
281
+ export CONDA_PREFIX='\''/home/apps/miniconda3'\''
282
+ export CONDA_SHLVL='\''1'\''
283
+ export CONDA_DEFAULT_ENV='\''base'\''
284
+ export CONDA_PROMPT_MODIFIER='\''(base) '\''
285
+ export CONDA_EXE='\''/home/apps/miniconda3/bin/conda'\''
286
+ export CONDA_PYTHON_EXE='\''/home/apps/miniconda3/bin/python'\'''
287
+ ++ export PATH=/home/apps/miniconda3/bin:/home/IITB/ai-at-ieor/23m1521/.vscode-server/cli/servers/Stable-ddc367ed5c8936efe395cffeec279b04ffd7db78/server/bin/remote-cli:/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/bin:/home/apps/MLDL/DL-CondaPy3/condabin:/home/IITB/ai-at-ieor/23m1521/.local/bin:/home/IITB/ai-at-ieor/23m1521/bin:/usr/local/bin:/usr/bin:/usr/local/sbin:/usr/sbin:/var/lib/snapd/snap/bin:/home/IITB/ai-at-ieor/23m1521/.vscode-server/extensions/ms-python.debugpy-2025.4.1-linux-x64/bundled/scripts/noConfigScripts:/home/IITB/ai-at-ieor/23m1521/.vscode-server/data/User/globalStorage/github.copilot-chat/debugCommand
288
+ ++ PATH=/home/apps/miniconda3/bin:/home/IITB/ai-at-ieor/23m1521/.vscode-server/cli/servers/Stable-ddc367ed5c8936efe395cffeec279b04ffd7db78/server/bin/remote-cli:/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/bin:/home/apps/MLDL/DL-CondaPy3/condabin:/home/IITB/ai-at-ieor/23m1521/.local/bin:/home/IITB/ai-at-ieor/23m1521/bin:/usr/local/bin:/usr/bin:/usr/local/sbin:/usr/sbin:/var/lib/snapd/snap/bin:/home/IITB/ai-at-ieor/23m1521/.vscode-server/extensions/ms-python.debugpy-2025.4.1-linux-x64/bundled/scripts/noConfigScripts:/home/IITB/ai-at-ieor/23m1521/.vscode-server/data/User/globalStorage/github.copilot-chat/debugCommand
289
+ ++ . /home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/etc/conda/deactivate.d/proj4-deactivate.sh
290
+ +++ unset PROJ_DATA
291
+ +++ unset PROJ_NETWORK
292
+ +++ '[' -n /home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/share/proj ']'
293
+ +++ export PROJ_DATA=/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/share/proj
294
+ +++ PROJ_DATA=/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/share/proj
295
+ +++ unset _CONDA_SET_PROJ_DATA
296
+ ++ . /home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/etc/conda/deactivate.d/pdal-python-deactivate.sh
297
+ +++ [[ -n /home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/lib ]]
298
+ +++ export PDAL_DRIVER_PATH=/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/lib
299
+ +++ PDAL_DRIVER_PATH=/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/lib
300
+ +++ unset _CONDA_SET_PDAL_PYTHON_DRIVER_PATH
301
+ ++ . /home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/etc/conda/deactivate.d/libxml2_deactivate.sh
302
+ +++ test -n 'file:///home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/etc/xml/catalog file:///etc/xml/catalog'
303
+ +++ export 'XML_CATALOG_FILES=file:///home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/etc/xml/catalog file:///etc/xml/catalog'
304
+ +++ XML_CATALOG_FILES='file:///home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/etc/xml/catalog file:///etc/xml/catalog'
305
+ +++ unset xml_catalog_files_libxml2
306
+ ++ . /home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/etc/conda/deactivate.d/libpdal-core_deactivate.sh
307
+ +++ unset PDAL_DRIVER_PATH
308
+ +++ '[' -n /home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/lib:/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/lib/python3.12/site-packages/pdal ']'
309
+ +++ export PDAL_DRIVER_PATH=/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/lib:/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/lib/python3.12/site-packages/pdal
310
+ +++ PDAL_DRIVER_PATH=/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/lib:/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/lib/python3.12/site-packages/pdal
311
+ +++ unset _CONDA_SET_PDAL_DRIVER_PATH
312
+ ++ . /home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/etc/conda/deactivate.d/libglib_deactivate.sh
313
+ +++ export GSETTINGS_SCHEMA_DIR=/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/share/glib-2.0/schemas
314
+ +++ GSETTINGS_SCHEMA_DIR=/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/share/glib-2.0/schemas
315
+ +++ unset GSETTINGS_SCHEMA_DIR_CONDA_BACKUP
316
+ +++ '[' -z /home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/share/glib-2.0/schemas ']'
317
+ ++ . /home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/etc/conda/deactivate.d/geotiff-deactivate.sh
318
+ +++ unset GEOTIFF_CSV
319
+ +++ '[' -n '' ']'
320
+ ++ . /home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/etc/conda/deactivate.d/gdal-deactivate.sh
321
+ +++ unset GDAL_DATA
322
+ +++ '[' -n /home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/share/gdal ']'
323
+ +++ export GDAL_DATA=/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/share/gdal
324
+ +++ GDAL_DATA=/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/share/gdal
325
+ +++ unset _CONDA_SET_GDAL_DATA
326
+ +++ unset GDAL_DRIVER_PATH
327
+ +++ '[' -n /home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/lib/gdalplugins ']'
328
+ +++ export GDAL_DRIVER_PATH=/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/lib/gdalplugins
329
+ +++ GDAL_DRIVER_PATH=/home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env/lib/gdalplugins
330
+ +++ unset _CONDA_SET_GDAL_DRIVER_PATH
331
+ +++ unset CPL_ZIP_ENCODING
332
+ ++ unset CONDA_PREFIX_1
333
+ ++ unset _CE_M
334
+ ++ unset _CE_CONDA
335
+ ++ PS1='(base) '
336
+ ++ export CONDA_PREFIX=/home/apps/miniconda3
337
+ ++ CONDA_PREFIX=/home/apps/miniconda3
338
+ ++ export CONDA_SHLVL=1
339
+ ++ CONDA_SHLVL=1
340
+ ++ export CONDA_DEFAULT_ENV=base
341
+ ++ CONDA_DEFAULT_ENV=base
342
+ ++ export 'CONDA_PROMPT_MODIFIER=(base) '
343
+ ++ CONDA_PROMPT_MODIFIER='(base) '
344
+ ++ export CONDA_EXE=/home/apps/miniconda3/bin/conda
345
+ ++ CONDA_EXE=/home/apps/miniconda3/bin/conda
346
+ ++ export CONDA_PYTHON_EXE=/home/apps/miniconda3/bin/python
347
+ ++ CONDA_PYTHON_EXE=/home/apps/miniconda3/bin/python
348
+ + __conda_hashr
349
+ + '[' -n '' ']'
350
+ + '[' -n '' ']'
351
+ + hash -r
Vaani/39448.out ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # conda environments:
3
+ #
4
+ aku_env /home/IITB/ai-at-ieor/23m1521/.conda/envs/aku_env
5
+ cuml /home/IITB/ai-at-ieor/23m1521/.conda/envs/cuml
6
+ base * /home/apps/miniconda3
7
+ SCA_deepspeed /home/apps/miniconda3/envs/SCA_deepspeed
8
+ llama2 /home/apps/miniconda3/envs/llama2
9
+ tutorial /home/apps/miniconda3/envs/tutorial
10
+
11
+ Results saved to /home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/image_dimensions_count.csv
Vaani/IISc_VaaniProject_M_AP_Anantpur_00014520_1544240000_APATSR_190315_1880_16300.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:903b573851ab7767554050c6b238964660511571352084706283a2db802ffb35
3
+ size 462726
Vaani/LDM/__init__.py ADDED
File without changes
Vaani/LDM/notebooks/Vaani-subplot.png ADDED

Git LFS Details

  • SHA256: 3b22fe2de54d1a38e517def2bd26d83b5eb1279237f02b7652cc4480530492a5
  • Pointer size: 132 Bytes
  • Size of remote file: 8.94 MB
Vaani/LDM/notebooks/Vaani_VQVAE_Recon_Images/reconstructed_images_EP-15_16.png ADDED

Git LFS Details

  • SHA256: 7a8b412940e4342bd00636ca00b292e82530e39ae1f3dfb8d8993004a7ba9973
  • Pointer size: 131 Bytes
  • Size of remote file: 959 kB
Vaani/LDM/notebooks/Vaani_VQVAE_Recon_Images/reconstructed_images_EP-30_16.png ADDED

Git LFS Details

  • SHA256: 98cafe70cfb371e0ed61091c479c920ff6de0b8d29c46c7190a1c254244828e3
  • Pointer size: 131 Bytes
  • Size of remote file: 968 kB
Vaani/LDM/notebooks/Vaani_VQVAE_Recon_Images/reconstructed_images_EP-4.png ADDED

Git LFS Details

  • SHA256: 4ba9a1004033edbe03c1a0d9e1672fc0f4f966968f08ac6e535db52531f0ec14
  • Pointer size: 131 Bytes
  • Size of remote file: 491 kB
Vaani/LDM/notebooks/Vaani_VQVAE_Recon_Images/reconstructed_images_EP-5.png ADDED

Git LFS Details

  • SHA256: ef43308df008221144827461f26bacca11f6c9a0d17970129f076b1a95ab630e
  • Pointer size: 131 Bytes
  • Size of remote file: 488 kB
Vaani/LDM/notebooks/Vaani_VQVAE_Recon_Images/reconstructed_images_EP-6.png ADDED

Git LFS Details

  • SHA256: 4d909e9e6de6cec642bf668f451e38ae021854250ae86353eaa030afbc95529b
  • Pointer size: 131 Bytes
  • Size of remote file: 491 kB
Vaani/LDM/notebooks/Vaani_VQVAE_Recon_Images/reconstructed_images_EP-6_16.png ADDED

Git LFS Details

  • SHA256: 6b16b33674cc9d2c6618828519b07de8553835c825bf248131efc054a67d277d
  • Pointer size: 131 Bytes
  • Size of remote file: 967 kB
Vaani/LDM/notebooks/Vaani_VQVAE_Recon_Images/reconstructed_images_EP-8_16.png ADDED

Git LFS Details

  • SHA256: da223574afe8fe5860b17e229f23b94b04d7acbd7f9409270c7f301f5a72bcac
  • Pointer size: 131 Bytes
  • Size of remote file: 971 kB
Vaani/LDM/notebooks/_1_Main.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
Vaani/LDM/notebooks/_2_Rough-LPIPS.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
Vaani/LDM/scripts/AE-training.log ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0
  0%| | 0/3 [00:00<?, ?it/s]
 
1
  0%| | 0/32201 [00:00<?, ?it/s]
 
2
  0%| | 1/32201 [00:04<42:33:43, 4.76s/it]
 
3
  0%| | 2/32201 [00:05<19:42:29, 2.20s/it]
 
4
  0%| | 3/32201 [00:05<13:55:59, 1.56s/it]
 
5
  0%| | 4/32201 [00:06<11:09:27, 1.25s/it]
 
6
  0%| | 5/32201 [00:07<9:39:35, 1.08s/it] 
 
7
  0%| | 6/32201 [00:08<8:45:29, 1.02it/s]
 
8
  0%| | 7/32201 [00:09<8:11:20, 1.09it/s]
 
9
  0%| | 8/32201 [00:09<7:48:52, 1.14it/s]
 
10
  0%| | 9/32201 [00:10<7:33:28, 1.18it/s]
 
11
  0%| | 10/32201 [00:11<7:22:44, 1.21it/s]
 
12
  0%| | 11/32201 [00:12<7:15:53, 1.23it/s]
 
13
  0%| | 12/32201 [00:13<7:11:41, 1.24it/s]
 
14
  0%| | 13/32201 [00:13<7:08:13, 1.25it/s]
 
15
  0%| | 14/32201 [00:14<7:05:54, 1.26it/s]
 
16
  0%| | 15/32201 [00:15<7:05:03, 1.26it/s]
 
17
  0%| | 16/32201 [00:16<7:02:58, 1.27it/s]
 
18
  0%| | 17/32201 [00:16<7:01:54, 1.27it/s]
 
19
  0%| | 18/32201 [00:17<7:01:32, 1.27it/s]
 
20
  0%| | 19/32201 [00:18<7:01:28, 1.27it/s]
 
21
  0%| | 20/32201 [00:19<7:00:58, 1.27it/s]
 
22
  0%| | 21/32201 [00:20<7:00:49, 1.27it/s]
 
23
  0%| | 22/32201 [00:20<7:00:22, 1.28it/s]
 
24
  0%| | 23/32201 [00:21<7:00:16, 1.28it/s]
 
25
  0%| | 24/32201 [00:22<7:00:04, 1.28it/s]
 
26
  0%| | 25/32201 [00:23<7:00:03, 1.28it/s]
 
27
  0%| | 26/32201 [00:23<7:00:35, 1.27it/s]
 
28
  0%| | 27/32201 [00:24<7:00:58, 1.27it/s]
 
29
  0%| | 28/32201 [00:25<7:01:00, 1.27it/s]
 
30
  0%| | 29/32201 [00:26<7:01:05, 1.27it/s]
 
31
  0%| | 30/32201 [00:27<7:00:31, 1.28it/s]
 
32
  0%| | 31/32201 [00:27<7:00:28, 1.28it/s]
 
33
  0%| | 32/32201 [00:28<7:00:15, 1.28it/s]
 
34
  0%| | 33/32201 [00:29<7:00:04, 1.28it/s]
 
35
  0%| | 34/32201 [00:30<7:00:28, 1.28it/s]
 
36
  0%| | 35/32201 [00:31<7:00:09, 1.28it/s]
 
37
  0%| | 36/32201 [00:31<6:59:58, 1.28it/s]
 
38
  0%| | 37/32201 [00:32<7:00:04, 1.28it/s]
 
39
  0%| | 38/32201 [00:33<7:00:27, 1.27it/s]
 
40
  0%| | 39/32201 [00:34<7:00:33, 1.27it/s]
 
41
  0%| | 40/32201 [00:34<7:00:21, 1.28it/s]
 
42
  0%| | 41/32201 [00:35<7:00:25, 1.27it/s]
 
43
  0%| | 42/32201 [00:36<7:00:46, 1.27it/s]
 
44
  0%| | 43/32201 [00:37<7:00:35, 1.27it/s]
 
45
  0%| | 44/32201 [00:38<7:00:56, 1.27it/s]
 
46
  0%| | 45/32201 [00:38<7:00:58, 1.27it/s]
 
47
  0%| | 46/32201 [00:39<7:01:10, 1.27it/s]
 
48
  0%| | 47/32201 [00:40<7:01:06, 1.27it/s]
 
49
  0%| | 48/32201 [00:41<7:00:54, 1.27it/s]
 
50
  0%| | 49/32201 [00:42<7:01:07, 1.27it/s]
 
51
  0%| | 50/32201 [00:42<7:01:11, 1.27it/s]
 
52
  0%| | 51/32201 [00:43<7:01:03, 1.27it/s]
 
53
  0%| | 52/32201 [00:44<7:00:58, 1.27it/s]
 
54
  0%| | 53/32201 [00:45<7:01:05, 1.27it/s]
 
55
  0%| | 54/32201 [00:45<7:00:55, 1.27it/s]
 
56
  0%| | 55/32201 [00:46<7:01:12, 1.27it/s]
 
57
  0%| | 56/32201 [00:47<7:01:30, 1.27it/s]
 
58
  0%| | 57/32201 [00:48<7:01:27, 1.27it/s]
 
59
  0%| | 58/32201 [00:49<7:01:30, 1.27it/s]
 
60
  0%| | 59/32201 [00:49<7:02:07, 1.27it/s]
61
  0%| | 59/32201 [00:50<7:36:56, 1.17it/s]
 
62
  0%| | 0/3 [00:50<?, ?it/s]
 
 
 
 
 
 
 
 
1
+ TIME: 2025-03-25 01:30:39.070253
2
+ DEVICE: cuda
3
+ {'autoencoder_params': {'attn_down': [False, False],
4
+ 'codebook_size': 20,
5
+ 'down_channels': [32, 64, 128],
6
+ 'down_sample': [True, True],
7
+ 'mid_channels': [128, 128],
8
+ 'norm_channels': 32,
9
+ 'num_down_layers': 4,
10
+ 'num_heads': 16,
11
+ 'num_mid_layers': 4,
12
+ 'num_up_layers': 4,
13
+ 'z_channels': 3},
14
+ 'dataset_params': {'im_channels': 3,
15
+ 'im_path': '/home/taruntejaneurips23/Ashish/datasets/CelebA/img_align_celeba/img_align_celeba',
16
+ 'im_size': 256},
17
+ 'diffusion_params': {'beta_end': 0.0195, 'beta_start': 0.0015, 'num_timesteps': 1000},
18
+ 'ldm_params': {'attn_down': [True, True, True],
19
+ 'conv_out_channels': 128,
20
+ 'down_channels': [128, 256, 256, 256],
21
+ 'down_sample': [False, False, False],
22
+ 'mid_channels': [256, 256],
23
+ 'norm_channels': 32,
24
+ 'num_down_layers': 2,
25
+ 'num_heads': 16,
26
+ 'num_mid_layers': 2,
27
+ 'num_up_layers': 2,
28
+ 'time_emb_dim': 256},
29
+ 'train_params': {'autoencoder_acc_steps': 1,
30
+ 'autoencoder_batch_size': 4,
31
+ 'autoencoder_epochs': 3,
32
+ 'autoencoder_img_save_steps': 8,
33
+ 'autoencoder_lr': 0.0001,
34
+ 'checkpoint_dir': './',
35
+ 'codebook_weight': 1,
36
+ 'commitment_beta': 0.2,
37
+ 'disc_start': 1000,
38
+ 'disc_weight': 0.5,
39
+ 'kl_weight': 5e-06,
40
+ 'ldm_batch_size': 1,
41
+ 'ldm_ckpt_name': 'ddpm_ckpt.pth',
42
+ 'ldm_epochs': 10,
43
+ 'ldm_lr': 1e-05,
44
+ 'num_grid_rows': 3,
45
+ 'num_samples': 9,
46
+ 'perceptual_weight': 1,
47
+ 'save_latents': True,
48
+ 'seed': 4422,
49
+ 'task_name': 'VaaniLDM',
50
+ 'vqvae_autoencoder_ckpt_name': 'vqvae_autoencoder_ckpt.pth',
51
+ 'vqvae_discriminator_ckpt_name': 'vqvae_discriminator_ckpt.pth',
52
+ 'vqvae_latent_dir_name': 'vqvae_latents'},
53
+ 'training': {'_continue_': True}}
54
+ Files found: 128807
55
+ IMAGE SHAPE: torch.Size([3, 256, 256])
56
+ BATCH SHAPE: torch.Size([4, 3, 256, 256])
57
+ No checkpoint found. Starting from scratch.
58
+
59
  0%| | 0/3 [00:00<?, ?it/s]
60
+
61
  0%| | 0/32201 [00:00<?, ?it/s]
62
+
63
  0%| | 1/32201 [00:04<42:33:43, 4.76s/it]
64
+
65
  0%| | 2/32201 [00:05<19:42:29, 2.20s/it]
66
+
67
  0%| | 3/32201 [00:05<13:55:59, 1.56s/it]
68
+
69
  0%| | 4/32201 [00:06<11:09:27, 1.25s/it]
70
+
71
  0%| | 5/32201 [00:07<9:39:35, 1.08s/it] 
72
+
73
  0%| | 6/32201 [00:08<8:45:29, 1.02it/s]
74
+
75
  0%| | 7/32201 [00:09<8:11:20, 1.09it/s]
76
+
77
  0%| | 8/32201 [00:09<7:48:52, 1.14it/s]
78
+
79
  0%| | 9/32201 [00:10<7:33:28, 1.18it/s]
80
+
81
  0%| | 10/32201 [00:11<7:22:44, 1.21it/s]
82
+
83
  0%| | 11/32201 [00:12<7:15:53, 1.23it/s]
84
+
85
  0%| | 12/32201 [00:13<7:11:41, 1.24it/s]
86
+
87
  0%| | 13/32201 [00:13<7:08:13, 1.25it/s]
88
+
89
  0%| | 14/32201 [00:14<7:05:54, 1.26it/s]
90
+
91
  0%| | 15/32201 [00:15<7:05:03, 1.26it/s]
92
+
93
  0%| | 16/32201 [00:16<7:02:58, 1.27it/s]
94
+
95
  0%| | 17/32201 [00:16<7:01:54, 1.27it/s]
96
+
97
  0%| | 18/32201 [00:17<7:01:32, 1.27it/s]
98
+
99
  0%| | 19/32201 [00:18<7:01:28, 1.27it/s]
100
+
101
  0%| | 20/32201 [00:19<7:00:58, 1.27it/s]
102
+
103
  0%| | 21/32201 [00:20<7:00:49, 1.27it/s]
104
+
105
  0%| | 22/32201 [00:20<7:00:22, 1.28it/s]
106
+
107
  0%| | 23/32201 [00:21<7:00:16, 1.28it/s]
108
+
109
  0%| | 24/32201 [00:22<7:00:04, 1.28it/s]
110
+
111
  0%| | 25/32201 [00:23<7:00:03, 1.28it/s]
112
+
113
  0%| | 26/32201 [00:23<7:00:35, 1.27it/s]
114
+
115
  0%| | 27/32201 [00:24<7:00:58, 1.27it/s]
116
+
117
  0%| | 28/32201 [00:25<7:01:00, 1.27it/s]
118
+
119
  0%| | 29/32201 [00:26<7:01:05, 1.27it/s]
120
+
121
  0%| | 30/32201 [00:27<7:00:31, 1.28it/s]
122
+
123
  0%| | 31/32201 [00:27<7:00:28, 1.28it/s]
124
+
125
  0%| | 32/32201 [00:28<7:00:15, 1.28it/s]
126
+
127
  0%| | 33/32201 [00:29<7:00:04, 1.28it/s]
128
+
129
  0%| | 34/32201 [00:30<7:00:28, 1.28it/s]
130
+
131
  0%| | 35/32201 [00:31<7:00:09, 1.28it/s]
132
+
133
  0%| | 36/32201 [00:31<6:59:58, 1.28it/s]
134
+
135
  0%| | 37/32201 [00:32<7:00:04, 1.28it/s]
136
+
137
  0%| | 38/32201 [00:33<7:00:27, 1.27it/s]
138
+
139
  0%| | 39/32201 [00:34<7:00:33, 1.27it/s]
140
+
141
  0%| | 40/32201 [00:34<7:00:21, 1.28it/s]
142
+
143
  0%| | 41/32201 [00:35<7:00:25, 1.27it/s]
144
+
145
  0%| | 42/32201 [00:36<7:00:46, 1.27it/s]
146
+
147
  0%| | 43/32201 [00:37<7:00:35, 1.27it/s]
148
+
149
  0%| | 44/32201 [00:38<7:00:56, 1.27it/s]
150
+
151
  0%| | 45/32201 [00:38<7:00:58, 1.27it/s]
152
+
153
  0%| | 46/32201 [00:39<7:01:10, 1.27it/s]
154
+
155
  0%| | 47/32201 [00:40<7:01:06, 1.27it/s]
156
+
157
  0%| | 48/32201 [00:41<7:00:54, 1.27it/s]
158
+
159
  0%| | 49/32201 [00:42<7:01:07, 1.27it/s]
160
+
161
  0%| | 50/32201 [00:42<7:01:11, 1.27it/s]
162
+
163
  0%| | 51/32201 [00:43<7:01:03, 1.27it/s]
164
+
165
  0%| | 52/32201 [00:44<7:00:58, 1.27it/s]
166
+
167
  0%| | 53/32201 [00:45<7:01:05, 1.27it/s]
168
+
169
  0%| | 54/32201 [00:45<7:00:55, 1.27it/s]
170
+
171
  0%| | 55/32201 [00:46<7:01:12, 1.27it/s]
172
+
173
  0%| | 56/32201 [00:47<7:01:30, 1.27it/s]
174
+
175
  0%| | 57/32201 [00:48<7:01:27, 1.27it/s]
176
+
177
  0%| | 58/32201 [00:49<7:01:30, 1.27it/s]
178
+
179
  0%| | 59/32201 [00:49<7:02:07, 1.27it/s]
180
  0%| | 59/32201 [00:50<7:36:56, 1.17it/s]
181
+
182
  0%| | 0/3 [00:50<?, ?it/s]
183
+ Traceback (most recent call last):
184
+ File "/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/LDM/scripts/Vaani-VQVAE-Main.py", line 1105, in <module>
185
+ trainVAE(Config, dataloader)
186
+ File "/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/LDM/scripts/Vaani-VQVAE-Main.py", line 1049, in trainVAE
187
+ images = images.to(device)
188
+ ^^^^^^^^^^^^^^^^^
189
+ KeyboardInterrupt
Vaani/LDM/scripts/Main.py ADDED
@@ -0,0 +1,2303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ==================================================================
2
+ # L A T E N T D I F F U S I O N M O D E L
3
+ # ==================================================================
4
+ # Author : Ashish Kumar Uchadiya
5
+ # Created : November 3, 2024
6
+ # Description: This script implements a Latent Diffusion Model using
7
+ # a cosine or linear noise scheduling approach for high-resolution
8
+ # image generation. The model leverages generative techniques to
9
+ # learn a latent representation and progressively reduce noise to
10
+ # generate clear, realistic images.
11
+ # ==================================================================
12
+ # I M P O R T S
13
+ # ==================================================================
14
+
15
+ import os
16
+ os.environ["CUDA_VISIBLE_DEVICES"] = "1"
17
+
18
+ """Lpips"""
19
+
20
+ # from __future__ import absolute_import
21
+ from collections import namedtuple
22
+ import torch
23
+ import torch.nn as nn
24
+ import torch.nn.init as init
25
+ from torch.autograd import Variable
26
+ import numpy as np
27
+ import torch.nn
28
+ import torchvision
29
+
30
+ # Taken from https://github.com/richzhang/PerceptualSimilarity/blob/master/lpips/lpips.py
31
+
32
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
33
+
34
+
35
+ def spatial_average(in_tens, keepdim=True):
36
+ return in_tens.mean([2, 3], keepdim=keepdim)
37
+
38
+
39
+ class vgg16(torch.nn.Module):
40
+ def __init__(self, requires_grad=False, pretrained=True):
41
+ super(vgg16, self).__init__()
42
+ vgg_pretrained_features = torchvision.models.vgg16(
43
+ weights=torchvision.models.VGG16_Weights.IMAGENET1K_V1
44
+ ).features
45
+ self.slice1 = torch.nn.Sequential()
46
+ self.slice2 = torch.nn.Sequential()
47
+ self.slice3 = torch.nn.Sequential()
48
+ self.slice4 = torch.nn.Sequential()
49
+ self.slice5 = torch.nn.Sequential()
50
+ self.N_slices = 5
51
+ for x in range(4):
52
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
53
+ for x in range(4, 9):
54
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
55
+ for x in range(9, 16):
56
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
57
+ for x in range(16, 23):
58
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
59
+ for x in range(23, 30):
60
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
61
+
62
+ # Freeze vgg model
63
+ if not requires_grad:
64
+ for param in self.parameters():
65
+ param.requires_grad = False
66
+
67
+ def forward(self, X):
68
+ # Return output of vgg features
69
+ h = self.slice1(X)
70
+ h_relu1_2 = h
71
+ h = self.slice2(h)
72
+ h_relu2_2 = h
73
+ h = self.slice3(h)
74
+ h_relu3_3 = h
75
+ h = self.slice4(h)
76
+ h_relu4_3 = h
77
+ h = self.slice5(h)
78
+ h_relu5_3 = h
79
+ vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
80
+ out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
81
+ return out
82
+
83
+
84
+ # Learned perceptual metric
85
+ class LPIPS(nn.Module):
86
+ def __init__(self, net='vgg', version='0.1', use_dropout=True):
87
+ super(LPIPS, self).__init__()
88
+ self.version = version
89
+ # Imagenet normalization
90
+ self.scaling_layer = ScalingLayer()
91
+ ########################
92
+
93
+ # Instantiate vgg model
94
+ self.chns = [64, 128, 256, 512, 512]
95
+ self.L = len(self.chns)
96
+ self.net = vgg16(pretrained=True, requires_grad=False)
97
+
98
+ # Add 1x1 convolutional Layers
99
+ self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
100
+ self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
101
+ self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
102
+ self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
103
+ self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
104
+ self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
105
+ self.lins = nn.ModuleList(self.lins)
106
+ ########################
107
+
108
+ # Load the weights of trained LPIPS model
109
+ import inspect
110
+ import os
111
+ # /home/taruntejaneurips23/.cache/torch/hub/checkpoints/vgg16-397923af.pth
112
+ print(os.path.abspath(os.path.join(inspect.getfile(self.__init__), '..', 'weights/v%s/%s.pth' % (version, net))))
113
+ # model_path = os.path.abspath(
114
+ # os.path.join(inspect.getfile(self.__init__), '..', 'weights/v%s/%s.pth' % (version, net)))
115
+
116
+ # print('Loading model from: %s' % model_path)
117
+ # self.load_state_dict(torch.load(model_path, map_location=device), strict=False)
118
+ ########################
119
+
120
+ # Freeze all parameters
121
+ self.eval()
122
+ for param in self.parameters():
123
+ param.requires_grad = False
124
+ ########################
125
+
126
+ def forward(self, in0, in1, normalize=False):
127
+ # Scale the inputs to -1 to +1 range if needed
128
+ if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1]
129
+ in0 = 2 * in0 - 1
130
+ in1 = 2 * in1 - 1
131
+ ########################
132
+
133
+ # Normalize the inputs according to imagenet normalization
134
+ in0_input, in1_input = self.scaling_layer(in0), self.scaling_layer(in1)
135
+ ########################
136
+
137
+ # Get VGG outputs for image0 and image1
138
+ outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
139
+ feats0, feats1, diffs = {}, {}, {}
140
+ ########################
141
+
142
+ # Compute Square of Difference for each layer output
143
+ for kk in range(self.L):
144
+ feats0[kk], feats1[kk] = torch.nn.functional.normalize(outs0[kk], dim=1), torch.nn.functional.normalize(
145
+ outs1[kk])
146
+ diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
147
+ ########################
148
+
149
+ # 1x1 convolution followed by spatial average on the square differences
150
+ res = [spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)]
151
+ val = 0
152
+
153
+ # Aggregate the results of each layer
154
+ for l in range(self.L):
155
+ val += res[l]
156
+ return val
157
+
158
+
159
+ class ScalingLayer(nn.Module):
160
+ def __init__(self):
161
+ super(ScalingLayer, self).__init__()
162
+ # Imagnet normalization for (0-1)
163
+ # mean = [0.485, 0.456, 0.406]
164
+ # std = [0.229, 0.224, 0.225]
165
+ self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
166
+ self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
167
+
168
+ def forward(self, inp):
169
+ return (inp - self.shift) / self.scale
170
+
171
+
172
+ class NetLinLayer(nn.Module):
173
+ ''' A single linear layer which does a 1x1 conv '''
174
+
175
+ def __init__(self, chn_in, chn_out=1, use_dropout=False):
176
+ super(NetLinLayer, self).__init__()
177
+
178
+ layers = [nn.Dropout(), ] if (use_dropout) else []
179
+ layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
180
+ self.model = nn.Sequential(*layers)
181
+
182
+ def forward(self, x):
183
+ out = self.model(x)
184
+ return out
185
+
186
+ """Blocks"""
187
+
188
+ import torch
189
+ import numpy as np
190
+
191
+
192
+ class LinearNoiseScheduler:
193
+ r"""
194
+ Class for the linear noise scheduler that is used in DDPM.
195
+ """
196
+
197
+ def __init__(self, num_timesteps, beta_start, beta_end):
198
+
199
+ self.num_timesteps = num_timesteps
200
+ self.beta_start = beta_start
201
+ self.beta_end = beta_end
202
+ # Mimicking how compvis repo creates schedule
203
+ self.betas = (
204
+ torch.linspace(beta_start ** 0.5, beta_end ** 0.5, num_timesteps) ** 2
205
+ )
206
+ self.alphas = 1. - self.betas
207
+ self.alpha_cum_prod = torch.cumprod(self.alphas, dim=0)
208
+ self.sqrt_alpha_cum_prod = torch.sqrt(self.alpha_cum_prod)
209
+ self.sqrt_one_minus_alpha_cum_prod = torch.sqrt(1 - self.alpha_cum_prod)
210
+
211
+ def add_noise(self, original, noise, t):
212
+ r"""
213
+ Forward method for diffusion
214
+ :param original: Image on which noise is to be applied
215
+ :param noise: Random Noise Tensor (from normal dist)
216
+ :param t: timestep of the forward process of shape -> (B,)
217
+ :return:
218
+ """
219
+ original_shape = original.shape
220
+ batch_size = original_shape[0]
221
+
222
+ sqrt_alpha_cum_prod = self.sqrt_alpha_cum_prod.to(original.device)[t].reshape(batch_size)
223
+ sqrt_one_minus_alpha_cum_prod = self.sqrt_one_minus_alpha_cum_prod.to(original.device)[t].reshape(batch_size)
224
+
225
+ # Reshape till (B,) becomes (B,1,1,1) if image is (B,C,H,W)
226
+ for _ in range(len(original_shape) - 1):
227
+ sqrt_alpha_cum_prod = sqrt_alpha_cum_prod.unsqueeze(-1)
228
+ for _ in range(len(original_shape) - 1):
229
+ sqrt_one_minus_alpha_cum_prod = sqrt_one_minus_alpha_cum_prod.unsqueeze(-1)
230
+
231
+ # Apply and Return Forward process equation
232
+ return (sqrt_alpha_cum_prod.to(original.device) * original
233
+ + sqrt_one_minus_alpha_cum_prod.to(original.device) * noise)
234
+
235
+ def sample_prev_timestep(self, xt, noise_pred, t):
236
+ r"""
237
+ Use the noise prediction by model to get
238
+ xt-1 using xt and the nosie predicted
239
+ :param xt: current timestep sample
240
+ :param noise_pred: model noise prediction
241
+ :param t: current timestep we are at
242
+ :return:
243
+ """
244
+ x0 = ((xt - (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t] * noise_pred)) /
245
+ torch.sqrt(self.alpha_cum_prod.to(xt.device)[t]))
246
+ x0 = torch.clamp(x0, -1., 1.)
247
+
248
+ mean = xt - ((self.betas.to(xt.device)[t]) * noise_pred) / (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t])
249
+ mean = mean / torch.sqrt(self.alphas.to(xt.device)[t])
250
+
251
+ if t == 0:
252
+ return mean, x0
253
+ else:
254
+ variance = (1 - self.alpha_cum_prod.to(xt.device)[t - 1]) / (1.0 - self.alpha_cum_prod.to(xt.device)[t])
255
+ variance = variance * self.betas.to(xt.device)[t]
256
+ sigma = variance ** 0.5
257
+ z = torch.randn(xt.shape).to(xt.device)
258
+
259
+ # OR
260
+ # variance = self.betas[t]
261
+ # sigma = variance ** 0.5
262
+ # z = torch.randn(xt.shape).to(xt.device)
263
+ return mean + sigma * z, x0
264
+
265
+
266
+ import torch
267
+ import math
268
+
269
+ class CosineNoiseScheduler:
270
+ r"""
271
+ Class for the cosine noise scheduler, often used in DDPM-based models.
272
+ """
273
+
274
+ def __init__(self, num_timesteps, s=0.008):
275
+ self.num_timesteps = num_timesteps
276
+ self.s = s
277
+
278
+ # Cosine schedule based on paper
279
+ def cosine_schedule(t):
280
+ return math.cos((t / self.num_timesteps + s) / (1 + s) * math.pi / 2) ** 2
281
+
282
+ # Compute alphas
283
+ self.alphas = torch.tensor([cosine_schedule(t) for t in range(num_timesteps)])
284
+ self.alpha_cum_prod = torch.cumprod(self.alphas, dim=0)
285
+ self.sqrt_alpha_cum_prod = torch.sqrt(self.alpha_cum_prod)
286
+ self.sqrt_one_minus_alpha_cum_prod = torch.sqrt(1 - self.alpha_cum_prod)
287
+
288
+ def add_noise(self, original, noise, t):
289
+ original_shape = original.shape
290
+ batch_size = original_shape[0]
291
+
292
+ sqrt_alpha_cum_prod = self.sqrt_alpha_cum_prod.to(original.device)[t].reshape(batch_size)
293
+ sqrt_one_minus_alpha_cum_prod = self.sqrt_one_minus_alpha_cum_prod.to(original.device)[t].reshape(batch_size)
294
+
295
+ for _ in range(len(original_shape) - 1):
296
+ sqrt_alpha_cum_prod = sqrt_alpha_cum_prod.unsqueeze(-1)
297
+ for _ in range(len(original_shape) - 1):
298
+ sqrt_one_minus_alpha_cum_prod = sqrt_one_minus_alpha_cum_prod.unsqueeze(-1)
299
+
300
+ return (sqrt_alpha_cum_prod * original + sqrt_one_minus_alpha_cum_prod * noise)
301
+
302
+ def sample_prev_timestep(self, xt, noise_pred, t):
303
+ x0 = ((xt - (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t] * noise_pred)) /
304
+ torch.sqrt(self.alpha_cum_prod.to(xt.device)[t]))
305
+ x0 = torch.clamp(x0, -1., 1.)
306
+
307
+ mean = xt - ((1 - self.alphas.to(xt.device)[t]) * noise_pred) / (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t])
308
+ mean = mean / torch.sqrt(self.alphas.to(xt.device)[t])
309
+
310
+ if t == 0:
311
+ return mean, x0
312
+ else:
313
+ variance = (1 - self.alpha_cum_prod.to(xt.device)[t - 1]) / (1.0 - self.alpha_cum_prod.to(xt.device)[t])
314
+ variance = variance * (1 - self.alphas.to(xt.device)[t])
315
+ sigma = variance ** 0.5
316
+ z = torch.randn(xt.shape).to(xt.device)
317
+ return mean + sigma * z, x0
318
+
319
+
320
+
321
+
322
+ import torch
323
+ import torch.nn as nn
324
+
325
+
326
+ def get_time_embedding(time_steps, temb_dim):
327
+ r"""
328
+ Convert time steps tensor into an embedding using the
329
+ sinusoidal time embedding formula
330
+ :param time_steps: 1D tensor of length batch size
331
+ :param temb_dim: Dimension of the embedding
332
+ :return: BxD embedding representation of B time steps
333
+ """
334
+ assert temb_dim % 2 == 0, "time embedding dimension must be divisible by 2"
335
+
336
+ # factor = 10000^(2i/d_model)
337
+ factor = 10000 ** ((torch.arange(
338
+ start=0, end=temb_dim // 2, dtype=torch.float32, device=time_steps.device) / (temb_dim // 2))
339
+ )
340
+
341
+ # pos / factor
342
+ # timesteps B -> B, 1 -> B, temb_dim
343
+ t_emb = time_steps[:, None].repeat(1, temb_dim // 2) / factor
344
+ t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1)
345
+ return t_emb
346
+
347
+
348
+ class DownBlock(nn.Module):
349
+ r"""
350
+ Down conv block with attention.
351
+ Sequence of following block
352
+ 1. Resnet block with time embedding
353
+ 2. Attention block
354
+ 3. Downsample
355
+ """
356
+
357
+ def __init__(self, in_channels, out_channels, t_emb_dim,
358
+ down_sample, num_heads, num_layers, attn, norm_channels, cross_attn=False, context_dim=None):
359
+ super().__init__()
360
+ self.num_layers = num_layers
361
+ self.down_sample = down_sample
362
+ self.attn = attn
363
+ self.context_dim = context_dim
364
+ self.cross_attn = cross_attn
365
+ self.t_emb_dim = t_emb_dim
366
+ self.resnet_conv_first = nn.ModuleList(
367
+ [
368
+ nn.Sequential(
369
+ nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels),
370
+ nn.SiLU(),
371
+ nn.Conv2d(in_channels if i == 0 else out_channels, out_channels,
372
+ kernel_size=3, stride=1, padding=1),
373
+ )
374
+ for i in range(num_layers)
375
+ ]
376
+ )
377
+ if self.t_emb_dim is not None:
378
+ self.t_emb_layers = nn.ModuleList([
379
+ nn.Sequential(
380
+ nn.SiLU(),
381
+ nn.Linear(self.t_emb_dim, out_channels)
382
+ )
383
+ for _ in range(num_layers)
384
+ ])
385
+ self.resnet_conv_second = nn.ModuleList(
386
+ [
387
+ nn.Sequential(
388
+ nn.GroupNorm(norm_channels, out_channels),
389
+ nn.SiLU(),
390
+ nn.Conv2d(out_channels, out_channels,
391
+ kernel_size=3, stride=1, padding=1),
392
+ )
393
+ for _ in range(num_layers)
394
+ ]
395
+ )
396
+
397
+ if self.attn:
398
+ self.attention_norms = nn.ModuleList(
399
+ [nn.GroupNorm(norm_channels, out_channels)
400
+ for _ in range(num_layers)]
401
+ )
402
+
403
+ self.attentions = nn.ModuleList(
404
+ [nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
405
+ for _ in range(num_layers)]
406
+ )
407
+
408
+ if self.cross_attn:
409
+ assert context_dim is not None, "Context Dimension must be passed for cross attention"
410
+ self.cross_attention_norms = nn.ModuleList(
411
+ [nn.GroupNorm(norm_channels, out_channels)
412
+ for _ in range(num_layers)]
413
+ )
414
+ self.cross_attentions = nn.ModuleList(
415
+ [nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
416
+ for _ in range(num_layers)]
417
+ )
418
+ self.context_proj = nn.ModuleList(
419
+ [nn.Linear(context_dim, out_channels)
420
+ for _ in range(num_layers)]
421
+ )
422
+
423
+ self.residual_input_conv = nn.ModuleList(
424
+ [
425
+ nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
426
+ for i in range(num_layers)
427
+ ]
428
+ )
429
+ self.down_sample_conv = nn.Conv2d(out_channels, out_channels,
430
+ 4, 2, 1) if self.down_sample else nn.Identity()
431
+
432
+ def forward(self, x, t_emb=None, context=None):
433
+ out = x
434
+ for i in range(self.num_layers):
435
+ # Resnet block of Unet
436
+ resnet_input = out
437
+ out = self.resnet_conv_first[i](out)
438
+ if self.t_emb_dim is not None:
439
+ out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
440
+ out = self.resnet_conv_second[i](out)
441
+ out = out + self.residual_input_conv[i](resnet_input)
442
+
443
+ if self.attn:
444
+ # Attention block of Unet
445
+ batch_size, channels, h, w = out.shape
446
+ in_attn = out.reshape(batch_size, channels, h * w)
447
+ in_attn = self.attention_norms[i](in_attn)
448
+ in_attn = in_attn.transpose(1, 2)
449
+ out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
450
+ out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
451
+ out = out + out_attn
452
+
453
+ if self.cross_attn:
454
+ assert context is not None, "context cannot be None if cross attention layers are used"
455
+ batch_size, channels, h, w = out.shape
456
+ in_attn = out.reshape(batch_size, channels, h * w)
457
+ in_attn = self.cross_attention_norms[i](in_attn)
458
+ in_attn = in_attn.transpose(1, 2)
459
+ assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim
460
+ context_proj = self.context_proj[i](context)
461
+ out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj)
462
+ out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
463
+ out = out + out_attn
464
+
465
+ # Downsample
466
+ out = self.down_sample_conv(out)
467
+ return out
468
+
469
+
470
+ class MidBlock(nn.Module):
471
+ r"""
472
+ Mid conv block with attention.
473
+ Sequence of following blocks
474
+ 1. Resnet block with time embedding
475
+ 2. Attention block
476
+ 3. Resnet block with time embedding
477
+ """
478
+
479
+ def __init__(self, in_channels, out_channels, t_emb_dim, num_heads, num_layers, norm_channels, cross_attn=None, context_dim=None):
480
+ super().__init__()
481
+ self.num_layers = num_layers
482
+ self.t_emb_dim = t_emb_dim
483
+ self.context_dim = context_dim
484
+ self.cross_attn = cross_attn
485
+ self.resnet_conv_first = nn.ModuleList(
486
+ [
487
+ nn.Sequential(
488
+ nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels),
489
+ nn.SiLU(),
490
+ nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1,
491
+ padding=1),
492
+ )
493
+ for i in range(num_layers + 1)
494
+ ]
495
+ )
496
+
497
+ if self.t_emb_dim is not None:
498
+ self.t_emb_layers = nn.ModuleList([
499
+ nn.Sequential(
500
+ nn.SiLU(),
501
+ nn.Linear(t_emb_dim, out_channels)
502
+ )
503
+ for _ in range(num_layers + 1)
504
+ ])
505
+ self.resnet_conv_second = nn.ModuleList(
506
+ [
507
+ nn.Sequential(
508
+ nn.GroupNorm(norm_channels, out_channels),
509
+ nn.SiLU(),
510
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
511
+ )
512
+ for _ in range(num_layers + 1)
513
+ ]
514
+ )
515
+
516
+ self.attention_norms = nn.ModuleList(
517
+ [nn.GroupNorm(norm_channels, out_channels)
518
+ for _ in range(num_layers)]
519
+ )
520
+
521
+ self.attentions = nn.ModuleList(
522
+ [nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
523
+ for _ in range(num_layers)]
524
+ )
525
+ if self.cross_attn:
526
+ assert context_dim is not None, "Context Dimension must be passed for cross attention"
527
+ self.cross_attention_norms = nn.ModuleList(
528
+ [nn.GroupNorm(norm_channels, out_channels)
529
+ for _ in range(num_layers)]
530
+ )
531
+ self.cross_attentions = nn.ModuleList(
532
+ [nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
533
+ for _ in range(num_layers)]
534
+ )
535
+ self.context_proj = nn.ModuleList(
536
+ [nn.Linear(context_dim, out_channels)
537
+ for _ in range(num_layers)]
538
+ )
539
+ self.residual_input_conv = nn.ModuleList(
540
+ [
541
+ nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
542
+ for i in range(num_layers + 1)
543
+ ]
544
+ )
545
+
546
+ def forward(self, x, t_emb=None, context=None):
547
+ out = x
548
+
549
+ # First resnet block
550
+ resnet_input = out
551
+ out = self.resnet_conv_first[0](out)
552
+ if self.t_emb_dim is not None:
553
+ out = out + self.t_emb_layers[0](t_emb)[:, :, None, None]
554
+ out = self.resnet_conv_second[0](out)
555
+ out = out + self.residual_input_conv[0](resnet_input)
556
+
557
+ for i in range(self.num_layers):
558
+ # Attention Block
559
+ batch_size, channels, h, w = out.shape
560
+ in_attn = out.reshape(batch_size, channels, h * w)
561
+ in_attn = self.attention_norms[i](in_attn)
562
+ in_attn = in_attn.transpose(1, 2)
563
+ out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
564
+ out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
565
+ out = out + out_attn
566
+
567
+ if self.cross_attn:
568
+ assert context is not None, "context cannot be None if cross attention layers are used"
569
+ batch_size, channels, h, w = out.shape
570
+ in_attn = out.reshape(batch_size, channels, h * w)
571
+ in_attn = self.cross_attention_norms[i](in_attn)
572
+ in_attn = in_attn.transpose(1, 2)
573
+ assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim
574
+ context_proj = self.context_proj[i](context)
575
+ out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj)
576
+ out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
577
+ out = out + out_attn
578
+
579
+
580
+ # Resnet Block
581
+ resnet_input = out
582
+ out = self.resnet_conv_first[i + 1](out)
583
+ if self.t_emb_dim is not None:
584
+ out = out + self.t_emb_layers[i + 1](t_emb)[:, :, None, None]
585
+ out = self.resnet_conv_second[i + 1](out)
586
+ out = out + self.residual_input_conv[i + 1](resnet_input)
587
+
588
+ return out
589
+
590
+
591
+ class UpBlock(nn.Module):
592
+ r"""
593
+ Up conv block with attention.
594
+ Sequence of following blocks
595
+ 1. Upsample
596
+ 1. Concatenate Down block output
597
+ 2. Resnet block with time embedding
598
+ 3. Attention Block
599
+ """
600
+
601
+ def __init__(self, in_channels, out_channels, t_emb_dim,
602
+ up_sample, num_heads, num_layers, attn, norm_channels):
603
+ super().__init__()
604
+ self.num_layers = num_layers
605
+ self.up_sample = up_sample
606
+ self.t_emb_dim = t_emb_dim
607
+ self.attn = attn
608
+ self.resnet_conv_first = nn.ModuleList(
609
+ [
610
+ nn.Sequential(
611
+ nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels),
612
+ nn.SiLU(),
613
+ nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1,
614
+ padding=1),
615
+ )
616
+ for i in range(num_layers)
617
+ ]
618
+ )
619
+
620
+ if self.t_emb_dim is not None:
621
+ self.t_emb_layers = nn.ModuleList([
622
+ nn.Sequential(
623
+ nn.SiLU(),
624
+ nn.Linear(t_emb_dim, out_channels)
625
+ )
626
+ for _ in range(num_layers)
627
+ ])
628
+
629
+ self.resnet_conv_second = nn.ModuleList(
630
+ [
631
+ nn.Sequential(
632
+ nn.GroupNorm(norm_channels, out_channels),
633
+ nn.SiLU(),
634
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
635
+ )
636
+ for _ in range(num_layers)
637
+ ]
638
+ )
639
+ if self.attn:
640
+ self.attention_norms = nn.ModuleList(
641
+ [
642
+ nn.GroupNorm(norm_channels, out_channels)
643
+ for _ in range(num_layers)
644
+ ]
645
+ )
646
+
647
+ self.attentions = nn.ModuleList(
648
+ [
649
+ nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
650
+ for _ in range(num_layers)
651
+ ]
652
+ )
653
+
654
+ self.residual_input_conv = nn.ModuleList(
655
+ [
656
+ nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
657
+ for i in range(num_layers)
658
+ ]
659
+ )
660
+ self.up_sample_conv = nn.ConvTranspose2d(in_channels, in_channels,
661
+ 4, 2, 1) \
662
+ if self.up_sample else nn.Identity()
663
+
664
+ def forward(self, x, out_down=None, t_emb=None):
665
+ # Upsample
666
+ x = self.up_sample_conv(x)
667
+
668
+ # Concat with Downblock output
669
+ if out_down is not None:
670
+ x = torch.cat([x, out_down], dim=1)
671
+
672
+ out = x
673
+ for i in range(self.num_layers):
674
+ # Resnet Block
675
+ resnet_input = out
676
+ out = self.resnet_conv_first[i](out)
677
+ if self.t_emb_dim is not None:
678
+ out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
679
+ out = self.resnet_conv_second[i](out)
680
+ out = out + self.residual_input_conv[i](resnet_input)
681
+
682
+ # Self Attention
683
+ if self.attn:
684
+ batch_size, channels, h, w = out.shape
685
+ in_attn = out.reshape(batch_size, channels, h * w)
686
+ in_attn = self.attention_norms[i](in_attn)
687
+ in_attn = in_attn.transpose(1, 2)
688
+ out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
689
+ out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
690
+ out = out + out_attn
691
+ return out
692
+
693
+
694
+ class UpBlockUnet(nn.Module):
695
+ r"""
696
+ Up conv block with attention.
697
+ Sequence of following blocks
698
+ 1. Upsample
699
+ 1. Concatenate Down block output
700
+ 2. Resnet block with time embedding
701
+ 3. Attention Block
702
+ """
703
+
704
+ def __init__(self, in_channels, out_channels, t_emb_dim, up_sample,
705
+ num_heads, num_layers, norm_channels, cross_attn=False, context_dim=None):
706
+ super().__init__()
707
+ self.num_layers = num_layers
708
+ self.up_sample = up_sample
709
+ self.t_emb_dim = t_emb_dim
710
+ self.cross_attn = cross_attn
711
+ self.context_dim = context_dim
712
+ self.resnet_conv_first = nn.ModuleList(
713
+ [
714
+ nn.Sequential(
715
+ nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels),
716
+ nn.SiLU(),
717
+ nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1,
718
+ padding=1),
719
+ )
720
+ for i in range(num_layers)
721
+ ]
722
+ )
723
+
724
+ if self.t_emb_dim is not None:
725
+ self.t_emb_layers = nn.ModuleList([
726
+ nn.Sequential(
727
+ nn.SiLU(),
728
+ nn.Linear(t_emb_dim, out_channels)
729
+ )
730
+ for _ in range(num_layers)
731
+ ])
732
+
733
+ self.resnet_conv_second = nn.ModuleList(
734
+ [
735
+ nn.Sequential(
736
+ nn.GroupNorm(norm_channels, out_channels),
737
+ nn.SiLU(),
738
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
739
+ )
740
+ for _ in range(num_layers)
741
+ ]
742
+ )
743
+
744
+ self.attention_norms = nn.ModuleList(
745
+ [
746
+ nn.GroupNorm(norm_channels, out_channels)
747
+ for _ in range(num_layers)
748
+ ]
749
+ )
750
+
751
+ self.attentions = nn.ModuleList(
752
+ [
753
+ nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
754
+ for _ in range(num_layers)
755
+ ]
756
+ )
757
+
758
+ if self.cross_attn:
759
+ assert context_dim is not None, "Context Dimension must be passed for cross attention"
760
+ self.cross_attention_norms = nn.ModuleList(
761
+ [nn.GroupNorm(norm_channels, out_channels)
762
+ for _ in range(num_layers)]
763
+ )
764
+ self.cross_attentions = nn.ModuleList(
765
+ [nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
766
+ for _ in range(num_layers)]
767
+ )
768
+ self.context_proj = nn.ModuleList(
769
+ [nn.Linear(context_dim, out_channels)
770
+ for _ in range(num_layers)]
771
+ )
772
+ self.residual_input_conv = nn.ModuleList(
773
+ [
774
+ nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
775
+ for i in range(num_layers)
776
+ ]
777
+ )
778
+ self.up_sample_conv = nn.ConvTranspose2d(in_channels // 2, in_channels // 2,
779
+ 4, 2, 1) \
780
+ if self.up_sample else nn.Identity()
781
+
782
+ def forward(self, x, out_down=None, t_emb=None, context=None):
783
+ x = self.up_sample_conv(x)
784
+ if out_down is not None:
785
+ x = torch.cat([x, out_down], dim=1)
786
+
787
+ out = x
788
+ for i in range(self.num_layers):
789
+ # Resnet
790
+ resnet_input = out
791
+ out = self.resnet_conv_first[i](out)
792
+ if self.t_emb_dim is not None:
793
+ out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
794
+ out = self.resnet_conv_second[i](out)
795
+ out = out + self.residual_input_conv[i](resnet_input)
796
+ # Self Attention
797
+ batch_size, channels, h, w = out.shape
798
+ in_attn = out.reshape(batch_size, channels, h * w)
799
+ in_attn = self.attention_norms[i](in_attn)
800
+ in_attn = in_attn.transpose(1, 2)
801
+ out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
802
+ out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
803
+ out = out + out_attn
804
+ # Cross Attention
805
+ if self.cross_attn:
806
+ assert context is not None, "context cannot be None if cross attention layers are used"
807
+ batch_size, channels, h, w = out.shape
808
+ in_attn = out.reshape(batch_size, channels, h * w)
809
+ in_attn = self.cross_attention_norms[i](in_attn)
810
+ in_attn = in_attn.transpose(1, 2)
811
+ assert len(context.shape) == 3, \
812
+ "Context shape does not match B,_,CONTEXT_DIM"
813
+ assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim,\
814
+ "Context shape does not match B,_,CONTEXT_DIM"
815
+ context_proj = self.context_proj[i](context)
816
+ out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj)
817
+ out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
818
+ out = out + out_attn
819
+
820
+ return out
821
+
822
+ """Vqvae"""
823
+
824
+ import torch
825
+ import torch.nn as nn
826
+
827
+
828
+ class VQVAE(nn.Module):
829
+ def __init__(self, im_channels, model_config):
830
+ super().__init__()
831
+ self.down_channels = model_config.down_channels
832
+ self.mid_channels = model_config.mid_channels
833
+ self.down_sample = model_config.down_sample
834
+ self.num_down_layers = model_config.num_down_layers
835
+ self.num_mid_layers = model_config.num_mid_layers
836
+ self.num_up_layers = model_config.num_up_layers
837
+
838
+ # To disable attention in Downblock of Encoder and Upblock of Decoder
839
+ self.attns = model_config.attn_down
840
+
841
+ # Latent Dimension
842
+ self.z_channels = model_config.z_channels
843
+ self.codebook_size = model_config.codebook_size
844
+ self.norm_channels = model_config.norm_channels
845
+ self.num_heads = model_config.num_heads
846
+
847
+ # Assertion to validate the channel information
848
+ assert self.mid_channels[0] == self.down_channels[-1]
849
+ assert self.mid_channels[-1] == self.down_channels[-1]
850
+ assert len(self.down_sample) == len(self.down_channels) - 1
851
+ assert len(self.attns) == len(self.down_channels) - 1
852
+
853
+ # Wherever we use downsampling in encoder correspondingly use
854
+ # upsampling in decoder
855
+ self.up_sample = list(reversed(self.down_sample))
856
+
857
+ ##################### Encoder ######################
858
+ self.encoder_conv_in = nn.Conv2d(im_channels, self.down_channels[0], kernel_size=3, padding=(1, 1))
859
+
860
+ # Downblock + Midblock
861
+ self.encoder_layers = nn.ModuleList([])
862
+ for i in range(len(self.down_channels) - 1):
863
+ self.encoder_layers.append(DownBlock(self.down_channels[i], self.down_channels[i + 1],
864
+ t_emb_dim=None, down_sample=self.down_sample[i],
865
+ num_heads=self.num_heads,
866
+ num_layers=self.num_down_layers,
867
+ attn=self.attns[i],
868
+ norm_channels=self.norm_channels))
869
+
870
+ self.encoder_mids = nn.ModuleList([])
871
+ for i in range(len(self.mid_channels) - 1):
872
+ self.encoder_mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i + 1],
873
+ t_emb_dim=None,
874
+ num_heads=self.num_heads,
875
+ num_layers=self.num_mid_layers,
876
+ norm_channels=self.norm_channels))
877
+
878
+ self.encoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[-1])
879
+ self.encoder_conv_out = nn.Conv2d(self.down_channels[-1], self.z_channels, kernel_size=3, padding=1)
880
+
881
+ # Pre Quantization Convolution
882
+ self.pre_quant_conv = nn.Conv2d(self.z_channels, self.z_channels, kernel_size=1)
883
+
884
+ # Codebook
885
+ self.embedding = nn.Embedding(self.codebook_size, self.z_channels)
886
+ ####################################################
887
+
888
+ ##################### Decoder ######################
889
+
890
+ # Post Quantization Convolution
891
+ self.post_quant_conv = nn.Conv2d(self.z_channels, self.z_channels, kernel_size=1)
892
+ self.decoder_conv_in = nn.Conv2d(self.z_channels, self.mid_channels[-1], kernel_size=3, padding=(1, 1))
893
+
894
+ # Midblock + Upblock
895
+ self.decoder_mids = nn.ModuleList([])
896
+ for i in reversed(range(1, len(self.mid_channels))):
897
+ self.decoder_mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i - 1],
898
+ t_emb_dim=None,
899
+ num_heads=self.num_heads,
900
+ num_layers=self.num_mid_layers,
901
+ norm_channels=self.norm_channels))
902
+
903
+ self.decoder_layers = nn.ModuleList([])
904
+ for i in reversed(range(1, len(self.down_channels))):
905
+ self.decoder_layers.append(UpBlock(self.down_channels[i], self.down_channels[i - 1],
906
+ t_emb_dim=None, up_sample=self.down_sample[i - 1],
907
+ num_heads=self.num_heads,
908
+ num_layers=self.num_up_layers,
909
+ attn=self.attns[i-1],
910
+ norm_channels=self.norm_channels))
911
+
912
+ self.decoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[0])
913
+ self.decoder_conv_out = nn.Conv2d(self.down_channels[0], im_channels, kernel_size=3, padding=1)
914
+
915
+ def quantize(self, x):
916
+ B, C, H, W = x.shape
917
+
918
+ # B, C, H, W -> B, H, W, C
919
+ x = x.permute(0, 2, 3, 1)
920
+
921
+ # B, H, W, C -> B, H*W, C
922
+ x = x.reshape(x.size(0), -1, x.size(-1))
923
+
924
+ # Find nearest embedding/codebook vector
925
+ # dist between (B, H*W, C) and (B, K, C) -> (B, H*W, K)
926
+ dist = torch.cdist(x, self.embedding.weight[None, :].repeat((x.size(0), 1, 1)))
927
+ # (B, H*W)
928
+ min_encoding_indices = torch.argmin(dist, dim=-1)
929
+
930
+ # Replace encoder output with nearest codebook
931
+ # quant_out -> B*H*W, C
932
+ quant_out = torch.index_select(self.embedding.weight, 0, min_encoding_indices.view(-1))
933
+
934
+ # x -> B*H*W, C
935
+ x = x.reshape((-1, x.size(-1)))
936
+ commmitment_loss = torch.mean((quant_out.detach() - x) ** 2)
937
+ codebook_loss = torch.mean((quant_out - x.detach()) ** 2)
938
+ quantize_losses = {
939
+ 'codebook_loss': codebook_loss,
940
+ 'commitment_loss': commmitment_loss
941
+ }
942
+ # Straight through estimation
943
+ quant_out = x + (quant_out - x).detach()
944
+
945
+ # quant_out -> B, C, H, W
946
+ quant_out = quant_out.reshape((B, H, W, C)).permute(0, 3, 1, 2)
947
+ min_encoding_indices = min_encoding_indices.reshape((-1, quant_out.size(-2), quant_out.size(-1)))
948
+ return quant_out, quantize_losses, min_encoding_indices
949
+
950
+ def encode(self, x):
951
+ out = self.encoder_conv_in(x)
952
+ for idx, down in enumerate(self.encoder_layers):
953
+ out = down(out)
954
+ for mid in self.encoder_mids:
955
+ out = mid(out)
956
+ out = self.encoder_norm_out(out)
957
+ out = nn.SiLU()(out)
958
+ out = self.encoder_conv_out(out)
959
+ out = self.pre_quant_conv(out)
960
+ out, quant_losses, _ = self.quantize(out)
961
+ return out, quant_losses
962
+
963
+ def decode(self, z):
964
+ out = z
965
+ out = self.post_quant_conv(out)
966
+ out = self.decoder_conv_in(out)
967
+ for mid in self.decoder_mids:
968
+ out = mid(out)
969
+ for idx, up in enumerate(self.decoder_layers):
970
+ out = up(out)
971
+
972
+ out = self.decoder_norm_out(out)
973
+ out = nn.SiLU()(out)
974
+ out = self.decoder_conv_out(out)
975
+ return out
976
+
977
+ def forward(self, x):
978
+ z, quant_losses = self.encode(x)
979
+ out = self.decode(z)
980
+ return out, z, quant_losses
981
+
982
+ """Vae"""
983
+
984
+ import torch
985
+ import torch.nn as nn
986
+
987
+
988
+ class VAE(nn.Module):
989
+ def __init__(self, im_channels, model_config):
990
+ super().__init__()
991
+ self.down_channels = model_config['down_channels']
992
+ self.mid_channels = model_config['mid_channels']
993
+ self.down_sample = model_config['down_sample']
994
+ self.num_down_layers = model_config['num_down_layers']
995
+ self.num_mid_layers = model_config['num_mid_layers']
996
+ self.num_up_layers = model_config['num_up_layers']
997
+
998
+ # To disable attention in Downblock of Encoder and Upblock of Decoder
999
+ self.attns = model_config['attn_down']
1000
+
1001
+ # Latent Dimension
1002
+ self.z_channels = model_config['z_channels']
1003
+ self.norm_channels = model_config['norm_channels']
1004
+ self.num_heads = model_config['num_heads']
1005
+
1006
+ # Assertion to validate the channel information
1007
+ assert self.mid_channels[0] == self.down_channels[-1]
1008
+ assert self.mid_channels[-1] == self.down_channels[-1]
1009
+ assert len(self.down_sample) == len(self.down_channels) - 1
1010
+ assert len(self.attns) == len(self.down_channels) - 1
1011
+
1012
+ # Wherever we use downsampling in encoder correspondingly use
1013
+ # upsampling in decoder
1014
+ self.up_sample = list(reversed(self.down_sample))
1015
+
1016
+ ##################### Encoder ######################
1017
+ self.encoder_conv_in = nn.Conv2d(im_channels, self.down_channels[0], kernel_size=3, padding=(1, 1))
1018
+
1019
+ # Downblock + Midblock
1020
+ self.encoder_layers = nn.ModuleList([])
1021
+ for i in range(len(self.down_channels) - 1):
1022
+ self.encoder_layers.append(DownBlock(self.down_channels[i], self.down_channels[i + 1],
1023
+ t_emb_dim=None, down_sample=self.down_sample[i],
1024
+ num_heads=self.num_heads,
1025
+ num_layers=self.num_down_layers,
1026
+ attn=self.attns[i],
1027
+ norm_channels=self.norm_channels))
1028
+
1029
+ self.encoder_mids = nn.ModuleList([])
1030
+ for i in range(len(self.mid_channels) - 1):
1031
+ self.encoder_mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i + 1],
1032
+ t_emb_dim=None,
1033
+ num_heads=self.num_heads,
1034
+ num_layers=self.num_mid_layers,
1035
+ norm_channels=self.norm_channels))
1036
+
1037
+ self.encoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[-1])
1038
+ self.encoder_conv_out = nn.Conv2d(self.down_channels[-1], 2*self.z_channels, kernel_size=3, padding=1)
1039
+
1040
+ # Latent Dimension is 2*Latent because we are predicting mean & variance
1041
+ self.pre_quant_conv = nn.Conv2d(2*self.z_channels, 2*self.z_channels, kernel_size=1)
1042
+ ####################################################
1043
+
1044
+
1045
+ ##################### Decoder ######################
1046
+ self.post_quant_conv = nn.Conv2d(self.z_channels, self.z_channels, kernel_size=1)
1047
+ self.decoder_conv_in = nn.Conv2d(self.z_channels, self.mid_channels[-1], kernel_size=3, padding=(1, 1))
1048
+
1049
+ # Midblock + Upblock
1050
+ self.decoder_mids = nn.ModuleList([])
1051
+ for i in reversed(range(1, len(self.mid_channels))):
1052
+ self.decoder_mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i - 1],
1053
+ t_emb_dim=None,
1054
+ num_heads=self.num_heads,
1055
+ num_layers=self.num_mid_layers,
1056
+ norm_channels=self.norm_channels))
1057
+
1058
+ self.decoder_layers = nn.ModuleList([])
1059
+ for i in reversed(range(1, len(self.down_channels))):
1060
+ self.decoder_layers.append(UpBlock(self.down_channels[i], self.down_channels[i - 1],
1061
+ t_emb_dim=None, up_sample=self.down_sample[i - 1],
1062
+ num_heads=self.num_heads,
1063
+ num_layers=self.num_up_layers,
1064
+ attn=self.attns[i - 1],
1065
+ norm_channels=self.norm_channels))
1066
+
1067
+ self.decoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[0])
1068
+ self.decoder_conv_out = nn.Conv2d(self.down_channels[0], im_channels, kernel_size=3, padding=1)
1069
+
1070
+ def encode(self, x):
1071
+ out = self.encoder_conv_in(x)
1072
+ for idx, down in enumerate(self.encoder_layers):
1073
+ out = down(out)
1074
+ for mid in self.encoder_mids:
1075
+ out = mid(out)
1076
+ out = self.encoder_norm_out(out)
1077
+ out = nn.SiLU()(out)
1078
+ out = self.encoder_conv_out(out)
1079
+ out = self.pre_quant_conv(out)
1080
+ mean, logvar = torch.chunk(out, 2, dim=1)
1081
+ std = torch.exp(0.5 * logvar)
1082
+ sample = mean + std * torch.randn(mean.shape).to(device=x.device)
1083
+ return sample, out
1084
+
1085
+ def decode(self, z):
1086
+ out = z
1087
+ out = self.post_quant_conv(out)
1088
+ out = self.decoder_conv_in(out)
1089
+ for mid in self.decoder_mids:
1090
+ out = mid(out)
1091
+ for idx, up in enumerate(self.decoder_layers):
1092
+ out = up(out)
1093
+
1094
+ out = self.decoder_norm_out(out)
1095
+ out = nn.SiLU()(out)
1096
+ out = self.decoder_conv_out(out)
1097
+ return out
1098
+
1099
+ def forward(self, x):
1100
+ z, encoder_output = self.encode(x)
1101
+ out = self.decode(z)
1102
+ return out, encoder_output
1103
+
1104
+ """Discriminator"""
1105
+
1106
+ import torch
1107
+ import torch.nn as nn
1108
+
1109
+
1110
+ class Discriminator(nn.Module):
1111
+ r"""
1112
+ PatchGAN Discriminator.
1113
+ Rather than taking IMG_CHANNELSxIMG_HxIMG_W all the way to
1114
+ 1 scalar value , we instead predict grid of values.
1115
+ Where each grid is prediction of how likely
1116
+ the discriminator thinks that the image patch corresponding
1117
+ to the grid cell is real
1118
+ """
1119
+
1120
+ def __init__(self, im_channels=3,
1121
+ conv_channels=[64, 128, 256],
1122
+ kernels=[4,4,4,4],
1123
+ strides=[2,2,2,1],
1124
+ paddings=[1,1,1,1]):
1125
+ super().__init__()
1126
+ self.im_channels = im_channels
1127
+ activation = nn.LeakyReLU(0.2)
1128
+ layers_dim = [self.im_channels] + conv_channels + [1]
1129
+ self.layers = nn.ModuleList([
1130
+ nn.Sequential(
1131
+ nn.Conv2d(layers_dim[i], layers_dim[i + 1],
1132
+ kernel_size=kernels[i],
1133
+ stride=strides[i],
1134
+ padding=paddings[i],
1135
+ bias=False if i !=0 else True),
1136
+ nn.BatchNorm2d(layers_dim[i + 1]) if i != len(layers_dim) - 2 and i != 0 else nn.Identity(),
1137
+ activation if i != len(layers_dim) - 2 else nn.Identity()
1138
+ )
1139
+ for i in range(len(layers_dim) - 1)
1140
+ ])
1141
+
1142
+ def forward(self, x):
1143
+ out = x
1144
+ for layer in self.layers:
1145
+ out = layer(out)
1146
+ return out
1147
+
1148
+
1149
+ # if __name__ == '__main__':
1150
+ # x = torch.randn((2,3, 256, 256))
1151
+ # prob = Discriminator(im_channels=3)(x)
1152
+ # print(prob.shape)
1153
+
1154
+ # import os
1155
+
1156
+ # image_paths = [os.path.join("/home/taruntejaneurips23/Ashish/datasets/animefacedata/images", f)
1157
+ # for f in os.listdir("/home/taruntejaneurips23/Ashish/datasets/animefacedata/images")]
1158
+ # image_paths
1159
+
1160
+ import glob
1161
+ import os
1162
+ import torchvision
1163
+ from PIL import Image
1164
+ from tqdm import tqdm, trange
1165
+ # from utils.diffusion_utils import load_latents
1166
+ from torch.utils.data.dataset import Dataset
1167
+
1168
+ import pickle
1169
+ import glob
1170
+ import os
1171
+ import torch
1172
+
1173
+
1174
+ def load_latents(latent_path):
1175
+ r"""
1176
+ Simple utility to save latents to speed up ldm training
1177
+ :param latent_path:
1178
+ :return:
1179
+ """
1180
+ latent_maps = {}
1181
+ for fname in glob.glob(os.path.join(latent_path, '*.pkl')):
1182
+ s = pickle.load(open(fname, 'rb'))
1183
+ for k, v in s.items():
1184
+ latent_maps[k] = v[0]
1185
+ return latent_maps
1186
+
1187
+
1188
+ def drop_text_condition(text_embed, im, empty_text_embed, text_drop_prob):
1189
+ if text_drop_prob > 0:
1190
+ text_drop_mask = torch.zeros((im.shape[0]), device=im.device).float().uniform_(0,
1191
+ 1) < text_drop_prob
1192
+ assert empty_text_embed is not None, ("Text Conditioning required as well as"
1193
+ " text dropping but empty text representation not created")
1194
+ text_embed[text_drop_mask, :, :] = empty_text_embed[0]
1195
+ return text_embed
1196
+
1197
+
1198
+ def drop_image_condition(image_condition, im, im_drop_prob):
1199
+ if im_drop_prob > 0:
1200
+ im_drop_mask = torch.zeros((im.shape[0], 1, 1, 1), device=im.device).float().uniform_(0,
1201
+ 1) > im_drop_prob
1202
+ return image_condition * im_drop_mask
1203
+ else:
1204
+ return image_condition
1205
+
1206
+
1207
+ def drop_class_condition(class_condition, class_drop_prob, im):
1208
+ if class_drop_prob > 0:
1209
+ class_drop_mask = torch.zeros((im.shape[0], 1), device=im.device).float().uniform_(0,
1210
+ 1) > class_drop_prob
1211
+ return class_condition * class_drop_mask
1212
+ else:
1213
+ return class_condition
1214
+
1215
+
1216
+ class MnistDataset(Dataset):
1217
+ r"""
1218
+ Nothing special here. Just a simple dataset class for mnist images.
1219
+ Created a dataset class rather using torchvision to allow
1220
+ replacement with any other image dataset
1221
+ """
1222
+
1223
+ def __init__(self, split, im_path, im_size, im_channels,
1224
+ use_latents=False, latent_path=None, condition_config=None):
1225
+ r"""
1226
+ Init method for initializing the dataset properties
1227
+ :param split: train/test to locate the image files
1228
+ :param im_path: root folder of images
1229
+ :param im_ext: image extension. assumes all
1230
+ images would be this type.
1231
+ """
1232
+ self.split = split
1233
+ self.im_size = im_size
1234
+ self.im_channels = im_channels
1235
+
1236
+ # Should we use latents or not
1237
+ self.latent_maps = None
1238
+ self.use_latents = False
1239
+
1240
+ # Conditioning for the dataset
1241
+ self.condition_types = [] if condition_config is None else condition_config['condition_types']
1242
+
1243
+ self.images, self.labels = self.load_images(im_path)
1244
+
1245
+ # Whether to load images and call vae or to load latents
1246
+ if use_latents and latent_path is not None:
1247
+ latent_maps = load_latents(latent_path)
1248
+ if len(latent_maps) == len(self.images):
1249
+ self.use_latents = True
1250
+ self.latent_maps = latent_maps
1251
+ print('Found {} latents'.format(len(self.latent_maps)))
1252
+ else:
1253
+ print('Latents not found')
1254
+
1255
+ def load_images(self, im_path):
1256
+ r"""
1257
+ Gets all images from the path specified
1258
+ and stacks them all up
1259
+ :param im_path:
1260
+ :return:
1261
+ """
1262
+ assert os.path.exists(im_path), "images path {} does not exist".format(im_path)
1263
+ ims = []
1264
+ labels = []
1265
+ for d_name in tqdm(os.listdir(im_path)):
1266
+ fnames = glob.glob(os.path.join(im_path, d_name, '*.{}'.format('png')))
1267
+ fnames += glob.glob(os.path.join(im_path, d_name, '*.{}'.format('jpg')))
1268
+ fnames += glob.glob(os.path.join(im_path, d_name, '*.{}'.format('jpeg')))
1269
+ for fname in fnames:
1270
+ ims.append(fname)
1271
+ if 'class' in self.condition_types:
1272
+ labels.append(int(d_name))
1273
+ print('Found {} images for split {}'.format(len(ims), self.split))
1274
+ return ims, labels
1275
+
1276
+ def __len__(self):
1277
+ return len(self.images)
1278
+
1279
+ def __getitem__(self, index):
1280
+ ######## Set Conditioning Info ########
1281
+ cond_inputs = {}
1282
+ if 'class' in self.condition_types:
1283
+ cond_inputs['class'] = self.labels[index]
1284
+ #######################################
1285
+
1286
+ if self.use_latents:
1287
+ latent = self.latent_maps[self.images[index]]
1288
+ if len(self.condition_types) == 0:
1289
+ return latent
1290
+ else:
1291
+ return latent, cond_inputs
1292
+ else:
1293
+ im = Image.open(self.images[index])
1294
+ im_tensor = torchvision.transforms.ToTensor()(im)
1295
+
1296
+ # Convert input to -1 to 1 range.
1297
+ im_tensor = (2 * im_tensor) - 1
1298
+ if len(self.condition_types) == 0:
1299
+ return im_tensor
1300
+ else:
1301
+ return im_tensor, cond_inputs
1302
+
1303
+
1304
+ class AnimeFaceDataset(Dataset):
1305
+ def __init__(self, split, im_path, im_size, im_channels,
1306
+ use_latents=False, latent_path=None, condition_config=None):
1307
+
1308
+ self.split = split
1309
+ self.im_size = im_size
1310
+ self.im_channels = im_channels
1311
+
1312
+ # Should we use latents or not
1313
+ self.latent_maps = None
1314
+ self.use_latents = False
1315
+
1316
+ # Conditioning for the dataset
1317
+ self.condition_types = [] if condition_config is None else condition_config['condition_types']
1318
+
1319
+ self.images = self.load_images(im_path)
1320
+
1321
+ # Whether to load images and call vae or to load latents
1322
+ if use_latents and latent_path is not None:
1323
+ latent_maps = load_latents(latent_path)
1324
+ if len(latent_maps) == len(self.images):
1325
+ self.use_latents = True
1326
+ self.latent_maps = latent_maps
1327
+ print('Found {} latents'.format(len(self.latent_maps)))
1328
+ else:
1329
+ print('Latents not found')
1330
+
1331
+ def load_images(self, im_path):
1332
+ r"""
1333
+ Gets all images from the path specified
1334
+ and stacks them all up
1335
+ :param im_path:
1336
+ :return:
1337
+ """
1338
+ assert os.path.exists(im_path), "images path {} does not exist".format(im_path)
1339
+ # ims = []
1340
+ # labels = []
1341
+ ims = [os.path.join(im_path, f) for f in os.listdir(im_path)]
1342
+ return ims
1343
+
1344
+ def __len__(self):
1345
+ return len(self.images)
1346
+
1347
+ def __getitem__(self, index):
1348
+ ######## Set Conditioning Info ########
1349
+ # cond_inputs = {}
1350
+ # if 'class' in self.condition_types:
1351
+ # cond_inputs['class'] = self.labels[index]
1352
+ #######################################
1353
+
1354
+ if self.use_latents:
1355
+ latent = self.latent_maps[self.images[index]]
1356
+ if len(self.condition_types) == 0:
1357
+ return latent
1358
+ # else:
1359
+ # return latent, cond_inputs
1360
+ else:
1361
+ im = Image.open(self.images[index])
1362
+ im_tensor = torchvision.transforms.Compose([
1363
+ torchvision.transforms.Resize(self.im_size),
1364
+ torchvision.transforms.CenterCrop(self.im_size),
1365
+ torchvision.transforms.ToTensor(),
1366
+ ])(im)
1367
+ im.close()
1368
+ # im_tensor = torchvision.transforms.ToTensor()(im)
1369
+
1370
+ # Convert input to -1 to 1 range.
1371
+ im_tensor = (2 * im_tensor) - 1
1372
+ if len(self.condition_types) == 0:
1373
+ return im_tensor
1374
+ # else:
1375
+ # return im_tensor, cond_inputs
1376
+
1377
+
1378
+ import glob
1379
+ import os
1380
+ import random
1381
+ import torch
1382
+ import torchvision
1383
+ import numpy as np
1384
+ from PIL import Image
1385
+ from tqdm import tqdm
1386
+ from torch.utils.data.dataset import Dataset
1387
+
1388
+
1389
+ class CelebDataset(Dataset):
1390
+ def __init__(self, split, im_path, im_size, im_channels,
1391
+ use_latents=False, latent_path=None, condition_config=None):
1392
+
1393
+ self.split = split
1394
+ self.im_size = im_size
1395
+ self.im_channels = im_channels
1396
+
1397
+ # Should we use latents or not
1398
+ self.latent_maps = None
1399
+ self.use_latents = False
1400
+
1401
+ # Conditioning for the dataset
1402
+ self.condition_types = [] if condition_config is None else condition_config['condition_types']
1403
+
1404
+ self.images = self.load_images(im_path)
1405
+
1406
+ # Whether to load images and call vae or to load latents
1407
+ if use_latents and latent_path is not None:
1408
+ latent_maps = load_latents(latent_path)
1409
+ if len(latent_maps) == len(self.images):
1410
+ self.use_latents = True
1411
+ self.latent_maps = latent_maps
1412
+ print('Found {} latents'.format(len(self.latent_maps)))
1413
+ else:
1414
+ print('Latents not found')
1415
+
1416
+ def load_images(self, im_path):
1417
+ r"""
1418
+ Gets all images from the path specified
1419
+ and stacks them all up
1420
+ :param im_path:
1421
+ :return:
1422
+ """
1423
+ assert os.path.exists(im_path), "images path {} does not exist".format(im_path)
1424
+ # ims = []
1425
+ # labels = []
1426
+ ims = [os.path.join(im_path, f) for f in os.listdir(im_path)]
1427
+ return ims
1428
+
1429
+ def __len__(self):
1430
+ return len(self.images)
1431
+
1432
+ def __getitem__(self, index):
1433
+ ######## Set Conditioning Info ########
1434
+ # cond_inputs = {}
1435
+ # if 'class' in self.condition_types:
1436
+ # cond_inputs['class'] = self.labels[index]
1437
+ #######################################
1438
+
1439
+ if self.use_latents:
1440
+ latent = self.latent_maps[self.images[index]]
1441
+ if len(self.condition_types) == 0:
1442
+ return latent
1443
+ # else:
1444
+ # return latent, cond_inputs
1445
+ else:
1446
+ im = Image.open(self.images[index])
1447
+ im_tensor = torchvision.transforms.Compose([
1448
+ # torchvision.transforms.Resize(self.im_size),
1449
+ torchvision.transforms.CenterCrop(self.im_size),
1450
+ torchvision.transforms.ToTensor(),
1451
+ ])(im)
1452
+ im.close()
1453
+ # im_tensor = torchvision.transforms.ToTensor()(im)
1454
+
1455
+ # Convert input to -1 to 1 range.
1456
+ im_tensor = (2 * im_tensor) - 1
1457
+ if len(self.condition_types) == 0:
1458
+ return im_tensor
1459
+ # else:
1460
+ # return im_tensor, cond_inputs
1461
+ import pandas as pd
1462
+ class CelebHairDataset(Dataset):
1463
+ def __init__(self, split, im_path, im_size, im_channels,
1464
+ use_latents=False, latent_path=None, condition_config=None):
1465
+
1466
+ self.df = pd.read_csv("/home/taruntejaneurips23/Ashish/DDPM/hair_df_100.csv")
1467
+ self.split = split
1468
+ self.im_size = im_size
1469
+ self.im_channels = im_channels
1470
+
1471
+ # Should we use latents or not
1472
+ self.latent_maps = None
1473
+ self.use_latents = False
1474
+
1475
+ # Conditioning for the dataset
1476
+ self.condition_types = [] if condition_config is None else condition_config['condition_types']
1477
+
1478
+ self.images = self.load_images(im_path, self.df)
1479
+
1480
+ # Whether to load images and call vae or to load latents
1481
+ if use_latents and latent_path is not None:
1482
+ latent_maps = load_latents(latent_path)
1483
+ if len(latent_maps) == len(self.images):
1484
+ self.use_latents = True
1485
+ self.latent_maps = latent_maps
1486
+ print('Found {} latents'.format(len(self.latent_maps)))
1487
+ else:
1488
+ print('Latents not found')
1489
+
1490
+ def load_images(self, im_path, df):
1491
+ r"""
1492
+ Gets all images from the path specified
1493
+ and stacks them all up
1494
+ :param im_path:
1495
+ :return:
1496
+ """
1497
+ assert os.path.exists(im_path), "images path {} does not exist".format(im_path)
1498
+ # ims = []
1499
+ # labels = []
1500
+ # ims = [os.path.join(im_path, f) for f in os.listdir(im_path)]
1501
+ ims = [os.path.join(im_path, i) for i in df.image_id.values]
1502
+ return ims
1503
+
1504
+ def __len__(self):
1505
+ return len(self.images)
1506
+
1507
+ def __getitem__(self, index):
1508
+ ######## Set Conditioning Info ########
1509
+ # cond_inputs = {}
1510
+ # if 'class' in self.condition_types:
1511
+ # cond_inputs['class'] = self.labels[index]
1512
+ #######################################
1513
+
1514
+ if self.use_latents:
1515
+ latent = self.latent_maps[self.images[index]]
1516
+ if len(self.condition_types) == 0:
1517
+ return latent
1518
+ # else:
1519
+ # return latent, cond_inputs
1520
+ else:
1521
+ im = Image.open(self.images[index])
1522
+ im_tensor = torchvision.transforms.Compose([
1523
+ # torchvision.transforms.Resize(self.im_size),
1524
+ torchvision.transforms.CenterCrop(self.im_size),
1525
+ torchvision.transforms.ToTensor(),
1526
+ ])(im)
1527
+ im.close()
1528
+ # im_tensor = torchvision.transforms.ToTensor()(im)
1529
+
1530
+ # Convert input to -1 to 1 range.
1531
+ im_tensor = (2 * im_tensor) - 1
1532
+ if len(self.condition_types) == 0:
1533
+ return im_tensor
1534
+ # else:
1535
+ # return im_tensor, cond_inputs
1536
+
1537
+ #"""Train VQVAE"""...............................................................................................................................................
1538
+
1539
+ # Commented out IPython magic to ensure Python compatibility.
1540
+ import torch
1541
+ import torch.nn as nn
1542
+ import yaml
1543
+ from ashish.MTP.Vaani.LDM.scripts.dotdict import DotDict
1544
+
1545
+ config_path = "/home/taruntejaneurips23/Ashish/DDPM/_5_ldm_celeba.yaml"
1546
+ with open(config_path, 'r') as file:
1547
+ Config = yaml.safe_load(file)
1548
+
1549
+
1550
+ Config = DotDict.from_dict(Config)
1551
+ dataset_config = Config.dataset_params
1552
+ diffusion_config = Config.diffusion_params
1553
+ model_config = Config.model_params
1554
+ train_config = Config.train_params
1555
+
1556
+ import torch
1557
+ import os
1558
+ import random
1559
+ import numpy as np
1560
+ import matplotlib.pyplot as plt
1561
+ from tqdm import tqdm
1562
+ from torch.optim import Adam
1563
+ from torch.utils.data import Dataset, TensorDataset, DataLoader
1564
+ # device = 'cuda:1' if torch.cuda.is_available() else 'cpu'
1565
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
1566
+
1567
+
1568
+
1569
+ from torchvision.utils import make_grid
1570
+
1571
+ def trainVAE(Config):
1572
+
1573
+ dataset_config = Config.dataset_params
1574
+ autoencoder_config = Config.autoencoder_params
1575
+ train_config = Config.train_params
1576
+
1577
+ # Set the desired seed value #
1578
+ seed = train_config.seed
1579
+ torch.manual_seed(seed)
1580
+ np.random.seed(seed)
1581
+ random.seed(seed)
1582
+ if device == 'cuda':
1583
+ torch.cuda.manual_seed_all(seed)
1584
+ #############################
1585
+
1586
+ # Create the model and dataset #
1587
+ model = VQVAE(im_channels=dataset_config.im_channels,
1588
+ model_config=autoencoder_config).to(device)
1589
+ # model.load_state_dict(torch.load("/home/taruntejaneurips23/Ashish/DDPM/celebAhair_ldm/vqvae_autoencoder_ckpt.pth", map_location=device))
1590
+ if os.path.exists(os.path.join(train_config.task_name, train_config.vqvae_autoencoder_ckpt_name)):
1591
+ print('Loaded vae checkpoint')
1592
+ model.load_state_dict(torch.load(os.path.join(train_config.task_name, train_config.vqvae_autoencoder_ckpt_name),
1593
+ map_location=device, weights_only=True))
1594
+
1595
+ # Create the dataset
1596
+ im_dataset_cls = {
1597
+ 'mnist': MnistDataset,
1598
+ 'celebA': CelebDataset,
1599
+ 'animeface': AnimeFaceDataset,
1600
+ 'celebAhair': CelebHairDataset
1601
+ }.get(dataset_config.name)
1602
+
1603
+ im_dataset = im_dataset_cls(split='train',
1604
+ im_path=dataset_config.im_path,
1605
+ im_size=dataset_config.im_size,
1606
+ im_channels=dataset_config.im_channels)
1607
+
1608
+ data_loader = DataLoader(im_dataset,
1609
+ batch_size=train_config.autoencoder_batch_size,
1610
+ shuffle=True,
1611
+ num_workers=os.cpu_count(),
1612
+ pin_memory=True,
1613
+ drop_last=True,
1614
+ persistent_workers=True, pin_memory_device=device)
1615
+
1616
+ # Create output directories
1617
+ if not os.path.exists(train_config.task_name):
1618
+ os.mkdir(train_config.task_name)
1619
+
1620
+ num_epochs = train_config.autoencoder_epochs
1621
+
1622
+ # L1/L2 loss for Reconstruction
1623
+ recon_criterion = torch.nn.MSELoss()
1624
+ # Disc Loss can even be BCEWithLogits
1625
+ disc_criterion = torch.nn.MSELoss()
1626
+
1627
+ # No need to freeze lpips as lpips.py takes care of that
1628
+ lpips_model = LPIPS().eval().to(device)
1629
+ discriminator = Discriminator(im_channels=dataset_config.im_channels).to(device)
1630
+ # discriminator.load_state_dict(torch.load("/home/taruntejaneurips23/Ashish/DDPM/celebAhair_ldm/vqvae_discriminator_ckpt.pth", map_location=device))
1631
+ if os.path.exists(os.path.join(train_config.task_name, train_config.vqvae_discriminator_ckpt_name)):
1632
+ print('Loaded discriminator checkpoint')
1633
+ discriminator.load_state_dict(torch.load(os.path.join(train_config.task_name, train_config.vqvae_discriminator_ckpt_name),
1634
+ map_location=device, weights_only=True))
1635
+
1636
+ optimizer_d = Adam(discriminator.parameters(), lr=train_config.autoencoder_lr, betas=(0.5, 0.999))
1637
+ optimizer_g = Adam(model.parameters(), lr=train_config.autoencoder_lr, betas=(0.5, 0.999))
1638
+
1639
+ disc_step_start = train_config.disc_start
1640
+ step_count = 0
1641
+
1642
+ # This is for accumulating gradients incase the images are huge
1643
+ # And one cant afford higher batch sizes
1644
+ acc_steps = train_config.autoencoder_acc_steps
1645
+ image_save_steps = train_config.autoencoder_img_save_steps
1646
+ img_save_count = 0
1647
+
1648
+ for epoch_idx in trange(num_epochs, desc='Training VQVAE'):
1649
+ recon_losses = []
1650
+ codebook_losses = []
1651
+ #commitment_losses = []
1652
+ perceptual_losses = []
1653
+ disc_losses = []
1654
+ gen_losses = []
1655
+ losses = []
1656
+
1657
+ optimizer_g.zero_grad()
1658
+ optimizer_d.zero_grad()
1659
+
1660
+ # for im in tqdm(data_loader):
1661
+ for im in data_loader:
1662
+ step_count += 1
1663
+ im = im.float().to(device)
1664
+
1665
+ # Fetch autoencoders output(reconstructions)
1666
+ model_output = model(im)
1667
+ output, z, quantize_losses = model_output
1668
+
1669
+ # Image Saving Logic
1670
+ if step_count % image_save_steps == 0 or step_count == 1:
1671
+ sample_size = min(8, im.shape[0])
1672
+ save_output = torch.clamp(output[:sample_size], -1., 1.).detach().cpu()
1673
+ save_output = ((save_output + 1) / 2)
1674
+ save_input = ((im[:sample_size] + 1) / 2).detach().cpu()
1675
+
1676
+ grid = make_grid(torch.cat([save_input, save_output], dim=0), nrow=sample_size)
1677
+ img = torchvision.transforms.ToPILImage()(grid)
1678
+ if not os.path.exists(os.path.join(train_config.task_name,'vqvae_autoencoder_samples')):
1679
+ os.mkdir(os.path.join(train_config.task_name, 'vqvae_autoencoder_samples'))
1680
+ img.save(os.path.join(train_config.task_name,'vqvae_autoencoder_samples',
1681
+ 'current_autoencoder_sample_{}.png'.format(img_save_count)))
1682
+ img_save_count += 1
1683
+ img.close()
1684
+
1685
+ ######### Optimize Generator ##########
1686
+ # L2 Loss
1687
+ recon_loss = recon_criterion(output, im)
1688
+ recon_losses.append(recon_loss.item())
1689
+ recon_loss = recon_loss / acc_steps
1690
+ g_loss = (recon_loss +
1691
+ (train_config.codebook_weight * quantize_losses['codebook_loss'] / acc_steps) +
1692
+ (train_config.commitment_beta * quantize_losses['commitment_loss'] / acc_steps))
1693
+ codebook_losses.append(train_config.codebook_weight * quantize_losses['codebook_loss'].item())
1694
+ # Adversarial loss only if disc_step_start steps passed
1695
+ if step_count > disc_step_start:
1696
+ disc_fake_pred = discriminator(model_output[0])
1697
+ disc_fake_loss = disc_criterion(disc_fake_pred,
1698
+ torch.ones(disc_fake_pred.shape,
1699
+ device=disc_fake_pred.device))
1700
+ gen_losses.append(train_config.disc_weight * disc_fake_loss.item())
1701
+ g_loss += train_config.disc_weight * disc_fake_loss / acc_steps
1702
+ lpips_loss = torch.mean(lpips_model(output, im)) / acc_steps
1703
+ perceptual_losses.append(train_config.perceptual_weight * lpips_loss.item())
1704
+ g_loss += train_config.perceptual_weight*lpips_loss / acc_steps
1705
+ losses.append(g_loss.item())
1706
+ g_loss.backward()
1707
+ #####################################
1708
+
1709
+ ######### Optimize Discriminator #######
1710
+ if step_count > disc_step_start:
1711
+ fake = output
1712
+ disc_fake_pred = discriminator(fake.detach())
1713
+ disc_real_pred = discriminator(im)
1714
+ disc_fake_loss = disc_criterion(disc_fake_pred,
1715
+ torch.zeros(disc_fake_pred.shape,
1716
+ device=disc_fake_pred.device))
1717
+ disc_real_loss = disc_criterion(disc_real_pred,
1718
+ torch.ones(disc_real_pred.shape,
1719
+ device=disc_real_pred.device))
1720
+ disc_loss = train_config.disc_weight * (disc_fake_loss + disc_real_loss) / 2
1721
+ disc_losses.append(disc_loss.item())
1722
+ disc_loss = disc_loss / acc_steps
1723
+ disc_loss.backward()
1724
+ if step_count % acc_steps == 0:
1725
+ optimizer_d.step()
1726
+ optimizer_d.zero_grad()
1727
+ #####################################
1728
+
1729
+ if step_count % acc_steps == 0:
1730
+ optimizer_g.step()
1731
+ optimizer_g.zero_grad()
1732
+ optimizer_d.step()
1733
+ optimizer_d.zero_grad()
1734
+ optimizer_g.step()
1735
+ optimizer_g.zero_grad()
1736
+ if len(disc_losses) > 0:
1737
+ print(
1738
+ 'Finished epoch: {}/{} | Recon Loss : {:.4f} | Perceptual Loss : {:.4f} | '
1739
+ 'Codebook : {:.4f} | G Loss : {:.4f} | D Loss {:.4f}'.
1740
+ format(epoch_idx + 1,
1741
+ num_epochs,
1742
+ np.mean(recon_losses),
1743
+ np.mean(perceptual_losses),
1744
+ np.mean(codebook_losses),
1745
+ np.mean(gen_losses),
1746
+ np.mean(disc_losses)))
1747
+ else:
1748
+ print('Finished epoch: {}/{} | Recon Loss : {:.4f} | Perceptual Loss : {:.4f} | Codebook : {:.4f}'.
1749
+ format(epoch_idx + 1,
1750
+ num_epochs,
1751
+ np.mean(recon_losses),
1752
+ np.mean(perceptual_losses),
1753
+ np.mean(codebook_losses)))
1754
+
1755
+ torch.save(model.state_dict(), os.path.join(train_config.task_name,
1756
+ train_config.vqvae_autoencoder_ckpt_name))
1757
+ torch.save(discriminator.state_dict(), os.path.join(train_config.task_name,
1758
+ train_config.vqvae_discriminator_ckpt_name))
1759
+ print('Done Training...')
1760
+
1761
+
1762
+ # trainVAE(Config)
1763
+
1764
+ import torch
1765
+ import torch.nn as nn
1766
+
1767
+
1768
+ class Unet(nn.Module):
1769
+ r"""
1770
+ Unet model comprising
1771
+ Down blocks, Midblocks and Uplocks
1772
+ """
1773
+
1774
+ def __init__(self, im_channels, model_config):
1775
+ super().__init__()
1776
+ self.down_channels = model_config.down_channels
1777
+ self.mid_channels = model_config.mid_channels
1778
+ self.t_emb_dim = model_config.time_emb_dim
1779
+ self.down_sample = model_config.down_sample
1780
+ self.num_down_layers = model_config.num_down_layers
1781
+ self.num_mid_layers = model_config.num_mid_layers
1782
+ self.num_up_layers = model_config.num_up_layers
1783
+ self.attns = model_config.attn_down
1784
+ self.norm_channels = model_config.norm_channels
1785
+ self.num_heads = model_config.num_heads
1786
+ self.conv_out_channels = model_config.conv_out_channels
1787
+
1788
+ assert self.mid_channels[0] == self.down_channels[-1]
1789
+ assert self.mid_channels[-1] == self.down_channels[-2]
1790
+ assert len(self.down_sample) == len(self.down_channels) - 1
1791
+ assert len(self.attns) == len(self.down_channels) - 1
1792
+
1793
+ # Initial projection from sinusoidal time embedding
1794
+ self.t_proj = nn.Sequential(
1795
+ nn.Linear(self.t_emb_dim, self.t_emb_dim),
1796
+ nn.SiLU(),
1797
+ nn.Linear(self.t_emb_dim, self.t_emb_dim),
1798
+ )
1799
+
1800
+ self.up_sample = list(reversed(self.down_sample))
1801
+ self.conv_in = nn.Conv2d(
1802
+ im_channels, self.down_channels[0], kernel_size=3, padding=1
1803
+ )
1804
+
1805
+ # --::----- D O W N - B l O C K S ----------------::--------------::----------------
1806
+ self.downs = nn.ModuleList([])
1807
+ for i in range(len(self.down_channels) - 1):
1808
+ self.downs.append(
1809
+ DownBlock(
1810
+ self.down_channels[i],
1811
+ self.down_channels[i + 1],
1812
+ self.t_emb_dim,
1813
+ down_sample=self.down_sample[i],
1814
+ num_heads=self.num_heads,
1815
+ num_layers=self.num_down_layers,
1816
+ attn=self.attns[i],
1817
+ norm_channels=self.norm_channels,
1818
+ )
1819
+ )
1820
+
1821
+ # --::----- M I D - B l O C K S ----------------::--------------::----------------
1822
+ self.mids = nn.ModuleList([])
1823
+ for i in range(len(self.mid_channels) - 1):
1824
+ self.mids.append(
1825
+ MidBlock(
1826
+ self.mid_channels[i],
1827
+ self.mid_channels[i + 1],
1828
+ self.t_emb_dim,
1829
+ num_heads=self.num_heads,
1830
+ num_layers=self.num_mid_layers,
1831
+ norm_channels=self.norm_channels,
1832
+ )
1833
+ )
1834
+
1835
+ # --::----- U P - B l O C K S ----------------::--------------::----------------
1836
+ self.ups = nn.ModuleList([])
1837
+ for i in reversed(range(len(self.down_channels) - 1)):
1838
+ self.ups.append(
1839
+ UpBlockUnet(
1840
+ self.down_channels[i] * 2,
1841
+ self.down_channels[i - 1] if i != 0 else self.conv_out_channels,
1842
+ self.t_emb_dim,
1843
+ up_sample=self.down_sample[i],
1844
+ num_heads=self.num_heads,
1845
+ num_layers=self.num_up_layers,
1846
+ norm_channels=self.norm_channels,
1847
+ )
1848
+ )
1849
+
1850
+ self.norm_out = nn.GroupNorm(self.norm_channels, self.conv_out_channels)
1851
+ self.conv_out = nn.Conv2d(
1852
+ self.conv_out_channels, im_channels, kernel_size=3, padding=1
1853
+ )
1854
+
1855
+ def forward(self, x, t):
1856
+ # Shapes assuming downblocks are [C1, C2, C3, C4]
1857
+ # Shapes assuming midblocks are [C4, C4, C3]
1858
+ # Shapes assuming downsamples are [True, True, False]
1859
+ # B x C x H x W
1860
+ out = self.conv_in(x)
1861
+ # B x C1 x H x W
1862
+
1863
+ # t_emb -> B x t_emb_dim
1864
+ t_emb = get_time_embedding(torch.as_tensor(t).long(), self.t_emb_dim)
1865
+ t_emb = self.t_proj(t_emb)
1866
+
1867
+ # --- Down Pass ------------------
1868
+ down_outs = []
1869
+ for idx, down in enumerate(self.downs):
1870
+ down_outs.append(out)
1871
+ out = down(out, t_emb)
1872
+ # down_outs [B x C1 x H x W, B x C2 x H/2 x W/2, B x C3 x H/4 x W/4]
1873
+ # out B x C4 x H/4 x W/4
1874
+
1875
+ # --- Mid Pass ------------------
1876
+ for mid in self.mids:
1877
+ out = mid(out, t_emb)
1878
+ # out B x C3 x H/4 x W/4
1879
+
1880
+ # --- Up Pass ------------------
1881
+ for up in self.ups:
1882
+ down_out = down_outs.pop()
1883
+ out = up(out, down_out, t_emb)
1884
+ # out [B x C2 x H/4 x W/4, B x C1 x H/2 x W/2, B x 16 x H x W]
1885
+
1886
+ out = self.norm_out(out)
1887
+ out = nn.SiLU()(out)
1888
+ out = self.conv_out(out)
1889
+ # out B x C x H x W
1890
+ return out
1891
+
1892
+
1893
+ def trainLDM(Config):
1894
+
1895
+ diffusion_config = Config.diffusion_params
1896
+ dataset_config = Config.dataset_params
1897
+ diffusion_model_config = Config.ldm_params
1898
+ autoencoder_model_config = Config.autoencoder_params
1899
+ train_config = Config.train_params
1900
+
1901
+ # Create the noise scheduler
1902
+ scheduler = LinearNoiseScheduler(num_timesteps=diffusion_config.num_timesteps,
1903
+ beta_start=diffusion_config.beta_start,
1904
+ beta_end=diffusion_config.beta_end)
1905
+ # scheduler = CosineNoiseScheduler(diffusion_config.num_timesteps)
1906
+
1907
+ im_dataset_cls = {
1908
+ 'mnist': MnistDataset,
1909
+ 'celebA': CelebDataset,
1910
+ 'animeface': AnimeFaceDataset,
1911
+ 'celebAhair': CelebHairDataset
1912
+ }.get(dataset_config.name)
1913
+
1914
+ im_dataset = im_dataset_cls(split='train',
1915
+ im_path=dataset_config.im_path,
1916
+ im_size=dataset_config.im_size,
1917
+ im_channels=dataset_config.im_channels,
1918
+ use_latents=True,
1919
+ latent_path=os.path.join(train_config.task_name,
1920
+ train_config.vqvae_latent_dir_name)
1921
+ )
1922
+
1923
+ data_loader = DataLoader(im_dataset,
1924
+ batch_size=train_config.ldm_batch_size,
1925
+ shuffle=True,
1926
+ num_workers=os.cpu_count(),
1927
+ pin_memory=True,
1928
+ drop_last=False,
1929
+ persistent_workers=True, pin_memory_device=device)
1930
+
1931
+ # Instantiate the model
1932
+ model = Unet(im_channels=autoencoder_model_config.z_channels,
1933
+ model_config=diffusion_model_config).to(device)
1934
+ if os.path.exists(os.path.join(train_config.task_name, train_config.ldm_ckpt_name)):
1935
+ print('Loaded ldm checkpoint')
1936
+ model.load_state_dict(torch.load(os.path.join(train_config.task_name, train_config.ldm_ckpt_name), map_location=device, weights_only=True))
1937
+ model.train()
1938
+
1939
+ # Load VAE ONLY if latents are not to be used or are missing
1940
+ if not im_dataset.use_latents:
1941
+ print('Loading vqvae model as latents not present')
1942
+ vae = VQVAE(im_channels=dataset_config.im_channels,
1943
+ model_config=autoencoder_model_config).to(device)
1944
+ vae.eval()
1945
+ # Load vae if found
1946
+ if os.path.exists(os.path.join(train_config.task_name,
1947
+ train_config.vqvae_autoencoder_ckpt_name)):
1948
+ print('Loaded vae checkpoint')
1949
+ vae.load_state_dict(torch.load(os.path.join(train_config.task_name,
1950
+ train_config.vqvae_autoencoder_ckpt_name),
1951
+ map_location=device))
1952
+ # Specify training parameters
1953
+ num_epochs = train_config.ldm_epochs
1954
+ optimizer = Adam(model.parameters(), lr=train_config.ldm_lr)
1955
+ criterion = torch.nn.MSELoss()
1956
+
1957
+ # Run training
1958
+ if not im_dataset.use_latents:
1959
+ for param in vae.parameters():
1960
+ param.requires_grad = False
1961
+
1962
+ for epoch_idx in range(num_epochs):
1963
+ losses = []
1964
+ for im in tqdm(data_loader):
1965
+ optimizer.zero_grad()
1966
+ im = im.float().to(device)
1967
+ if not im_dataset.use_latents:
1968
+ with torch.no_grad():
1969
+ im, _ = vae.encode(im)
1970
+
1971
+ # Sample random noise
1972
+ noise = torch.randn_like(im).to(device)
1973
+
1974
+ # Sample timestep
1975
+ t = torch.randint(0, diffusion_config.num_timesteps, (im.shape[0],)).to(device)
1976
+
1977
+ # Add noise to images according to timestep
1978
+ noisy_im = scheduler.add_noise(im, noise, t)
1979
+ noise_pred = model(noisy_im, t)
1980
+
1981
+ loss = criterion(noise_pred, noise)
1982
+ losses.append(loss.item())
1983
+ loss.backward()
1984
+ optimizer.step()
1985
+ print(f'Finished epoch:{epoch_idx + 1}/{num_epochs} | Loss : {np.mean(losses):.4f}')
1986
+
1987
+ torch.save(model.state_dict(), os.path.join(train_config.task_name,
1988
+ train_config.ldm_ckpt_name))
1989
+
1990
+ # Doing Inference
1991
+ infer(Config)
1992
+
1993
+ # Checking to conntinue training
1994
+ train_continue = yaml.safe_load(open("/home/taruntejaneurips23/Ashish/DDPM/_5_ldm_celeba.yaml", 'r'))
1995
+ train_continue = DotDict.from_dict(train_continue)
1996
+ if train_continue.training._continue_ == False:
1997
+ print('Training Stoped ...')
1998
+ break
1999
+
2000
+ print('Done Training ...')
2001
+
2002
+ # trainLDM(Config)
2003
+
2004
+ # import subprocess
2005
+ # subprocess.run(f'kill {os.getpid()}', shell=True, check=True)
2006
+
2007
+ def sample(model, scheduler, train_config, diffusion_model_config,
2008
+ autoencoder_model_config, diffusion_config, dataset_config, vae):
2009
+ r"""
2010
+ Sample stepwise by going backward one timestep at a time.
2011
+ We save the x0 predictions
2012
+ """
2013
+ im_size = dataset_config.im_size // 2**sum(autoencoder_model_config.down_sample)
2014
+ xt = torch.randn((train_config.num_samples,
2015
+ autoencoder_model_config.z_channels,
2016
+ im_size,
2017
+ im_size)).to(device)
2018
+
2019
+ save_count = 0
2020
+ for i in tqdm(reversed(range(diffusion_config.num_timesteps)), total=diffusion_config.num_timesteps):
2021
+ # Get prediction of noise
2022
+ noise_pred = model(xt, torch.as_tensor(i).unsqueeze(0).to(device))
2023
+
2024
+ # Use scheduler to get x0 and xt-1
2025
+ xt, x0_pred = scheduler.sample_prev_timestep(xt, noise_pred, torch.as_tensor(i).to(device))
2026
+
2027
+ # Save x0
2028
+ #ims = torch.clamp(xt, -1., 1.).detach().cpu()
2029
+ if i == 0:
2030
+ # Decode ONLY the final iamge to save time
2031
+ ims = vae.decode(xt)
2032
+ else:
2033
+ ims = xt
2034
+
2035
+ ims = torch.clamp(ims, -1., 1.).detach().cpu()
2036
+ ims = (ims + 1) / 2
2037
+ grid = make_grid(ims, nrow=train_config.num_grid_rows)
2038
+ img = torchvision.transforms.ToPILImage()(grid)
2039
+
2040
+ if not os.path.exists(os.path.join(train_config.task_name, 'samples')):
2041
+ os.mkdir(os.path.join(train_config.task_name, 'samples'))
2042
+ img.save(os.path.join(train_config.task_name, 'samples', 'x0_{}.png'.format(i)))
2043
+ img.close()
2044
+
2045
+
2046
+ def infer(Config):
2047
+
2048
+ diffusion_config = Config.diffusion_params
2049
+ dataset_config = Config.dataset_params
2050
+ diffusion_model_config = Config.ldm_params
2051
+ autoencoder_model_config = Config.autoencoder_params
2052
+ train_config = Config.train_params
2053
+
2054
+ # Create the noise scheduler
2055
+ scheduler = LinearNoiseScheduler(num_timesteps=diffusion_config.num_timesteps,
2056
+ beta_start=diffusion_config.beta_start,
2057
+ beta_end=diffusion_config.beta_end)
2058
+ # scheduler = CosineNoiseScheduler(diffusion_config.num_timesteps)
2059
+
2060
+ model = Unet(im_channels=autoencoder_model_config.z_channels,
2061
+ model_config=diffusion_model_config).to(device)
2062
+ model.eval()
2063
+ if os.path.exists(os.path.join(train_config.task_name,
2064
+ train_config.ldm_ckpt_name)):
2065
+ print('Loaded unet checkpoint')
2066
+ model.load_state_dict(torch.load(os.path.join(train_config.task_name,
2067
+ train_config.ldm_ckpt_name),
2068
+ map_location=device))
2069
+ # Create output directories
2070
+ if not os.path.exists(train_config.task_name):
2071
+ os.mkdir(train_config.task_name)
2072
+
2073
+ vae = VQVAE(im_channels=dataset_config.im_channels,
2074
+ model_config=autoencoder_model_config).to(device)
2075
+ vae.eval()
2076
+
2077
+ # Load vae if found
2078
+ if os.path.exists(os.path.join(train_config.task_name,
2079
+ train_config.vqvae_autoencoder_ckpt_name)):
2080
+ print('Loaded vae checkpoint')
2081
+ vae.load_state_dict(torch.load(os.path.join(train_config.task_name,
2082
+ train_config.vqvae_autoencoder_ckpt_name),
2083
+ map_location=device), strict=True)
2084
+ with torch.no_grad():
2085
+ sample(model, scheduler, train_config, diffusion_model_config,
2086
+ autoencoder_model_config, diffusion_config, dataset_config, vae)
2087
+
2088
+
2089
+
2090
+ import argparse
2091
+
2092
+ def get_args():
2093
+ parser = argparse.ArgumentParser(description="Choose between train VAE, train LDM, or infer mode.")
2094
+ parser.add_argument('--mode', choices=['train_vae', 'train_ldm', 'infer'], default='infer',
2095
+ help="Mode to run: train_vae, train_ldm, or infer")
2096
+ return parser.parse_args()
2097
+
2098
+ args = get_args()
2099
+
2100
+ if args.mode == 'train_vae':
2101
+ trainVAE(Config)
2102
+ elif args.mode == 'train_ldm':
2103
+ trainLDM(Config)
2104
+ else:
2105
+ infer(Config)
2106
+
2107
+ # python _5.2_ldm_celeba_hair_cosine.py --mode train_vae
2108
+ # python _5.2_ldm_celeba_hair_cosine.py --mode train_ldm
2109
+ # python _5.2_ldm_celeba_hair_cosine.py --mode infer
2110
+
2111
+
2112
+
2113
+
2114
+ # import matplotlib.pyplot as plt
2115
+ # from PIL import Image
2116
+ # # plt.style.use('dark_background')
2117
+ # # %matplotlib inline
2118
+
2119
+ # plt.imshow(Image.open('/home/taruntejaneurips23/Ashish/DDPM/mnist_ldm/samples/x0_0.png'), cmap='gray')
2120
+
2121
+ # import matplotlib.pyplot as plt
2122
+ # import matplotlib.image as mpimg
2123
+
2124
+ # dataset_name = 'animeface_ldm'
2125
+
2126
+ # image_paths = [f'/home/taruntejaneurips23/Ashish/DDPM/{dataset_name}/samples/x0_0.png',
2127
+ # f'/home/taruntejaneurips23/Ashish/DDPM/{dataset_name}/samples/x0_1.png',
2128
+ # f'/home/taruntejaneurips23/Ashish/DDPM/{dataset_name}/samples/x0_5.png',
2129
+ # f'/home/taruntejaneurips23/Ashish/DDPM/{dataset_name}/samples/x0_100.png',
2130
+ # f'/home/taruntejaneurips23/Ashish/DDPM/{dataset_name}/samples/x0_200.png'
2131
+ # ]
2132
+
2133
+ # fig, axes = plt.subplots(1, len(image_paths), figsize=(15, 5))
2134
+
2135
+ # for i, path in enumerate(image_paths):
2136
+ # img = mpimg.imread(path)
2137
+ # axes[i].imshow(img)
2138
+ # axes[i].axis('off') # Hide axes
2139
+ # axes[i].set_title(f't = {path.split("/")[-1].split(".")[0].split("_")[-1]}')
2140
+
2141
+ # plt.tight_layout()
2142
+ # plt.show()
2143
+
2144
+ # ---------------------------------------------------------
2145
+ # ---------- T H E - E N D -------------------------------
2146
+ # ---------------------------------------------------------
2147
+
2148
+
2149
+
2150
+ def save_checkpoint(
2151
+ total_steps, epoch, model, discriminator,
2152
+ optimizer_d, optimizer_g, loss, checkpoint_path
2153
+ ):
2154
+ checkpoint = {
2155
+ "total_steps": total_steps,
2156
+ "epoch": epoch,
2157
+ "model_state_dict": model.state_dict(),
2158
+ "discriminator_state_dict": discriminator.state_dict(),
2159
+ "optimizer_d_state_dict": optimizer_d.state_dict(),
2160
+ "optimizer_g_state_dict": optimizer_g.state_dict(),
2161
+ "loss": loss,
2162
+ }
2163
+ torch.save(checkpoint, checkpoint_path)
2164
+ print(f"Checkpoint saved after {total_steps} steps at epoch {epoch}")
2165
+
2166
+
2167
+ def load_checkpoint(
2168
+ checkpoint_path, model, discriminator, optimizer_d, optimizer_g
2169
+ ):
2170
+ if os.path.exists(checkpoint_path):
2171
+ checkpoint = torch.load(checkpoint_path)
2172
+ model.load_state_dict(checkpoint["model_state_dict"])
2173
+ discriminator.load_state_dict(checkpoint["discriminator_state_dict"])
2174
+ optimizer_d.load_state_dict(checkpoint["optimizer_d_state_dict"])
2175
+ optimizer_g.load_state_dict(checkpoint["optimizer_g_state_dict"])
2176
+ total_steps = checkpoint["total_steps"]
2177
+ start_epoch = checkpoint["epoch"] + 1
2178
+ loss = checkpoint["loss"]
2179
+ print(f"Checkpoint loaded. Resuming from epoch {start_epoch}")
2180
+ return total_steps, start_epoch, loss
2181
+ else:
2182
+ print("No checkpoint found. Starting from scratch.")
2183
+ return 0, 0, None
2184
+
2185
+
2186
+ def trainVAE(Config, dataloader):
2187
+ """
2188
+ Trains a VQVAE model using the provided configuration and data loader.
2189
+ """
2190
+ # --- Configurations ----------------------------------------------------
2191
+ dataset_config = Config.dataset_params
2192
+ autoencoder_config = Config.autoencoder_params
2193
+ train_config = Config.train_params
2194
+
2195
+ seed = train_config.seed
2196
+ torch.manual_seed(seed)
2197
+ np.random.seed(seed)
2198
+ random.seed(seed)
2199
+ if device == "cuda":
2200
+ torch.cuda.manual_seed_all(seed)
2201
+
2202
+ # --- Model Initialization ----------------------------------------------
2203
+ model = VQVAE(im_channels=dataset_config.im_channels, model_config=autoencoder_config).to(device)
2204
+ discriminator = Discriminator(im_channels=dataset_config.im_channels).to(device)
2205
+
2206
+ # --- Load Checkpoints --------------------------------------------------
2207
+ checkpoint_path = os.path.join(train_config.task_name, "vqvae_checkpoint.pth")
2208
+ total_steps, start_epoch, _ = load_checkpoint(checkpoint_path, model, discriminator, None, None)
2209
+
2210
+ # --- Loss Function Initialization --------------------------------------
2211
+ recon_criterion = torch.nn.MSELoss()
2212
+ lpips_model = LPIPS().eval().to(device)
2213
+ disc_criterion = torch.nn.MSELoss()
2214
+
2215
+ # --- Optimizer Initialization ------------------------------------------
2216
+ optimizer_d = torch.optim.AdamW(discriminator.parameters(), lr=train_config.autoencoder_lr, betas=(0.5, 0.999))
2217
+ optimizer_g = torch.optim.AdamW(model.parameters(), lr=train_config.autoencoder_lr, betas=(0.5, 0.999))
2218
+
2219
+ num_epochs = train_config.autoencoder_epochs
2220
+ acc_steps = train_config.autoencoder_acc_steps
2221
+ image_save_steps = train_config.autoencoder_img_save_steps
2222
+ img_save_count = 0
2223
+
2224
+ # Create necessary directories
2225
+ os.makedirs(os.path.join(train_config.task_name, "vqvae_autoencoder_samples"), exist_ok=True)
2226
+
2227
+ # --- Training Loop -----------------------------------------------------
2228
+ for epoch_idx in range(start_epoch, num_epochs):
2229
+ recon_losses, codebook_losses, perceptual_losses, disc_losses, gen_losses = [], [], [], [], []
2230
+
2231
+ for images in dataloader:
2232
+ total_steps += 1
2233
+ images = images.to(device)
2234
+
2235
+ # Forward pass
2236
+ model_output = model(images)
2237
+ output, z, quantize_losses = model_output
2238
+
2239
+ # Save generated images periodically
2240
+ if total_steps % image_save_steps == 0 or total_steps == 1:
2241
+ sample_size = min(8, images.shape[0])
2242
+ save_output = torch.clamp(output[:sample_size], -1.0, 1.0).detach().cpu()
2243
+ save_output = (save_output + 1) / 2
2244
+ save_input = ((images[:sample_size] + 1) / 2).detach().cpu()
2245
+
2246
+ grid = make_grid(torch.cat([save_input, save_output], dim=0), nrow=sample_size)
2247
+ img = tv.transforms.ToPILImage()(grid)
2248
+ img.save(
2249
+ os.path.join(
2250
+ train_config.task_name,
2251
+ "vqvae_autoencoder_samples",
2252
+ f"current_autoencoder_sample_{img_save_count}.png",
2253
+ )
2254
+ )
2255
+ img_save_count += 1
2256
+ img.close()
2257
+
2258
+ # Reconstruction Loss
2259
+ recon_loss = recon_criterion(output, images) / acc_steps
2260
+ recon_losses.append(recon_loss.item())
2261
+
2262
+ # Generator Loss
2263
+ codebook_loss = train_config.codebook_weight * quantize_losses["codebook_loss"] / acc_steps
2264
+ perceptual_loss = train_config.perceptual_weight * lpips_model(output, images).mean() / acc_steps
2265
+ g_loss = recon_loss + codebook_loss + perceptual_loss
2266
+
2267
+ if total_steps > train_config.disc_start:
2268
+ disc_fake_pred = discriminator(output)
2269
+ gen_loss = train_config.disc_weight * disc_criterion(
2270
+ disc_fake_pred, torch.ones_like(disc_fake_pred)
2271
+ ) / acc_steps
2272
+ g_loss += gen_loss
2273
+ gen_losses.append(gen_loss.item())
2274
+
2275
+ g_loss.backward()
2276
+ optimizer_g.step()
2277
+ optimizer_g.zero_grad()
2278
+
2279
+ # Discriminator Loss
2280
+ if total_steps > train_config.disc_start:
2281
+ disc_fake_pred = discriminator(output.detach())
2282
+ disc_real_pred = discriminator(images)
2283
+ disc_fake_loss = disc_criterion(
2284
+ disc_fake_pred, torch.zeros_like(disc_fake_pred)
2285
+ ) / acc_steps
2286
+ disc_real_loss = disc_criterion(
2287
+ disc_real_pred, torch.ones_like(disc_real_pred)
2288
+ ) / acc_steps
2289
+ disc_loss = train_config.disc_weight * (disc_fake_loss + disc_real_loss) / 2
2290
+ disc_loss.backward()
2291
+ optimizer_d.step()
2292
+ optimizer_d.zero_grad()
2293
+ disc_losses.append(disc_loss.item())
2294
+
2295
+ # Save checkpoint after each epoch
2296
+ save_checkpoint(total_steps, epoch_idx, model, discriminator, optimizer_d, optimizer_g, recon_losses, checkpoint_path)
2297
+
2298
+ # Print epoch summary
2299
+ print(
2300
+ f"Epoch {epoch_idx + 1}/{num_epochs} | Recon Loss: {np.mean(recon_losses):.4f} | "
2301
+ f"Perceptual Loss: {np.mean(perceptual_losses):.4f} | Codebook Loss: {np.mean(codebook_losses):.4f} | "
2302
+ f"G Loss: {np.mean(gen_losses):.4f} | D Loss: {np.mean(disc_losses):.4f}"
2303
+ )
Vaani/LDM/scripts/SLURM-AE-Train.sh ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash -x
2
+ #SBATCH -p gpu
3
+ #SBATCH -N 1
4
+ #SBATCH --ntasks-per-node=48
5
+ #SBATCH --mem 128G
6
+ #SBATCH -t 2-00:00:00
7
+ #SBATCH -J ASHISH_AE_Train
8
+ #SBATCH -o %j.out # name of stdout output file(--output)
9
+ #SBATCH -e %j.err # name of stderr error file(--error)
10
+ cd $SLURM_WORKDIR
11
+
12
+ module purge
13
+ module load miniconda # load the module and environment
14
+ source /home/apps/miniconda3/etc/profile.d/conda.sh
15
+ conda env list
16
+ conda activate aku # load working environment
17
+
18
+ python "/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/LDM/scripts/Vaani-VQVAE-Main.py" > "/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/LDM/scripts/AE-training.log" 2>&1 # run python script
19
+
20
+ conda deactivate # deactivate environment
21
+ # end of script
Vaani/LDM/scripts/SLURM-AE-Train2.sh ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash -x
2
+ #SBATCH -p gpu
3
+ #SBATCH -N 1
4
+ #SBATCH --ntasks-per-node=48
5
+ #SBATCH --mem 128G
6
+ #SBATCH -t 10:00:00
7
+ #SBATCH -J ASHISH_AE_Train
8
+ #SBATCH -o %j.out # name of stdout output file(--output)
9
+ #SBATCH -e %j.err # name of stderr error file(--error)
10
+ cd $SLURM_WORKDIR
11
+
12
+ module purge
13
+ module load miniconda # load the module and environment
14
+ source /home/apps/miniconda3/etc/profile.d/conda.sh
15
+ conda env list
16
+ conda activate aku # load working environment
17
+
18
+ python "/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/LDM/scripts/Vaani-VQVAE-Main.py" > "/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/LDM/scripts/AE-training.log" 2>&1 # run python script
19
+
20
+ conda deactivate # deactivate environment
21
+ # end of script
Vaani/LDM/scripts/Vaani-VQVAE-Main.py ADDED
@@ -0,0 +1,1151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ==================================================================
2
+ # V Q - V A E T R A I N I N G
3
+ # ==================================================================
4
+ # Author : Ashish Kumar Uchadiya
5
+ # Created : November 3, 2024
6
+ # Description: This script implements the training of a VQ-VAE model for
7
+ # image reconstruction. It uses LPIPS (Learned Perceptual Image Patch Similarity)
8
+ # loss to capture perceptual differences and PatchGAN loss to enforce local
9
+ # realism. The model maps images to a discrete latent space and reconstructs
10
+ # high-fidelity outputs by minimizing these combined losses.
11
+ # ==================================================================
12
+ # I M P O R T S
13
+ # ==================================================================
14
+
15
+
16
+ import os
17
+ os.environ["CUDA_VISIBLE_DEVICES"] = "1"
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ import numpy as np
22
+ from collections import namedtuple
23
+
24
+ import pandas as pd
25
+ import torchvision as tv
26
+ from torchvision.transforms import v2
27
+ from tqdm.auto import tqdm, trange
28
+ import matplotlib.pyplot as plt
29
+
30
+ import yaml
31
+ import random
32
+ import datetime
33
+ import torch.hub
34
+ from torch.utils.data import Dataset, DataLoader
35
+ from torchvision.utils import make_grid
36
+
37
+ print("TIME:", datetime.datetime.now())
38
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
39
+ print("DEVICE:", device)
40
+
41
+
42
+ # ==================================================================
43
+ # H E L P E R S
44
+ # ==================================================================
45
+ from typing import Any
46
+ from argparse import Namespace
47
+ import typing
48
+
49
+
50
+ class DotDict(Namespace):
51
+ """A simple class that builds upon `argparse.Namespace`
52
+ in order to make chained attributes possible."""
53
+
54
+ def __init__(self, temp=False, key=None, parent=None) -> None:
55
+ self._temp = temp
56
+ self._key = key
57
+ self._parent = parent
58
+
59
+ def __eq__(self, other):
60
+ if not isinstance(other, DotDict):
61
+ return NotImplemented
62
+ return vars(self) == vars(other)
63
+
64
+ def __getattr__(self, __name: str) -> Any:
65
+ if __name not in self.__dict__ and not self._temp:
66
+ self.__dict__[__name] = DotDict(temp=True, key=__name, parent=self)
67
+ else:
68
+ del self._parent.__dict__[self._key]
69
+ raise AttributeError("No attribute '%s'" % __name)
70
+ return self.__dict__[__name]
71
+
72
+ def __repr__(self) -> str:
73
+ item_keys = [k for k in self.__dict__ if not k.startswith("_")]
74
+
75
+ if len(item_keys) == 0:
76
+ return "DotDict()"
77
+ elif len(item_keys) == 1:
78
+ key = item_keys[0]
79
+ val = self.__dict__[key]
80
+ return "DotDict(%s=%s)" % (key, repr(val))
81
+ else:
82
+ return "DotDict(%s)" % ", ".join(
83
+ "%s=%s" % (key, repr(val)) for key, val in self.__dict__.items()
84
+ )
85
+
86
+ @classmethod
87
+ def from_dict(cls, original: typing.Mapping[str, any]) -> "DotDict":
88
+ """Create a DotDict from a (possibly nested) dict `original`.
89
+ Warning: this method should not be used on very deeply nested inputs,
90
+ since it's recursively traversing the nested dictionary values.
91
+ """
92
+ dd = DotDict()
93
+ for key, value in original.items():
94
+ if isinstance(value, typing.Mapping):
95
+ value = cls.from_dict(value)
96
+ setattr(dd, key, value)
97
+ return dd
98
+
99
+
100
+ # ==================================================================
101
+ # L P I P S
102
+ # ==================================================================
103
+ class vgg16(nn.Module):
104
+ def __init__(self):
105
+ super(vgg16, self).__init__()
106
+ vgg_pretrained_features = tv.models.vgg16(
107
+ weights=tv.models.VGG16_Weights.IMAGENET1K_V1
108
+ ).features
109
+ self.slice1 = torch.nn.Sequential()
110
+ self.slice2 = torch.nn.Sequential()
111
+ self.slice3 = torch.nn.Sequential()
112
+ self.slice4 = torch.nn.Sequential()
113
+ self.slice5 = torch.nn.Sequential()
114
+ self.N_slices = 5
115
+ for x in range(4):
116
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
117
+ for x in range(4, 9):
118
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
119
+ for x in range(9, 16):
120
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
121
+ for x in range(16, 23):
122
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
123
+ for x in range(23, 30):
124
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
125
+
126
+ self.eval()
127
+ for param in self.parameters():
128
+ param.requires_grad = False
129
+
130
+ def forward(self, X):
131
+ h1 = self.slice1(X)
132
+ h2 = self.slice2(h1)
133
+ h3 = self.slice3(h2)
134
+ h4 = self.slice4(h3)
135
+ h5 = self.slice5(h4)
136
+ vgg_outputs = namedtuple("VggOutputs", ['h1', 'h2', 'h3', 'h4', 'h5'])
137
+ out = vgg_outputs(h1, h2, h3, h4, h5)
138
+ return out
139
+
140
+
141
+ def _spatial_average(in_tens, keepdim=True):
142
+ return in_tens.mean([2, 3], keepdim=keepdim)
143
+
144
+
145
+ def _normalize_tensor(in_feat, eps= 1e-8):
146
+ norm_factor = torch.sqrt(eps + torch.sum(in_feat**2, dim=1, keepdim=True))
147
+ return in_feat / norm_factor
148
+
149
+
150
+ class ScalingLayer(nn.Module):
151
+ def __init__(self):
152
+ super(ScalingLayer, self).__init__()
153
+ # Imagnet normalization for (0-1)
154
+ # mean = [0.485, 0.456, 0.406]
155
+ # std = [0.229, 0.224, 0.225]
156
+
157
+ self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
158
+ self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
159
+
160
+ def forward(self, inp):
161
+ return (inp - self.shift) / self.scale
162
+
163
+
164
+ class NetLinLayer(nn.Module):
165
+ ''' A single linear layer which does a 1x1 conv '''
166
+ def __init__(self, chn_in, chn_out=1, use_dropout=False):
167
+ super(NetLinLayer, self).__init__()
168
+ layers = [nn.Dropout(), ] if (use_dropout) else []
169
+ layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
170
+ self.model = nn.Sequential(*layers)
171
+
172
+ def forward(self, x):
173
+ return self.model(x)
174
+
175
+
176
+ class LPIPS(nn.Module):
177
+ def __init__(self, net='vgg', version='0.1', use_dropout=True):
178
+ super(LPIPS, self).__init__()
179
+ self.version = version
180
+ self.scaling_layer = ScalingLayer()
181
+ self.chns = [64, 128, 256, 512, 512]
182
+ self.L = len(self.chns)
183
+ self.net = vgg16()
184
+ self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
185
+ self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
186
+ self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
187
+ self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
188
+ self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
189
+ self.lins = nn.ModuleList([self.lin0, self.lin1, self.lin2, self.lin3, self.lin4])
190
+
191
+ # --- Orignal url --------------------
192
+ # weights_url = f"https://github.com/richzhang/PerceptualSimilarity/raw/master/lpips/weights/v{version}/{net}.pth"
193
+
194
+ # --- Orignal Forked url -------------
195
+ weights_url = f"https://github.com/akuresonite/PerceptualSimilarity-Forked/raw/master/lpips/weights/v{version}/{net}.pth"
196
+
197
+ # --- Orignal torchmetric url --------
198
+ # weights_url = "https://github.com/Lightning-AI/torchmetrics/raw/master/src/torchmetrics/functional/image/lpips_models/vgg.pth"
199
+
200
+ state_dict = torch.hub.load_state_dict_from_url(weights_url, map_location='cpu')
201
+ self.load_state_dict(state_dict, strict=False)
202
+
203
+ self.eval()
204
+ for param in self.parameters():
205
+ param.requires_grad = False
206
+
207
+ def forward(self, in0, in1, normalize=False):
208
+ # Scale the inputs to -1 to +1 range if input in [0,1]
209
+ if normalize:
210
+ in0 = 2 * in0 - 1
211
+ in1 = 2 * in1 - 1
212
+
213
+ in0_input, in1_input = self.scaling_layer(in0), self.scaling_layer(in1)
214
+ # in0_input, in1_input = in0, in1
215
+
216
+ outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
217
+
218
+ diffs = {}
219
+ for kk in range(self.L):
220
+ feats0 = _normalize_tensor(outs0[kk])
221
+ feats1 = _normalize_tensor(outs1[kk])
222
+ diffs[kk] = (feats0 - feats1) ** 2
223
+
224
+ res = [_spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)]
225
+ val = sum(res)
226
+ return val.reshape(-1)
227
+
228
+
229
+ # ==================================================================
230
+ # P A T C H - G A N - D I S C R I M I N A T O R
231
+ # ==================================================================
232
+ class Discriminator(nn.Module):
233
+ r"""
234
+ PatchGAN Discriminator.
235
+ Rather than taking IMG_CHANNELSxIMG_HxIMG_W all the way to
236
+ 1 scalar value , we instead predict grid of values.
237
+ Where each grid is prediction of how likely
238
+ the discriminator thinks that the image patch corresponding
239
+ to the grid cell is real
240
+ """
241
+
242
+ def __init__(
243
+ self,
244
+ im_channels=3,
245
+ conv_channels=[64, 128, 256],
246
+ kernels=[4, 4, 4, 4],
247
+ strides=[2, 2, 2, 1],
248
+ paddings=[1, 1, 1, 1],
249
+ ):
250
+ super().__init__()
251
+ self.im_channels = im_channels
252
+ activation = nn.LeakyReLU(0.2)
253
+ layers_dim = [self.im_channels] + conv_channels + [1]
254
+ self.layers = nn.ModuleList(
255
+ [
256
+ nn.Sequential(
257
+ nn.Conv2d(
258
+ layers_dim[i],
259
+ layers_dim[i + 1],
260
+ kernel_size=kernels[i],
261
+ stride=strides[i],
262
+ padding=paddings[i],
263
+ bias=False if i != 0 else True,
264
+ ),
265
+ (
266
+ nn.BatchNorm2d(layers_dim[i + 1])
267
+ if i != len(layers_dim) - 2 and i != 0
268
+ else nn.Identity()
269
+ ),
270
+ activation if i != len(layers_dim) - 2 else nn.Identity(),
271
+ )
272
+ for i in range(len(layers_dim) - 1)
273
+ ]
274
+ )
275
+
276
+ def forward(self, x):
277
+ out = x
278
+ for layer in self.layers:
279
+ out = layer(out)
280
+ return out
281
+
282
+
283
+
284
+ # ==================================================================
285
+ # D O W E - B L O C K
286
+ # ==================================================================
287
+ class DownBlock(nn.Module):
288
+ r"""
289
+ Down conv block with attention.
290
+ Sequence of following block
291
+ 1. Resnet block with time embedding
292
+ 2. Attention block
293
+ 3. Downsample
294
+ """
295
+
296
+ def __init__(
297
+ self,
298
+ in_channels,
299
+ out_channels,
300
+ t_emb_dim,
301
+ down_sample,
302
+ num_heads,
303
+ num_layers,
304
+ attn,
305
+ norm_channels,
306
+ cross_attn=False,
307
+ context_dim=None,
308
+ ):
309
+ super().__init__()
310
+ self.num_layers = num_layers
311
+ self.down_sample = down_sample
312
+ self.attn = attn
313
+ self.context_dim = context_dim
314
+ self.cross_attn = cross_attn
315
+ self.t_emb_dim = t_emb_dim
316
+ self.resnet_conv_first = nn.ModuleList(
317
+ [
318
+ nn.Sequential(
319
+ nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels),
320
+ nn.SiLU(),
321
+ nn.Conv2d(
322
+ in_channels if i == 0 else out_channels,
323
+ out_channels,
324
+ kernel_size=3,
325
+ stride=1,
326
+ padding=1,
327
+ ),
328
+ )
329
+ for i in range(num_layers)
330
+ ]
331
+ )
332
+ if self.t_emb_dim is not None:
333
+ self.t_emb_layers = nn.ModuleList(
334
+ [
335
+ nn.Sequential(nn.SiLU(), nn.Linear(self.t_emb_dim, out_channels))
336
+ for _ in range(num_layers)
337
+ ]
338
+ )
339
+ self.resnet_conv_second = nn.ModuleList(
340
+ [
341
+ nn.Sequential(
342
+ nn.GroupNorm(norm_channels, out_channels),
343
+ nn.SiLU(),
344
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
345
+ )
346
+ for _ in range(num_layers)
347
+ ]
348
+ )
349
+
350
+ if self.attn:
351
+ self.attention_norms = nn.ModuleList(
352
+ [nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)]
353
+ )
354
+
355
+ self.attentions = nn.ModuleList(
356
+ [
357
+ nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
358
+ for _ in range(num_layers)
359
+ ]
360
+ )
361
+ if self.cross_attn:
362
+ assert context_dim is not None, "Context Dimension must be passed for cross attention"
363
+ self.cross_attention_norms = nn.ModuleList(
364
+ [nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)]
365
+ )
366
+ self.cross_attentions = nn.ModuleList(
367
+ [
368
+ nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
369
+ for _ in range(num_layers)
370
+ ]
371
+ )
372
+ self.context_proj = nn.ModuleList(
373
+ [nn.Linear(context_dim, out_channels) for _ in range(num_layers)]
374
+ )
375
+ self.residual_input_conv = nn.ModuleList(
376
+ [
377
+ nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
378
+ for i in range(num_layers)
379
+ ]
380
+ )
381
+ self.down_sample_conv = (
382
+ nn.Conv2d(out_channels, out_channels, 4, 2, 1) if self.down_sample else nn.Identity()
383
+ )
384
+
385
+ def forward(self, x, t_emb=None, context=None):
386
+ out = x
387
+ for i in range(self.num_layers):
388
+ # Resnet block of Unet
389
+
390
+ resnet_input = out
391
+ out = self.resnet_conv_first[i](out)
392
+ if self.t_emb_dim is not None:
393
+ out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
394
+ out = self.resnet_conv_second[i](out)
395
+ out = out + self.residual_input_conv[i](resnet_input)
396
+
397
+ if self.attn:
398
+ # Attention block of Unet
399
+
400
+ batch_size, channels, h, w = out.shape
401
+ in_attn = out.reshape(batch_size, channels, h * w)
402
+ in_attn = self.attention_norms[i](in_attn)
403
+ in_attn = in_attn.transpose(1, 2)
404
+ out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
405
+ out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
406
+ out = out + out_attn
407
+ if self.cross_attn:
408
+ assert (
409
+ context is not None
410
+ ), "context cannot be None if cross attention layers are used"
411
+ batch_size, channels, h, w = out.shape
412
+ in_attn = out.reshape(batch_size, channels, h * w)
413
+ in_attn = self.cross_attention_norms[i](in_attn)
414
+ in_attn = in_attn.transpose(1, 2)
415
+ assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim
416
+ context_proj = self.context_proj[i](context)
417
+ out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj)
418
+ out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
419
+ out = out + out_attn
420
+ # Downsample
421
+
422
+ out = self.down_sample_conv(out)
423
+ return out
424
+
425
+
426
+
427
+ # ==================================================================
428
+ # M I D - B L O C K
429
+ # ==================================================================
430
+ class MidBlock(nn.Module):
431
+ r"""
432
+ Mid conv block with attention.
433
+ Sequence of following blocks
434
+ 1. Resnet block with time embedding
435
+ 2. Attention block
436
+ 3. Resnet block with time embedding
437
+ """
438
+
439
+ def __init__(
440
+ self,
441
+ in_channels,
442
+ out_channels,
443
+ t_emb_dim,
444
+ num_heads,
445
+ num_layers,
446
+ norm_channels,
447
+ cross_attn=None,
448
+ context_dim=None,
449
+ ):
450
+ super().__init__()
451
+ self.num_layers = num_layers
452
+ self.t_emb_dim = t_emb_dim
453
+ self.context_dim = context_dim
454
+ self.cross_attn = cross_attn
455
+ self.resnet_conv_first = nn.ModuleList(
456
+ [
457
+ nn.Sequential(
458
+ nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels),
459
+ nn.SiLU(),
460
+ nn.Conv2d(
461
+ in_channels if i == 0 else out_channels,
462
+ out_channels,
463
+ kernel_size=3,
464
+ stride=1,
465
+ padding=1,
466
+ ),
467
+ )
468
+ for i in range(num_layers + 1)
469
+ ]
470
+ )
471
+
472
+ if self.t_emb_dim is not None:
473
+ self.t_emb_layers = nn.ModuleList(
474
+ [
475
+ nn.Sequential(nn.SiLU(), nn.Linear(t_emb_dim, out_channels))
476
+ for _ in range(num_layers + 1)
477
+ ]
478
+ )
479
+ self.resnet_conv_second = nn.ModuleList(
480
+ [
481
+ nn.Sequential(
482
+ nn.GroupNorm(norm_channels, out_channels),
483
+ nn.SiLU(),
484
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
485
+ )
486
+ for _ in range(num_layers + 1)
487
+ ]
488
+ )
489
+
490
+ self.attention_norms = nn.ModuleList(
491
+ [nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)]
492
+ )
493
+
494
+ self.attentions = nn.ModuleList(
495
+ [
496
+ nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
497
+ for _ in range(num_layers)
498
+ ]
499
+ )
500
+ if self.cross_attn:
501
+ assert context_dim is not None, "Context Dimension must be passed for cross attention"
502
+ self.cross_attention_norms = nn.ModuleList(
503
+ [nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)]
504
+ )
505
+ self.cross_attentions = nn.ModuleList(
506
+ [
507
+ nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
508
+ for _ in range(num_layers)
509
+ ]
510
+ )
511
+ self.context_proj = nn.ModuleList(
512
+ [nn.Linear(context_dim, out_channels) for _ in range(num_layers)]
513
+ )
514
+ self.residual_input_conv = nn.ModuleList(
515
+ [
516
+ nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
517
+ for i in range(num_layers + 1)
518
+ ]
519
+ )
520
+
521
+ def forward(self, x, t_emb=None, context=None):
522
+ out = x
523
+
524
+ # First resnet block
525
+
526
+ resnet_input = out
527
+ out = self.resnet_conv_first[0](out)
528
+ if self.t_emb_dim is not None:
529
+ out = out + self.t_emb_layers[0](t_emb)[:, :, None, None]
530
+ out = self.resnet_conv_second[0](out)
531
+ out = out + self.residual_input_conv[0](resnet_input)
532
+
533
+ for i in range(self.num_layers):
534
+ # Attention Block
535
+
536
+ batch_size, channels, h, w = out.shape
537
+ in_attn = out.reshape(batch_size, channels, h * w)
538
+ in_attn = self.attention_norms[i](in_attn)
539
+ in_attn = in_attn.transpose(1, 2)
540
+ out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
541
+ out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
542
+ out = out + out_attn
543
+
544
+ if self.cross_attn:
545
+ assert (
546
+ context is not None
547
+ ), "context cannot be None if cross attention layers are used"
548
+ batch_size, channels, h, w = out.shape
549
+ in_attn = out.reshape(batch_size, channels, h * w)
550
+ in_attn = self.cross_attention_norms[i](in_attn)
551
+ in_attn = in_attn.transpose(1, 2)
552
+ assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim
553
+ context_proj = self.context_proj[i](context)
554
+ out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj)
555
+ out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
556
+ out = out + out_attn
557
+ # Resnet Block
558
+
559
+ resnet_input = out
560
+ out = self.resnet_conv_first[i + 1](out)
561
+ if self.t_emb_dim is not None:
562
+ out = out + self.t_emb_layers[i + 1](t_emb)[:, :, None, None]
563
+ out = self.resnet_conv_second[i + 1](out)
564
+ out = out + self.residual_input_conv[i + 1](resnet_input)
565
+ return out
566
+
567
+
568
+ # ==================================================================
569
+ # U P - B L O C K
570
+ # ==================================================================
571
+ class UpBlock(nn.Module):
572
+ r"""
573
+ Up conv block with attention.
574
+ Sequence of following blocks
575
+ 1. Upsample
576
+ 1. Concatenate Down block output
577
+ 2. Resnet block with time embedding
578
+ 3. Attention Block
579
+ """
580
+
581
+ def __init__(
582
+ self,
583
+ in_channels,
584
+ out_channels,
585
+ t_emb_dim,
586
+ up_sample,
587
+ num_heads,
588
+ num_layers,
589
+ attn,
590
+ norm_channels,
591
+ ):
592
+ super().__init__()
593
+ self.num_layers = num_layers
594
+ self.up_sample = up_sample
595
+ self.t_emb_dim = t_emb_dim
596
+ self.attn = attn
597
+ self.resnet_conv_first = nn.ModuleList(
598
+ [
599
+ nn.Sequential(
600
+ nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels),
601
+ nn.SiLU(),
602
+ nn.Conv2d(
603
+ in_channels if i == 0 else out_channels,
604
+ out_channels,
605
+ kernel_size=3,
606
+ stride=1,
607
+ padding=1,
608
+ ),
609
+ )
610
+ for i in range(num_layers)
611
+ ]
612
+ )
613
+
614
+ if self.t_emb_dim is not None:
615
+ self.t_emb_layers = nn.ModuleList(
616
+ [
617
+ nn.Sequential(nn.SiLU(), nn.Linear(t_emb_dim, out_channels))
618
+ for _ in range(num_layers)
619
+ ]
620
+ )
621
+ self.resnet_conv_second = nn.ModuleList(
622
+ [
623
+ nn.Sequential(
624
+ nn.GroupNorm(norm_channels, out_channels),
625
+ nn.SiLU(),
626
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
627
+ )
628
+ for _ in range(num_layers)
629
+ ]
630
+ )
631
+ if self.attn:
632
+ self.attention_norms = nn.ModuleList(
633
+ [nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)]
634
+ )
635
+
636
+ self.attentions = nn.ModuleList(
637
+ [
638
+ nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
639
+ for _ in range(num_layers)
640
+ ]
641
+ )
642
+ self.residual_input_conv = nn.ModuleList(
643
+ [
644
+ nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
645
+ for i in range(num_layers)
646
+ ]
647
+ )
648
+ self.up_sample_conv = (
649
+ nn.ConvTranspose2d(in_channels, in_channels, 4, 2, 1)
650
+ if self.up_sample
651
+ else nn.Identity()
652
+ )
653
+
654
+ def forward(self, x, out_down=None, t_emb=None):
655
+ # Upsample
656
+
657
+ x = self.up_sample_conv(x)
658
+
659
+ # Concat with Downblock output
660
+
661
+ if out_down is not None:
662
+ x = torch.cat([x, out_down], dim=1)
663
+ out = x
664
+ for i in range(self.num_layers):
665
+ # Resnet Block
666
+
667
+ resnet_input = out
668
+ out = self.resnet_conv_first[i](out)
669
+ if self.t_emb_dim is not None:
670
+ out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
671
+ out = self.resnet_conv_second[i](out)
672
+ out = out + self.residual_input_conv[i](resnet_input)
673
+
674
+ # Self Attention
675
+
676
+ if self.attn:
677
+ batch_size, channels, h, w = out.shape
678
+ in_attn = out.reshape(batch_size, channels, h * w)
679
+ in_attn = self.attention_norms[i](in_attn)
680
+ in_attn = in_attn.transpose(1, 2)
681
+ out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
682
+ out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
683
+ out = out + out_attn
684
+ return out
685
+
686
+
687
+ # ==================================================================
688
+ # V Q - V A E
689
+ # ==================================================================
690
+ class VQVAE(nn.Module):
691
+ def __init__(self, im_channels, model_config):
692
+ super().__init__()
693
+ self.down_channels = model_config.down_channels
694
+ self.mid_channels = model_config.mid_channels
695
+ self.down_sample = model_config.down_sample
696
+ self.num_down_layers = model_config.num_down_layers
697
+ self.num_mid_layers = model_config.num_mid_layers
698
+ self.num_up_layers = model_config.num_up_layers
699
+
700
+ # To disable attention in Downblock of Encoder and Upblock of Decoder
701
+ self.attns = model_config.attn_down
702
+
703
+ # Latent Dimension
704
+ self.z_channels = model_config.z_channels
705
+ self.codebook_size = model_config.codebook_size
706
+ self.norm_channels = model_config.norm_channels
707
+ self.num_heads = model_config.num_heads
708
+
709
+ # Assertion to validate the channel information
710
+ assert self.mid_channels[0] == self.down_channels[-1]
711
+ assert self.mid_channels[-1] == self.down_channels[-1]
712
+ assert len(self.down_sample) == len(self.down_channels) - 1
713
+ assert len(self.attns) == len(self.down_channels) - 1
714
+
715
+ # Wherever we use downsampling in encoder correspondingly use
716
+ # upsampling in decoder
717
+ self.up_sample = list(reversed(self.down_sample))
718
+
719
+ ##################### Encoder ######################
720
+ self.encoder_conv_in = nn.Conv2d(
721
+ im_channels, self.down_channels[0], kernel_size=3, padding=(1, 1)
722
+ )
723
+
724
+ # Downblock + Midblock
725
+ self.encoder_layers = nn.ModuleList([])
726
+ for i in range(len(self.down_channels) - 1):
727
+ self.encoder_layers.append(
728
+ DownBlock(
729
+ self.down_channels[i],
730
+ self.down_channels[i + 1],
731
+ t_emb_dim=None,
732
+ down_sample=self.down_sample[i],
733
+ num_heads=self.num_heads,
734
+ num_layers=self.num_down_layers,
735
+ attn=self.attns[i],
736
+ norm_channels=self.norm_channels,
737
+ )
738
+ )
739
+ self.encoder_mids = nn.ModuleList([])
740
+ for i in range(len(self.mid_channels) - 1):
741
+ self.encoder_mids.append(
742
+ MidBlock(
743
+ self.mid_channels[i],
744
+ self.mid_channels[i + 1],
745
+ t_emb_dim=None,
746
+ num_heads=self.num_heads,
747
+ num_layers=self.num_mid_layers,
748
+ norm_channels=self.norm_channels,
749
+ )
750
+ )
751
+ self.encoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[-1])
752
+ self.encoder_conv_out = nn.Conv2d(
753
+ self.down_channels[-1], self.z_channels, kernel_size=3, padding=1
754
+ )
755
+
756
+ # Pre Quantization Convolution
757
+ self.pre_quant_conv = nn.Conv2d(self.z_channels, self.z_channels, kernel_size=1)
758
+
759
+ # Codebook
760
+ self.embedding = nn.Embedding(self.codebook_size, self.z_channels)
761
+ ####################################################
762
+
763
+ ##################### Decoder ######################
764
+
765
+ # Post Quantization Convolution
766
+ self.post_quant_conv = nn.Conv2d(self.z_channels, self.z_channels, kernel_size=1)
767
+ self.decoder_conv_in = nn.Conv2d(
768
+ self.z_channels, self.mid_channels[-1], kernel_size=3, padding=(1, 1)
769
+ )
770
+
771
+ # Midblock + Upblock
772
+ self.decoder_mids = nn.ModuleList([])
773
+ for i in reversed(range(1, len(self.mid_channels))):
774
+ self.decoder_mids.append(
775
+ MidBlock(
776
+ self.mid_channels[i],
777
+ self.mid_channels[i - 1],
778
+ t_emb_dim=None,
779
+ num_heads=self.num_heads,
780
+ num_layers=self.num_mid_layers,
781
+ norm_channels=self.norm_channels,
782
+ )
783
+ )
784
+ self.decoder_layers = nn.ModuleList([])
785
+ for i in reversed(range(1, len(self.down_channels))):
786
+ self.decoder_layers.append(
787
+ UpBlock(
788
+ self.down_channels[i],
789
+ self.down_channels[i - 1],
790
+ t_emb_dim=None,
791
+ up_sample=self.down_sample[i - 1],
792
+ num_heads=self.num_heads,
793
+ num_layers=self.num_up_layers,
794
+ attn=self.attns[i - 1],
795
+ norm_channels=self.norm_channels,
796
+ )
797
+ )
798
+ self.decoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[0])
799
+ self.decoder_conv_out = nn.Conv2d(
800
+ self.down_channels[0], im_channels, kernel_size=3, padding=1
801
+ )
802
+
803
+ def quantize(self, x):
804
+ B, C, H, W = x.shape
805
+
806
+ # B, C, H, W -> B, H, W, C
807
+ x = x.permute(0, 2, 3, 1)
808
+
809
+ # B, H, W, C -> B, H*W, C
810
+ x = x.reshape(x.size(0), -1, x.size(-1))
811
+
812
+ # Find nearest embedding/codebook vector
813
+ # dist between (B, H*W, C) and (B, K, C) -> (B, H*W, K)
814
+ dist = torch.cdist(x, self.embedding.weight[None, :].repeat((x.size(0), 1, 1)))
815
+ # (B, H*W)
816
+ min_encoding_indices = torch.argmin(dist, dim=-1)
817
+
818
+ # Replace encoder output with nearest codebook
819
+ # quant_out -> B*H*W, C
820
+ quant_out = torch.index_select(self.embedding.weight, 0, min_encoding_indices.view(-1))
821
+
822
+ # x -> B*H*W, C
823
+ x = x.reshape((-1, x.size(-1)))
824
+ commmitment_loss = torch.mean((quant_out.detach() - x) ** 2)
825
+ codebook_loss = torch.mean((quant_out - x.detach()) ** 2)
826
+ quantize_losses = {"codebook_loss": codebook_loss, "commitment_loss": commmitment_loss}
827
+ # Straight through estimation
828
+ quant_out = x + (quant_out - x).detach()
829
+
830
+ # quant_out -> B, C, H, W
831
+ quant_out = quant_out.reshape((B, H, W, C)).permute(0, 3, 1, 2)
832
+ min_encoding_indices = min_encoding_indices.reshape(
833
+ (-1, quant_out.size(-2), quant_out.size(-1))
834
+ )
835
+ return quant_out, quantize_losses, min_encoding_indices
836
+
837
+ def encode(self, x):
838
+ out = self.encoder_conv_in(x)
839
+ for idx, down in enumerate(self.encoder_layers):
840
+ out = down(out)
841
+ for mid in self.encoder_mids:
842
+ out = mid(out)
843
+ out = self.encoder_norm_out(out)
844
+ out = nn.SiLU()(out)
845
+ out = self.encoder_conv_out(out)
846
+ out = self.pre_quant_conv(out)
847
+ out, quant_losses, _ = self.quantize(out)
848
+ return out, quant_losses
849
+
850
+ def decode(self, z):
851
+ out = z
852
+ out = self.post_quant_conv(out)
853
+ out = self.decoder_conv_in(out)
854
+ for mid in self.decoder_mids:
855
+ out = mid(out)
856
+ for idx, up in enumerate(self.decoder_layers):
857
+ out = up(out)
858
+ out = self.decoder_norm_out(out)
859
+ out = nn.SiLU()(out)
860
+ out = self.decoder_conv_out(out)
861
+ return out
862
+
863
+ def forward(self, x):
864
+ '''out: [B, 3, 256, 256]
865
+ z: [B, 3, 64, 64]
866
+ quant_losses: {
867
+ codebook_loss: 0.0681,
868
+ commitment_loss: 0.0681
869
+ }
870
+ '''
871
+ z, quant_losses = self.encode(x)
872
+ out = self.decode(z)
873
+ return out, z, quant_losses
874
+
875
+
876
+ # ==================================================================
877
+ # C O N F I G U R A T I O N
878
+ # ==================================================================
879
+ import pprint
880
+ config_path = "/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/LDM/scripts/config.yaml"
881
+ with open(config_path, 'r') as file:
882
+ Config = yaml.safe_load(file)
883
+ pprint.pprint(Config, width=120)
884
+
885
+ Config = DotDict.from_dict(Config)
886
+ dataset_config = Config.dataset_params
887
+ diffusion_config = Config.diffusion_params
888
+ model_config = Config.model_params
889
+ train_config = Config.train_params
890
+ paths = Config.paths
891
+
892
+
893
+ # ==================================================================
894
+ # V A A N I - D A T A S E T
895
+ # ==================================================================
896
+ IMAGES_PATH = paths.images_dir
897
+
898
+ def walkDIR(folder_path, include=None):
899
+ file_list = []
900
+ for root, _, files in os.walk(folder_path):
901
+ for file in files:
902
+ if include is None or any(file.endswith(ext) for ext in include):
903
+ file_list.append(os.path.join(root, file))
904
+ print("Files found:", len(file_list))
905
+ return file_list
906
+
907
+ files = walkDIR(IMAGES_PATH, include=['.png', '.jpeg', '.jpg'])
908
+ df = pd.DataFrame(files, columns=['image_path'])
909
+
910
+ class VaaniDataset(torch.utils.data.Dataset):
911
+ def __init__(self, files_paths, im_size):
912
+ self.files_paths = files_paths
913
+ self.im_size = im_size
914
+
915
+ def __len__(self):
916
+ return len(self.files_paths)
917
+
918
+ def __getitem__(self, idx):
919
+ image = tv.io.decode_image(self.files_paths[idx], mode='RGB')
920
+ image = v2.Resize((self.im_size,self.im_size))(image)
921
+ image = v2.ToDtype(torch.float32, scale=True)(image)
922
+ # image = 2*image - 1
923
+ return image
924
+
925
+ dataset = VaaniDataset(files_paths=files, im_size=dataset_config.im_size)
926
+ image = dataset[2]
927
+ print('IMAGE SHAPE:', image.shape)
928
+
929
+ dataloader = torch.utils.data.DataLoader(
930
+ dataset,
931
+ batch_size=train_config.autoencoder_batch_size,
932
+ shuffle=True,
933
+ num_workers=os.cpu_count(),
934
+ pin_memory=False,
935
+ drop_last=True,
936
+ persistent_workers=True
937
+ )
938
+
939
+ images = next(iter(dataloader))
940
+ print('BATCH SHAPE:', images.shape)
941
+
942
+
943
+ # ==================================================================
944
+ # M O D E L - I N I T I L I Z A T I O N
945
+ # ==================================================================
946
+ dataset_config = Config.dataset_params
947
+ autoencoder_config = Config.autoencoder_params
948
+ train_config = Config.train_params
949
+
950
+ model = VQVAE(im_channels=dataset_config.im_channels,
951
+ model_config=autoencoder_config).to(device)
952
+
953
+ # model_output = model(images)
954
+ # print('MODEL OUTPUT:')
955
+ # print(model_output[0].shape, model_output[1].shape, model_output[2])
956
+
957
+
958
+
959
+ # ==================================================================
960
+ # V Q - V A E - T R A I N I N G
961
+ # ==================================================================
962
+ # python your_script.py 2>&1 > training.log
963
+ import time
964
+
965
+ def format_time(t1, t2):
966
+ elapsed_time = t2 - t1
967
+ if elapsed_time < 60:
968
+ return f"{elapsed_time:.2f} seconds"
969
+ elif elapsed_time < 3600:
970
+ minutes = elapsed_time // 60
971
+ seconds = elapsed_time % 60
972
+ return f"{minutes:.0f} minutes {seconds:.2f} seconds"
973
+ elif elapsed_time < 86400:
974
+ hours = elapsed_time // 3600
975
+ remainder = elapsed_time % 3600
976
+ minutes = remainder // 60
977
+ seconds = remainder % 60
978
+ return f"{hours:.0f} hours {minutes:.0f} minutes {seconds:.2f} seconds"
979
+ else:
980
+ days = elapsed_time // 86400
981
+ remainder = elapsed_time % 86400
982
+ hours = remainder // 3600
983
+ remainder = remainder % 3600
984
+ minutes = remainder // 60
985
+ seconds = remainder % 60
986
+ return f"{days:.0f} days {hours:.0f} hours {minutes:.0f} minutes {seconds:.2f} seconds"
987
+
988
+ def save_checkpoint(
989
+ total_steps, epoch, model, discriminator, optimizer_d,
990
+ optimizer_g, metrics, checkpoint_path, logs, total_training_time
991
+ ):
992
+ checkpoint = {
993
+ "total_steps": total_steps,
994
+ "epoch": epoch,
995
+ "model_state_dict": model.state_dict(),
996
+ "discriminator_state_dict": discriminator.state_dict(),
997
+ "optimizer_d_state_dict": optimizer_d.state_dict(),
998
+ "optimizer_g_state_dict": optimizer_g.state_dict(),
999
+ "metrics": metrics,
1000
+ "logs": logs,
1001
+ "total_training_time": total_training_time
1002
+ }
1003
+ torch.save(checkpoint, checkpoint_path)
1004
+ print(f"Checkpoint saved after {total_steps} steps at epoch {epoch}")
1005
+
1006
+ def load_checkpoint(checkpoint_path, model, discriminator, optimizer_d, optimizer_g):
1007
+ if os.path.exists(checkpoint_path):
1008
+ checkpoint = torch.load(checkpoint_path, map_location=device)
1009
+ model.load_state_dict(checkpoint["model_state_dict"])
1010
+ discriminator.load_state_dict(checkpoint["discriminator_state_dict"])
1011
+ optimizer_d.load_state_dict(checkpoint["optimizer_d_state_dict"])
1012
+ optimizer_g.load_state_dict(checkpoint["optimizer_g_state_dict"])
1013
+ total_steps = checkpoint["total_steps"]
1014
+ epoch = checkpoint["epoch"]
1015
+ metrics = checkpoint["metrics"]
1016
+ logs = checkpoint.get("logs", [])
1017
+ total_training_time = checkpoint.get("total_training_time", 0)
1018
+ print(f"Checkpoint loaded. Resuming from epoch {epoch + 1}, step {total_steps}")
1019
+ return total_steps, epoch + 1, metrics, logs, total_training_time
1020
+ else:
1021
+ print("No checkpoint found. Starting from scratch.")
1022
+ return 0, 0, None, [], 0
1023
+
1024
+ def inference(model, dataset, save_path, epoch, device="cuda", sample_size=8):
1025
+ if not os.path.exists(save_path):
1026
+ os.makedirs(save_path)
1027
+
1028
+ image_tensors = []
1029
+ for i in range(sample_size):
1030
+ image_tensors.append(dataset[i].unsqueeze(0))
1031
+
1032
+ image_tensors = torch.cat(image_tensors, dim=0).to(device)
1033
+ with torch.no_grad():
1034
+ outputs, _, _ = model(image_tensors)
1035
+
1036
+ save_input = image_tensors.detach().cpu()
1037
+ save_output = outputs
1038
+
1039
+ grid = make_grid(torch.cat([save_input, save_output], dim=0), nrow=sample_size)
1040
+
1041
+ combined_image = tv.transforms.ToPILImage()(grid)
1042
+ combined_image.save(os.path.join(save_path, f"reconstructed_images_EP-{epoch}_{sample_size}.png"))
1043
+
1044
+ print(f"Reconstructed images saved at: {save_path}")
1045
+
1046
+
1047
+ def trainVAE(Config, dataloader):
1048
+ dataset_config = Config.dataset_params
1049
+ autoencoder_config = Config.autoencoder_params
1050
+ train_config = Config.train_params
1051
+ paths = Config.paths
1052
+
1053
+ seed = train_config.seed
1054
+ torch.manual_seed(seed)
1055
+ np.random.seed(seed)
1056
+ random.seed(seed)
1057
+ if device == "cuda":
1058
+ torch.cuda.manual_seed_all(seed)
1059
+
1060
+ model = VQVAE(im_channels=dataset_config.im_channels, model_config=autoencoder_config).to(device)
1061
+ discriminator = Discriminator(im_channels=dataset_config.im_channels).to(device)
1062
+
1063
+ optimizer_d = torch.optim.AdamW(discriminator.parameters(), lr=train_config.autoencoder_lr, betas=(0.5, 0.999))
1064
+ optimizer_g = torch.optim.AdamW(model.parameters(), lr=train_config.autoencoder_lr, betas=(0.5, 0.999))
1065
+
1066
+ checkpoint_path = os.path.join(train_config.task_name, "vqvaq_ckpt.pth")
1067
+ total_steps, start_epoch, metrics, logs, total_training_time = load_checkpoint(checkpoint_path, model, discriminator, optimizer_d, optimizer_g)
1068
+
1069
+ if not os.path.exists(train_config.task_name):
1070
+ os.mkdir(train_config.task_name)
1071
+
1072
+ num_epochs = train_config.autoencoder_epochs
1073
+ recon_criterion = torch.nn.MSELoss()
1074
+ disc_criterion = torch.nn.MSELoss()
1075
+ lpips_model = LPIPS().eval().to(device)
1076
+
1077
+ acc_steps = train_config.autoencoder_acc_steps
1078
+ disc_step_start = train_config.disc_start
1079
+
1080
+ start_time_total = time.time() - total_training_time
1081
+
1082
+ for epoch_idx in trange(start_epoch, num_epochs):
1083
+ start_time_epoch = time.time()
1084
+ epoch_log = []
1085
+
1086
+ for images in tqdm(dataloader):
1087
+ batch_start_time = time.time()
1088
+ total_steps += 1
1089
+
1090
+ images = images.to(device)
1091
+ model_output = model(images)
1092
+ output, z, quantize_losses = model_output
1093
+
1094
+ recon_loss = recon_criterion(output, images) / acc_steps
1095
+
1096
+ g_loss = (
1097
+ recon_loss
1098
+ + (train_config.codebook_weight * quantize_losses["codebook_loss"] / acc_steps)
1099
+ + (train_config.commitment_beta * quantize_losses["commitment_loss"] / acc_steps)
1100
+ )
1101
+
1102
+ if total_steps > disc_step_start:
1103
+ disc_fake_pred = discriminator(output)
1104
+ disc_fake_loss = disc_criterion(disc_fake_pred, torch.ones_like(disc_fake_pred))
1105
+ g_loss += train_config.disc_weight * disc_fake_loss / acc_steps
1106
+
1107
+ lpips_loss = torch.mean(lpips_model(output, images)) / acc_steps
1108
+ g_loss += train_config.perceptual_weight * lpips_loss
1109
+
1110
+ g_loss.backward()
1111
+
1112
+ if total_steps % acc_steps == 0:
1113
+ optimizer_g.step()
1114
+ optimizer_g.zero_grad()
1115
+
1116
+ if total_steps > disc_step_start:
1117
+ disc_fake_pred = discriminator(output.detach())
1118
+ disc_real_pred = discriminator(images)
1119
+ disc_loss = (disc_criterion(disc_fake_pred, torch.zeros_like(disc_fake_pred)) +
1120
+ disc_criterion(disc_real_pred, torch.ones_like(disc_real_pred))) / 2 / acc_steps
1121
+ disc_loss.backward()
1122
+
1123
+ if total_steps % acc_steps == 0:
1124
+ optimizer_d.step()
1125
+ optimizer_d.zero_grad()
1126
+
1127
+ batch_time = time.time() - batch_start_time
1128
+ epoch_log.append(format_time(0, batch_time))
1129
+
1130
+ epoch_time = time.time() - start_time_epoch
1131
+ logs.append({"epoch": epoch_idx + 1, "epoch_time": format_time(0, epoch_time), "batch_times": epoch_log})
1132
+
1133
+ total_training_time = time.time() - start_time_total
1134
+
1135
+ save_checkpoint(total_steps, epoch_idx + 1, model, discriminator, optimizer_d, optimizer_g, metrics, checkpoint_path, logs, total_training_time)
1136
+ recon_save_path = os.path.join(train_config.task_name, 'vqvae_recon')
1137
+ inference(model, dataset, recon_save_path, epoch=epoch_idx, device=device, sample_size=16)
1138
+
1139
+ print("Training completed.")
1140
+
1141
+
1142
+
1143
+
1144
+ # ==================================================================
1145
+ # S T A R T I N G - T R A I N I N G
1146
+ # ==================================================================
1147
+
1148
+ trainVAE(Config, dataloader)
1149
+
1150
+ # python Vaani-VQVAE-Main.py | tee AE-training.log
1151
+ # python Vaani-VQVAE-Main.py > AE-training.log 2>&1
Vaani/LDM/scripts/VaaniLDM/vqvaq_ckpt-15.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3204e13addde475d8203e0865947f1742ffeef2ecb828cf298a704c660a5964b
3
+ size 88345234
Vaani/LDM/scripts/VaaniLDM/vqvaq_ckpt.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5c8b43abfb2f4362a48ffd111535aaf45ef239a08496838b79f8855f95d291bc
3
+ size 93659794
Vaani/LDM/scripts/_1_Lpips.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ==================================================================
2
+ # LEARNED PERCEPTUAL IMAGE PATCH SIMILARITY ( L P I P S )
3
+ # ==================================================================
4
+ # Author : Ashish Kumar Uchadiya
5
+ # Created : January 18, 2025
6
+ # Description: LPIPS essentially computes the similarity between the
7
+ # activations of two image patches for some pre-defined network.
8
+ # This measure has been shown to match human perception well.
9
+ # A low LPIPS score means that image patches are perceptual similar.
10
+ # ==================================================================
11
+
12
+
13
+
14
+ class vgg16(torch.nn.Module):
15
+ def __init__(self, requires_grad=False, pretrained=True):
16
+ super(vgg16, self).__init__()
17
+ vgg_pretrained_features = torchvision.models.vgg16(
18
+ weights=torchvision.models.VGG16_Weights.IMAGENET1K_V1
19
+ ).features
20
+ self.slice1 = torch.nn.Sequential()
21
+ self.slice2 = torch.nn.Sequential()
22
+ self.slice3 = torch.nn.Sequential()
23
+ self.slice4 = torch.nn.Sequential()
24
+ self.slice5 = torch.nn.Sequential()
25
+ self.N_slices = 5
26
+ for x in range(4):
27
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
28
+ for x in range(4, 9):
29
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
30
+ for x in range(9, 16):
31
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
32
+ for x in range(16, 23):
33
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
34
+ for x in range(23, 30):
35
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
36
+
37
+ # Freeze vgg model
38
+ if not requires_grad:
39
+ for param in self.parameters():
40
+ param.requires_grad = False
41
+
42
+ def forward(self, X):
43
+ # Return output of vgg features
44
+ h = self.slice1(X)
45
+ h_relu1_2 = h
46
+ h = self.slice2(h)
47
+ h_relu2_2 = h
48
+ h = self.slice3(h)
49
+ h_relu3_3 = h
50
+ h = self.slice4(h)
51
+ h_relu4_3 = h
52
+ h = self.slice5(h)
53
+ h_relu5_3 = h
54
+ vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
55
+ out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
56
+ return out
Vaani/LDM/scripts/__init__.py ADDED
File without changes
Vaani/LDM/scripts/config.yaml ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset_params:
2
+ im_channels: 3
3
+ im_size: 128
4
+
5
+ paths:
6
+ images_dir: "/scratch/IITB/ai-at-ieor/23m1521/datasets/Vaani/Images"
7
+ vqvae_recon:
8
+
9
+ diffusion_params:
10
+ num_timesteps: 1000
11
+ beta_start: 0.0015
12
+ beta_end: 0.0195
13
+
14
+ ldm_params:
15
+ down_channels: [ 128, 256, 256, 256 ]
16
+ mid_channels: [ 256, 256 ]
17
+ down_sample: [ False, False, False ]
18
+ attn_down: [ True, True, True ]
19
+ time_emb_dim: 256
20
+ norm_channels: 32
21
+ num_heads: 16
22
+ conv_out_channels: 128
23
+ num_down_layers: 2
24
+ num_mid_layers: 2
25
+ num_up_layers: 2
26
+
27
+ autoencoder_params:
28
+ z_channels: 3
29
+ codebook_size: 20
30
+ down_channels: [ 32, 64, 128 ]
31
+ mid_channels: [ 128, 128 ]
32
+ down_sample: [ True, True ]
33
+ attn_down: [ False, False ]
34
+ norm_channels: 32
35
+ num_heads: 16
36
+ num_down_layers: 4
37
+ num_mid_layers: 4
38
+ num_up_layers: 4
39
+
40
+ train_params:
41
+ seed: 4422
42
+ task_name: 'VaaniLDM'
43
+ ldm_batch_size: 1
44
+ autoencoder_batch_size: 4
45
+ disc_start: 1000
46
+ disc_weight: 0.5
47
+ codebook_weight: 1
48
+ commitment_beta: 0.2
49
+ perceptual_weight: 1
50
+ kl_weight: 0.000005
51
+ ldm_epochs: 10
52
+ autoencoder_epochs: 10
53
+ num_samples: 9
54
+ num_grid_rows: 3
55
+ ldm_lr: 0.00001
56
+ autoencoder_lr: 0.0001
57
+ autoencoder_acc_steps: 1
58
+ autoencoder_img_save_steps: 8
59
+ save_latents: True
60
+ vqvae_latent_dir_name: 'vqvae_latents'
61
+ ldm_ckpt_name: 'ddpm_ckpt.pth'
62
+ vqvae_ckpt_name: 'vqvaq_ckpt.pth'
63
+
64
+ training:
65
+ _continue_: True
Vaani/LDM/scripts/dotdict.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+ from argparse import Namespace
3
+ import typing
4
+
5
+
6
+ class DotDict(Namespace):
7
+ """A simple class that builds upon `argparse.Namespace`
8
+ in order to make chained attributes possible."""
9
+
10
+ def __init__(self, temp=False, key=None, parent=None) -> None:
11
+ self._temp = temp
12
+ self._key = key
13
+ self._parent = parent
14
+
15
+ def __eq__(self, other):
16
+ if not isinstance(other, DotDict):
17
+ return NotImplemented
18
+ return vars(self) == vars(other)
19
+
20
+ def __getattr__(self, __name: str) -> Any:
21
+ if __name not in self.__dict__ and not self._temp:
22
+ self.__dict__[__name] = DotDict(temp=True, key=__name, parent=self)
23
+ else:
24
+ del self._parent.__dict__[self._key]
25
+ raise AttributeError("No attribute '%s'" % __name)
26
+ return self.__dict__[__name]
27
+
28
+ def __repr__(self) -> str:
29
+ item_keys = [k for k in self.__dict__ if not k.startswith("_")]
30
+
31
+ if len(item_keys) == 0:
32
+ return "DotDict()"
33
+ elif len(item_keys) == 1:
34
+ key = item_keys[0]
35
+ val = self.__dict__[key]
36
+ return "DotDict(%s=%s)" % (key, repr(val))
37
+ else:
38
+ return "DotDict(%s)" % ", ".join(
39
+ "%s=%s" % (key, repr(val)) for key, val in self.__dict__.items()
40
+ )
41
+
42
+ @classmethod
43
+ def from_dict(cls, original: typing.Mapping[str, any]) -> "DotDict":
44
+ """Create a DotDict from a (possibly nested) dict `original`.
45
+ Warning: this method should not be used on very deeply nested inputs,
46
+ since it's recursively traversing the nested dictionary values.
47
+ """
48
+ dd = DotDict()
49
+ for key, value in original.items():
50
+ if isinstance(value, typing.Mapping):
51
+ value = cls.from_dict(value)
52
+ setattr(dd, key, value)
53
+ return dd
Vaani/SLURM_test.sh ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash -x
2
+ #SBATCH -N 1
3
+ #SBATCH --ntasks-per-node=48
4
+ #SBATCH --mem 128G
5
+ #SBATCH -t 01:00:00
6
+ #SBATCH -J ASHISH_test_cpu
7
+ #SBATCH -o %j.out # name of stdout output file(--output)
8
+ #SBATCH -e %j.err # name of stderr error file(--error)
9
+ cd $SLURM_WORKDIR
10
+
11
+ module purge
12
+ module load miniconda # load the module and environment
13
+ source /home/apps/miniconda3/etc/profile.d/conda.sh
14
+ conda env list
15
+ conda activate aku_env # load working environment
16
+
17
+ python /home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/image_data_metadata.py # run python script
18
+
19
+ conda deactivate # deactivate environment
20
+ # end of script
Vaani/VQVAE_architecture.svg ADDED
Vaani/VQVAE_summary.txt ADDED
@@ -0,0 +1,438 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TIME: 2025-05-09 21:58:45.534412
2
+ DEVICE: cuda
3
+ {'autoencoder_params': {'attn_down': [False, False],
4
+ 'codebook_size': 20,
5
+ 'down_channels': [32, 64, 128],
6
+ 'down_sample': [True, True],
7
+ 'mid_channels': [128, 128],
8
+ 'norm_channels': 32,
9
+ 'num_down_layers': 4,
10
+ 'num_heads': 16,
11
+ 'num_mid_layers': 4,
12
+ 'num_up_layers': 4,
13
+ 'z_channels': 3},
14
+ 'dataset_params': {'im_channels': 3, 'im_size': 128},
15
+ 'diffusion_params': {'beta_end': 0.0195, 'beta_start': 0.0015, 'num_timesteps': 1000},
16
+ 'ldm_params': {'attn_down': [True, True, True],
17
+ 'conv_out_channels': 128,
18
+ 'down_channels': [128, 256, 256, 256],
19
+ 'down_sample': [False, False, False],
20
+ 'mid_channels': [256, 256],
21
+ 'norm_channels': 32,
22
+ 'num_down_layers': 2,
23
+ 'num_heads': 16,
24
+ 'num_mid_layers': 2,
25
+ 'num_up_layers': 2,
26
+ 'time_emb_dim': 256},
27
+ 'paths': {'images_dir': '/scratch/IITB/ai-at-ieor/23m1521/datasets/Vaani/Images'},
28
+ 'train_params': {'autoencoder_acc_steps': 1,
29
+ 'autoencoder_batch_size': 8,
30
+ 'autoencoder_epochs': 30,
31
+ 'autoencoder_img_save_steps': 8,
32
+ 'autoencoder_lr': 0.0001,
33
+ 'codebook_weight': 1,
34
+ 'commitment_beta': 0.2,
35
+ 'disc_start': 1000,
36
+ 'disc_weight': 0.5,
37
+ 'kl_weight': 5e-06,
38
+ 'ldm_batch_size': 1,
39
+ 'ldm_ckpt_name': 'ddpm_ckpt.pth',
40
+ 'ldm_epochs': 10,
41
+ 'ldm_lr': 1e-05,
42
+ 'num_grid_rows': 3,
43
+ 'num_samples': 9,
44
+ 'perceptual_weight': 1,
45
+ 'save_latents': True,
46
+ 'seed': 4422,
47
+ 'task_name': 'VaaniLDM',
48
+ 'vqvae_ckpt_name': 'vqvaq_ckpt.pth',
49
+ 'vqvae_latent_dir_name': 'vqvae_latents'},
50
+ 'training': {'_continue_': True}}
51
+
52
+
53
+ Files found: 128807
54
+ IMAGE SHAPE: torch.Size([3, 128, 128])
55
+ BATCH SHAPE: torch.Size([8, 3, 128, 128])
56
+
57
+
58
+ ======================================================================================================================================================
59
+ Layer (type (var_name)) Input Shape Output Shape Param # Trainable Param %
60
+ ======================================================================================================================================================
61
+ VQVAE (VQVAE) [8, 3, 128, 128] [8, 3, 128, 128] 60 True 0.00%
62
+ ├─Conv2d (encoder_conv_in) [8, 3, 128, 128] [8, 32, 128, 128] 896 True 0.01%
63
+ ├─ModuleList (encoder_layers) -- -- -- True --
64
+ │ └─DownBlock (0) [8, 32, 128, 128] [8, 64, 64, 64] -- True --
65
+ │ │ └─ModuleList (resnet_conv_first) -- -- (recursive) True (recursive)
66
+ │ │ │ └─Sequential (0) [8, 32, 128, 128] [8, 64, 128, 128] -- True --
67
+ │ │ │ │ └─GroupNorm (0) [8, 32, 128, 128] [8, 32, 128, 128] 64 True 0.00%
68
+ │ │ │ │ └─SiLU (1) [8, 32, 128, 128] [8, 32, 128, 128] -- -- --
69
+ │ │ │ │ └─Conv2d (2) [8, 32, 128, 128] [8, 64, 128, 128] 18,496 True 0.30%
70
+ │ │ └─ModuleList (resnet_conv_second) -- -- (recursive) True (recursive)
71
+ │ │ │ └─Sequential (0) [8, 64, 128, 128] [8, 64, 128, 128] -- True --
72
+ │ │ │ │ └─GroupNorm (0) [8, 64, 128, 128] [8, 64, 128, 128] 128 True 0.00%
73
+ │ │ │ │ └─SiLU (1) [8, 64, 128, 128] [8, 64, 128, 128] -- -- --
74
+ │ │ │ │ └─Conv2d (2) [8, 64, 128, 128] [8, 64, 128, 128] 36,928 True 0.59%
75
+ │ │ └─ModuleList (residual_input_conv) -- -- (recursive) True (recursive)
76
+ │ │ │ └─Conv2d (0) [8, 32, 128, 128] [8, 64, 128, 128] 2,112 True 0.03%
77
+ │ │ └─ModuleList (resnet_conv_first) -- -- (recursive) True (recursive)
78
+ │ │ │ └─Sequential (1) [8, 64, 128, 128] [8, 64, 128, 128] -- True --
79
+ │ │ │ │ └─GroupNorm (0) [8, 64, 128, 128] [8, 64, 128, 128] 128 True 0.00%
80
+ │ │ │ │ └─SiLU (1) [8, 64, 128, 128] [8, 64, 128, 128] -- -- --
81
+ │ │ │ │ └─Conv2d (2) [8, 64, 128, 128] [8, 64, 128, 128] 36,928 True 0.59%
82
+ │ │ └─ModuleList (resnet_conv_second) -- -- (recursive) True (recursive)
83
+ │ │ │ └─Sequential (1) [8, 64, 128, 128] [8, 64, 128, 128] -- True --
84
+ │ │ │ │ └─GroupNorm (0) [8, 64, 128, 128] [8, 64, 128, 128] 128 True 0.00%
85
+ │ │ │ │ └─SiLU (1) [8, 64, 128, 128] [8, 64, 128, 128] -- -- --
86
+ │ │ │ │ └─Conv2d (2) [8, 64, 128, 128] [8, 64, 128, 128] 36,928 True 0.59%
87
+ │ │ └─ModuleList (residual_input_conv) -- -- (recursive) True (recursive)
88
+ │ │ │ └─Conv2d (1) [8, 64, 128, 128] [8, 64, 128, 128] 4,160 True 0.07%
89
+ │ │ └─ModuleList (resnet_conv_first) -- -- (recursive) True (recursive)
90
+ │ │ │ └─Sequential (2) [8, 64, 128, 128] [8, 64, 128, 128] -- True --
91
+ │ │ │ │ └─GroupNorm (0) [8, 64, 128, 128] [8, 64, 128, 128] 128 True 0.00%
92
+ │ │ │ │ └─SiLU (1) [8, 64, 128, 128] [8, 64, 128, 128] -- -- --
93
+ │ │ │ │ └─Conv2d (2) [8, 64, 128, 128] [8, 64, 128, 128] 36,928 True 0.59%
94
+ │ │ └─ModuleList (resnet_conv_second) -- -- (recursive) True (recursive)
95
+ │ │ │ └─Sequential (2) [8, 64, 128, 128] [8, 64, 128, 128] -- True --
96
+ │ │ │ │ └─GroupNorm (0) [8, 64, 128, 128] [8, 64, 128, 128] 128 True 0.00%
97
+ │ │ │ │ └─SiLU (1) [8, 64, 128, 128] [8, 64, 128, 128] -- -- --
98
+ │ │ │ │ └─Conv2d (2) [8, 64, 128, 128] [8, 64, 128, 128] 36,928 True 0.59%
99
+ │ │ └─ModuleList (residual_input_conv) -- -- (recursive) True (recursive)
100
+ │ │ │ └─Conv2d (2) [8, 64, 128, 128] [8, 64, 128, 128] 4,160 True 0.07%
101
+ │ │ └─ModuleList (resnet_conv_first) -- -- (recursive) True (recursive)
102
+ │ │ │ └─Sequential (3) [8, 64, 128, 128] [8, 64, 128, 128] -- True --
103
+ │ │ │ │ └─GroupNorm (0) [8, 64, 128, 128] [8, 64, 128, 128] 128 True 0.00%
104
+ │ │ │ │ └─SiLU (1) [8, 64, 128, 128] [8, 64, 128, 128] -- -- --
105
+ │ │ │ │ └─Conv2d (2) [8, 64, 128, 128] [8, 64, 128, 128] 36,928 True 0.59%
106
+ │ │ └─ModuleList (resnet_conv_second) -- -- (recursive) True (recursive)
107
+ │ │ │ └─Sequential (3) [8, 64, 128, 128] [8, 64, 128, 128] -- True --
108
+ │ │ │ │ └─GroupNorm (0) [8, 64, 128, 128] [8, 64, 128, 128] 128 True 0.00%
109
+ │ │ │ │ └─SiLU (1) [8, 64, 128, 128] [8, 64, 128, 128] -- -- --
110
+ │ │ │ │ └─Conv2d (2) [8, 64, 128, 128] [8, 64, 128, 128] 36,928 True 0.59%
111
+ │ │ └─ModuleList (residual_input_conv) -- -- (recursive) True (recursive)
112
+ │ │ │ └─Conv2d (3) [8, 64, 128, 128] [8, 64, 128, 128] 4,160 True 0.07%
113
+ │ │ └─Conv2d (down_sample_conv) [8, 64, 128, 128] [8, 64, 64, 64] 65,600 True 1.05%
114
+ │ └─DownBlock (1) [8, 64, 64, 64] [8, 128, 32, 32] -- True --
115
+ │ │ └─ModuleList (resnet_conv_first) -- -- (recursive) True (recursive)
116
+ │ │ │ └─Sequential (0) [8, 64, 64, 64] [8, 128, 64, 64] -- True --
117
+ │ │ │ │ └─GroupNorm (0) [8, 64, 64, 64] [8, 64, 64, 64] 128 True 0.00%
118
+ │ │ │ │ └─SiLU (1) [8, 64, 64, 64] [8, 64, 64, 64] -- -- --
119
+ │ │ │ │ └─Conv2d (2) [8, 64, 64, 64] [8, 128, 64, 64] 73,856 True 1.19%
120
+ │ │ └─ModuleList (resnet_conv_second) -- -- (recursive) True (recursive)
121
+ │ │ │ └─Sequential (0) [8, 128, 64, 64] [8, 128, 64, 64] -- True --
122
+ │ │ │ │ └─GroupNorm (0) [8, 128, 64, 64] [8, 128, 64, 64] 256 True 0.00%
123
+ │ │ │ │ └─SiLU (1) [8, 128, 64, 64] [8, 128, 64, 64] -- -- --
124
+ │ │ │ │ └─Conv2d (2) [8, 128, 64, 64] [8, 128, 64, 64] 147,584 True 2.37%
125
+ │ │ └─ModuleList (residual_input_conv) -- -- (recursive) True (recursive)
126
+ │ │ │ └─Conv2d (0) [8, 64, 64, 64] [8, 128, 64, 64] 8,320 True 0.13%
127
+ │ │ └─ModuleList (resnet_conv_first) -- -- (recursive) True (recursive)
128
+ │ │ │ └─Sequential (1) [8, 128, 64, 64] [8, 128, 64, 64] -- True --
129
+ │ │ │ │ └─GroupNorm (0) [8, 128, 64, 64] [8, 128, 64, 64] 256 True 0.00%
130
+ │ │ │ │ └─SiLU (1) [8, 128, 64, 64] [8, 128, 64, 64] -- -- --
131
+ │ │ │ │ └─Conv2d (2) [8, 128, 64, 64] [8, 128, 64, 64] 147,584 True 2.37%
132
+ │ │ └─ModuleList (resnet_conv_second) -- -- (recursive) True (recursive)
133
+ │ │ │ └─Sequential (1) [8, 128, 64, 64] [8, 128, 64, 64] -- True --
134
+ │ │ │ │ └─GroupNorm (0) [8, 128, 64, 64] [8, 128, 64, 64] 256 True 0.00%
135
+ │ │ │ │ └─SiLU (1) [8, 128, 64, 64] [8, 128, 64, 64] -- -- --
136
+ │ │ │ │ └─Conv2d (2) [8, 128, 64, 64] [8, 128, 64, 64] 147,584 True 2.37%
137
+ │ │ └─ModuleList (residual_input_conv) -- -- (recursive) True (recursive)
138
+ │ │ │ └─Conv2d (1) [8, 128, 64, 64] [8, 128, 64, 64] 16,512 True 0.27%
139
+ │ │ └─ModuleList (resnet_conv_first) -- -- (recursive) True (recursive)
140
+ │ │ │ └─Sequential (2) [8, 128, 64, 64] [8, 128, 64, 64] -- True --
141
+ │ │ │ │ └─GroupNorm (0) [8, 128, 64, 64] [8, 128, 64, 64] 256 True 0.00%
142
+ │ │ │ │ └─SiLU (1) [8, 128, 64, 64] [8, 128, 64, 64] -- -- --
143
+ │ │ │ │ └─Conv2d (2) [8, 128, 64, 64] [8, 128, 64, 64] 147,584 True 2.37%
144
+ │ │ └─ModuleList (resnet_conv_second) -- -- (recursive) True (recursive)
145
+ │ │ │ └─Sequential (2) [8, 128, 64, 64] [8, 128, 64, 64] -- True --
146
+ │ │ │ │ └─GroupNorm (0) [8, 128, 64, 64] [8, 128, 64, 64] 256 True 0.00%
147
+ │ │ │ │ └─SiLU (1) [8, 128, 64, 64] [8, 128, 64, 64] -- -- --
148
+ │ │ │ │ └─Conv2d (2) [8, 128, 64, 64] [8, 128, 64, 64] 147,584 True 2.37%
149
+ │ │ └─ModuleList (residual_input_conv) -- -- (recursive) True (recursive)
150
+ │ │ │ └─Conv2d (2) [8, 128, 64, 64] [8, 128, 64, 64] 16,512 True 0.27%
151
+ │ │ └─ModuleList (resnet_conv_first) -- -- (recursive) True (recursive)
152
+ │ │ │ └─Sequential (3) [8, 128, 64, 64] [8, 128, 64, 64] -- True --
153
+ │ │ │ │ └─GroupNorm (0) [8, 128, 64, 64] [8, 128, 64, 64] 256 True 0.00%
154
+ │ │ │ │ └─SiLU (1) [8, 128, 64, 64] [8, 128, 64, 64] -- -- --
155
+ │ │ │ │ └─Conv2d (2) [8, 128, 64, 64] [8, 128, 64, 64] 147,584 True 2.37%
156
+ │ │ └─ModuleList (resnet_conv_second) -- -- (recursive) True (recursive)
157
+ │ │ │ └─Sequential (3) [8, 128, 64, 64] [8, 128, 64, 64] -- True --
158
+ │ │ │ │ └─GroupNorm (0) [8, 128, 64, 64] [8, 128, 64, 64] 256 True 0.00%
159
+ │ │ │ │ └─SiLU (1) [8, 128, 64, 64] [8, 128, 64, 64] -- -- --
160
+ │ │ │ │ └─Conv2d (2) [8, 128, 64, 64] [8, 128, 64, 64] 147,584 True 2.37%
161
+ │ │ └─ModuleList (residual_input_conv) -- -- (recursive) True (recursive)
162
+ │ │ │ └─Conv2d (3) [8, 128, 64, 64] [8, 128, 64, 64] 16,512 True 0.27%
163
+ │ │ └─Conv2d (down_sample_conv) [8, 128, 64, 64] [8, 128, 32, 32] 262,272 True 4.22%
164
+ ├─ModuleList (encoder_mids) -- -- -- True --
165
+ │ └─MidBlock (0) [8, 128, 32, 32] [8, 128, 32, 32] -- True --
166
+ │ │ └─ModuleList (resnet_conv_first) -- -- (recursive) True (recursive)
167
+ │ │ │ └─Sequential (0) [8, 128, 32, 32] [8, 128, 32, 32] -- True --
168
+ │ │ │ │ └─GroupNorm (0) [8, 128, 32, 32] [8, 128, 32, 32] 256 True 0.00%
169
+ │ │ │ │ └─SiLU (1) [8, 128, 32, 32] [8, 128, 32, 32] -- -- --
170
+ │ │ │ │ └─Conv2d (2) [8, 128, 32, 32] [8, 128, 32, 32] 147,584 True 2.37%
171
+ │ │ └─ModuleList (resnet_conv_second) -- -- (recursive) True (recursive)
172
+ │ │ │ └─Sequential (0) [8, 128, 32, 32] [8, 128, 32, 32] -- True --
173
+ │ │ │ │ └─GroupNorm (0) [8, 128, 32, 32] [8, 128, 32, 32] 256 True 0.00%
174
+ │ │ │ │ └─SiLU (1) [8, 128, 32, 32] [8, 128, 32, 32] -- -- --
175
+ │ │ │ │ └─Conv2d (2) [8, 128, 32, 32] [8, 128, 32, 32] 147,584 True 2.37%
176
+ │ │ └─ModuleList (residual_input_conv) -- -- (recursive) True (recursive)
177
+ │ │ │ └─Conv2d (0) [8, 128, 32, 32] [8, 128, 32, 32] 16,512 True 0.27%
178
+ │ │ └─ModuleList (attention_norms) -- -- (recursive) True (recursive)
179
+ │ │ │ └─GroupNorm (0) [8, 128, 1024] [8, 128, 1024] 256 True 0.00%
180
+ │ │ └─ModuleList (attentions) -- -- (recursive) True (recursive)
181
+ │ │ │ └─MultiheadAttention (0) [8, 1024, 128] [8, 1024, 128] 66,048 True 1.06%
182
+ │ │ └─ModuleList (resnet_conv_first) -- -- (recursive) True (recursive)
183
+ │ │ │ └─Sequential (1) [8, 128, 32, 32] [8, 128, 32, 32] -- True --
184
+ │ │ │ │ └─GroupNorm (0) [8, 128, 32, 32] [8, 128, 32, 32] 256 True 0.00%
185
+ │ │ │ │ └─SiLU (1) [8, 128, 32, 32] [8, 128, 32, 32] -- -- --
186
+ │ │ │ │ └─Conv2d (2) [8, 128, 32, 32] [8, 128, 32, 32] 147,584 True 2.37%
187
+ │ │ └─ModuleList (resnet_conv_second) -- -- (recursive) True (recursive)
188
+ │ │ │ └─Sequential (1) [8, 128, 32, 32] [8, 128, 32, 32] -- True --
189
+ │ │ │ │ └─GroupNorm (0) [8, 128, 32, 32] [8, 128, 32, 32] 256 True 0.00%
190
+ │ │ │ │ └─SiLU (1) [8, 128, 32, 32] [8, 128, 32, 32] -- -- --
191
+ │ │ │ │ └─Conv2d (2) [8, 128, 32, 32] [8, 128, 32, 32] 147,584 True 2.37%
192
+ │ │ └─ModuleList (residual_input_conv) -- -- (recursive) True (recursive)
193
+ │ │ │ └─Conv2d (1) [8, 128, 32, 32] [8, 128, 32, 32] 16,512 True 0.27%
194
+ │ │ └─ModuleList (attention_norms) -- -- (recursive) True (recursive)
195
+ │ │ │ └─GroupNorm (1) [8, 128, 1024] [8, 128, 1024] 256 True 0.00%
196
+ │ │ └─ModuleList (attentions) -- -- (recursive) True (recursive)
197
+ │ │ │ └─MultiheadAttention (1) [8, 1024, 128] [8, 1024, 128] 66,048 True 1.06%
198
+ │ │ └─ModuleList (resnet_conv_first) -- -- (recursive) True (recursive)
199
+ │ │ │ └─Sequential (2) [8, 128, 32, 32] [8, 128, 32, 32] -- True --
200
+ │ │ │ │ └─GroupNorm (0) [8, 128, 32, 32] [8, 128, 32, 32] 256 True 0.00%
201
+ │ │ │ │ └─SiLU (1) [8, 128, 32, 32] [8, 128, 32, 32] -- -- --
202
+ │ │ │ │ └─Conv2d (2) [8, 128, 32, 32] [8, 128, 32, 32] 147,584 True 2.37%
203
+ │ │ └─ModuleList (resnet_conv_second) -- -- (recursive) True (recursive)
204
+ │ │ │ └─Sequential (2) [8, 128, 32, 32] [8, 128, 32, 32] -- True --
205
+ │ │ │ │ └─GroupNorm (0) [8, 128, 32, 32] [8, 128, 32, 32] 256 True 0.00%
206
+ │ │ │ │ └─SiLU (1) [8, 128, 32, 32] [8, 128, 32, 32] -- -- --
207
+ │ │ │ │ └─Conv2d (2) [8, 128, 32, 32] [8, 128, 32, 32] 147,584 True 2.37%
208
+ │ │ └─ModuleList (residual_input_conv) -- -- (recursive) True (recursive)
209
+ │ │ │ └─Conv2d (2) [8, 128, 32, 32] [8, 128, 32, 32] 16,512 True 0.27%
210
+ │ │ └─ModuleList (attention_norms) -- -- (recursive) True (recursive)
211
+ │ │ │ └─GroupNorm (2) [8, 128, 1024] [8, 128, 1024] 256 True 0.00%
212
+ │ │ └─ModuleList (attentions) -- -- (recursive) True (recursive)
213
+ │ │ │ └─MultiheadAttention (2) [8, 1024, 128] [8, 1024, 128] 66,048 True 1.06%
214
+ │ │ └─ModuleList (resnet_conv_first) -- -- (recursive) True (recursive)
215
+ │ │ │ └─Sequential (3) [8, 128, 32, 32] [8, 128, 32, 32] -- True --
216
+ │ │ │ │ └─GroupNorm (0) [8, 128, 32, 32] [8, 128, 32, 32] 256 True 0.00%
217
+ │ │ │ │ └─SiLU (1) [8, 128, 32, 32] [8, 128, 32, 32] -- -- --
218
+ │ │ │ │ └─Conv2d (2) [8, 128, 32, 32] [8, 128, 32, 32] 147,584 True 2.37%
219
+ │ │ └─ModuleList (resnet_conv_second) -- -- (recursive) True (recursive)
220
+ │ │ │ └─Sequential (3) [8, 128, 32, 32] [8, 128, 32, 32] -- True --
221
+ │ │ │ │ └─GroupNorm (0) [8, 128, 32, 32] [8, 128, 32, 32] 256 True 0.00%
222
+ │ │ │ │ └─SiLU (1) [8, 128, 32, 32] [8, 128, 32, 32] -- -- --
223
+ │ │ │ │ └─Conv2d (2) [8, 128, 32, 32] [8, 128, 32, 32] 147,584 True 2.37%
224
+ │ │ └─ModuleList (residual_input_conv) -- -- (recursive) True (recursive)
225
+ │ │ │ └─Conv2d (3) [8, 128, 32, 32] [8, 128, 32, 32] 16,512 True 0.27%
226
+ │ │ └─ModuleList (attention_norms) -- -- (recursive) True (recursive)
227
+ │ │ │ └─GroupNorm (3) [8, 128, 1024] [8, 128, 1024] 256 True 0.00%
228
+ │ │ └─ModuleList (attentions) -- -- (recursive) True (recursive)
229
+ │ │ │ └─MultiheadAttention (3) [8, 1024, 128] [8, 1024, 128] 66,048 True 1.06%
230
+ │ │ └─ModuleList (resnet_conv_first) -- -- (recursive) True (recursive)
231
+ │ │ │ └─Sequential (4) [8, 128, 32, 32] [8, 128, 32, 32] -- True --
232
+ │ │ │ │ └─GroupNorm (0) [8, 128, 32, 32] [8, 128, 32, 32] 256 True 0.00%
233
+ │ │ │ │ └─SiLU (1) [8, 128, 32, 32] [8, 128, 32, 32] -- -- --
234
+ │ │ │ │ └─Conv2d (2) [8, 128, 32, 32] [8, 128, 32, 32] 147,584 True 2.37%
235
+ │ │ └─ModuleList (resnet_conv_second) -- -- (recursive) True (recursive)
236
+ │ │ │ └─Sequential (4) [8, 128, 32, 32] [8, 128, 32, 32] -- True --
237
+ │ │ │ │ └─GroupNorm (0) [8, 128, 32, 32] [8, 128, 32, 32] 256 True 0.00%
238
+ │ │ │ │ └─SiLU (1) [8, 128, 32, 32] [8, 128, 32, 32] -- -- --
239
+ │ │ │ │ └─Conv2d (2) [8, 128, 32, 32] [8, 128, 32, 32] 147,584 True 2.37%
240
+ │ │ └─ModuleList (residual_input_conv) -- -- (recursive) True (recursive)
241
+ │ │ │ └─Conv2d (4) [8, 128, 32, 32] [8, 128, 32, 32] 16,512 True 0.27%
242
+ ├─GroupNorm (encoder_norm_out) [8, 128, 32, 32] [8, 128, 32, 32] 256 True 0.00%
243
+ ├─Conv2d (encoder_conv_out) [8, 128, 32, 32] [8, 3, 32, 32] 3,459 True 0.06%
244
+ ├─Conv2d (pre_quant_conv) [8, 3, 32, 32] [8, 3, 32, 32] 12 True 0.00%
245
+ ├─Conv2d (post_quant_conv) [8, 3, 32, 32] [8, 3, 32, 32] 12 True 0.00%
246
+ ├─Conv2d (decoder_conv_in) [8, 3, 32, 32] [8, 128, 32, 32] 3,584 True 0.06%
247
+ ├─ModuleList (decoder_mids) -- -- -- True --
248
+ │ └─MidBlock (0) [8, 128, 32, 32] [8, 128, 32, 32] -- True --
249
+ │ │ └─ModuleList (resnet_conv_first) -- -- (recursive) True (recursive)
250
+ │ │ │ └─Sequential (0) [8, 128, 32, 32] [8, 128, 32, 32] -- True --
251
+ │ │ │ │ └─GroupNorm (0) [8, 128, 32, 32] [8, 128, 32, 32] 256 True 0.00%
252
+ │ │ │ │ └─SiLU (1) [8, 128, 32, 32] [8, 128, 32, 32] -- -- --
253
+ │ │ │ │ └─Conv2d (2) [8, 128, 32, 32] [8, 128, 32, 32] 147,584 True 2.37%
254
+ │ │ └─ModuleList (resnet_conv_second) -- -- (recursive) True (recursive)
255
+ │ │ │ └─Sequential (0) [8, 128, 32, 32] [8, 128, 32, 32] -- True --
256
+ │ │ │ │ └─GroupNorm (0) [8, 128, 32, 32] [8, 128, 32, 32] 256 True 0.00%
257
+ │ │ │ │ └─SiLU (1) [8, 128, 32, 32] [8, 128, 32, 32] -- -- --
258
+ │ │ │ │ └─Conv2d (2) [8, 128, 32, 32] [8, 128, 32, 32] 147,584 True 2.37%
259
+ │ │ └─ModuleList (residual_input_conv) -- -- (recursive) True (recursive)
260
+ │ │ │ └─Conv2d (0) [8, 128, 32, 32] [8, 128, 32, 32] 16,512 True 0.27%
261
+ │ │ └─ModuleList (attention_norms) -- -- (recursive) True (recursive)
262
+ │ │ │ └─GroupNorm (0) [8, 128, 1024] [8, 128, 1024] 256 True 0.00%
263
+ │ │ └─ModuleList (attentions) -- -- (recursive) True (recursive)
264
+ │ │ │ └─MultiheadAttention (0) [8, 1024, 128] [8, 1024, 128] 66,048 True 1.06%
265
+ │ │ └─ModuleList (resnet_conv_first) -- -- (recursive) True (recursive)
266
+ │ │ │ └─Sequential (1) [8, 128, 32, 32] [8, 128, 32, 32] -- True --
267
+ │ │ │ │ └─GroupNorm (0) [8, 128, 32, 32] [8, 128, 32, 32] 256 True 0.00%
268
+ │ │ │ │ └─SiLU (1) [8, 128, 32, 32] [8, 128, 32, 32] -- -- --
269
+ │ │ │ │ └─Conv2d (2) [8, 128, 32, 32] [8, 128, 32, 32] 147,584 True 2.37%
270
+ │ │ └─ModuleList (resnet_conv_second) -- -- (recursive) True (recursive)
271
+ │ │ │ └─Sequential (1) [8, 128, 32, 32] [8, 128, 32, 32] -- True --
272
+ │ │ │ │ └─GroupNorm (0) [8, 128, 32, 32] [8, 128, 32, 32] 256 True 0.00%
273
+ │ │ │ │ └─SiLU (1) [8, 128, 32, 32] [8, 128, 32, 32] -- -- --
274
+ │ │ │ │ └─Conv2d (2) [8, 128, 32, 32] [8, 128, 32, 32] 147,584 True 2.37%
275
+ │ │ └─ModuleList (residual_input_conv) -- -- (recursive) True (recursive)
276
+ │ │ │ └─Conv2d (1) [8, 128, 32, 32] [8, 128, 32, 32] 16,512 True 0.27%
277
+ │ │ └─ModuleList (attention_norms) -- -- (recursive) True (recursive)
278
+ │ │ │ └─GroupNorm (1) [8, 128, 1024] [8, 128, 1024] 256 True 0.00%
279
+ │ │ └─ModuleList (attentions) -- -- (recursive) True (recursive)
280
+ │ │ │ └─MultiheadAttention (1) [8, 1024, 128] [8, 1024, 128] 66,048 True 1.06%
281
+ │ │ └─ModuleList (resnet_conv_first) -- -- (recursive) True (recursive)
282
+ │ │ │ └─Sequential (2) [8, 128, 32, 32] [8, 128, 32, 32] -- True --
283
+ │ │ │ │ └─GroupNorm (0) [8, 128, 32, 32] [8, 128, 32, 32] 256 True 0.00%
284
+ │ │ │ │ └─SiLU (1) [8, 128, 32, 32] [8, 128, 32, 32] -- -- --
285
+ │ │ │ │ └─Conv2d (2) [8, 128, 32, 32] [8, 128, 32, 32] 147,584 True 2.37%
286
+ │ │ └─ModuleList (resnet_conv_second) -- -- (recursive) True (recursive)
287
+ │ │ │ └─Sequential (2) [8, 128, 32, 32] [8, 128, 32, 32] -- True --
288
+ │ │ │ │ └─GroupNorm (0) [8, 128, 32, 32] [8, 128, 32, 32] 256 True 0.00%
289
+ │ │ │ │ └─SiLU (1) [8, 128, 32, 32] [8, 128, 32, 32] -- -- --
290
+ │ │ │ │ └─Conv2d (2) [8, 128, 32, 32] [8, 128, 32, 32] 147,584 True 2.37%
291
+ │ │ └─ModuleList (residual_input_conv) -- -- (recursive) True (recursive)
292
+ │ │ │ └─Conv2d (2) [8, 128, 32, 32] [8, 128, 32, 32] 16,512 True 0.27%
293
+ │ │ └─ModuleList (attention_norms) -- -- (recursive) True (recursive)
294
+ │ │ │ └─GroupNorm (2) [8, 128, 1024] [8, 128, 1024] 256 True 0.00%
295
+ │ │ └─ModuleList (attentions) -- -- (recursive) True (recursive)
296
+ │ │ │ └─MultiheadAttention (2) [8, 1024, 128] [8, 1024, 128] 66,048 True 1.06%
297
+ │ │ └─ModuleList (resnet_conv_first) -- -- (recursive) True (recursive)
298
+ │ │ │ └─Sequential (3) [8, 128, 32, 32] [8, 128, 32, 32] -- True --
299
+ │ │ │ │ └─GroupNorm (0) [8, 128, 32, 32] [8, 128, 32, 32] 256 True 0.00%
300
+ │ │ │ │ └─SiLU (1) [8, 128, 32, 32] [8, 128, 32, 32] -- -- --
301
+ │ │ │ │ └─Conv2d (2) [8, 128, 32, 32] [8, 128, 32, 32] 147,584 True 2.37%
302
+ │ │ └─ModuleList (resnet_conv_second) -- -- (recursive) True (recursive)
303
+ │ │ │ └─Sequential (3) [8, 128, 32, 32] [8, 128, 32, 32] -- True --
304
+ │ │ │ │ └─GroupNorm (0) [8, 128, 32, 32] [8, 128, 32, 32] 256 True 0.00%
305
+ │ │ │ │ └─SiLU (1) [8, 128, 32, 32] [8, 128, 32, 32] -- -- --
306
+ │ │ │ │ └─Conv2d (2) [8, 128, 32, 32] [8, 128, 32, 32] 147,584 True 2.37%
307
+ │ │ └─ModuleList (residual_input_conv) -- -- (recursive) True (recursive)
308
+ │ │ │ └─Conv2d (3) [8, 128, 32, 32] [8, 128, 32, 32] 16,512 True 0.27%
309
+ │ │ └─ModuleList (attention_norms) -- -- (recursive) True (recursive)
310
+ │ │ │ └─GroupNorm (3) [8, 128, 1024] [8, 128, 1024] 256 True 0.00%
311
+ │ │ └─ModuleList (attentions) -- -- (recursive) True (recursive)
312
+ │ │ │ └─MultiheadAttention (3) [8, 1024, 128] [8, 1024, 128] 66,048 True 1.06%
313
+ │ │ └─ModuleList (resnet_conv_first) -- -- (recursive) True (recursive)
314
+ │ │ │ └─Sequential (4) [8, 128, 32, 32] [8, 128, 32, 32] -- True --
315
+ │ │ │ │ └─GroupNorm (0) [8, 128, 32, 32] [8, 128, 32, 32] 256 True 0.00%
316
+ │ │ │ │ └─SiLU (1) [8, 128, 32, 32] [8, 128, 32, 32] -- -- --
317
+ │ │ │ │ └─Conv2d (2) [8, 128, 32, 32] [8, 128, 32, 32] 147,584 True 2.37%
318
+ │ │ └─ModuleList (resnet_conv_second) -- -- (recursive) True (recursive)
319
+ │ │ │ └─Sequential (4) [8, 128, 32, 32] [8, 128, 32, 32] -- True --
320
+ │ │ │ │ └─GroupNorm (0) [8, 128, 32, 32] [8, 128, 32, 32] 256 True 0.00%
321
+ │ │ │ │ └─SiLU (1) [8, 128, 32, 32] [8, 128, 32, 32] -- -- --
322
+ │ │ │ │ └─Conv2d (2) [8, 128, 32, 32] [8, 128, 32, 32] 147,584 True 2.37%
323
+ │ │ └─ModuleList (residual_input_conv) -- -- (recursive) True (recursive)
324
+ │ │ │ └─Conv2d (4) [8, 128, 32, 32] [8, 128, 32, 32] 16,512 True 0.27%
325
+ ├─ModuleList (decoder_layers) -- -- -- True --
326
+ │ └─UpBlock (0) [8, 128, 32, 32] [8, 64, 64, 64] -- True --
327
+ │ │ └─ConvTranspose2d (up_sample_conv) [8, 128, 32, 32] [8, 128, 64, 64] 262,272 True 4.22%
328
+ │ │ └─ModuleList (resnet_conv_first) -- -- (recursive) True (recursive)
329
+ │ │ │ └─Sequential (0) [8, 128, 64, 64] [8, 64, 64, 64] -- True --
330
+ │ │ │ │ └─GroupNorm (0) [8, 128, 64, 64] [8, 128, 64, 64] 256 True 0.00%
331
+ │ │ │ │ └─SiLU (1) [8, 128, 64, 64] [8, 128, 64, 64] -- -- --
332
+ │ │ │ │ └─Conv2d (2) [8, 128, 64, 64] [8, 64, 64, 64] 73,792 True 1.19%
333
+ │ │ └─ModuleList (resnet_conv_second) -- -- (recursive) True (recursive)
334
+ │ │ │ └─Sequential (0) [8, 64, 64, 64] [8, 64, 64, 64] -- True --
335
+ │ │ │ │ └─GroupNorm (0) [8, 64, 64, 64] [8, 64, 64, 64] 128 True 0.00%
336
+ │ │ │ │ └─SiLU (1) [8, 64, 64, 64] [8, 64, 64, 64] -- -- --
337
+ │ │ │ │ └─Conv2d (2) [8, 64, 64, 64] [8, 64, 64, 64] 36,928 True 0.59%
338
+ │ │ └─ModuleList (residual_input_conv) -- -- (recursive) True (recursive)
339
+ │ │ │ └─Conv2d (0) [8, 128, 64, 64] [8, 64, 64, 64] 8,256 True 0.13%
340
+ │ │ └─ModuleList (resnet_conv_first) -- -- (recursive) True (recursive)
341
+ │ │ │ └─Sequential (1) [8, 64, 64, 64] [8, 64, 64, 64] -- True --
342
+ │ │ │ │ └─GroupNorm (0) [8, 64, 64, 64] [8, 64, 64, 64] 128 True 0.00%
343
+ │ │ │ │ └─SiLU (1) [8, 64, 64, 64] [8, 64, 64, 64] -- -- --
344
+ │ │ │ │ └─Conv2d (2) [8, 64, 64, 64] [8, 64, 64, 64] 36,928 True 0.59%
345
+ │ │ └─ModuleList (resnet_conv_second) -- -- (recursive) True (recursive)
346
+ │ │ │ └─Sequential (1) [8, 64, 64, 64] [8, 64, 64, 64] -- True --
347
+ │ │ │ │ └─GroupNorm (0) [8, 64, 64, 64] [8, 64, 64, 64] 128 True 0.00%
348
+ │ │ │ │ └─SiLU (1) [8, 64, 64, 64] [8, 64, 64, 64] -- -- --
349
+ │ │ │ │ └─Conv2d (2) [8, 64, 64, 64] [8, 64, 64, 64] 36,928 True 0.59%
350
+ │ │ └─ModuleList (residual_input_conv) -- -- (recursive) True (recursive)
351
+ │ │ │ └─Conv2d (1) [8, 64, 64, 64] [8, 64, 64, 64] 4,160 True 0.07%
352
+ │ │ └─ModuleList (resnet_conv_first) -- -- (recursive) True (recursive)
353
+ │ │ │ └─Sequential (2) [8, 64, 64, 64] [8, 64, 64, 64] -- True --
354
+ │ │ │ │ └─GroupNorm (0) [8, 64, 64, 64] [8, 64, 64, 64] 128 True 0.00%
355
+ │ │ │ │ └─SiLU (1) [8, 64, 64, 64] [8, 64, 64, 64] -- -- --
356
+ │ │ │ │ └─Conv2d (2) [8, 64, 64, 64] [8, 64, 64, 64] 36,928 True 0.59%
357
+ │ │ └─ModuleList (resnet_conv_second) -- -- (recursive) True (recursive)
358
+ │ │ │ └─Sequential (2) [8, 64, 64, 64] [8, 64, 64, 64] -- True --
359
+ │ │ │ │ └─GroupNorm (0) [8, 64, 64, 64] [8, 64, 64, 64] 128 True 0.00%
360
+ │ │ │ │ └─SiLU (1) [8, 64, 64, 64] [8, 64, 64, 64] -- -- --
361
+ │ │ │ │ └─Conv2d (2) [8, 64, 64, 64] [8, 64, 64, 64] 36,928 True 0.59%
362
+ │ │ └─ModuleList (residual_input_conv) -- -- (recursive) True (recursive)
363
+ │ │ │ └─Conv2d (2) [8, 64, 64, 64] [8, 64, 64, 64] 4,160 True 0.07%
364
+ │ │ └─ModuleList (resnet_conv_first) -- -- (recursive) True (recursive)
365
+ │ │ │ └─Sequential (3) [8, 64, 64, 64] [8, 64, 64, 64] -- True --
366
+ │ │ │ │ └─GroupNorm (0) [8, 64, 64, 64] [8, 64, 64, 64] 128 True 0.00%
367
+ │ │ │ │ └─SiLU (1) [8, 64, 64, 64] [8, 64, 64, 64] -- -- --
368
+ │ │ │ │ └─Conv2d (2) [8, 64, 64, 64] [8, 64, 64, 64] 36,928 True 0.59%
369
+ │ │ └─ModuleList (resnet_conv_second) -- -- (recursive) True (recursive)
370
+ │ │ │ └─Sequential (3) [8, 64, 64, 64] [8, 64, 64, 64] -- True --
371
+ │ │ │ │ └─GroupNorm (0) [8, 64, 64, 64] [8, 64, 64, 64] 128 True 0.00%
372
+ │ │ │ │ └─SiLU (1) [8, 64, 64, 64] [8, 64, 64, 64] -- -- --
373
+ │ │ │ │ └─Conv2d (2) [8, 64, 64, 64] [8, 64, 64, 64] 36,928 True 0.59%
374
+ │ │ └─ModuleList (residual_input_conv) -- -- (recursive) True (recursive)
375
+ │ │ │ └─Conv2d (3) [8, 64, 64, 64] [8, 64, 64, 64] 4,160 True 0.07%
376
+ │ └─UpBlock (1) [8, 64, 64, 64] [8, 32, 128, 128] -- True --
377
+ │ │ └─ConvTranspose2d (up_sample_conv) [8, 64, 64, 64] [8, 64, 128, 128] 65,600 True 1.05%
378
+ │ │ └─ModuleList (resnet_conv_first) -- -- (recursive) True (recursive)
379
+ │ │ │ └─Sequential (0) [8, 64, 128, 128] [8, 32, 128, 128] -- True --
380
+ │ │ │ │ └─GroupNorm (0) [8, 64, 128, 128] [8, 64, 128, 128] 128 True 0.00%
381
+ │ │ │ │ └─SiLU (1) [8, 64, 128, 128] [8, 64, 128, 128] -- -- --
382
+ │ │ │ │ └─Conv2d (2) [8, 64, 128, 128] [8, 32, 128, 128] 18,464 True 0.30%
383
+ │ │ └─ModuleList (resnet_conv_second) -- -- (recursive) True (recursive)
384
+ │ │ │ └─Sequential (0) [8, 32, 128, 128] [8, 32, 128, 128] -- True --
385
+ │ │ │ │ └─GroupNorm (0) [8, 32, 128, 128] [8, 32, 128, 128] 64 True 0.00%
386
+ │ │ │ │ └─SiLU (1) [8, 32, 128, 128] [8, 32, 128, 128] -- -- --
387
+ │ │ │ │ └─Conv2d (2) [8, 32, 128, 128] [8, 32, 128, 128] 9,248 True 0.15%
388
+ │ │ └─ModuleList (residual_input_conv) -- -- (recursive) True (recursive)
389
+ │ │ │ └─Conv2d (0) [8, 64, 128, 128] [8, 32, 128, 128] 2,080 True 0.03%
390
+ │ │ └─ModuleList (resnet_conv_first) -- -- (recursive) True (recursive)
391
+ │ │ │ └─Sequential (1) [8, 32, 128, 128] [8, 32, 128, 128] -- True --
392
+ │ │ │ │ └─GroupNorm (0) [8, 32, 128, 128] [8, 32, 128, 128] 64 True 0.00%
393
+ │ │ │ │ └─SiLU (1) [8, 32, 128, 128] [8, 32, 128, 128] -- -- --
394
+ │ │ │ │ └─Conv2d (2) [8, 32, 128, 128] [8, 32, 128, 128] 9,248 True 0.15%
395
+ │ │ └─ModuleList (resnet_conv_second) -- -- (recursive) True (recursive)
396
+ │ │ │ └─Sequential (1) [8, 32, 128, 128] [8, 32, 128, 128] -- True --
397
+ │ │ │ │ └─GroupNorm (0) [8, 32, 128, 128] [8, 32, 128, 128] 64 True 0.00%
398
+ │ │ │ │ └─SiLU (1) [8, 32, 128, 128] [8, 32, 128, 128] -- -- --
399
+ │ │ │ │ └─Conv2d (2) [8, 32, 128, 128] [8, 32, 128, 128] 9,248 True 0.15%
400
+ │ │ └─ModuleList (residual_input_conv) -- -- (recursive) True (recursive)
401
+ │ │ │ └─Conv2d (1) [8, 32, 128, 128] [8, 32, 128, 128] 1,056 True 0.02%
402
+ │ │ └─ModuleList (resnet_conv_first) -- -- (recursive) True (recursive)
403
+ │ │ │ └─Sequential (2) [8, 32, 128, 128] [8, 32, 128, 128] -- True --
404
+ │ │ │ │ └─GroupNorm (0) [8, 32, 128, 128] [8, 32, 128, 128] 64 True 0.00%
405
+ │ │ │ │ └─SiLU (1) [8, 32, 128, 128] [8, 32, 128, 128] -- -- --
406
+ │ │ │ │ └─Conv2d (2) [8, 32, 128, 128] [8, 32, 128, 128] 9,248 True 0.15%
407
+ │ │ └─ModuleList (resnet_conv_second) -- -- (recursive) True (recursive)
408
+ │ │ │ └─Sequential (2) [8, 32, 128, 128] [8, 32, 128, 128] -- True --
409
+ │ │ │ │ └─GroupNorm (0) [8, 32, 128, 128] [8, 32, 128, 128] 64 True 0.00%
410
+ │ │ │ │ └─SiLU (1) [8, 32, 128, 128] [8, 32, 128, 128] -- -- --
411
+ │ │ │ │ └─Conv2d (2) [8, 32, 128, 128] [8, 32, 128, 128] 9,248 True 0.15%
412
+ │ │ └─ModuleList (residual_input_conv) -- -- (recursive) True (recursive)
413
+ │ │ │ └─Conv2d (2) [8, 32, 128, 128] [8, 32, 128, 128] 1,056 True 0.02%
414
+ │ │ └─ModuleList (resnet_conv_first) -- -- (recursive) True (recursive)
415
+ │ │ │ └─Sequential (3) [8, 32, 128, 128] [8, 32, 128, 128] -- True --
416
+ │ │ │ │ └─GroupNorm (0) [8, 32, 128, 128] [8, 32, 128, 128] 64 True 0.00%
417
+ │ │ │ │ └─SiLU (1) [8, 32, 128, 128] [8, 32, 128, 128] -- -- --
418
+ │ │ │ │ └─Conv2d (2) [8, 32, 128, 128] [8, 32, 128, 128] 9,248 True 0.15%
419
+ │ │ └─ModuleList (resnet_conv_second) -- -- (recursive) True (recursive)
420
+ │ │ │ └─Sequential (3) [8, 32, 128, 128] [8, 32, 128, 128] -- True --
421
+ │ │ │ │ └─GroupNorm (0) [8, 32, 128, 128] [8, 32, 128, 128] 64 True 0.00%
422
+ │ │ │ │ └─SiLU (1) [8, 32, 128, 128] [8, 32, 128, 128] -- -- --
423
+ │ │ │ │ └─Conv2d (2) [8, 32, 128, 128] [8, 32, 128, 128] 9,248 True 0.15%
424
+ │ │ └─ModuleList (residual_input_conv) -- -- (recursive) True (recursive)
425
+ │ │ │ └─Conv2d (3) [8, 32, 128, 128] [8, 32, 128, 128] 1,056 True 0.02%
426
+ ├─GroupNorm (decoder_norm_out) [8, 32, 128, 128] [8, 32, 128, 128] 64 True 0.00%
427
+ ├─Conv2d (decoder_conv_out) [8, 32, 128, 128] [8, 3, 128, 128] 867 True 0.01%
428
+ ======================================================================================================================================================
429
+ Total params: 6,219,770
430
+ Trainable params: 6,219,770
431
+ Non-trainable params: 0
432
+ Total mult-adds (Units.GIGABYTES): 146.86
433
+ ======================================================================================================================================================
434
+ Input size (MB): 1.57
435
+ Forward/backward pass size (MB): 3719.89
436
+ Params size (MB): 22.77
437
+ Estimated Total Size (MB): 3744.23
438
+ ======================================================================================================================================================
Vaani/VQVAE_training.sh ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # ========= Variables =========
4
+
5
+ # ACC_CONFIG_PATH="/home/IITB/ai-at-ieor/23m1521/.cache/huggingface/accelerate/FSDP_2gpu.yaml"
6
+
7
+ # ACC_CONFIG_PATH="/home/IITB/ai-at-ieor/23m1521/.cache/huggingface/accelerate/default_config.yaml"
8
+
9
+ # ACC_CONFIG_PATH="/home/IITB/ai-at-ieor/23m1521/.cache/huggingface/accelerate/1GPU.yaml"
10
+
11
+ ACC_CONFIG_PATH="/home/IITB/ai-at-ieor/23m1521/.cache/huggingface/accelerate/default_config.yaml"
12
+
13
+ TRAINING_SCRIPT="/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/_6_Vaani-VQVAE-Main-Accelerate.py"
14
+
15
+ TRAIN_CONFIG_PATH="/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/config-Acc.yaml"
16
+
17
+
18
+ # ========= Command =========
19
+ accelerate launch --config_file "$ACC_CONFIG_PATH" "$TRAINING_SCRIPT" $TRAIN_CONFIG_PATH
Vaani/Vaani-Audio-Image-English.csv ADDED
The diff for this file is too large to render. See raw diff
 
Vaani/Vaani-Images-Audio-MetaData.parquet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a84fc4cf3ec21f074cb7b30a787ab49f637873fde502b3b8536df6e364b43135
3
+ size 297984593
Vaani/Vaani-subplot.png ADDED

Git LFS Details

  • SHA256: ba3fd22df273b14a6906a4257f02ca728320b04e3eaa1cb606ad9db376158b49
  • Pointer size: 132 Bytes
  • Size of remote file: 9.15 MB
Vaani/VaaniLDM/ddpm_ckpt_epoch14.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3ca34fdd03d28b5ecf65ebe1e92efde7b592f97ad0fd47e5828ac690a8f296df
3
+ size 593242410
Vaani/VaaniLDM/ddpm_ckpt_epoch15.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:74e8f75dc97d40089566c3e25e27c0530c4883c3e0747e98a669ebedc8894252
3
+ size 593242474