Julien Blanchon commited on
Commit
d0e893e
·
0 Parent(s):

Clean Space repo (code only, checkpoints in model repo)

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +39 -0
  2. .gitignore +183 -0
  3. .python-version +1 -0
  4. README.md +280 -0
  5. app.py +343 -0
  6. configs/c2i/tim_b_p4.yaml +78 -0
  7. configs/c2i/tim_xl_p1_512.yaml +85 -0
  8. configs/c2i/tim_xl_p1_512_mg.yaml +85 -0
  9. configs/c2i/tim_xl_p2_256.yaml +85 -0
  10. configs/c2i/tim_xl_p2_256_mg.yaml +85 -0
  11. configs/t2i/tim_xl_p1_t2i.yaml +81 -0
  12. pyproject.toml +31 -0
  13. requirements.txt +15 -0
  14. setup.py +12 -0
  15. tim/data/c2i_data.py +150 -0
  16. tim/data/sampler_utils.py +52 -0
  17. tim/data/t2i_data.py +126 -0
  18. tim/models/c2i/tim_model.py +406 -0
  19. tim/models/nvidia_radio/hubconf.py +192 -0
  20. tim/models/nvidia_radio/radio/__init__.py +17 -0
  21. tim/models/nvidia_radio/radio/adaptor_base.py +37 -0
  22. tim/models/nvidia_radio/radio/adaptor_generic.py +69 -0
  23. tim/models/nvidia_radio/radio/adaptor_mlp.py +174 -0
  24. tim/models/nvidia_radio/radio/adaptor_registry.py +37 -0
  25. tim/models/nvidia_radio/radio/block.py +54 -0
  26. tim/models/nvidia_radio/radio/cls_token.py +59 -0
  27. tim/models/nvidia_radio/radio/common.py +108 -0
  28. tim/models/nvidia_radio/radio/conv.py +65 -0
  29. tim/models/nvidia_radio/radio/dinov2_arch.py +1016 -0
  30. tim/models/nvidia_radio/radio/dual_hybrid_vit.py +213 -0
  31. tim/models/nvidia_radio/radio/enable_cpe_support.py +224 -0
  32. tim/models/nvidia_radio/radio/enable_damp.py +42 -0
  33. tim/models/nvidia_radio/radio/enable_spectral_reparam.py +277 -0
  34. tim/models/nvidia_radio/radio/eradio_model.py +1392 -0
  35. tim/models/nvidia_radio/radio/extra_models.py +206 -0
  36. tim/models/nvidia_radio/radio/extra_timm_models.py +206 -0
  37. tim/models/nvidia_radio/radio/feature_normalizer.py +111 -0
  38. tim/models/nvidia_radio/radio/forward_intermediates.py +138 -0
  39. tim/models/nvidia_radio/radio/hf_model.py +202 -0
  40. tim/models/nvidia_radio/radio/input_conditioner.py +49 -0
  41. tim/models/nvidia_radio/radio/open_clip_adaptor.py +41 -0
  42. tim/models/nvidia_radio/radio/radio_model.py +375 -0
  43. tim/models/nvidia_radio/radio/vision_transformer_xpos.py +357 -0
  44. tim/models/nvidia_radio/radio/vit_patch_generator.py +287 -0
  45. tim/models/nvidia_radio/radio/vitdet.py +188 -0
  46. tim/models/t2i/tim_model.py +493 -0
  47. tim/models/utils/funcs.py +53 -0
  48. tim/models/utils/norms.py +403 -0
  49. tim/models/utils/rope.py +305 -0
  50. tim/models/utils/text_encoders.py +63 -0
.gitattributes ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ checkpoints/** filter=lfs diff=lfs merge=lfs -text
37
+ checkpoints/c2i_model_256.safetensors filter=lfs diff=lfs merge=lfs -text
38
+ checkpoints/c2i_model_512.safetensors filter=lfs diff=lfs merge=lfs -text
39
+ checkpoints/t2i_model.bin filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # UV
98
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ #uv.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ #poetry.lock
109
+
110
+ # pdm
111
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112
+ #pdm.lock
113
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114
+ # in version control.
115
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116
+ .pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121
+ __pypackages__/
122
+
123
+ # Celery stuff
124
+ celerybeat-schedule
125
+ celerybeat.pid
126
+
127
+ # SageMath parsed files
128
+ *.sage.py
129
+
130
+ # Environments
131
+ .env
132
+ .venv
133
+ env/
134
+ venv/
135
+ ENV/
136
+ env.bak/
137
+ venv.bak/
138
+
139
+ # Spyder project settings
140
+ .spyderproject
141
+ .spyproject
142
+
143
+ # Rope project settings
144
+ .ropeproject
145
+
146
+ # mkdocs documentation
147
+ /site
148
+
149
+ # mypy
150
+ .mypy_cache/
151
+ .dmypy.json
152
+ dmypy.json
153
+
154
+ # Pyre type checker
155
+ .pyre/
156
+
157
+ # pytype static type analyzer
158
+ .pytype/
159
+
160
+ # Cython debug symbols
161
+ cython_debug/
162
+
163
+ # PyCharm
164
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
167
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168
+ #.idea/
169
+
170
+ # Ruff stuff:
171
+ .ruff_cache/
172
+
173
+ # PyPI configuration file
174
+ .pypirc
175
+
176
+
177
+ *.json
178
+ *.svg
179
+ /workdir
180
+ /datasets
181
+ /wandb
182
+
183
+ samples/
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.10
README.md ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: TiM
3
+ emoji: 🏆
4
+ colorFrom: blue
5
+ colorTo: red
6
+ sdk: gradio
7
+ sdk_version: 5.44.1
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ <h1 align="center">Transition Models: Rethinking the Generative Learning Objective</h1>
13
+
14
+
15
+
16
+ <div align="center">
17
+ <a href="https://github.com/WZDTHU" target="_blank">ZiDong&nbsp;Wang</a><sup>1,2,*</sup>
18
+ &ensp; <b>&middot;</b> &ensp;
19
+ <a href="https://invictus717.github.io" target="_blank">Yiyuan&nbsp;Zhang</a><sup>1,2,*,‡</sup>
20
+ &ensp; <b>&middot;</b> &ensp;
21
+ <a href="https://yuexy.github.io/" target="_blank">Xiaoyu&nbsp;Yue</a><sup>2,3</sup>
22
+ &ensp; <b>&middot;</b> &ensp;
23
+ <a href="https://xyue.io" target="_blank">Xiangyu&nbsp;Yue</a><sup>1</sup>
24
+ &ensp; <b>&middot;</b> &ensp;
25
+ <a href="https://yg256li.github.io" target="_blank">Yangguang&nbsp;Li</a><sup>1,†</sup>
26
+ &ensp; <b>&middot;</b> &ensp;
27
+ <a href="https://wlouyang.github.io" target="_blank">Wanli&nbsp;Ouyang</a><sup>1,2</sup>
28
+ &ensp; <b>&middot;</b> &ensp;
29
+ <a href="http://leibai.site" target="_blank">Lei&nbsp;Bai</a><sup>2,†</sup>
30
+
31
+ <sup>1</sup> MMLab CUHK &emsp; <sup>2</sup>Shanghai AI Lab &emsp; <sup>3</sup>USYD <br>
32
+ <sup>*</sup>Equal Contribution &emsp; <sup>‡</sup>Project Lead &emsp; <sup>†</sup>Corresponding Authors &emsp; <br>
33
+ </div>
34
+
35
+
36
+
37
+ <h3 align="center">
38
+ <!-- [<a href="https://wzdthu.github.io/NiT">project page</a>]&emsp; -->
39
+ [<a href="https://arxiv.org/abs/2509.04394">arXiv</a>]&emsp;
40
+ [<a href="https://huggingface.co/GoodEnough/TiM-T2I">Model</a>]&emsp;
41
+ [<a href="https://huggingface.co/datasets/GoodEnough/TiM-Toy-T2I-Dataset">Dataset</a>]&emsp;
42
+ </h3>
43
+ <br>
44
+
45
+ <b>Highlights</b>: We propose Transition Models (TiM), a novel generative model that learns to navigate the entire generative trajectory with unprecedented flexibility.
46
+ * Our Transition Models (TiM) are trained to master arbitrary state-to-state transitions. This approach allows TiM to learn the entire solution manifold of the generative process, unifying the few-step and many-step regimes within a single, powerful model.
47
+ ![Figure](./assets/illustration.png)
48
+ * Despite having only 865M parameters, TiM achieves state-of-the-art performance, surpassing leading models such as SD3.5 (8B parameters) and FLUX.1 (12B parameters) across all evaluated step counts on GenEval benchmark. Importantly, unlike previous few-step generators, TiM demonstrates monotonic quality improvement as the sampling budget increases.
49
+ ![Figure](./assets/nfe_demo.png)
50
+ * Additionally, when employing our native-resolution strategy, TiM delivers exceptional fidelity at resolutions up to $4096\times4096$.
51
+ ![Figure](./assets/tim_demo.png)
52
+
53
+
54
+ ## 🚨 News
55
+
56
+ - `2025-9-5` We are delighted to introduce TiM, which is the first text-to-image generator support any-step generation, entirely trained from scratch. We have released the codes and pretrained models of TiM.
57
+
58
+
59
+
60
+ ## 1. Setup
61
+
62
+ First, clone the repo:
63
+ ```bash
64
+ git clone https://github.com/WZDTHU/TiM.git && cd TiM
65
+ ```
66
+
67
+ ### 1.1 Environment Setup
68
+
69
+ ```bash
70
+ conda create -n tim_env python=3.10
71
+ pip install torch==2.5.1 torchvision==0.20.1 --index-url https://download.pytorch.org/whl/cu118
72
+ pip install flash-attn
73
+ pip install -r requirements.txt
74
+ pip install -e .
75
+ ```
76
+
77
+
78
+ ### 1.2 Model Zoo (WIP)
79
+
80
+
81
+ #### Text-to-Image Generation
82
+
83
+ A single TiM model can perform any-step generation (one-step, few-step, and multi-step) and demonstrate monotonic quality improvement as the sampling budget increases.
84
+ | Model | Model Zoo | Model Size | VAE | 1-NFE GenEval | 8-NFE GenEval | 128-NFE GenEval |
85
+ |---------------|------------|---------|------------|-------|-------|-------|
86
+ | TiM-T2I | [🤗 HF](https://huggingface.co/GoodEnough/TiM-T2I/blob/main/t2i_model.bin) | 865M | [DC-AE](https://huggingface.co/mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers) | 0.67 | 0.76 | 0.83 |
87
+
88
+
89
+
90
+ ```bash
91
+ mkdir checkpoints
92
+ wget -c "https://huggingface.co/GoodEnough/TiM-T2I/resolve/main/t2i_model.bin" -O checkpoints/t2i_model.bin
93
+ ```
94
+
95
+
96
+ #### Class-guided Image Generation:
97
+
98
+ | Model | Model Zoo | Model Size | VAE | 2-NFE FID | 500-NFE FID |
99
+ |---------------|------------|---------|------------|------------|------------|
100
+ | TiM-C2I-256 | [🤗 HF](https://huggingface.co/GoodEnough/TiM-C2I/blob/main/c2i_model_256.safetensors) | 664M | [SD-VAE](https://huggingface.co/stabilityai/sd-vae-ft-ema) | 6.14 | 1.65
101
+ | TiM-C2I-512 | [🤗 HF](https://huggingface.co/GoodEnough/TiM-C2I/blob/main/c2i_model_512.safetensors) | 664M | [DC-AE](https://huggingface.co/mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers) | 4.79 | 1.69
102
+
103
+
104
+ ```bash
105
+ mkdir checkpoints
106
+ wget -c "https://huggingface.co/GoodEnough/TiM-C2I/resolve/main/c2i_model_256.safetensors" -O checkpoints/c2i_model_256.safetensors
107
+ wget -c "https://huggingface.co/GoodEnough/TiM-C2I/resolve/main/c2i_model_512.safetensors" -O checkpoints/c2i_model_512.safetensors
108
+ ```
109
+
110
+
111
+ ## 2. Sampling
112
+
113
+ #### Text-to-Image Generation
114
+
115
+ We provide the sampling scripts on three benchmarks: GenEval, DPGBench, and MJHQ30K. You can specify the sampling steps, resolutions, and CFG scale in the corresponding scripts.
116
+
117
+ Sampling with TiM-T2I model on GenEval benchmark:
118
+ ```bash
119
+ bash scripts/sample/t2i/sample_t2i_geneval.sh
120
+ ```
121
+
122
+ Sampling with TiM-T2I model on DPGBench benchmark:
123
+ ```bash
124
+ bash scripts/sample/t2i/sample_t2i_dpgbench.sh
125
+ ```
126
+
127
+ Sampling with TiM-T2I model on MJHQ30k benchmark:
128
+ ```bash
129
+ bash scripts/sample/t2i/sample_t2i_mjhq30k.sh
130
+ ```
131
+
132
+ #### Class-guided Image Generation
133
+
134
+ We provide the sampling scripts for ImageNet-256 and ImageNet-512.
135
+
136
+ Sampling with C2I model on $256\times256$ resolution:
137
+ ```bash
138
+ bash scripts/sample/c2i/sample_256x256.sh
139
+ ```
140
+
141
+ Sampling with C2I model on $512\times512$ resolution:
142
+ ```bash
143
+ bash scripts/sample/c2i/sample_512x512.sh
144
+ ```
145
+
146
+
147
+ ## 3. Evaluation
148
+
149
+
150
+ ### Text-to-Image Generation
151
+
152
+ #### GenEval
153
+
154
+ Please follow the [GenEval](https://github.com/djghosh13/geneval) to setup the conda-environment.
155
+
156
+ Given the directory of the generated images `SAMPLING_DIR` and folder of object dector `OBJECT_DETECTOR_FOLDER`, run the following codes:
157
+ ```bash
158
+ python projects/evaluate/geneval/evaluation/evaluate_images.py $SAMPLING_DIR --outfile geneval_results.jsonl --model-path $OBJECT_DETECTOR_FOLDER
159
+ ```
160
+ This will result in a JSONL file with each line corresponding to an image. Run the following codes to obtain the GenEval Score:
161
+ ```bash
162
+ python projects/evaluate/geneval/evaluation/summary_scores.py geneval_results.jsonl
163
+ ```
164
+
165
+
166
+ #### DPGBench
167
+ Please follow the [DPGBench](https://github.com/TencentQQGYLab/ELLA) to setup the conda-environment.
168
+ Given the directory of the generated images `SAMPLING_DIR` , run the following codes:
169
+ ```bash
170
+ python projects/evaluate/dpg_bench/compute_dpg_bench.py --image-root-path $SAMPLING_DIR --res-path dpgbench_results.txt --pic-num 4
171
+ ```
172
+
173
+ #### MJHQ30K
174
+ Please download [MJHQ30K](https://huggingface.co/datasets/playgroundai/MJHQ-30K) as the reference-image.
175
+
176
+
177
+ Given the directory of the reference-image direcotry `REFERENCE_DIR` and the directory of the generated images `SAMPLING_DIR`, run the following codes to calculate the FID Score:
178
+ ```bash
179
+ python projects/evaluate/mjhq30k/calculate_fid.py $REFERENCE_DIR $SAMPLING_DIR
180
+ ```
181
+
182
+ For CLIP Score, first compute the text features and save it in `MJHQ30K_TEXT_FEAT`:
183
+ ```bash
184
+ python projects/evaluate/mjhq30k/calculate_clip.py projects/evaluate/mjhq30k/meta_data.json $MJHQ30K_TEXT_FEAT/clip_feat.safetensors --save-stats
185
+ ```
186
+ Then run the following codes to calculate the CLIP Score:
187
+ ```bash
188
+ python projects/evaluate/mjhq30k/calculate_clip.py $MJHQ30K_TEXT_FEAT/clip_feat.safetensors $SAMPLING_DIR
189
+ ```
190
+
191
+
192
+
193
+ ### Class-guided Image Generation
194
+
195
+ The sampling generates a folder of samples to compute FID, Inception Score and other metrics.
196
+ <b>Note that we do not pack the generate samples as a `.npz` file, this does not affect the calculation of FID and other metrics.</b>
197
+ Please follow the [ADM's TensorFlow
198
+ evaluation suite](https://github.com/openai/guided-diffusion/tree/main/evaluations)
199
+ to setup the conda-environment and download the reference batch.
200
+
201
+ ```bash
202
+ wget -c "https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/classify_image_graph_def.pb" -O checkpoints/classify_image_graph_def.pb
203
+ ```
204
+
205
+
206
+ Given the directory of the reference batch `REFERENCE_DIR` and the directory of the generated images `SAMPLING_DIR`, run the following codes:
207
+ ```bash
208
+ python projects/evaluate/adm_evaluator.py $REFERENCE_DIR $SAMPLING_DIR
209
+ ```
210
+
211
+
212
+
213
+
214
+
215
+ ## 4. Training
216
+
217
+ ### 4.1 Dataset Setup
218
+
219
+ Currently, we provide all the [preprocessed dataset](https://huggingface.co/datasets/GoodEnough/NiT-Preprocessed-ImageNet1K) for ImageNet1K. Please use the following commands to download the preprocessed latents.
220
+
221
+ ```bash
222
+ bash tools/download_imagenet_256x256.sh
223
+ bash tools/download_imagenet_512x512.sh
224
+ ```
225
+
226
+ For text-to-image generation, we provide a [toy dataset](https://huggingface.co/datasets/GoodEnough/TiM-Toy-T2I-Dataset). Please use the following command to download this dataset.
227
+ ```bash
228
+ bash tools/download_toy_t2i_dataset.sh
229
+ ```
230
+
231
+
232
+ ### 4.2 Download Image Encoder
233
+
234
+ We use RADIO-v2.5-b as our image encoder for REPA-loss.
235
+
236
+ ```bash
237
+ wget -c "https://huggingface.co/nvidia/RADIO/resolve/main/radio-v2.5-b_half.pth.tar" -O checkpoints/radio-v2.5-b_half.pth.tar
238
+ ```
239
+
240
+
241
+ ### 4.3 Training Scripts
242
+
243
+ Specify the `image_dir` in `configs/c2i/tim_b_p4.yaml` and train the base-model (131M) on ImageNet-256:
244
+ ```bash
245
+ bash scripts/train/c2i/train_tim_c2i_b.sh
246
+ ```
247
+
248
+ Specify the `image_dir` in `configs/c2i/tim_xl_p2_256.yaml` and train the XL-model (664M) on ImageNet-256:
249
+ ```bash
250
+ bash scripts/train/c2i/train_tim_c2i_xl_256.sh
251
+ ```
252
+
253
+ Specify the `image_dir` in `configs/c2i/tim_xl_p2_512.yaml` and train the XL-model (664M) on ImageNet-512:
254
+ ```bash
255
+ bash scripts/train/c2i/train_tim_c2i_xl_512.sh
256
+ ```
257
+
258
+ Specify the `root_dir` in `configs/t2i/tim_xl_p1_t2i.yaml` and train the T2I-model (865M) on Toy-T2I-Dataset:
259
+ ```bash
260
+ bash scripts/train/t2i/train_tim_t2i.sh
261
+ ```
262
+
263
+
264
+
265
+
266
+ ## Citations
267
+ If you find the project useful, please kindly cite:
268
+ ```bibtex
269
+ @article{wang2025transition,
270
+ title={Transition Models: Rethinking the Generative Learning Objective},
271
+ author={Wang, Zidong and Zhang, Yiyuan and Yue, Xiaoyu and Yue, Xiangyu and Li, Yangguang and Ouyang, Wanli and Bai, Lei},
272
+ year={2025},
273
+ eprint={2509.04394},
274
+ archivePrefix={arXiv},
275
+ primaryClass={cs.LG}
276
+ }
277
+ ```
278
+ https://arxiv.org/abs/
279
+ ## License
280
+ This project is licensed under the Apache-2.0 license.
app.py ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces # type: ignore - ZeroGPU spaces library
3
+ import numpy as np
4
+ import random
5
+ import torch
6
+ import functools
7
+ from pathlib import Path
8
+ from PIL import Image
9
+ from omegaconf import OmegaConf # type: ignore - YAML configuration library
10
+ from tim.schedulers.transition import TransitionSchedule
11
+ from tim.utils.misc_utils import instantiate_from_config, init_from_ckpt
12
+ from tim.models.vae import get_sd_vae, get_dc_ae, sd_vae_decode, dc_ae_decode
13
+ from tim.models.utils.text_encoders import load_text_encoder, encode_prompt
14
+
15
+ # Configuration
16
+ dtype = torch.bfloat16
17
+ device = "cuda" if torch.cuda.is_available() else "cpu"
18
+ MAX_SEED = np.iinfo(np.int32).max
19
+ MAX_IMAGE_SIZE = 2048
20
+
21
+ # Global variables to store loaded components
22
+ model = None
23
+ scheduler = None
24
+ text_encoder = None
25
+ tokenizer = None
26
+ decode_func = None
27
+ null_cap_feat = None
28
+ null_cap_mask = None
29
+ config = None
30
+
31
+
32
+ def load_model_components(device: str = "cuda"):
33
+ """Load all model components once at startup"""
34
+ global \
35
+ model, \
36
+ scheduler, \
37
+ text_encoder, \
38
+ tokenizer, \
39
+ decode_func, \
40
+ null_cap_feat, \
41
+ null_cap_mask, \
42
+ config
43
+
44
+ try:
45
+ # Load configuration
46
+ config_path = "configs/t2i/tim_xl_p1_t2i.yaml"
47
+ ckpt_path = "checkpoints/t2i_model.bin"
48
+
49
+ if not Path(config_path).exists():
50
+ raise FileNotFoundError(f"Config file not found: {config_path}")
51
+ if not Path(ckpt_path).exists():
52
+ raise FileNotFoundError(f"Checkpoint file not found: {ckpt_path}")
53
+
54
+ print("Loading configuration...")
55
+ config = OmegaConf.load(config_path)
56
+ model_config = config.model
57
+
58
+ print("Loading VAE...")
59
+ # Load VAE
60
+ if "dc-ae" in model_config.vae_dir:
61
+ dc_ae = get_dc_ae(model_config.vae_dir, dtype=torch.float32, device=device)
62
+ dc_ae.enable_tiling(2560, 2560, 2560, 2560)
63
+ decode_func = functools.partial(dc_ae_decode, dc_ae, slice_vae=True)
64
+ elif "sd-vae" in model_config.vae_dir:
65
+ sd_vae = get_sd_vae(
66
+ model_config.vae_dir, dtype=torch.float32, device=device
67
+ )
68
+ decode_func = functools.partial(sd_vae_decode, sd_vae, slice_vae=True)
69
+ else:
70
+ raise ValueError("Unsupported VAE type")
71
+
72
+ print("Loading text encoder...")
73
+ # Load text encoder
74
+ text_encoder, tokenizer = load_text_encoder(
75
+ text_encoder_dir=model_config.text_encoder_dir,
76
+ device=device,
77
+ weight_dtype=torch.bfloat16,
78
+ )
79
+
80
+ print("Encoding null caption...")
81
+ # Get null caption features
82
+ null_cap_feat, null_cap_mask = encode_prompt(
83
+ tokenizer,
84
+ text_encoder,
85
+ device,
86
+ torch.bfloat16,
87
+ [""],
88
+ model_config.use_last_hidden_state,
89
+ max_seq_length=model_config.max_seq_length,
90
+ )
91
+
92
+ print("Loading main model...")
93
+ # Load main model
94
+ model = instantiate_from_config(model_config.network).to(
95
+ device=device, dtype=dtype
96
+ )
97
+ init_from_ckpt(model, checkpoint_dir=ckpt_path, ignore_keys=None, verbose=True)
98
+ model.eval()
99
+
100
+ print("Loading scheduler...")
101
+ # Load scheduler
102
+ transport = instantiate_from_config(model_config.transport)
103
+ scheduler = TransitionSchedule(
104
+ transport=transport, **OmegaConf.to_container(model_config.transition_loss)
105
+ )
106
+
107
+ print("All components loaded successfully!")
108
+
109
+ except Exception as e:
110
+ print(f"Error loading model components: {e}")
111
+ raise e
112
+
113
+
114
+ @spaces.GPU(duration=60)
115
+ def generate_image(
116
+ prompt,
117
+ seed=42,
118
+ randomize_seed=False,
119
+ width=1024,
120
+ height=1024,
121
+ guidance_scale=2.5,
122
+ num_inference_steps=16,
123
+ progress=gr.Progress(track_tqdm=True),
124
+ ):
125
+ """Generate image from text prompt"""
126
+ try:
127
+ # Validate inputs
128
+ if not prompt or len(prompt.strip()) == 0:
129
+ raise ValueError("Please enter a valid prompt")
130
+
131
+ if model is None or scheduler is None:
132
+ raise RuntimeError("Model components not loaded. Please check the setup.")
133
+
134
+ # Validate dimensions
135
+ if (
136
+ width < 256
137
+ or width > MAX_IMAGE_SIZE
138
+ or height < 256
139
+ or height > MAX_IMAGE_SIZE
140
+ ):
141
+ raise ValueError(
142
+ f"Image dimensions must be between 256 and {MAX_IMAGE_SIZE}"
143
+ )
144
+
145
+ if width % 32 != 0 or height % 32 != 0:
146
+ raise ValueError("Image dimensions must be divisible by 32")
147
+
148
+ if randomize_seed:
149
+ seed = random.randint(0, MAX_SEED)
150
+
151
+ generator = torch.Generator(device=device).manual_seed(seed)
152
+
153
+ # Calculate latent dimensions
154
+ spatial_downsample = 32 if "dc-ae" in config.model.vae_dir else 8
155
+ latent_h = int(height / spatial_downsample)
156
+ latent_w = int(width / spatial_downsample)
157
+
158
+ progress(0.1, desc="Generating random latent...")
159
+
160
+ # Generate random latent
161
+ z = torch.randn(
162
+ (1, model.in_channels, latent_h, latent_w),
163
+ device=device,
164
+ dtype=dtype,
165
+ generator=generator,
166
+ )
167
+
168
+ progress(0.1, desc="Encoding prompt...")
169
+
170
+ # Encode prompt
171
+ cap_features, cap_mask = encode_prompt(
172
+ tokenizer,
173
+ text_encoder,
174
+ device,
175
+ dtype,
176
+ [prompt],
177
+ config.model.use_last_hidden_state,
178
+ max_seq_length=config.model.max_seq_length,
179
+ )
180
+
181
+ cur_max_seq_len = cap_mask.sum(dim=-1).max()
182
+ y = cap_features[:, :cur_max_seq_len]
183
+
184
+ y_null = null_cap_feat[:, :cur_max_seq_len]
185
+ y_null = y_null.expand(y.shape[0], cur_max_seq_len, null_cap_feat.shape[-1])
186
+
187
+ # Generate image
188
+ with torch.no_grad():
189
+ samples = scheduler.sample(
190
+ model,
191
+ y,
192
+ y_null,
193
+ z,
194
+ T_max=1.0,
195
+ T_min=0.0,
196
+ num_steps=num_inference_steps,
197
+ cfg_scale=guidance_scale,
198
+ cfg_low=0.0,
199
+ cfg_high=1.0,
200
+ stochasticity_ratio=0.0,
201
+ sample_type="transition",
202
+ step_callback=lambda step: progress(
203
+ 0.1 + 0.9 * (step / num_inference_steps), desc="Generating image..."
204
+ ),
205
+ )[-1]
206
+ samples = samples.to(torch.float32)
207
+
208
+ # Decode to image
209
+ images = decode_func(samples)
210
+ images = (
211
+ torch.clamp(127.5 * images + 128.0, 0, 255)
212
+ .permute(0, 2, 3, 1)
213
+ .to(torch.uint8)
214
+ .contiguous()
215
+ )
216
+ image = Image.fromarray(images[0].cpu().numpy())
217
+
218
+ progress(1.0, desc="Complete!")
219
+
220
+ return image, seed
221
+
222
+ except Exception as e:
223
+ print(f"Error during image generation: {e}")
224
+ # Return a placeholder image or error message
225
+ error_img = Image.new("RGB", (512, 512), color="red")
226
+ return error_img, seed
227
+
228
+
229
+ # Example prompts
230
+ examples = [
231
+ ["a tiny astronaut hatching from an egg on the moon"],
232
+ ["🐶 Wearing 🕶 flying on the 🌈"],
233
+ ["an anime illustration of a wiener schnitzel"],
234
+ ["a photorealistic landscape of mountains at sunset"],
235
+ ["a majestic lion in a golden savanna at sunset"],
236
+ ["a futuristic city with flying cars and neon lights"],
237
+ ["a cozy cabin in a snowy forest with smoke coming from the chimney"],
238
+ ["a beautiful mermaid swimming in crystal clear water"],
239
+ ]
240
+
241
+ # CSS styling
242
+ css = """
243
+ #col-container {
244
+ margin: 0 auto;
245
+ max-width: 520px;
246
+ }
247
+ """
248
+
249
+ # Initialize model components
250
+ try:
251
+ load_model_components(device)
252
+ print("Model components loaded successfully!")
253
+ except Exception as e:
254
+ print(f"Error loading model components: {e}")
255
+ print("Please ensure config and checkpoint files are available")
256
+
257
+ # Create Gradio interface
258
+ with gr.Blocks(css=css) as demo:
259
+ with gr.Column(elem_id="col-container"):
260
+ gr.Markdown("# TiM Text-to-Image Generator")
261
+ gr.Markdown(
262
+ "Generate high-quality images from text prompts using the TiM (Transition in Matching) model"
263
+ )
264
+
265
+ with gr.Row():
266
+ prompt = gr.Text(
267
+ label="Prompt",
268
+ show_label=False,
269
+ max_lines=1,
270
+ placeholder="Enter your prompt",
271
+ container=False,
272
+ )
273
+ run_button = gr.Button("Generate", scale=0)
274
+
275
+ result = gr.Image(label="Result", show_label=False)
276
+
277
+ with gr.Accordion("Advanced Settings", open=False):
278
+ seed = gr.Slider(
279
+ label="Seed",
280
+ minimum=0,
281
+ maximum=MAX_SEED,
282
+ step=1,
283
+ value=0,
284
+ )
285
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
286
+
287
+ with gr.Row():
288
+ width = gr.Slider(
289
+ label="Width",
290
+ minimum=256,
291
+ maximum=MAX_IMAGE_SIZE,
292
+ step=32,
293
+ value=1024,
294
+ )
295
+ height = gr.Slider(
296
+ label="Height",
297
+ minimum=256,
298
+ maximum=MAX_IMAGE_SIZE,
299
+ step=32,
300
+ value=1024,
301
+ )
302
+
303
+ with gr.Row():
304
+ guidance_scale = gr.Slider(
305
+ label="Guidance Scale",
306
+ minimum=1,
307
+ maximum=15,
308
+ step=0.1,
309
+ value=2.5,
310
+ )
311
+ num_inference_steps = gr.Slider(
312
+ label="Number of inference steps",
313
+ minimum=1,
314
+ maximum=50,
315
+ step=1,
316
+ value=16,
317
+ )
318
+
319
+ gr.Examples(
320
+ examples=examples,
321
+ fn=generate_image,
322
+ inputs=[prompt],
323
+ outputs=[result, seed],
324
+ cache_examples="lazy",
325
+ )
326
+
327
+ gr.on(
328
+ triggers=[run_button.click, prompt.submit],
329
+ fn=generate_image,
330
+ inputs=[
331
+ prompt,
332
+ seed,
333
+ randomize_seed,
334
+ width,
335
+ height,
336
+ guidance_scale,
337
+ num_inference_steps,
338
+ ],
339
+ outputs=[result, seed],
340
+ )
341
+
342
+ if __name__ == "__main__":
343
+ demo.launch()
configs/c2i/tim_b_p4.yaml ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ transport:
3
+ target: tim.schedulers.transports.OT_FM
4
+ params:
5
+ P_mean: -0.4
6
+ P_std: 1.0
7
+ sigma_d: 1.0
8
+ transition_loss:
9
+ diffusion_ratio: 0.5
10
+ consistency_ratio: 0.1
11
+ derivative_type: dde
12
+ differential_epsilon: 0.005
13
+ weight_time_type: sqrt
14
+ weight_time_tangent: True
15
+ network:
16
+ target: tim.models.c2i.tim_model.TiM
17
+ params:
18
+ input_size: 32
19
+ patch_size: 4
20
+ in_channels: 4
21
+ class_dropout_prob: 0.1
22
+ num_classes: 1000
23
+ depth: 12
24
+ hidden_size: 768
25
+ num_heads: 12
26
+ encoder_depth: 4
27
+ qk_norm: True
28
+ z_dim: 768
29
+ new_condition: t-r
30
+ use_new_embed: True
31
+ distance_aware: True
32
+ lora_hidden_size: 256
33
+ # pretrained_vae:
34
+ vae_dir: stabilityai/sd-vae-ft-ema
35
+ # repa encoder
36
+ enc_dir: checkpoints/radio/radio-v2.5-b_half.pth.tar
37
+ proj_coeff: 1.0
38
+ # ema
39
+ use_ema: True
40
+ ema_decay: 0.9999
41
+
42
+ data:
43
+ data_type: latent
44
+ dataset:
45
+ latent_dir: datasets/imagenet1k/sd-vae-ft-ema-256x256
46
+ image_dir: datasets/imagenet1k/images/train
47
+ image_size: 256
48
+ dataloader:
49
+ num_workers: 16
50
+ batch_size: 256 # Batch size (per device) for the training dataloader.
51
+
52
+
53
+
54
+ training:
55
+ tracker: null
56
+ max_train_steps: 100000
57
+ checkpointing_steps: 2000
58
+ checkpoints_total_limit: 2
59
+ resume_from_checkpoint: latest
60
+ learning_rate: 1.0e-4
61
+ learning_rate_base_batch_size: 256
62
+ scale_lr: True
63
+ lr_scheduler: constant # "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"]
64
+ lr_warmup_steps: 0
65
+ gradient_accumulation_steps: 1
66
+ optimizer:
67
+ target: torch.optim.AdamW
68
+ params:
69
+ # betas: ${tuple:0.9, 0.999}
70
+ betas: [0.9, 0.95]
71
+ weight_decay: 1.0e-2
72
+ eps: 1.0e-6
73
+ max_grad_norm: 1.0
74
+ proportion_empty_prompts: 0.0
75
+ mixed_precision: bf16 # ["no", "fp16", "bf16"]
76
+ allow_tf32: True
77
+ validation_steps: 500
78
+ checkpoint_list: [100000, 200000, 300000]
configs/c2i/tim_xl_p1_512.yaml ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ transport:
3
+ target: tim.schedulers.transports.OT_FM
4
+ params:
5
+ P_mean: -0.4
6
+ P_std: 1.0
7
+ sigma_d: 1.0
8
+ T_max: 1.0
9
+ T_min: 0.0
10
+ enhance_target: False
11
+ w_gt: 1.0
12
+ w_cond: 0.0
13
+ w_start: 0.0
14
+ w_end: 0.0
15
+ transition_loss:
16
+ diffusion_ratio: 0.5
17
+ consistency_ratio: 0.1
18
+ derivative_type: dde
19
+ differential_epsilon: 0.005
20
+ weight_time_type: sqrt
21
+ weight_time_tangent: True
22
+ network:
23
+ target: tim.models.c2i.tim_model.TiM
24
+ params:
25
+ input_size: 16
26
+ patch_size: 1
27
+ in_channels: 32
28
+ class_dropout_prob: 0.1
29
+ num_classes: 1000
30
+ depth: 28
31
+ hidden_size: 1152
32
+ num_heads: 16
33
+ encoder_depth: 8
34
+ qk_norm: True
35
+ z_dim: 768
36
+ new_condition: t-r
37
+ use_new_embed: True
38
+ distance_aware: True
39
+ lora_hidden_size: 384
40
+ # pretrained_vae:
41
+ vae_dir: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers
42
+ # repa encoder
43
+ enc_dir: checkpoints/radio/radio-v2.5-b_half.pth.tar
44
+ proj_coeff: 1.0
45
+ # ema
46
+ use_ema: True
47
+ ema_decay: 0.9999
48
+
49
+ data:
50
+ data_type: latent
51
+ dataset:
52
+ latent_dir: datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-512x512
53
+ image_dir: datasets/imagenet1k/images/train
54
+ image_size: 512
55
+ dataloader:
56
+ num_workers: 4
57
+ batch_size: 64 # Batch size (per device) for the training dataloader.
58
+
59
+
60
+
61
+ training:
62
+ tracker: null
63
+ max_train_steps: 750000
64
+ checkpointing_steps: 2000
65
+ checkpoints_total_limit: 2
66
+ resume_from_checkpoint: latest
67
+ learning_rate: 1.0e-4
68
+ learning_rate_base_batch_size: 256
69
+ scale_lr: True
70
+ lr_scheduler: constant # "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"]
71
+ lr_warmup_steps: 0
72
+ gradient_accumulation_steps: 1
73
+ optimizer:
74
+ target: torch.optim.AdamW
75
+ params:
76
+ # betas: ${tuple:0.9, 0.999}
77
+ betas: [0.9, 0.95]
78
+ weight_decay: 1.0e-2
79
+ eps: 1.0e-6
80
+ max_grad_norm: 1.0
81
+ proportion_empty_prompts: 0.0
82
+ mixed_precision: bf16 # ["no", "fp16", "bf16"]
83
+ allow_tf32: True
84
+ validation_steps: 500
85
+ checkpoint_list: [100000, 250000, 500000]
configs/c2i/tim_xl_p1_512_mg.yaml ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ transport:
3
+ target: tim.schedulers.transports.OT_FM
4
+ params:
5
+ P_mean: -0.4
6
+ P_std: 1.0
7
+ sigma_d: 1.0
8
+ T_max: 1.0
9
+ T_min: 0.0
10
+ enhance_target: True
11
+ w_gt: 1.0
12
+ w_cond: 0.75
13
+ w_start: 0.3
14
+ w_end: 0.8
15
+ transition_loss:
16
+ diffusion_ratio: 0.5
17
+ consistency_ratio: 0.1
18
+ derivative_type: dde
19
+ differential_epsilon: 0.005
20
+ weight_time_type: sqrt
21
+ weight_time_tangent: True
22
+ network:
23
+ target: tim.models.c2i.tim_model.TiM
24
+ params:
25
+ input_size: 16
26
+ patch_size: 1
27
+ in_channels: 32
28
+ class_dropout_prob: 0.1
29
+ num_classes: 1000
30
+ depth: 28
31
+ hidden_size: 1152
32
+ num_heads: 16
33
+ encoder_depth: 8
34
+ qk_norm: True
35
+ z_dim: 768
36
+ new_condition: t-r
37
+ use_new_embed: True
38
+ distance_aware: True
39
+ lora_hidden_size: 384
40
+ # pretrained_vae:
41
+ vae_dir: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers
42
+ # repa encoder
43
+ enc_dir: checkpoints/radio/radio-v2.5-b_half.pth.tar
44
+ proj_coeff: 1.0
45
+ # ema
46
+ use_ema: True
47
+ ema_decay: 0.9999
48
+
49
+ data:
50
+ data_type: latent
51
+ dataset:
52
+ latent_dir: datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-512x512
53
+ image_dir: datasets/imagenet1k/images/train
54
+ image_size: 512
55
+ dataloader:
56
+ num_workers: 4
57
+ batch_size: 64 # Batch size (per device) for the training dataloader.
58
+
59
+
60
+
61
+ training:
62
+ tracker: null
63
+ max_train_steps: 750000
64
+ checkpointing_steps: 2000
65
+ checkpoints_total_limit: 2
66
+ resume_from_checkpoint: latest
67
+ learning_rate: 1.0e-4
68
+ learning_rate_base_batch_size: 256
69
+ scale_lr: True
70
+ lr_scheduler: constant # "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"]
71
+ lr_warmup_steps: 0
72
+ gradient_accumulation_steps: 1
73
+ optimizer:
74
+ target: torch.optim.AdamW
75
+ params:
76
+ # betas: ${tuple:0.9, 0.999}
77
+ betas: [0.9, 0.95]
78
+ weight_decay: 1.0e-2
79
+ eps: 1.0e-6
80
+ max_grad_norm: 1.0
81
+ proportion_empty_prompts: 0.0
82
+ mixed_precision: bf16 # ["no", "fp16", "bf16"]
83
+ allow_tf32: True
84
+ validation_steps: 500
85
+ checkpoint_list: [100000, 250000, 500000]
configs/c2i/tim_xl_p2_256.yaml ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ transport:
3
+ target: tim.schedulers.transports.OT_FM
4
+ params:
5
+ P_mean: -0.4
6
+ P_std: 1.0
7
+ sigma_d: 1.0
8
+ T_max: 1.0
9
+ T_min: 0.0
10
+ enhance_target: False
11
+ w_gt: 1.0
12
+ w_cond: 0.0
13
+ w_start: 0.0
14
+ w_end: 0.0
15
+ transition_loss:
16
+ diffusion_ratio: 0.5
17
+ consistency_ratio: 0.1
18
+ derivative_type: dde
19
+ differential_epsilon: 0.005
20
+ weight_time_type: sqrt
21
+ weight_time_tangent: True
22
+ network:
23
+ target: tim.models.c2i.tim_model.TiM
24
+ params:
25
+ input_size: 32
26
+ patch_size: 2
27
+ in_channels: 4
28
+ class_dropout_prob: 0.1
29
+ num_classes: 1000
30
+ depth: 28
31
+ hidden_size: 1152
32
+ num_heads: 16
33
+ encoder_depth: 8
34
+ qk_norm: True
35
+ z_dim: 768
36
+ new_condition: t-r
37
+ use_new_embed: True
38
+ distance_aware: True
39
+ lora_hidden_size: 384
40
+ # pretrained_vae:
41
+ vae_dir: stabilityai/sd-vae-ft-ema
42
+ # repa encoder
43
+ enc_dir: checkpoints/radio/radio-v2.5-b_half.pth.tar
44
+ proj_coeff: 1.0
45
+ # ema
46
+ use_ema: True
47
+ ema_decay: 0.9999
48
+
49
+ data:
50
+ data_type: latent
51
+ dataset:
52
+ latent_dir: datasets/imagenet1k/sd-vae-ft-ema-256x256
53
+ image_dir: datasets/imagenet1k/images/train
54
+ image_size: 256
55
+ dataloader:
56
+ num_workers: 4
57
+ batch_size: 64 # Batch size (per device) for the training dataloader.
58
+
59
+
60
+
61
+ training:
62
+ tracker: null
63
+ max_train_steps: 750000
64
+ checkpointing_steps: 2000
65
+ checkpoints_total_limit: 2
66
+ resume_from_checkpoint: latest
67
+ learning_rate: 1.0e-4
68
+ learning_rate_base_batch_size: 256
69
+ scale_lr: True
70
+ lr_scheduler: constant # "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"]
71
+ lr_warmup_steps: 0
72
+ gradient_accumulation_steps: 1
73
+ optimizer:
74
+ target: torch.optim.AdamW
75
+ params:
76
+ # betas: ${tuple:0.9, 0.999}
77
+ betas: [0.9, 0.95]
78
+ weight_decay: 1.0e-2
79
+ eps: 1.0e-6
80
+ max_grad_norm: 1.0
81
+ proportion_empty_prompts: 0.0
82
+ mixed_precision: bf16 # ["no", "fp16", "bf16"]
83
+ allow_tf32: True
84
+ validation_steps: 500
85
+ checkpoint_list: [100000, 250000, 500000]
configs/c2i/tim_xl_p2_256_mg.yaml ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ transport:
3
+ target: tim.schedulers.transports.OT_FM
4
+ params:
5
+ P_mean: -0.4
6
+ P_std: 1.0
7
+ sigma_d: 1.0
8
+ T_max: 1.0
9
+ T_min: 0.0
10
+ enhance_target: True
11
+ w_gt: 1.0
12
+ w_cond: 0.75
13
+ w_start: 0.3
14
+ w_end: 0.8
15
+ transition_loss:
16
+ diffusion_ratio: 0.5
17
+ consistency_ratio: 0.1
18
+ derivative_type: dde
19
+ differential_epsilon: 0.005
20
+ weight_time_type: sqrt
21
+ weight_time_tangent: True
22
+ network:
23
+ target: tim.models.c2i.tim_model.TiM
24
+ params:
25
+ input_size: 32
26
+ patch_size: 2
27
+ in_channels: 4
28
+ class_dropout_prob: 0.1
29
+ num_classes: 1000
30
+ depth: 28
31
+ hidden_size: 1152
32
+ num_heads: 16
33
+ encoder_depth: 8
34
+ qk_norm: True
35
+ z_dim: 768
36
+ new_condition: t-r
37
+ use_new_embed: True
38
+ distance_aware: True
39
+ lora_hidden_size: 384
40
+ # pretrained_vae:
41
+ vae_dir: stabilityai/sd-vae-ft-ema
42
+ # repa encoder
43
+ enc_dir: checkpoints/radio/radio-v2.5-b_half.pth.tar
44
+ proj_coeff: 1.0
45
+ # ema
46
+ use_ema: True
47
+ ema_decay: 0.9999
48
+
49
+ data:
50
+ data_type: latent
51
+ dataset:
52
+ latent_dir: datasets/imagenet1k/sd-vae-ft-ema-256x256
53
+ image_dir: datasets/imagenet1k/images/train
54
+ image_size: 256
55
+ dataloader:
56
+ num_workers: 4
57
+ batch_size: 64 # Batch size (per device) for the training dataloader.
58
+
59
+
60
+
61
+ training:
62
+ tracker: null
63
+ max_train_steps: 750000
64
+ checkpointing_steps: 2000
65
+ checkpoints_total_limit: 2
66
+ resume_from_checkpoint: latest
67
+ learning_rate: 1.0e-4
68
+ learning_rate_base_batch_size: 256
69
+ scale_lr: True
70
+ lr_scheduler: constant # "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"]
71
+ lr_warmup_steps: 0
72
+ gradient_accumulation_steps: 1
73
+ optimizer:
74
+ target: torch.optim.AdamW
75
+ params:
76
+ # betas: ${tuple:0.9, 0.999}
77
+ betas: [0.9, 0.95]
78
+ weight_decay: 1.0e-2
79
+ eps: 1.0e-6
80
+ max_grad_norm: 1.0
81
+ proportion_empty_prompts: 0.0
82
+ mixed_precision: bf16 # ["no", "fp16", "bf16"]
83
+ allow_tf32: True
84
+ validation_steps: 500
85
+ checkpoint_list: [100000, 250000, 500000]
configs/t2i/tim_xl_p1_t2i.yaml ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ transport:
3
+ target: tim.schedulers.transports.OT_FM
4
+ params:
5
+ P_mean: 0.0
6
+ P_std: 1.6
7
+ sigma_d: 1.0
8
+ transition_loss:
9
+ diffusion_ratio: 0.5
10
+ consistency_ratio: 0.1
11
+ derivative_type: dde
12
+ differential_epsilon: 0.005
13
+ weight_time_type: sqrt
14
+ weight_time_tangent: True
15
+ network:
16
+ target: tim.models.t2i.tim_model.TiM
17
+ params:
18
+ input_size: 16
19
+ patch_size: 1
20
+ in_channels: 32
21
+ depth: 28
22
+ hidden_size: 1152
23
+ cap_feat_dim: 1152
24
+ num_heads: 16
25
+ encoder_depth: 8
26
+ qk_norm: True
27
+ z_dim: 768
28
+ new_condition: t-r
29
+ use_new_embed: True
30
+ distance_aware: True
31
+ lora_hidden_size: 384
32
+ # pretrained_vae:
33
+ vae_dir: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers
34
+ # text encoder
35
+ text_encoder_dir: google/gemma-3-1b-it
36
+ proportion_empty_prompts: 0.1
37
+ use_last_hidden_state: True
38
+ max_seq_length: 256
39
+ # repa encoder
40
+ enc_dir: checkpoints/radio/radio-v2.5-b_half.pth.tar
41
+ proj_coeff: 1.0
42
+ # ema
43
+ use_ema: True
44
+ ema_decay: 0.9999
45
+
46
+ data:
47
+ data_type: image_ms
48
+ dataset:
49
+ root_dir: datasets/t2i_toy_dataset
50
+ packed_json: datasets/t2i_toy_dataset/bucket_sampler.json
51
+ jsonl_dir: datasets/t2i_toy_dataset/data_info.jsonl
52
+ dataloader:
53
+ num_workers: 4
54
+ batch_size: 128 # Batch size (per device) for the training dataloader.
55
+
56
+
57
+ training:
58
+ tracker: null
59
+ max_train_steps: 500000
60
+ checkpointing_steps: 1000
61
+ checkpoints_total_limit: 2
62
+ resume_from_checkpoint: latest
63
+ learning_rate: 1.0e-4
64
+ learning_rate_base_batch_size: 512
65
+ scale_lr: True
66
+ lr_scheduler: constant # "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"]
67
+ lr_warmup_steps: 0
68
+ gradient_accumulation_steps: 1
69
+ optimizer:
70
+ target: torch.optim.AdamW
71
+ params:
72
+ # betas: ${tuple:0.9, 0.999}
73
+ betas: [0.9, 0.95]
74
+ weight_decay: 1.0e-2
75
+ eps: 1.0e-6
76
+ max_grad_norm: 1.0
77
+ proportion_empty_prompts: 0.0
78
+ mixed_precision: bf16 # ["no", "fp16", "bf16"]
79
+ allow_tf32: True
80
+ validation_steps: 500
81
+ checkpoint_list: [100000, 200000, 300000, 400000]
pyproject.toml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "tim"
3
+ version = "0.1.0"
4
+ description = "Add your description here"
5
+ readme = "README.md"
6
+ requires-python = ">=3.10"
7
+ dependencies = [
8
+ "accelerate>=0.33.0",
9
+ "bitsandbytes>=0.47.0",
10
+ "diffusers==0.33.1",
11
+ "einops>=0.8.1",
12
+ "flash-attn>=2.8.3",
13
+ "gradio>=5.44.1",
14
+ "imageio==2.34.2",
15
+ "imageio-ffmpeg==0.5.1",
16
+ "moviepy==1.0.3",
17
+ "numpy==1.26.0",
18
+ "omegaconf>=2.3.0",
19
+ "pillow==9.5.0",
20
+ "safetensors>=0.6.2",
21
+ "sentencepiece>=0.2.0",
22
+ "spaces>=0.40.1",
23
+ "streamlit>=1.38.0",
24
+ "timm>=1.0.19",
25
+ "torch>=2.8.0",
26
+ "torchdiffeq>=0.2.5",
27
+ "torchvision>=0.23.0",
28
+ "transformers>=4.44.2",
29
+ "triton>=3.4.0",
30
+ "wandb>=0.21.3",
31
+ ]
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio>=4.0.0
2
+ spaces>=0.28.0
3
+ torch>=2.1.0
4
+ torchvision
5
+ diffusers
6
+ transformers>=4.25.0
7
+ omegaconf
8
+ einops
9
+ numpy
10
+ Pillow
11
+ safetensors
12
+ tqdm
13
+ flash-attn>=2.0.0
14
+ accelerate
15
+ -e .
setup.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import find_packages, setup
2
+
3
+ setup(
4
+ name="tim",
5
+ version="0.0.1",
6
+ description="",
7
+ packages=find_packages(),
8
+ install_requires=[
9
+ "torch",
10
+ "numpy",
11
+ ],
12
+ )
tim/data/c2i_data.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import datetime
4
+ import torchvision
5
+ import numpy as np
6
+ import torch
7
+
8
+ from omegaconf import OmegaConf
9
+ from PIL import Image
10
+ from torch.utils.data import DataLoader, Dataset
11
+ from torchvision.datasets import ImageFolder
12
+ from torchvision import transforms
13
+ from torchvision.transforms.functional import hflip
14
+ from accelerate.logging import get_logger
15
+ from safetensors.torch import load_file
16
+ from .sampler_utils import get_train_sampler
17
+
18
+
19
+ logger = get_logger(__name__, log_level="INFO")
20
+
21
+
22
+ def center_crop_arr(pil_image, image_size):
23
+ """
24
+ Center cropping implementation from ADM.
25
+ https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
26
+ """
27
+ while min(*pil_image.size) >= 2 * image_size:
28
+ pil_image = pil_image.resize(
29
+ tuple(x // 2 for x in pil_image.size), resample=Image.Resampling.BOX
30
+ )
31
+
32
+ scale = image_size / min(*pil_image.size)
33
+ pil_image = pil_image.resize(
34
+ tuple(round(x * scale) for x in pil_image.size), resample=Image.Resampling.BICUBIC
35
+ )
36
+
37
+ arr = np.array(pil_image)
38
+ crop_y = (arr.shape[0] - image_size) // 2
39
+ crop_x = (arr.shape[1] - image_size) // 2
40
+ return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
41
+
42
+ class ImagenetDictWrapper(Dataset):
43
+ def __init__(self, dataset):
44
+ super().__init__()
45
+ self.dataset = dataset
46
+
47
+ def __getitem__(self, i):
48
+ x, y = self.dataset[i]
49
+ return {"image": x, "label": y}
50
+
51
+ def __len__(self):
52
+ return len(self.dataset)
53
+
54
+ class ImagenetLatentDataset(Dataset):
55
+ def __init__(self, latent_dir, image_dir, image_size):
56
+ super().__init__()
57
+ self.RandomHorizontalFlipProb = 0.5
58
+ self.transform = transforms.Compose([
59
+ transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, image_size)),
60
+ transforms.Lambda(lambda pil_image: (pil_image, hflip(pil_image))),
61
+ transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])), # returns a 4D tensor
62
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
63
+ ])
64
+
65
+ self.dataset = []
66
+ for class_folder in os.listdir(image_dir):
67
+ if os.path.isfile(os.path.join(image_dir, class_folder)):
68
+ continue
69
+ latent_class_folder = os.path.join(latent_dir, class_folder)
70
+ image_class_folder = os.path.join(image_dir, class_folder)
71
+ for file in os.listdir(image_class_folder):
72
+ self.dataset.append(
73
+ dict(
74
+ latent=os.path.join(latent_class_folder, file.split('.')[0]+'.safetensors'),
75
+ image=os.path.join(image_class_folder, file)
76
+ )
77
+ )
78
+
79
+ def __len__(self):
80
+ return len(self.dataset)
81
+
82
+ def __getitem__(self, idx):
83
+ data_item = dict()
84
+ data = load_file(self.dataset[idx]['latent'])
85
+ image = self.transform(Image.open(self.dataset[idx]['image']).convert("RGB"))
86
+ if torch.rand(1) < self.RandomHorizontalFlipProb:
87
+ data_item['latent'] = data['latent'][0]
88
+ data_item['image'] = image[0]
89
+ else:
90
+ data_item['latent'] = data['latent'][1]
91
+ data_item['image'] = image[1]
92
+ data_item['label'] = data['label']
93
+ return data_item
94
+
95
+
96
+
97
+ class C2ILoader():
98
+ def __init__(self, data_config):
99
+ super().__init__()
100
+
101
+ self.batch_size = data_config.dataloader.batch_size
102
+ self.num_workers = data_config.dataloader.num_workers
103
+
104
+ self.data_type = data_config.data_type
105
+
106
+ if data_config.data_type == 'image':
107
+ self.train_dataset = ImagenetDictWrapper(**OmegaConf.to_container(data_config.dataset))
108
+ elif data_config.data_type == 'latent':
109
+ self.train_dataset = ImagenetLatentDataset(**OmegaConf.to_container(data_config.dataset))
110
+ else:
111
+ raise NotImplementedError
112
+
113
+
114
+ self.test_dataset = None
115
+ self.val_dataset = None
116
+
117
+ def train_len(self):
118
+ return len(self.train_dataset)
119
+
120
+ def train_dataloader(self, rank, world_size, global_batch_size, max_steps, resume_steps, seed):
121
+
122
+ sampler = get_train_sampler(
123
+ self.train_dataset, rank, world_size, global_batch_size, max_steps, resume_steps, seed
124
+ )
125
+ return DataLoader(
126
+ self.train_dataset,
127
+ batch_size=self.batch_size,
128
+ sampler=sampler,
129
+ num_workers=self.num_workers,
130
+ pin_memory=True,
131
+ drop_last=True,
132
+ prefetch_factor=2,
133
+ )
134
+
135
+ def test_dataloader(self):
136
+ return None
137
+
138
+ def val_dataloader(self):
139
+ return DataLoader(
140
+ self.train_dataset,
141
+ batch_size=self.batch_size,
142
+ shuffle=self.shuffle,
143
+ num_workers=self.num_workers,
144
+ pin_memory=True,
145
+ drop_last=True
146
+ )
147
+
148
+
149
+
150
+
tim/data/sampler_utils.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import json
3
+
4
+ # from https://github.com/Alpha-VLLM/LLaMA2-Accessory/blob/main/Large-DiT-ImageNet/train.py#L60
5
+ def get_train_sampler(dataset, rank, world_size, global_batch_size, max_steps,
6
+ resume_step, seed):
7
+ sample_indices = torch.empty([max_steps * global_batch_size // world_size],
8
+ dtype=torch.long)
9
+ epoch_id, fill_ptr, offs = 0, 0, 0
10
+ while fill_ptr < sample_indices.size(0):
11
+ g = torch.Generator()
12
+ g.manual_seed(seed + epoch_id)
13
+ epoch_sample_indices = torch.randperm(len(dataset), generator=g)
14
+ epoch_id += 1
15
+ epoch_sample_indices = epoch_sample_indices[
16
+ (rank + offs) % world_size::world_size
17
+ ]
18
+ offs = (offs + world_size - len(dataset) % world_size) % world_size
19
+ epoch_sample_indices = epoch_sample_indices[
20
+ :sample_indices.size(0) - fill_ptr
21
+ ]
22
+ sample_indices[fill_ptr: fill_ptr + epoch_sample_indices.size(0)] = \
23
+ epoch_sample_indices
24
+ fill_ptr += epoch_sample_indices.size(0)
25
+ return sample_indices[resume_step * global_batch_size // world_size:].tolist()
26
+
27
+
28
+
29
+
30
+ def get_packed_batch_sampler(
31
+ dataset, rank, world_size, max_steps, resume_step, seed
32
+ ):
33
+ sample_indices = [None for _ in range(max_steps)]
34
+ epoch_id, fill_ptr, offs = 0, 0, 0
35
+ while fill_ptr < len(sample_indices):
36
+ g = torch.Generator()
37
+ g.manual_seed(seed + epoch_id)
38
+ epoch_sample_indices = torch.randperm(len(dataset), generator=g)
39
+ epoch_id += 1
40
+ epoch_sample_indices = epoch_sample_indices[
41
+ (rank + offs) % world_size::world_size
42
+ ]
43
+ offs = (offs + world_size - len(dataset) % world_size) % world_size
44
+ epoch_sample_indices = epoch_sample_indices[
45
+ :len(sample_indices) - fill_ptr
46
+ ]
47
+ sample_indices[fill_ptr: fill_ptr + epoch_sample_indices.size(0)] = [
48
+ dataset[i] for i in epoch_sample_indices
49
+ ]
50
+ fill_ptr += epoch_sample_indices.size(0)
51
+ return sample_indices[resume_step:]
52
+
tim/data/t2i_data.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import csv
3
+ import json
4
+ import os
5
+ import random
6
+ import ast
7
+ import numpy as np
8
+ from omegaconf import OmegaConf
9
+ from torchvision import transforms
10
+ from torch.utils.data import DataLoader, Dataset
11
+ from PIL import Image
12
+ from tqdm import tqdm
13
+ from safetensors.torch import save_file, load_file
14
+ from .sampler_utils import get_train_sampler, get_packed_batch_sampler
15
+
16
+
17
+
18
+ def resize_arr(pil_image, height, width):
19
+ pil_image = pil_image.resize((width, height), resample=Image.Resampling.BICUBIC)
20
+
21
+ return pil_image
22
+
23
+
24
+ class T2IDatasetMS(Dataset):
25
+ def __init__(self, root_dir, packed_json, jsonl_dir) -> None:
26
+ super().__init__()
27
+ self.root_dir = root_dir
28
+ self.dataset = []
29
+ with open(packed_json, 'r') as fp:
30
+ self.packed_dataset = json.load(fp)
31
+
32
+ with open(jsonl_dir, 'r') as fp:
33
+ self.dataset = [json.loads(line) for line in fp]
34
+
35
+
36
+ def __len__(self):
37
+ return len(self.dataset)
38
+
39
+ def get_one_data(self, data_meta):
40
+ data_item = dict()
41
+ image_file = os.path.join(self.root_dir, data_meta['image_file'])
42
+
43
+ image = Image.open(image_file).convert("RGB")
44
+
45
+ bucket = data_meta['bucket']
46
+ resolutions = bucket.split('-')[-1].split('x')
47
+ height, width = int(int(resolutions[0])/32)*32, int(int(resolutions[1])/32)*32
48
+ transform = transforms.Compose([
49
+ transforms.Lambda(lambda pil_image: resize_arr(pil_image, height, width)),
50
+ transforms.ToTensor(),
51
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
52
+ ])
53
+ image = transform(image)
54
+
55
+ data_item['image'] = image
56
+ data_item['caption'] = random.choice(data_meta['captions']).encode('unicode-escape').decode('utf-8')
57
+
58
+ return data_item
59
+
60
+ def __getitem__(self, index):
61
+ data_meta = self.dataset[index]
62
+ # data_item = self.get_one_data(data_meta)
63
+ try:
64
+ data_item = self.get_one_data(data_meta)
65
+ except:
66
+ print(f"Warning: {data_meta['image_file']} does not exist", flush=True)
67
+ data_item = None
68
+
69
+ return data_item
70
+
71
+
72
+
73
+ def bucket_collate_fn(batch):
74
+ caption = []
75
+ image = []
76
+ for data in batch:
77
+ if data == None:
78
+ continue
79
+ caption.append(data['caption'])
80
+ image.append(data['image'])
81
+ image = torch.stack(image)
82
+ return dict(image=image, caption=caption)
83
+
84
+
85
+
86
+
87
+ class T2ILoader():
88
+ def __init__(self, data_config):
89
+ super().__init__()
90
+
91
+ self.batch_size = data_config.dataloader.batch_size
92
+ self.num_workers = data_config.dataloader.num_workers
93
+
94
+ self.data_type = data_config.data_type
95
+
96
+ if self.data_type == 'image_ms':
97
+ self.train_dataset = T2IDatasetMS(**OmegaConf.to_container(data_config.dataset))
98
+ else:
99
+ raise
100
+ self.test_dataset = None
101
+ self.val_dataset = None
102
+
103
+ def train_len(self):
104
+ return len(self.train_dataset)
105
+
106
+ def train_dataloader(self, rank, world_size, global_batch_size, max_steps, resume_steps, seed):
107
+ batch_sampler = get_packed_batch_sampler(
108
+ self.train_dataset.packed_dataset, rank, world_size, max_steps, resume_steps, seed
109
+ )
110
+ return DataLoader(
111
+ self.train_dataset,
112
+ batch_sampler=batch_sampler,
113
+ collate_fn=bucket_collate_fn,
114
+ num_workers=self.num_workers,
115
+ pin_memory=True,
116
+ )
117
+
118
+ def test_dataloader(self):
119
+ return None
120
+
121
+ def val_dataloader(self):
122
+ return None
123
+
124
+
125
+
126
+
tim/models/c2i/tim_model.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This source code is licensed under the license found in the
2
+ # LICENSE file in the root directory of this source tree.
3
+ # --------------------------------------------------------
4
+ # References:
5
+ # GLIDE: https://github.com/openai/glide-text2im
6
+ # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
7
+ # --------------------------------------------------------
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import numpy as np
13
+ import math
14
+ from timm.layers.mlp import SwiGLU, Mlp
15
+ from timm.models.vision_transformer import PatchEmbed, Attention
16
+ from tim.models.utils.funcs import build_mlp, modulate, get_parameter_dtype
17
+ from tim.models.utils.rope import VisionRotaryEmbedding, rotate_half
18
+ from flash_attn import flash_attn_func
19
+
20
+
21
+ #################################################################################
22
+ # Embedding Layers for Timesteps and Class Labels #
23
+ #################################################################################
24
+ class TimestepEmbedder(nn.Module):
25
+ """
26
+ Embeds scalar timesteps into vector representations.
27
+ """
28
+ def __init__(self, hidden_size, frequency_embedding_size=256):
29
+ super().__init__()
30
+ self.mlp = nn.Sequential(
31
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
32
+ nn.SiLU(),
33
+ nn.Linear(hidden_size, hidden_size, bias=True),
34
+ )
35
+ self.frequency_embedding_size = frequency_embedding_size
36
+
37
+ @staticmethod
38
+ def positional_embedding(t, dim, max_period=10000):
39
+ """
40
+ Create sinusoidal timestep embeddings.
41
+ :param t: a 1-D Tensor of N indices, one per batch element.
42
+ These may be fractional.
43
+ :param dim: the dimension of the output.
44
+ :param max_period: controls the minimum frequency of the embeddings.
45
+ :return: an (N, D) Tensor of positional embeddings.
46
+ """
47
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
48
+ half = dim // 2
49
+ freqs = torch.exp(
50
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
51
+ ).to(device=t.device)
52
+ args = t[:, None].float() * freqs[None]
53
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
54
+ if dim % 2:
55
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
56
+ return embedding
57
+
58
+ def forward(self, t):
59
+ self.timestep_embedding = self.positional_embedding
60
+ t_freq = self.timestep_embedding(t, dim=self.frequency_embedding_size).to(t.dtype)
61
+ t_emb = self.mlp(t_freq)
62
+ return t_emb
63
+
64
+
65
+ class LabelEmbedder(nn.Module):
66
+ """
67
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
68
+ """
69
+ def __init__(self, num_classes, hidden_size, dropout_prob):
70
+ super().__init__()
71
+ use_cfg_embedding = dropout_prob > 0
72
+ self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
73
+ self.num_classes = num_classes
74
+ self.dropout_prob = dropout_prob
75
+
76
+
77
+ def forward(self, labels):
78
+ embeddings = self.embedding_table(labels)
79
+ return embeddings
80
+
81
+
82
+
83
+
84
+ #################################################################################
85
+ # Attention Block #
86
+ #################################################################################
87
+
88
+ class Attention(nn.Module):
89
+ def __init__(
90
+ self,
91
+ dim: int,
92
+ num_heads: int = 8,
93
+ qkv_bias: bool = False,
94
+ qk_norm: bool = False,
95
+ attn_drop: float = 0.,
96
+ proj_drop: float = 0.,
97
+ norm_layer: nn.Module = nn.LayerNorm,
98
+ distance_aware: bool = False,
99
+ ) -> None:
100
+ super().__init__()
101
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
102
+ self.num_heads = num_heads
103
+ self.head_dim = dim // num_heads
104
+ self.scale = self.head_dim ** -0.5
105
+
106
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
107
+ self.distance_aware = distance_aware
108
+ if distance_aware:
109
+ self.qkv_d = nn.Linear(dim, dim * 3, bias=False)
110
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
111
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
112
+ self.attn_drop = nn.Dropout(attn_drop)
113
+ self.proj = nn.Linear(dim, dim)
114
+ self.proj_drop = nn.Dropout(proj_drop)
115
+
116
+ def forward(self, x: torch.Tensor, freqs_cos, freqs_sin, attn_type='fused_attn', delta_t=None) -> torch.Tensor:
117
+ B, N, C = x.shape
118
+ if self.distance_aware:
119
+ qkv = self.qkv(x) + self.qkv_d(delta_t)
120
+ else:
121
+ qkv = self.qkv(x)
122
+ if attn_type == 'flash_attn': # q, k, v: (B, N, n_head, d_head)
123
+ qkv = qkv.reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 1, 3, 4)
124
+ else: # q, k, v: (B, n_head, N, d_head)
125
+ qkv = qkv.reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
126
+ ori_dtype = qkv.dtype
127
+ q, k, v = qkv.unbind(0)
128
+ q, k = self.q_norm(q), self.k_norm(k)
129
+
130
+ q = q * freqs_cos + rotate_half(q) * freqs_sin
131
+ k = k * freqs_cos + rotate_half(k) * freqs_sin
132
+ q, k = q.to(ori_dtype), k.to(ori_dtype)
133
+
134
+ if attn_type == 'flash_attn':
135
+ x = flash_attn_func(
136
+ q, k, v,
137
+ dropout_p=self.attn_drop.p if self.training else 0.,
138
+ )
139
+ x = x.reshape(B, N, C)
140
+ elif attn_type == 'fused_attn':
141
+ x = F.scaled_dot_product_attention(
142
+ q, k, v,
143
+ dropout_p=self.attn_drop.p if self.training else 0.,
144
+ )
145
+ x = x.transpose(1, 2).reshape(B, N, C)
146
+ else:
147
+ q = q * self.scale
148
+ attn = q @ k.transpose(-2, -1)
149
+ attn = attn.softmax(dim=-1)
150
+ attn = self.attn_drop(attn)
151
+ x = attn @ v
152
+ x = x.transpose(1, 2).reshape(B, N, C)
153
+
154
+ x = self.proj(x)
155
+ x = self.proj_drop(x)
156
+ return x
157
+
158
+
159
+
160
+
161
+
162
+
163
+ #################################################################################
164
+ # Core TiM Model #
165
+ #################################################################################
166
+
167
+ class TiMBlock(nn.Module):
168
+ """
169
+ A TiM block with adaptive layer norm zero (adaLN-Zero) conditioning.
170
+ """
171
+ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
172
+ super().__init__()
173
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
174
+ distance_aware = block_kwargs.get('distance_aware', False)
175
+ self.attn = Attention(
176
+ hidden_size, num_heads=num_heads, qkv_bias=True, qk_norm=block_kwargs["qk_norm"],
177
+ distance_aware=distance_aware
178
+ )
179
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
180
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
181
+ self.mlp = SwiGLU(
182
+ in_features=hidden_size, hidden_features=(mlp_hidden_dim*2)//3, bias=True
183
+ )
184
+ if block_kwargs.get('lora_hidden_size', None) != None:
185
+ lora_hidden_size = block_kwargs['lora_hidden_size']
186
+ else:
187
+ lora_hidden_size = (hidden_size//4)*3
188
+ self.adaLN_modulation = SwiGLU(
189
+ in_features=hidden_size, hidden_features=lora_hidden_size, out_features=6*hidden_size, bias=True
190
+ )
191
+
192
+
193
+
194
+ def forward(self, x, c, freqs_cos, freqs_sin, attn_type, delta_t=None):
195
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
196
+ self.adaLN_modulation(c).chunk(6, dim=-1)
197
+ )
198
+ x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), freqs_cos, freqs_sin, attn_type, delta_t)
199
+ x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
200
+
201
+ return x
202
+
203
+
204
+ class FinalLayer(nn.Module):
205
+ """
206
+ The final layer of TiM.
207
+ """
208
+ def __init__(self, hidden_size, patch_size, out_channels):
209
+ super().__init__()
210
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
211
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
212
+ self.adaLN_modulation = SwiGLU(
213
+ in_features=hidden_size, hidden_features=hidden_size//2, out_features=2*hidden_size, bias=True
214
+ )
215
+
216
+
217
+ def forward(self, x, c):
218
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
219
+ x = modulate(self.norm_final(x), shift, scale)
220
+ x = self.linear(x)
221
+
222
+ return x
223
+
224
+
225
+ class TiM(nn.Module):
226
+ def __init__(
227
+ self,
228
+ input_size=32,
229
+ patch_size=2,
230
+ in_channels=4,
231
+ hidden_size=1152,
232
+ encoder_depth=8,
233
+ depth=28,
234
+ num_heads=16,
235
+ mlp_ratio=4.0,
236
+ class_dropout_prob=0.1,
237
+ num_classes=1000,
238
+ z_dim=768,
239
+ projector_dim=2048,
240
+ use_checkpoint: bool = False,
241
+ new_condition: str = 't-r',
242
+ use_new_embed: bool = False,
243
+ **block_kwargs # qk_norm
244
+ ):
245
+ super().__init__()
246
+ self.in_channels = in_channels
247
+ self.out_channels = in_channels
248
+ self.patch_size = patch_size
249
+ self.num_heads = num_heads
250
+ self.num_classes = num_classes
251
+ self.encoder_depth = encoder_depth
252
+ self.use_checkpoint = use_checkpoint
253
+ self.new_condition = new_condition
254
+ self.use_new_embed = use_new_embed
255
+
256
+ self.x_embedder = PatchEmbed(
257
+ input_size, patch_size, in_channels, hidden_size, bias=True, strict_img_size=False
258
+ )
259
+ self.t_embedder = TimestepEmbedder(hidden_size) # timestep embedding type
260
+ if use_new_embed:
261
+ self.delta_embedder = TimestepEmbedder(hidden_size)
262
+ self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
263
+ # Will use fixed sin-cos embedding:
264
+ self.rope = VisionRotaryEmbedding(head_dim=hidden_size//num_heads)
265
+
266
+ self.blocks = nn.ModuleList([
267
+ TiMBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, **block_kwargs) for _ in range(depth)
268
+ ])
269
+ self.projector = build_mlp(hidden_size, projector_dim, z_dim)
270
+ self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
271
+ self.initialize_weights()
272
+
273
+ def initialize_weights(self):
274
+ # Initialize transformer layers:
275
+ def _basic_init(module):
276
+ if isinstance(module, nn.Linear):
277
+ torch.nn.init.xavier_uniform_(module.weight)
278
+ if module.bias is not None:
279
+ nn.init.constant_(module.bias, 0)
280
+ self.apply(_basic_init)
281
+
282
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
283
+ w = self.x_embedder.proj.weight.data
284
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
285
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
286
+
287
+ # Initialize label embedding table:
288
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
289
+
290
+ # Initialize timestep embedding MLP:
291
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
292
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
293
+
294
+ # Zero-out adaLN modulation layers in TiM blocks:
295
+ for block in self.blocks:
296
+ nn.init.constant_(block.adaLN_modulation.fc2.weight, 0)
297
+ nn.init.constant_(block.adaLN_modulation.fc2.bias, 0)
298
+
299
+ # Zero-out output layers:
300
+ nn.init.constant_(self.final_layer.adaLN_modulation.fc2.weight, 0)
301
+ nn.init.constant_(self.final_layer.adaLN_modulation.fc2.bias, 0)
302
+
303
+ nn.init.constant_(self.final_layer.linear.weight, 0)
304
+ nn.init.constant_(self.final_layer.linear.bias, 0)
305
+
306
+ def unpatchify(self, x, H, W):
307
+ """
308
+ x: (N, T, patch_size**2 * C)
309
+ imgs: (N, H, W, C)
310
+ """
311
+ c = self.out_channels
312
+ p = self.patch_size
313
+ h, w = int(H/p), int(W/p)
314
+
315
+
316
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
317
+ x = torch.einsum('nhwpqc->nchpwq', x)
318
+ imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
319
+ return imgs
320
+
321
+ def get_rope(self, h, w, attn_type):
322
+ grid_h = torch.arange(h)
323
+ grid_w = torch.arange(w)
324
+ grid = torch.meshgrid(grid_h, grid_w, indexing='xy')
325
+ grid = torch.stack(grid, dim=0).reshape(2, -1).unsqueeze(0)
326
+ freqs_cos, freqs_sin = self.rope.get_cached_2d_rope_from_grid(grid)
327
+ if attn_type == 'flash_attn': # (1, N, 1, d_head)
328
+ return freqs_cos.unsqueeze(2), freqs_sin.unsqueeze(2)
329
+ else: # (1, 1, N, d_head)
330
+ return freqs_cos.unsqueeze(1), freqs_sin.unsqueeze(1)
331
+
332
+
333
+ def forward(self, x, t, r, y, attn_type='flash_attn', return_zs=False, jvp=False):
334
+ """
335
+ Forward pass of TiM.
336
+ x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
337
+ t: (N,) tensor of diffusion timesteps
338
+ y: (N,) tensor of class labels
339
+ """
340
+ B, C, H, W = x.shape
341
+ x = self.x_embedder(x) # (N, T, D), where T = H * W / patch_size ** 2
342
+
343
+ # timestep and class embedding
344
+ t_embed = self.t_embedder(t).unsqueeze(1) # (N, 1, D)
345
+ delta_embed = self.get_delta_embed(t, r).unsqueeze(1) # (N, 1, D)
346
+ y = self.y_embedder(y).unsqueeze(1) # (N, 1, D)
347
+ c = t_embed + delta_embed + y # (N, 1, D)
348
+ freqs_cos, freqs_sin = self.get_rope(
349
+ int(H/self.patch_size), int(W/self.patch_size), attn_type
350
+ )
351
+
352
+ for i, block in enumerate(self.blocks):
353
+ if (not self.use_checkpoint) or jvp:
354
+ x = block(x, c, freqs_cos, freqs_sin, attn_type, delta_embed) # (N, T, D)
355
+ else:
356
+ x = torch.utils.checkpoint.checkpoint(
357
+ self.ckpt_wrapper(block), x, c, freqs_cos, freqs_sin, attn_type, delta_embed
358
+ )
359
+ if (i + 1) == self.encoder_depth:
360
+ h_proj = self.projector(x)
361
+
362
+ x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
363
+ x = self.unpatchify(x, H, W) # (N, out_channels, H, W)
364
+
365
+ if return_zs:
366
+ return x, h_proj
367
+ else:
368
+ return x
369
+
370
+ def get_delta_embed(self, t, r):
371
+ if self.use_new_embed:
372
+ delta_embedder = self.delta_embedder
373
+ else:
374
+ delta_embedder = self.t_embedder
375
+ if self.new_condition == 't-r':
376
+ delta_embed = delta_embedder(t-r)
377
+ elif self.new_condition == 'r':
378
+ delta_embed = delta_embedder(r)
379
+ elif self.new_condition == 't,r':
380
+ delta_embed = self.t_embedder(t) + delta_embedder(r)
381
+ elif self.new_condition == 't,t-r':
382
+ delta_embed = self.t_embedder(t) + delta_embedder(t-r)
383
+ elif self.new_condition == 'r,t-r':
384
+ delta_embed = self.t_embedder(r) + delta_embedder(t-r)
385
+ elif self.new_condition == 't,r,t-r':
386
+ delta_embed = self.t_embedder(t) + self.t_embedder(r) + delta_embedder(t-r)
387
+ else:
388
+ raise NotImplementedError
389
+ return delta_embed
390
+
391
+ def ckpt_wrapper(self, module):
392
+ def ckpt_forward(*inputs):
393
+ outputs = module(*inputs)
394
+ return outputs
395
+ return ckpt_forward
396
+
397
+
398
+ @property
399
+ def dtype(self) -> torch.dtype:
400
+ """
401
+ `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
402
+ """
403
+ return get_parameter_dtype(self)
404
+
405
+
406
+
tim/models/nvidia_radio/hubconf.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ dependencies = ["torch", "timm", "einops"]
10
+
11
+ import os
12
+ from typing import Dict, Any, Optional, Union, List
13
+ import warnings
14
+
15
+ import torch
16
+ from torch.hub import load_state_dict_from_url
17
+
18
+ from timm.models import clean_state_dict
19
+
20
+ from .radio.adaptor_registry import adaptor_registry
21
+ from .radio.common import DEFAULT_VERSION, RadioResource, RESOURCE_MAP
22
+ from .radio.enable_damp import configure_damp_from_args
23
+ from .radio.enable_spectral_reparam import disable_spectral_reparam, configure_spectral_reparam_from_args
24
+ from .radio.feature_normalizer import FeatureNormalizer, IntermediateFeatureNormalizer
25
+ from .radio.radio_model import RADIOModel, create_model_from_args
26
+ from .radio.input_conditioner import get_default_conditioner
27
+ from .radio.vitdet import apply_vitdet_arch, VitDetArgs
28
+
29
+
30
+ def radio_model(
31
+ version: str = "",
32
+ progress: bool = True,
33
+ adaptor_names: Union[str, List[str]] = None,
34
+ vitdet_window_size: Optional[int] = None,
35
+ return_checkpoint: bool = False,
36
+ support_packing: bool=False,
37
+ **kwargs,
38
+ ) -> RADIOModel:
39
+ if not version:
40
+ version = DEFAULT_VERSION
41
+
42
+ if os.path.isfile(version):
43
+ chk = torch.load(version, map_location="cpu", weights_only=False)
44
+ resource = RadioResource(version, patch_size=None, max_resolution=None, preferred_resolution=None)
45
+ else:
46
+ resource = RESOURCE_MAP[version]
47
+ chk = load_state_dict_from_url(
48
+ resource.url, progress=progress, map_location="cpu", weights_only=False,
49
+ )
50
+
51
+ if "state_dict_ema" in chk:
52
+ state_dict = chk["state_dict_ema"]
53
+ chk['args'].spectral_reparam = False
54
+ else:
55
+ state_dict = chk["state_dict"]
56
+
57
+ args = chk["args"]
58
+ args.support_packing = support_packing
59
+ mod = create_model_from_args(args)
60
+
61
+ mod_state_dict = get_prefix_state_dict(state_dict, "base_model.")
62
+
63
+ if args.spectral_reparam:
64
+ configure_spectral_reparam_from_args(mod, args, state_dict_guidance=mod_state_dict)
65
+
66
+ if getattr(args, 'damp', None):
67
+ configure_damp_from_args(mod, args)
68
+
69
+ state_dict = clean_state_dict(state_dict)
70
+
71
+ key_warn = mod.load_state_dict(mod_state_dict, strict=False)
72
+ if key_warn.missing_keys:
73
+ warnings.warn(f'Missing keys in state dict: {key_warn.missing_keys}')
74
+ if key_warn.unexpected_keys:
75
+ warnings.warn(f'Unexpected keys in state dict: {key_warn.unexpected_keys}')
76
+
77
+ if chk['args'].spectral_reparam:
78
+ # Spectral reparametrization uses PyTorch's "parametrizations" API. The idea behind
79
+ # the method is that instead of there being a `weight` tensor for certain Linear layers
80
+ # in the model, we make it a dynamically computed function. During training, this
81
+ # helps stabilize the model. However, for downstream use cases, it shouldn't be necessary.
82
+ # Disabling it in this context means that instead of having `w' = f(w)`, we just compute `w' = f(w)`
83
+ # once, during this function call, and replace the parametrization with the realized weights.
84
+ # This makes the model run faster, and also use less memory.
85
+ disable_spectral_reparam(mod)
86
+ chk['args'].spectral_reparam = False
87
+
88
+ conditioner = get_default_conditioner()
89
+ conditioner.load_state_dict(get_prefix_state_dict(state_dict, "input_conditioner."))
90
+
91
+ dtype = getattr(chk['args'], 'dtype', torch.float32)
92
+ mod.to(dtype=dtype)
93
+ conditioner.dtype = dtype
94
+
95
+ cls_token_per_teacher = getattr(chk['args'], 'cls_token_per_teacher', True)
96
+ if cls_token_per_teacher:
97
+ name_to_idx_map = dict()
98
+ for i, t in enumerate(chk['args'].teachers):
99
+ if t.get('use_summary', True):
100
+ name = t['name']
101
+ if name not in name_to_idx_map:
102
+ name_to_idx_map[name] = i
103
+ summary_idxs = torch.tensor(sorted(name_to_idx_map.values()), dtype=torch.int64)
104
+ else:
105
+ summary_idxs = torch.tensor([0], dtype=torch.int64)
106
+
107
+ if adaptor_names is None:
108
+ adaptor_names = []
109
+ elif isinstance(adaptor_names, str):
110
+ adaptor_names = [adaptor_names]
111
+
112
+ teachers = chk["args"].teachers
113
+ adaptors = dict()
114
+ for adaptor_name in adaptor_names:
115
+ for tidx, tconf in enumerate(teachers):
116
+ if tconf["name"] == adaptor_name:
117
+ break
118
+ else:
119
+ raise ValueError(f'Unable to find the specified adaptor name. Known names: {list(t["name"] for t in teachers)}')
120
+
121
+ ttype = tconf["type"]
122
+
123
+ pf_idx_head = f'_heads.{tidx}'
124
+ pf_name_head = f'_heads.{adaptor_name}'
125
+ pf_idx_feat = f'_feature_projections.{tidx}'
126
+ pf_name_feat = f'_feature_projections.{adaptor_name}'
127
+
128
+ adaptor_state = dict()
129
+ for k, v in state_dict.items():
130
+ if k.startswith(pf_idx_head):
131
+ adaptor_state['summary' + k[len(pf_idx_head):]] = v
132
+ elif k.startswith(pf_name_head):
133
+ adaptor_state['summary' + k[len(pf_name_head):]] = v
134
+ elif k.startswith(pf_idx_feat):
135
+ adaptor_state['feature' + k[len(pf_idx_feat):]] = v
136
+ elif k.startswith(pf_name_feat):
137
+ adaptor_state['feature' + k[len(pf_name_feat):]] = v
138
+
139
+ adaptor = adaptor_registry.create_adaptor(ttype, chk["args"], tconf, adaptor_state)
140
+ adaptor.head_idx = tidx if cls_token_per_teacher else 0
141
+ adaptors[adaptor_name] = adaptor
142
+
143
+ feat_norm_sd = get_prefix_state_dict(state_dict, '_feature_normalizer.')
144
+ feature_normalizer = None
145
+ if feat_norm_sd:
146
+ feature_normalizer = FeatureNormalizer(feat_norm_sd['mean'].shape[0], dtype=dtype)
147
+ feature_normalizer.load_state_dict(feat_norm_sd)
148
+
149
+ inter_feat_norm_sd = get_prefix_state_dict(state_dict, '_intermediate_feature_normalizer.')
150
+ inter_feature_normalizer = None
151
+ if inter_feat_norm_sd:
152
+ inter_feature_normalizer = IntermediateFeatureNormalizer(
153
+ *inter_feat_norm_sd['means'].shape[:2],
154
+ rot_per_layer=inter_feat_norm_sd['rotation'].ndim == 3,
155
+ dtype=dtype
156
+ )
157
+ inter_feature_normalizer.load_state_dict(inter_feat_norm_sd)
158
+
159
+ radio = RADIOModel(
160
+ mod,
161
+ conditioner,
162
+ summary_idxs=summary_idxs,
163
+ patch_size=resource.patch_size,
164
+ max_resolution=resource.max_resolution,
165
+ window_size=vitdet_window_size,
166
+ preferred_resolution=resource.preferred_resolution,
167
+ adaptors=adaptors,
168
+ feature_normalizer=feature_normalizer,
169
+ inter_feature_normalizer=inter_feature_normalizer,
170
+ )
171
+
172
+ if vitdet_window_size is not None:
173
+ apply_vitdet_arch(
174
+ mod,
175
+ VitDetArgs(
176
+ vitdet_window_size,
177
+ radio.num_summary_tokens,
178
+ num_windowed=resource.vitdet_num_windowed,
179
+ num_global=resource.vitdet_num_global,
180
+ ),
181
+ )
182
+
183
+ if return_checkpoint:
184
+ return radio, chk
185
+ return radio
186
+
187
+
188
+ def get_prefix_state_dict(state_dict: Dict[str, Any], prefix: str):
189
+ mod_state_dict = {
190
+ k[len(prefix) :]: v for k, v in state_dict.items() if k.startswith(prefix)
191
+ }
192
+ return mod_state_dict
tim/models/nvidia_radio/radio/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ # Register the adaptors
10
+ from .adaptor_registry import adaptor_registry
11
+ from . import open_clip_adaptor
12
+ from .adaptor_base import AdaptorInput, RadioOutput, AdaptorBase
13
+
14
+ # Enable support for other model types via the timm register_model mechanism
15
+ from . import extra_timm_models
16
+ from . import extra_models
17
+ from . import vision_transformer_xpos
tim/models/nvidia_radio/radio/adaptor_base.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+ from argparse import Namespace
9
+ from typing import NamedTuple, Optional
10
+
11
+ import torch
12
+ from torch import nn
13
+ import torch.nn.functional as F
14
+
15
+
16
+ class AdaptorInput(NamedTuple):
17
+ images: torch.Tensor
18
+ summary: torch.Tensor
19
+ features: torch.Tensor
20
+ feature_fmt: str
21
+ patch_size: int
22
+
23
+
24
+ class RadioOutput(NamedTuple):
25
+ summary: torch.Tensor
26
+ features: torch.Tensor
27
+
28
+ def to(self, *args, **kwargs):
29
+ return RadioOutput(
30
+ self.summary.to(*args, **kwargs) if self.summary is not None else None,
31
+ self.features.to(*args, **kwargs) if self.features is not None else None,
32
+ )
33
+
34
+
35
+ class AdaptorBase(nn.Module):
36
+ def forward(self, input: AdaptorInput) -> RadioOutput:
37
+ raise NotImplementedError("Subclasses must implement this!")
tim/models/nvidia_radio/radio/adaptor_generic.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+ from argparse import Namespace
9
+
10
+ import torch
11
+ from torch import nn
12
+ import torch.nn.functional as F
13
+
14
+ from .adaptor_base import AdaptorBase, AdaptorInput, RadioOutput
15
+ from .adaptor_mlp import create_mlp_from_state, create_mlp_from_config
16
+
17
+
18
+ class GenericAdaptor(AdaptorBase):
19
+ def __init__(self, main_config: Namespace, adaptor_config, state, mlp_config=None):
20
+ super().__init__()
21
+
22
+ extra_args = dict()
23
+ ups = None
24
+ ups_rank = None
25
+ if adaptor_config is not None:
26
+ ups = adaptor_config.get('fd_upsample_factor', None)
27
+ ups_rank = adaptor_config.get('fd_upsample_rank', None)
28
+ elif mlp_config is not None:
29
+ ups = mlp_config["feature"].get('upsample_factor', None)
30
+ ups_rank = mlp_config["feature"].get('upsample_rank', None)
31
+ if ups is not None:
32
+ extra_args['upsample_factor'] = ups
33
+ extra_args['upsample_rank'] = ups_rank
34
+
35
+ if state is not None:
36
+ spectral_heads = getattr(main_config, 'spectral_heads', False)
37
+ self.head_mlp = create_mlp_from_state(main_config.mlp_version, state, 'summary.', spectral_weights=spectral_heads)
38
+ self.feat_mlp = create_mlp_from_state(main_config.mlp_version, state, 'feature.', spectral_weights=spectral_heads, **extra_args)
39
+ else:
40
+ assert mlp_config is not None, "Config must not be None if state is None"
41
+
42
+ self.head_mlp = create_mlp_from_config(
43
+ main_config.mlp_version,
44
+ mlp_config["summary"]["input_dim"],
45
+ mlp_config["summary"]["hidden_dim"],
46
+ mlp_config["summary"]["output_dim"],
47
+ mlp_config["summary"]["num_inner"],
48
+ )
49
+ self.feat_mlp = create_mlp_from_config(
50
+ main_config.mlp_version,
51
+ mlp_config["feature"]["input_dim"],
52
+ mlp_config["feature"]["hidden_dim"],
53
+ mlp_config["feature"]["output_dim"],
54
+ mlp_config["feature"]["num_inner"],
55
+ **extra_args
56
+ )
57
+
58
+ def forward(self, input: AdaptorInput) -> RadioOutput:
59
+ # Convert input'd type to the type of the first parameter of the adaptor.
60
+ first_param = next(self.parameters())
61
+ summary = self.head_mlp(input.summary.to(dtype=first_param.dtype)).to(dtype=input.summary.dtype)
62
+ feat = self.feat_mlp(input.features.to(dtype=first_param.dtype), images=input.images, patch_size=input.patch_size).to(dtype=input.features.dtype)
63
+
64
+ if input.feature_fmt == 'NCHW':
65
+ feat = (feat.reshape(feat.shape[0], input.images.shape[-2] // input.patch_size * self.feat_mlp.upsample_factor, input.images.shape[-1] // input.patch_size * self.feat_mlp.upsample_factor, feat.shape[2])
66
+ .permute(0, 3, 1, 2)
67
+ )
68
+
69
+ return RadioOutput(summary, feat)
tim/models/nvidia_radio/radio/adaptor_mlp.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+ import math
9
+ from typing import Dict, Optional
10
+
11
+ import torch
12
+ from torch import nn
13
+
14
+ from einops import rearrange
15
+ from timm.models.vision_transformer import Block
16
+
17
+ from .enable_spectral_reparam import disable_spectral_reparam, enable_spectral_reparam
18
+
19
+
20
+ class MLP(nn.Module):
21
+ def __init__(self, input_size: int, hidden_size: int, output_size: int,
22
+ num_inner: int = 0, device: torch.device = None, **kwargs):
23
+ super(MLP, self).__init__()
24
+ self.fc1 = nn.Linear(input_size, hidden_size, device=device)
25
+ self.norm = nn.LayerNorm(hidden_size, device=device)
26
+ self.relu = nn.ReLU()
27
+
28
+ inner = []
29
+ for _ in range(num_inner):
30
+ inner.extend([
31
+ nn.Linear(hidden_size, hidden_size, device=device),
32
+ nn.LayerNorm(hidden_size, device=device),
33
+ nn.ReLU(),
34
+ ])
35
+ if inner:
36
+ self.inner = nn.Sequential(*inner)
37
+ else:
38
+ self.inner = nn.Identity()
39
+
40
+ self.fc2 = nn.Linear(hidden_size, output_size, device=device)
41
+
42
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
43
+ x = self.fc1(x)
44
+ x = self.norm(x)
45
+ x = self.relu(x)
46
+ x = self.inner(x)
47
+ x = self.fc2(x)
48
+ return x
49
+
50
+
51
+ class MLP2(nn.Module):
52
+ def __init__(self, input_size: int, hidden_size: int, output_size: int,
53
+ num_inner: int = 0,
54
+ pre_norm: bool = False, device: torch.device = None,
55
+ upsample_factor: int = 1,
56
+ upsample_rank: int = None,
57
+ from_config: bool = False,
58
+ **kwargs):
59
+ super().__init__()
60
+
61
+ self.pre_norm = nn.Sequential(
62
+ nn.LayerNorm(input_size),
63
+ nn.GELU(),
64
+ ) if pre_norm else nn.Identity()
65
+
66
+ self.upsample_factor = upsample_factor
67
+ sq_ups = upsample_factor ** 2
68
+
69
+ self._real_output_dim = output_size // sq_ups
70
+
71
+ # hidden_size *= upsample_factor
72
+ # output_size *= (upsample_factor ** 2)
73
+
74
+ self.fc1 = nn.Linear(input_size, hidden_size, device=device)
75
+
76
+ blocks = []
77
+ for _ in range(num_inner):
78
+ blocks.append(nn.Sequential(
79
+ nn.LayerNorm(hidden_size, device=device),
80
+ nn.GELU(),
81
+ nn.Linear(hidden_size, hidden_size, device=device),
82
+ ))
83
+ self.blocks = nn.ModuleList(blocks)
84
+
85
+ self.final = nn.Sequential(
86
+ nn.LayerNorm(hidden_size, device=device),
87
+ nn.GELU(),
88
+ nn.Linear(hidden_size, output_size, device=device),
89
+ )
90
+
91
+ def forward(self, x: torch.Tensor, images: Optional[torch.Tensor] = None, patch_size: Optional[int] = None) -> torch.Tensor:
92
+ x = self.pre_norm(x)
93
+ x = self.fc1(x)
94
+ for block in self.blocks:
95
+ x = x + block(x)
96
+ x = self.final(x)
97
+
98
+ if self.upsample_factor > 1:
99
+ if images is None:
100
+ raise ValueError(f'`images` cannot be `None` when the head\'s `upsample_factor > 1`!')
101
+ if patch_size is None:
102
+ raise ValueError(f'`patch_size` cannot be `None` when the head\'s `upsample_factor > 1`!')
103
+ h, w = tuple(d // patch_size for d in images.shape[-2:])
104
+ x = rearrange(x, 'b (h w) (u1 u2 c) -> b (h u1 w u2) c',
105
+ h=h, w=w, u1=self.upsample_factor, u2=self.upsample_factor,
106
+ c=self._real_output_dim)
107
+
108
+ return x
109
+
110
+
111
+ MLP_FACTORY = {
112
+ 'v1': MLP,
113
+ 'v2': MLP2,
114
+ }
115
+
116
+
117
+ def strip_prefix(state: Dict[str, torch.Tensor], prefix: str):
118
+ state = {
119
+ k[len(prefix):]: v
120
+ for k, v in state.items()
121
+ if k.startswith(prefix)
122
+ }
123
+ return state
124
+
125
+
126
+ def get_mlp_info_from_state(version: str, state: Dict[str, torch.Tensor], prefix: str = '', spectral_weights: bool = False):
127
+ state = strip_prefix(state, prefix)
128
+
129
+ weight_suffix = 'weight' if not spectral_weights else 'parametrizations.weight.original'
130
+
131
+ if version == 'v1':
132
+ hidden_dim, input_dim = state[f'fc1.{weight_suffix}'].shape
133
+ output_dim = state[f'fc2.{weight_suffix}'].shape[0]
134
+
135
+ for num_inner in range(1000):
136
+ k = f'inner.{num_inner}.0.weight'
137
+ if k not in state:
138
+ break
139
+ elif version == 'v2':
140
+ hidden_dim, input_dim = state[f'fc1.{weight_suffix}'].shape
141
+ output_dim = state[f'final.2.{weight_suffix}'].shape[0]
142
+
143
+ for num_inner in range(1000):
144
+ k = f'blocks.{num_inner}.0.weight'
145
+ if k not in state:
146
+ break
147
+ else:
148
+ raise ValueError(f'Unsupported MLP version: {version}')
149
+
150
+ return input_dim, hidden_dim, output_dim, num_inner
151
+
152
+
153
+ def create_mlp_from_config(version: str, input_dim: int, hidden_dim: int, output_dim: int, num_inner: int, **kwargs):
154
+ ret: nn.Module = MLP_FACTORY[version](input_dim, hidden_dim, output_dim, num_inner, from_config=True, **kwargs)
155
+
156
+ return ret
157
+
158
+
159
+ def create_mlp_from_state(version: str, state: Dict[str, torch.Tensor], prefix: str = '', spectral_weights: bool = False, **kwargs):
160
+ state = strip_prefix(state, prefix)
161
+
162
+ input_dim, hidden_dim, output_dim, num_inner = get_mlp_info_from_state(version, state, spectral_weights=spectral_weights)
163
+
164
+ ret: nn.Module = create_mlp_from_config(version, input_dim, hidden_dim, output_dim, num_inner, **kwargs)
165
+
166
+ if spectral_weights:
167
+ enable_spectral_reparam(ret, init_norm_to_current=False, state_dict_guidance=state)
168
+
169
+ ret.load_state_dict(state)
170
+
171
+ if spectral_weights:
172
+ disable_spectral_reparam(ret)
173
+
174
+ return ret
tim/models/nvidia_radio/radio/adaptor_registry.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+ from argparse import Namespace
9
+ from typing import Dict, Any
10
+
11
+ import torch
12
+
13
+ from .adaptor_generic import GenericAdaptor, AdaptorBase
14
+
15
+ dict_t = Dict[str, Any]
16
+ state_t = Dict[str, torch.Tensor]
17
+
18
+
19
+ class AdaptorRegistry:
20
+ def __init__(self):
21
+ self._registry = {}
22
+
23
+ def register_adaptor(self, name):
24
+ def decorator(factory_function):
25
+ if name in self._registry:
26
+ raise ValueError(f"Model '{name}' already registered")
27
+ self._registry[name] = factory_function
28
+ return factory_function
29
+ return decorator
30
+
31
+ def create_adaptor(self, name, main_config: Namespace, adaptor_config: dict_t, state: state_t) -> AdaptorBase:
32
+ if name not in self._registry:
33
+ return GenericAdaptor(main_config, adaptor_config, state)
34
+ return self._registry[name](main_config, adaptor_config, state)
35
+
36
+ # Creating an instance of the registry
37
+ adaptor_registry = AdaptorRegistry()
tim/models/nvidia_radio/radio/block.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ """
3
+ Block modules
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from timm.models.layers import DropPath
9
+
10
+ from .conv import Conv
11
+ # from .transformer import TransformerBlock
12
+
13
+ __all__ = ('C2f', 'Bottleneck',)
14
+
15
+ class C2f(nn.Module):
16
+ """Faster Implementation of CSP Bottleneck with 2 convolutions."""
17
+
18
+ def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5, drop_path=None): # ch_in, ch_out, number, shortcut, groups, expansion
19
+ super().__init__()
20
+ if drop_path is None:
21
+ drop_path = [0.0] * n
22
+
23
+ self.c = int(c2 * e) # hidden channels
24
+ self.cv1 = Conv(c1, 2 * self.c, 1, 1)
25
+ self.cv2 = Conv((2 + n) * self.c, c2, 1) # optional act=FReLU(c2)
26
+ self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0, drop_path=drop_path[i]) for i in range(n))
27
+
28
+ def forward(self, x):
29
+ """Forward pass through C2f layer."""
30
+ y = list(self.cv1(x).chunk(2, 1))
31
+ y.extend(m(y[-1]) for m in self.m)
32
+ return self.cv2(torch.cat(y, 1))
33
+
34
+ def forward_split(self, x):
35
+ """Forward pass using split() instead of chunk()."""
36
+ y = list(self.cv1(x).split((self.c, self.c), 1))
37
+ y.extend(m(y[-1]) for m in self.m)
38
+ return self.cv2(torch.cat(y, 1))
39
+
40
+
41
+ class Bottleneck(nn.Module):
42
+ """Standard bottleneck."""
43
+
44
+ def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5, drop_path=0.0): # ch_in, ch_out, shortcut, groups, kernels, expand
45
+ super().__init__()
46
+ c_ = int(c2 * e) # hidden channels
47
+ self.cv1 = Conv(c1, c_, k[0], 1)
48
+ self.cv2 = Conv(c_, c2, k[1], 1, g=g)
49
+ self.add = shortcut and c1 == c2
50
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
51
+
52
+ def forward(self, x):
53
+ """'forward()' applies the YOLOv5 FPN to input data."""
54
+ return x + self.drop_path1(self.cv2(self.cv1(x))) if self.add else self.cv2(self.cv1(x))
tim/models/nvidia_radio/radio/cls_token.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+ from typing import Optional
9
+
10
+ import torch
11
+ from torch import nn
12
+
13
+
14
+ class ClsToken(nn.Module):
15
+ def __init__(self, ndim: int,
16
+ num_tokens: int = 1,
17
+ enabled: bool = True,
18
+ register_multiple: Optional[int] = None,
19
+ num_registers: Optional[int] = None,
20
+ ):
21
+ super().__init__()
22
+
23
+ self.ndim = ndim
24
+ self.enabled = enabled
25
+ self.num_registers = 0
26
+ self.num_tokens = num_tokens
27
+ if enabled:
28
+ if num_registers:
29
+ self.num_registers = num_registers
30
+ elif register_multiple:
31
+ self.num_registers = register_multiple - (num_tokens % register_multiple)
32
+
33
+ scale = ndim ** -0.5
34
+ self.token = nn.Parameter(torch.randn(num_tokens + self.num_registers, ndim) * scale)
35
+ else:
36
+ self.token = None
37
+
38
+ self.num_patches = self.num_tokens + self.num_registers
39
+
40
+ def disable(self):
41
+ self.token = None
42
+ self.enabled = False
43
+
44
+ def forward(self, x: torch.Tensor):
45
+ if self.token is None:
46
+ return x
47
+
48
+ token = self.token.unsqueeze(0).expand(x.shape[0], -1, -1)
49
+ x = torch.cat([
50
+ token,
51
+ x,
52
+ ], dim=1)
53
+
54
+ return x
55
+
56
+ def no_weight_decay(self):
57
+ return [
58
+ 'token',
59
+ ]
tim/models/nvidia_radio/radio/common.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ from dataclasses import dataclass
10
+ from typing import Optional
11
+
12
+ from .radio_model import Resolution
13
+
14
+
15
+ @dataclass
16
+ class RadioResource:
17
+ url: str
18
+ patch_size: int
19
+ max_resolution: int
20
+ preferred_resolution: Resolution
21
+ vitdet_num_windowed: Optional[int] = None
22
+ vitdet_num_global: Optional[int] = None
23
+
24
+
25
+ RESOURCE_MAP = {
26
+ # RADIOv2.5
27
+ "radio_v2.5-b": RadioResource(
28
+ "https://huggingface.co/nvidia/RADIO/resolve/main/radio-v2.5-b_half.pth.tar?download=true",
29
+ patch_size=16,
30
+ max_resolution=2048,
31
+ preferred_resolution=(768, 768),
32
+ vitdet_num_global=4,
33
+ ),
34
+ "radio_v2.5-l": RadioResource(
35
+ "https://huggingface.co/nvidia/RADIO/resolve/main/radio-v2.5-l_half.pth.tar?download=true",
36
+ patch_size=16,
37
+ max_resolution=2048,
38
+ preferred_resolution=(768, 768),
39
+ vitdet_num_global=4,
40
+ ),
41
+ "radio_v2.5-h": RadioResource(
42
+ "https://huggingface.co/nvidia/RADIO/resolve/main/radio_v2.5-h.pth.tar?download=true",
43
+ patch_size=16,
44
+ max_resolution=2048,
45
+ preferred_resolution=(768, 768),
46
+ vitdet_num_global=4,
47
+ ),
48
+ "radio_v2.5-h-norm": RadioResource(
49
+ "https://huggingface.co/nvidia/RADIO/resolve/main/radio_v2.5-h-norm.pth.tar?download=true",
50
+ patch_size=16,
51
+ max_resolution=2048,
52
+ preferred_resolution=(768, 768),
53
+ vitdet_num_global=4,
54
+ ),
55
+ "radio_v2.5-g": RadioResource(
56
+ "https://huggingface.co/nvidia/RADIO/resolve/main/radio_v2.5-g.pth.tar?download=true",
57
+ patch_size=14,
58
+ max_resolution=1792,
59
+ preferred_resolution=(896, 896),
60
+ vitdet_num_global=8,
61
+ ),
62
+ # RADIO
63
+ "radio_v2.1": RadioResource(
64
+ "https://huggingface.co/nvidia/RADIO/resolve/main/radio_v2.1_bf16.pth.tar?download=true",
65
+ patch_size=16,
66
+ max_resolution=2048,
67
+ preferred_resolution=Resolution(432, 432),
68
+ vitdet_num_windowed=5,
69
+ ),
70
+ "radio_v2": RadioResource(
71
+ "https://huggingface.co/nvidia/RADIO/resolve/main/radio_v2.pth.tar?download=true",
72
+ patch_size=16,
73
+ max_resolution=2048,
74
+ preferred_resolution=Resolution(432, 432),
75
+ vitdet_num_windowed=5,
76
+ ),
77
+ "radio_v1": RadioResource(
78
+ "https://huggingface.co/nvidia/RADIO/resolve/main/radio_v1.pth.tar?download=true",
79
+ patch_size=14,
80
+ max_resolution=1050,
81
+ preferred_resolution=Resolution(378, 378),
82
+ ),
83
+ # E-RADIO
84
+ "e-radio_v2": RadioResource(
85
+ "https://huggingface.co/nvidia/RADIO/resolve/main/eradio_v2.pth.tar?download=true",
86
+ patch_size=16,
87
+ max_resolution=2048,
88
+ preferred_resolution=Resolution(512, 512),
89
+ ),
90
+ # C-RADIO
91
+ "c-radio_v2.5-g": RadioResource(
92
+ "https://huggingface.co/nvidia/C-RADIOv2-g/resolve/main/c-radio_v2-g_half.pth.tar",
93
+ patch_size=16,
94
+ max_resolution=2048,
95
+ preferred_resolution=(768, 768),
96
+ vitdet_num_global=8,
97
+ ),
98
+ "c-radio_v3-l": RadioResource(
99
+ # NOTE: Currently, this model cannot be loaded via TorchHub. Instead, use the transformers API at https://huggingface.co/nvidia/C-RADIOv3-L
100
+ # and accept the license terms.
101
+ "https://huggingface.co/nvidia/C-RADIOv3-L/resolve/main/c-radio-v3_l_half.pth.tar?download=true",
102
+ patch_size=16,
103
+ max_resolution=2048,
104
+ preferred_resolution=Resolution(512, 512),
105
+ ),
106
+ }
107
+
108
+ DEFAULT_VERSION = "radio_v2.5-h"
tim/models/nvidia_radio/radio/conv.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ """
3
+ Convolution modules
4
+ """
5
+
6
+ import math
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+ __all__ = ('Conv', 'LightConv', 'DWConv', 'DWConvTranspose2d', 'ConvTranspose', 'Focus', 'GhostConv',
13
+ 'ChannelAttention', 'SpatialAttention', 'CBAM', 'Concat', 'RepConv')
14
+
15
+
16
+ def autopad(k, p=None, d=1): # kernel, padding, dilation
17
+ """Pad to 'same' shape outputs."""
18
+ if d > 1:
19
+ k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k] # actual kernel-size
20
+ if p is None:
21
+ p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
22
+ return p
23
+
24
+ # Pavlo's implementation with switch to deploy
25
+ class Conv(nn.Module):
26
+ default_act = nn.SiLU() # default activation
27
+
28
+ def __init__(self, a, b, kernel_size=1, stride=1, padding=None, g=1, dilation=1, bn_weight_init=1, bias=False, act=True):
29
+ super().__init__()
30
+
31
+ self.conv = torch.nn.Conv2d(a, b, kernel_size, stride, autopad(kernel_size, padding, dilation), dilation, g, bias=False)
32
+ if 1:
33
+ self.bn = torch.nn.BatchNorm2d(b)
34
+ torch.nn.init.constant_(self.bn.weight, bn_weight_init)
35
+ torch.nn.init.constant_(self.bn.bias, 0)
36
+ self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
37
+
38
+
39
+ def forward(self,x):
40
+ x = self.conv(x)
41
+ x = self.bn(x)
42
+ x = self.act(x)
43
+ return x
44
+
45
+ @torch.no_grad()
46
+ def switch_to_deploy(self):
47
+ if not isinstance(self.bn, nn.Identity):
48
+ # return 1
49
+ c, bn = self.conv, self.bn
50
+ w = bn.weight / (bn.running_var + bn.eps) ** 0.5
51
+ w = c.weight * w[:, None, None, None]
52
+ b = bn.bias - bn.running_mean * bn.weight / \
53
+ (bn.running_var + bn.eps)**0.5
54
+ # m = torch.nn.Conv2d(w.size(1) * c.groups,
55
+ # w.size(0),
56
+ # w.shape[2:],
57
+ # stride=c.stride,
58
+ # padding=c.padding,
59
+ # dilation=c.dilation,
60
+ # groups=c.groups)
61
+ self.conv.weight.data.copy_(w)
62
+ self.conv.bias = nn.Parameter(b)
63
+ # self.conv.bias.data.copy_(b)
64
+ # self.conv = m.to(c.weight.device)
65
+ self.bn = nn.Identity()
tim/models/nvidia_radio/radio/dinov2_arch.py ADDED
@@ -0,0 +1,1016 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9
+
10
+ # Nvidia
11
+ # NOTE: We re-define this model architecture primarily so that we don't have to worry about version compatibility breaking,
12
+ # but also because Huggingface does a string replace of `gamma` to something else when loading the model state,
13
+ # and this breaks loading of this model.
14
+
15
+ from enum import Enum
16
+ from functools import partial
17
+ import logging
18
+ import math
19
+ import os
20
+ import sys
21
+ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
22
+ import warnings
23
+
24
+ import torch
25
+ from torch import nn
26
+ from torch.nn import functional as F
27
+ from torch.nn.init import trunc_normal_
28
+
29
+ _torch_has_sdpa = hasattr(F, 'scaled_dot_product_attention')
30
+
31
+
32
+ XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
33
+ try:
34
+ if XFORMERS_ENABLED:
35
+ from xformers.ops import fmha, scaled_index_add, index_select_cat, SwiGLU, memory_efficient_attention, unbind
36
+
37
+ XFORMERS_AVAILABLE = True
38
+ else:
39
+ raise ImportError
40
+ except ImportError:
41
+ XFORMERS_AVAILABLE = False
42
+
43
+
44
+ def make_2tuple(x):
45
+ if isinstance(x, tuple):
46
+ assert len(x) == 2
47
+ return x
48
+
49
+ assert isinstance(x, int)
50
+ return (x, x)
51
+
52
+
53
+ class PatchEmbed(nn.Module):
54
+ """
55
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
56
+
57
+ Args:
58
+ img_size: Image size.
59
+ patch_size: Patch token size.
60
+ in_chans: Number of input image channels.
61
+ embed_dim: Number of linear projection output channels.
62
+ norm_layer: Normalization layer.
63
+ """
64
+
65
+ def __init__(
66
+ self,
67
+ img_size: Union[int, Tuple[int, int]] = 224,
68
+ patch_size: Union[int, Tuple[int, int]] = 16,
69
+ in_chans: int = 3,
70
+ embed_dim: int = 768,
71
+ norm_layer: Optional[Callable] = None,
72
+ flatten_embedding: bool = True,
73
+ ) -> None:
74
+ super().__init__()
75
+
76
+ image_HW = make_2tuple(img_size)
77
+ patch_HW = make_2tuple(patch_size)
78
+ patch_grid_size = (
79
+ image_HW[0] // patch_HW[0],
80
+ image_HW[1] // patch_HW[1],
81
+ )
82
+
83
+ self.img_size = image_HW
84
+ self.patch_size = patch_HW
85
+ self.patches_resolution = patch_grid_size
86
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
87
+
88
+ self.in_chans = in_chans
89
+ self.embed_dim = embed_dim
90
+
91
+ self.flatten_embedding = flatten_embedding
92
+
93
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
94
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
95
+
96
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
97
+ _, _, H, W = x.shape
98
+ patch_H, patch_W = self.patch_size
99
+
100
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
101
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
102
+
103
+ x = self.proj(x) # B C H W
104
+ H, W = x.size(2), x.size(3)
105
+ x = x.flatten(2).transpose(1, 2) # B HW C
106
+ x = self.norm(x)
107
+ if not self.flatten_embedding:
108
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
109
+ return x
110
+
111
+ def flops(self) -> float:
112
+ Ho, Wo = self.patches_resolution
113
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
114
+ if self.norm is not None:
115
+ flops += Ho * Wo * self.embed_dim
116
+ return flops
117
+
118
+
119
+ class Attention(nn.Module):
120
+ def __init__(
121
+ self,
122
+ dim: int,
123
+ num_heads: int = 8,
124
+ qkv_bias: bool = False,
125
+ proj_bias: bool = True,
126
+ attn_drop: float = 0.0,
127
+ proj_drop: float = 0.0,
128
+ ) -> None:
129
+ super().__init__()
130
+ self.num_heads = num_heads
131
+ head_dim = dim // num_heads
132
+ self.scale = head_dim**-0.5
133
+
134
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
135
+ self.attn_drop = nn.Dropout(attn_drop)
136
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
137
+ self.proj_drop = nn.Dropout(proj_drop)
138
+
139
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
140
+ B, N, C = x.shape
141
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
142
+
143
+ q, k, v = qkv[0], qkv[1], qkv[2]
144
+ if _torch_has_sdpa:
145
+ x = F.scaled_dot_product_attention(
146
+ q, k, v,
147
+ is_causal=False,
148
+ dropout_p=self.attn_drop.p if self.training else 0.,
149
+ scale=self.scale,
150
+ )
151
+ else:
152
+ q = q * self.scale
153
+ attn = q @ k.transpose(-2, -1)
154
+
155
+ attn = attn.softmax(dim=-1)
156
+ attn = self.attn_drop(attn)
157
+ x = attn @ v
158
+
159
+ x = x.transpose(1, 2).reshape(B, N, C)
160
+ x = self.proj(x)
161
+ x = self.proj_drop(x)
162
+ return x
163
+
164
+
165
+ class MemEffAttention(Attention):
166
+ def forward(self, x: torch.Tensor, attn_bias=None) -> torch.Tensor:
167
+ if not XFORMERS_AVAILABLE:
168
+ if attn_bias is not None:
169
+ raise AssertionError("xFormers is required for using nested tensors")
170
+ return super().forward(x)
171
+
172
+ B, N, C = x.shape
173
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
174
+
175
+ q, k, v = unbind(qkv, 2)
176
+
177
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
178
+ x = x.reshape([B, N, C])
179
+
180
+ x = self.proj(x)
181
+ x = self.proj_drop(x)
182
+ return x
183
+
184
+
185
+ class Mlp(nn.Module):
186
+ def __init__(
187
+ self,
188
+ in_features: int,
189
+ hidden_features: Optional[int] = None,
190
+ out_features: Optional[int] = None,
191
+ act_layer: Callable[..., nn.Module] = nn.GELU,
192
+ drop: float = 0.0,
193
+ bias: bool = True,
194
+ ) -> None:
195
+ super().__init__()
196
+ out_features = out_features or in_features
197
+ hidden_features = hidden_features or in_features
198
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
199
+ self.act = act_layer()
200
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
201
+ self.drop = nn.Dropout(drop)
202
+
203
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
204
+ x = self.fc1(x)
205
+ x = self.act(x)
206
+ x = self.drop(x)
207
+ x = self.fc2(x)
208
+ x = self.drop(x)
209
+ return x
210
+
211
+
212
+ class SwiGLUFFN(nn.Module):
213
+ def __init__(
214
+ self,
215
+ in_features: int,
216
+ hidden_features: Optional[int] = None,
217
+ out_features: Optional[int] = None,
218
+ act_layer: Callable[..., nn.Module] = None,
219
+ drop: float = 0.0,
220
+ bias: bool = True,
221
+ ) -> None:
222
+ super().__init__()
223
+ out_features = out_features or in_features
224
+ hidden_features = hidden_features or in_features
225
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
226
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
227
+
228
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
229
+ x12 = self.w12(x)
230
+ x1, x2 = x12.chunk(2, dim=-1)
231
+ hidden = F.silu(x1) * x2
232
+ return self.w3(hidden)
233
+
234
+
235
+ if not XFORMERS_AVAILABLE:
236
+ SwiGLU = SwiGLUFFN
237
+
238
+
239
+ class SwiGLUFFNFused(SwiGLU):
240
+ def __init__(
241
+ self,
242
+ in_features: int,
243
+ hidden_features: Optional[int] = None,
244
+ out_features: Optional[int] = None,
245
+ act_layer: Callable[..., nn.Module] = None,
246
+ drop: float = 0.0,
247
+ bias: bool = True,
248
+ ) -> None:
249
+ out_features = out_features or in_features
250
+ hidden_features = hidden_features or in_features
251
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
252
+ super().__init__(
253
+ in_features=in_features,
254
+ hidden_features=hidden_features,
255
+ out_features=out_features,
256
+ bias=bias,
257
+ )
258
+
259
+
260
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
261
+ if drop_prob == 0.0 or not training:
262
+ return x
263
+ keep_prob = 1 - drop_prob
264
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
265
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
266
+ if keep_prob > 0.0:
267
+ random_tensor.div_(keep_prob)
268
+ output = x * random_tensor
269
+ return output
270
+
271
+
272
+ class DropPath(nn.Module):
273
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
274
+
275
+ def __init__(self, drop_prob=None):
276
+ super(DropPath, self).__init__()
277
+ self.drop_prob = drop_prob
278
+
279
+ def forward(self, x):
280
+ return drop_path(x, self.drop_prob, self.training)
281
+
282
+
283
+ class LayerScale(nn.Module):
284
+ def __init__(
285
+ self,
286
+ dim: int,
287
+ init_values: Union[float, torch.Tensor] = 1e-5,
288
+ inplace: bool = False,
289
+ ) -> None:
290
+ super().__init__()
291
+ self.inplace = inplace
292
+ self.grandma = nn.Parameter(init_values * torch.ones(dim))
293
+
294
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
295
+ return x.mul_(self.grandma) if self.inplace else x * self.grandma
296
+
297
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
298
+ # Huggingface is absurd and it will rename strings that contain `gamma`, which means that the normal DINO implementation
299
+ # of LayerScale won't work with HFHub. So we rename the variable to 'grandma', and support loading checkpoints in either
300
+ # format
301
+ key_a = f'{prefix}gamma'
302
+ key_b = f'{prefix}grandma'
303
+ if key_a in state_dict:
304
+ gamma = state_dict[key_a]
305
+ elif key_b in state_dict:
306
+ gamma = state_dict[key_b]
307
+ else:
308
+ if strict:
309
+ raise KeyError(f"Couldn't find the key {key_a} nor {key_b} in the state dict!")
310
+ else:
311
+ missing_keys.append(key_a)
312
+ missing_keys.append(key_b)
313
+ unexpected_keys.extend(state_dict.keys())
314
+ gamma = None
315
+
316
+ if gamma is not None:
317
+ self.grandma.data.copy_(gamma)
318
+
319
+ # return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
320
+
321
+
322
+ class Block(nn.Module):
323
+ def __init__(
324
+ self,
325
+ dim: int,
326
+ num_heads: int,
327
+ mlp_ratio: float = 4.0,
328
+ qkv_bias: bool = False,
329
+ proj_bias: bool = True,
330
+ ffn_bias: bool = True,
331
+ drop: float = 0.0,
332
+ attn_drop: float = 0.0,
333
+ init_values=None,
334
+ drop_path: float = 0.0,
335
+ act_layer: Callable[..., nn.Module] = nn.GELU,
336
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
337
+ attn_class: Callable[..., nn.Module] = Attention,
338
+ ffn_layer: Callable[..., nn.Module] = Mlp,
339
+ ) -> None:
340
+ super().__init__()
341
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
342
+ self.norm1 = norm_layer(dim)
343
+ self.attn = attn_class(
344
+ dim,
345
+ num_heads=num_heads,
346
+ qkv_bias=qkv_bias,
347
+ proj_bias=proj_bias,
348
+ attn_drop=attn_drop,
349
+ proj_drop=drop,
350
+ )
351
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
352
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
353
+
354
+ self.norm2 = norm_layer(dim)
355
+ mlp_hidden_dim = int(dim * mlp_ratio)
356
+ self.mlp = ffn_layer(
357
+ in_features=dim,
358
+ hidden_features=mlp_hidden_dim,
359
+ act_layer=act_layer,
360
+ drop=drop,
361
+ bias=ffn_bias,
362
+ )
363
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
364
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
365
+
366
+ self.sample_drop_ratio = drop_path
367
+
368
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
369
+ def attn_residual_func(x: torch.Tensor) -> torch.Tensor:
370
+ return self.ls1(self.attn(self.norm1(x)))
371
+
372
+ def ffn_residual_func(x: torch.Tensor) -> torch.Tensor:
373
+ return self.ls2(self.mlp(self.norm2(x)))
374
+
375
+ if self.training and self.sample_drop_ratio > 0.1:
376
+ # the overhead is compensated only for a drop path rate larger than 0.1
377
+ x = drop_add_residual_stochastic_depth(
378
+ x,
379
+ residual_func=attn_residual_func,
380
+ sample_drop_ratio=self.sample_drop_ratio,
381
+ )
382
+ x = drop_add_residual_stochastic_depth(
383
+ x,
384
+ residual_func=ffn_residual_func,
385
+ sample_drop_ratio=self.sample_drop_ratio,
386
+ )
387
+ elif self.training and self.sample_drop_ratio > 0.0:
388
+ x = x + self.drop_path1(attn_residual_func(x))
389
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
390
+ else:
391
+ x = x + attn_residual_func(x)
392
+ x = x + ffn_residual_func(x)
393
+ return x
394
+
395
+
396
+ class NestedTensorBlock(Block):
397
+ def forward_nested(self, x_list: List[torch.Tensor]) -> List[torch.Tensor]:
398
+ """
399
+ x_list contains a list of tensors to nest together and run
400
+ """
401
+ assert isinstance(self.attn, MemEffAttention)
402
+
403
+ if self.training and self.sample_drop_ratio > 0.0:
404
+
405
+ def attn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
406
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
407
+
408
+ def ffn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
409
+ return self.mlp(self.norm2(x))
410
+
411
+ x_list = drop_add_residual_stochastic_depth_list(
412
+ x_list,
413
+ residual_func=attn_residual_func,
414
+ sample_drop_ratio=self.sample_drop_ratio,
415
+ scaling_vector=self.ls1.grandma if isinstance(self.ls1, LayerScale) else None,
416
+ )
417
+ x_list = drop_add_residual_stochastic_depth_list(
418
+ x_list,
419
+ residual_func=ffn_residual_func,
420
+ sample_drop_ratio=self.sample_drop_ratio,
421
+ scaling_vector=self.ls2.grandma if isinstance(self.ls1, LayerScale) else None,
422
+ )
423
+ return x_list
424
+ else:
425
+
426
+ def attn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
427
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
428
+
429
+ def ffn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
430
+ return self.ls2(self.mlp(self.norm2(x)))
431
+
432
+ attn_bias, x = get_attn_bias_and_cat(x_list)
433
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
434
+ x = x + ffn_residual_func(x)
435
+ return attn_bias.split(x)
436
+
437
+ def forward(self, x_or_x_list):
438
+ if isinstance(x_or_x_list, torch.Tensor):
439
+ return super().forward(x_or_x_list)
440
+ elif isinstance(x_or_x_list, list):
441
+ if not XFORMERS_AVAILABLE:
442
+ raise AssertionError("xFormers is required for using nested tensors")
443
+ return self.forward_nested(x_or_x_list)
444
+ else:
445
+ raise AssertionError
446
+
447
+
448
+ def drop_add_residual_stochastic_depth(
449
+ x: torch.Tensor,
450
+ residual_func: Callable[[torch.Tensor], torch.Tensor],
451
+ sample_drop_ratio: float = 0.0,
452
+ ) -> torch.Tensor:
453
+ # 1) extract subset using permutation
454
+ b, n, d = x.shape
455
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
456
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
457
+ x_subset = x[brange]
458
+
459
+ # 2) apply residual_func to get residual
460
+ residual = residual_func(x_subset)
461
+
462
+ x_flat = x.flatten(1)
463
+ residual = residual.flatten(1)
464
+
465
+ residual_scale_factor = b / sample_subset_size
466
+
467
+ # 3) add the residual
468
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
469
+ return x_plus_residual.view_as(x)
470
+
471
+
472
+ def get_branges_scales(x, sample_drop_ratio=0.0):
473
+ b, n, d = x.shape
474
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
475
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
476
+ residual_scale_factor = b / sample_subset_size
477
+ return brange, residual_scale_factor
478
+
479
+
480
+ def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
481
+ if scaling_vector is None:
482
+ x_flat = x.flatten(1)
483
+ residual = residual.flatten(1)
484
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
485
+ else:
486
+ x_plus_residual = scaled_index_add(
487
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
488
+ )
489
+ return x_plus_residual
490
+
491
+
492
+ attn_bias_cache: Dict[Tuple, Any] = {}
493
+
494
+
495
+ def get_attn_bias_and_cat(x_list, branges=None):
496
+ """
497
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
498
+ """
499
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
500
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
501
+ if all_shapes not in attn_bias_cache.keys():
502
+ seqlens = []
503
+ for b, x in zip(batch_sizes, x_list):
504
+ for _ in range(b):
505
+ seqlens.append(x.shape[1])
506
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
507
+ attn_bias._batch_sizes = batch_sizes
508
+ attn_bias_cache[all_shapes] = attn_bias
509
+
510
+ if branges is not None:
511
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
512
+ else:
513
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
514
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
515
+
516
+ return attn_bias_cache[all_shapes], cat_tensors
517
+
518
+
519
+ def drop_add_residual_stochastic_depth_list(
520
+ x_list: List[torch.Tensor],
521
+ residual_func: Callable[[torch.Tensor, Any], torch.Tensor],
522
+ sample_drop_ratio: float = 0.0,
523
+ scaling_vector=None,
524
+ ) -> torch.Tensor:
525
+ # 1) generate random set of indices for dropping samples in the batch
526
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
527
+ branges = [s[0] for s in branges_scales]
528
+ residual_scale_factors = [s[1] for s in branges_scales]
529
+
530
+ # 2) get attention bias and index+concat the tensors
531
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
532
+
533
+ # 3) apply residual_func to get residual, and split the result
534
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
535
+
536
+ outputs = []
537
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
538
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
539
+ return outputs
540
+
541
+
542
+ def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
543
+ if not depth_first and include_root:
544
+ fn(module=module, name=name)
545
+ for child_name, child_module in module.named_children():
546
+ child_name = ".".join((name, child_name)) if name else child_name
547
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
548
+ if depth_first and include_root:
549
+ fn(module=module, name=name)
550
+ return module
551
+
552
+
553
+ class BlockChunk(nn.ModuleList):
554
+ def forward(self, x):
555
+ for b in self:
556
+ x = b(x)
557
+ return x
558
+
559
+
560
+ class DinoVisionTransformer(nn.Module):
561
+ def __init__(
562
+ self,
563
+ img_size=224,
564
+ patch_size=16,
565
+ in_chans=3,
566
+ embed_dim=768,
567
+ depth=12,
568
+ num_heads=12,
569
+ mlp_ratio=4.0,
570
+ qkv_bias=True,
571
+ ffn_bias=True,
572
+ proj_bias=True,
573
+ drop_path_rate=0.0,
574
+ drop_path_uniform=False,
575
+ init_values=None, # for layerscale: None or 0 => no layerscale
576
+ embed_layer=PatchEmbed,
577
+ act_layer=nn.GELU,
578
+ block_fn=Block,
579
+ ffn_layer="mlp",
580
+ block_chunks=1,
581
+ num_register_tokens=0,
582
+ interpolate_antialias=False,
583
+ interpolate_offset=0.1,
584
+ ):
585
+ """
586
+ Args:
587
+ img_size (int, tuple): input image size
588
+ patch_size (int, tuple): patch size
589
+ in_chans (int): number of input channels
590
+ embed_dim (int): embedding dimension
591
+ depth (int): depth of transformer
592
+ num_heads (int): number of attention heads
593
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
594
+ qkv_bias (bool): enable bias for qkv if True
595
+ proj_bias (bool): enable bias for proj in attn if True
596
+ ffn_bias (bool): enable bias for ffn if True
597
+ drop_path_rate (float): stochastic depth rate
598
+ drop_path_uniform (bool): apply uniform drop rate across blocks
599
+ weight_init (str): weight init scheme
600
+ init_values (float): layer-scale init values
601
+ embed_layer (nn.Module): patch embedding layer
602
+ act_layer (nn.Module): MLP activation layer
603
+ block_fn (nn.Module): transformer block class
604
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
605
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
606
+ num_register_tokens: (int) number of extra cls tokens (so-called "registers")
607
+ interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
608
+ interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
609
+ """
610
+ super().__init__()
611
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
612
+
613
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
614
+ self.num_tokens = 1
615
+ self.n_blocks = depth
616
+ self.num_heads = num_heads
617
+ self.patch_size = patch_size
618
+ self.num_register_tokens = num_register_tokens
619
+ self.interpolate_antialias = interpolate_antialias
620
+ self.interpolate_offset = interpolate_offset
621
+
622
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
623
+ num_patches = self.patch_embed.num_patches
624
+
625
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
626
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
627
+ assert num_register_tokens >= 0
628
+ self.register_tokens = (
629
+ nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
630
+ )
631
+
632
+ if drop_path_uniform is True:
633
+ dpr = [drop_path_rate] * depth
634
+ else:
635
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
636
+
637
+ if ffn_layer == "mlp":
638
+ ffn_layer = Mlp
639
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
640
+ ffn_layer = SwiGLUFFNFused
641
+ elif ffn_layer == "identity":
642
+ def f(*args, **kwargs):
643
+ return nn.Identity()
644
+
645
+ ffn_layer = f
646
+ else:
647
+ raise NotImplementedError
648
+
649
+ blocks_list = [
650
+ block_fn(
651
+ dim=embed_dim,
652
+ num_heads=num_heads,
653
+ mlp_ratio=mlp_ratio,
654
+ qkv_bias=qkv_bias,
655
+ proj_bias=proj_bias,
656
+ ffn_bias=ffn_bias,
657
+ drop_path=dpr[i],
658
+ norm_layer=norm_layer,
659
+ act_layer=act_layer,
660
+ ffn_layer=ffn_layer,
661
+ init_values=init_values,
662
+ )
663
+ for i in range(depth)
664
+ ]
665
+ if block_chunks > 0:
666
+ self.chunked_blocks = True
667
+ chunked_blocks = []
668
+ chunksize = depth // block_chunks
669
+ for i in range(0, depth, chunksize):
670
+ # this is to keep the block index consistent if we chunk the block list
671
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
672
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
673
+ else:
674
+ self.chunked_blocks = False
675
+ self.blocks = nn.ModuleList(blocks_list)
676
+
677
+ self.norm = norm_layer(embed_dim)
678
+ self.head = nn.Identity()
679
+
680
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
681
+
682
+ def interpolate_pos_encoding(self, x, w, h):
683
+ previous_dtype = x.dtype
684
+ npatch = x.shape[1] - 1
685
+ N = self.pos_embed.shape[1] - 1
686
+ if npatch == N and w == h:
687
+ return self.pos_embed
688
+ pos_embed = self.pos_embed.float()
689
+ class_pos_embed = pos_embed[:, 0]
690
+ patch_pos_embed = pos_embed[:, 1:]
691
+ dim = x.shape[-1]
692
+ w0 = w // self.patch_size
693
+ h0 = h // self.patch_size
694
+ M = int(math.sqrt(N)) # Recover the number of patches in each dimension
695
+ assert N == M * M
696
+ kwargs = {}
697
+ if self.interpolate_offset:
698
+ # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
699
+ # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
700
+ sx = float(w0 + self.interpolate_offset) / M
701
+ sy = float(h0 + self.interpolate_offset) / M
702
+ kwargs["scale_factor"] = (sx, sy)
703
+ else:
704
+ # Simply specify an output size instead of a scale factor
705
+ kwargs["size"] = (w0, h0)
706
+ patch_pos_embed = nn.functional.interpolate(
707
+ patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
708
+ mode="bicubic",
709
+ antialias=self.interpolate_antialias,
710
+ **kwargs,
711
+ )
712
+ assert (w0, h0) == patch_pos_embed.shape[-2:]
713
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
714
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
715
+
716
+ def prepare_tokens_with_masks(self, x, masks=None):
717
+ B, nc, w, h = x.shape
718
+ x = self.patch_embed(x)
719
+ if masks is not None:
720
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
721
+
722
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
723
+ x = x + self.interpolate_pos_encoding(x, w, h)
724
+
725
+ if self.register_tokens is not None:
726
+ x = torch.cat(
727
+ (
728
+ x[:, :1],
729
+ self.register_tokens.expand(x.shape[0], -1, -1),
730
+ x[:, 1:],
731
+ ),
732
+ dim=1,
733
+ )
734
+
735
+ return x
736
+
737
+ def forward_features_list(self, x_list, masks_list):
738
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
739
+ for blk in self.blocks:
740
+ x = blk(x)
741
+
742
+ all_x = x
743
+ output = []
744
+ for x, masks in zip(all_x, masks_list):
745
+ x_norm = self.norm(x)
746
+ output.append(
747
+ {
748
+ "x_norm_clstoken": x_norm[:, 0],
749
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
750
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
751
+ "x_prenorm": x,
752
+ "masks": masks,
753
+ }
754
+ )
755
+ return output
756
+
757
+ def forward_features(self, x, masks=None):
758
+ if isinstance(x, list):
759
+ return self.forward_features_list(x, masks)
760
+
761
+ x = self.prepare_tokens_with_masks(x, masks)
762
+
763
+ for blk in self.blocks:
764
+ x = blk(x)
765
+
766
+ x_norm = self.norm(x)
767
+ return {
768
+ "x_norm_clstoken": x_norm[:, 0],
769
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
770
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
771
+ "x_prenorm": x,
772
+ "masks": masks,
773
+ }
774
+
775
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
776
+ x = self.prepare_tokens_with_masks(x)
777
+ # If n is an int, take the n last blocks. If it's a list, take them
778
+ output, total_block_len = [], len(self.blocks)
779
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
780
+ for i, blk in enumerate(self.blocks):
781
+ x = blk(x)
782
+ if i in blocks_to_take:
783
+ output.append(x)
784
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
785
+ return output
786
+
787
+ def _get_intermediate_layers_chunked(self, x, n=1):
788
+ x = self.prepare_tokens_with_masks(x)
789
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
790
+ # If n is an int, take the n last blocks. If it's a list, take them
791
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
792
+ for block_chunk in self.blocks:
793
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
794
+ x = blk(x)
795
+ if i in blocks_to_take:
796
+ output.append(x)
797
+ i += 1
798
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
799
+ return output
800
+
801
+ def get_intermediate_layers(
802
+ self,
803
+ x: torch.Tensor,
804
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
805
+ reshape: bool = False,
806
+ return_class_token: bool = False,
807
+ norm=True,
808
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
809
+ if self.chunked_blocks:
810
+ outputs = self._get_intermediate_layers_chunked(x, n)
811
+ else:
812
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
813
+ if norm:
814
+ outputs = [self.norm(out) for out in outputs]
815
+ class_tokens = [out[:, 0] for out in outputs]
816
+ outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]
817
+ if reshape:
818
+ B, _, w, h = x.shape
819
+ outputs = [
820
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
821
+ for out in outputs
822
+ ]
823
+ if return_class_token:
824
+ return tuple(zip(outputs, class_tokens))
825
+ return tuple(outputs)
826
+
827
+ def forward(self, *args, is_training=False, **kwargs):
828
+ ret = self.forward_features(*args, **kwargs)
829
+ if is_training:
830
+ return ret
831
+ else:
832
+ return self.head(ret["x_norm_clstoken"])
833
+
834
+
835
+ def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
836
+ model = DinoVisionTransformer(
837
+ patch_size=patch_size,
838
+ embed_dim=384,
839
+ depth=12,
840
+ num_heads=6,
841
+ mlp_ratio=4,
842
+ block_fn=partial(Block, attn_class=MemEffAttention),
843
+ num_register_tokens=num_register_tokens,
844
+ **kwargs,
845
+ )
846
+ return model
847
+
848
+
849
+ def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
850
+ model = DinoVisionTransformer(
851
+ patch_size=patch_size,
852
+ embed_dim=768,
853
+ depth=12,
854
+ num_heads=12,
855
+ mlp_ratio=4,
856
+ block_fn=partial(Block, attn_class=MemEffAttention),
857
+ num_register_tokens=num_register_tokens,
858
+ **kwargs,
859
+ )
860
+ return model
861
+
862
+
863
+ def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
864
+ model = DinoVisionTransformer(
865
+ patch_size=patch_size,
866
+ embed_dim=1024,
867
+ depth=24,
868
+ num_heads=16,
869
+ mlp_ratio=4,
870
+ block_fn=partial(Block, attn_class=MemEffAttention),
871
+ num_register_tokens=num_register_tokens,
872
+ **kwargs,
873
+ )
874
+ return model
875
+
876
+
877
+ def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
878
+ """
879
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
880
+ """
881
+ model = DinoVisionTransformer(
882
+ patch_size=patch_size,
883
+ embed_dim=1536,
884
+ depth=40,
885
+ num_heads=24,
886
+ mlp_ratio=4,
887
+ block_fn=partial(Block, attn_class=MemEffAttention),
888
+ num_register_tokens=num_register_tokens,
889
+ **kwargs,
890
+ )
891
+ return model
892
+
893
+
894
+ class Weights(Enum):
895
+ LVD142M = "LVD142M"
896
+
897
+
898
+ def _make_dinov2_model(
899
+ *,
900
+ arch_name: str = "vit_large",
901
+ img_size: int = 518,
902
+ patch_size: int = 14,
903
+ init_values: float = 1.0,
904
+ ffn_layer: str = "mlp",
905
+ block_chunks: int = 0,
906
+ num_register_tokens: int = 0,
907
+ interpolate_antialias: bool = False,
908
+ interpolate_offset: float = 0.1,
909
+ weights: Union[Weights, str] = Weights.LVD142M,
910
+ **kwargs,
911
+ ):
912
+ if isinstance(weights, str):
913
+ try:
914
+ weights = Weights[weights]
915
+ except KeyError:
916
+ raise AssertionError(f"Unsupported weights: {weights}")
917
+
918
+ vit_kwargs = dict(
919
+ img_size=img_size,
920
+ patch_size=patch_size,
921
+ init_values=init_values,
922
+ ffn_layer=ffn_layer,
923
+ block_chunks=block_chunks,
924
+ num_register_tokens=num_register_tokens,
925
+ interpolate_antialias=interpolate_antialias,
926
+ interpolate_offset=interpolate_offset,
927
+ )
928
+ vit_kwargs.update(**kwargs)
929
+ model = sys.modules[__name__].__dict__[arch_name](**vit_kwargs)
930
+
931
+ return model
932
+
933
+
934
+ def dinov2_vits14(**kwargs):
935
+ """
936
+ DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset.
937
+ """
938
+ return _make_dinov2_model(arch_name="vit_small", **kwargs)
939
+
940
+
941
+ def dinov2_vitb14(**kwargs):
942
+ """
943
+ DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset.
944
+ """
945
+ return _make_dinov2_model(arch_name="vit_base", **kwargs)
946
+
947
+
948
+ def dinov2_vitl14(**kwargs):
949
+ """
950
+ DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset.
951
+ """
952
+ return _make_dinov2_model(arch_name="vit_large", **kwargs)
953
+
954
+
955
+ def dinov2_vitg14(**kwargs):
956
+ """
957
+ DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset.
958
+ """
959
+ return _make_dinov2_model(
960
+ arch_name="vit_giant2",
961
+ ffn_layer="swiglufused",
962
+ **kwargs,
963
+ )
964
+
965
+
966
+ def dinov2_vits14_reg(**kwargs):
967
+ """
968
+ DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset.
969
+ """
970
+ return _make_dinov2_model(
971
+ arch_name="vit_small",
972
+ num_register_tokens=4,
973
+ interpolate_antialias=True,
974
+ interpolate_offset=0.0,
975
+ **kwargs,
976
+ )
977
+
978
+
979
+ def dinov2_vitb14_reg(**kwargs):
980
+ """
981
+ DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset.
982
+ """
983
+ return _make_dinov2_model(
984
+ arch_name="vit_base",
985
+ num_register_tokens=4,
986
+ interpolate_antialias=True,
987
+ interpolate_offset=0.0,
988
+ **kwargs,
989
+ )
990
+
991
+
992
+ def dinov2_vitl14_reg(**kwargs):
993
+ """
994
+ DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset.
995
+ """
996
+ return _make_dinov2_model(
997
+ arch_name="vit_large",
998
+ num_register_tokens=4,
999
+ interpolate_antialias=True,
1000
+ interpolate_offset=0.0,
1001
+ **kwargs,
1002
+ )
1003
+
1004
+
1005
+ def dinov2_vitg14_reg(**kwargs):
1006
+ """
1007
+ DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset.
1008
+ """
1009
+ return _make_dinov2_model(
1010
+ arch_name="vit_giant2",
1011
+ ffn_layer="swiglufused",
1012
+ num_register_tokens=4,
1013
+ interpolate_antialias=True,
1014
+ interpolate_offset=0.0,
1015
+ **kwargs,
1016
+ )
tim/models/nvidia_radio/radio/dual_hybrid_vit.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from logging import getLogger
2
+ from typing import Tuple
3
+
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+
8
+ from timm.models import register_model
9
+ from timm.models import vision_transformer as tvit
10
+ from timm.models import convnext as tconv
11
+
12
+ from einops import rearrange
13
+
14
+ from . import extra_timm_models as et
15
+
16
+
17
+ class Fuser(nn.Module):
18
+ def __init__(self, src_dim: int, tgt_dim: int, gated: bool = True):
19
+ super().__init__()
20
+ self.gated = gated
21
+
22
+ mid_dim = max(src_dim, tgt_dim) * 2
23
+
24
+ self.fwd = nn.Sequential(
25
+ nn.Conv2d(src_dim, mid_dim, kernel_size=3, stride=1, padding=1),
26
+ nn.GELU(),
27
+ nn.Conv2d(mid_dim, tgt_dim * (2 if gated else 1), kernel_size=3, stride=1, padding=1),
28
+ )
29
+
30
+ def forward(self, src: torch.Tensor, tgt: torch.Tensor) -> torch.Tensor:
31
+ if src.ndim == 3:
32
+ shape = tgt.shape[-2:]
33
+ else:
34
+ shape = src.shape[-2:]
35
+
36
+ nd = shape[0] * shape[1]
37
+
38
+ if src.ndim == 3:
39
+ src = src[:, -nd:].reshape(src.shape[0], src.shape[2], *shape)
40
+
41
+ if tgt.ndim == 3:
42
+ tgt_pre = tgt[:, :-nd]
43
+ tgt = tgt[:, -nd:].reshape(tgt.shape[0], tgt.shape[2], *shape)
44
+ else:
45
+ tgt_pre = None
46
+
47
+ pred = self.fwd(src)
48
+
49
+ if self.gated:
50
+ g, pred = torch.chunk(pred, 2, dim=1)
51
+
52
+ g = F.sigmoid(g)
53
+
54
+ pred = g * pred
55
+
56
+ tgt = tgt + pred
57
+
58
+ if tgt_pre is not None:
59
+ tgt = rearrange(tgt, 'b c h w -> b (h w) c')
60
+ tgt = torch.cat([tgt_pre, tgt], dim=1)
61
+
62
+ return tgt
63
+
64
+
65
+ class AttnDownsample(nn.Module):
66
+ def __init__(self, dim: int, window_size: int, num_heads: int = 16):
67
+ super().__init__()
68
+ self.q = nn.Parameter(torch.randn(1, num_heads, 1, dim // num_heads) * 0.01)
69
+ self.kv = nn.Linear(dim, dim * 2)
70
+ self.proj = nn.Linear(dim, dim)
71
+ self.window_size = window_size
72
+ self.num_heads = num_heads
73
+ self.head_dim = dim // num_heads
74
+ self.scale = self.head_dim ** -0.5
75
+
76
+ def forward(self, x: torch.Tensor, twod_shape: Tuple[int, int]) -> torch.Tensor:
77
+ ntok = twod_shape[0] * twod_shape[1]
78
+ x_pre = x[:, :-ntok]
79
+
80
+ B = x.shape[0]
81
+ ds_hw = tuple(s // self.window_size for s in twod_shape)
82
+
83
+ x_spat = rearrange(
84
+ x[:, -ntok:],
85
+ 'b (h d1 w d2) c -> (b h w) (d1 d2) c',
86
+ h=ds_hw[0], w=ds_hw[1],
87
+ d1=self.window_size, d2=self.window_size,
88
+ )
89
+
90
+ B, N, C = x_spat.shape
91
+
92
+ k, v = self.kv(x_spat).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
93
+
94
+ q = (self.q * self.scale).expand(B, -1, -1, -1)
95
+ attn = q @ k.transpose(-2, -1)
96
+ attn = F.softmax(attn, dim=-1)
97
+ x = attn @ v
98
+
99
+ x = x.transpose(1, 2).reshape(B, C)
100
+ x = self.proj(x)
101
+
102
+ x = rearrange(x, '(b h w) c -> b (h w) c', b=x_pre.shape[0], h=ds_hw[0], w=ds_hw[1])
103
+
104
+ x = torch.cat([x_pre, x], dim=1)
105
+ return x
106
+
107
+
108
+ class HybridModel(nn.Module):
109
+ def __init__(self, vit: tvit.VisionTransformer, conv: tconv.ConvNeXt, pretrained: bool = False,
110
+ concatenate: bool = False, **kwargs):
111
+ super().__init__()
112
+ self.conv = conv
113
+ self.vit = vit
114
+ self.concatenate = concatenate
115
+
116
+ conv.stages = nn.ModuleList(conv.stages)
117
+ vit.blocks = nn.ModuleList(vit.blocks)
118
+
119
+ self._half_vit_idx = len(vit.blocks) // 2 + 1
120
+
121
+ self._half_conv_idx = None
122
+ x = torch.empty(1, 3, 256, 256)
123
+ x = self.conv.stem(x)
124
+ for i in range(len(conv.stages)):
125
+ x = conv.stages[i](x)
126
+ if self._half_conv_idx is None and x.shape[-2:] == (16, 16):
127
+ self._half_conv_idx = i + 1
128
+ half_conv_dim = x.shape[1]
129
+ final_conv_dim = x.shape[1]
130
+
131
+ self.vit_to_conv_fusion = Fuser(vit.embed_dim, half_conv_dim)
132
+ self.conv_to_vit_fusion = Fuser(half_conv_dim, vit.embed_dim)
133
+ self.vit_ds = AttnDownsample(vit.embed_dim, window_size=2)
134
+
135
+ embed_dim = vit.embed_dim + (final_conv_dim if concatenate else 0)
136
+ if not concatenate:
137
+ self.final_fuse = Fuser(final_conv_dim, vit.embed_dim, gated=False)
138
+ self.final_block = tvit.Block(embed_dim, num_heads=16)
139
+
140
+ self.embed_dim = embed_dim
141
+
142
+ @property
143
+ def patch_size(self):
144
+ return 32
145
+
146
+ @property
147
+ def no_fsdp_wrap_types(self):
148
+ return {tvit.VisionTransformer, tconv.ConvNeXt}
149
+
150
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
151
+ return self.forward_features(x)
152
+
153
+ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
154
+ y_vit = self.vit.patch_generator(x)
155
+
156
+ for i in range(self._half_vit_idx):
157
+ y_vit = self.vit.blocks[i](y_vit)
158
+
159
+ y_conv = self.conv.stem(x)
160
+ for i in range(self._half_conv_idx):
161
+ y_conv = self.conv.stages[i](y_conv)
162
+
163
+ y_vit, y_conv = self.conv_to_vit_fusion(y_conv, y_vit), self.vit_to_conv_fusion(y_vit, y_conv)
164
+
165
+ y_vit = self.vit_ds(y_vit, y_conv.shape[-2:])
166
+
167
+ for i in range(self._half_vit_idx, len(self.vit.blocks)):
168
+ y_vit = self.vit.blocks[i](y_vit)
169
+
170
+ for i in range(self._half_conv_idx, len(self.conv.stages)):
171
+ y_conv = self.conv.stages[i](y_conv)
172
+
173
+ if self.concatenate:
174
+ y_conv = rearrange(y_conv, 'b c h w -> b (h w) c')
175
+ # Average pool across the board, and replicate for each cls/register token
176
+ conv_summary = y_conv.mean(dim=1, keepdim=True).expand(-1, self.vit.patch_generator.num_cls_patches, -1)
177
+ y_conv = torch.cat([conv_summary, y_conv], dim=1)
178
+ y = torch.cat([y_vit, y_conv], dim=2)
179
+ else:
180
+ y = self.final_fuse(y_conv, y_vit)
181
+ y = self.final_block(y)
182
+
183
+ summary = y[:, :self.vit.patch_generator.num_cls_tokens]
184
+ features = y[:, self.vit.patch_generator.num_cls_patches:]
185
+
186
+ return summary, features
187
+
188
+
189
+ @register_model
190
+ def hybrid_base(pretrained=False, concatenate: bool = False, weight_init: str = 'skip', **kwargs):
191
+ cfg = dict(num_classes=0, **kwargs)
192
+ conv = tconv.convnextv2_base(pretrained=pretrained, **cfg)
193
+ vit = tvit.vit_base_patch16_224(pretrained=pretrained, weight_init=weight_init, **cfg)
194
+
195
+ return HybridModel(vit, conv, pretrained, concatenate=concatenate)
196
+
197
+
198
+ @register_model
199
+ def hybrid_large(pretrained=False, concatenate: bool = False, weight_init: str = 'skip', **kwargs):
200
+ cfg = dict(num_classes=0, **kwargs)
201
+ conv = tconv.convnextv2_large(pretrained=pretrained, **cfg)
202
+ vit = tvit.vit_large_patch16_224(pretrained=pretrained, weight_init=weight_init, **cfg)
203
+
204
+ return HybridModel(vit, conv, pretrained, concatenate=concatenate)
205
+
206
+
207
+ @register_model
208
+ def hybrid_huge(pretrained=False, concatenate: bool = False, weight_init: str = 'skip', **kwargs):
209
+ cfg = dict(num_classes=0, **kwargs)
210
+ conv = tconv.convnextv2_huge(pretrained=pretrained, **cfg)
211
+ vit = et.vit_huge_patch16_224(pretrained=pretrained, weight_init=weight_init, **cfg)
212
+
213
+ return HybridModel(vit, conv, pretrained, concatenate=concatenate)
tim/models/nvidia_radio/radio/enable_cpe_support.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ from typing import List, Optional, Set, Tuple, Union
10
+ from types import MethodType
11
+
12
+ import torch
13
+ from torch import nn
14
+
15
+ from timm.models import VisionTransformer, checkpoint_seq
16
+ from timm.models.vision_transformer import Attention, Block
17
+
18
+ from .feature_normalizer import IntermediateFeatureNormalizerBase, NullIntermediateFeatureNormalizer
19
+
20
+ from .extra_models import DinoWrapper
21
+ from .vit_patch_generator import ViTPatchGenerator
22
+ from .forward_intermediates import forward_intermediates
23
+ from .dual_hybrid_vit import HybridModel
24
+ from flash_attn import flash_attn_varlen_func
25
+
26
+
27
+ def _attn_forward_pack(self: Attention, x: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
28
+ N, C = x.shape
29
+ qkv = self.qkv(x).reshape(N, 3, self.num_heads, self.head_dim).permute(1, 0, 2, 3)
30
+ q, k, v = qkv.unbind(0)
31
+ q, k = self.q_norm(q), self.k_norm(k)
32
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
33
+
34
+ x = flash_attn_varlen_func(
35
+ q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen
36
+ ).reshape(N, -1)
37
+
38
+ x = self.proj(x)
39
+ x = self.proj_drop(x)
40
+ return x
41
+
42
+ def _block_forward_pack(self: Block, x: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
43
+ x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), cu_seqlens)))
44
+ x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
45
+ return x
46
+
47
+ def _forward_cpe_pack(self: VisionTransformer, images: List[torch.Tensor]) -> torch.Tensor:
48
+ device = images[0].device
49
+ x = []
50
+ seqlens = []
51
+ for image in images:
52
+ # image: [1, c, H, W] -> x: [n_cls+h*w, D], h=H/p and w=W/p
53
+ _image = self.patch_generator(image).squeeze(0)
54
+ x.append(_image)
55
+ seqlens.append(_image.shape[0])
56
+
57
+ x = torch.cat(x, dim=0)
58
+ seqlens = torch.tensor(seqlens, device=device, dtype=torch.int)
59
+
60
+ cu_seqlens = torch.cat([
61
+ torch.tensor([0], device=device, dtype=torch.int32),
62
+ torch.cumsum(seqlens, dim=0, dtype=torch.int32)
63
+ ])
64
+ if getattr(self, 'grad_checkpointing', False) and not torch.jit.is_scripting():
65
+ for block in self.blocks:
66
+ x = checkpoint_seq(block, x, cu_seqlens)
67
+ else:
68
+ for block in self.blocks:
69
+ x = block(x, cu_seqlens)
70
+ x = self.norm(x)
71
+ return x, cu_seqlens
72
+
73
+ def _forward_cpe(self: VisionTransformer, x: torch.Tensor) -> torch.Tensor:
74
+ x = self.patch_generator(x)
75
+ if getattr(self, 'grad_checkpointing', False) and not torch.jit.is_scripting():
76
+ x = checkpoint_seq(self.blocks, x)
77
+ else:
78
+ x = self.blocks(x)
79
+ x = self.norm(x)
80
+ return x
81
+
82
+
83
+ def _take_indices(
84
+ num_blocks: int,
85
+ n: Optional[Union[int, List[int], Tuple[int]]],
86
+ ) -> Tuple[Set[int], int]:
87
+ if isinstance(n, int):
88
+ assert n >= 0
89
+ take_indices = {x for x in range(num_blocks - n, num_blocks)}
90
+ else:
91
+ take_indices = {num_blocks + idx if idx < 0 else idx for idx in n}
92
+ return take_indices, max(take_indices)
93
+
94
+
95
+ def _forward_intermediates_cpe(
96
+ self,
97
+ x: torch.Tensor,
98
+ norm: bool = False,
99
+ **kwargs,
100
+ ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
101
+ return forward_intermediates(
102
+ self,
103
+ patch_extractor=self.patch_generator,
104
+ num_summary_tokens=self.patch_generator.num_skip,
105
+ num_cls_tokens=self.patch_generator.num_cls_tokens,
106
+ norm=self.norm if norm else lambda y: y,
107
+ x=x,
108
+ **kwargs,
109
+ )
110
+
111
+
112
+ def _forward_cpe_dinov2(self: DinoWrapper, x: torch.Tensor) -> torch.Tensor:
113
+ y = _forward_cpe(self.inner, x)
114
+
115
+ return y[:, 0], y[:, self.num_summary_tokens:]
116
+
117
+
118
+ def _forward_intermediates_cpe_dinov2(self: DinoWrapper, *args, **kwargs):
119
+ return _forward_intermediates_cpe(self.inner, *args, **kwargs)
120
+
121
+
122
+ def _enable_cpe_for_timm_vit(model: VisionTransformer,
123
+ max_img_size: Union[int, Tuple[int, int]] = 1024,
124
+ num_cls_tokens: int = 1,
125
+ pos_dropout: float = 0.1,
126
+ register_multiple: int = Optional[None],
127
+ num_registers: int = Optional[None],
128
+ support_packing: bool = False,
129
+ ):
130
+ if not isinstance(model, VisionTransformer):
131
+ raise ValueError("CPE only support for VisionTransformer models!")
132
+
133
+ patch_size = model.patch_embed.patch_size[0]
134
+ embed_dim = model.embed_dim
135
+ input_dims = model.patch_embed.img_size
136
+ normalize_patches = not isinstance(model.patch_embed.norm, nn.Identity)
137
+ cls_token = model.cls_token is not None
138
+
139
+ max_img_size = int(round(max_img_size / patch_size) * patch_size)
140
+
141
+ patch_generator = ViTPatchGenerator(
142
+ patch_size=patch_size,
143
+ embed_dim=embed_dim,
144
+ input_dims=input_dims,
145
+ normalize_patches=normalize_patches,
146
+ cls_token=cls_token,
147
+ max_input_dims=max_img_size,
148
+ pos_dropout=pos_dropout,
149
+ num_cls_tokens=num_cls_tokens,
150
+ register_multiple=register_multiple,
151
+ num_registers=num_registers,
152
+ )
153
+
154
+ model.patch_generator = patch_generator
155
+ model.patch_embed = None
156
+ model.cls_token = None
157
+ model.pos_embed = None
158
+ model.pos_drop = None
159
+ model.patch_size = patch_size
160
+ model.num_cls_tokens = num_cls_tokens
161
+ model.num_registers = patch_generator.num_registers
162
+
163
+ model.forward_features = MethodType(_forward_cpe, model)
164
+ model.forward_intermediates = MethodType(_forward_intermediates_cpe, model)
165
+ if support_packing:
166
+ model.forward_features = MethodType(_forward_cpe_pack, model)
167
+ for block in model.blocks:
168
+ block.forward = MethodType(_block_forward_pack, block)
169
+ block.attn.forward = MethodType(_attn_forward_pack, block.attn)
170
+
171
+
172
+ def _enable_cpe_for_dv2_reg_vit(model: DinoWrapper,
173
+ max_img_size: Union[int, Tuple[int, int]] = 1024,
174
+ num_cls_tokens: int = 1,
175
+ pos_dropout: float = 0.1,
176
+ register_multiple: int = Optional[None],
177
+ num_registers: int = Optional[None],
178
+ ):
179
+ patch_size = model.patch_size
180
+ embed_dim = model.embed_dim
181
+ input_dims = model.inner.patch_embed.patches_resolution
182
+ normalize_patches = not isinstance(model.inner.patch_embed.norm, nn.Identity)
183
+ cls_token = True
184
+
185
+ max_img_size = int(round(max_img_size / patch_size) * patch_size)
186
+
187
+ patch_generator = ViTPatchGenerator(
188
+ patch_size=patch_size,
189
+ embed_dim=embed_dim,
190
+ input_dims=input_dims,
191
+ normalize_patches=normalize_patches,
192
+ cls_token=cls_token,
193
+ max_input_dims=max_img_size,
194
+ pos_dropout=pos_dropout,
195
+ num_cls_tokens=num_cls_tokens,
196
+ register_multiple=register_multiple,
197
+ num_registers=num_registers,
198
+ patch_bias=True,
199
+ )
200
+
201
+ inner = model.inner
202
+ inner.patch_generator = patch_generator
203
+ inner.patch_embed = None
204
+ inner.cls_token = None
205
+ inner.pos_embed = None
206
+ inner.register_tokens = None
207
+ inner.patch_size = patch_size
208
+
209
+ model.forward_features = MethodType(_forward_cpe_dinov2, model)
210
+ model.forward_intermediates = MethodType(_forward_intermediates_cpe_dinov2, model)
211
+
212
+
213
+ def enable_cpe(model: nn.Module,
214
+ *args,
215
+ **kwargs,
216
+ ):
217
+ if isinstance(model, VisionTransformer):
218
+ _enable_cpe_for_timm_vit(model, *args, **kwargs)
219
+ elif isinstance(model, DinoWrapper):
220
+ _enable_cpe_for_dv2_reg_vit(model, *args, **kwargs)
221
+ elif isinstance(model, HybridModel):
222
+ _enable_cpe_for_timm_vit(model.vit, *args, **kwargs)
223
+ else:
224
+ raise ValueError(f'CPE not supported for this model type: {type(model)}')
tim/models/nvidia_radio/radio/enable_damp.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ from logging import getLogger
10
+ import math
11
+ import os
12
+ from typing import Dict, List, Optional, Union, Tuple
13
+ from types import MethodType
14
+
15
+ import torch
16
+ from torch import nn
17
+ from torch.nn import functional as F
18
+ from torch.nn.utils import parametrize
19
+
20
+
21
+ # For now, don't do anything
22
+ class DAMP(nn.Identity):
23
+ def __init__(self, std: float):
24
+ super().__init__()
25
+ self.std = std
26
+
27
+
28
+ def enable_damp(model: nn.Module, std: float):
29
+ if isinstance(model, (list, tuple)):
30
+ for m in model:
31
+ enable_damp(m, std)
32
+ return
33
+
34
+ for name, module in model.named_modules():
35
+ if isinstance(module, nn.Linear):
36
+ parametrize.register_parametrization(module, 'weight', DAMP(std))
37
+
38
+
39
+ def configure_damp_from_args(model: nn.Module, args):
40
+ damp = getattr(args, 'damp', None)
41
+ if damp:
42
+ enable_damp(model, damp)
tim/models/nvidia_radio/radio/enable_spectral_reparam.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ from logging import getLogger
10
+ import math
11
+ import os
12
+ from typing import Dict, List, Optional, Union, Tuple
13
+ from types import MethodType
14
+
15
+ import torch
16
+ from torch import nn
17
+ from torch.nn import functional as F
18
+ from torch.nn.utils import parametrize
19
+ from torch.nn.utils.parametrizations import _SpectralNorm
20
+
21
+ from timm.models.vision_transformer import Attention, Mlp
22
+
23
+ _EPS = 1e-5
24
+
25
+
26
+ class _SNReweight(_SpectralNorm):
27
+ def __init__(self, weight: torch.Tensor, *args, init_norm_to_current: bool = False, alpha: float = 0.05, version: int = 2, **kwargs):
28
+ super().__init__(weight, *args, **kwargs)
29
+
30
+ self.alpha = alpha
31
+ self.version = version
32
+ self.register_buffer('_sn_version', torch.tensor(version))
33
+
34
+ if init_norm_to_current:
35
+ # This will set the numerator to match the denominator, which should preserve the original values
36
+ init_scale = self._get_sigma(weight, n_power_iterations=20).item()
37
+ else:
38
+ init_scale = 1.0
39
+
40
+ if version == 1:
41
+ init_value = init_scale
42
+ elif version == 2:
43
+ t = init_scale - alpha
44
+ if t < _EPS:
45
+ getLogger("spectral_reparam").warn(f'The initialized spectral norm {init_scale} is too small to be represented. Setting to {_EPS} instead.')
46
+ t = _EPS
47
+
48
+ init_value = math.log(math.exp(t) - 1)
49
+ else:
50
+ raise ValueError(f'Unsupported version: {version}')
51
+
52
+ # Make 2D so that weight decay gets applied
53
+ self.scale = nn.Parameter(torch.tensor([[init_value]], dtype=torch.float32, device=weight.device))
54
+
55
+ # Re-implementing this because we need to make division by sigma safe
56
+ def _get_sigma(self, weight: torch.Tensor, n_power_iterations: int = None) -> torch.Tensor:
57
+ if not n_power_iterations:
58
+ n_power_iterations = self.n_power_iterations
59
+ if weight.ndim == 1:
60
+ # Faster and more exact path, no need to approximate anything
61
+ sigma = weight.norm()
62
+ else:
63
+ weight_mat = self._reshape_weight_to_matrix(weight)
64
+ if self.training:
65
+ self._power_method(weight_mat, n_power_iterations)
66
+ # See above on why we need to clone
67
+ u = self._u.clone(memory_format=torch.contiguous_format)
68
+ v = self._v.clone(memory_format=torch.contiguous_format)
69
+ # The proper way of computing this should be through F.bilinear, but
70
+ # it seems to have some efficiency issues:
71
+ # https://github.com/pytorch/pytorch/issues/58093
72
+ sigma = torch.dot(u, torch.mv(weight_mat, v))
73
+
74
+ return sigma + self.eps
75
+
76
+ def forward(self, weight: torch.Tensor, *args, **kwargs):
77
+ dtype = weight.dtype
78
+ sigma = self._get_sigma(weight, *args, **kwargs)
79
+
80
+ if self.version == 1:
81
+ scale = self.scale
82
+ elif self.version == 2:
83
+ scale = F.softplus(self.scale) + self.alpha
84
+ else:
85
+ raise ValueError(f'Unsupported version: {self.version}')
86
+
87
+ scale = scale.float() / sigma.float()
88
+
89
+ y = weight * scale
90
+
91
+ if dtype in (torch.float16, torch.bfloat16):
92
+ y = y.to(dtype)
93
+ return y
94
+
95
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
96
+ version_key = f'{prefix}_sn_version'
97
+ if version_key not in state_dict:
98
+ self.version = 1
99
+ state_dict[version_key] = torch.tensor(1)
100
+ return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
101
+
102
+
103
+ class _ChunkedSNReweight(nn.Module):
104
+ def __init__(self, weight: torch.Tensor, num_chunks: int, *args, init_norm_to_current: bool = False, **kwargs):
105
+ super().__init__()
106
+
107
+ self.num_chunks = num_chunks
108
+ parts = weight.split(weight.shape[0] // num_chunks, dim=0)
109
+
110
+ self.parts = nn.ModuleList([
111
+ _SNReweight(p, *args, init_norm_to_current=init_norm_to_current, **kwargs)
112
+ for p in parts
113
+ ])
114
+
115
+ def forward(self, weight: torch.Tensor, *args, **kwargs):
116
+ parts = weight.split(weight.shape[0] // self.num_chunks, dim=0)
117
+
118
+ parts = [
119
+ fn(p)
120
+ for fn, p in zip(self.parts, parts)
121
+ ]
122
+
123
+ return torch.cat(parts, dim=0)
124
+
125
+
126
+ class _AttnSNReweight(_ChunkedSNReweight):
127
+ def __init__(self, weight: torch.Tensor, *args, init_norm_to_current: bool = False, renorm_values: bool = False, **kwargs):
128
+ super().__init__(weight, 3, *args, init_norm_to_current=init_norm_to_current, **kwargs)
129
+
130
+ if not renorm_values:
131
+ self.parts[2] = nn.Identity()
132
+
133
+
134
+ def enable_spectral_reparam(model: Union[nn.Module, List[nn.Module]],
135
+ n_power_iterations: int = 1,
136
+ eps: float = 1e-6,
137
+ init_norm_to_current: bool = False,
138
+ renorm_values: bool = True,
139
+ renorm_mlp: bool = True,
140
+ state_dict_guidance: Optional[Dict[str, torch.Tensor]] = None):
141
+ if isinstance(model, (list, tuple)):
142
+ for i, sub in enumerate(model):
143
+ sub_sd = state_dict_guidance[i] if isinstance(state_dict_guidance, (list, tuple)) else state_dict_guidance
144
+ enable_spectral_reparam(sub, n_power_iterations=n_power_iterations, eps=eps,
145
+ init_norm_to_current=init_norm_to_current, renorm_values=renorm_values,
146
+ renorm_mlp=renorm_mlp, state_dict_guidance=sub_sd)
147
+ return
148
+
149
+ print('Enabling spectral reparametrization')
150
+ args = dict(n_power_iterations=n_power_iterations, dim=0, eps=eps, init_norm_to_current=init_norm_to_current)
151
+ visited_prefixes = set()
152
+
153
+ def is_guidance_parametrized(name: str):
154
+ if state_dict_guidance is None:
155
+ return True
156
+
157
+ p_name = f'{name}.parametrizations'
158
+ is_prm = any(k for k in state_dict_guidance if k.startswith(p_name) and k.endswith('_sn_version'))
159
+ return is_prm
160
+
161
+ def parametrize_linear(linear: nn.Linear):
162
+ parametrize.register_parametrization(
163
+ linear,
164
+ 'weight',
165
+ _SNReweight(linear.weight, **args)
166
+ )
167
+
168
+ for name, mod in model.named_modules():
169
+ pref = '.'.join(name.split('.')[:-1])
170
+ if pref in visited_prefixes:
171
+ continue
172
+
173
+ if isinstance(mod, Attention) or name.endswith('.attn'):
174
+ if is_guidance_parametrized(f'{name}.qkv'):
175
+ parametrize.register_parametrization(
176
+ mod.qkv,
177
+ 'weight',
178
+ _AttnSNReweight(mod.qkv.weight, renorm_values=renorm_values, **args),
179
+ )
180
+ if hasattr(mod, 'proj') and is_guidance_parametrized(f'{name}.proj'):
181
+ parametrize_linear(mod.proj)
182
+ visited_prefixes.add(name)
183
+ elif name.endswith('mlp') and renorm_mlp and hasattr(mod, 'w12'):
184
+ if is_guidance_parametrized(f'{name}.w12'):
185
+ parametrize.register_parametrization(
186
+ mod.w12,
187
+ 'weight',
188
+ _ChunkedSNReweight(mod.w12.weight, num_chunks=2, **args),
189
+ )
190
+ if is_guidance_parametrized(f'{name}.w3'):
191
+ parametrize_linear(mod.w3)
192
+ visited_prefixes.add(name)
193
+ elif isinstance(mod, nn.Linear) and 'patch_generator' not in name and is_guidance_parametrized(name):
194
+ parametrize_linear(mod)
195
+
196
+
197
+ def configure_spectral_reparam_from_args(model: nn.Module, args, state_dict_guidance: Optional[Dict[str, torch.Tensor]] = None):
198
+ spectral_reparam = getattr(args, 'spectral_reparam', False)
199
+ if isinstance(spectral_reparam, bool) and spectral_reparam:
200
+ enable_spectral_reparam(model, init_norm_to_current=True, state_dict_guidance=state_dict_guidance)
201
+ elif isinstance(spectral_reparam, dict):
202
+ enable_spectral_reparam(
203
+ model,
204
+ n_power_iterations=spectral_reparam.get('n_power_iterations', 1),
205
+ eps=spectral_reparam.get('eps', 1e-12),
206
+ init_norm_to_current=True,
207
+ state_dict_guidance=state_dict_guidance,
208
+ )
209
+
210
+
211
+ def disable_spectral_reparam(model: nn.Module):
212
+ print('Disabling spectral reparametrization')
213
+ for name, mod in model.named_modules():
214
+ if parametrize.is_parametrized(mod):
215
+ parametrize.remove_parametrizations(mod, 'weight')
216
+ pass
217
+
218
+
219
+
220
+ if __name__ == '__main__':
221
+ import argparse
222
+ from . import radio_model as create_model
223
+
224
+ parser = argparse.ArgumentParser(description='Remove parametrization from state dict')
225
+ parser.add_argument('--checkpoint', type=str, required=True, help='The checkpoint to load')
226
+ parser.add_argument('--output', type=str, default='', help='Where to store the checkpoint')
227
+ parser.add_argument('--release', default=False, action='store_true', help='Prune extraneous checkpoint fields')
228
+ parser.add_argument('--strict', default=False, action='store_true', help='Strictly load the state dict')
229
+
230
+ args = parser.parse_args()
231
+
232
+ if not args.output:
233
+ chk_dir, chk_name = os.path.split(args.checkpoint)
234
+ args.output = os.path.join(chk_dir, f'clean_{chk_name}')
235
+ print(f'Set output to "{args.output}"')
236
+
237
+ chk = torch.load(args.checkpoint, map_location='cpu', mmap=True)
238
+
239
+ model = create_model.create_model_from_args(chk['args'])
240
+
241
+ key = 'base_model.'
242
+ mod_state = dict()
243
+ extra_state = dict()
244
+ for k, v in chk['state_dict'].items():
245
+ if k.startswith(key):
246
+ mod_state[k[len(key):]] = v
247
+ else:
248
+ extra_state[k] = v
249
+
250
+ chk_load_info = model.load_state_dict(mod_state, strict=args.strict)
251
+ if chk_load_info.unexpected_keys or chk_load_info.missing_keys:
252
+ print(chk_load_info)
253
+
254
+ if chk['args'].spectral_reparam:
255
+ disable_spectral_reparam(model)
256
+
257
+ if hasattr(chk['args'], 'dtype'):
258
+ model.to(dtype=chk['args'].dtype)
259
+
260
+ mod_state = model.state_dict()
261
+ final_state = dict()
262
+ final_state.update({f'{key}{k}': v for k, v in mod_state.items()})
263
+ final_state.update(extra_state)
264
+
265
+ chk['state_dict'] = final_state
266
+ chk['args'].spectral_reparam = False
267
+
268
+ if args.release:
269
+ chk = {
270
+ 'arch': chk['arch'],
271
+ 'epoch': chk['epoch'],
272
+ 'state_dict': chk['state_dict'],
273
+ 'args': chk['args'],
274
+ }
275
+
276
+ torch.save(chk, args.output)
277
+ pass
tim/models/nvidia_radio/radio/eradio_model.py ADDED
@@ -0,0 +1,1392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
6
+ # and proprietary rights in and to this software, related documentation
7
+ # and any modifications thereto. Any use, reproduction, disclosure or
8
+ # distribution of this software and related documentation without an express
9
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
10
+
11
+ # E-RADIO model from
12
+ # Mike Ranzinger, Greg Heinrich, Jan Kautz, and Pavlo Molchanov. "AM-RADIO: Agglomerative Model--Reduce All Domains Into One." arXiv preprint arXiv:2312.06709 (2023).
13
+
14
+ # based on FasterViT, Swin Transformer, YOLOv8
15
+
16
+ # FasterViT:
17
+ # Ali Hatamizadeh, Greg Heinrich, Hongxu Yin, Andrew Tao, Jose M. Alvarez, Jan Kautz, and Pavlo Molchanov. "FasterViT: Fast Vision Transformers with Hierarchical Attention." arXiv preprint arXiv:2306.06189 (2023).
18
+
19
+ import timm
20
+ import torch
21
+ import torch.nn as nn
22
+ from timm.models.registry import register_model
23
+
24
+ from timm.models.layers import trunc_normal_, DropPath, LayerNorm2d
25
+ import numpy as np
26
+ import torch.nn.functional as F
27
+ import math
28
+ import warnings
29
+
30
+ #######################
31
+ ## Codebase from YOLOv8
32
+ ## BEGINNING
33
+ #######################
34
+
35
+ class C2f(nn.Module):
36
+ """Faster Implementation of CSP Bottleneck with 2 convolutions."""
37
+ """From YOLOv8 codebase"""
38
+ def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5, drop_path=None): # ch_in, ch_out, number, shortcut, groups, expansion
39
+ super().__init__()
40
+ if drop_path is None:
41
+ drop_path = [0.0] * n
42
+
43
+ self.c = int(c2 * e) # hidden channels
44
+ self.cv1 = Conv(c1, 2 * self.c, 1, 1)
45
+ self.cv2 = Conv((2 + n) * self.c, c2, 1) # optional act=FReLU(c2)
46
+ self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0, drop_path=drop_path[i]) for i in range(n))
47
+
48
+ def forward(self, x):
49
+ """Forward pass through C2f layer."""
50
+ y = list(self.cv1(x).chunk(2, 1))
51
+ y.extend(m(y[-1]) for m in self.m)
52
+ return self.cv2(torch.cat(y, 1))
53
+
54
+ def forward_split(self, x):
55
+ """Forward pass using split() instead of chunk()."""
56
+ y = list(self.cv1(x).split((self.c, self.c), 1))
57
+ y.extend(m(y[-1]) for m in self.m)
58
+ return self.cv2(torch.cat(y, 1))
59
+
60
+ class Bottleneck(nn.Module):
61
+ """Standard bottleneck."""
62
+
63
+ def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5, drop_path=0.0): # ch_in, ch_out, shortcut, groups, kernels, expand
64
+ super().__init__()
65
+ c_ = int(c2 * e) # hidden channels
66
+ self.cv1 = Conv(c1, c_, k[0], 1)
67
+ self.cv2 = Conv(c_, c2, k[1], 1, g=g)
68
+ self.add = shortcut and c1 == c2
69
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
70
+
71
+ def forward(self, x):
72
+ """'forward()' applies the YOLOv5 FPN to input data."""
73
+ return x + self.drop_path1(self.cv2(self.cv1(x))) if self.add else self.cv2(self.cv1(x))
74
+
75
+
76
+ class Conv(nn.Module):
77
+ """Modified to support layer fusion"""
78
+ default_act = nn.SiLU() # default activation
79
+
80
+ def __init__(self, a, b, kernel_size=1, stride=1, padding=None, g=1, dilation=1, bn_weight_init=1, bias=False, act=True):
81
+ super().__init__()
82
+
83
+ self.conv = torch.nn.Conv2d(a, b, kernel_size, stride, autopad(kernel_size, padding, dilation), dilation, g, bias=False)
84
+ if 1:
85
+ self.bn = torch.nn.BatchNorm2d(b)
86
+ torch.nn.init.constant_(self.bn.weight, bn_weight_init)
87
+ torch.nn.init.constant_(self.bn.bias, 0)
88
+ self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
89
+
90
+
91
+ def forward(self,x):
92
+ x = self.conv(x)
93
+ x = self.bn(x)
94
+ x = self.act(x)
95
+ return x
96
+
97
+ @torch.no_grad()
98
+ def switch_to_deploy(self):
99
+ # return 1
100
+ if not isinstance(self.bn, nn.Identity):
101
+ c, bn = self.conv, self.bn
102
+ w = bn.weight / (bn.running_var + bn.eps) ** 0.5
103
+ w = c.weight * w[:, None, None, None]
104
+ b = bn.bias - bn.running_mean * bn.weight / \
105
+ (bn.running_var + bn.eps)**0.5
106
+
107
+ self.conv.weight.data.copy_(w)
108
+ self.conv.bias = nn.Parameter(b)
109
+
110
+ self.bn = nn.Identity()
111
+
112
+ def autopad(k, p=None, d=1): # kernel, padding, dilation
113
+ """Pad to 'same' shape outputs."""
114
+ if d > 1:
115
+ k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k] # actual kernel-size
116
+ if p is None:
117
+ p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
118
+ return p
119
+
120
+
121
+ #######################
122
+ ## Codebase from YOLOv8
123
+ ## END
124
+ #######################
125
+
126
+ def pixel_unshuffle(data, factor=2):
127
+ # performs nn.PixelShuffle(factor) in reverse, torch has some bug for ONNX and TRT, so doing it manually
128
+ B, C, H, W = data.shape
129
+ return data.view(B, C, factor, H//factor, factor, W//factor).permute(0,1,2,4,3,5).reshape(B, -1, H//factor, W//factor)
130
+
131
+ class SwiGLU(nn.Module):
132
+ # should be more advanced, but doesnt improve results so far
133
+ def forward(self, x):
134
+ x, gate = x.chunk(2, dim=-1)
135
+ return F.silu(gate) * x
136
+
137
+
138
+ def window_partition(x, window_size):
139
+ """
140
+ Function for partitioning image into windows and later do windowed attention
141
+ Args:
142
+ x: (B, C, H, W)
143
+ window_size: window size
144
+ Returns:
145
+ windows - local window features (num_windows*B, window_size*window_size, C)
146
+ (Hp, Wp) - the size of the padded image
147
+ """
148
+ B, C, H, W = x.shape
149
+
150
+ if window_size == 0 or (window_size==H and window_size==W):
151
+ windows = x.flatten(2).transpose(1, 2)
152
+ Hp, Wp = H, W
153
+ else:
154
+ pad_h = (window_size - H % window_size) % window_size
155
+ pad_w = (window_size - W % window_size) % window_size
156
+ if pad_h > 0 or pad_w > 0:
157
+ x = F.pad(x, (0, pad_w, 0, pad_h), mode="reflect")
158
+ Hp, Wp = H + pad_h, W + pad_w
159
+
160
+ x = x.view(B, C, Hp // window_size, window_size, Wp // window_size, window_size)
161
+ windows = x.permute(0, 2, 4, 3, 5, 1).reshape(-1, window_size*window_size, C)
162
+
163
+ return windows, (Hp, Wp)
164
+
165
+ class Conv2d_BN(nn.Module):
166
+ '''
167
+ Conv2d + BN layer with folding capability to speed up inference
168
+ Can be merged with Conv() function with additional arguments
169
+ '''
170
+ def __init__(self, a, b, kernel_size=1, stride=1, padding=0, dilation=1, groups=1, bn_weight_init=1, bias=False):
171
+ super().__init__()
172
+ self.conv = torch.nn.Conv2d(a, b, kernel_size, stride, padding, dilation, groups, bias=False)
173
+ if 1:
174
+ self.bn = torch.nn.BatchNorm2d(b)
175
+ torch.nn.init.constant_(self.bn.weight, bn_weight_init)
176
+ torch.nn.init.constant_(self.bn.bias, 0)
177
+
178
+ def forward(self,x):
179
+ x = self.conv(x)
180
+ x = self.bn(x)
181
+ return x
182
+
183
+ @torch.no_grad()
184
+ def switch_to_deploy(self):
185
+ if not isinstance(self.bn, nn.Identity):
186
+ c, bn = self.conv, self.bn
187
+ w = bn.weight / (bn.running_var + bn.eps) ** 0.5
188
+ w = c.weight * w[:, None, None, None]
189
+ b = bn.bias - bn.running_mean * bn.weight / \
190
+ (bn.running_var + bn.eps)**0.5
191
+ self.conv.weight.data.copy_(w)
192
+ self.conv.bias = nn.Parameter(b)
193
+ self.bn = nn.Identity()
194
+
195
+
196
+
197
+ def window_reverse(windows, window_size, H, W, pad_hw):
198
+ """
199
+ Windows to the full feature map
200
+ Args:
201
+ windows: local window features (num_windows*B, window_size, window_size, C)
202
+ window_size: Window size
203
+ H: Height of image
204
+ W: Width of image
205
+ pad_w - a tuple of image passing used in windowing step
206
+ Returns:
207
+ x: (B, C, H, W)
208
+
209
+ """
210
+ # print(f"window_reverse, windows.shape {windows.shape}")
211
+ Hp, Wp = pad_hw
212
+ if window_size == 0 or (window_size==H and window_size==W):
213
+ B = int(windows.shape[0] / (Hp * Wp / window_size / window_size))
214
+ x = windows.transpose(1, 2).view(B, -1, H, W)
215
+ else:
216
+ B = int(windows.shape[0] / (Hp * Wp / window_size / window_size))
217
+ x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
218
+ x = x.permute(0, 5, 1, 3, 2, 4).reshape(B,windows.shape[2], Hp, Wp)
219
+
220
+ if Hp > H or Wp > W:
221
+ x = x[:, :, :H, :W, ].contiguous()
222
+
223
+ return x
224
+
225
+
226
+
227
+ class PosEmbMLPSwinv2D(nn.Module):
228
+ """
229
+ 2D positional embedding from Swin Transformer v2
230
+ Added functionality to store the positional embedding in the model and not recompute it every time
231
+ """
232
+ def __init__(
233
+ self, window_size, pretrained_window_size, num_heads, seq_length, no_log=False, cpb_mlp_hidden=512,
234
+ ):
235
+ super().__init__()
236
+ self.window_size = window_size
237
+ self.num_heads = num_heads
238
+ # mlp to generate continuous relative position bias
239
+ self.cpb_mlp = nn.Sequential(
240
+ nn.Linear(2, cpb_mlp_hidden, bias=True),
241
+ nn.ReLU(inplace=True),
242
+ nn.Linear(cpb_mlp_hidden, num_heads, bias=False),
243
+ )
244
+
245
+ self.grid_exists = False
246
+ self.seq_length = seq_length
247
+ self.deploy = False
248
+ self.num_heads = num_heads
249
+ self.no_log = no_log
250
+ self.pretrained_window_size = pretrained_window_size
251
+ self.relative_bias_window_size = window_size
252
+
253
+ relative_coords_table, relative_position_index, relative_bias = self.relative_bias_initialization(window_size, num_heads,
254
+ pretrained_window_size, seq_length,
255
+ no_log)
256
+
257
+ self.register_buffer("relative_coords_table", relative_coords_table)
258
+ self.register_buffer("relative_position_index", relative_position_index)
259
+ self.register_buffer("relative_bias", relative_bias) # for EMA
260
+
261
+ def relative_bias_initialization(self, window_size, num_heads, pretrained_window_size, seq_length, no_log):
262
+ # as in separate function to support window size chage after model weights loading
263
+ relative_coords_h = torch.arange(
264
+ -(window_size[0] - 1), window_size[0], dtype=torch.float32
265
+ )
266
+ relative_coords_w = torch.arange(
267
+ -(window_size[1] - 1), window_size[1], dtype=torch.float32
268
+ )
269
+ relative_coords_table = (
270
+ torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w]))
271
+ .permute(1, 2, 0)
272
+ .contiguous()
273
+ .unsqueeze(0)
274
+ ) # 1, 2*Wh-1, 2*Ww-1, 2
275
+ if pretrained_window_size[0] > 0:
276
+ relative_coords_table[:, :, :, 0] /= pretrained_window_size[0] - 1
277
+ relative_coords_table[:, :, :, 1] /= pretrained_window_size[1] - 1
278
+ else:
279
+ relative_coords_table[:, :, :, 0] /= self.window_size[0] - 1
280
+ relative_coords_table[:, :, :, 1] /= self.window_size[1] - 1
281
+
282
+ if not no_log:
283
+ relative_coords_table *= 8 # normalize to -8, 8
284
+ relative_coords_table = (
285
+ torch.sign(relative_coords_table)
286
+ * torch.log2(torch.abs(relative_coords_table) + 1.0)
287
+ / np.log2(8)
288
+ )
289
+
290
+ # get pair-wise relative position index for each token inside the window
291
+ coords_h = torch.arange(self.window_size[0])
292
+ coords_w = torch.arange(self.window_size[1])
293
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
294
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
295
+ relative_coords = (
296
+ coords_flatten[:, :, None] - coords_flatten[:, None, :]
297
+ ) # 2, Wh*Ww, Wh*Ww
298
+ relative_coords = relative_coords.permute(
299
+ 1, 2, 0
300
+ ).contiguous() # Wh*Ww, Wh*Ww, 2
301
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
302
+ relative_coords[:, :, 1] += self.window_size[1] - 1
303
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
304
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
305
+
306
+ relative_bias = torch.zeros(1, num_heads, seq_length, seq_length)
307
+
308
+ self.relative_bias_window_size = window_size
309
+
310
+ return relative_coords_table, relative_position_index, relative_bias
311
+
312
+
313
+ def switch_to_deploy(self):
314
+ self.deploy = True
315
+ self.grid_exists = True
316
+
317
+ def forward(self, input_tensor):
318
+ # for efficiency, we want this forward to be folded into a single operation (sum)
319
+ # if resolution stays the same, then we dont need to recompute MLP layers
320
+
321
+ if not self.deploy or self.training:
322
+ self.grid_exists = False
323
+
324
+ #compare if all elements in self.window_size list match those in self.relative_bias_window_size
325
+ if not all([self.window_size[i] == self.relative_bias_window_size[i] for i in range(len(self.window_size))]):
326
+ relative_coords_table, relative_position_index, relative_bias = self.relative_bias_initialization(self.window_size, self.num_heads,
327
+ self.pretrained_window_size, self.seq_length,
328
+ self.no_log)
329
+
330
+ self.relative_coords_table = relative_coords_table.to(self.relative_coords_table.device)
331
+ self.relative_position_index = relative_position_index.to(self.relative_position_index.device)
332
+ self.relative_bias = relative_bias.to(self.relative_bias.device)
333
+
334
+ if self.deploy and self.grid_exists:
335
+ input_tensor = input_tensor + self.relative_bias
336
+ return input_tensor
337
+
338
+ if 1:
339
+ self.grid_exists = True
340
+
341
+ relative_position_bias_table = self.cpb_mlp(
342
+ self.relative_coords_table
343
+ ).view(-1, self.num_heads)
344
+ relative_position_bias = relative_position_bias_table[
345
+ self.relative_position_index.view(-1)
346
+ ].view(
347
+ self.window_size[0] * self.window_size[1],
348
+ self.window_size[0] * self.window_size[1],
349
+ -1,
350
+ ) # Wh*Ww,Wh*Ww,nH
351
+
352
+ relative_position_bias = relative_position_bias.permute(
353
+ 2, 0, 1
354
+ ).contiguous() # nH, Wh*Ww, Wh*Ww
355
+ relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
356
+
357
+ self.relative_bias = relative_position_bias.unsqueeze(0)
358
+
359
+ input_tensor = input_tensor + self.relative_bias
360
+ return input_tensor
361
+
362
+
363
+ class GRAAttentionBlock(nn.Module):
364
+ def __init__(self, window_size, dim_in, dim_out,
365
+ num_heads, drop_path=0., qk_scale=None, qkv_bias=False,
366
+ norm_layer=nn.LayerNorm, layer_scale=None,
367
+ use_swiglu=True,
368
+ subsample_ratio=1, dim_ratio=1, conv_base=False,
369
+ do_windowing=True, multi_query=False, use_shift=0,
370
+ cpb_mlp_hidden=512, conv_groups_ratio=0):
371
+ '''
372
+ Global Resolution Attention Block , see README for details
373
+ Attention with subsampling to get a bigger receptive field for attention
374
+ conv_base - use conv2d instead of avgpool2d for downsample / upsample
375
+
376
+
377
+ '''
378
+ super().__init__()
379
+
380
+ self.shift_size=window_size//2 if use_shift else 0
381
+
382
+ self.do_windowing = do_windowing
383
+ self.subsample_ratio = subsample_ratio
384
+
385
+
386
+
387
+ if do_windowing:
388
+ if conv_base:
389
+ self.downsample_op = nn.Conv2d(dim_in, dim_out, kernel_size=subsample_ratio, stride=subsample_ratio) if subsample_ratio > 1 else nn.Identity()
390
+
391
+
392
+ self.downsample_mixer = nn.Identity()
393
+ self.upsample_mixer = nn.Identity()
394
+ self.upsample_op = nn.ConvTranspose2d(dim_in, dim_out, kernel_size=subsample_ratio, stride=subsample_ratio) if subsample_ratio > 1 else nn.Identity()
395
+ else:
396
+ self.downsample_op = nn.AvgPool2d(kernel_size=subsample_ratio, stride=subsample_ratio) if subsample_ratio > 1 else nn.Identity()
397
+ self.downsample_mixer = Conv2d_BN(dim_in, dim_out, kernel_size=1, stride=1) if subsample_ratio > 1 else nn.Identity()
398
+ self.upsample_mixer = nn.Upsample(scale_factor=subsample_ratio, mode='nearest') if subsample_ratio > 1 else nn.Identity()
399
+ self.upsample_op = Conv2d_BN(dim_in, dim_out, kernel_size=1, stride=1, padding=0, bias=False) if subsample_ratio > 1 else nn.Identity()
400
+
401
+
402
+ # in case there is no downsampling conv we want to have it separately
403
+ # will help with information propagation between windows
404
+ if subsample_ratio == 1:
405
+ # conv_groups_ratio=0
406
+ self.pre_conv = Conv2d_BN(dim_in, dim_in, kernel_size=3, stride=1, padding=1, groups=max(1,int(conv_groups_ratio*dim_in)), bias=False)
407
+ # self.pre_conv = nn.Conv2d(dim_in, dim_in, kernel_size=3, stride=1, padding=1, groups=max(1,int(conv_groups_ratio*dim_in)), bias=False)
408
+ # self.pre_conv_act = nn.ReLU6()
409
+ #for simplicity:
410
+ self.pre_conv_act = nn.Identity()
411
+ if conv_groups_ratio == -1:
412
+ self.pre_conv = nn.Identity()
413
+ self.pre_conv_act = nn.Identity()
414
+
415
+ self.window_size = window_size
416
+
417
+ self.norm1 = norm_layer(dim_in)
418
+
419
+ self.attn = WindowAttention(
420
+ dim_in,
421
+ num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
422
+ resolution=window_size,
423
+ seq_length=window_size**2, dim_out=dim_in, multi_query=multi_query,
424
+ shift_size=self.shift_size, cpb_mlp_hidden=cpb_mlp_hidden)
425
+
426
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
427
+
428
+ use_layer_scale = layer_scale is not None and type(layer_scale) in [int, float]
429
+ self.gamma1 = nn.Parameter(layer_scale * torch.ones(dim_in)) if use_layer_scale else 1
430
+
431
+ ### mlp layer
432
+ mlp_ratio = 4
433
+ self.norm2 = norm_layer(dim_in)
434
+ mlp_hidden_dim = int(dim_in * mlp_ratio)
435
+
436
+ activation = nn.GELU if not use_swiglu else SwiGLU
437
+ mlp_hidden_dim = int((4 * dim_in * 1 / 2) / 64) * 64 if use_swiglu else mlp_hidden_dim
438
+
439
+ self.mlp = Mlp(in_features=dim_in, hidden_features=mlp_hidden_dim, act_layer=activation, use_swiglu=use_swiglu)
440
+
441
+ self.gamma2 = nn.Parameter(layer_scale * torch.ones(dim_in)) if layer_scale else 1
442
+ self.drop_path2=DropPath(drop_path) if drop_path > 0. else nn.Identity()
443
+
444
+
445
+ def forward(self, x):
446
+ skip_connection = x
447
+ attn_mask = None
448
+
449
+ # in case there is no downsampling conv we want to have it separately
450
+ # will help with information propagation
451
+ if self.subsample_ratio == 1:
452
+ x = self.pre_conv_act(self.pre_conv(x)) + skip_connection
453
+
454
+ if self.do_windowing:
455
+ # performing windowing if required
456
+ x = self.downsample_op(x)
457
+ x = self.downsample_mixer(x)
458
+
459
+ if self.window_size>0:
460
+ H, W = x.shape[2], x.shape[3]
461
+
462
+ if self.shift_size > 0 and H>self.window_size and W>self.window_size:
463
+ # @swin like cyclic shift, doesnt show better performance
464
+ x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(2, 3))
465
+
466
+ x, pad_hw = window_partition(x, self.window_size)
467
+
468
+ if self.shift_size > 0 and H>self.window_size and W>self.window_size:
469
+ # set atten matrix to have -100 and the top right square
470
+ # attn[:, :, :-self.shift_size, -self.shift_size:] = -100.0
471
+ # calculate attention mask for SW-MSA
472
+ # not used in final version, can be useful for some cases especially for high res
473
+ H, W = pad_hw
474
+ img_mask = torch.zeros((1, H, W, 1), device=x.device) # 1 H W 1
475
+ h_slices = (slice(0, -self.window_size),
476
+ slice(-self.window_size, -self.shift_size),
477
+ slice(-self.shift_size, None))
478
+ w_slices = (slice(0, -self.window_size),
479
+ slice(-self.window_size, -self.shift_size),
480
+ slice(-self.shift_size, None))
481
+ cnt = 0
482
+ for h in h_slices:
483
+ for w in w_slices:
484
+ img_mask[:, h, w, :] = cnt
485
+ cnt += 1
486
+ img_mask = img_mask.transpose(1,2).transpose(1,3)
487
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
488
+
489
+ mask_windows = mask_windows[0].view(-1, self.window_size * self.window_size)
490
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
491
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
492
+
493
+ # window attention
494
+ x = x + self.drop_path1(self.gamma1*self.attn(self.norm1(x), attn_mask=attn_mask)) # or pass H,W
495
+ # mlp layer
496
+ x = x + self.drop_path2(self.gamma2*self.mlp(self.norm2(x)))
497
+
498
+ if self.do_windowing:
499
+ if self.window_size > 0:
500
+ x = window_reverse(x, self.window_size, H, W, pad_hw)
501
+
502
+ # reverse cyclic shift
503
+ if self.shift_size > 0 and H>self.window_size and W>self.window_size:
504
+ # @swin like cyclic shift, not tested
505
+ x = torch.roll(x, shifts=(self.shift_size, self.shift_size), dims=(2, 3))
506
+
507
+ x = self.upsample_mixer(x)
508
+ x = self.upsample_op(x)
509
+
510
+
511
+ if x.shape[2] != skip_connection.shape[2] or x.shape[3] != skip_connection.shape[3]:
512
+ x = torch.nn.functional.pad(x, ( 0, -x.shape[3] + skip_connection.shape[3], 0, -x.shape[2] + skip_connection.shape[2]), mode="reflect")
513
+ # need to add skip connection because downsampling and upsampling will break residual connection
514
+ # 0.5 is needed to make sure that the skip connection is not too strong
515
+ # in case of no downsample / upsample we can show that 0.5 compensates for the residual connection
516
+ x = 0.5 * x + 0.5 * skip_connection
517
+ return x
518
+
519
+
520
+
521
+
522
+ class MultiResolutionAttention(nn.Module):
523
+ """
524
+ MultiResolutionAttention (MRA) module
525
+ The idea is to use multiple attention blocks with different resolution
526
+ Feature maps are downsampled / upsampled for each attention block on different blocks
527
+ Every attention block supports windowing
528
+ """
529
+
530
+ def __init__(self, window_size, sr_ratio,
531
+ dim, dim_ratio, num_heads,
532
+ do_windowing=True,
533
+ layer_scale=1e-5, norm_layer=nn.LayerNorm,
534
+ drop_path = 0, qkv_bias=False, qk_scale=1.0,
535
+ use_swiglu=True, multi_query=False, conv_base=False,
536
+ use_shift=0, cpb_mlp_hidden=512, conv_groups_ratio=0) -> None:
537
+ """
538
+ Args:
539
+ input_resolution: input image resolution
540
+ window_size: window size
541
+ compression_ratio: compression ratio
542
+ max_depth: maximum depth of the GRA module
543
+ use_shift: do window shifting
544
+ """
545
+ super().__init__()
546
+
547
+ depth = len(sr_ratio)
548
+
549
+ self.attention_blocks = nn.ModuleList()
550
+
551
+
552
+ for i in range(depth):
553
+ subsample_ratio = sr_ratio[i]
554
+ if len(window_size) > i:
555
+ window_size_local = window_size[i]
556
+ else:
557
+ window_size_local = window_size[0]
558
+
559
+ self.attention_blocks.append(GRAAttentionBlock(window_size=window_size_local,
560
+ dim_in=dim, dim_out=dim, num_heads=num_heads,
561
+ qkv_bias=qkv_bias, qk_scale=qk_scale, norm_layer=norm_layer,
562
+ layer_scale=layer_scale, drop_path=drop_path,
563
+ use_swiglu=use_swiglu, subsample_ratio=subsample_ratio, dim_ratio=dim_ratio,
564
+ do_windowing=do_windowing, multi_query=multi_query, conv_base=conv_base,
565
+ use_shift=use_shift, cpb_mlp_hidden=cpb_mlp_hidden, conv_groups_ratio=conv_groups_ratio),
566
+ )
567
+
568
+ def forward(self, x):
569
+
570
+ for attention_block in self.attention_blocks:
571
+ x = attention_block(x)
572
+
573
+ return x
574
+
575
+
576
+
577
+ class Mlp(nn.Module):
578
+ """
579
+ Multi-Layer Perceptron (MLP) block
580
+ """
581
+
582
+ def __init__(self,
583
+ in_features,
584
+ hidden_features=None,
585
+ out_features=None,
586
+ act_layer=nn.GELU,
587
+ use_swiglu=True,
588
+ drop=0.):
589
+ """
590
+ Args:
591
+ in_features: input features dimension.
592
+ hidden_features: hidden features dimension.
593
+ out_features: output features dimension.
594
+ act_layer: activation function.
595
+ drop: dropout rate.
596
+ """
597
+
598
+ super().__init__()
599
+ out_features = out_features or in_features
600
+ hidden_features = hidden_features or in_features
601
+ self.fc1 = nn.Linear(in_features, hidden_features * (2 if use_swiglu else 1), bias=False)
602
+ self.act = act_layer()
603
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=False)
604
+
605
+ def forward(self, x):
606
+ x_size = x.size()
607
+ x = x.view(-1, x_size[-1])
608
+ x = self.fc1(x)
609
+ x = self.act(x)
610
+ x = self.fc2(x)
611
+ x = x.view(x_size)
612
+ return x
613
+
614
+ class Downsample(nn.Module):
615
+ """
616
+ Down-sampling block
617
+ Pixel Unshuffle is used for down-sampling, works great accuracy - wise but takes 10% more TRT time
618
+ """
619
+
620
+ def __init__(self,
621
+ dim,
622
+ shuffle = False,
623
+ ):
624
+ """
625
+ Args:
626
+ dim: feature size dimension.
627
+ shuffle: idea with
628
+ keep_dim: bool argument for maintaining the resolution.
629
+ """
630
+
631
+ super().__init__()
632
+ dim_out = 2 * dim
633
+
634
+ if shuffle:
635
+ self.norm = lambda x: pixel_unshuffle(x, factor=2)
636
+ self.reduction = Conv2d_BN(dim*4, dim_out, 1, 1, 0, bias=False)
637
+ # pixel unshuffleging works well but doesnt provide any speedup
638
+ else:
639
+ # removed layer norm for better, in this formulation we are getting 10% better speed
640
+ # LayerNorm for high resolution inputs will be a pain as it pools over the entire spatial dimension
641
+ # therefore we remove it compared to the original implementation in FasterViT
642
+ self.norm = nn.Identity()
643
+ self.reduction = Conv2d_BN(dim, dim_out, 3, 2, 1, bias=False)
644
+
645
+
646
+ def forward(self, x):
647
+ x = self.norm(x)
648
+ x = self.reduction(x)
649
+ return x
650
+
651
+
652
+ class PatchEmbed(nn.Module):
653
+ """
654
+ Patch embedding block
655
+ Used to convert image into an initial set of feature maps with lower resolution
656
+ """
657
+
658
+ def __init__(self, in_chans=3, in_dim=64, dim=96, shuffle_down=False):
659
+ """
660
+ Args:
661
+ in_chans: number of input channels.
662
+ in_dim: intermediate feature size dimension to speed up stem.
663
+ dim: final stem channel number
664
+ shuffle_down: use PixelUnshuffle for down-sampling, effectively increases the receptive field
665
+ """
666
+
667
+ super().__init__()
668
+ # shuffle_down = False
669
+ if not shuffle_down:
670
+ self.proj = nn.Identity()
671
+ self.conv_down = nn.Sequential(
672
+ Conv2d_BN(in_chans, in_dim, 3, 2, 1, bias=False),
673
+ nn.ReLU(),
674
+ Conv2d_BN(in_dim, dim, 3, 2, 1, bias=False),
675
+ nn.ReLU()
676
+ )
677
+ else:
678
+ self.proj = lambda x: pixel_unshuffle(x, factor=4)
679
+ self.conv_down = nn.Sequential(Conv2d_BN(in_chans*16, dim, 3, 1, 1),
680
+ nn.ReLU(),
681
+ )
682
+
683
+ def forward(self, x):
684
+ x = self.proj(x)
685
+ x = self.conv_down(x)
686
+ return x
687
+
688
+
689
+
690
+ class ConvBlock(nn.Module):
691
+ """
692
+ Convolutional block, used in first couple of stages
693
+ Experimented with plan resnet-18 like modules, they are the best in terms of throughput
694
+ Finally, YOLOv8 idea seem to work fine (resnet-18 like block with squeezed feature dimension, and feature concatendation at the end)
695
+ """
696
+ def __init__(self, dim,
697
+ drop_path=0.,
698
+ layer_scale=None,
699
+ kernel_size=3,
700
+ ):
701
+ super().__init__()
702
+
703
+ self.conv1 = Conv2d_BN(dim, dim, kernel_size=kernel_size, stride=1, padding=1)
704
+ self.act1 = nn.GELU()
705
+
706
+ self.conv2 = Conv2d_BN(dim, dim, kernel_size=kernel_size, stride=1, padding=1)
707
+
708
+ self.layer_scale = layer_scale
709
+ if layer_scale is not None and type(layer_scale) in [int, float]:
710
+ self.gamma = nn.Parameter(layer_scale * torch.ones(dim))
711
+ self.layer_scale = True
712
+ else:
713
+ self.layer_scale = False
714
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
715
+
716
+ def forward(self, x):
717
+ input = x
718
+
719
+ x = self.conv1(x)
720
+ x = self.act1(x)
721
+ x = self.conv2(x)
722
+
723
+ if self.layer_scale:
724
+ x = x * self.gamma.view(1, -1, 1, 1)
725
+ x = input + self.drop_path(x)
726
+ return x
727
+
728
+
729
+ class WindowAttention(nn.Module):
730
+ # Windowed Attention from SwinV2
731
+ # use a MLP trick to deal with various input image resolutions, then fold it to improve speed
732
+
733
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, resolution=0,
734
+ seq_length=0, dim_out=None, multi_query=False, shift_size=0, cpb_mlp_hidden=512):
735
+ # taken from EdgeViT and tweaked with attention bias.
736
+ super().__init__()
737
+ if not dim_out: dim_out = dim
738
+ self.shift_size = shift_size
739
+ self.multi_query = multi_query
740
+ self.num_heads = num_heads
741
+ head_dim = dim // num_heads
742
+ self.head_dim = dim // num_heads
743
+
744
+ self.dim_internal = dim
745
+
746
+ self.scale = qk_scale or head_dim ** -0.5
747
+ if not multi_query:
748
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
749
+ else:
750
+ self.qkv = nn.Linear(dim, dim + 2*self.head_dim, bias=qkv_bias)
751
+
752
+ self.proj = nn.Linear(dim, dim_out, bias=False)
753
+ # attention positional bias
754
+ self.pos_emb_funct = PosEmbMLPSwinv2D(window_size=[resolution, resolution],
755
+ pretrained_window_size=[resolution, resolution],
756
+ num_heads=num_heads,
757
+ seq_length=seq_length,
758
+ cpb_mlp_hidden=cpb_mlp_hidden)
759
+
760
+ self.resolution = resolution
761
+
762
+ def forward(self, x, attn_mask = None):
763
+ B, N, C = x.shape
764
+
765
+ if not self.multi_query:
766
+ qkv = self.qkv(x).reshape(B, -1, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
767
+ q, k, v = qkv[0], qkv[1], qkv[2]
768
+ else:
769
+ qkv = self.qkv(x)
770
+ (q, k, v) = qkv.split([self.dim_internal, self.head_dim, self.head_dim], dim=2)
771
+
772
+ q = q.reshape(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
773
+ k = k.reshape(B, -1, 1, C // self.num_heads).permute(0, 2, 1, 3)
774
+ v = v.reshape(B, -1, 1, C // self.num_heads).permute(0, 2, 1, 3)
775
+
776
+ attn = (q @ k.transpose(-2, -1)) * self.scale
777
+
778
+ attn = self.pos_emb_funct(attn)
779
+
780
+ #add window shift
781
+ if attn_mask is not None:
782
+ nW = attn_mask.shape[0]
783
+ attn = attn.view(B // nW, nW, self.num_heads, N, N) + attn_mask.unsqueeze(1).unsqueeze(0)
784
+ attn = attn.view(-1, self.num_heads, N, N)
785
+
786
+ attn = attn.softmax(dim=-1)
787
+ x = (attn @ v).transpose(1, 2).reshape(B, -1, C)
788
+ x = self.proj(x)
789
+ return x
790
+
791
+
792
+
793
+ class ERADIOLayer(nn.Module):
794
+ """
795
+ E-RADIO Layer
796
+ """
797
+
798
+ def __init__(self,
799
+ dim,
800
+ depth,
801
+ num_heads,
802
+ window_size,
803
+ conv=False,
804
+ downsample=True,
805
+ mlp_ratio=4.,
806
+ qkv_bias=False,
807
+ qk_scale=None,
808
+ norm_layer=nn.LayerNorm,
809
+ drop_path=0.,
810
+ layer_scale=None,
811
+ layer_scale_conv=None,
812
+ sr_dim_ratio=1,
813
+ sr_ratio=1,
814
+ multi_query=False,
815
+ use_swiglu=True,
816
+ yolo_arch=False,
817
+ downsample_shuffle=False,
818
+ conv_base=False,
819
+ use_shift=False,
820
+ cpb_mlp_hidden=512,
821
+ conv_groups_ratio=0,
822
+ verbose: bool = True,
823
+
824
+ ):
825
+ """
826
+ Args:
827
+ dim: feature size dimension.
828
+ depth: number of layers in each stage.
829
+ input_resolution: input image resolution.
830
+ window_size: window size in each stage.
831
+ downsample: bool argument for down-sampling.
832
+ mlp_ratio: MLP ratio.
833
+ num_heads: number of heads in each stage.
834
+ qkv_bias: bool argument for query, key, value learnable bias.
835
+ qk_scale: bool argument to scaling query, key.
836
+ drop: dropout rate.
837
+ attn_drop: attention dropout rate.
838
+ drop_path: drop path rate.
839
+ norm_layer: normalization layer.
840
+ layer_scale: layer scaling coefficient.
841
+ use_shift: SWIN like window shifting for half the window size for every alternating layer (considering multi-resolution)
842
+ conv_groups_ratio: group ratio for conv when no subsampling in multi-res attention
843
+ """
844
+
845
+ super().__init__()
846
+ self.conv = conv
847
+ self.yolo_arch=False
848
+ self.verbose = verbose
849
+ if conv:
850
+ if not yolo_arch:
851
+ self.blocks = nn.ModuleList([
852
+ ConvBlock(dim=dim,
853
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
854
+ layer_scale=layer_scale_conv)
855
+ for i in range(depth)])
856
+ self.blocks = nn.Sequential(*self.blocks)
857
+ else:
858
+ self.blocks = C2f(dim,dim,n=depth,shortcut=True,e=0.5)
859
+ self.yolo_arch=True
860
+ else:
861
+ if not isinstance(window_size, list): window_size = [window_size]
862
+ self.window_size = window_size[0]
863
+ self.do_single_windowing = True
864
+ if not isinstance(sr_ratio, list): sr_ratio = [sr_ratio]
865
+ self.sr_ratio = sr_ratio
866
+ if any([sr!=1 for sr in sr_ratio]) or len(set(window_size))>1:
867
+ self.do_single_windowing = False
868
+ do_windowing = True
869
+ else:
870
+ self.do_single_windowing = True
871
+ do_windowing = False
872
+
873
+ #for v2_2
874
+ if conv_groups_ratio != -1:
875
+ self.do_single_windowing = False
876
+ do_windowing = True
877
+
878
+ self.blocks = nn.ModuleList()
879
+ for i in range(depth):
880
+ self.blocks.append(
881
+ MultiResolutionAttention(window_size=window_size,
882
+ sr_ratio=sr_ratio,
883
+ dim=dim,
884
+ dim_ratio = sr_dim_ratio,
885
+ num_heads=num_heads,
886
+ norm_layer=norm_layer,
887
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
888
+ layer_scale=layer_scale,
889
+ qkv_bias=qkv_bias,
890
+ qk_scale=qk_scale,
891
+ use_swiglu=use_swiglu,
892
+ do_windowing=do_windowing,
893
+ multi_query=multi_query,
894
+ conv_base=conv_base,
895
+ cpb_mlp_hidden=cpb_mlp_hidden,
896
+ use_shift =0 if ((not use_shift) or ((i) % 2 == 0)) else True ,
897
+ conv_groups_ratio=conv_groups_ratio,
898
+ ))
899
+ self.blocks = nn.Sequential(*self.blocks)
900
+
901
+ self.transformer = not conv
902
+ self.downsample = None if not downsample else Downsample(dim=dim, shuffle=downsample_shuffle)
903
+
904
+
905
+ def forward(self, x):
906
+ B, C, H, W = x.shape
907
+
908
+ # do padding for transforemr
909
+ interpolate = True
910
+ if self.transformer and interpolate:
911
+ # Windowed Attention will split feature map into windows with the size of window_size x window_size
912
+ # if the resolution is not divisible by window_size, we need to interpolate the feature map
913
+ # can be done via padding, but doing so after training hurts the model performance.
914
+ # interpolation affects the performance as well, but not as much as padding
915
+ if isinstance(self.window_size, list) or isinstance(self.window_size, tuple):
916
+ current_max_window_size = max(self.window_size)
917
+ else:
918
+ current_max_window_size = self.window_size
919
+
920
+ max_window_size = max([res_upsample*current_max_window_size for res_upsample in self.sr_ratio])
921
+ if H % max_window_size != 0 or W % max_window_size != 0:
922
+ new_h = int(np.ceil(H/max_window_size)*max_window_size)
923
+ new_w = int(np.ceil(W/max_window_size)*max_window_size)
924
+ x = F.interpolate(x, size=(new_h, new_w), mode='nearest')
925
+ if self.verbose:
926
+ warnings.warn(f"Choosen window size is not optimal for given resolution. Interpolation of features maps will be done and it can affect the performance. Max window size is {max_window_size}, feature map size is {H}x{W}, interpolated feature map size is {new_h}x{new_w}.")
927
+
928
+
929
+ if self.transformer and self.do_single_windowing:
930
+ H, W = x.shape[2], x.shape[3]
931
+ x, pad_hw = window_partition(x, self.window_size)
932
+
933
+ #run main blocks
934
+ x = self.blocks(x)
935
+
936
+ if self.transformer and self.do_single_windowing:
937
+ x = window_reverse(x, self.window_size, H, W, pad_hw)
938
+
939
+ if self.transformer and interpolate:
940
+ #lets keep original resolution, might be not ideal, but for the upsampling tower we need to keep the expected resolution.
941
+ x = F.interpolate(x, size=(H, W), mode='nearest')
942
+
943
+ if self.downsample is None:
944
+ return x, x
945
+
946
+ return self.downsample(x), x # changing to output pre downsampled features
947
+
948
+
949
+ class InterpolateLayer(nn.Module):
950
+ def __init__(self, size=None, scale_factor=None, mode='nearest'):
951
+ super(InterpolateLayer, self).__init__()
952
+ self.size = size
953
+ self.scale_factor = scale_factor
954
+ self.mode = mode
955
+
956
+ def forward(self, x):
957
+ return F.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode)
958
+
959
+
960
+ class HiResNeck(nn.Module):
961
+ """
962
+ The block is used to output dense features from all stages
963
+ Otherwise, by default, only the last stage features are returned with E-RADIO
964
+ """
965
+ def __init__(self, dim, depths, neck_start_stage, full_features_head_dim, downsample_enabled):
966
+
967
+ '''
968
+ Hi Resolution neck to support output of high res features that are useful for dense tasks.
969
+ depths - total number of layers in the base model
970
+ neck_start_stage - when to start the neck, 0 - start from the first stage, 1 - start from the second stage etc.
971
+ earlier layers result in higher resolution features at the cost of compute
972
+ full_features_head_dim - number of channels in the dense features head
973
+ '''
974
+ super().__init__()
975
+ # create feature projection layers for segmentation output
976
+ self.neck_features_proj = nn.ModuleList()
977
+ self.neck_start_stage = neck_start_stage
978
+ upsample_ratio = 1
979
+ for i in range(len(depths)):
980
+ level_n_features_output = int(dim * 2 ** i)
981
+
982
+ if self.neck_start_stage > i: continue
983
+
984
+ if (upsample_ratio > 1) or full_features_head_dim!=level_n_features_output:
985
+ feature_projection = nn.Sequential()
986
+ if False:
987
+ feature_projection.add_module("norm",nn.BatchNorm2d(level_n_features_output)) #fast, but worse
988
+ feature_projection.add_module("dconv", nn.ConvTranspose2d(level_n_features_output,
989
+ full_features_head_dim, kernel_size=upsample_ratio, stride=upsample_ratio))
990
+ else:
991
+ # B, in_channels, H, W -> B, in_channels, H*upsample_ratio, W*upsample_ratio
992
+ # print("upsample ratio", upsample_ratio, level_n_features_output, level_n_features_output)
993
+ feature_projection.add_module("upsample", InterpolateLayer(scale_factor=upsample_ratio, mode='nearest'))
994
+ feature_projection.add_module("conv1", nn.Conv2d(level_n_features_output, level_n_features_output, kernel_size=3, stride=1, padding=1, groups=level_n_features_output))
995
+ feature_projection.add_module("norm",nn.BatchNorm2d(level_n_features_output))
996
+ # B, in_channels, H*upsample_ratio, W*upsample_ratio -> B, full_features_head_dim, H*upsample_ratio, W*upsample_ratio
997
+ feature_projection.add_module("conv2", nn.Conv2d(level_n_features_output, full_features_head_dim, kernel_size=1, stride=1, padding=0))
998
+ else:
999
+ feature_projection = nn.Sequential()
1000
+
1001
+ self.neck_features_proj.append(feature_projection)
1002
+
1003
+ if i>0 and downsample_enabled[i]:
1004
+ upsample_ratio *= 2
1005
+
1006
+ def forward(self, x, il_level=-1, full_features=None):
1007
+ if self.neck_start_stage > il_level:
1008
+ return full_features
1009
+
1010
+ if full_features is None:
1011
+ full_features = self.neck_features_proj[il_level - self.neck_start_stage](x)
1012
+ else:
1013
+ #upsample torch tensor x to match full_features size, and add to full_features
1014
+ feature_projection = self.neck_features_proj[il_level - self.neck_start_stage](x)
1015
+ if feature_projection.shape[2] != full_features.shape[2] or feature_projection.shape[3] != full_features.shape[3]:
1016
+ feature_projection = torch.nn.functional.pad(feature_projection, ( 0, -feature_projection.shape[3] + full_features.shape[3], 0, -feature_projection.shape[2] + full_features.shape[2]))
1017
+ full_features = full_features + feature_projection
1018
+ return full_features
1019
+
1020
+ class ERADIO(nn.Module):
1021
+ """
1022
+ Efficient RADIO
1023
+ """
1024
+
1025
+ def __init__(self,
1026
+ dim,
1027
+ in_dim,
1028
+ depths,
1029
+ window_size,
1030
+ mlp_ratio,
1031
+ num_heads,
1032
+ drop_path_rate=0.2,
1033
+ in_chans=3,
1034
+ num_classes=1000,
1035
+ qkv_bias=False,
1036
+ qk_scale=None,
1037
+ layer_scale=None,
1038
+ layer_scale_conv=None,
1039
+ layer_norm_last=False,
1040
+ sr_ratio = [1, 1, 1, 1],
1041
+ max_depth = -1,
1042
+ conv_base=False,
1043
+ use_swiglu=False,
1044
+ multi_query=False,
1045
+ norm_layer=nn.LayerNorm,
1046
+ drop_uniform=False,
1047
+ yolo_arch=False,
1048
+ shuffle_down=False,
1049
+ downsample_shuffle=False,
1050
+ return_full_features=False,
1051
+ full_features_head_dim=128,
1052
+ neck_start_stage=1,
1053
+ use_neck=False,
1054
+ use_shift=False,
1055
+ cpb_mlp_hidden=512,
1056
+ conv_groups_ratio=0,
1057
+ verbose: bool = False,
1058
+ **kwargs):
1059
+ """
1060
+ Args:
1061
+ dim: feature size dimension.
1062
+ depths: number of layers in each stage.
1063
+ window_size: window size in each stage.
1064
+ mlp_ratio: MLP ratio.
1065
+ num_heads: number of heads in each stage.
1066
+ drop_path_rate: drop path rate.
1067
+ in_chans: number of input channels.
1068
+ num_classes: number of classes.
1069
+ qkv_bias: bool argument for query, key, value learnable bias.
1070
+ qk_scale: bool argument to scaling query, key.
1071
+ drop_rate: dropout rate.
1072
+ attn_drop_rate: attention dropout rate.
1073
+ norm_layer: normalization layer.
1074
+ layer_scale: layer scaling coefficient.
1075
+ return_full_features: output dense features as well as logits
1076
+ full_features_head_dim: number of channels in the dense features head
1077
+ neck_start_stage: a stage id to start full feature neck. Model has 4 stages, indix starts with 0
1078
+ for 224 resolution, the output of the stage before downsample:
1079
+ stage 0: 56x56, stage 1: 28x28, stage 2: 14x14, stage 3: 7x7
1080
+ use_neck: even for summarization embedding use neck
1081
+ use_shift: SWIN like window shifting but without masking attention
1082
+ conv_groups_ratio: will be used for conv blocks where there is no multires attention,
1083
+ if 0 then normal conv,
1084
+ if 1 then channels are independent,
1085
+ if -1 then no conv at all
1086
+
1087
+ """
1088
+ super().__init__()
1089
+
1090
+ num_features = int(dim * 2 ** (len(depths) - 1))
1091
+ self.num_classes = num_classes
1092
+ self.patch_embed = PatchEmbed(in_chans=in_chans, in_dim=in_dim, dim=dim, shuffle_down=shuffle_down)
1093
+ # set return_full_features true if we want to return full features from all stages
1094
+ self.return_full_features = return_full_features
1095
+ self.use_neck = use_neck
1096
+
1097
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
1098
+ if drop_uniform:
1099
+ dpr = [drop_path_rate for x in range(sum(depths))]
1100
+
1101
+ if not isinstance(max_depth, list): max_depth = [max_depth] * len(depths)
1102
+
1103
+ self.levels = nn.ModuleList()
1104
+ for i in range(len(depths)):
1105
+ conv = True if (i == 0 or i == 1) else False
1106
+
1107
+ level = ERADIOLayer(dim=int(dim * 2 ** i),
1108
+ depth=depths[i],
1109
+ num_heads=num_heads[i],
1110
+ window_size=window_size[i],
1111
+ mlp_ratio=mlp_ratio,
1112
+ qkv_bias=qkv_bias,
1113
+ qk_scale=qk_scale,
1114
+ conv=conv,
1115
+ drop_path=dpr[sum(depths[:i]):sum(depths[:i + 1])],
1116
+ downsample=(i < len(depths) - 1),
1117
+ layer_scale=layer_scale,
1118
+ layer_scale_conv=layer_scale_conv,
1119
+ sr_ratio=sr_ratio[i],
1120
+ use_swiglu=use_swiglu,
1121
+ multi_query=multi_query,
1122
+ norm_layer=norm_layer,
1123
+ yolo_arch=yolo_arch,
1124
+ downsample_shuffle=downsample_shuffle,
1125
+ conv_base=conv_base,
1126
+ cpb_mlp_hidden=cpb_mlp_hidden,
1127
+ use_shift=use_shift,
1128
+ conv_groups_ratio=conv_groups_ratio,
1129
+ verbose=verbose)
1130
+
1131
+ self.levels.append(level)
1132
+
1133
+ if self.return_full_features or self.use_neck:
1134
+ #num_heads
1135
+ downsample_enabled = [self.levels[i-1].downsample is not None for i in range(len(self.levels))]
1136
+ self.high_res_neck = HiResNeck(dim, depths, neck_start_stage, full_features_head_dim, downsample_enabled)
1137
+
1138
+ self.switched_to_deploy = False
1139
+
1140
+ self.norm = LayerNorm2d(num_features) if layer_norm_last else nn.BatchNorm2d(num_features)
1141
+ self.avgpool = nn.AdaptiveAvgPool2d(1)
1142
+ self.head = nn.Linear(num_features, num_classes) if num_classes > 0 else nn.Identity()
1143
+ self.apply(self._init_weights)
1144
+
1145
+ def _init_weights(self, m):
1146
+ if isinstance(m, nn.Linear):
1147
+ trunc_normal_(m.weight, std=.02)
1148
+ if isinstance(m, nn.Linear) and m.bias is not None:
1149
+ nn.init.constant_(m.bias, 0)
1150
+ elif isinstance(m, nn.LayerNorm):
1151
+ nn.init.constant_(m.bias, 0)
1152
+ nn.init.constant_(m.weight, 1.0)
1153
+ elif isinstance(m, LayerNorm2d):
1154
+ nn.init.constant_(m.bias, 0)
1155
+ nn.init.constant_(m.weight, 1.0)
1156
+ elif isinstance(m, nn.BatchNorm2d):
1157
+ nn.init.ones_(m.weight)
1158
+ nn.init.zeros_(m.bias)
1159
+
1160
+ @torch.jit.ignore
1161
+ def no_weight_decay_keywords(self):
1162
+ return {'rpb'}
1163
+
1164
+ def forward_features(self, x):
1165
+ _, _, H, W = x.shape
1166
+ if H % 32 != 0 or W % 32 != 0:
1167
+ raise ValueError(f"E-RADIO requires input dimensions to be divisible by 32 but got H x W: {H} x {W}")
1168
+ x = self.patch_embed(x)
1169
+ full_features = None
1170
+ for il, level in enumerate(self.levels):
1171
+ x, pre_downsample_x = level(x)
1172
+
1173
+ if self.return_full_features or self.use_neck:
1174
+ full_features = self.high_res_neck(pre_downsample_x, il, full_features)
1175
+
1176
+ # x = self.norm(full_features if (self.return_full_features or self.use_neck) else x)
1177
+ x = self.norm(x) # new version for
1178
+
1179
+ if not self.return_full_features:
1180
+ return x, None
1181
+
1182
+ return x, full_features
1183
+
1184
+ def forward(self, x):
1185
+ x, full_features = self.forward_features(x)
1186
+
1187
+ x = self.avgpool(x)
1188
+ x = torch.flatten(x, 1)
1189
+
1190
+ x = self.head(x)
1191
+ if full_features is not None:
1192
+ return x, full_features
1193
+ return x
1194
+
1195
+ def switch_to_deploy(self):
1196
+ '''
1197
+ A method to perform model self-compression
1198
+ merges BN into conv layers
1199
+ converts MLP relative positional bias into precomputed buffers
1200
+ '''
1201
+ if not self.switched_to_deploy:
1202
+ for level in [self.patch_embed, self.levels, self.head]:
1203
+ for module in level.modules():
1204
+ if hasattr(module, 'switch_to_deploy'):
1205
+ module.switch_to_deploy()
1206
+ self.switched_to_deploy = True
1207
+
1208
+
1209
+ def change_window_size(self, new_window_size):
1210
+ """
1211
+ E-RADIO employs windowed attention, which may be sensitive to the choice of this parameter,
1212
+ especially in cases of uneven partitioning of the feature maps.
1213
+ E-RADIO allows for the adjustment of the window size after training,
1214
+ making it adaptable to different input image resolutions.
1215
+ The recommended values for window size based on input resolution are as follows:
1216
+
1217
+ Input Resolution | Window Size
1218
+ 224 | 7
1219
+ 256 | 8
1220
+ 386 | 12
1221
+ 512 | 16
1222
+ Ideally, the window size should be a factor of the input resolution. In the third stage, we divide the resolution by 16, so the window size should be
1223
+ img_res/16/2
1224
+ for the third stage and img_res/32 for the last stage. While this can be applied in a brute-force manner, a better way is to do model.change_window_size.
1225
+ Manual way to change resolution -> model.change_window_size(resolution)
1226
+ """
1227
+ window_size = new_window_size
1228
+ print(f"Setting window size to {window_size}")
1229
+ for module in self.modules():
1230
+ if hasattr(module, "window_size"):
1231
+ # check if tuple or a number
1232
+ if isinstance(module.window_size, tuple):
1233
+ if module.window_size[0] != window_size:
1234
+ module.window_size = (window_size, window_size)
1235
+ elif isinstance(module.window_size, list):
1236
+ if module.window_size[0] != window_size:
1237
+ module.window_size = [window_size, window_size]
1238
+ else:
1239
+ module.window_size = window_size
1240
+
1241
+
1242
+ def set_optimal_window_size(self, image_dim, max_window_size = 16):
1243
+ """
1244
+ Using hand picked window size for various resolutions.
1245
+
1246
+ E-RADIO employs windowed attention, which may be sensitive to the choice of this parameter,
1247
+ especially in cases of uneven partitioning of the feature maps.
1248
+ E-RADIO allows for the adjustment of the window size after training,
1249
+ making it adaptable to different input image resolutions.
1250
+ The recommended values for window size based on input resolution are as follows:
1251
+
1252
+ Input Resolution | Window Size
1253
+ 224 | 7
1254
+ 256 | 8
1255
+ 386 | 12
1256
+ 512 | 16
1257
+ Ideally, the window size should be a factor of the input resolution. In the third stage, we divide the resolution by 16, so the window size should be
1258
+ img_res/16/2
1259
+ for the third stage and img_res/32 for the last stage. While this can be applied in a brute-force manner, a better way is to do model.change_window_size.
1260
+ Manual way to change resolution -> model.change_window_size(resolution)
1261
+
1262
+ """
1263
+ # import math
1264
+
1265
+ def divisorGenerator(n):
1266
+ large_divisors = []
1267
+ for i in range(1, int(math.sqrt(n) + 1)):
1268
+ if n % i == 0:
1269
+ yield i
1270
+ if i*i != n:
1271
+ large_divisors.append(n / i)
1272
+ for divisor in reversed(large_divisors):
1273
+ yield divisor
1274
+
1275
+ if isinstance(image_dim, list) or isinstance(image_dim, tuple):
1276
+ image_dim = min(image_dim)
1277
+
1278
+ # we do windowed attention in the 3rd stage for the first time, therefore //16,
1279
+ # we do subsampled attention with downsample by 2 so need to get //32 actually
1280
+ # ideally we should rewrite this to be dependent on the structure of the model like what if subsampled is removed etc
1281
+ all_divisors = np.array(list(divisorGenerator(image_dim//32)))
1282
+ new_window_size = int(min(all_divisors[all_divisors <= max_window_size][-1], max_window_size))
1283
+
1284
+ # for image_dim in [128, 224, 256, 384, 512, 768, 1024]:
1285
+ # all_divisors = np.array(list(divisorGenerator(image_dim//32)))
1286
+ # new_window_size = int(min(all_divisors[all_divisors <= max_window_size][-1], max_window_size))
1287
+ # print(f"Setting window size to {new_window_size} for image resolution {image_dim}")
1288
+
1289
+ self.change_window_size(new_window_size = new_window_size)
1290
+
1291
+
1292
+ @register_model
1293
+ def eradio_large_fullres_ws16(pretrained=False, **kwargs):
1294
+ model = ERADIO(
1295
+ depths=[3, 3, 5, 5],
1296
+ num_heads=[2, 4, 8, 16],
1297
+ window_size=[None, None, [16, 16], 16],
1298
+ dim=192,
1299
+ in_dim=64,
1300
+ mlp_ratio=4,
1301
+ drop_path_rate=0.0,
1302
+ sr_ratio=[1, 1, [2, 1], 1],
1303
+ use_swiglu=False,
1304
+ yolo_arch=True,
1305
+ shuffle_down=False,
1306
+ conv_base=True,
1307
+ use_neck=True,
1308
+ full_features_head_dim=1536,
1309
+ neck_start_stage=2,
1310
+ **kwargs,
1311
+ )
1312
+ if pretrained:
1313
+ model.load_state_dict(torch.load(pretrained)["state_dict"])
1314
+ return model
1315
+
1316
+
1317
+ @register_model
1318
+ def eradio_xxxtiny(pretrained=False, **kwargs): # ,
1319
+ model = ERADIO(
1320
+ depths=[1, 3, 4, 5],
1321
+ num_heads=[2, 4, 8, 16],
1322
+ window_size=[None, None, [16, 16], 16],
1323
+ dim=32,
1324
+ in_dim=32,
1325
+ mlp_ratio=4,
1326
+ drop_path_rate=0.0,
1327
+ sr_ratio=[1, 1, [2, 1], 1],
1328
+ use_swiglu=False,
1329
+ yolo_arch=True,
1330
+ shuffle_down=False,
1331
+ conv_base=True,
1332
+ use_neck=True,
1333
+ full_features_head_dim=256,
1334
+ neck_start_stage=2,
1335
+ **kwargs,
1336
+ )
1337
+ if pretrained:
1338
+ model.load_state_dict(torch.load(pretrained))
1339
+ return model
1340
+
1341
+ @register_model
1342
+ def eradio_xxxtiny_8x_ws12(pretrained=False, **kwargs):
1343
+ model = ERADIO(depths=[1, 3, 4, 5],
1344
+ num_heads=[2, 4, 8, 16],
1345
+ window_size=[None, None, [12, 12], 12],
1346
+ dim=32,
1347
+ in_dim=32,
1348
+ mlp_ratio=4,
1349
+ drop_path_rate=0.0,
1350
+ sr_ratio=[1, 1, [2, 1], 1],
1351
+ use_swiglu=False,
1352
+ downsample_shuffle=False,
1353
+ yolo_arch=True,
1354
+ shuffle_down=False,
1355
+ cpb_mlp_hidden=64,
1356
+ use_neck=True,
1357
+ full_features_head_dim=256,
1358
+ neck_start_stage=2,
1359
+ conv_groups_ratio = 1,
1360
+ **kwargs)
1361
+ if pretrained:
1362
+ model.load_state_dict(torch.load(pretrained)["state_dict"])
1363
+ return model
1364
+
1365
+
1366
+ @register_model
1367
+ def eradio_xxxtiny_8x_ws16(pretrained=False, **kwargs):
1368
+ model = ERADIO(depths=[1, 3, 4, 5],
1369
+ num_heads=[2, 4, 8, 16],
1370
+ window_size=[None, None, [16, 16], 16],
1371
+ dim=32,
1372
+ in_dim=32,
1373
+ mlp_ratio=4,
1374
+ drop_path_rate=0.0,
1375
+ sr_ratio=[1, 1, [2, 1], 1],
1376
+ use_swiglu=False,
1377
+ downsample_shuffle=False,
1378
+ yolo_arch=True,
1379
+ shuffle_down=False,
1380
+ cpb_mlp_hidden=64,
1381
+ use_neck=True,
1382
+ full_features_head_dim=256,
1383
+ neck_start_stage=1,
1384
+ conv_groups_ratio = 1,
1385
+ **kwargs)
1386
+ if pretrained:
1387
+ model.load_state_dict(torch.load(pretrained)["state_dict"])
1388
+ return model
1389
+
1390
+ @register_model
1391
+ def eradio(pretrained=False, **kwargs):
1392
+ return eradio_large_fullres_ws16(pretrained=pretrained, **kwargs)
tim/models/nvidia_radio/radio/extra_models.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from distutils.version import LooseVersion
2
+ from types import MethodType
3
+ from typing import List, Optional, Tuple, Union
4
+ import warnings
5
+
6
+ import torch
7
+ from torch import nn
8
+ import torch.nn.functional as F
9
+
10
+ from timm.models.registry import register_model
11
+ from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
12
+
13
+ from .forward_intermediates import forward_intermediates
14
+ from .input_conditioner import InputConditioner
15
+
16
+ _has_torch_sdpa = hasattr(F, 'scaled_dot_product_attention')
17
+
18
+
19
+ class PaliGemmaWrapper(nn.Module):
20
+ def __init__(self, vis_model: nn.Module, embed_dim: int):
21
+ super().__init__()
22
+
23
+ self.vis_model = vis_model
24
+ self.embed_dim = embed_dim
25
+
26
+ @property
27
+ def patch_size(self):
28
+ return self.vis_model.embeddings.patch_size
29
+
30
+ @property
31
+ def blocks(self):
32
+ return self.vis_model.encoder.layers
33
+
34
+ @property
35
+ def embed_dim(self):
36
+ return self.vis_model.embeddings.embed_dim
37
+
38
+ def forward(self, x: torch.Tensor):
39
+ outputs = self.vis_model(
40
+ x,
41
+ return_dict=False,
42
+ interpolate_pos_encoding=True,
43
+ )
44
+
45
+ features = outputs[0].to(torch.float32)
46
+
47
+ summary = features.mean(dim=1)
48
+
49
+ return summary, features
50
+
51
+ def forward_features(self, x: torch.Tensor):
52
+ return self(x)
53
+
54
+
55
+ def _get_paligemma_model(repo: str, embed_dim: int = None, dtype: torch.dtype = torch.bfloat16):
56
+ from transformers import PaliGemmaForConditionalGeneration, __version__ as tx_version
57
+
58
+ if LooseVersion(tx_version) > LooseVersion('4.44.2'):
59
+ warnings.warn(f'Your transformers version "{tx_version}" is higher than 4.44.2, and for whatever reason, PaliGemma might be broken.')
60
+
61
+ extra_args = dict()
62
+
63
+ if dtype is not None:
64
+ extra_args['torch_dtype'] = dtype
65
+ rev = str(dtype).split('.')[-1]
66
+ extra_args['revision'] = rev
67
+
68
+ model = PaliGemmaForConditionalGeneration.from_pretrained(repo, **extra_args)
69
+
70
+ vis_model = model.vision_tower.vision_model
71
+
72
+ vis_model = PaliGemmaWrapper(vis_model, embed_dim)
73
+
74
+ return vis_model
75
+
76
+ @register_model
77
+ def paligemma_896_student(**kwargs):
78
+ model = _get_paligemma_model('google/paligemma-3b-pt-896', embed_dim=1152, dtype=None)
79
+
80
+ return model
81
+
82
+
83
+ def dv2_sdpa(self, x: torch.Tensor) -> torch.Tensor:
84
+ B, N, C = x.shape
85
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
86
+
87
+ q, k, v = qkv[0], qkv[1], qkv[2]
88
+ x = F.scaled_dot_product_attention(
89
+ q, k, v,
90
+ is_causal=False,
91
+ dropout_p=self.attn_drop.p if self.training else 0.,
92
+ scale=self.scale,
93
+ )
94
+ x = x.transpose(1, 2).reshape(B, N, C)
95
+ x = self.proj(x)
96
+ x = self.proj_drop(x)
97
+ return x
98
+
99
+ def _load_dino_v2(dino_v2_model, cache_dir: Optional[str] = None, pretrained=True, **kwargs):
100
+ if cache_dir:
101
+ torch.hub.set_dir(cache_dir)
102
+ model: nn.Module = torch.hub.load(
103
+ 'facebookresearch/dinov2',
104
+ dino_v2_model,
105
+ pretrained=pretrained,
106
+ # **kwargs,
107
+ )
108
+
109
+ if _has_torch_sdpa:
110
+ for n, m in model.named_modules():
111
+ if n.endswith('.attn'):
112
+ m.forward = MethodType(dv2_sdpa, m)
113
+
114
+ return model
115
+
116
+ class DinoWrapper(nn.Module):
117
+ def __init__(self, dino_model: nn.Module):
118
+ super().__init__()
119
+
120
+ self.inner = dino_model
121
+ dino_model.blocks = nn.Sequential(*dino_model.blocks)
122
+
123
+ @property
124
+ def embed_dim(self):
125
+ return self.inner.embed_dim
126
+
127
+ @property
128
+ def patch_size(self):
129
+ return self.inner.patch_size
130
+
131
+ @property
132
+ def num_cls_tokens(self):
133
+ return getattr(self.inner, 'num_tokens', 1)
134
+
135
+ @property
136
+ def num_registers(self):
137
+ return getattr(self.inner, 'num_register_tokens', 0)
138
+
139
+ @property
140
+ def num_summary_tokens(self):
141
+ return self.num_cls_tokens + self.num_registers
142
+
143
+ @property
144
+ def blocks(self):
145
+ return self.inner.blocks
146
+
147
+ def forward(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
148
+ parts = self.inner.forward_features(*args, **kwargs)
149
+
150
+ cls_token = parts['x_norm_clstoken']
151
+ features = parts['x_norm_patchtokens']
152
+
153
+ return cls_token, features
154
+
155
+ def forward_features(self, x: torch.Tensor):
156
+ x = self.inner.prepare_tokens_with_masks(x)
157
+ x = self.inner.blocks(x)
158
+ x_norm = self.inner.norm(x)
159
+
160
+ return x_norm[:, 0], x_norm[:, self.num_summary_tokens:]
161
+
162
+ def patchify(self, x: torch.Tensor) -> torch.Tensor:
163
+ return self.inner.prepare_tokens_with_masks(x)
164
+
165
+ def forward_intermediates(self,
166
+ x: torch.Tensor,
167
+ norm: bool = False,
168
+ **kwargs,
169
+ ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
170
+ return forward_intermediates(
171
+ self,
172
+ patch_extractor=self.inner.prepare_tokens_with_masks,
173
+ num_summary_tokens=self.num_summary_tokens,
174
+ num_cls_tokens=self.num_cls_tokens,
175
+ norm=self.inner.norm if norm else lambda y: y,
176
+ x=x,
177
+ **kwargs,
178
+ )
179
+
180
+
181
+ def _dino_student(arch: str, **kwargs):
182
+ from . import dinov2_arch
183
+
184
+ factory = getattr(dinov2_arch, arch)
185
+ model = factory()
186
+
187
+ model = DinoWrapper(model)
188
+
189
+ conditioner = InputConditioner(
190
+ input_scale=1.0,
191
+ norm_mean=IMAGENET_DEFAULT_MEAN,
192
+ norm_std=IMAGENET_DEFAULT_STD,
193
+ )
194
+
195
+ model.input_conditioner = conditioner
196
+
197
+ return model
198
+
199
+
200
+ @register_model
201
+ def dino_v2_l_student(**kwargs):
202
+ return _dino_student('dinov2_vitl14_reg', **kwargs)
203
+
204
+ @register_model
205
+ def dino_v2_g_student(**kwargs):
206
+ return _dino_student('dinov2_vitg14_reg', **kwargs)
tim/models/nvidia_radio/radio/extra_timm_models.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import math
10
+ import warnings
11
+
12
+ import torch
13
+ from torch import nn
14
+ from torch.nn import functional as F
15
+
16
+ from timm.models import register_model
17
+ from timm.models.vision_transformer import (
18
+ VisionTransformer,
19
+ _create_vision_transformer as _timm_create_vision_transformer,
20
+ Mlp,
21
+ Block,
22
+ LayerScale as TIMMLayerScale,
23
+ )
24
+
25
+ # Import these to also register them
26
+ from . import dinov2_arch
27
+
28
+
29
+ @register_model
30
+ def vit_tiny_patch14_224(pretrained=False, **kwargs) -> VisionTransformer:
31
+ """ ViT-Tiny (Vit-Ti/16)
32
+ """
33
+ model_args = dict(patch_size=14, embed_dim=192, depth=12, num_heads=3)
34
+ model = _create_vision_transformer('vit_tiny_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
35
+ return model
36
+
37
+
38
+ @register_model
39
+ def vit_small_patch14_224(pretrained=False, **kwargs) -> VisionTransformer:
40
+ """ ViT-Small (ViT-S/16)
41
+ """
42
+ model_args = dict(patch_size=14, embed_dim=384, depth=12, num_heads=6)
43
+ model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
44
+ return model
45
+
46
+
47
+ @register_model
48
+ def vit_base_patch14_224(pretrained=False, **kwargs) -> VisionTransformer:
49
+ """ ViT-Base (ViT-B/14) from original paper (https://arxiv.org/abs/2010.11929).
50
+ ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
51
+ """
52
+ model_args = dict(patch_size=14, embed_dim=768, depth=12, num_heads=12)
53
+ model = _create_vision_transformer('vit_base_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
54
+ return model
55
+
56
+
57
+ @register_model
58
+ def vit_base_patch16_v2_224(pretrained=False, **kwargs) -> VisionTransformer:
59
+ """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
60
+ ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
61
+ """
62
+ model_args = dict(
63
+ patch_size=16, embed_dim=768, depth=12, num_heads=12, init_values=1e-5,
64
+ reg_tokens=4, no_embed_class=True, img_size=518 * 16 // 14
65
+ )
66
+ model = _create_vision_transformer(
67
+ 'vit_base_patch14_reg4_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
68
+ return model
69
+
70
+
71
+ @register_model
72
+ def vit_large_patch16_v2_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
73
+ """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
74
+ ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
75
+ """
76
+ name = 'vit_large_patch14_reg4_dinov2'
77
+ model_args = dict(
78
+ patch_size=16, embed_dim=1024, depth=24, num_heads=16, init_values=1e-5,
79
+ reg_tokens=4, no_embed_class=True, img_size=518 * 16 // 14
80
+ )
81
+ model = _create_vision_transformer(name, pretrained=pretrained, **dict(model_args, **kwargs))
82
+
83
+ return model
84
+
85
+ @register_model
86
+ def vit_huge_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:
87
+ """ ViT-Huge model (ViT-H/16) from original paper (https://arxiv.org/abs/2010.11929).
88
+ """
89
+ model_args = dict(patch_size=16, embed_dim=1280, depth=32, num_heads=16)
90
+ if pretrained:
91
+ # There is no pretrained version of ViT-H/16, but we can adapt a ViT-H/14 for this purpose
92
+ model = _create_vision_transformer('vit_huge_patch14_224', pretrained=True, **dict(model_args, **kwargs))
93
+ else:
94
+ model = _create_vision_transformer('vit_huge_patch16_224', pretrained=False, **dict(model_args, **kwargs))
95
+ return model
96
+
97
+
98
+ @register_model
99
+ def vit_huge_patch16_224_mlpnorm(pretrained=False, **kwargs) -> VisionTransformer:
100
+ """ ViT-Huge model (ViT-H/16) from original paper (https://arxiv.org/abs/2010.11929).
101
+ """
102
+ model = vit_huge_patch16_224(pretrained=pretrained, **kwargs)
103
+
104
+ for m in model.modules():
105
+ if isinstance(m, Mlp) and not isinstance(m.norm, nn.LayerNorm):
106
+ m.norm = nn.LayerNorm(m.fc1.out_features)
107
+
108
+ return model
109
+
110
+
111
+ @register_model
112
+ def vit_giant_patch16_224(pretrained=False, scaled_ln: bool = False, **kwargs) -> VisionTransformer:
113
+ """ ViT-giant model (ViT-g/16) from original paper (https://arxiv.org/abs/2010.11929).
114
+ """
115
+ model_args = dict(patch_size=16, embed_dim=1536, depth=40, num_heads=24)
116
+ model = _create_vision_transformer('vit_giant_patch16_224', pretrained=False, **dict(model_args, **kwargs))
117
+ if scaled_ln:
118
+ _apply_scaled_ln(model)
119
+ return model
120
+
121
+
122
+ @register_model
123
+ def vit_bigG_patch14_224(pretrained=False, **kwargs) -> VisionTransformer:
124
+ model_args = dict(patch_size=14, embed_dim=1664, depth=48, num_heads=16, init_values=1e-6)
125
+ model = _create_vision_transformer('vit_bigG_patch14', pretrained=False, **dict(model_args, **kwargs))
126
+ return model
127
+
128
+
129
+ def _create_vision_transformer(*args, **kwargs):
130
+ model = _timm_create_vision_transformer(*args, **kwargs)
131
+ _patch_layer_scale(model)
132
+ return model
133
+
134
+
135
+ def _patch_layer_scale(model: VisionTransformer):
136
+ def replace_ls(old_ls: TIMMLayerScale):
137
+ new_ls = dinov2_arch.LayerScale(old_ls.gamma.shape[0], inplace=old_ls.inplace)
138
+ new_ls.load_state_dict(old_ls.state_dict())
139
+ return new_ls
140
+
141
+ # Monkey patch: Replace TIMM's LayerScale with our modified DINOv2 one, that uses a param name
142
+ # other than gamma, so that HFHub doesn't mess with it!
143
+ for mod in model.modules():
144
+ if isinstance(mod, Block):
145
+ if isinstance(mod.ls1, TIMMLayerScale):
146
+ mod.ls1 = replace_ls(mod.ls1)
147
+ if isinstance(mod.ls2, TIMMLayerScale):
148
+ mod.ls2 = replace_ls(mod.ls2)
149
+ pass
150
+
151
+
152
+ class ScaledLayerNorm(nn.LayerNorm):
153
+ '''
154
+ https://arxiv.org/pdf/2502.05795v1
155
+ '''
156
+ def __init__(self, ln_base: nn.LayerNorm, depth: int = 0):
157
+ super().__init__(ln_base.normalized_shape, eps=ln_base.eps, elementwise_affine=ln_base.elementwise_affine)
158
+ self.load_state_dict(ln_base.state_dict())
159
+ self.register_buffer('ln_scale', torch.tensor(1.0 / math.sqrt(depth)), persistent=False)
160
+
161
+ def forward(self, x):
162
+ y = super().forward(x)
163
+ y = y * self.ln_scale
164
+ return y
165
+
166
+
167
+ class DyT(nn.Module):
168
+ def __init__(self, C: int, init_alpha: float):
169
+ super().__init__()
170
+ self.alpha = nn.Parameter(torch.full((1,), init_alpha))
171
+ self.gamma = nn.Parameter(torch.ones(C))
172
+ self.beta = nn.Parameter(torch.zeros(C))
173
+
174
+ def forward(self, x: torch.Tensor):
175
+ x = F.tanh(self.alpha * x)
176
+ return self.gamma * x + self.beta
177
+
178
+ @register_model
179
+ def vit_large_dyt_patch16_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
180
+ """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
181
+ ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
182
+ """
183
+ model_args = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16)
184
+ model = _create_vision_transformer('vit_large_dyt_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
185
+
186
+ def _replace_ln_with_dyt(ln: nn.LayerNorm, depth: int):
187
+ return DyT(ln.normalized_shape[0], init_alpha=0.9)
188
+ _replace_ln(model, _replace_ln_with_dyt)
189
+
190
+ return model
191
+
192
+
193
+ def _apply_scaled_ln(model: VisionTransformer):
194
+ warnings.warn('Post-LayerNorm scaling activated!')
195
+
196
+ _replace_ln(model, lambda ln, depth: ScaledLayerNorm(ln, depth=depth))
197
+
198
+ def _replace_ln(model: VisionTransformer, fn):
199
+ def _inner_replace_ln(block: Block, depth: int, key: str):
200
+ prev = getattr(block, key)
201
+ if isinstance(prev, nn.LayerNorm):
202
+ setattr(block, key, fn(prev, depth=depth))
203
+
204
+ for i, block in enumerate(model.blocks):
205
+ _inner_replace_ln(block, i + 1, 'norm1')
206
+ _inner_replace_ln(block, i + 1, 'norm2')
tim/models/nvidia_radio/radio/feature_normalizer.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+ from collections import namedtuple
9
+ from typing import NamedTuple, Optional, Tuple
10
+ import torch
11
+ from torch import nn
12
+
13
+
14
+ def _run_kernel(x: torch.Tensor, mean: torch.Tensor, tx: torch.Tensor):
15
+ if x.ndim <= 3:
16
+ x = x - mean
17
+ x = x @ tx.T
18
+ elif x.ndim == 4:
19
+ x = x - mean.reshape(1, -1, 1, 1)
20
+ kernel = tx.reshape(*tx.shape, 1, 1)
21
+ x = torch.nn.functional.conv2d(x, weight=kernel, bias=None, stride=1, padding=0)
22
+ else:
23
+ raise ValueError(f'Unsupported input dimension: {x.ndim}, shape: {x.shape}')
24
+ return x
25
+
26
+
27
+ class FeatureNormalizer(nn.Module):
28
+ def __init__(self, embed_dim: int, dtype: torch.dtype = torch.float32):
29
+ super().__init__()
30
+
31
+ self.register_buffer('mean', torch.zeros(embed_dim, dtype=dtype))
32
+ self.register_buffer('tx', torch.eye(embed_dim, dtype=dtype))
33
+
34
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
35
+ x = _run_kernel(x, self.mean, self.tx)
36
+ return x
37
+
38
+
39
+ class InterFeatState(NamedTuple):
40
+ y: torch.Tensor
41
+ alpha: torch.Tensor
42
+
43
+
44
+ class IntermediateFeatureNormalizerBase(nn.Module):
45
+ def forward(self, x: torch.Tensor, index: int, rot_index: int = None, skip: Optional[int] = None) -> InterFeatState:
46
+ raise NotImplementedError()
47
+
48
+
49
+ class IntermediateFeatureNormalizer(IntermediateFeatureNormalizerBase):
50
+ def __init__(self, num_intermediates: int, embed_dim: int, rot_per_layer: bool = False, dtype: torch.dtype = torch.float32):
51
+ super().__init__()
52
+ self.register_buffer('alphas', torch.ones(num_intermediates, dtype=dtype))
53
+
54
+ rot = torch.eye(embed_dim, dtype=dtype)
55
+ if rot_per_layer:
56
+ rot = rot.unsqueeze(0).repeat(num_intermediates, 1, 1)
57
+
58
+ self.register_buffer('rotation', rot.contiguous())
59
+ self.register_buffer('means', torch.zeros(num_intermediates, embed_dim, dtype=dtype))
60
+
61
+ def forward(self, x: torch.Tensor, index: int, rot_index: int = None, skip: Optional[int] = None) -> InterFeatState:
62
+ if rot_index is None:
63
+ rot_index = index
64
+
65
+ if skip:
66
+ assert x.ndim == 3, f'Cannot use the `skip` parameter when the `x` tensor isn\'t 3-dimensional.'
67
+ prefix, x = x[:, :skip], x[:, skip:]
68
+
69
+ rotation = self._get_rotation(rot_index)
70
+ y = _run_kernel(x, self.means[index], rotation)
71
+
72
+ alpha = self.alphas[index]
73
+ if skip:
74
+ alpha = torch.cat([
75
+ torch.ones(skip, dtype=alpha.dtype, device=alpha.device),
76
+ alpha[None].expand(y.shape[1]),
77
+ ]).reshape(1, -1, 1)
78
+ y = torch.cat([prefix, y], dim=1)
79
+ else:
80
+ if x.ndim == 3:
81
+ alpha = alpha.reshape(1, 1, 1).expand(1, y.shape[1], 1)
82
+ elif x.ndim == 4:
83
+ alpha = alpha.reshape(1, 1, 1, 1).expand(1, 1, *y.shape[2:])
84
+ else:
85
+ raise ValueError(f'Unsupported input dimension: {x.ndim}')
86
+
87
+ return InterFeatState(y, alpha)
88
+
89
+ def _get_rotation(self, rot_index: int) -> torch.Tensor:
90
+ if self.rotation.ndim == 2:
91
+ return self.rotation
92
+ return self.rotation[rot_index]
93
+
94
+
95
+ class NullIntermediateFeatureNormalizer(IntermediateFeatureNormalizerBase):
96
+ instances = dict()
97
+
98
+ def __init__(self, dtype: torch.dtype, device: torch.device):
99
+ super().__init__()
100
+ self.register_buffer('alpha', torch.tensor(1, dtype=dtype, device=device))
101
+
102
+ @staticmethod
103
+ def get_instance(dtype: torch.dtype, device: torch.device):
104
+ instance = NullIntermediateFeatureNormalizer.instances.get((dtype, device), None)
105
+ if instance is None:
106
+ instance = NullIntermediateFeatureNormalizer(dtype, device)
107
+ NullIntermediateFeatureNormalizer.instances[(dtype, device)] = instance
108
+ return instance
109
+
110
+ def forward(self, x: torch.Tensor, index: int, rot_index: int = None, skip: Optional[int] = None) -> InterFeatState:
111
+ return InterFeatState(x, self.alpha)
tim/models/nvidia_radio/radio/forward_intermediates.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ from typing import Callable, Dict, List, Optional, Set, Tuple, Union, Any, Iterable
10
+ from types import MethodType
11
+
12
+ import torch
13
+ from torch import nn
14
+
15
+ from .feature_normalizer import IntermediateFeatureNormalizerBase, NullIntermediateFeatureNormalizer
16
+
17
+
18
+ def _take_indices(
19
+ num_blocks: int,
20
+ n: Optional[Union[int, List[int], Tuple[int]]],
21
+ ) -> Tuple[Set[int], int]:
22
+ if isinstance(n, int):
23
+ assert n >= 0
24
+ take_indices = {x for x in range(num_blocks - n, num_blocks)}
25
+ else:
26
+ take_indices = {num_blocks + idx if idx < 0 else idx for idx in n}
27
+ return take_indices, max(take_indices)
28
+
29
+
30
+ def forward_intermediates(
31
+ model: nn.Module,
32
+ patch_extractor: Callable[[torch.Tensor], torch.Tensor],
33
+ norm: nn.Module,
34
+ num_summary_tokens: int,
35
+ num_cls_tokens: int,
36
+ x: torch.Tensor,
37
+ indices: Optional[Union[int, List[int], Tuple[int]]] = None,
38
+ return_prefix_tokens: bool = False,
39
+ stop_early: bool = False,
40
+ output_fmt: str = 'NCHW',
41
+ intermediates_only: bool = False,
42
+ aggregation: Optional[str] = "sparse",
43
+ inter_feature_normalizer: Optional[IntermediateFeatureNormalizerBase] = None,
44
+ norm_alpha_scheme = "post-alpha",
45
+ block_kwargs: Dict = None,
46
+ ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
47
+ """ Forward features that returns intermediates.
48
+
49
+ The Dense layer aggregation method is inspired from the paper: "Dense Connector for MLLMs"
50
+ by Yao, Huanjin et al. (2024). arXiv preprint arXiv:2405.13800}
51
+
52
+ Args:
53
+ x: Input image tensor
54
+ indices: Take last n blocks if int, select matching indices if sequence
55
+ return_prefix_tokens: Return both prefix and spatial intermediate tokens
56
+ norm: Apply norm layer to all intermediates
57
+ stop_early: Stop iterating over blocks when last desired intermediate hit
58
+ output_fmt: Shape of intermediate feature outputs
59
+ intermediates_only: Only return intermediate features
60
+ aggregation: intermediate layer aggregation method (sparse or dense)
61
+ norm_alpha_scheme: apply alpha before ("pre-alpha") or after accumulation ("post-alpha")
62
+ Returns:
63
+ """
64
+ assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.'
65
+ assert aggregation in ('sparse', 'dense'), 'Aggregation must be one of sparse or dense.'
66
+ reshape = output_fmt == 'NCHW'
67
+ intermediates = []
68
+
69
+ block_kwargs = block_kwargs or dict()
70
+
71
+ blocks = model.blocks
72
+
73
+ take_indices, max_index = _take_indices(len(blocks), indices)
74
+ take_indices = sorted(take_indices)
75
+ # forward pass
76
+ B, _, height, width = x.shape
77
+
78
+ x = patch_extractor(x)
79
+
80
+ if stop_early:
81
+ blocks = blocks[:max_index + 1]
82
+
83
+ if inter_feature_normalizer is None or norm_alpha_scheme == 'none':
84
+ inter_feature_normalizer = NullIntermediateFeatureNormalizer.get_instance(x.dtype, x.device)
85
+
86
+ assert norm_alpha_scheme in ('none', 'pre-alpha', 'post-alpha'), f'Unsupported alpha scheme: {norm_alpha_scheme}'
87
+ post_alpha_scheme = norm_alpha_scheme == 'post-alpha'
88
+
89
+ accumulator = 0
90
+ alpha_sum = 0
91
+ num_accumulated = 0
92
+
93
+ take_off = 0
94
+
95
+ for i, blk in enumerate(blocks):
96
+ x = blk(x, **block_kwargs)
97
+ if aggregation == "dense":
98
+ # Arbitrarily use the rotation matrix from the final layer in the dense group
99
+ y, alpha = inter_feature_normalizer(x, i, rot_index=take_indices[take_off], skip=num_summary_tokens)
100
+ if post_alpha_scheme:
101
+ accumulator = accumulator + y
102
+ alpha_sum = alpha_sum + alpha
103
+ else:
104
+ accumulator = accumulator + (alpha * y)
105
+ alpha_sum += 1
106
+ num_accumulated += 1
107
+ if i == take_indices[take_off]:
108
+ if aggregation == "dense":
109
+ alpha = alpha_sum / num_accumulated
110
+ x_ = alpha * accumulator / num_accumulated
111
+ num_accumulated = 0
112
+ accumulator = 0
113
+ alpha_sum = 0
114
+ else:
115
+ y, alpha = inter_feature_normalizer(x, i, skip=num_summary_tokens)
116
+ x_ = alpha * y
117
+ # normalize intermediates with final norm layer if enabled
118
+ intermediates.append(norm(x_))
119
+ take_off = min(take_off + 1, len(take_indices) - 1)
120
+
121
+ # process intermediates
122
+
123
+ # split prefix (e.g. class, distill) and spatial feature tokens
124
+ prefix_tokens = [y[:, :num_cls_tokens] for y in intermediates]
125
+ intermediates = [y[:, num_summary_tokens:] for y in intermediates]
126
+
127
+ if reshape:
128
+ # reshape to BCHW output format
129
+ H = height // model.patch_size
130
+ W = width // model.patch_size
131
+ intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates]
132
+ if not torch.jit.is_scripting() and return_prefix_tokens:
133
+ # return_prefix not support in torchscript due to poor type handling
134
+ intermediates = list(zip(prefix_tokens, intermediates))
135
+ if intermediates_only:
136
+ return intermediates
137
+ x = norm(x)
138
+ return x, intermediates
tim/models/nvidia_radio/radio/hf_model.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from collections import namedtuple
15
+ from typing import Callable, Dict, Optional, List, Union
16
+
17
+ from timm.models import VisionTransformer
18
+ import torch
19
+ from torch import nn
20
+ from transformers import PretrainedConfig, PreTrainedModel
21
+
22
+
23
+ from .common import RESOURCE_MAP, DEFAULT_VERSION
24
+
25
+ # Import all required modules.
26
+ from .adaptor_base import AdaptorBase, RadioOutput, AdaptorInput
27
+ from .adaptor_generic import GenericAdaptor, AdaptorBase
28
+ from .adaptor_mlp import create_mlp_from_config
29
+ from .adaptor_registry import adaptor_registry
30
+ from .cls_token import ClsToken
31
+ from .dinov2_arch import dinov2_vitg14_reg
32
+ from .enable_cpe_support import enable_cpe
33
+ from .enable_spectral_reparam import configure_spectral_reparam_from_args
34
+ from .eradio_model import eradio
35
+ from .feature_normalizer import FeatureNormalizer, IntermediateFeatureNormalizer
36
+ from .forward_intermediates import forward_intermediates
37
+ from .radio_model import create_model_from_args
38
+ from .radio_model import RADIOModel as RADIOModelBase, Resolution
39
+ from .input_conditioner import get_default_conditioner, InputConditioner
40
+ from .open_clip_adaptor import OpenCLIP_RADIO
41
+ from .vit_patch_generator import ViTPatchGenerator
42
+ from .vitdet import apply_vitdet_arch, VitDetArgs
43
+
44
+ # Register extra models
45
+ from .extra_timm_models import *
46
+ from .extra_models import *
47
+
48
+
49
+ class RADIOConfig(PretrainedConfig):
50
+ """Pretrained Hugging Face configuration for RADIO models."""
51
+
52
+ def __init__(
53
+ self,
54
+ args: Optional[dict] = None,
55
+ version: Optional[str] = DEFAULT_VERSION,
56
+ patch_size: Optional[int] = None,
57
+ max_resolution: Optional[int] = None,
58
+ preferred_resolution: Optional[Resolution] = None,
59
+ adaptor_names: Union[str, List[str]] = None,
60
+ adaptor_configs: Dict[str, Dict[str, int]] = None,
61
+ vitdet_window_size: Optional[int] = None,
62
+ feature_normalizer_config: Optional[dict] = None,
63
+ inter_feature_normalizer_config: Optional[dict] = None,
64
+ **kwargs,
65
+ ):
66
+ self.args = args
67
+ for field in ["dtype", "amp_dtype"]:
68
+ if self.args is not None and field in self.args:
69
+ # Convert to a string in order to make it serializable.
70
+ # For example for torch.float32 we will store "float32",
71
+ # for "bfloat16" we will store "bfloat16".
72
+ self.args[field] = str(args[field]).split(".")[-1]
73
+ self.version = version
74
+ resource = RESOURCE_MAP[version]
75
+ self.patch_size = patch_size or resource.patch_size
76
+ self.max_resolution = max_resolution or resource.max_resolution
77
+ self.preferred_resolution = (
78
+ preferred_resolution or resource.preferred_resolution
79
+ )
80
+ self.adaptor_names = adaptor_names
81
+ self.adaptor_configs = adaptor_configs
82
+ self.vitdet_window_size = vitdet_window_size
83
+ self.feature_normalizer_config = feature_normalizer_config
84
+ self.inter_feature_normalizer_config = inter_feature_normalizer_config
85
+ super().__init__(**kwargs)
86
+
87
+
88
+
89
+ class RADIOModel(PreTrainedModel):
90
+ """Pretrained Hugging Face model for RADIO.
91
+
92
+ This class inherits from PreTrainedModel, which provides
93
+ HuggingFace's functionality for loading and saving models.
94
+ """
95
+
96
+ config_class = RADIOConfig
97
+
98
+ def __init__(self, config: RADIOConfig):
99
+ super().__init__(config)
100
+
101
+ RADIOArgs = namedtuple("RADIOArgs", config.args.keys())
102
+ args = RADIOArgs(**config.args)
103
+ self.config = config
104
+
105
+ model = create_model_from_args(args)
106
+ input_conditioner: InputConditioner = get_default_conditioner()
107
+
108
+ dtype = getattr(args, "dtype", torch.float32)
109
+ if isinstance(dtype, str):
110
+ # Convert the dtype's string representation back to a dtype.
111
+ dtype = getattr(torch, dtype)
112
+ model.to(dtype=dtype)
113
+ input_conditioner.dtype = dtype
114
+
115
+ summary_idxs = torch.tensor(
116
+ [i for i, t in enumerate(args.teachers) if t.get("use_summary", True)],
117
+ dtype=torch.int64,
118
+ )
119
+
120
+ adaptor_configs = config.adaptor_configs
121
+ adaptor_names = config.adaptor_names or []
122
+
123
+ adaptors = dict()
124
+ for adaptor_name in adaptor_names:
125
+ mlp_config = adaptor_configs[adaptor_name]
126
+ adaptor = GenericAdaptor(args, None, None, mlp_config)
127
+ adaptor.head_idx = mlp_config["head_idx"]
128
+ adaptors[adaptor_name] = adaptor
129
+
130
+ feature_normalizer = None
131
+ if config.feature_normalizer_config is not None:
132
+ # Actual normalization values will be restored when loading checkpoint weights.
133
+ feature_normalizer = FeatureNormalizer(config.feature_normalizer_config["embed_dim"])
134
+
135
+ inter_feature_normalizer = None
136
+ if config.inter_feature_normalizer_config is not None:
137
+ inter_feature_normalizer = IntermediateFeatureNormalizer(
138
+ config.inter_feature_normalizer_config["num_intermediates"],
139
+ config.inter_feature_normalizer_config["embed_dim"],
140
+ rot_per_layer=config.inter_feature_normalizer_config["rot_per_layer"],
141
+ dtype=dtype)
142
+
143
+ self.radio_model = RADIOModelBase(
144
+ model,
145
+ input_conditioner,
146
+ summary_idxs=summary_idxs,
147
+ patch_size=config.patch_size,
148
+ max_resolution=config.max_resolution,
149
+ window_size=config.vitdet_window_size,
150
+ preferred_resolution=config.preferred_resolution,
151
+ adaptors=adaptors,
152
+ feature_normalizer=feature_normalizer,
153
+ inter_feature_normalizer=inter_feature_normalizer,
154
+ )
155
+
156
+ @property
157
+ def adaptors(self) -> nn.ModuleDict:
158
+ return self.radio_model.adaptors
159
+
160
+ @property
161
+ def model(self) -> VisionTransformer:
162
+ return self.radio_model.model
163
+
164
+ @property
165
+ def input_conditioner(self) -> InputConditioner:
166
+ return self.radio_model.input_conditioner
167
+
168
+ @property
169
+ def num_summary_tokens(self) -> int:
170
+ return self.radio_model.num_summary_tokens
171
+
172
+ @property
173
+ def patch_size(self) -> int:
174
+ return self.radio_model.patch_size
175
+
176
+ @property
177
+ def max_resolution(self) -> int:
178
+ return self.radio_model.max_resolution
179
+
180
+ @property
181
+ def preferred_resolution(self) -> Resolution:
182
+ return self.radio_model.preferred_resolution
183
+
184
+ @property
185
+ def window_size(self) -> int:
186
+ return self.radio_model.window_size
187
+
188
+ @property
189
+ def min_resolution_step(self) -> int:
190
+ return self.radio_model.min_resolution_step
191
+
192
+ def make_preprocessor_external(self) -> Callable[[torch.Tensor], torch.Tensor]:
193
+ return self.radio_model.make_preprocessor_external()
194
+
195
+ def get_nearest_supported_resolution(self, height: int, width: int) -> Resolution:
196
+ return self.radio_model.get_nearest_supported_resolution(height, width)
197
+
198
+ def switch_to_deploy(self):
199
+ return self.radio_model.switch_to_deploy()
200
+
201
+ def forward(self, x: torch.Tensor):
202
+ return self.radio_model.forward(x)
tim/models/nvidia_radio/radio/input_conditioner.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ from typing import Union, Tuple
10
+
11
+ import torch
12
+ from torch import nn
13
+
14
+
15
+ norm_t = Union[Tuple[float, float, float], torch.Tensor]
16
+
17
+ class InputConditioner(nn.Module):
18
+ def __init__(self,
19
+ input_scale: float,
20
+ norm_mean: norm_t,
21
+ norm_std: norm_t,
22
+ dtype: torch.dtype = None,
23
+ ):
24
+ super().__init__()
25
+
26
+ self.dtype = dtype
27
+
28
+ self.register_buffer("norm_mean", _to_tensor(norm_mean) / input_scale)
29
+ self.register_buffer("norm_std", _to_tensor(norm_std) / input_scale)
30
+
31
+ def forward(self, x: torch.Tensor):
32
+ y = (x - self.norm_mean) / self.norm_std
33
+ if self.dtype is not None:
34
+ y = y.to(self.dtype)
35
+ return y
36
+
37
+
38
+ def get_default_conditioner():
39
+ from timm.data.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
40
+
41
+ return InputConditioner(
42
+ input_scale=1.0,
43
+ norm_mean=OPENAI_CLIP_MEAN,
44
+ norm_std=OPENAI_CLIP_STD,
45
+ )
46
+
47
+
48
+ def _to_tensor(v: norm_t):
49
+ return torch.as_tensor(v, dtype=torch.float32).view(-1, 1, 1)
tim/models/nvidia_radio/radio/open_clip_adaptor.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+ from argparse import Namespace
9
+
10
+ import torch
11
+ from torch import nn
12
+ import torch.nn.functional as F
13
+
14
+ from .adaptor_registry import adaptor_registry, dict_t, state_t
15
+
16
+ from .adaptor_generic import GenericAdaptor
17
+
18
+
19
+ class OpenCLIP_RADIO(GenericAdaptor):
20
+ def __init__(self, main_config: Namespace, adaptor_config: dict_t, state: state_t):
21
+ super().__init__(main_config, adaptor_config, state)
22
+
23
+ import open_clip
24
+
25
+ self.oc_model = open_clip.create_model_from_pretrained(
26
+ model_name=adaptor_config['model'],
27
+ pretrained=adaptor_config['pretrained'],
28
+ return_transform=False,
29
+ )
30
+ # Unload these parameters
31
+ self.oc_model.visual = None
32
+
33
+ self.tokenizer = open_clip.get_tokenizer(model_name=adaptor_config['model'])
34
+
35
+ def encode_text(self, text, normalize: bool = False):
36
+ return self.oc_model.encode_text(text, normalize=normalize)
37
+
38
+
39
+ @adaptor_registry.register_adaptor("open_clip")
40
+ def create_open_clip_adaptor(main_config: Namespace, adaptor_config: dict_t, state: state_t):
41
+ return OpenCLIP_RADIO(main_config, adaptor_config, state)
tim/models/nvidia_radio/radio/radio_model.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+ from typing import Callable, Dict, Iterable, List, NamedTuple, Optional, Tuple, Union
9
+
10
+ import torch
11
+ from torch import nn
12
+
13
+ from timm.models import create_model, VisionTransformer
14
+ from types import MethodType
15
+
16
+ from .enable_cpe_support import enable_cpe
17
+ from .input_conditioner import InputConditioner
18
+ from .adaptor_base import AdaptorBase, RadioOutput, AdaptorInput
19
+ from . import eradio_model
20
+ from .enable_spectral_reparam import configure_spectral_reparam_from_args
21
+ from .feature_normalizer import FeatureNormalizer, IntermediateFeatureNormalizer
22
+ from . import dual_hybrid_vit
23
+
24
+
25
+ class Resolution(NamedTuple):
26
+ height: int
27
+ width: int
28
+
29
+
30
+ class RADIOModel(nn.Module):
31
+ def __init__(
32
+ self,
33
+ model: nn.Module,
34
+ input_conditioner: InputConditioner,
35
+ patch_size: int,
36
+ max_resolution: int,
37
+ preferred_resolution: Resolution,
38
+ summary_idxs: Optional[torch.Tensor] = None,
39
+ window_size: int = None,
40
+ adaptors: Dict[str, AdaptorBase] = None,
41
+ feature_normalizer: Optional[FeatureNormalizer] = None,
42
+ inter_feature_normalizer: Optional[IntermediateFeatureNormalizer] = None,
43
+ ):
44
+ super().__init__()
45
+
46
+ self.model = model
47
+ self.input_conditioner = input_conditioner
48
+ if summary_idxs is not None:
49
+ self.register_buffer('summary_idxs', summary_idxs)
50
+ else:
51
+ self.summary_idxs = None
52
+
53
+ self._preferred_resolution = preferred_resolution
54
+ self._patch_size = patch_size
55
+ self._max_resolution = max_resolution
56
+ self._window_size = window_size
57
+
58
+ adaptors = adaptors or dict()
59
+ self.adaptors = nn.ModuleDict(adaptors)
60
+
61
+ if feature_normalizer is None:
62
+ feature_normalizer = nn.Identity()
63
+ self.feature_normalizer = feature_normalizer
64
+ self.inter_feature_normalizer = inter_feature_normalizer
65
+
66
+ @property
67
+ def num_summary_tokens(self) -> int:
68
+ if hasattr(self.model, 'num_summary_tokens'):
69
+ return self.model.num_summary_tokens
70
+
71
+ patch_gen = getattr(self.model, "patch_generator", None)
72
+ if patch_gen is not None:
73
+ return patch_gen.num_skip
74
+ elif getattr(self.model, 'global_pool', None) == 'avg':
75
+ return 0
76
+ return 1
77
+
78
+ @property
79
+ def num_cls_tokens(self) -> int:
80
+ if hasattr(self.model, 'num_cls_tokens'):
81
+ return self.model.num_cls_tokens
82
+
83
+ patch_gen = getattr(self.model, 'patch_generator', None)
84
+ if patch_gen is not None:
85
+ return patch_gen.num_cls_tokens
86
+ elif getattr(self.model, 'global_pool', None) == 'avg':
87
+ return 0
88
+ return 1
89
+
90
+ @property
91
+ def patch_size(self) -> int:
92
+ if self._patch_size is not None:
93
+ return self._patch_size
94
+ if hasattr(self.model, "patch_size"):
95
+ return self.model.patch_size
96
+ patch_gen = getattr(self.model, "patch_generator", None)
97
+ if patch_gen is not None:
98
+ return patch_gen.patch_size
99
+ return None
100
+
101
+ @property
102
+ def max_resolution(self) -> int:
103
+ return self._max_resolution
104
+
105
+ @property
106
+ def preferred_resolution(self) -> Resolution:
107
+ return self._preferred_resolution
108
+
109
+ @property
110
+ def window_size(self) -> int:
111
+ return self._window_size
112
+
113
+ @property
114
+ def min_resolution_step(self) -> int:
115
+ res = self.patch_size
116
+ if self.window_size is not None:
117
+ res *= self.window_size
118
+ return res
119
+
120
+ @property
121
+ def blocks(self) -> Iterable[nn.Module]:
122
+ blocks = getattr(self.model, 'blocks', None)
123
+ if blocks is not None:
124
+ return blocks
125
+ return None
126
+
127
+ @property
128
+ def embed_dim(self) -> int:
129
+ return self.model.embed_dim
130
+
131
+ def make_preprocessor_external(self) -> Callable[[torch.Tensor], torch.Tensor]:
132
+ ret = self.input_conditioner
133
+ self.input_conditioner = nn.Identity()
134
+ return ret
135
+
136
+ def get_nearest_supported_resolution(self, height: int, width: int) -> Resolution:
137
+ height = int(round(height / self.min_resolution_step) * self.min_resolution_step)
138
+ width = int(round(width / self.min_resolution_step) * self.min_resolution_step)
139
+
140
+ height = max(height, self.min_resolution_step)
141
+ width = max(width, self.min_resolution_step)
142
+
143
+ return Resolution(height=height, width=width)
144
+
145
+ def switch_to_deploy(self):
146
+ fn = getattr(self.model, 'switch_to_deploy', None)
147
+ if fn is not None:
148
+ fn()
149
+
150
+ def forward(self, x: torch.Tensor, feature_fmt: str = 'NLC') -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
151
+ '''
152
+ Forward process for model.
153
+ Args:
154
+ x: Input tensor. Unless `make_preprocessor_external` has been called, then the dynamic range of `x` is expected to be `[0, 1]`,
155
+ otherwise `x` is expected to be mean centered with unit standard deviation.
156
+ feature_format: ['NLC', 'NCHW'] - The output format for the features.
157
+ '''
158
+ res_step = self.min_resolution_step
159
+ if res_step is not None and (x.shape[-2] % res_step != 0 or x.shape[-1] % res_step != 0):
160
+ raise ValueError('The input resolution must be a multiple of `self.min_resolution_step`. '
161
+ '`self.get_nearest_supported_resolution(<height>, <width>) is provided as a convenience API. '
162
+ f'Input: {x.shape[-2:]}, Nearest: {self.get_nearest_supported_resolution(*x.shape[-2:])}')
163
+
164
+ x = self.input_conditioner(x)
165
+ y = self.model.forward_features(x)
166
+ ret = self._extract_final(x, y, feature_fmt=feature_fmt)
167
+ return ret
168
+
169
+ def forward_pack(self, x: List[torch.Tensor], feature_fmt: str = 'NLC') -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
170
+ '''
171
+ Forward process for model.
172
+ Args:
173
+ x: Input tensor. Unless `make_preprocessor_external` has been called, then the dynamic range of `x` is expected to be `[0, 1]`,
174
+ otherwise `x` is expected to be mean centered with unit standard deviation.
175
+ feature_format: ['NLC', 'NCHW'] - The output format for the features.
176
+ '''
177
+ res_step = self.min_resolution_step
178
+ for _x in x:
179
+ if res_step is not None and (_x.shape[-2] % res_step != 0 or _x.shape[-1] % res_step != 0):
180
+ raise ValueError('The input resolution must be a multiple of `self.min_resolution_step`. '
181
+ '`self.get_nearest_supported_resolution(<height>, <width>) is provided as a convenience API. '
182
+ f'Input: {_x.shape[-2:]}, Nearest: {self.get_nearest_supported_resolution(*_x.shape[-2:])}')
183
+
184
+ x = [self.input_conditioner(_x) for _x in x]
185
+ y, cu_seqlens = self.model.forward_features(x)
186
+ all_summary, spatial_features = [], []
187
+ num_cls_tokens = self.model.patch_generator.num_cls_tokens
188
+ num_skip = self.model.patch_generator.num_skip
189
+ for i in range(len(cu_seqlens)-1):
190
+ summary = y[cu_seqlens[i]: cu_seqlens[i+1]][: num_cls_tokens]
191
+ all_feat = y[cu_seqlens[i]: cu_seqlens[i+1]][num_skip :]
192
+ all_summary.append(summary)
193
+ spatial_features.append(all_feat)
194
+ all_summary = torch.cat(all_summary)
195
+ spatial_features = torch.cat(spatial_features)
196
+ return all_summary, spatial_features
197
+
198
+ def _extract_final(self, x: torch.Tensor, y: torch.Tensor, feature_fmt: str = 'NLC'):
199
+ if isinstance(self.model, VisionTransformer):
200
+ patch_gen = getattr(self.model, "patch_generator", None)
201
+ if patch_gen is not None:
202
+ all_summary = y[:, : patch_gen.num_cls_tokens]
203
+ if self.summary_idxs is not None:
204
+ bb_summary = all_summary[:, self.summary_idxs]
205
+ else:
206
+ bb_summary = all_summary
207
+ all_feat = y[:, patch_gen.num_skip :]
208
+ elif self.model.global_pool == "avg":
209
+ all_summary = y[:, self.model.num_prefix_tokens :].mean(dim=1)
210
+ bb_summary = all_summary
211
+ all_feat = y
212
+ else:
213
+ all_summary = y[:, 0]
214
+ bb_summary = all_summary
215
+ all_feat = y[:, 1:]
216
+ elif isinstance(self.model, eradio_model.ERADIO):
217
+ _, f = y
218
+ all_feat = f.flatten(2).transpose(1, 2)
219
+ all_summary = all_feat.mean(dim=1)
220
+ bb_summary = all_summary
221
+ elif isinstance(y, (list, tuple)):
222
+ all_summary, all_feat = y
223
+ bb_summary = all_summary
224
+ else:
225
+ all_summary = y[:, :self.num_cls_tokens]
226
+ if self.summary_idxs is not None and all_summary.shape[1] > 1:
227
+ if all_summary.shape[1] == 1:
228
+ # Create dummy duplicates
229
+ all_summary = all_summary.expand(-1, 128, -1)
230
+ bb_summary = all_summary[:, self.summary_idxs]
231
+ else:
232
+ bb_summary = all_summary
233
+ all_feat = y[:, self.num_summary_tokens:]
234
+
235
+ all_feat = self.feature_normalizer(all_feat)
236
+
237
+ if feature_fmt == 'NCHW':
238
+ fmt_feat = (all_feat.reshape(all_feat.shape[0], x.shape[-2] // self.patch_size, x.shape[-1] // self.patch_size, all_feat.shape[2])
239
+ .permute(0, 3, 1, 2)
240
+ )
241
+ elif feature_fmt == 'NLC':
242
+ fmt_feat = all_feat
243
+ else:
244
+ raise ValueError(f'Unsupported feature_fmt: {feature_fmt}. Must be one of ["NLC", "NCHW"]')
245
+
246
+ ret = RadioOutput(bb_summary.flatten(1), fmt_feat)
247
+
248
+ if self.adaptors:
249
+ ret = dict(backbone=ret)
250
+ for name, adaptor in self.adaptors.items():
251
+ if all_summary.ndim == 3:
252
+ if all_summary.shape[1] == 1:
253
+ summary = all_summary[:, 0]
254
+ else:
255
+ summary = all_summary[:, adaptor.head_idx]
256
+ else:
257
+ summary = all_summary
258
+ ada_input = AdaptorInput(images=x, summary=summary.float(), features=all_feat, feature_fmt=feature_fmt, patch_size=self.patch_size)
259
+ v = adaptor(ada_input).to(torch.float32)
260
+ ret[name] = v
261
+
262
+ return ret
263
+
264
+ def forward_intermediates(
265
+ self,
266
+ x: torch.Tensor,
267
+ indices: Optional[Union[int, List[int], Tuple[int]]] = None,
268
+ return_prefix_tokens: bool = False,
269
+ norm: bool = False,
270
+ stop_early: bool = False,
271
+ output_fmt: str = 'NCHW',
272
+ intermediates_only: bool = False,
273
+ aggregation: Optional[str] = "sparse",
274
+ norm_alpha_scheme: Optional[str] = "post-alpha",
275
+ ) -> List[RadioOutput]:
276
+ """ Forward features that returns intermediates.
277
+ Args:
278
+ x: Input image tensor
279
+ indices: Take last n blocks if int, select matching indices if sequence
280
+ return_prefix_tokens: Return both prefix and spatial intermediate tokens
281
+ norm: Apply norm layer to all intermediates
282
+ stop_early: Stop iterating over blocks when last desired intermediate hit
283
+ output_fmt: Shape of intermediate feature outputs. Options: NCHW, NLC
284
+ intermediates_only: Only return intermediate features
285
+ aggregation: intermediate layer aggregation method (sparse or dense).
286
+ Dense accumulation is done by averaging the features in each group.
287
+ norm_alpha_scheme: apply alpha before ("pre-alpha") or after accumulation ("post-alpha"), or don't normalize ("none")
288
+ Only affects dense aggregation
289
+ Returns:
290
+ List of RadioOutput objects.
291
+ """
292
+ x = self.input_conditioner(x)
293
+ intermediates = self.model.forward_intermediates(
294
+ x,
295
+ indices=indices,
296
+ return_prefix_tokens=return_prefix_tokens,
297
+ norm=norm,
298
+ stop_early=stop_early,
299
+ output_fmt=output_fmt,
300
+ intermediates_only=intermediates_only,
301
+ aggregation=aggregation,
302
+ inter_feature_normalizer=self.inter_feature_normalizer,
303
+ norm_alpha_scheme=norm_alpha_scheme,
304
+ )
305
+
306
+ if not intermediates_only:
307
+ final, intermediates = intermediates
308
+
309
+ def prepare_summary(summ: Optional[torch.Tensor]):
310
+ if summ is None:
311
+ return summ
312
+ if self.summary_idxs is not None and summ.shape[1] > 1:
313
+ summ = summ[:, self.summary_idxs]
314
+ return summ.flatten(1)
315
+
316
+ if return_prefix_tokens:
317
+ radio_outputs = [
318
+ RadioOutput(prepare_summary(summary), features)
319
+ for summary, features in intermediates
320
+ ]
321
+ else:
322
+ radio_outputs = intermediates
323
+
324
+ if intermediates_only:
325
+ return radio_outputs
326
+ else:
327
+ final = self._extract_final(x, final, feature_fmt=output_fmt)
328
+ return final, radio_outputs
329
+
330
+
331
+
332
+ def create_model_from_args(args) -> nn.Module:
333
+ in_chans = 3
334
+ if args.in_chans is not None:
335
+ in_chans = args.in_chans
336
+ elif args.input_size is not None:
337
+ in_chans = args.input_size[0]
338
+
339
+ # Skip weight initialization unless it's explicitly requested.
340
+ weight_init = args.model_kwargs.pop("weight_init", "skip")
341
+
342
+ model = create_model(
343
+ args.model,
344
+ pretrained=args.pretrained,
345
+ in_chans=in_chans,
346
+ num_classes=args.num_classes,
347
+ drop_rate=args.drop,
348
+ drop_path_rate=args.drop_path,
349
+ drop_block_rate=args.drop_block,
350
+ global_pool=args.gp,
351
+ bn_momentum=args.bn_momentum,
352
+ bn_eps=args.bn_eps,
353
+ scriptable=args.torchscript,
354
+ checkpoint_path=args.initial_checkpoint,
355
+ weight_init=weight_init,
356
+ **args.model_kwargs,
357
+ )
358
+
359
+ if hasattr(model, 'norm') and not getattr(args, 'model_norm', False):
360
+ model.norm = nn.Identity()
361
+
362
+ model.head = nn.Identity()
363
+
364
+ if args.cpe_max_size is not None:
365
+ uq_teachers = set(t['name'] for t in args.teachers)
366
+ enable_cpe(
367
+ model,
368
+ args.cpe_max_size,
369
+ num_cls_tokens=len(uq_teachers) if args.cls_token_per_teacher else 1,
370
+ register_multiple=getattr(args, 'register_multiple', None),
371
+ num_registers=getattr(args, 'cpe_num_registers', None),
372
+ support_packing=args.support_packing,
373
+ )
374
+
375
+ return model
tim/models/nvidia_radio/radio/vision_transformer_xpos.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Final, List, Optional, Tuple, Union
3
+
4
+
5
+ from einops import rearrange
6
+ from timm.models import register_model
7
+ import torch
8
+ from torch import Type, nn
9
+ from torch.nn import functional as F
10
+ from torch.nn.init import xavier_normal_, xavier_uniform_, zeros_
11
+
12
+ from .forward_intermediates import forward_intermediates
13
+
14
+
15
+ def _get_init_scale(num_encoder_layers: int, num_decoder_layers: int, is_encoder: bool):
16
+ if num_encoder_layers > 0 and num_decoder_layers == 0:
17
+ return math.sqrt(math.log(2 * num_encoder_layers))
18
+ if num_decoder_layers > 0 and num_encoder_layers == 0:
19
+ return math.sqrt(math.log(2 * num_decoder_layers))
20
+ if is_encoder:
21
+ # Both encoders and decoders
22
+ return math.sqrt(
23
+ 0.33 * math.log(3 * num_decoder_layers) * math.log(2 * num_encoder_layers)
24
+ )
25
+
26
+ return math.sqrt(math.log(3 * num_decoder_layers))
27
+
28
+
29
+ # [1,2] [1,1,2,2]
30
+ # [3,4] -> [3,3,4,4]
31
+ # [5,6] [5,5,6,6]
32
+ def duplicate_interleave(m):
33
+ return m.view(-1, 1).repeat(1, 2).view(m.shape[0], -1)
34
+
35
+ # 0,1,2,3,4,5,6,7 -> -1,0,-3,2,-5,4,-7,6
36
+ def rotate_every_two(x):
37
+ x1 = x[:, :, ::2]
38
+ x2 = x[:, :, 1::2]
39
+ x = torch.stack((-x2, x1), dim=-1)
40
+ return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')\
41
+
42
+
43
+ class XPosEmbedding2D(torch.nn.Module):
44
+ """Implementation of xPos based on RotaryEmbedding from GPT-NeoX.
45
+ This implementation is designed to operate on queries and keys that are compatible with
46
+ [batch_size, n_heads_per_partition, seq_len, head_dim] (e.g. MinGPTAttention format).
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ head_dim: int,
52
+ base=50000,
53
+ scale_base=512
54
+ ):
55
+ super().__init__()
56
+ half_dim = head_dim // 2
57
+ self.half_dim = half_dim
58
+ inv_freq = 1.0 / (base ** (torch.arange(0, half_dim, 2).float() / half_dim))
59
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
60
+ self.head_dim = head_dim
61
+ self.token_shape_cached = None
62
+ self.batch_size_cached = None
63
+ self.cos_cached: torch.Tensor | None = None
64
+ self.sin_cached: torch.Tensor | None = None
65
+ self.scale_cached: torch.Tensor | None = None
66
+ self.scale_base = scale_base
67
+ self.register_buffer("scale",
68
+ (torch.arange(0, half_dim, 2) + 0.4 * half_dim) / (1.4 * half_dim))
69
+
70
+ def cos_sin(
71
+ self,
72
+ token_shape: Tuple[int, int],
73
+ device="cuda",
74
+ dtype=torch.bfloat16,
75
+ ) -> torch.Tensor:
76
+ if token_shape != self.token_shape_cached:
77
+ self.token_shape_cached = token_shape
78
+ y = torch.arange(token_shape[0], device=device, dtype=self.inv_freq.dtype)
79
+ x = torch.arange(token_shape[1], device=device, dtype=self.inv_freq.dtype)
80
+ x, y = torch.meshgrid(x, y, indexing='xy')
81
+
82
+ y_freqs = torch.einsum("i,j->ij", y.flatten(), self.inv_freq)
83
+ x_freqs = torch.einsum("i,j->ij", x.flatten(), self.inv_freq)
84
+
85
+ y_scales = self.scale ** y.flatten().div(self.scale_base)[:, None]
86
+ x_scales = self.scale ** x.flatten().div(self.scale_base)[:, None]
87
+
88
+ freqs = torch.cat([y_freqs, x_freqs], dim=-1)
89
+ emb = torch.repeat_interleave(freqs, repeats=2, dim=-1)
90
+
91
+ scales = torch.cat([y_scales, x_scales], dim=-1)
92
+ scales = torch.repeat_interleave(scales, repeats=2, dim=-1)
93
+
94
+ if dtype in [torch.float16, torch.bfloat16]:
95
+ emb = emb.float()
96
+
97
+ self.cos_cached = emb.cos()[None, :, :]
98
+ self.sin_cached = emb.sin()[None, :, :]
99
+ self.scale_cached = scales[None, :, :]
100
+
101
+ self.cos_cached = self.cos_cached.type(dtype)
102
+ self.sin_cached = self.sin_cached.type(dtype)
103
+ self.scale_cached = self.scale_cached.type(dtype)
104
+
105
+ return self.cos_cached, self.sin_cached, self.scale_cached
106
+
107
+ def forward(self, q: torch.Tensor, k: torch.Tensor, token_shape: Tuple[int, int]):
108
+ batch, seq_len, head_dim = q.shape
109
+ cos, sin, scale = self.cos_sin(token_shape, q.device, q.dtype)
110
+ # scale = self.scale**torch.arange(seq_len).to(self.scale).div(self.scale_base)[:, None]
111
+ # scale = torch.repeat_interleave(scale, 2, dim=-1).to(q.device)
112
+ # scale = torch.cat([scale, scale], dim=-1)
113
+ # scale = 1
114
+ return (
115
+ (q * cos * scale) + (rotate_every_two(q) * sin * scale),
116
+ (k * cos * (1 / scale)) + (rotate_every_two(k) * sin * (1 / scale)),
117
+ )
118
+
119
+
120
+ class MagnetoAttention(nn.Module):
121
+ def __init__(self, d_model: int, n_head: int, pos_emb: XPosEmbedding2D):
122
+ super().__init__()
123
+ self.num_heads = n_head
124
+ self.head_dim = d_model // n_head
125
+ self.scale = self.head_dim ** -0.5
126
+
127
+ self.qkv = nn.Linear(d_model, d_model * 3, bias=False)
128
+ self.proj = nn.Linear(d_model, d_model)
129
+ self.pos_emb = pos_emb
130
+
131
+ self.norm0 = nn.LayerNorm(d_model)
132
+ self.norm1 = nn.LayerNorm(d_model)
133
+
134
+ def forward(self, x: torch.Tensor, num_prefix_tokens: int, patch_shape: Tuple[int, int]) -> torch.Tensor:
135
+ B, N, C = x.shape
136
+ x = self.norm0(x)
137
+
138
+ qkv = self.qkv(x).reshape(B, N, 3, C).permute(2, 0, 1, 3)
139
+ q, k, v = qkv.unbind(0)
140
+
141
+ q_pref = q[:, :num_prefix_tokens]
142
+ q_patch = q[:, num_prefix_tokens:]
143
+
144
+ k_pref = k[:, :num_prefix_tokens]
145
+ k_patch = k[:, num_prefix_tokens:]
146
+
147
+ q_patch, k_patch = self.pos_emb(q_patch, k_patch, patch_shape)
148
+
149
+ q = torch.cat([q_pref, q_patch], dim=1)
150
+ k = torch.cat([k_pref, k_patch], dim=1)
151
+
152
+ def head_reshape(t: torch.Tensor):
153
+ return t.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
154
+
155
+ q = head_reshape(q)
156
+ k = head_reshape(k)
157
+ v = head_reshape(v)
158
+
159
+ x = F.scaled_dot_product_attention(q, k, v)
160
+ x = x.transpose(1, 2).reshape(B, N, C)
161
+ x = self.norm1(x)
162
+ x = self.proj(x)
163
+ return x
164
+
165
+ def _reset_parameters(self):
166
+ xavier_uniform_(self.qkv.weight)
167
+ if self.qkv.bias is not None:
168
+ zeros_(self.qkv.bias)
169
+ xavier_normal_(self.proj.weight)
170
+ zeros_(self.proj.bias)
171
+
172
+
173
+ class MagnetoTransformerEncoderLayer(nn.Module):
174
+ def __init__(self, d_model: int, nhead: int, pos_emb: XPosEmbedding2D,
175
+ num_encoder_layers: int, num_decoder_layers: int = 0,
176
+ dim_mhsa: int = 0,
177
+ dim_feedforward: int = 2048,
178
+ layer_norm_eps: float = 1e-5,
179
+ batch_first: bool = True):
180
+ super().__init__()
181
+
182
+ if dim_mhsa == 0:
183
+ dim_mhsa = d_model
184
+
185
+ self._num_encoder_layers = num_encoder_layers
186
+ self._num_decoder_layers = num_decoder_layers
187
+
188
+ self.attn = MagnetoAttention(d_model, nhead, pos_emb)
189
+
190
+ self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
191
+ self.linear2 = nn.Linear(d_model, dim_feedforward)
192
+ self.norm3 = nn.LayerNorm(dim_feedforward, eps=layer_norm_eps)
193
+ self.linear3 = nn.Linear(dim_feedforward, d_model)
194
+
195
+ def initialize(self):
196
+ gamma = _get_init_scale(self._num_encoder_layers, self._num_decoder_layers, is_encoder=True)
197
+
198
+ # Magneto Initialization
199
+ for mod in self.children():
200
+ if isinstance(mod, nn.Linear):
201
+ xavier_normal_(mod.weight.data, gamma)
202
+ elif isinstance(mod, MagnetoAttention):
203
+ mod._reset_parameters()
204
+
205
+ def forward(self, x: torch.Tensor, num_prefix_tokens: int, patch_shape: Tuple[int, int]) -> torch.Tensor:
206
+ x = x + self._sa_block(x, num_prefix_tokens, patch_shape)
207
+ x = x + self._ff_block(x)
208
+ return x
209
+
210
+ def _sa_block(self, x: torch.Tensor, num_prefix_tokens: int, patch_shape: Tuple[int, int]) -> torch.Tensor:
211
+ x = self.attn(x, num_prefix_tokens, patch_shape)
212
+ return x
213
+
214
+ def _ff_block(self, x: torch.Tensor) -> torch.Tensor:
215
+ x = self.norm2(x)
216
+ x = self.linear2(x)
217
+ x = F.gelu(x)
218
+ x = self.norm3(x)
219
+ x = self.linear3(x)
220
+ return x
221
+
222
+
223
+ class VisionTransformer(nn.Module):
224
+ """ Vision Transformer
225
+
226
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
227
+ - https://arxiv.org/abs/2010.11929
228
+ """
229
+ dynamic_img_size: Final[bool]
230
+
231
+ def __init__(
232
+ self,
233
+ patch_size: Union[int, Tuple[int, int]] = 16,
234
+ in_chans: int = 3,
235
+ embed_dim: int = 768,
236
+ depth: int = 12,
237
+ num_heads: int = 12,
238
+ mlp_ratio: float = 4.,
239
+ num_cls_tokens: int = 1,
240
+ num_reg_tokens: int = 0,
241
+ ) -> None:
242
+ """
243
+ Args:
244
+ patch_size: Patch size.
245
+ in_chans: Number of image input channels.
246
+ embed_dim: Transformer embedding dimension.
247
+ depth: Depth of transformer.
248
+ num_heads: Number of attention heads.
249
+ mlp_ratio: Ratio of mlp hidden dim to embedding dim.
250
+ num_cls_tokens: Number of cls tokens
251
+ num_reg_tokens: Number of register tokens.
252
+ block_fn: Transformer block layer.
253
+ """
254
+ super().__init__()
255
+
256
+ self.patch_size = patch_size
257
+ self.embed_dim = embed_dim
258
+ self.num_cls_tokens = num_cls_tokens
259
+ self.num_reg_tokens = num_reg_tokens
260
+
261
+ self.patch_embed = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
262
+
263
+ self.prefix_buffer = nn.Parameter(torch.randn(1, self.num_prefix_tokens, embed_dim) * .02)
264
+
265
+ pos_emb = XPosEmbedding2D(embed_dim)
266
+
267
+ self.blocks = nn.ModuleList([
268
+ MagnetoTransformerEncoderLayer(
269
+ d_model=embed_dim,
270
+ nhead=num_heads,
271
+ num_encoder_layers=depth,
272
+ num_decoder_layers=0,
273
+ dim_feedforward=int(embed_dim * mlp_ratio),
274
+ pos_emb=pos_emb,
275
+ )
276
+ for _ in range(depth)
277
+ ])
278
+
279
+ for block in self.blocks:
280
+ block.initialize()
281
+
282
+ @property
283
+ def num_prefix_tokens(self):
284
+ return self.num_cls_tokens + self.num_reg_tokens
285
+
286
+ @property
287
+ def num_summary_tokens(self):
288
+ return self.num_prefix_tokens
289
+
290
+ def forward_features(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
291
+ x, patch_shape = self._patchify(x)
292
+
293
+ for block in self.blocks:
294
+ x = block(x, self.num_prefix_tokens, patch_shape)
295
+
296
+ summary = x[:, :self.num_cls_tokens]
297
+ features = x[:, self.num_prefix_tokens:]
298
+
299
+ return summary, features
300
+
301
+ def forward_intermediates(self, x: torch.Tensor, norm: bool = False, **kwargs):
302
+ patch_shape = tuple(d // self.patch_size for d in x.shape[-2:])
303
+
304
+ def patch_extractor(x: torch.Tensor):
305
+ x, _ = self._patchify(x)
306
+ return x
307
+
308
+ return forward_intermediates(
309
+ self,
310
+ patch_extractor=patch_extractor,
311
+ num_summary_tokens=self.num_prefix_tokens,
312
+ num_cls_tokens=self.num_cls_tokens,
313
+ norm=lambda y: y,
314
+ x=x,
315
+ block_kwargs=dict(num_prefix_tokens=self.num_prefix_tokens, patch_shape=patch_shape),
316
+ **kwargs,
317
+ )
318
+
319
+ def _patchify(self, x: torch.Tensor):
320
+ x = self.patch_embed(x)
321
+ patch_shape = x.shape[-2:]
322
+ x = rearrange(x, 'b c h w -> b (h w) c')
323
+
324
+ prefix = self.prefix_buffer.expand(x.shape[0], -1, -1)
325
+
326
+ x = torch.cat([prefix, x], dim=1)
327
+ return x, patch_shape
328
+
329
+
330
+ @register_model
331
+ def vit_base_patch16_xpos(num_cls_tokens: int = 1, num_reg_tokens: int = 0, **kwargs) -> VisionTransformer:
332
+ return VisionTransformer(patch_size=16, embed_dim=768, depth=12, num_heads=12,
333
+ num_cls_tokens=num_cls_tokens, num_reg_tokens=num_reg_tokens)
334
+
335
+
336
+ @register_model
337
+ def vit_large_patch16_xpos(num_cls_tokens: int = 1, num_reg_tokens: int = 0, **kwargs) -> VisionTransformer:
338
+ return VisionTransformer(patch_size=16, embed_dim=1024, depth=24, num_heads=16,
339
+ num_cls_tokens=num_cls_tokens, num_reg_tokens=num_reg_tokens)
340
+
341
+
342
+ @register_model
343
+ def vit_huge_patch16_xpos(num_cls_tokens: int = 1, num_reg_tokens: int = 0, **kwargs) -> VisionTransformer:
344
+ return VisionTransformer(patch_size=16, embed_dim=1280, depth=32, num_heads=16,
345
+ num_cls_tokens=num_cls_tokens, num_reg_tokens=num_reg_tokens)
346
+
347
+
348
+ @register_model
349
+ def vit_giant_patch16_xpos(num_cls_tokens: int = 1, num_reg_tokens: int = 0, **kwargs) -> VisionTransformer:
350
+ return VisionTransformer(patch_size=16, embed_dim=1408, depth=40, num_heads=16,
351
+ num_cls_tokens=num_cls_tokens, num_reg_tokens=num_reg_tokens)
352
+
353
+
354
+ @register_model
355
+ def vit_bigG_patch16_xpos(num_cls_tokens: int = 1, num_reg_tokens: int = 0, **kwargs) -> VisionTransformer:
356
+ return VisionTransformer(patch_size=16, embed_dim=1664, depth=48, num_heads=16,
357
+ num_cls_tokens=num_cls_tokens, num_reg_tokens=num_reg_tokens)
tim/models/nvidia_radio/radio/vit_patch_generator.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import math
10
+ from typing import Union, Tuple, Optional
11
+
12
+ import torch
13
+ import torch.nn.functional as F
14
+ from torch import nn
15
+ from einops import rearrange
16
+
17
+ from .cls_token import ClsToken
18
+
19
+ input_dim_t = Union[int, Tuple[int, int]]
20
+
21
+ try:
22
+ # raise ImportError()
23
+ from indirect_grid_sample import indirect_grid_sample
24
+ except ImportError:
25
+ indirect_grid_sample = None
26
+
27
+ class ViTPatchGenerator(nn.Module):
28
+ def __init__(self,
29
+ patch_size: int,
30
+ embed_dim: int,
31
+ input_dims: input_dim_t,
32
+ abs_pos: bool = True,
33
+ normalize_patches: bool = False,
34
+ cls_token: bool = False,
35
+ max_input_dims: Optional[input_dim_t] = None,
36
+ pos_dropout: float = 0.0,
37
+ return_pos_enc: bool = False,
38
+ num_cls_tokens: int = 1,
39
+ register_multiple: Optional[int] = None,
40
+ num_registers: Optional[int] = None,
41
+ patch_bias: bool = False,
42
+ device=None, dtype=None,
43
+ ):
44
+ super().__init__()
45
+
46
+ if isinstance(input_dims, int):
47
+ input_dims = (input_dims, input_dims)
48
+
49
+ if max_input_dims is None:
50
+ max_input_dims = input_dims
51
+ if isinstance(max_input_dims, int):
52
+ max_input_dims = (max_input_dims, max_input_dims)
53
+
54
+ max_input_dims = tuple(
55
+ int(math.ceil(d / patch_size) * patch_size)
56
+ for d in max_input_dims
57
+ )
58
+
59
+ self.cpe_mode = max_input_dims != input_dims
60
+ self.pos_dropout = pos_dropout
61
+ self.return_pos_enc = return_pos_enc
62
+
63
+ factory = dict(device=device, dtype=dtype)
64
+
65
+ self.patch_size = patch_size
66
+ self.abs_pos = abs_pos
67
+ self.embed_dim = embed_dim
68
+
69
+ self.num_rows = max_input_dims[0] // patch_size
70
+ self.num_cols = max_input_dims[1] // patch_size
71
+ self.input_dims = tuple(d // patch_size for d in input_dims)
72
+ self.num_patches = self.num_rows * self.num_cols
73
+ self.max_input_dims = max_input_dims
74
+
75
+ self.im_to_patches = Im2Patches(patch_size)
76
+ self.embedder = ViTPatchLinear(patch_size, embed_dim, bias=patch_bias, **factory)
77
+
78
+ if abs_pos:
79
+ scale = embed_dim ** -0.5
80
+ self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches, embed_dim, **factory) * scale)
81
+
82
+ self.cls_token = ClsToken(
83
+ embed_dim,
84
+ num_tokens=num_cls_tokens,
85
+ enabled=cls_token,
86
+ register_multiple=register_multiple,
87
+ num_registers=num_registers,
88
+ )
89
+
90
+ self.patch_normalizer = nn.LayerNorm(embed_dim) if normalize_patches else nn.Identity()
91
+
92
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
93
+ patches = self.embed_patches(x)
94
+ patches, pos_enc = self.apply_pos_enc(patches, input_size=x.shape[2:])
95
+ patches = self.cls_token(patches)
96
+ patches = self.patch_normalizer(patches)
97
+ if self.return_pos_enc:
98
+ return patches, pos_enc
99
+ return patches
100
+
101
+ @property
102
+ def apply_cls_token(self):
103
+ return self.cls_token.enabled
104
+
105
+ @property
106
+ def num_cls_tokens(self):
107
+ return self.cls_token.num_tokens
108
+
109
+ @property
110
+ def num_cls_patches(self):
111
+ return self.cls_token.num_patches
112
+
113
+ @property
114
+ def num_registers(self):
115
+ return self.cls_token.num_registers
116
+
117
+ @property
118
+ def num_skip(self):
119
+ return self.num_cls_tokens + self.num_registers
120
+
121
+ def no_weight_decay(self):
122
+ return [
123
+ 'pos_embed',
124
+ ]
125
+
126
+ def _load_embed(self, src_embed: torch.Tensor, targ_embed: nn.Parameter):
127
+ if src_embed.shape != targ_embed.shape:
128
+ src_size = int(math.sqrt(src_embed.shape[1]))
129
+
130
+ assert src_size ** 2 == src_embed.shape[1], 'Unable to interpolate non-square embedding'
131
+
132
+ src_embed = rearrange(src_embed, 'b (h w) c -> b c h w', h=src_size, w=src_size)
133
+ src_embed = F.interpolate(src_embed, size=(self.num_rows, self.num_cols), mode='bicubic', align_corners=True, antialias=False)
134
+ src_embed = rearrange(src_embed, 'b c h w -> b (h w) c')
135
+ targ_embed.data.copy_(src_embed)
136
+
137
+ def _load_projection(self, src_proj_weight: torch.Tensor, targ_proj_weight: torch.Tensor):
138
+ if src_proj_weight.shape != targ_proj_weight.shape:
139
+ src_patch_size = int(math.sqrt(src_proj_weight.shape[1] // 3))
140
+
141
+ assert (src_patch_size ** 2) * 3 == src_proj_weight.shape[1], 'Unable to interpolate non-square patch size'
142
+
143
+ src_proj_weight = rearrange(src_proj_weight, 'b (c h w) -> b c h w', c=3, h=src_patch_size, w=src_patch_size)
144
+ src_proj_weight = F.interpolate(src_proj_weight, size=(self.patch_size, self.patch_size), mode='bicubic', align_corners=True, antialias=False)
145
+ src_proj_weight = rearrange(src_proj_weight, 'b c h w -> b (c h w)')
146
+ targ_proj_weight.data.copy_(src_proj_weight)
147
+
148
+ def embed_patches(self, x: torch.Tensor) -> torch.Tensor:
149
+ patches = self.im_to_patches(x)
150
+ patches = self.embedder(patches)
151
+ return patches
152
+
153
+ def apply_pos_enc(self,
154
+ patches: torch.Tensor,
155
+ patch_idxs: Optional[torch.Tensor] = None,
156
+ input_size: Optional[Tuple[int, int]] = None,
157
+ ) -> torch.Tensor:
158
+ if not self.abs_pos:
159
+ return patches
160
+
161
+ pos_enc = self.get_pos_enc(patches.shape[0], patch_idxs, input_size)
162
+
163
+ if self.training and self.pos_dropout > 0:
164
+ keeps = torch.rand(patches.shape[0], 1, 1, dtype=pos_enc.dtype, device=pos_enc.device) > self.pos_dropout
165
+ pos_enc_drop = torch.where(keeps, pos_enc, 0)
166
+ else:
167
+ pos_enc_drop = pos_enc
168
+
169
+ return patches + pos_enc_drop, pos_enc
170
+
171
+ def get_pos_enc(self,
172
+ batch_size: int,
173
+ patch_idxs: Optional[torch.Tensor] = None,
174
+ input_size: Optional[Tuple[int, int]] = None,
175
+ ) -> torch.Tensor:
176
+ if input_size is None:
177
+ input_dims = self.input_dims
178
+ else:
179
+ input_dims = tuple(d // self.patch_size for d in input_size)
180
+
181
+ pos_embed = self._get_pos_embeddings(batch_size, input_dims)
182
+
183
+ if patch_idxs is None:
184
+ return pos_embed
185
+
186
+ exp_patch_idxs = patch_idxs.unsqueeze(-1).expand(-1, -1, pos_embed.shape[-1])
187
+
188
+ pos_embed = torch.gather(pos_embed.expand(patch_idxs.shape[0], -1, -1), dim=1, index=exp_patch_idxs)
189
+ return pos_embed
190
+
191
+
192
+ def _get_pos_embeddings(self, batch_size: int, input_dims: Tuple[int, int]):
193
+ if (self.num_rows, self.num_cols) == input_dims:
194
+ return self.pos_embed
195
+
196
+ pos_embed = self.pos_embed.reshape(1, self.num_rows, self.num_cols, -1).permute(0, 3, 1, 2)
197
+
198
+ def window_select(pos_embed):
199
+ if input_dims[0] < pos_embed.shape[-2]:
200
+ pos_embed = pos_embed[..., :input_dims[0], :]
201
+ if input_dims[1] < pos_embed.shape[-1]:
202
+ pos_embed = pos_embed[..., :, :input_dims[1]]
203
+ return pos_embed
204
+
205
+ if self.cpe_mode:
206
+ if self.training:
207
+ min_scale = math.sqrt(0.1)
208
+ scale = torch.rand(batch_size, 1, 1, device=pos_embed.device) * (1 - min_scale) + min_scale
209
+ aspect_min = math.log(3 / 4)
210
+ aspect_max = -aspect_min
211
+ aspect = torch.exp(torch.rand(batch_size, 1, 1, device=pos_embed.device) * (aspect_max - aspect_min) + aspect_min)
212
+
213
+ scale_x = scale * aspect
214
+ scale_y = scale * (1 / aspect)
215
+ scale_xy = torch.stack([scale_x, scale_y], dim=-1).clamp_(0, 1)
216
+
217
+ pos_xy = torch.rand(batch_size, 1, 1, 2, device=pos_embed.device) * (1 - scale_xy)
218
+
219
+ lin_x = torch.linspace(0, 1, steps=input_dims[1], device=pos_embed.device)[None, None].expand(batch_size, input_dims[0], -1)
220
+ lin_y = torch.linspace(0, 1, steps=input_dims[0], device=pos_embed.device)[None, :, None].expand(batch_size, -1, input_dims[1])
221
+
222
+ lin_xy = torch.stack([lin_x, lin_y], dim=-1)
223
+
224
+ grid_xy = lin_xy * scale_xy + pos_xy
225
+
226
+ # Convert to [-1, 1] range
227
+ grid_xy.mul_(2).sub_(1)
228
+
229
+ pos_embed = F.grid_sample(
230
+ pos_embed.float().expand(batch_size, -1, -1, -1),
231
+ grid=grid_xy,
232
+ mode='bilinear',
233
+ padding_mode='zeros',
234
+ align_corners=True,
235
+ ).to(pos_embed.dtype)
236
+ else:
237
+ # i_rows, i_cols = input_dims
238
+ # p_rows, p_cols = pos_embed.shape[2:]
239
+ # if i_rows <= p_rows and i_cols <= p_cols:
240
+ # left = (p_cols - i_cols) // 2
241
+ # top = (p_rows - i_rows) // 2
242
+ # pos_embed = pos_embed[..., top:top+i_rows, left:left+i_cols]
243
+ # else:
244
+ max_dim = max(input_dims)
245
+ pos_embed = F.interpolate(pos_embed.float(), size=(max_dim, max_dim), align_corners=True, mode='bilinear').to(pos_embed.dtype)
246
+
247
+ pos_embed = window_select(pos_embed)
248
+ else:
249
+ pos_embed = window_select(pos_embed)
250
+
251
+ if pos_embed.shape[-2:] != input_dims:
252
+ pos_embed = F.interpolate(pos_embed.float(), size=input_dims, align_corners=True, mode='bilinear').to(pos_embed.dtype)
253
+
254
+ pos_embed = pos_embed.flatten(2).permute(0, 2, 1)
255
+
256
+ return pos_embed
257
+
258
+
259
+ class Im2Patches(nn.Module):
260
+ def __init__(self, patch_size: int):
261
+ super().__init__()
262
+ self.patch_size = patch_size
263
+
264
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
265
+ if self.patch_size == 1:
266
+ patches = x.flatten(2)
267
+ patches = patches.permute(0, 2, 1)
268
+ return patches
269
+
270
+ py = x.shape[-2] // self.patch_size
271
+ px = x.shape[-1] // self.patch_size
272
+ patches = rearrange(x, 'b c (py yy) (px xx) -> b (py px) (c yy xx)',
273
+ py=py, yy=self.patch_size,
274
+ px=px, xx=self.patch_size,
275
+ )
276
+ return patches
277
+
278
+
279
+ class ViTPatchLinear(nn.Linear):
280
+ def __init__(self, patch_size: int, embed_dim: int, bias: bool = False, **factory):
281
+ super().__init__(
282
+ 3 * (patch_size ** 2),
283
+ embed_dim,
284
+ bias=bias,
285
+ **factory
286
+ )
287
+ self.patch_size = patch_size
tim/models/nvidia_radio/radio/vitdet.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ from contextlib import contextmanager
3
+ from logging import getLogger
4
+ import math
5
+ import sys
6
+ from typing import List, Union, Iterable
7
+
8
+ import numpy as np
9
+ import torch
10
+ from torch import nn
11
+
12
+ from timm.models import VisionTransformer
13
+ from einops import rearrange
14
+
15
+ from .extra_models import DinoWrapper
16
+
17
+ DEFAULT_NUM_WINDOWED = 5
18
+ DEFAULT_NUM_GLOBAL = 4
19
+
20
+
21
+ class VitDetArgs:
22
+ def __init__(self,
23
+ window_size: int,
24
+ num_summary_tokens: int,
25
+ num_windowed: int = None,
26
+ num_global: int = None,
27
+ ):
28
+ self.window_size = window_size
29
+ self.num_summary_tokens = num_summary_tokens
30
+ self.num_windowed = num_windowed
31
+ self.num_global = num_global
32
+
33
+
34
+ def apply_vitdet_arch(model: Union[VisionTransformer, DinoWrapper], args: VitDetArgs):
35
+ if isinstance(model, VisionTransformer):
36
+ patch_embed = getattr(model, 'patch_generator', model.patch_embed)
37
+
38
+ return ViTDetHook(patch_embed, model.blocks, args)
39
+ elif isinstance(model, DinoWrapper):
40
+ inner = model.inner
41
+
42
+ patch_embed = getattr(inner, 'patch_generator', inner.patch_embed)
43
+ return ViTDetHook(patch_embed, inner.blocks, args)
44
+ else:
45
+ print(f'Warning: Unable to apply VitDet aug!', file=sys.stderr)
46
+
47
+
48
+ class ViTDetHook:
49
+ def __init__(self,
50
+ embedder: nn.Module,
51
+ blocks: nn.Sequential,
52
+ args: VitDetArgs,
53
+ ):
54
+ self.blocks = blocks
55
+ self.num_summary_tokens = args.num_summary_tokens
56
+ self.window_size = args.window_size
57
+
58
+ self._input_resolution = None
59
+ self._num_windows = None
60
+ self._cls_patch = None
61
+ self._order_cache = dict()
62
+
63
+ embedder.register_forward_pre_hook(self._enter_model)
64
+
65
+ # This will decide if we window-fy the patches
66
+ # and enable vit-det for this iteration, and if so,
67
+ # rearrange the patches for efficient mode switching
68
+ blocks.register_forward_pre_hook(self._enter_blocks)
69
+
70
+ is_global = True
71
+ if args.num_windowed is not None:
72
+ period = args.num_windowed + 1
73
+ else:
74
+ num_global = args.num_global or DEFAULT_NUM_GLOBAL
75
+ period = max(len(blocks) // num_global, 1)
76
+
77
+ for i, layer in enumerate(blocks[:-1]):
78
+ ctr = i % period
79
+ if ctr == 0:
80
+ layer.register_forward_pre_hook(self._to_windows)
81
+ is_global = False
82
+ elif ctr == period - 1:
83
+ layer.register_forward_pre_hook(self._to_global)
84
+ is_global = True
85
+
86
+ # Always ensure the final layer is a global layer
87
+ if not is_global:
88
+ blocks[-1].register_forward_pre_hook(self._to_global)
89
+
90
+ blocks.register_forward_hook(self._exit_model)
91
+
92
+ def _enter_model(self, _, input: List[torch.Tensor]):
93
+ self._input_resolution = input[0].shape[-2:]
94
+
95
+ def _enter_blocks(self, _, input: List[torch.Tensor]):
96
+ # print(f'{get_rank()} - ViTDet Window Size: {self._window_size}', file=sys.stderr)
97
+
98
+ patches = input[0]
99
+ patches = self._rearrange_patches(patches)
100
+
101
+ return (patches,) + input[1:]
102
+
103
+ def _to_windows(self, _, input: List[torch.Tensor]):
104
+ patches = input[0]
105
+
106
+ if self.num_summary_tokens:
107
+ self._cls_patch = patches[:, :self.num_summary_tokens]
108
+ patches = patches[:, self.num_summary_tokens:]
109
+
110
+ patches = rearrange(
111
+ patches, 'b (p t) c -> (b p) t c',
112
+ p=self._num_windows, t=self.window_size ** 2,
113
+ )
114
+
115
+ return (patches,) + input[1:]
116
+
117
+ def _to_global(self, _, input: List[torch.Tensor]):
118
+ patches = input[0]
119
+
120
+ patches = rearrange(
121
+ patches, '(b p) t c -> b (p t) c',
122
+ p=self._num_windows, t=self.window_size ** 2,
123
+ b=patches.shape[0] // self._num_windows,
124
+ )
125
+
126
+ if self.num_summary_tokens:
127
+ patches = torch.cat([
128
+ self._cls_patch,
129
+ patches,
130
+ ], dim=1)
131
+
132
+ return (patches,) + input[1:]
133
+
134
+ def _exit_model(self, _, inputs: List[torch.Tensor], patches: torch.Tensor):
135
+ # Return patches to their original order
136
+ patch_order = self._order_cache[self._input_resolution][0]
137
+ patch_order = patch_order.reshape(1, -1, 1).expand_as(patches)
138
+
139
+ ret_patches = torch.empty_like(patches)
140
+ ret_patches = torch.scatter(
141
+ ret_patches,
142
+ dim=1,
143
+ index=patch_order,
144
+ src=patches,
145
+ )
146
+
147
+ return ret_patches
148
+
149
+ def _rearrange_patches(self, patches: torch.Tensor):
150
+ # We rearrange the patches so that we can efficiently
151
+ # switch between windowed and global mode by just
152
+ # reshaping the tensor
153
+
154
+ patch_order, self._num_windows = self._order_cache.get(self._input_resolution, (None, None))
155
+ if patch_order is None:
156
+ num_feat_patches = patches.shape[1] - self.num_summary_tokens
157
+ num_pixels = self._input_resolution[0] * self._input_resolution[1]
158
+
159
+ patch_size = int(round(math.sqrt(num_pixels / num_feat_patches)))
160
+ rows = self._input_resolution[-2] // patch_size
161
+ cols = self._input_resolution[-1] // patch_size
162
+
163
+ w_rows = rows // self.window_size
164
+ w_cols = cols // self.window_size
165
+
166
+ patch_order = torch.arange(0, num_feat_patches, device=patches.device)
167
+
168
+ patch_order = rearrange(
169
+ patch_order, '(wy py wx px) -> (wy wx py px)',
170
+ wy=w_rows, wx=w_cols,
171
+ py=self.window_size, px=self.window_size,
172
+ )
173
+
174
+ if self.num_summary_tokens:
175
+ patch_order = torch.cat([
176
+ torch.arange(self.num_summary_tokens, dtype=patch_order.dtype, device=patch_order.device),
177
+ patch_order + self.num_summary_tokens,
178
+ ])
179
+
180
+ self._num_windows = w_rows * w_cols
181
+ self._order_cache[self._input_resolution] = (
182
+ patch_order,
183
+ self._num_windows,
184
+ )
185
+
186
+ patch_order = patch_order.reshape(1, -1, 1).expand_as(patches)
187
+ patches = torch.gather(patches, dim=1, index=patch_order)
188
+ return patches
tim/models/t2i/tim_model.py ADDED
@@ -0,0 +1,493 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This source code is licensed under the license found in the
2
+ # LICENSE file in the root directory of this source tree.
3
+ # --------------------------------------------------------
4
+ # References:
5
+ # GLIDE: https://github.com/openai/glide-text2im
6
+ # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
7
+ # --------------------------------------------------------
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import numpy as np
13
+ import math
14
+ from timm.layers.mlp import SwiGLU, Mlp
15
+ from timm.models.vision_transformer import PatchEmbed, Attention
16
+ from tim.models.utils.funcs import build_mlp, modulate, get_parameter_dtype
17
+ from tim.models.utils.rope import VisionRotaryEmbedding, rotate_half
18
+ from flash_attn import flash_attn_func
19
+
20
+
21
+ #################################################################################
22
+ # Embedding Layers for Timesteps and Class Labels #
23
+ #################################################################################
24
+ class TimestepEmbedder(nn.Module):
25
+ """
26
+ Embeds scalar timesteps into vector representations.
27
+ """
28
+ def __init__(self, hidden_size, frequency_embedding_size=256):
29
+ super().__init__()
30
+ self.mlp = nn.Sequential(
31
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
32
+ nn.SiLU(),
33
+ nn.Linear(hidden_size, hidden_size, bias=True),
34
+ )
35
+ self.frequency_embedding_size = frequency_embedding_size
36
+
37
+ @staticmethod
38
+ def positional_embedding(t, dim, max_period=10000):
39
+ """
40
+ Create sinusoidal timestep embeddings.
41
+ :param t: a 1-D Tensor of N indices, one per batch element.
42
+ These may be fractional.
43
+ :param dim: the dimension of the output.
44
+ :param max_period: controls the minimum frequency of the embeddings.
45
+ :return: an (N, D) Tensor of positional embeddings.
46
+ """
47
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
48
+ half = dim // 2
49
+ freqs = torch.exp(
50
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
51
+ ).to(device=t.device)
52
+ args = t[:, None].float() * freqs[None]
53
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
54
+ if dim % 2:
55
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
56
+ return embedding
57
+
58
+ def forward(self, t):
59
+ self.timestep_embedding = self.positional_embedding
60
+ t_freq = self.timestep_embedding(t, dim=self.frequency_embedding_size).to(t.dtype)
61
+ t_emb = self.mlp(t_freq)
62
+ return t_emb
63
+
64
+
65
+ class CaptionEmbedder(nn.Module):
66
+ """
67
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
68
+ """
69
+ def __init__(self, cap_feat_dim, hidden_size):
70
+ super().__init__()
71
+ self.norm = nn.LayerNorm(cap_feat_dim)
72
+ self.mlp = SwiGLU(in_features=cap_feat_dim, hidden_features=hidden_size*4, out_features=hidden_size)
73
+
74
+
75
+ def forward(self, cap_feats):
76
+ '''
77
+ cfg is also essential in text-to-image generation
78
+ '''
79
+ cap_feats = self.mlp(self.norm(cap_feats))
80
+ return cap_feats
81
+
82
+
83
+
84
+ #################################################################################
85
+ # Attention Block #
86
+ #################################################################################
87
+
88
+ class Attention(nn.Module):
89
+ def __init__(
90
+ self,
91
+ dim: int,
92
+ num_heads: int = 8,
93
+ qkv_bias: bool = False,
94
+ qk_norm: bool = False,
95
+ attn_drop: float = 0.,
96
+ proj_drop: float = 0.,
97
+ norm_layer: nn.Module = nn.LayerNorm,
98
+ distance_aware: bool = False,
99
+ ) -> None:
100
+ super().__init__()
101
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
102
+ self.num_heads = num_heads
103
+ self.head_dim = dim // num_heads
104
+ self.scale = self.head_dim ** -0.5
105
+
106
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
107
+ self.distance_aware = distance_aware
108
+ if distance_aware:
109
+ self.qkv_d = nn.Linear(dim, dim * 3, bias=False)
110
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
111
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
112
+ self.attn_drop = nn.Dropout(attn_drop)
113
+ self.proj = nn.Linear(dim, dim)
114
+ self.proj_drop = nn.Dropout(proj_drop)
115
+
116
+ def forward(self, x: torch.Tensor, freqs_cos, freqs_sin, attn_type='fused_attn', delta_t=None) -> torch.Tensor:
117
+ B, N, C = x.shape
118
+ if self.distance_aware:
119
+ qkv = self.qkv(x) + self.qkv_d(delta_t)
120
+ else:
121
+ qkv = self.qkv(x)
122
+ if attn_type == 'flash_attn': # q, k, v: (B, N, n_head, d_head)
123
+ qkv = qkv.reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 1, 3, 4)
124
+ else: # q, k, v: (B, n_head, N, d_head)
125
+ qkv = qkv.reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
126
+ ori_dtype = qkv.dtype
127
+ q, k, v = qkv.unbind(0)
128
+ q, k = self.q_norm(q), self.k_norm(k)
129
+
130
+ q = q * freqs_cos + rotate_half(q) * freqs_sin
131
+ k = k * freqs_cos + rotate_half(k) * freqs_sin
132
+ q, k = q.to(ori_dtype), k.to(ori_dtype)
133
+
134
+ if attn_type == 'flash_attn':
135
+ x = flash_attn_func(
136
+ q, k, v,
137
+ dropout_p=self.attn_drop.p if self.training else 0.,
138
+ )
139
+ x = x.reshape(B, N, C)
140
+ elif attn_type == 'fused_attn':
141
+ x = F.scaled_dot_product_attention(
142
+ q, k, v,
143
+ dropout_p=self.attn_drop.p if self.training else 0.,
144
+ )
145
+ x = x.transpose(1, 2).reshape(B, N, C)
146
+ else:
147
+ q = q * self.scale
148
+ attn = q @ k.transpose(-2, -1)
149
+ attn = attn.softmax(dim=-1)
150
+ attn = self.attn_drop(attn)
151
+ x = attn @ v
152
+ x = x.transpose(1, 2).reshape(B, N, C)
153
+
154
+ x = self.proj(x)
155
+ x = self.proj_drop(x)
156
+ return x
157
+
158
+
159
+
160
+
161
+
162
+
163
+ #################################################################################
164
+ # Cross Attention Block #
165
+ #################################################################################
166
+
167
+ class CrossAttention(nn.Module):
168
+ def __init__(
169
+ self,
170
+ dim: int,
171
+ num_heads: int = 8,
172
+ qkv_bias: bool = False,
173
+ qk_norm: bool = False,
174
+ attn_drop: float = 0.,
175
+ proj_drop: float = 0.,
176
+ norm_layer: nn.Module = nn.LayerNorm,
177
+ ) -> None:
178
+ super().__init__()
179
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
180
+ self.num_heads = num_heads
181
+ self.head_dim = dim // num_heads
182
+ self.scale = self.head_dim ** -0.5
183
+
184
+ self.q = nn.Linear(dim, dim, bias=qkv_bias)
185
+ self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
186
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
187
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
188
+ self.attn_drop = nn.Dropout(attn_drop)
189
+ self.proj = nn.Linear(dim, dim)
190
+ self.proj_drop = nn.Dropout(proj_drop)
191
+
192
+ def forward(self, x: torch.Tensor, y: torch.Tensor, freqs_cos, freqs_sin, attn_type='fused_attn') -> torch.Tensor:
193
+ B, N, C = x.shape
194
+ _, M, _ = y.shape
195
+ if attn_type == 'flash_attn': # q, k, v: (B, N, n_head, d_head)
196
+ q = self.q(x).reshape(B, N, self.num_heads, self.head_dim)
197
+ kv = self.kv(y).reshape(B, M, 2, self.num_heads, self.head_dim).permute(2, 0, 1, 3, 4)
198
+ else: # q, k, v: (B, n_head, N, d_head)
199
+ q = self.q(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
200
+ kv = self.kv(y).reshape(B, M, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
201
+ ori_dtype = q.dtype
202
+ k, v = kv.unbind(0)
203
+ q, k = self.q_norm(q), self.k_norm(k)
204
+ q = q * freqs_cos + rotate_half(q) * freqs_sin
205
+ q, k = q.to(ori_dtype), k.to(ori_dtype)
206
+
207
+ if attn_type == 'flash_attn':
208
+ x = flash_attn_func(
209
+ q, k, v,
210
+ dropout_p=self.attn_drop.p if self.training else 0.,
211
+ )
212
+ x = x.reshape(B, N, C)
213
+ elif attn_type == 'fused_attn':
214
+ x = F.scaled_dot_product_attention(
215
+ q, k, v,
216
+ dropout_p=self.attn_drop.p if self.training else 0.,
217
+ )
218
+ x = x.transpose(1, 2).reshape(B, N, C)
219
+ else:
220
+ q = q * self.scale
221
+ attn = q @ k.transpose(-2, -1)
222
+ attn = attn.softmax(dim=-1)
223
+ attn = self.attn_drop(attn)
224
+ x = attn @ v
225
+ x = x.transpose(1, 2).reshape(B, N, C)
226
+
227
+ x = self.proj(x)
228
+ x = self.proj_drop(x)
229
+ return x
230
+
231
+
232
+
233
+
234
+
235
+
236
+ #################################################################################
237
+ # Core TiM Model #
238
+ #################################################################################
239
+
240
+ class TiMBlock(nn.Module):
241
+ """
242
+ A TiM block with adaptive layer norm zero (adaLN-Zero) conditioning.
243
+ """
244
+ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
245
+ super().__init__()
246
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
247
+ distance_aware = block_kwargs.get('distance_aware', False)
248
+ self.attn = Attention(
249
+ hidden_size, num_heads=num_heads, qkv_bias=True, qk_norm=block_kwargs["qk_norm"],
250
+ distance_aware=distance_aware
251
+ )
252
+ self.norm2_i = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
253
+ self.norm2_t = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
254
+ self.cross_attn = CrossAttention(
255
+ hidden_size, num_heads=num_heads, qkv_bias=True, qk_norm=block_kwargs["qk_norm"]
256
+ )
257
+ self.norm3 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
258
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
259
+ self.mlp = SwiGLU(
260
+ in_features=hidden_size, hidden_features=(mlp_hidden_dim*2)//3, bias=True
261
+ )
262
+ if block_kwargs.get('lora_hidden_size', None) != None:
263
+ lora_hidden_size = block_kwargs['lora_hidden_size']
264
+ else:
265
+ lora_hidden_size = (hidden_size//4)*3
266
+ self.adaLN_modulation = SwiGLU(
267
+ in_features=hidden_size, hidden_features=lora_hidden_size, out_features=9*hidden_size, bias=True
268
+ )
269
+
270
+
271
+
272
+ def forward(self, x, y, c, freqs_cos, freqs_sin, attn_type, delta_t=None):
273
+ (
274
+ shift_msa, scale_msa, gate_msa,
275
+ shift_msc, scale_msc, gate_msc,
276
+ shift_mlp, scale_mlp, gate_mlp
277
+ ) = self.adaLN_modulation(c).chunk(9, dim=-1)
278
+ x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), freqs_cos, freqs_sin, attn_type, delta_t)
279
+ x = x + gate_msc * self.cross_attn(modulate(self.norm2_i(x), shift_msc, scale_msc), self.norm2_t(y), freqs_cos, freqs_sin, attn_type)
280
+ x = x + gate_mlp * self.mlp(modulate(self.norm3(x), shift_mlp, scale_mlp))
281
+
282
+ return x
283
+
284
+
285
+ class FinalLayer(nn.Module):
286
+ """
287
+ The final layer of TiM.
288
+ """
289
+ def __init__(self, hidden_size, patch_size, out_channels):
290
+ super().__init__()
291
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
292
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
293
+ self.adaLN_modulation = SwiGLU(
294
+ in_features=hidden_size, hidden_features=hidden_size//2, out_features=2*hidden_size, bias=True
295
+ )
296
+
297
+
298
+ def forward(self, x, c):
299
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
300
+ x = modulate(self.norm_final(x), shift, scale)
301
+ x = self.linear(x)
302
+
303
+ return x
304
+
305
+
306
+ class TiM(nn.Module):
307
+ """
308
+ Diffusion model with a Transformer backbone.
309
+ """
310
+ def __init__(
311
+ self,
312
+ input_size=32,
313
+ patch_size=2,
314
+ in_channels=4,
315
+ hidden_size=1152,
316
+ encoder_depth=8,
317
+ depth=28,
318
+ num_heads=16,
319
+ mlp_ratio=4.0,
320
+ cap_feat_dim=2048,
321
+ z_dim=768,
322
+ projector_dim=2048,
323
+ use_checkpoint: bool = False,
324
+ new_condition: str = 't-r',
325
+ use_new_embed: bool = False,
326
+ **block_kwargs # qk_norm
327
+ ):
328
+ super().__init__()
329
+ self.in_channels = in_channels
330
+ self.out_channels = in_channels
331
+ self.patch_size = patch_size
332
+ self.num_heads = num_heads
333
+ self.cap_feat_dim = cap_feat_dim
334
+ self.encoder_depth = encoder_depth
335
+ self.use_checkpoint = use_checkpoint
336
+ self.new_condition = new_condition
337
+ self.use_new_embed = use_new_embed
338
+
339
+ self.x_embedder = PatchEmbed(
340
+ input_size, patch_size, in_channels, hidden_size, bias=True, strict_img_size=False
341
+ )
342
+ self.t_embedder = TimestepEmbedder(hidden_size) # timestep embedding type
343
+ if use_new_embed:
344
+ self.delta_embedder = TimestepEmbedder(hidden_size)
345
+ self.y_embedder = CaptionEmbedder(cap_feat_dim, hidden_size)
346
+ # Will use fixed sin-cos embedding:
347
+ self.rope = VisionRotaryEmbedding(head_dim=hidden_size//num_heads)
348
+
349
+ self.blocks = nn.ModuleList([
350
+ TiMBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, **block_kwargs) for _ in range(depth)
351
+ ])
352
+ self.projector = build_mlp(hidden_size, projector_dim, z_dim)
353
+ self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
354
+ self.initialize_weights()
355
+
356
+ def initialize_weights(self):
357
+ # Initialize transformer layers:
358
+ def _basic_init(module):
359
+ if isinstance(module, nn.Linear):
360
+ torch.nn.init.xavier_uniform_(module.weight)
361
+ if module.bias is not None:
362
+ nn.init.constant_(module.bias, 0)
363
+ self.apply(_basic_init)
364
+
365
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
366
+ w = self.x_embedder.proj.weight.data
367
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
368
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
369
+
370
+ # Initialize label embedding table:
371
+ nn.init.normal_(self.y_embedder.mlp.fc1_g.weight, std=0.02)
372
+ nn.init.normal_(self.y_embedder.mlp.fc1_x.weight, std=0.02)
373
+ nn.init.normal_(self.y_embedder.mlp.fc2.weight, std=0.02)
374
+
375
+ # Initialize timestep embedding MLP:
376
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
377
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
378
+
379
+ # Zero-out adaLN modulation layers in TiM blocks:
380
+ for block in self.blocks:
381
+ nn.init.constant_(block.adaLN_modulation.fc2.weight, 0)
382
+ nn.init.constant_(block.adaLN_modulation.fc2.bias, 0)
383
+
384
+
385
+ # Zero-out output layers:
386
+ nn.init.constant_(self.final_layer.adaLN_modulation.fc2.weight, 0)
387
+ nn.init.constant_(self.final_layer.adaLN_modulation.fc2.bias, 0)
388
+ nn.init.constant_(self.final_layer.linear.weight, 0)
389
+ nn.init.constant_(self.final_layer.linear.bias, 0)
390
+
391
+ def unpatchify(self, x, H, W):
392
+ """
393
+ x: (N, T, patch_size**2 * C)
394
+ imgs: (N, H, W, C)
395
+ """
396
+ c = self.out_channels
397
+ p = self.patch_size
398
+ h, w = int(H/p), int(W/p)
399
+
400
+
401
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
402
+ x = torch.einsum('nhwpqc->nchpwq', x)
403
+ imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
404
+ return imgs
405
+
406
+ def get_rope(self, h, w, attn_type):
407
+ grid_h = torch.arange(h)
408
+ grid_w = torch.arange(w)
409
+ grid = torch.meshgrid(grid_h, grid_w, indexing='xy')
410
+ grid = torch.stack(grid, dim=0).reshape(2, -1).unsqueeze(0)
411
+ freqs_cos, freqs_sin = self.rope.get_cached_2d_rope_from_grid(grid)
412
+ if attn_type == 'flash_attn': # (1, N, 1, d_head)
413
+ return freqs_cos.unsqueeze(2), freqs_sin.unsqueeze(2)
414
+ else: # (1, 1, N, d_head)
415
+ return freqs_cos.unsqueeze(1), freqs_sin.unsqueeze(1)
416
+
417
+
418
+ def forward(self, x, t, r, y, attn_type='flash_attn', return_zs=False, jvp=False):
419
+ """
420
+ Forward pass of TiM.
421
+ x: (B, C, H, W) tensor of spatial inputs (images or latent representations of images)
422
+ t: (B,) tensor of diffusion timesteps
423
+ y: (B,) tensor of class labels
424
+ """
425
+ B, C, H, W = x.shape
426
+ x = self.x_embedder(x) # (N, N, D), where T = H * W / patch_size ** 2
427
+
428
+ # timestep and class embedding
429
+ t_embed = self.t_embedder(t).unsqueeze(1) # (B, 1, D)
430
+ delta_embed = self.get_delta_embed(t, r).unsqueeze(1) # (B, 1, D)
431
+ y = self.y_embedder(y) # (B, M, D)
432
+ c = t_embed + delta_embed # (B, 1, D)
433
+
434
+
435
+ freqs_cos, freqs_sin = self.get_rope(
436
+ int(H/self.patch_size), int(W/self.patch_size), attn_type
437
+ )
438
+
439
+ for i, block in enumerate(self.blocks):
440
+ if not self.use_checkpoint or jvp:
441
+ x = block(x, y, c, freqs_cos, freqs_sin, attn_type, delta_embed) # (B, N, D)
442
+ else:
443
+ x = torch.utils.checkpoint.checkpoint(
444
+ self.ckpt_wrapper(block), x, y, c, freqs_cos, freqs_sin, attn_type, delta_embed
445
+ )
446
+ if (i + 1) == self.encoder_depth:
447
+ h_proj = self.projector(x)
448
+ x = self.final_layer(x, c) # (B, N, patch_size ** 2 * out_channels)
449
+ x = self.unpatchify(x, H, W) # (b, out_channels, H, W)
450
+
451
+ if return_zs:
452
+ return x, h_proj
453
+ else:
454
+ return x
455
+
456
+ def get_delta_embed(self, t, r):
457
+ if self.use_new_embed:
458
+ delta_embedder = self.delta_embedder
459
+ else:
460
+ delta_embedder = self.t_embedder
461
+ if self.new_condition == 't-r':
462
+ delta_embed = delta_embedder(t-r)
463
+ elif self.new_condition == 'r':
464
+ delta_embed = delta_embedder(r)
465
+ elif self.new_condition == 't,r':
466
+ delta_embed = self.t_embedder(t) + delta_embedder(r)
467
+ elif self.new_condition == 't,t-r':
468
+ delta_embed = self.t_embedder(t) + delta_embedder(t-r)
469
+ elif self.new_condition == 'r,t-r':
470
+ delta_embed = self.t_embedder(r) + delta_embedder(t-r)
471
+ elif self.new_condition == 't,r,t-r':
472
+ delta_embed = self.t_embedder(t) + self.t_embedder(r) + delta_embedder(t-r)
473
+ else:
474
+ raise NotImplementedError
475
+ return delta_embed
476
+
477
+ def ckpt_wrapper(self, module):
478
+ def ckpt_forward(*inputs):
479
+ outputs = module(*inputs)
480
+ return outputs
481
+ return ckpt_forward
482
+
483
+
484
+ @property
485
+ def dtype(self) -> torch.dtype:
486
+ """
487
+ `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
488
+ """
489
+ return get_parameter_dtype(self)
490
+
491
+
492
+
493
+
tim/models/utils/funcs.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from torch import Tensor
5
+ from typing import List, Tuple
6
+ from itertools import chain
7
+
8
+
9
+
10
+ def expand_t_like_x(t, x):
11
+ """Function to reshape time t to broadcastable dimension of x
12
+ Args:
13
+ t: [batch_dim,], time vector
14
+ x: [batch_dim,...], data point
15
+ """
16
+ dims = [1] * (len(x.size()) - 1)
17
+ t = t.view(t.size(0), *dims)
18
+ return t
19
+
20
+
21
+ def build_mlp(hidden_size, projector_dim, z_dim):
22
+ return nn.Sequential(
23
+ nn.Linear(hidden_size, projector_dim),
24
+ nn.SiLU(),
25
+ nn.Linear(projector_dim, projector_dim),
26
+ nn.SiLU(),
27
+ nn.Linear(projector_dim, z_dim),
28
+ )
29
+
30
+ def modulate(x, shift, scale):
31
+ return x * (1 + scale) + shift
32
+
33
+
34
+ def get_parameter_dtype(parameter: torch.nn.Module):
35
+ try:
36
+ params = tuple(parameter.parameters())
37
+ if len(params) > 0:
38
+ return params[0].dtype
39
+
40
+ buffers = tuple(parameter.buffers())
41
+ if len(buffers) > 0:
42
+ return buffers[0].dtype
43
+
44
+ except StopIteration:
45
+ # For torch.nn.DataParallel compatibility in PyTorch 1.5
46
+
47
+ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
48
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
49
+ return tuples
50
+
51
+ gen = parameter._named_members(get_members_fn=find_tensor_attributes)
52
+ first_tuple = next(gen)
53
+ return first_tuple[1].dtype
tim/models/utils/norms.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+
9
+ from functools import partial
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+
14
+ import triton
15
+ import triton.language as tl
16
+ import torch.nn.functional as F
17
+
18
+
19
+ def create_norm(norm_type: str, dim: int, eps: float = 1e-6):
20
+ """
21
+ Creates the specified normalization layer based on the norm_type.
22
+
23
+ Args:
24
+ norm_type (str): The type of normalization layer to create.
25
+ Supported types: 1. rmsnorm 2. fused_rmsnorm 3. layernorm 4. np_layernorm
26
+ dim (int): The dimension of the normalization layer.
27
+ eps (float, optional): The epsilon value for numerical stability. Defaults to 1e-6.
28
+
29
+ Returns:
30
+ The created normalization layer.
31
+
32
+ Raises:
33
+ NotImplementedError: If an unknown norm_type is provided.
34
+ """
35
+ if norm_type == None or norm_type == "":
36
+ return nn.Identity()
37
+ norm_type = norm_type.lower() # Normalize to lowercase
38
+
39
+ if norm_type == "layernorm":
40
+ return nn.LayerNorm(dim, eps=eps, bias=False)
41
+ elif norm_type == "np_layernorm":
42
+ return nn.LayerNorm(dim, eps=eps, elementwise_affine=False, bias=False)
43
+ elif norm_type == "np_layernorm_32":
44
+ return FP32_Layernorm(dim, eps=eps, elementwise_affine=False, bias=True)
45
+ elif norm_type == "layernorm_32":
46
+ return FP32_Layernorm(dim, eps=eps, bias=True)
47
+ elif norm_type == "rmsnorm":
48
+ return RMSNorm(dim, include_weight=True, eps=eps)
49
+ elif norm_type == "np_rmsnorm":
50
+ return RMSNorm(dim, include_weight=False, eps=1e-6)
51
+ elif norm_type == "fused_rmsnorm":
52
+ return FusedRMSNorm(dim, eps=1/65536)
53
+ elif norm_type == "fused_rmsnorm_32":
54
+ return FusedRMSNorm32(dim, eps=1e-6)
55
+ elif norm_type == 'none':
56
+ return nn.Identity()
57
+ else:
58
+ return nn.Identity()
59
+
60
+ class FP32_Layernorm(nn.LayerNorm):
61
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
62
+ origin_dtype = inputs.dtype
63
+ if self.bias == None and self.weight == None:
64
+ return F.layer_norm(
65
+ input=inputs.float(),
66
+ normalized_shape=self.normalized_shape,
67
+ eps=self.eps
68
+ ).to(origin_dtype)
69
+ elif self.bias == None:
70
+ return F.layer_norm(
71
+ input=inputs.float(),
72
+ normalized_shape=self.normalized_shape,
73
+ weight=self.weight.float(),
74
+ eps=self.eps
75
+ ).to(origin_dtype)
76
+ else:
77
+ return F.layer_norm(
78
+ input=inputs.float(),
79
+ normalized_shape=self.normalized_shape,
80
+ weight=self.weight.float(),
81
+ bias=self.bias.float(),
82
+ eps=self.eps
83
+ ).to(origin_dtype)
84
+
85
+ class FusedRMSNorm(nn.Module):
86
+ """Fused RMS Norm, wraps a fused Triton Kernel"""
87
+
88
+ def __init__(
89
+ self,
90
+ dim: int,
91
+ eps: float = 1e-6,
92
+ ):
93
+ super().__init__()
94
+ self.eps = eps
95
+ self.weight = nn.Parameter(torch.ones(dim))
96
+ self.fused_rms_norm_fn = fused_rms_norm_fn
97
+
98
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
99
+ """leverages Triton Fused RMS Norm kernel"""
100
+ return self.fused_rms_norm_fn(
101
+ x,
102
+ self.weight,
103
+ eps=self.eps,
104
+ )
105
+
106
+ def reset_parameters(self):
107
+ torch.nn.init.ones_(self.weight) # type: ignore
108
+
109
+ class FusedRMSNorm32(nn.Module):
110
+ """Fused RMS Norm, wraps a fused Triton Kernel"""
111
+
112
+ def __init__(
113
+ self,
114
+ dim: int,
115
+ eps: float = 1e-6,
116
+ ):
117
+ super().__init__()
118
+ self.eps = eps
119
+ self.weight = nn.Parameter(torch.ones(dim))
120
+ self.fused_rms_norm_fn = fused_rms_norm_fn
121
+
122
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
123
+ """leverages Triton Fused RMS Norm kernel"""
124
+ dtype = x.dtype
125
+ return self.fused_rms_norm_fn(
126
+ x.to(torch.float32),
127
+ self.weight,
128
+ eps=self.eps,
129
+ ).to(dtype)
130
+
131
+ def reset_parameters(self):
132
+ torch.nn.init.ones_(self.weight) # type: ignore
133
+
134
+ class RMSNorm(nn.Module):
135
+ def __init__(self, dim: int, include_weight: bool = True, eps: float = 1e-6, **block_kwargs):
136
+ """
137
+ Initialize the RMSNorm normalization layer.
138
+
139
+ Args:
140
+ dim (int): The dimension of the input tensor.
141
+ include_weight: bool: Whether include weight in the normalization
142
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
143
+
144
+ Attributes:
145
+ eps (float): A small value added to the denominator for numerical stability.
146
+ weight (nn.Parameter): Learnable scaling parameter.
147
+
148
+ """
149
+ super().__init__()
150
+ self.eps = eps
151
+ if include_weight:
152
+ self.weight = nn.Parameter(torch.ones(dim))
153
+ else:
154
+ self.weight = None
155
+
156
+ def _norm(self, x):
157
+ """
158
+ Apply the RMSNorm normalization to the input tensor.
159
+
160
+ Args:
161
+ x (torch.Tensor): The input tensor.
162
+
163
+ Returns:
164
+ torch.Tensor: The normalized tensor.
165
+
166
+ """
167
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
168
+
169
+ def forward(self, x):
170
+ """
171
+ Forward pass through the RMSNorm layer.
172
+
173
+ Args:
174
+ x (torch.Tensor): The input tensor.
175
+
176
+ Returns:
177
+ torch.Tensor: The output tensor after applying RMSNorm.
178
+
179
+ """
180
+ output = self._norm(x.float()).type_as(x)
181
+ if self.weight == None:
182
+ return output
183
+ else:
184
+ return output * self.weight
185
+
186
+
187
+
188
+ # FusedRMSNorm in Triton
189
+
190
+ # Credit
191
+ # Tri Dao's Triton LayerNorm: https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/layer_norm.py
192
+ # Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
193
+
194
+
195
+ @triton.autotune(
196
+ configs=[
197
+ triton.Config({}, num_warps=1),
198
+ triton.Config({}, num_warps=2),
199
+ triton.Config({}, num_warps=4),
200
+ triton.Config({}, num_warps=8),
201
+ triton.Config({}, num_warps=16),
202
+ triton.Config({}, num_warps=32),
203
+ ],
204
+ key=["N"],
205
+ )
206
+ @triton.jit
207
+ def _rms_norm_fwd_kernel(
208
+ X,
209
+ stride_x,
210
+ Y,
211
+ stride_y,
212
+ W,
213
+ Rstd,
214
+ eps,
215
+ M, # num rows
216
+ N, # num cols
217
+ block_N: tl.constexpr,
218
+ ):
219
+ row = tl.program_id(0)
220
+ cols = tl.arange(0, block_N)
221
+
222
+ # Load input data and weights
223
+ mask = cols < N
224
+ x = tl.load(X + row * stride_x + cols, mask=mask, other=0.0).to(tl.float32)
225
+ w = tl.load(W + cols, mask=mask, other=0.0).to(tl.float32)
226
+
227
+ # Compute mean and variance
228
+ xbar = tl.where(cols < N, x, 0.0)
229
+ var = tl.sum(xbar * xbar, axis=0) / N
230
+ rstd = 1 / tl.sqrt(var + eps)
231
+
232
+ # Store the reciprocal standard deviation
233
+ tl.store(Rstd + row, rstd)
234
+
235
+ # Normalize and apply linear transformation
236
+ x_hat = x * rstd
237
+ y = x_hat * w
238
+
239
+ # Write output
240
+ tl.store(Y + row * stride_y + cols, y, mask=mask)
241
+
242
+
243
+ @triton.autotune(
244
+ configs=[
245
+ triton.Config({}, num_warps=1),
246
+ triton.Config({}, num_warps=2),
247
+ triton.Config({}, num_warps=4),
248
+ triton.Config({}, num_warps=8),
249
+ triton.Config({}, num_warps=16),
250
+ triton.Config({}, num_warps=32),
251
+ ],
252
+ key=["N"],
253
+ )
254
+ @triton.jit
255
+ def _rms_norm_bwd_kernel_sm(
256
+ X,
257
+ stride_x,
258
+ W,
259
+ DY,
260
+ stride_dy,
261
+ DX,
262
+ stride_dx,
263
+ Rstd,
264
+ DW,
265
+ eps,
266
+ M, # num rows
267
+ N, # num cols
268
+ rows_per_program,
269
+ block_N: tl.constexpr,
270
+ ):
271
+ row_block_id = tl.program_id(0)
272
+ row_start = row_block_id * rows_per_program
273
+ cols = tl.arange(0, block_N)
274
+ mask = cols < N
275
+
276
+ # Load weights
277
+ w = tl.load(W + cols, mask=mask, other=0.0).to(tl.float32)
278
+
279
+ # Accumulate gradients for weights
280
+ dw = tl.zeros((block_N,), dtype=tl.float32)
281
+
282
+ row_end = min(row_start + rows_per_program, M)
283
+ for row in range(row_start, row_end):
284
+ # Load input, output gradient, and reciprocal standard deviation
285
+ x = tl.load(X + row * stride_x + cols, mask=mask, other=0.0).to(tl.float32)
286
+ dy = tl.load(DY + row * stride_dy + cols, mask=mask, other=0.0).to(tl.float32)
287
+ rstd = tl.load(Rstd + row)
288
+
289
+ # Compute normalized input and gradients
290
+ x_hat = x * rstd
291
+ wdy = w * dy
292
+ dw += dy * x_hat
293
+ c1 = tl.sum(x_hat * wdy, axis=0) / N
294
+ dx = (wdy - x_hat * c1) * rstd
295
+
296
+ # Store input gradient
297
+ tl.store(DX + row * stride_dx + cols, dx, mask=mask)
298
+
299
+ # Store weight gradients
300
+ tl.store(DW + row_block_id * N + cols, dw, mask=mask)
301
+
302
+
303
+ class TritonFusedRMSNorm(torch.autograd.Function):
304
+ @staticmethod
305
+ def forward(ctx, x, weight, eps):
306
+ x_shape_start = x.shape
307
+
308
+ # Flatten input
309
+ x = x.view(-1, x.shape[-1])
310
+ if x.stride(-1) != 1:
311
+ x = x.contiguous()
312
+ if weight.stride(-1) != 1:
313
+ weight = weight.contiguous()
314
+
315
+ M, N = x.shape
316
+ y = torch.empty_like(x)
317
+ rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
318
+
319
+ max_size = 65536 // x.element_size()
320
+ block_N = min(max_size, triton.next_power_of_2(N))
321
+
322
+ if N > block_N:
323
+ raise ValueError(f"N {N} must be <= {block_N=}")
324
+
325
+ grid = lambda meta: (M,)
326
+ _rms_norm_fwd_kernel[grid](
327
+ x,
328
+ x.stride(0),
329
+ y,
330
+ y.stride(0),
331
+ weight,
332
+ rstd,
333
+ eps,
334
+ M,
335
+ N,
336
+ block_N,
337
+ )
338
+
339
+ ctx.eps = eps
340
+ ctx.save_for_backward(x, weight, rstd)
341
+ ctx.x_shape_start = x_shape_start
342
+
343
+ y = y.reshape(x_shape_start)
344
+ return y
345
+
346
+ @staticmethod
347
+ def backward(ctx, dy):
348
+ x, weight, rstd = ctx.saved_tensors
349
+ eps = ctx.eps
350
+ x_shape_start = ctx.x_shape_start
351
+
352
+ # Flatten input and output gradients
353
+ dy = dy.view(-1, dy.shape[-1])
354
+ if dy.stride(-1) != 1:
355
+ dy = dy.contiguous()
356
+
357
+ M, N = dy.shape
358
+ dx = torch.empty_like(x)
359
+ dw = torch.empty_like(weight)
360
+
361
+ sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
362
+ _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
363
+
364
+ max_size = 65536 // x.element_size()
365
+ block_N = min(max_size, triton.next_power_of_2(N))
366
+ rows_per_sm = math.ceil(M / sm_count)
367
+
368
+ if N > block_N:
369
+ raise ValueError(f"N {N} must be <= {block_N=}")
370
+
371
+ grid = lambda meta: (sm_count,)
372
+ _rms_norm_bwd_kernel_sm[grid](
373
+ x,
374
+ x.stride(0),
375
+ weight,
376
+ dy,
377
+ dy.stride(0),
378
+ dx,
379
+ dx.stride(0),
380
+ rstd,
381
+ _dw,
382
+ eps,
383
+ M,
384
+ N,
385
+ rows_per_sm,
386
+ block_N,
387
+ )
388
+ dw = _dw.sum(0).to(weight.dtype)
389
+ dx = dx.view(x_shape_start)
390
+ return dx, dw, None
391
+
392
+
393
+ # expose fusedRMSNorm as a function
394
+ def fused_rms_norm_fn(
395
+ x,
396
+ weight,
397
+ eps=1e-6,
398
+ ):
399
+ return TritonFusedRMSNorm.apply(
400
+ x,
401
+ weight,
402
+ eps,
403
+ )
tim/models/utils/rope.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # FiT: A Flexible Vision Transformer for Image Generation
3
+ #
4
+ # Based on the following repository
5
+ # https://github.com/lucidrains/rotary-embedding-torch
6
+ # https://github.com/jquesnelle/yarn/blob/HEAD/scaled_rope
7
+ # https://colab.research.google.com/drive/1VI2nhlyKvd5cw4-zHvAIk00cAVj2lCCC#scrollTo=b80b3f37
8
+ # --------------------------------------------------------
9
+
10
+ import math
11
+ from math import pi
12
+ from typing import Optional, Any, Union, Tuple
13
+ import torch
14
+ from torch import nn
15
+
16
+ from einops import rearrange, repeat
17
+ from functools import lru_cache
18
+
19
+ #################################################################################
20
+ # NTK Operations #
21
+ #################################################################################
22
+
23
+ def find_correction_factor(num_rotations, dim, base=10000, max_position_embeddings=2048):
24
+ return (dim * math.log(max_position_embeddings/(num_rotations * 2 * math.pi)))/(2 * math.log(base)) #Inverse dim formula to find number of rotations
25
+
26
+ def find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048):
27
+ low = math.floor(find_correction_factor(low_rot, dim, base, max_position_embeddings))
28
+ high = math.ceil(find_correction_factor(high_rot, dim, base, max_position_embeddings))
29
+ return max(low, 0), min(high, dim-1) #Clamp values just in case
30
+
31
+ def linear_ramp_mask(min, max, dim):
32
+ if min == max:
33
+ max += 0.001 #Prevent singularity
34
+
35
+ linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
36
+ ramp_func = torch.clamp(linear_func, 0, 1)
37
+ return ramp_func
38
+
39
+ def find_newbase_ntk(dim, base=10000, scale=1):
40
+ # Base change formula
41
+ return base * scale ** (dim / (dim-2))
42
+
43
+ def get_mscale(scale=torch.Tensor):
44
+ # if scale <= 1:
45
+ # return 1.0
46
+ # return 0.1 * math.log(scale) + 1.0
47
+ return torch.where(scale <= 1., torch.tensor(1.0), 0.1 * torch.log(scale) + 1.0)
48
+
49
+ def get_proportion(L_test, L_train):
50
+ L_test = L_test * 2
51
+ return torch.where(torch.tensor(L_test/L_train) <= 1., torch.tensor(1.0), torch.sqrt(torch.log(torch.tensor(L_test))/torch.log(torch.tensor(L_train))))
52
+ # return torch.sqrt(torch.log(torch.tensor(L_test))/torch.log(torch.tensor(L_train)))
53
+
54
+
55
+
56
+ #################################################################################
57
+ # Rotate Q or K #
58
+ #################################################################################
59
+
60
+ def rotate_half(x):
61
+ x = rearrange(x, '... (d r) -> ... d r', r = 2)
62
+ x1, x2 = x.unbind(dim = -1)
63
+ x = torch.stack((-x2, x1), dim = -1)
64
+ return rearrange(x, '... d r -> ... (d r)')
65
+
66
+
67
+
68
+ #################################################################################
69
+ # Core Vision RoPE #
70
+ #################################################################################
71
+
72
+ class VisionRotaryEmbedding(nn.Module):
73
+ def __init__(
74
+ self,
75
+ head_dim: int, # embed dimension for each head
76
+ custom_freqs: str = 'normal',
77
+ theta: int = 10000,
78
+ online_rope: bool = False,
79
+ max_cached_len: int = 1024,
80
+ max_pe_len_h: Optional[int] = None,
81
+ max_pe_len_w: Optional[int] = None,
82
+ decouple: bool = False,
83
+ ori_max_pe_len: Optional[int] = None,
84
+ ):
85
+ super().__init__()
86
+
87
+ dim = head_dim // 2
88
+ assert dim % 2 == 0 # accually, this is important
89
+ self.dim = dim
90
+ self.custom_freqs = custom_freqs.lower()
91
+ self.theta = theta
92
+ self.decouple = decouple
93
+ self.ori_max_pe_len = ori_max_pe_len
94
+
95
+ self.custom_freqs = custom_freqs.lower()
96
+ if not online_rope:
97
+ if self.custom_freqs in ['normal', 'scale1', 'scale2']:
98
+ freqs_h = 1. / (theta ** (torch.arange(0, dim, 2).float() / dim))
99
+ freqs_w = 1. / (theta ** (torch.arange(0, dim, 2).float() / dim))
100
+ else:
101
+ if decouple:
102
+ freqs_h = self.get_1d_rope_freqs(theta, dim, max_pe_len_h, ori_max_pe_len)
103
+ freqs_w = self.get_1d_rope_freqs(theta, dim, max_pe_len_w, ori_max_pe_len)
104
+ else:
105
+ max_pe_len = max(max_pe_len_h, max_pe_len_w)
106
+ freqs_h = self.get_1d_rope_freqs(theta, dim, max_pe_len, ori_max_pe_len)
107
+ freqs_w = self.get_1d_rope_freqs(theta, dim, max_pe_len, ori_max_pe_len)
108
+
109
+ self.register_buffer('freqs_h', freqs_h, persistent=False)
110
+ self.register_buffer('freqs_w', freqs_w, persistent=False)
111
+
112
+ if max_pe_len_h != None and max_pe_len_w != None and ori_max_pe_len != None:
113
+ attn_factor = 1.0
114
+ scale = torch.clamp_min(torch.tensor(max(max_pe_len_h, max_pe_len_w)) / ori_max_pe_len, 1.0) # dynamic scale
115
+ self.mscale = get_mscale(scale).to(scale) * attn_factor # Get n-d magnitude scaling corrected for interpolation
116
+ self.proportion1 = get_proportion(max(max_pe_len_h, max_pe_len_w), ori_max_pe_len)
117
+ self.proportion2 = get_proportion(max_pe_len_h * max_pe_len_w, ori_max_pe_len ** 2)
118
+
119
+
120
+ freqs_h_cached = torch.einsum('..., f -> ... f', torch.arange(max_cached_len), self.freqs_h)
121
+ freqs_h_cached = repeat(freqs_h_cached, '... n -> ... (n r)', r = 2)
122
+ self.register_buffer('freqs_h_cached', freqs_h_cached, persistent=False)
123
+ freqs_w_cached = torch.einsum('..., f -> ... f', torch.arange(max_cached_len), self.freqs_w)
124
+ freqs_w_cached = repeat(freqs_w_cached, '... n -> ... (n r)', r = 2)
125
+ self.register_buffer('freqs_w_cached', freqs_w_cached, persistent=False)
126
+
127
+
128
+ def get_1d_rope_freqs(self, theta, dim, max_pe_len, ori_max_pe_len):
129
+ # scaling operations for extrapolation
130
+ assert isinstance(ori_max_pe_len, int)
131
+ # scale = max_pe_len / ori_max_pe_len
132
+ if not isinstance(max_pe_len, torch.Tensor):
133
+ max_pe_len = torch.tensor(max_pe_len)
134
+ scale = torch.clamp_min(max_pe_len / ori_max_pe_len, 1.0) # dynamic scale
135
+
136
+ if self.custom_freqs == 'linear': # equal to position interpolation
137
+ freqs = 1. / torch.einsum('..., f -> ... f', scale, theta ** (torch.arange(0, dim, 2).float() / dim))
138
+ elif self.custom_freqs == 'ntk-aware' or self.custom_freqs == 'ntk-aware-pro1' or self.custom_freqs == 'ntk-aware-pro2':
139
+ freqs = 1. / torch.pow(
140
+ find_newbase_ntk(dim, theta, scale).view(-1, 1),
141
+ (torch.arange(0, dim, 2).to(scale).float() / dim)
142
+ ).squeeze()
143
+ elif self.custom_freqs == 'ntk-by-parts':
144
+ #Interpolation constants found experimentally for LLaMA (might not be totally optimal though)
145
+ #Do not change unless there is a good reason for doing so!
146
+ beta_0 = 1.25
147
+ beta_1 = 0.75
148
+ gamma_0 = 16
149
+ gamma_1 = 2
150
+ ntk_factor = 1
151
+ extrapolation_factor = 1
152
+
153
+ #Three RoPE extrapolation/interpolation methods
154
+ freqs_base = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
155
+ freqs_linear = 1.0 / torch.einsum('..., f -> ... f', scale, (theta ** (torch.arange(0, dim, 2).to(scale).float() / dim)))
156
+ freqs_ntk = 1. / torch.pow(
157
+ find_newbase_ntk(dim, theta, scale).view(-1, 1),
158
+ (torch.arange(0, dim, 2).to(scale).float() / dim)
159
+ ).squeeze()
160
+
161
+ #Combine NTK and Linear
162
+ low, high = find_correction_range(beta_0, beta_1, dim, theta, ori_max_pe_len)
163
+ freqs_mask = (1 - linear_ramp_mask(low, high, dim // 2).to(scale)) * ntk_factor
164
+ freqs = freqs_linear * (1 - freqs_mask) + freqs_ntk * freqs_mask
165
+
166
+ #Combine Extrapolation and NTK and Linear
167
+ low, high = find_correction_range(gamma_0, gamma_1, dim, theta, ori_max_pe_len)
168
+ freqs_mask = (1 - linear_ramp_mask(low, high, dim // 2).to(scale)) * extrapolation_factor
169
+ freqs = freqs * (1 - freqs_mask) + freqs_base * freqs_mask
170
+
171
+ elif self.custom_freqs == 'yarn':
172
+ #Interpolation constants found experimentally for LLaMA (might not be totally optimal though)
173
+ #Do not change unless there is a good reason for doing so!
174
+ beta_fast = 32
175
+ beta_slow = 1
176
+ extrapolation_factor = 1
177
+
178
+ freqs_extrapolation = 1.0 / (theta ** (torch.arange(0, dim, 2).to(scale).float() / dim))
179
+ freqs_interpolation = 1.0 / torch.einsum('..., f -> ... f', scale, (theta ** (torch.arange(0, dim, 2).to(scale).float() / dim)))
180
+
181
+ low, high = find_correction_range(beta_fast, beta_slow, dim, theta, ori_max_pe_len)
182
+ freqs_mask = (1 - linear_ramp_mask(low, high, dim // 2).to(scale).float()) * extrapolation_factor # Get n-d rotational scaling corrected for extrapolation
183
+ freqs = freqs_interpolation * (1 - freqs_mask) + freqs_extrapolation * freqs_mask
184
+ else:
185
+ raise ValueError(f'Unknown modality {self.custom_freqs}. Only support normal, linear, ntk-aware, ntk-by-parts, yarn!')
186
+ return freqs
187
+
188
+
189
+ def online_get_2d_rope_from_grid(self, grid, size):
190
+ '''
191
+ grid: (B, 2, N)
192
+ N = H * W
193
+ the first dimension represents width, and the second reprensents height
194
+ e.g., [0. 1. 2. 3. 0. 1. 2. 3. 0. 1. 2. 3.]
195
+ [0. 0. 0. 0. 1. 1. 1. 1. 2. 2. 2. 2.]
196
+ size: (B, 1, 2), h goes first and w goes last
197
+ '''
198
+ size = size.squeeze() # (B, 1, 2) -> (B, 2)
199
+ if self.decouple:
200
+ size_h = size[:, 0]
201
+ size_w = size[:, 1]
202
+ freqs_h = self.get_1d_rope_freqs(self.theta, self.dim, size_h, self.ori_max_pe_len)
203
+ freqs_w = self.get_1d_rope_freqs(self.theta, self.dim, size_w, self.ori_max_pe_len)
204
+ else:
205
+ size_max = torch.max(size[:, 0], size[:, 1])
206
+ freqs_h = self.get_1d_rope_freqs(self.theta, self.dim, size_max, self.ori_max_pe_len)
207
+ freqs_w = self.get_1d_rope_freqs(self.theta, self.dim, size_max, self.ori_max_pe_len)
208
+ freqs_w = grid[:, 0][..., None] * freqs_w[:, None, :]
209
+ freqs_w = repeat(freqs_w, '... n -> ... (n r)', r = 2)
210
+
211
+ freqs_h = grid[:, 1][..., None] * freqs_h[:, None, :]
212
+ freqs_h = repeat(freqs_h, '... n -> ... (n r)', r = 2)
213
+
214
+ freqs = torch.cat([freqs_h, freqs_w], dim=-1) # (B, N, D)
215
+
216
+ if self.custom_freqs == 'yarn':
217
+ freqs_cos = freqs.cos() * self.mscale[:, None, None]
218
+ freqs_sin = freqs.sin() * self.mscale[:, None, None]
219
+ elif self.custom_freqs == 'ntk-aware-pro1':
220
+ freqs_cos = freqs.cos() * self.proportion1[:, None, None]
221
+ freqs_sin = freqs.sin() * self.proportion1[:, None, None]
222
+ elif self.custom_freqs == 'ntk-aware-pro2':
223
+ freqs_cos = freqs.cos() * self.proportion2[:, None, None]
224
+ freqs_sin = freqs.sin() * self.proportion2[:, None, None]
225
+ else:
226
+ freqs_cos = freqs.cos()
227
+ freqs_sin = freqs.sin()
228
+
229
+ return freqs_cos, freqs_sin
230
+
231
+ @lru_cache()
232
+ def get_2d_rope_from_grid(self, grid):
233
+ '''
234
+ grid: (B, 2, N)
235
+ N = H * W
236
+ the first dimension represents width, and the second reprensents height
237
+ e.g., [0. 1. 2. 3. 0. 1. 2. 3. 0. 1. 2. 3.]
238
+ [0. 0. 0. 0. 1. 1. 1. 1. 2. 2. 2. 2.]
239
+ '''
240
+ freqs_h = torch.einsum('..., f -> ... f', grid[:, 0], self.freqs_h)
241
+ freqs_h = repeat(freqs_h, '... n -> ... (n r)', r = 2)
242
+ freqs_w = torch.einsum('..., f -> ... f', grid[:, 1], self.freqs_w)
243
+ freqs_w = repeat(freqs_w, '... n -> ... (n r)', r = 2)
244
+
245
+ freqs = torch.cat([freqs_h, freqs_w], dim=-1) # (B, N, D)
246
+
247
+ if self.custom_freqs == 'yarn':
248
+ freqs_cos = freqs.cos() * self.mscale
249
+ freqs_sin = freqs.sin() * self.mscale
250
+ elif self.custom_freqs in ['ntk-aware-pro1', 'scale1']:
251
+ freqs_cos = freqs.cos() * self.proportion1
252
+ freqs_sin = freqs.sin() * self.proportion1
253
+ elif self.custom_freqs in ['ntk-aware-pro2', 'scale2']:
254
+ freqs_cos = freqs.cos() * self.proportion2
255
+ freqs_sin = freqs.sin() * self.proportion2
256
+ else:
257
+ freqs_cos = freqs.cos()
258
+ freqs_sin = freqs.sin()
259
+
260
+ return freqs_cos, freqs_sin
261
+
262
+ @lru_cache()
263
+ def get_cached_2d_rope_from_grid(self, grid: torch.Tensor):
264
+ '''
265
+ grid: (B, 2, N)
266
+ N = H * W
267
+ the first dimension represents width, and the second reprensents height
268
+ e.g., [0. 1. 2. 3. 0. 1. 2. 3. 0. 1. 2. 3.]
269
+ [0. 0. 0. 0. 1. 1. 1. 1. 2. 2. 2. 2.]
270
+ '''
271
+ if len(grid.shape) == 3: # (B, 2, N)
272
+ freqs_h, freqs_w = self.freqs_h_cached[grid[:, 0]], self.freqs_w_cached[grid[:, 1]]
273
+ elif len(grid.shape) == 2: # (2, N)
274
+ freqs_h, freqs_w = self.freqs_h_cached[grid[0]], self.freqs_w_cached[grid[1]]
275
+ freqs = torch.cat([freqs_h, freqs_w], dim=-1) # (B, N, D)
276
+
277
+ if self.custom_freqs == 'yarn':
278
+ freqs_cos = freqs.cos() * self.mscale
279
+ freqs_sin = freqs.sin() * self.mscale
280
+ elif self.custom_freqs in ['ntk-aware-pro1', 'scale1']:
281
+ freqs_cos = freqs.cos() * self.proportion1
282
+ freqs_sin = freqs.sin() * self.proportion1
283
+ elif self.custom_freqs in ['ntk-aware-pro2', 'scale2']:
284
+ freqs_cos = freqs.cos() * self.proportion2
285
+ freqs_sin = freqs.sin() * self.proportion2
286
+ else:
287
+ freqs_cos = freqs.cos()
288
+ freqs_sin = freqs.sin()
289
+
290
+ return freqs_cos, freqs_sin
291
+
292
+
293
+ def forward(self, x, grid):
294
+ '''
295
+ x: (B, n_head, N, D)
296
+ grid: (B, 2, N)
297
+ '''
298
+ # freqs_cos, freqs_sin = self.get_2d_rope_from_grid(grid)
299
+ # freqs_cos, freqs_sin = freqs_cos.unsqueeze(1), freqs_sin.unsqueeze(1)
300
+ # using cache to accelerate, this is the same with the above codes:
301
+ freqs_cos, freqs_sin = self.get_cached_2d_rope_from_grid(grid)
302
+ freqs_cos, freqs_sin = freqs_cos.unsqueeze(1), freqs_sin.unsqueeze(1)
303
+ return x * freqs_cos + rotate_half(x) * freqs_sin
304
+
305
+
tim/models/utils/text_encoders.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from transformers import T5EncoderModel, AutoModelForCausalLM, AutoTokenizer
4
+
5
+
6
+ # load text-encoder
7
+ def load_text_encoder(text_encoder_dir, device, weight_dtype):
8
+ os.environ["TOKENIZERS_PARALLELISM"] = "true"
9
+ tokenizer = AutoTokenizer.from_pretrained(text_encoder_dir)
10
+ if "gemma" in text_encoder_dir:
11
+ tokenizer.padding_side = "right"
12
+ text_encoder = AutoModelForCausalLM.from_pretrained(
13
+ text_encoder_dir,
14
+ attn_implementation="flash_attention_2",
15
+ device_map="cpu",
16
+ torch_dtype=weight_dtype,
17
+ ).model
18
+ elif "t5" in text_encoder_dir:
19
+ text_encoder = T5EncoderModel.from_pretrained(
20
+ text_encoder_dir,
21
+ attn_implementation="sdpa",
22
+ device_map="cpu",
23
+ torch_dtype=weight_dtype,
24
+ )
25
+ else:
26
+ raise NotImplementedError
27
+ text_encoder.requires_grad_(False)
28
+ text_encoder = text_encoder.eval().to(device=device, dtype=weight_dtype)
29
+
30
+ return text_encoder, tokenizer
31
+
32
+
33
+ def encode_prompt(
34
+ tokenizer,
35
+ text_encoder,
36
+ device,
37
+ weight_dtype,
38
+ captions,
39
+ use_last_hidden_state,
40
+ max_seq_length=256,
41
+ ):
42
+ text_inputs = tokenizer(
43
+ captions,
44
+ padding="max_length",
45
+ max_length=max_seq_length,
46
+ truncation=True,
47
+ return_tensors="pt",
48
+ )
49
+ text_input_ids = text_inputs.input_ids.to(device)
50
+ prompt_masks = text_inputs.attention_mask.to(device)
51
+ with torch.no_grad(), torch.autocast("cuda", dtype=weight_dtype):
52
+ results = text_encoder(
53
+ input_ids=text_input_ids,
54
+ attention_mask=prompt_masks,
55
+ output_hidden_states=True,
56
+ )
57
+
58
+ if use_last_hidden_state:
59
+ prompt_embeds = results.last_hidden_state
60
+ else: # from Imagen paper
61
+ prompt_embeds = results.hidden_states[-2]
62
+
63
+ return prompt_embeds, prompt_masks