Spaces:
Paused
Paused
Duplicate from mpatel57/WOUAF-Text-to-Image
Browse filesCo-authored-by: Maitreya Patel <[email protected]>
- .gitattributes +34 -0
- .gitignore +162 -0
- README.md +14 -0
- app.py +214 -0
- attribution.py +190 -0
- checkpoints/.gitattributes +3 -0
- checkpoints/decoding_network.pth +3 -0
- checkpoints/mapping_network.pth +3 -0
- checkpoints/vae_decoder.pth +3 -0
- customization.py +306 -0
- dnnlib/__init__.py +9 -0
- dnnlib/util.py +477 -0
- requirements.txt +7 -0
- torch_utils/__init__.py +9 -0
- torch_utils/custom_ops.py +126 -0
- torch_utils/misc.py +262 -0
- torch_utils/ops/__init__.py +9 -0
- torch_utils/ops/bias_act.cpp +99 -0
- torch_utils/ops/bias_act.cu +173 -0
- torch_utils/ops/bias_act.h +38 -0
- torch_utils/ops/bias_act.py +212 -0
- torch_utils/ops/conv2d_gradfix.py +170 -0
- torch_utils/ops/conv2d_resample.py +156 -0
- torch_utils/ops/fma.py +60 -0
- torch_utils/ops/grid_sample_gradfix.py +83 -0
- torch_utils/ops/upfirdn2d.cpp +103 -0
- torch_utils/ops/upfirdn2d.cu +350 -0
- torch_utils/ops/upfirdn2d.h +59 -0
- torch_utils/ops/upfirdn2d.py +384 -0
- torch_utils/persistence.py +251 -0
- torch_utils/training_stats.py +268 -0
.gitattributes
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio_cached_examples/
|
2 |
+
|
3 |
+
# Byte-compiled / optimized / DLL files
|
4 |
+
__pycache__/
|
5 |
+
*.py[cod]
|
6 |
+
*$py.class
|
7 |
+
|
8 |
+
# C extensions
|
9 |
+
*.so
|
10 |
+
|
11 |
+
# Distribution / packaging
|
12 |
+
.Python
|
13 |
+
build/
|
14 |
+
develop-eggs/
|
15 |
+
dist/
|
16 |
+
downloads/
|
17 |
+
eggs/
|
18 |
+
.eggs/
|
19 |
+
lib/
|
20 |
+
lib64/
|
21 |
+
parts/
|
22 |
+
sdist/
|
23 |
+
var/
|
24 |
+
wheels/
|
25 |
+
share/python-wheels/
|
26 |
+
*.egg-info/
|
27 |
+
.installed.cfg
|
28 |
+
*.egg
|
29 |
+
MANIFEST
|
30 |
+
|
31 |
+
# PyInstaller
|
32 |
+
# Usually these files are written by a python script from a template
|
33 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
34 |
+
*.manifest
|
35 |
+
*.spec
|
36 |
+
|
37 |
+
# Installer logs
|
38 |
+
pip-log.txt
|
39 |
+
pip-delete-this-directory.txt
|
40 |
+
|
41 |
+
# Unit test / coverage reports
|
42 |
+
htmlcov/
|
43 |
+
.tox/
|
44 |
+
.nox/
|
45 |
+
.coverage
|
46 |
+
.coverage.*
|
47 |
+
.cache
|
48 |
+
nosetests.xml
|
49 |
+
coverage.xml
|
50 |
+
*.cover
|
51 |
+
*.py,cover
|
52 |
+
.hypothesis/
|
53 |
+
.pytest_cache/
|
54 |
+
cover/
|
55 |
+
|
56 |
+
# Translations
|
57 |
+
*.mo
|
58 |
+
*.pot
|
59 |
+
|
60 |
+
# Django stuff:
|
61 |
+
*.log
|
62 |
+
local_settings.py
|
63 |
+
db.sqlite3
|
64 |
+
db.sqlite3-journal
|
65 |
+
|
66 |
+
# Flask stuff:
|
67 |
+
instance/
|
68 |
+
.webassets-cache
|
69 |
+
|
70 |
+
# Scrapy stuff:
|
71 |
+
.scrapy
|
72 |
+
|
73 |
+
# Sphinx documentation
|
74 |
+
docs/_build/
|
75 |
+
|
76 |
+
# PyBuilder
|
77 |
+
.pybuilder/
|
78 |
+
target/
|
79 |
+
|
80 |
+
# Jupyter Notebook
|
81 |
+
.ipynb_checkpoints
|
82 |
+
|
83 |
+
# IPython
|
84 |
+
profile_default/
|
85 |
+
ipython_config.py
|
86 |
+
|
87 |
+
# pyenv
|
88 |
+
# For a library or package, you might want to ignore these files since the code is
|
89 |
+
# intended to run in multiple environments; otherwise, check them in:
|
90 |
+
# .python-version
|
91 |
+
|
92 |
+
# pipenv
|
93 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
94 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
95 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
96 |
+
# install all needed dependencies.
|
97 |
+
#Pipfile.lock
|
98 |
+
|
99 |
+
# poetry
|
100 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
101 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
102 |
+
# commonly ignored for libraries.
|
103 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
104 |
+
#poetry.lock
|
105 |
+
|
106 |
+
# pdm
|
107 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
108 |
+
#pdm.lock
|
109 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
110 |
+
# in version control.
|
111 |
+
# https://pdm.fming.dev/#use-with-ide
|
112 |
+
.pdm.toml
|
113 |
+
|
114 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
115 |
+
__pypackages__/
|
116 |
+
|
117 |
+
# Celery stuff
|
118 |
+
celerybeat-schedule
|
119 |
+
celerybeat.pid
|
120 |
+
|
121 |
+
# SageMath parsed files
|
122 |
+
*.sage.py
|
123 |
+
|
124 |
+
# Environments
|
125 |
+
.env
|
126 |
+
.venv
|
127 |
+
env/
|
128 |
+
venv/
|
129 |
+
ENV/
|
130 |
+
env.bak/
|
131 |
+
venv.bak/
|
132 |
+
|
133 |
+
# Spyder project settings
|
134 |
+
.spyderproject
|
135 |
+
.spyproject
|
136 |
+
|
137 |
+
# Rope project settings
|
138 |
+
.ropeproject
|
139 |
+
|
140 |
+
# mkdocs documentation
|
141 |
+
/site
|
142 |
+
|
143 |
+
# mypy
|
144 |
+
.mypy_cache/
|
145 |
+
.dmypy.json
|
146 |
+
dmypy.json
|
147 |
+
|
148 |
+
# Pyre type checker
|
149 |
+
.pyre/
|
150 |
+
|
151 |
+
# pytype static type analyzer
|
152 |
+
.pytype/
|
153 |
+
|
154 |
+
# Cython debug symbols
|
155 |
+
cython_debug/
|
156 |
+
|
157 |
+
# PyCharm
|
158 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
159 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
160 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
161 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
162 |
+
#.idea/
|
README.md
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: WOUAF Text To Image
|
3 |
+
emoji: 🔥
|
4 |
+
colorFrom: purple
|
5 |
+
colorTo: green
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 3.34.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: mit
|
11 |
+
duplicated_from: mpatel57/WOUAF-Text-to-Image
|
12 |
+
---
|
13 |
+
|
14 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from PIL import Image
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import re
|
6 |
+
import os
|
7 |
+
import requests
|
8 |
+
|
9 |
+
from customization import customize_vae_decoder
|
10 |
+
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel, DDIMScheduler, EulerDiscreteScheduler
|
11 |
+
from torchvision import transforms
|
12 |
+
from attribution import MappingNetwork
|
13 |
+
|
14 |
+
import math
|
15 |
+
from typing import List
|
16 |
+
from PIL import Image, ImageChops
|
17 |
+
import numpy as np
|
18 |
+
import torch
|
19 |
+
|
20 |
+
|
21 |
+
PRETRAINED_MODEL_NAME_OR_PATH = "./checkpoints/"
|
22 |
+
|
23 |
+
|
24 |
+
def get_image_grid(images: List[Image.Image]) -> Image:
|
25 |
+
num_images = len(images)
|
26 |
+
cols = 3#int(math.ceil(math.sqrt(num_images)))
|
27 |
+
rows = 1#int(math.ceil(num_images / cols))
|
28 |
+
width, height = images[0].size
|
29 |
+
grid_image = Image.new('RGB', (cols * width, rows * height))
|
30 |
+
for i, img in enumerate(images):
|
31 |
+
x = i % cols
|
32 |
+
y = i // cols
|
33 |
+
grid_image.paste(img, (x * width, y * height))
|
34 |
+
return grid_image
|
35 |
+
|
36 |
+
|
37 |
+
class AttributionModel:
|
38 |
+
def __init__(self):
|
39 |
+
is_cuda = False
|
40 |
+
if torch.cuda.is_available():
|
41 |
+
is_cuda = True
|
42 |
+
|
43 |
+
scheduler = EulerDiscreteScheduler.from_pretrained('stabilityai/stable-diffusion-2', subfolder="scheduler")
|
44 |
+
self.pipe = StableDiffusionPipeline.from_pretrained('stabilityai/stable-diffusion-2', scheduler=scheduler)#, safety_checker=None, torch_dtype=torch.float16)
|
45 |
+
if is_cuda:
|
46 |
+
self.pipe = self.pipe.to("cuda")
|
47 |
+
self.resize_transform = transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR)
|
48 |
+
self.vae = AutoencoderKL.from_pretrained(
|
49 |
+
'stabilityai/stable-diffusion-2', subfolder="vae"
|
50 |
+
)
|
51 |
+
self.vae = customize_vae_decoder(self.vae, 128, "deqkv", "all", False, 1.0)
|
52 |
+
|
53 |
+
self.mapping_network = MappingNetwork(32, 0, 128, None, num_layers=2, w_avg_beta=None, normalization = False)
|
54 |
+
|
55 |
+
from torchvision.models import resnet50, ResNet50_Weights
|
56 |
+
self.decoding_network = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
|
57 |
+
self.decoding_network.fc = torch.nn.Linear(2048,32)
|
58 |
+
|
59 |
+
self.vae.decoder.load_state_dict(torch.load(os.path.join(PRETRAINED_MODEL_NAME_OR_PATH, 'vae_decoder.pth')))
|
60 |
+
self.mapping_network.load_state_dict(torch.load(os.path.join(PRETRAINED_MODEL_NAME_OR_PATH, 'mapping_network.pth')))
|
61 |
+
self.decoding_network.load_state_dict(torch.load(os.path.join(PRETRAINED_MODEL_NAME_OR_PATH, 'decoding_network.pth')))
|
62 |
+
|
63 |
+
if is_cuda:
|
64 |
+
self.vae = self.vae.to("cuda")
|
65 |
+
self.mapping_network = self.mapping_network.to("cuda")
|
66 |
+
self.decoding_network = self.decoding_network.to("cuda")
|
67 |
+
|
68 |
+
self.test_norm = transforms.Compose(
|
69 |
+
[
|
70 |
+
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
71 |
+
]
|
72 |
+
)
|
73 |
+
|
74 |
+
def infer(self, prompt, negative, steps, guidance_scale):
|
75 |
+
with torch.no_grad():
|
76 |
+
out_latents = self.pipe([prompt], negative_prompt=[negative], output_type="latent", num_inference_steps=steps, guidance_scale=guidance_scale).images
|
77 |
+
image_attr = self.inference_with_attribution(out_latents)
|
78 |
+
image_attr_pil = self.pipe.numpy_to_pil(image_attr[0])
|
79 |
+
|
80 |
+
image_org = self.inference_without_attribution(out_latents)
|
81 |
+
image_org_pil = self.pipe.numpy_to_pil(image_org[0])
|
82 |
+
|
83 |
+
# image_diff_pil = self.pipe.numpy_to_pil(image_attr[0] - image_org[0])
|
84 |
+
diff_factor = 5
|
85 |
+
image_diff_pil = ImageChops.difference(image_org_pil[0], image_attr_pil[0]).convert("RGB", (diff_factor,0,0,0,0,diff_factor,0,0,0,0,diff_factor,0))
|
86 |
+
|
87 |
+
return image_org_pil[0], image_attr_pil[0], image_diff_pil
|
88 |
+
|
89 |
+
def inference_without_attribution(self, latents):
|
90 |
+
latents = 1 / 0.18215 * latents
|
91 |
+
with torch.no_grad():
|
92 |
+
image = self.pipe.vae.decode(latents).sample
|
93 |
+
image = image.clamp(-1,1)
|
94 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
95 |
+
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
96 |
+
return image
|
97 |
+
|
98 |
+
def get_phis(self, phi_dimension, batch_size ,eps = 1e-8):
|
99 |
+
phi_length = phi_dimension
|
100 |
+
b = batch_size
|
101 |
+
phi = torch.empty(b,phi_length).uniform_(0,1)
|
102 |
+
return torch.bernoulli(phi) + eps
|
103 |
+
|
104 |
+
|
105 |
+
def inference_with_attribution(self, latents, key=None):
|
106 |
+
if key==None:
|
107 |
+
key = self.get_phis(32, 1)
|
108 |
+
|
109 |
+
latents = 1 / 0.18215 * latents
|
110 |
+
with torch.no_grad():
|
111 |
+
image = self.vae.decode(latents, self.mapping_network(key.cuda())).sample
|
112 |
+
image = image.clamp(-1,1)
|
113 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
114 |
+
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
115 |
+
return image
|
116 |
+
|
117 |
+
def postprocess(self, image):
|
118 |
+
image = self.resize_transform(image)
|
119 |
+
return image
|
120 |
+
|
121 |
+
def detect_key(self, image):
|
122 |
+
reconstructed_keys = self.decoding_network(self.test_norm((image / 2 + 0.5).clamp(0, 1)))
|
123 |
+
return reconstructed_keys
|
124 |
+
|
125 |
+
|
126 |
+
attribution_model = AttributionModel()
|
127 |
+
def get_images(prompt, negative, steps, guidence_scale):
|
128 |
+
x1, x2, x3 = attribution_model.infer(prompt, negative, steps, guidence_scale)
|
129 |
+
return [x1, x2, x3]
|
130 |
+
|
131 |
+
|
132 |
+
image_examples = [
|
133 |
+
["A pikachu fine dining with a view to the Eiffel Tower", "low quality", 50, 10],
|
134 |
+
["A mecha robot in a favela in expressionist style", "low quality, 3d, photorealistic", 50, 10]
|
135 |
+
]
|
136 |
+
|
137 |
+
with gr.Blocks() as demo:
|
138 |
+
gr.Markdown(
|
139 |
+
"""<h1 style="text-align: center;"><b>WOUAF:
|
140 |
+
Weight Modulation for User Attribution and Fingerprinting in Text-to-Image Diffusion Models</b> <br> <a href="https://wouaf.vercel.app">Project Page</a></h1>""")
|
141 |
+
|
142 |
+
with gr.Row(elem_id="prompt-container").style(mobile_collapse=False, equal_height=True):
|
143 |
+
with gr.Column():
|
144 |
+
text = gr.Textbox(
|
145 |
+
label="Enter your prompt",
|
146 |
+
show_label=False,
|
147 |
+
max_lines=1,
|
148 |
+
placeholder="Enter your prompt",
|
149 |
+
elem_id="prompt-text-input",
|
150 |
+
).style(
|
151 |
+
border=(True, False, True, True),
|
152 |
+
rounded=(True, False, False, True),
|
153 |
+
container=False,
|
154 |
+
)
|
155 |
+
negative = gr.Textbox(
|
156 |
+
label="Enter your negative prompt",
|
157 |
+
show_label=False,
|
158 |
+
max_lines=1,
|
159 |
+
placeholder="Enter a negative prompt",
|
160 |
+
elem_id="negative-prompt-text-input",
|
161 |
+
).style(
|
162 |
+
border=(True, False, True, True),
|
163 |
+
rounded=(True, False, False, True),
|
164 |
+
container=False,
|
165 |
+
)
|
166 |
+
|
167 |
+
with gr.Row():
|
168 |
+
steps = gr.Slider(label="Steps", minimum=45, maximum=55, value=50, step=1)
|
169 |
+
guidance_scale = gr.Slider(
|
170 |
+
label="Guidance Scale", minimum=0, maximum=10, value=7.5, step=0.1
|
171 |
+
)
|
172 |
+
|
173 |
+
with gr.Row():
|
174 |
+
btn = gr.Button(value="Generate Image", full_width=False)
|
175 |
+
|
176 |
+
with gr.Row():
|
177 |
+
im_2 = gr.Image(type="pil", label="without attribution")
|
178 |
+
im_3 = gr.Image(type="pil", label="**with** attribution")
|
179 |
+
im_4 = gr.Image(type="pil", label="pixel-wise difference multiplied by 5")
|
180 |
+
|
181 |
+
|
182 |
+
btn.click(get_images, inputs=[text, negative, steps, guidance_scale], outputs=[im_2, im_3, im_4])
|
183 |
+
|
184 |
+
gr.Examples(
|
185 |
+
examples=image_examples,
|
186 |
+
inputs=[text, negative, steps, guidance_scale],
|
187 |
+
outputs=[im_2, im_3, im_4],
|
188 |
+
fn=get_images,
|
189 |
+
cache_examples=True,
|
190 |
+
)
|
191 |
+
|
192 |
+
gr.HTML(
|
193 |
+
"""
|
194 |
+
<div class="footer">
|
195 |
+
<p>Pre-trained model by <a href="https://huggingface.co/stabilityai" style="text-decoration: underline;" target="_blank">StabilityAI</a>
|
196 |
+
</p>
|
197 |
+
<p>
|
198 |
+
Fine-tuned by authors for research purpose.
|
199 |
+
</p>
|
200 |
+
</div>
|
201 |
+
"""
|
202 |
+
)
|
203 |
+
with gr.Accordion(label="Ethics & Privacy", open=False):
|
204 |
+
gr.HTML(
|
205 |
+
"""<div class="acknowledgments">
|
206 |
+
<p><h4>Privacy</h4>
|
207 |
+
We do not collect any images or key data. This demo is designed with sole purpose of fun and reducing misuse of AI.
|
208 |
+
<p><h4>Biases and content acknowledgment</h4>
|
209 |
+
This model will have the same biases as Stable Diffusion V2.1 </div>
|
210 |
+
"""
|
211 |
+
)
|
212 |
+
|
213 |
+
if __name__ == "__main__":
|
214 |
+
demo.launch()
|
attribution.py
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from torch_utils.ops import bias_act
|
4 |
+
from torch_utils import misc
|
5 |
+
|
6 |
+
|
7 |
+
|
8 |
+
def normalize_2nd_moment(x, dim=1, eps=1e-8):
|
9 |
+
return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt()
|
10 |
+
|
11 |
+
|
12 |
+
class FullyConnectedLayer_normal(torch.nn.Module):
|
13 |
+
def __init__(self,
|
14 |
+
in_features, # Number of input features.
|
15 |
+
out_features, # Number of output features.
|
16 |
+
bias = True, # Apply additive bias before the activation function?
|
17 |
+
bias_init = 0, # Initial value for the additive bias.
|
18 |
+
):
|
19 |
+
super().__init__()
|
20 |
+
self.fc = torch.nn.Linear(in_features, out_features, bias=bias)
|
21 |
+
if bias:
|
22 |
+
with torch.no_grad():
|
23 |
+
self.fc.bias.fill_(bias_init)
|
24 |
+
|
25 |
+
def forward(self, x):
|
26 |
+
output = self.fc(x)
|
27 |
+
return output
|
28 |
+
|
29 |
+
|
30 |
+
class MappingNetwork_normal(torch.nn.Module):
|
31 |
+
def __init__(self,
|
32 |
+
in_features, # Number of input features.
|
33 |
+
int_dim,
|
34 |
+
num_layers = 8, # Number of mapping layers.
|
35 |
+
mapping_normalization = False #2nd normalization
|
36 |
+
):
|
37 |
+
super().__init__()
|
38 |
+
layers = [torch.nn.Linear(in_features, int_dim), torch.nn.LeakyReLU(0.2)]
|
39 |
+
for i in range(1, num_layers):
|
40 |
+
layers.append(torch.nn.Linear(int_dim, int_dim))
|
41 |
+
layers.append(torch.nn.LeakyReLU(0.2))
|
42 |
+
|
43 |
+
self.net = torch.nn.Sequential(*layers)
|
44 |
+
self.normalization = mapping_normalization
|
45 |
+
|
46 |
+
def forward(self, x):
|
47 |
+
if self.normalization:
|
48 |
+
x = normalize_2nd_moment(x)
|
49 |
+
output = self.net(x)
|
50 |
+
return output
|
51 |
+
|
52 |
+
|
53 |
+
class DecodingNetwork(torch.nn.Module):
|
54 |
+
def __init__(self,
|
55 |
+
in_features, # Number of input features.
|
56 |
+
out_dim,
|
57 |
+
num_layers = 8, # Number of mapping layers.
|
58 |
+
):
|
59 |
+
super().__init__()
|
60 |
+
layers = []
|
61 |
+
for i in range(num_layers-1):
|
62 |
+
layers.append(torch.nn.Linear(in_features, in_features))
|
63 |
+
layers.append(torch.nn.ReLU())
|
64 |
+
|
65 |
+
layers.append(torch.nn.Linear(in_features, out_dim))
|
66 |
+
|
67 |
+
self.net = torch.nn.Sequential(*layers)
|
68 |
+
|
69 |
+
def forward(self, x):
|
70 |
+
x = torch.nn.functional.normalize(x, dim=1)
|
71 |
+
output = self.net(x)
|
72 |
+
return output
|
73 |
+
|
74 |
+
|
75 |
+
class FullyConnectedLayer(torch.nn.Module):
|
76 |
+
def __init__(self,
|
77 |
+
in_features, # Number of input features.
|
78 |
+
out_features, # Number of output features.
|
79 |
+
bias = True, # Apply additive bias before the activation function?
|
80 |
+
activation = 'linear', # Activation function: 'relu', 'lrelu', etc.
|
81 |
+
lr_multiplier = 1, # Learning rate multiplier.
|
82 |
+
bias_init = 0, # Initial value for the additive bias.
|
83 |
+
):
|
84 |
+
super().__init__()
|
85 |
+
self.activation = activation
|
86 |
+
self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier)
|
87 |
+
self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None
|
88 |
+
self.weight_gain = lr_multiplier / np.sqrt(in_features)
|
89 |
+
self.bias_gain = lr_multiplier
|
90 |
+
|
91 |
+
def forward(self, x):
|
92 |
+
w = self.weight.to(x.dtype) * self.weight_gain
|
93 |
+
b = self.bias
|
94 |
+
if b is not None:
|
95 |
+
b = b.to(x.dtype)
|
96 |
+
if self.bias_gain != 1:
|
97 |
+
b = b * self.bias_gain
|
98 |
+
|
99 |
+
if self.activation == 'linear' and b is not None:
|
100 |
+
x = torch.addmm(b.unsqueeze(0), x, w.t())
|
101 |
+
else:
|
102 |
+
x = x.matmul(w.t())
|
103 |
+
x = bias_act.bias_act(x, b, act=self.activation)
|
104 |
+
return x
|
105 |
+
|
106 |
+
|
107 |
+
class MappingNetwork(torch.nn.Module):
|
108 |
+
def __init__(self,
|
109 |
+
z_dim, # Input latent (Z) dimensionality, 0 = no latent.
|
110 |
+
c_dim, # Conditioning label (C) dimensionality, 0 = no label.
|
111 |
+
w_dim, # Intermediate latent (W) dimensionality.
|
112 |
+
num_ws, # Number of intermediate latents to output, None = do not broadcast.
|
113 |
+
num_layers = 8, # Number of mapping layers.
|
114 |
+
embed_features = None, # Label embedding dimensionality, None = same as w_dim.
|
115 |
+
layer_features = None, # Number of intermediate features in the mapping layers, None = same as w_dim.
|
116 |
+
activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
|
117 |
+
lr_multiplier = 0.01, # Learning rate multiplier for the mapping layers.
|
118 |
+
w_avg_beta = 0.995, # Decay for tracking the moving average of W during training, None = do not track.
|
119 |
+
normalization = None # Normalization input using normalize_2nd_moment
|
120 |
+
):
|
121 |
+
super().__init__()
|
122 |
+
self.z_dim = z_dim
|
123 |
+
self.c_dim = c_dim
|
124 |
+
self.w_dim = w_dim
|
125 |
+
self.num_ws = num_ws
|
126 |
+
self.num_layers = num_layers
|
127 |
+
self.w_avg_beta = w_avg_beta
|
128 |
+
self.normalization = normalization
|
129 |
+
|
130 |
+
if embed_features is None:
|
131 |
+
embed_features = w_dim
|
132 |
+
if c_dim == 0:
|
133 |
+
embed_features = 0
|
134 |
+
if layer_features is None:
|
135 |
+
layer_features = w_dim
|
136 |
+
features_list = [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim]
|
137 |
+
|
138 |
+
if c_dim > 0:
|
139 |
+
self.embed = FullyConnectedLayer(c_dim, embed_features)
|
140 |
+
for idx in range(num_layers):
|
141 |
+
in_features = features_list[idx]
|
142 |
+
out_features = features_list[idx + 1]
|
143 |
+
layer = FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier)
|
144 |
+
setattr(self, f'fc{idx}', layer)
|
145 |
+
|
146 |
+
if num_ws is not None and w_avg_beta is not None:
|
147 |
+
self.register_buffer('w_avg', torch.zeros([w_dim]))
|
148 |
+
|
149 |
+
def forward(self, z, c=None, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False):
|
150 |
+
# Embed, normalize, and concat inputs.
|
151 |
+
x = None
|
152 |
+
with torch.autograd.profiler.record_function('input'):
|
153 |
+
if self.z_dim > 0:
|
154 |
+
misc.assert_shape(z, [None, self.z_dim])
|
155 |
+
if self.normalization:
|
156 |
+
x = normalize_2nd_moment(z.to(torch.float32))
|
157 |
+
else:
|
158 |
+
x = z
|
159 |
+
x = z.to(torch.float32)
|
160 |
+
if self.c_dim > 0:
|
161 |
+
raise ValueError("This implementation does not need class index")
|
162 |
+
misc.assert_shape(c, [None, self.c_dim])
|
163 |
+
y = normalize_2nd_moment(self.embed(c.to(torch.float32)))
|
164 |
+
y = self.embed(c.to(torch.float32))
|
165 |
+
x = torch.cat([x, y], dim=1) if x is not None else y
|
166 |
+
|
167 |
+
# Main layers.
|
168 |
+
for idx in range(self.num_layers):
|
169 |
+
layer = getattr(self, f'fc{idx}')
|
170 |
+
x = layer(x)
|
171 |
+
|
172 |
+
# Update moving average of W.
|
173 |
+
if self.w_avg_beta is not None and self.training and not skip_w_avg_update:
|
174 |
+
with torch.autograd.profiler.record_function('update_w_avg'):
|
175 |
+
self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta))
|
176 |
+
|
177 |
+
# Broadcast.
|
178 |
+
if self.num_ws is not None:
|
179 |
+
with torch.autograd.profiler.record_function('broadcast'):
|
180 |
+
x = x.unsqueeze(1).repeat([1, self.num_ws, 1])
|
181 |
+
|
182 |
+
# Apply truncation.
|
183 |
+
if truncation_psi != 1:
|
184 |
+
with torch.autograd.profiler.record_function('truncate'):
|
185 |
+
assert self.w_avg_beta is not None
|
186 |
+
if self.num_ws is None or truncation_cutoff is None:
|
187 |
+
x = self.w_avg.lerp(x, truncation_psi)
|
188 |
+
else:
|
189 |
+
x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi)
|
190 |
+
return x
|
checkpoints/.gitattributes
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
vae_decoder.pth filter=lfs diff=lfs merge=lfs -text
|
2 |
+
mapping_network.pth filter=lfs diff=lfs merge=lfs -text
|
3 |
+
decoding_network.pth filter=lfs diff=lfs merge=lfs -text
|
checkpoints/decoding_network.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8c127e07ca3746a35ede928e6fa8cc9a9a7499ab972c13abd417e2f2bbe38dfa
|
3 |
+
size 94614483
|
checkpoints/mapping_network.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9f9c3902af67da17ef1d65561871a0a9aee2ac9521f048785f3bfa09b8946626
|
3 |
+
size 84687
|
checkpoints/vae_decoder.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:aef96ce90432b1b66c930804d011ff508fcfdc43eba6256987dbae7a094b475a
|
3 |
+
size 204440985
|
customization.py
ADDED
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from dataclasses import dataclass
|
4 |
+
from diffusers.utils import BaseOutput
|
5 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
6 |
+
from diffusers.models.unet_2d_blocks import UNetMidBlock2D, UpDecoderBlock2D, CrossAttnDownBlock2D, DownBlock2D, UNetMidBlock2DCrossAttn, UpBlock2D, CrossAttnUpBlock2D
|
7 |
+
from diffusers.models.resnet import ResnetBlock2D
|
8 |
+
from diffusers.models.attention import AttentionBlock
|
9 |
+
from diffusers.models.cross_attention import CrossAttention
|
10 |
+
from attribution import FullyConnectedLayer
|
11 |
+
import math
|
12 |
+
|
13 |
+
|
14 |
+
def customize_vae_decoder(vae, phi_dimension, modulation, finetune, weight_offset, lr_multiplier):
|
15 |
+
d = 'd' in modulation
|
16 |
+
e = 'e' in modulation
|
17 |
+
q = 'q' in modulation
|
18 |
+
k = 'k' in modulation
|
19 |
+
v = 'v' in modulation
|
20 |
+
|
21 |
+
def add_affine_conv(vaed):
|
22 |
+
if not (d or e):
|
23 |
+
return
|
24 |
+
|
25 |
+
for layer in vaed.children():
|
26 |
+
if type(layer) == ResnetBlock2D:
|
27 |
+
if d:
|
28 |
+
layer.affine_d = FullyConnectedLayer(phi_dimension, layer.conv1.weight.shape[1], lr_multiplier=lr_multiplier, bias_init=1)
|
29 |
+
if e:
|
30 |
+
layer.affine_e = FullyConnectedLayer(phi_dimension, layer.conv2.weight.shape[1], lr_multiplier=lr_multiplier, bias_init=1)
|
31 |
+
else:
|
32 |
+
add_affine_conv(layer)
|
33 |
+
|
34 |
+
def add_affine_attn(vaed):
|
35 |
+
if not (q or k or v):
|
36 |
+
return
|
37 |
+
|
38 |
+
for layer in vaed.children():
|
39 |
+
if type(layer) == AttentionBlock:
|
40 |
+
if q:
|
41 |
+
layer.affine_q = FullyConnectedLayer(phi_dimension, layer.query.weight.shape[1], lr_multiplier=lr_multiplier, bias_init=1)
|
42 |
+
if k:
|
43 |
+
layer.affine_k = FullyConnectedLayer(phi_dimension, layer.key.weight.shape[1], lr_multiplier=lr_multiplier, bias_init=1)
|
44 |
+
if v:
|
45 |
+
layer.affine_v = FullyConnectedLayer(phi_dimension, layer.value.weight.shape[1], lr_multiplier=lr_multiplier, bias_init=1)
|
46 |
+
else:
|
47 |
+
add_affine_attn(layer)
|
48 |
+
|
49 |
+
def impose_grad_condition(vaed, finetune):
|
50 |
+
if finetune == 'all':
|
51 |
+
return
|
52 |
+
|
53 |
+
for name, params in vaed.named_parameters():
|
54 |
+
requires_grad = False
|
55 |
+
if finetune == 'match':
|
56 |
+
d_cond = d and (('resnets' in name and 'conv1' in name) or 'affine_d' in name)
|
57 |
+
e_cond = e and (('resnets' in name and 'conv2' in name) or 'affine_e' in name)
|
58 |
+
q_cond = q and (('attentions' in name and 'query' in name) or 'affine_q' in name)
|
59 |
+
k_cond = k and (('attentions' in name and 'key' in name) or 'affine_k' in name)
|
60 |
+
v_cond = v and (('attentions' in name and 'value' in name) or 'affine_v' in name)
|
61 |
+
if q_cond or k_cond or v_cond or d_cond or e_cond:
|
62 |
+
requires_grad = True
|
63 |
+
params.requires_grad = requires_grad
|
64 |
+
|
65 |
+
def change_forward(vaed, layer_type, new_forward):
|
66 |
+
for layer in vaed.children():
|
67 |
+
if type(layer) == layer_type:
|
68 |
+
bound_method = new_forward.__get__(layer, layer.__class__)
|
69 |
+
setattr(layer, 'forward', bound_method)
|
70 |
+
else:
|
71 |
+
change_forward(layer, layer_type, new_forward)
|
72 |
+
|
73 |
+
def new_forward_MB(self, hidden_states, encoded_fingerprint, temb=None):
|
74 |
+
hidden_states = self.resnets[0]((hidden_states, encoded_fingerprint), temb)
|
75 |
+
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
76 |
+
if attn is not None:
|
77 |
+
hidden_states = attn((hidden_states, encoded_fingerprint))
|
78 |
+
hidden_states = resnet((hidden_states, encoded_fingerprint), temb)
|
79 |
+
|
80 |
+
return hidden_states
|
81 |
+
|
82 |
+
def new_forward_UDB(self, hidden_states, encoded_fingerprint):
|
83 |
+
for resnet in self.resnets:
|
84 |
+
hidden_states = resnet((hidden_states, encoded_fingerprint), temb=None)
|
85 |
+
|
86 |
+
if self.upsamplers is not None:
|
87 |
+
for upsampler in self.upsamplers:
|
88 |
+
hidden_states = upsampler(hidden_states)
|
89 |
+
|
90 |
+
return hidden_states
|
91 |
+
|
92 |
+
def new_forward_RB(self, input_tensor, temb):
|
93 |
+
input_tensor, encoded_fingerprint = input_tensor
|
94 |
+
hidden_states = input_tensor
|
95 |
+
|
96 |
+
hidden_states = self.norm1(hidden_states)
|
97 |
+
hidden_states = self.nonlinearity(hidden_states)
|
98 |
+
|
99 |
+
if self.upsample is not None:
|
100 |
+
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
|
101 |
+
if hidden_states.shape[0] >= 64:
|
102 |
+
input_tensor = input_tensor.contiguous()
|
103 |
+
hidden_states = hidden_states.contiguous()
|
104 |
+
input_tensor = self.upsample(input_tensor)
|
105 |
+
hidden_states = self.upsample(hidden_states)
|
106 |
+
elif self.downsample is not None:
|
107 |
+
input_tensor = self.downsample(input_tensor)
|
108 |
+
hidden_states = self.downsample(hidden_states)
|
109 |
+
|
110 |
+
if d:
|
111 |
+
phis = self.affine_d(encoded_fingerprint)
|
112 |
+
batch_size = phis.shape[0]
|
113 |
+
if not weight_offset:
|
114 |
+
weight = phis.view(batch_size, 1, -1, 1, 1) * self.conv1.weight.unsqueeze(0)
|
115 |
+
else:
|
116 |
+
weight = self.conv1.weight
|
117 |
+
weight_mod = phis.view(batch_size, 1, -1, 1, 1) * self.conv1.weight.unsqueeze(0)
|
118 |
+
weight = weight.unsqueeze(0) + weight_mod
|
119 |
+
hidden_states = F.conv2d(hidden_states.contiguous().view(1, -1, hidden_states.shape[-2], hidden_states.shape[-1]), weight.view(-1, weight.shape[-3], weight.shape[-2], weight.shape[-1]), padding=1, groups=batch_size).view(batch_size, weight.shape[1], hidden_states.shape[-2], hidden_states.shape[-1]) + self.conv1.bias.view(1, -1, 1, 1)
|
120 |
+
else:
|
121 |
+
hidden_states = self.conv1(hidden_states)
|
122 |
+
|
123 |
+
if temb is not None:
|
124 |
+
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
|
125 |
+
|
126 |
+
if temb is not None and self.time_embedding_norm == "default":
|
127 |
+
hidden_states = hidden_states + temb
|
128 |
+
|
129 |
+
hidden_states = self.norm2(hidden_states)
|
130 |
+
|
131 |
+
if temb is not None and self.time_embedding_norm == "scale_shift":
|
132 |
+
scale, shift = torch.chunk(temb, 2, dim=1)
|
133 |
+
hidden_states = hidden_states * (1 + scale) + shift
|
134 |
+
|
135 |
+
hidden_states = self.nonlinearity(hidden_states)
|
136 |
+
|
137 |
+
hidden_states = self.dropout(hidden_states)
|
138 |
+
|
139 |
+
if e:
|
140 |
+
phis = self.affine_e(encoded_fingerprint)
|
141 |
+
batch_size = phis.shape[0]
|
142 |
+
if not weight_offset:
|
143 |
+
weight = phis.view(batch_size, 1, -1, 1, 1) * self.conv2.weight.unsqueeze(0)
|
144 |
+
else:
|
145 |
+
weight = self.conv2.weight
|
146 |
+
weight_mod = phis.view(batch_size, 1, -1, 1, 1) * self.conv2.weight.unsqueeze(0)
|
147 |
+
weight = weight.unsqueeze(0) + weight_mod
|
148 |
+
hidden_states = F.conv2d(hidden_states.contiguous().view(1, -1, hidden_states.shape[-2], hidden_states.shape[-1]), weight.view(-1, weight.shape[-3], weight.shape[-2], weight.shape[-1]), padding=1, groups=batch_size).view(batch_size, weight.shape[1], hidden_states.shape[-2], hidden_states.shape[-1]) + self.conv2.bias.view(1, -1, 1, 1)
|
149 |
+
else:
|
150 |
+
hidden_states = self.conv2(hidden_states)
|
151 |
+
|
152 |
+
if self.conv_shortcut is not None:
|
153 |
+
input_tensor = self.conv_shortcut(input_tensor)
|
154 |
+
|
155 |
+
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
|
156 |
+
|
157 |
+
return output_tensor
|
158 |
+
|
159 |
+
def new_forward_AB(self, hidden_states):
|
160 |
+
hidden_states, encoded_fingerprint = hidden_states
|
161 |
+
residual = hidden_states
|
162 |
+
batch, channel, height, width = hidden_states.shape
|
163 |
+
|
164 |
+
# norm
|
165 |
+
hidden_states = self.group_norm(hidden_states)
|
166 |
+
|
167 |
+
hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
|
168 |
+
|
169 |
+
# proj to q, k, v
|
170 |
+
if q:
|
171 |
+
phis_q = self.affine_q(encoded_fingerprint)
|
172 |
+
if not weight_offset:
|
173 |
+
query_proj = torch.bmm(hidden_states, phis_q.unsqueeze(-1) * self.query.weight.t().unsqueeze(0)) + self.query.bias
|
174 |
+
else:
|
175 |
+
qw = self.query.weight
|
176 |
+
qw_mod = phis_q.unsqueeze(-1) * qw.t().unsqueeze(0)
|
177 |
+
query_proj = torch.bmm(hidden_states, qw.t().unsqueeze(0) + qw_mod) + self.query.bias
|
178 |
+
else:
|
179 |
+
query_proj = self.query(hidden_states)
|
180 |
+
|
181 |
+
if k:
|
182 |
+
phis_k = self.affine_k(encoded_fingerprint)
|
183 |
+
if not weight_offset:
|
184 |
+
key_proj = torch.bmm(hidden_states, phis_k.unsqueeze(-1) * self.key.weight.t().unsqueeze(0)) + self.key.bias
|
185 |
+
else:
|
186 |
+
kw = self.key.weight
|
187 |
+
kw_mod = phis_k.unsqueeze(-1) * kw.t().unsqueeze(0)
|
188 |
+
key_proj = torch.bmm(hidden_states, kw.t().unsqueeze(0) + kw_mod) + self.key.bias
|
189 |
+
else:
|
190 |
+
key_proj = self.key(hidden_states)
|
191 |
+
|
192 |
+
if v:
|
193 |
+
phis_v = self.affine_v(encoded_fingerprint)
|
194 |
+
if not weight_offset:
|
195 |
+
value_proj = torch.bmm(hidden_states, phis_v.unsqueeze(-1) * self.value.weight.t().unsqueeze(0)) + self.value.bias
|
196 |
+
else:
|
197 |
+
vw = self.value.weight
|
198 |
+
vw_mod = phis_v.unsqueeze(-1) * vw.t().unsqueeze(0)
|
199 |
+
value_proj = torch.bmm(hidden_states, vw.t().unsqueeze(0) + vw_mod) + self.value.bias
|
200 |
+
else:
|
201 |
+
value_proj = self.value(hidden_states)
|
202 |
+
|
203 |
+
scale = 1 / math.sqrt(self.channels / self.num_heads)
|
204 |
+
|
205 |
+
query_proj = self.reshape_heads_to_batch_dim(query_proj)
|
206 |
+
key_proj = self.reshape_heads_to_batch_dim(key_proj)
|
207 |
+
value_proj = self.reshape_heads_to_batch_dim(value_proj)
|
208 |
+
|
209 |
+
if self._use_memory_efficient_attention_xformers:
|
210 |
+
# Memory efficient attention
|
211 |
+
hidden_states = xformers.ops.memory_efficient_attention(
|
212 |
+
query_proj, key_proj, value_proj, attn_bias=None, op=self._attention_op
|
213 |
+
)
|
214 |
+
hidden_states = hidden_states.to(query_proj.dtype)
|
215 |
+
else:
|
216 |
+
attention_scores = torch.baddbmm(
|
217 |
+
torch.empty(
|
218 |
+
query_proj.shape[0],
|
219 |
+
query_proj.shape[1],
|
220 |
+
key_proj.shape[1],
|
221 |
+
dtype=query_proj.dtype,
|
222 |
+
device=query_proj.device,
|
223 |
+
),
|
224 |
+
query_proj,
|
225 |
+
key_proj.transpose(-1, -2),
|
226 |
+
beta=0,
|
227 |
+
alpha=scale,
|
228 |
+
)
|
229 |
+
attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
|
230 |
+
hidden_states = torch.bmm(attention_probs, value_proj)
|
231 |
+
|
232 |
+
# reshape hidden_states
|
233 |
+
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
234 |
+
|
235 |
+
# compute next hidden_states
|
236 |
+
hidden_states = self.proj_attn(hidden_states)
|
237 |
+
|
238 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
|
239 |
+
|
240 |
+
# res connect and rescale
|
241 |
+
hidden_states = (hidden_states + residual) / self.rescale_output_factor
|
242 |
+
return hidden_states
|
243 |
+
|
244 |
+
# Reference: https://github.com/huggingface/diffusers
|
245 |
+
def new_forward_vaed(self, z, enconded_fingerprint):
|
246 |
+
sample = z
|
247 |
+
sample = self.conv_in(sample)
|
248 |
+
|
249 |
+
# middle
|
250 |
+
sample = self.mid_block(sample, enconded_fingerprint)
|
251 |
+
|
252 |
+
# up
|
253 |
+
for up_block in self.up_blocks:
|
254 |
+
sample = up_block(sample, enconded_fingerprint)
|
255 |
+
|
256 |
+
# post-process
|
257 |
+
sample = self.conv_norm_out(sample)
|
258 |
+
sample = self.conv_act(sample)
|
259 |
+
sample = self.conv_out(sample)
|
260 |
+
|
261 |
+
return sample
|
262 |
+
|
263 |
+
@dataclass
|
264 |
+
class DecoderOutput(BaseOutput):
|
265 |
+
"""
|
266 |
+
Output of decoding method.
|
267 |
+
Args:
|
268 |
+
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
269 |
+
Decoded output sample of the model. Output of the last layer of the model.
|
270 |
+
"""
|
271 |
+
|
272 |
+
sample: torch.FloatTensor
|
273 |
+
|
274 |
+
def new__decode(self, z: torch.FloatTensor, encoded_fingerprint: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
275 |
+
z = self.post_quant_conv(z)
|
276 |
+
dec = self.decoder(z, encoded_fingerprint)
|
277 |
+
|
278 |
+
if not return_dict:
|
279 |
+
return (dec,)
|
280 |
+
|
281 |
+
return DecoderOutput(sample=dec)
|
282 |
+
|
283 |
+
def new_decode(self, z: torch.FloatTensor, encoded_fingerprint: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
284 |
+
if self.use_slicing and z.shape[0] > 1:
|
285 |
+
decoded_slices = [self._decode(z_slice, encoded_fingerprint).sample for z_slice in z.split(1)]
|
286 |
+
decoded = torch.cat(decoded_slices)
|
287 |
+
else:
|
288 |
+
decoded = self._decode(z, encoded_fingerprint).sample
|
289 |
+
|
290 |
+
if not return_dict:
|
291 |
+
return (decoded,)
|
292 |
+
|
293 |
+
return DecoderOutput(sample=decoded)
|
294 |
+
|
295 |
+
add_affine_conv(vae.decoder)
|
296 |
+
add_affine_attn(vae.decoder)
|
297 |
+
impose_grad_condition(vae.decoder, finetune)
|
298 |
+
change_forward(vae.decoder, UNetMidBlock2D, new_forward_MB)
|
299 |
+
change_forward(vae.decoder, UpDecoderBlock2D, new_forward_UDB)
|
300 |
+
change_forward(vae.decoder, ResnetBlock2D, new_forward_RB)
|
301 |
+
change_forward(vae.decoder, AttentionBlock, new_forward_AB)
|
302 |
+
setattr(vae.decoder, 'forward', new_forward_vaed.__get__(vae.decoder, vae.decoder.__class__))
|
303 |
+
setattr(vae, '_decode', new__decode.__get__(vae, vae.__class__))
|
304 |
+
setattr(vae, 'decode', new_decode.__get__(vae, vae.__class__))
|
305 |
+
|
306 |
+
return vae
|
dnnlib/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
from .util import EasyDict, make_cache_dir_path
|
dnnlib/util.py
ADDED
@@ -0,0 +1,477 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
"""Miscellaneous utility classes and functions."""
|
10 |
+
|
11 |
+
import ctypes
|
12 |
+
import fnmatch
|
13 |
+
import importlib
|
14 |
+
import inspect
|
15 |
+
import numpy as np
|
16 |
+
import os
|
17 |
+
import shutil
|
18 |
+
import sys
|
19 |
+
import types
|
20 |
+
import io
|
21 |
+
import pickle
|
22 |
+
import re
|
23 |
+
import requests
|
24 |
+
import html
|
25 |
+
import hashlib
|
26 |
+
import glob
|
27 |
+
import tempfile
|
28 |
+
import urllib
|
29 |
+
import urllib.request
|
30 |
+
import uuid
|
31 |
+
|
32 |
+
from distutils.util import strtobool
|
33 |
+
from typing import Any, List, Tuple, Union
|
34 |
+
|
35 |
+
|
36 |
+
# Util classes
|
37 |
+
# ------------------------------------------------------------------------------------------
|
38 |
+
|
39 |
+
|
40 |
+
class EasyDict(dict):
|
41 |
+
"""Convenience class that behaves like a dict but allows access with the attribute syntax."""
|
42 |
+
|
43 |
+
def __getattr__(self, name: str) -> Any:
|
44 |
+
try:
|
45 |
+
return self[name]
|
46 |
+
except KeyError:
|
47 |
+
raise AttributeError(name)
|
48 |
+
|
49 |
+
def __setattr__(self, name: str, value: Any) -> None:
|
50 |
+
self[name] = value
|
51 |
+
|
52 |
+
def __delattr__(self, name: str) -> None:
|
53 |
+
del self[name]
|
54 |
+
|
55 |
+
|
56 |
+
class Logger(object):
|
57 |
+
"""Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
|
58 |
+
|
59 |
+
def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True):
|
60 |
+
self.file = None
|
61 |
+
|
62 |
+
if file_name is not None:
|
63 |
+
self.file = open(file_name, file_mode)
|
64 |
+
|
65 |
+
self.should_flush = should_flush
|
66 |
+
self.stdout = sys.stdout
|
67 |
+
self.stderr = sys.stderr
|
68 |
+
|
69 |
+
sys.stdout = self
|
70 |
+
sys.stderr = self
|
71 |
+
|
72 |
+
def __enter__(self) -> "Logger":
|
73 |
+
return self
|
74 |
+
|
75 |
+
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
76 |
+
self.close()
|
77 |
+
|
78 |
+
def write(self, text: Union[str, bytes]) -> None:
|
79 |
+
"""Write text to stdout (and a file) and optionally flush."""
|
80 |
+
if isinstance(text, bytes):
|
81 |
+
text = text.decode()
|
82 |
+
if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
|
83 |
+
return
|
84 |
+
|
85 |
+
if self.file is not None:
|
86 |
+
self.file.write(text)
|
87 |
+
|
88 |
+
self.stdout.write(text)
|
89 |
+
|
90 |
+
if self.should_flush:
|
91 |
+
self.flush()
|
92 |
+
|
93 |
+
def flush(self) -> None:
|
94 |
+
"""Flush written text to both stdout and a file, if open."""
|
95 |
+
if self.file is not None:
|
96 |
+
self.file.flush()
|
97 |
+
|
98 |
+
self.stdout.flush()
|
99 |
+
|
100 |
+
def close(self) -> None:
|
101 |
+
"""Flush, close possible files, and remove stdout/stderr mirroring."""
|
102 |
+
self.flush()
|
103 |
+
|
104 |
+
# if using multiple loggers, prevent closing in wrong order
|
105 |
+
if sys.stdout is self:
|
106 |
+
sys.stdout = self.stdout
|
107 |
+
if sys.stderr is self:
|
108 |
+
sys.stderr = self.stderr
|
109 |
+
|
110 |
+
if self.file is not None:
|
111 |
+
self.file.close()
|
112 |
+
self.file = None
|
113 |
+
|
114 |
+
|
115 |
+
# Cache directories
|
116 |
+
# ------------------------------------------------------------------------------------------
|
117 |
+
|
118 |
+
_dnnlib_cache_dir = None
|
119 |
+
|
120 |
+
def set_cache_dir(path: str) -> None:
|
121 |
+
global _dnnlib_cache_dir
|
122 |
+
_dnnlib_cache_dir = path
|
123 |
+
|
124 |
+
def make_cache_dir_path(*paths: str) -> str:
|
125 |
+
if _dnnlib_cache_dir is not None:
|
126 |
+
return os.path.join(_dnnlib_cache_dir, *paths)
|
127 |
+
if 'DNNLIB_CACHE_DIR' in os.environ:
|
128 |
+
return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
|
129 |
+
if 'HOME' in os.environ:
|
130 |
+
return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
|
131 |
+
if 'USERPROFILE' in os.environ:
|
132 |
+
return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
|
133 |
+
return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
|
134 |
+
|
135 |
+
# Small util functions
|
136 |
+
# ------------------------------------------------------------------------------------------
|
137 |
+
|
138 |
+
|
139 |
+
def format_time(seconds: Union[int, float]) -> str:
|
140 |
+
"""Convert the seconds to human readable string with days, hours, minutes and seconds."""
|
141 |
+
s = int(np.rint(seconds))
|
142 |
+
|
143 |
+
if s < 60:
|
144 |
+
return "{0}s".format(s)
|
145 |
+
elif s < 60 * 60:
|
146 |
+
return "{0}m {1:02}s".format(s // 60, s % 60)
|
147 |
+
elif s < 24 * 60 * 60:
|
148 |
+
return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
|
149 |
+
else:
|
150 |
+
return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
|
151 |
+
|
152 |
+
|
153 |
+
def ask_yes_no(question: str) -> bool:
|
154 |
+
"""Ask the user the question until the user inputs a valid answer."""
|
155 |
+
while True:
|
156 |
+
try:
|
157 |
+
print("{0} [y/n]".format(question))
|
158 |
+
return strtobool(input().lower())
|
159 |
+
except ValueError:
|
160 |
+
pass
|
161 |
+
|
162 |
+
|
163 |
+
def tuple_product(t: Tuple) -> Any:
|
164 |
+
"""Calculate the product of the tuple elements."""
|
165 |
+
result = 1
|
166 |
+
|
167 |
+
for v in t:
|
168 |
+
result *= v
|
169 |
+
|
170 |
+
return result
|
171 |
+
|
172 |
+
|
173 |
+
_str_to_ctype = {
|
174 |
+
"uint8": ctypes.c_ubyte,
|
175 |
+
"uint16": ctypes.c_uint16,
|
176 |
+
"uint32": ctypes.c_uint32,
|
177 |
+
"uint64": ctypes.c_uint64,
|
178 |
+
"int8": ctypes.c_byte,
|
179 |
+
"int16": ctypes.c_int16,
|
180 |
+
"int32": ctypes.c_int32,
|
181 |
+
"int64": ctypes.c_int64,
|
182 |
+
"float32": ctypes.c_float,
|
183 |
+
"float64": ctypes.c_double
|
184 |
+
}
|
185 |
+
|
186 |
+
|
187 |
+
def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
|
188 |
+
"""Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
|
189 |
+
type_str = None
|
190 |
+
|
191 |
+
if isinstance(type_obj, str):
|
192 |
+
type_str = type_obj
|
193 |
+
elif hasattr(type_obj, "__name__"):
|
194 |
+
type_str = type_obj.__name__
|
195 |
+
elif hasattr(type_obj, "name"):
|
196 |
+
type_str = type_obj.name
|
197 |
+
else:
|
198 |
+
raise RuntimeError("Cannot infer type name from input")
|
199 |
+
|
200 |
+
assert type_str in _str_to_ctype.keys()
|
201 |
+
|
202 |
+
my_dtype = np.dtype(type_str)
|
203 |
+
my_ctype = _str_to_ctype[type_str]
|
204 |
+
|
205 |
+
assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
|
206 |
+
|
207 |
+
return my_dtype, my_ctype
|
208 |
+
|
209 |
+
|
210 |
+
def is_pickleable(obj: Any) -> bool:
|
211 |
+
try:
|
212 |
+
with io.BytesIO() as stream:
|
213 |
+
pickle.dump(obj, stream)
|
214 |
+
return True
|
215 |
+
except:
|
216 |
+
return False
|
217 |
+
|
218 |
+
|
219 |
+
# Functionality to import modules/objects by name, and call functions by name
|
220 |
+
# ------------------------------------------------------------------------------------------
|
221 |
+
|
222 |
+
def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
|
223 |
+
"""Searches for the underlying module behind the name to some python object.
|
224 |
+
Returns the module and the object name (original name with module part removed)."""
|
225 |
+
|
226 |
+
# allow convenience shorthands, substitute them by full names
|
227 |
+
obj_name = re.sub("^np.", "numpy.", obj_name)
|
228 |
+
obj_name = re.sub("^tf.", "tensorflow.", obj_name)
|
229 |
+
|
230 |
+
# list alternatives for (module_name, local_obj_name)
|
231 |
+
parts = obj_name.split(".")
|
232 |
+
name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
|
233 |
+
|
234 |
+
# try each alternative in turn
|
235 |
+
for module_name, local_obj_name in name_pairs:
|
236 |
+
try:
|
237 |
+
module = importlib.import_module(module_name) # may raise ImportError
|
238 |
+
get_obj_from_module(module, local_obj_name) # may raise AttributeError
|
239 |
+
return module, local_obj_name
|
240 |
+
except:
|
241 |
+
pass
|
242 |
+
|
243 |
+
# maybe some of the modules themselves contain errors?
|
244 |
+
for module_name, _local_obj_name in name_pairs:
|
245 |
+
try:
|
246 |
+
importlib.import_module(module_name) # may raise ImportError
|
247 |
+
except ImportError:
|
248 |
+
if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
|
249 |
+
raise
|
250 |
+
|
251 |
+
# maybe the requested attribute is missing?
|
252 |
+
for module_name, local_obj_name in name_pairs:
|
253 |
+
try:
|
254 |
+
module = importlib.import_module(module_name) # may raise ImportError
|
255 |
+
get_obj_from_module(module, local_obj_name) # may raise AttributeError
|
256 |
+
except ImportError:
|
257 |
+
pass
|
258 |
+
|
259 |
+
# we are out of luck, but we have no idea why
|
260 |
+
raise ImportError(obj_name)
|
261 |
+
|
262 |
+
|
263 |
+
def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
|
264 |
+
"""Traverses the object name and returns the last (rightmost) python object."""
|
265 |
+
if obj_name == '':
|
266 |
+
return module
|
267 |
+
obj = module
|
268 |
+
for part in obj_name.split("."):
|
269 |
+
obj = getattr(obj, part)
|
270 |
+
return obj
|
271 |
+
|
272 |
+
|
273 |
+
def get_obj_by_name(name: str) -> Any:
|
274 |
+
"""Finds the python object with the given name."""
|
275 |
+
module, obj_name = get_module_from_obj_name(name)
|
276 |
+
return get_obj_from_module(module, obj_name)
|
277 |
+
|
278 |
+
|
279 |
+
def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
|
280 |
+
"""Finds the python object with the given name and calls it as a function."""
|
281 |
+
assert func_name is not None
|
282 |
+
func_obj = get_obj_by_name(func_name)
|
283 |
+
assert callable(func_obj)
|
284 |
+
return func_obj(*args, **kwargs)
|
285 |
+
|
286 |
+
|
287 |
+
def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:
|
288 |
+
"""Finds the python class with the given name and constructs it with the given arguments."""
|
289 |
+
return call_func_by_name(*args, func_name=class_name, **kwargs)
|
290 |
+
|
291 |
+
|
292 |
+
def get_module_dir_by_obj_name(obj_name: str) -> str:
|
293 |
+
"""Get the directory path of the module containing the given object name."""
|
294 |
+
module, _ = get_module_from_obj_name(obj_name)
|
295 |
+
return os.path.dirname(inspect.getfile(module))
|
296 |
+
|
297 |
+
|
298 |
+
def is_top_level_function(obj: Any) -> bool:
|
299 |
+
"""Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
|
300 |
+
return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
|
301 |
+
|
302 |
+
|
303 |
+
def get_top_level_function_name(obj: Any) -> str:
|
304 |
+
"""Return the fully-qualified name of a top-level function."""
|
305 |
+
assert is_top_level_function(obj)
|
306 |
+
module = obj.__module__
|
307 |
+
if module == '__main__':
|
308 |
+
module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0]
|
309 |
+
return module + "." + obj.__name__
|
310 |
+
|
311 |
+
|
312 |
+
# File system helpers
|
313 |
+
# ------------------------------------------------------------------------------------------
|
314 |
+
|
315 |
+
def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
|
316 |
+
"""List all files recursively in a given directory while ignoring given file and directory names.
|
317 |
+
Returns list of tuples containing both absolute and relative paths."""
|
318 |
+
assert os.path.isdir(dir_path)
|
319 |
+
base_name = os.path.basename(os.path.normpath(dir_path))
|
320 |
+
|
321 |
+
if ignores is None:
|
322 |
+
ignores = []
|
323 |
+
|
324 |
+
result = []
|
325 |
+
|
326 |
+
for root, dirs, files in os.walk(dir_path, topdown=True):
|
327 |
+
for ignore_ in ignores:
|
328 |
+
dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
|
329 |
+
|
330 |
+
# dirs need to be edited in-place
|
331 |
+
for d in dirs_to_remove:
|
332 |
+
dirs.remove(d)
|
333 |
+
|
334 |
+
files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
|
335 |
+
|
336 |
+
absolute_paths = [os.path.join(root, f) for f in files]
|
337 |
+
relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
|
338 |
+
|
339 |
+
if add_base_to_relative:
|
340 |
+
relative_paths = [os.path.join(base_name, p) for p in relative_paths]
|
341 |
+
|
342 |
+
assert len(absolute_paths) == len(relative_paths)
|
343 |
+
result += zip(absolute_paths, relative_paths)
|
344 |
+
|
345 |
+
return result
|
346 |
+
|
347 |
+
|
348 |
+
def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
|
349 |
+
"""Takes in a list of tuples of (src, dst) paths and copies files.
|
350 |
+
Will create all necessary directories."""
|
351 |
+
for file in files:
|
352 |
+
target_dir_name = os.path.dirname(file[1])
|
353 |
+
|
354 |
+
# will create all intermediate-level directories
|
355 |
+
if not os.path.exists(target_dir_name):
|
356 |
+
os.makedirs(target_dir_name)
|
357 |
+
|
358 |
+
shutil.copyfile(file[0], file[1])
|
359 |
+
|
360 |
+
|
361 |
+
# URL helpers
|
362 |
+
# ------------------------------------------------------------------------------------------
|
363 |
+
|
364 |
+
def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
|
365 |
+
"""Determine whether the given object is a valid URL string."""
|
366 |
+
if not isinstance(obj, str) or not "://" in obj:
|
367 |
+
return False
|
368 |
+
if allow_file_urls and obj.startswith('file://'):
|
369 |
+
return True
|
370 |
+
try:
|
371 |
+
res = requests.compat.urlparse(obj)
|
372 |
+
if not res.scheme or not res.netloc or not "." in res.netloc:
|
373 |
+
return False
|
374 |
+
res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
|
375 |
+
if not res.scheme or not res.netloc or not "." in res.netloc:
|
376 |
+
return False
|
377 |
+
except:
|
378 |
+
return False
|
379 |
+
return True
|
380 |
+
|
381 |
+
|
382 |
+
def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any:
|
383 |
+
"""Download the given URL and return a binary-mode file object to access the data."""
|
384 |
+
assert num_attempts >= 1
|
385 |
+
assert not (return_filename and (not cache))
|
386 |
+
|
387 |
+
# Doesn't look like an URL scheme so interpret it as a local filename.
|
388 |
+
if not re.match('^[a-z]+://', url):
|
389 |
+
return url if return_filename else open(url, "rb")
|
390 |
+
|
391 |
+
# Handle file URLs. This code handles unusual file:// patterns that
|
392 |
+
# arise on Windows:
|
393 |
+
#
|
394 |
+
# file:///c:/foo.txt
|
395 |
+
#
|
396 |
+
# which would translate to a local '/c:/foo.txt' filename that's
|
397 |
+
# invalid. Drop the forward slash for such pathnames.
|
398 |
+
#
|
399 |
+
# If you touch this code path, you should test it on both Linux and
|
400 |
+
# Windows.
|
401 |
+
#
|
402 |
+
# Some internet resources suggest using urllib.request.url2pathname() but
|
403 |
+
# but that converts forward slashes to backslashes and this causes
|
404 |
+
# its own set of problems.
|
405 |
+
if url.startswith('file://'):
|
406 |
+
filename = urllib.parse.urlparse(url).path
|
407 |
+
if re.match(r'^/[a-zA-Z]:', filename):
|
408 |
+
filename = filename[1:]
|
409 |
+
return filename if return_filename else open(filename, "rb")
|
410 |
+
|
411 |
+
assert is_url(url)
|
412 |
+
|
413 |
+
# Lookup from cache.
|
414 |
+
if cache_dir is None:
|
415 |
+
cache_dir = make_cache_dir_path('downloads')
|
416 |
+
|
417 |
+
url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
|
418 |
+
if cache:
|
419 |
+
cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
|
420 |
+
if len(cache_files) == 1:
|
421 |
+
filename = cache_files[0]
|
422 |
+
return filename if return_filename else open(filename, "rb")
|
423 |
+
|
424 |
+
# Download.
|
425 |
+
url_name = None
|
426 |
+
url_data = None
|
427 |
+
with requests.Session() as session:
|
428 |
+
if verbose:
|
429 |
+
print("Downloading %s ..." % url, end="", flush=True)
|
430 |
+
for attempts_left in reversed(range(num_attempts)):
|
431 |
+
try:
|
432 |
+
with session.get(url) as res:
|
433 |
+
res.raise_for_status()
|
434 |
+
if len(res.content) == 0:
|
435 |
+
raise IOError("No data received")
|
436 |
+
|
437 |
+
if len(res.content) < 8192:
|
438 |
+
content_str = res.content.decode("utf-8")
|
439 |
+
if "download_warning" in res.headers.get("Set-Cookie", ""):
|
440 |
+
links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
|
441 |
+
if len(links) == 1:
|
442 |
+
url = requests.compat.urljoin(url, links[0])
|
443 |
+
raise IOError("Google Drive virus checker nag")
|
444 |
+
if "Google Drive - Quota exceeded" in content_str:
|
445 |
+
raise IOError("Google Drive download quota exceeded -- please try again later")
|
446 |
+
|
447 |
+
match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
|
448 |
+
url_name = match[1] if match else url
|
449 |
+
url_data = res.content
|
450 |
+
if verbose:
|
451 |
+
print(" done")
|
452 |
+
break
|
453 |
+
except KeyboardInterrupt:
|
454 |
+
raise
|
455 |
+
except:
|
456 |
+
if not attempts_left:
|
457 |
+
if verbose:
|
458 |
+
print(" failed")
|
459 |
+
raise
|
460 |
+
if verbose:
|
461 |
+
print(".", end="", flush=True)
|
462 |
+
|
463 |
+
# Save to cache.
|
464 |
+
if cache:
|
465 |
+
safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
|
466 |
+
cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
|
467 |
+
temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
|
468 |
+
os.makedirs(cache_dir, exist_ok=True)
|
469 |
+
with open(temp_file, "wb") as f:
|
470 |
+
f.write(url_data)
|
471 |
+
os.replace(temp_file, cache_file) # atomic
|
472 |
+
if return_filename:
|
473 |
+
return cache_file
|
474 |
+
|
475 |
+
# Return data as file object.
|
476 |
+
assert not return_filename
|
477 |
+
return io.BytesIO(url_data)
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchvision
|
3 |
+
torchaudio
|
4 |
+
diffusers==0.16.0
|
5 |
+
transformers
|
6 |
+
accelerate==0.18.0
|
7 |
+
gradio
|
torch_utils/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
# empty
|
torch_utils/custom_ops.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
import os
|
10 |
+
import glob
|
11 |
+
import torch
|
12 |
+
import torch.utils.cpp_extension
|
13 |
+
import importlib
|
14 |
+
import hashlib
|
15 |
+
import shutil
|
16 |
+
from pathlib import Path
|
17 |
+
|
18 |
+
from torch.utils.file_baton import FileBaton
|
19 |
+
|
20 |
+
#----------------------------------------------------------------------------
|
21 |
+
# Global options.
|
22 |
+
|
23 |
+
verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full'
|
24 |
+
|
25 |
+
#----------------------------------------------------------------------------
|
26 |
+
# Internal helper funcs.
|
27 |
+
|
28 |
+
def _find_compiler_bindir():
|
29 |
+
patterns = [
|
30 |
+
'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64',
|
31 |
+
'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64',
|
32 |
+
'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64',
|
33 |
+
'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin',
|
34 |
+
]
|
35 |
+
for pattern in patterns:
|
36 |
+
matches = sorted(glob.glob(pattern))
|
37 |
+
if len(matches):
|
38 |
+
return matches[-1]
|
39 |
+
return None
|
40 |
+
|
41 |
+
#----------------------------------------------------------------------------
|
42 |
+
# Main entry point for compiling and loading C++/CUDA plugins.
|
43 |
+
|
44 |
+
_cached_plugins = dict()
|
45 |
+
|
46 |
+
def get_plugin(module_name, sources, **build_kwargs):
|
47 |
+
assert verbosity in ['none', 'brief', 'full']
|
48 |
+
|
49 |
+
# Already cached?
|
50 |
+
if module_name in _cached_plugins:
|
51 |
+
return _cached_plugins[module_name]
|
52 |
+
|
53 |
+
# Print status.
|
54 |
+
if verbosity == 'full':
|
55 |
+
print(f'Setting up PyTorch plugin "{module_name}"...')
|
56 |
+
elif verbosity == 'brief':
|
57 |
+
print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True)
|
58 |
+
|
59 |
+
try: # pylint: disable=too-many-nested-blocks
|
60 |
+
# Make sure we can find the necessary compiler binaries.
|
61 |
+
if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0:
|
62 |
+
compiler_bindir = _find_compiler_bindir()
|
63 |
+
if compiler_bindir is None:
|
64 |
+
raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".')
|
65 |
+
os.environ['PATH'] += ';' + compiler_bindir
|
66 |
+
|
67 |
+
# Compile and load.
|
68 |
+
verbose_build = (verbosity == 'full')
|
69 |
+
|
70 |
+
# Incremental build md5sum trickery. Copies all the input source files
|
71 |
+
# into a cached build directory under a combined md5 digest of the input
|
72 |
+
# source files. Copying is done only if the combined digest has changed.
|
73 |
+
# This keeps input file timestamps and filenames the same as in previous
|
74 |
+
# extension builds, allowing for fast incremental rebuilds.
|
75 |
+
#
|
76 |
+
# This optimization is done only in case all the source files reside in
|
77 |
+
# a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
|
78 |
+
# environment variable is set (we take this as a signal that the user
|
79 |
+
# actually cares about this.)
|
80 |
+
source_dirs_set = set(os.path.dirname(source) for source in sources)
|
81 |
+
if len(source_dirs_set) == 1 and ('TORCH_EXTENSIONS_DIR' in os.environ):
|
82 |
+
all_source_files = sorted(list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file()))
|
83 |
+
|
84 |
+
# Compute a combined hash digest for all source files in the same
|
85 |
+
# custom op directory (usually .cu, .cpp, .py and .h files).
|
86 |
+
hash_md5 = hashlib.md5()
|
87 |
+
for src in all_source_files:
|
88 |
+
with open(src, 'rb') as f:
|
89 |
+
hash_md5.update(f.read())
|
90 |
+
build_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access
|
91 |
+
digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest())
|
92 |
+
|
93 |
+
if not os.path.isdir(digest_build_dir):
|
94 |
+
os.makedirs(digest_build_dir, exist_ok=True)
|
95 |
+
baton = FileBaton(os.path.join(digest_build_dir, 'lock'))
|
96 |
+
if baton.try_acquire():
|
97 |
+
try:
|
98 |
+
for src in all_source_files:
|
99 |
+
shutil.copyfile(src, os.path.join(digest_build_dir, os.path.basename(src)))
|
100 |
+
finally:
|
101 |
+
baton.release()
|
102 |
+
else:
|
103 |
+
# Someone else is copying source files under the digest dir,
|
104 |
+
# wait until done and continue.
|
105 |
+
baton.wait()
|
106 |
+
digest_sources = [os.path.join(digest_build_dir, os.path.basename(x)) for x in sources]
|
107 |
+
torch.utils.cpp_extension.load(name=module_name, build_directory=build_dir,
|
108 |
+
verbose=verbose_build, sources=digest_sources, **build_kwargs)
|
109 |
+
else:
|
110 |
+
torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)
|
111 |
+
module = importlib.import_module(module_name)
|
112 |
+
|
113 |
+
except:
|
114 |
+
if verbosity == 'brief':
|
115 |
+
print('Failed!')
|
116 |
+
raise
|
117 |
+
|
118 |
+
# Print status and add to cache.
|
119 |
+
if verbosity == 'full':
|
120 |
+
print(f'Done setting up PyTorch plugin "{module_name}".')
|
121 |
+
elif verbosity == 'brief':
|
122 |
+
print('Done.')
|
123 |
+
_cached_plugins[module_name] = module
|
124 |
+
return module
|
125 |
+
|
126 |
+
#----------------------------------------------------------------------------
|
torch_utils/misc.py
ADDED
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
import re
|
10 |
+
import contextlib
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
import warnings
|
14 |
+
import dnnlib
|
15 |
+
|
16 |
+
#----------------------------------------------------------------------------
|
17 |
+
# Cached construction of constant tensors. Avoids CPU=>GPU copy when the
|
18 |
+
# same constant is used multiple times.
|
19 |
+
|
20 |
+
_constant_cache = dict()
|
21 |
+
|
22 |
+
def constant(value, shape=None, dtype=None, device=None, memory_format=None):
|
23 |
+
value = np.asarray(value)
|
24 |
+
if shape is not None:
|
25 |
+
shape = tuple(shape)
|
26 |
+
if dtype is None:
|
27 |
+
dtype = torch.get_default_dtype()
|
28 |
+
if device is None:
|
29 |
+
device = torch.device('cpu')
|
30 |
+
if memory_format is None:
|
31 |
+
memory_format = torch.contiguous_format
|
32 |
+
|
33 |
+
key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
|
34 |
+
tensor = _constant_cache.get(key, None)
|
35 |
+
if tensor is None:
|
36 |
+
tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
|
37 |
+
if shape is not None:
|
38 |
+
tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
|
39 |
+
tensor = tensor.contiguous(memory_format=memory_format)
|
40 |
+
_constant_cache[key] = tensor
|
41 |
+
return tensor
|
42 |
+
|
43 |
+
#----------------------------------------------------------------------------
|
44 |
+
# Replace NaN/Inf with specified numerical values.
|
45 |
+
|
46 |
+
try:
|
47 |
+
nan_to_num = torch.nan_to_num # 1.8.0a0
|
48 |
+
except AttributeError:
|
49 |
+
def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin
|
50 |
+
assert isinstance(input, torch.Tensor)
|
51 |
+
if posinf is None:
|
52 |
+
posinf = torch.finfo(input.dtype).max
|
53 |
+
if neginf is None:
|
54 |
+
neginf = torch.finfo(input.dtype).min
|
55 |
+
assert nan == 0
|
56 |
+
return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)
|
57 |
+
|
58 |
+
#----------------------------------------------------------------------------
|
59 |
+
# Symbolic assert.
|
60 |
+
|
61 |
+
try:
|
62 |
+
symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
|
63 |
+
except AttributeError:
|
64 |
+
symbolic_assert = torch.Assert # 1.7.0
|
65 |
+
|
66 |
+
#----------------------------------------------------------------------------
|
67 |
+
# Context manager to suppress known warnings in torch.jit.trace().
|
68 |
+
|
69 |
+
class suppress_tracer_warnings(warnings.catch_warnings):
|
70 |
+
def __enter__(self):
|
71 |
+
super().__enter__()
|
72 |
+
warnings.simplefilter('ignore', category=torch.jit.TracerWarning)
|
73 |
+
return self
|
74 |
+
|
75 |
+
#----------------------------------------------------------------------------
|
76 |
+
# Assert that the shape of a tensor matches the given list of integers.
|
77 |
+
# None indicates that the size of a dimension is allowed to vary.
|
78 |
+
# Performs symbolic assertion when used in torch.jit.trace().
|
79 |
+
|
80 |
+
def assert_shape(tensor, ref_shape):
|
81 |
+
if tensor.ndim != len(ref_shape):
|
82 |
+
raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}')
|
83 |
+
for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
|
84 |
+
if ref_size is None:
|
85 |
+
pass
|
86 |
+
elif isinstance(ref_size, torch.Tensor):
|
87 |
+
with suppress_tracer_warnings(): # as_tensor results are registered as constants
|
88 |
+
symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}')
|
89 |
+
elif isinstance(size, torch.Tensor):
|
90 |
+
with suppress_tracer_warnings(): # as_tensor results are registered as constants
|
91 |
+
symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}')
|
92 |
+
elif size != ref_size:
|
93 |
+
raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')
|
94 |
+
|
95 |
+
#----------------------------------------------------------------------------
|
96 |
+
# Function decorator that calls torch.autograd.profiler.record_function().
|
97 |
+
|
98 |
+
def profiled_function(fn):
|
99 |
+
def decorator(*args, **kwargs):
|
100 |
+
with torch.autograd.profiler.record_function(fn.__name__):
|
101 |
+
return fn(*args, **kwargs)
|
102 |
+
decorator.__name__ = fn.__name__
|
103 |
+
return decorator
|
104 |
+
|
105 |
+
#----------------------------------------------------------------------------
|
106 |
+
# Sampler for torch.utils.data.DataLoader that loops over the dataset
|
107 |
+
# indefinitely, shuffling items as it goes.
|
108 |
+
|
109 |
+
class InfiniteSampler(torch.utils.data.Sampler):
|
110 |
+
def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5):
|
111 |
+
assert len(dataset) > 0
|
112 |
+
assert num_replicas > 0
|
113 |
+
assert 0 <= rank < num_replicas
|
114 |
+
assert 0 <= window_size <= 1
|
115 |
+
super().__init__(dataset)
|
116 |
+
self.dataset = dataset
|
117 |
+
self.rank = rank
|
118 |
+
self.num_replicas = num_replicas
|
119 |
+
self.shuffle = shuffle
|
120 |
+
self.seed = seed
|
121 |
+
self.window_size = window_size
|
122 |
+
|
123 |
+
def __iter__(self):
|
124 |
+
order = np.arange(len(self.dataset))
|
125 |
+
rnd = None
|
126 |
+
window = 0
|
127 |
+
if self.shuffle:
|
128 |
+
rnd = np.random.RandomState(self.seed)
|
129 |
+
rnd.shuffle(order)
|
130 |
+
window = int(np.rint(order.size * self.window_size))
|
131 |
+
|
132 |
+
idx = 0
|
133 |
+
while True:
|
134 |
+
i = idx % order.size
|
135 |
+
if idx % self.num_replicas == self.rank:
|
136 |
+
yield order[i]
|
137 |
+
if window >= 2:
|
138 |
+
j = (i - rnd.randint(window)) % order.size
|
139 |
+
order[i], order[j] = order[j], order[i]
|
140 |
+
idx += 1
|
141 |
+
|
142 |
+
#----------------------------------------------------------------------------
|
143 |
+
# Utilities for operating with torch.nn.Module parameters and buffers.
|
144 |
+
|
145 |
+
def params_and_buffers(module):
|
146 |
+
assert isinstance(module, torch.nn.Module)
|
147 |
+
return list(module.parameters()) + list(module.buffers())
|
148 |
+
|
149 |
+
def named_params_and_buffers(module):
|
150 |
+
assert isinstance(module, torch.nn.Module)
|
151 |
+
return list(module.named_parameters()) + list(module.named_buffers())
|
152 |
+
|
153 |
+
def copy_params_and_buffers(src_module, dst_module, require_all=False):
|
154 |
+
assert isinstance(src_module, torch.nn.Module)
|
155 |
+
assert isinstance(dst_module, torch.nn.Module)
|
156 |
+
src_tensors = {name: tensor for name, tensor in named_params_and_buffers(src_module)}
|
157 |
+
for name, tensor in named_params_and_buffers(dst_module):
|
158 |
+
assert (name in src_tensors) or (not require_all)
|
159 |
+
if name in src_tensors:
|
160 |
+
tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad)
|
161 |
+
|
162 |
+
#----------------------------------------------------------------------------
|
163 |
+
# Context manager for easily enabling/disabling DistributedDataParallel
|
164 |
+
# synchronization.
|
165 |
+
|
166 |
+
@contextlib.contextmanager
|
167 |
+
def ddp_sync(module, sync):
|
168 |
+
assert isinstance(module, torch.nn.Module)
|
169 |
+
if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
|
170 |
+
yield
|
171 |
+
else:
|
172 |
+
with module.no_sync():
|
173 |
+
yield
|
174 |
+
|
175 |
+
#----------------------------------------------------------------------------
|
176 |
+
# Check DistributedDataParallel consistency across processes.
|
177 |
+
|
178 |
+
def check_ddp_consistency(module, ignore_regex=None):
|
179 |
+
assert isinstance(module, torch.nn.Module)
|
180 |
+
for name, tensor in named_params_and_buffers(module):
|
181 |
+
fullname = type(module).__name__ + '.' + name
|
182 |
+
if ignore_regex is not None and re.fullmatch(ignore_regex, fullname):
|
183 |
+
continue
|
184 |
+
tensor = tensor.detach()
|
185 |
+
other = tensor.clone()
|
186 |
+
torch.distributed.broadcast(tensor=other, src=0)
|
187 |
+
assert (nan_to_num(tensor) == nan_to_num(other)).all(), fullname
|
188 |
+
|
189 |
+
#----------------------------------------------------------------------------
|
190 |
+
# Print summary table of module hierarchy.
|
191 |
+
|
192 |
+
def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
|
193 |
+
assert isinstance(module, torch.nn.Module)
|
194 |
+
assert not isinstance(module, torch.jit.ScriptModule)
|
195 |
+
assert isinstance(inputs, (tuple, list))
|
196 |
+
|
197 |
+
# Register hooks.
|
198 |
+
entries = []
|
199 |
+
nesting = [0]
|
200 |
+
def pre_hook(_mod, _inputs):
|
201 |
+
nesting[0] += 1
|
202 |
+
def post_hook(mod, _inputs, outputs):
|
203 |
+
nesting[0] -= 1
|
204 |
+
if nesting[0] <= max_nesting:
|
205 |
+
outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs]
|
206 |
+
outputs = [t for t in outputs if isinstance(t, torch.Tensor)]
|
207 |
+
entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs))
|
208 |
+
hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()]
|
209 |
+
hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()]
|
210 |
+
|
211 |
+
# Run module.
|
212 |
+
outputs = module(*inputs)
|
213 |
+
for hook in hooks:
|
214 |
+
hook.remove()
|
215 |
+
|
216 |
+
# Identify unique outputs, parameters, and buffers.
|
217 |
+
tensors_seen = set()
|
218 |
+
for e in entries:
|
219 |
+
e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen]
|
220 |
+
e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen]
|
221 |
+
e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen]
|
222 |
+
tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs}
|
223 |
+
|
224 |
+
# Filter out redundant entries.
|
225 |
+
if skip_redundant:
|
226 |
+
entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)]
|
227 |
+
|
228 |
+
# Construct table.
|
229 |
+
rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']]
|
230 |
+
rows += [['---'] * len(rows[0])]
|
231 |
+
param_total = 0
|
232 |
+
buffer_total = 0
|
233 |
+
submodule_names = {mod: name for name, mod in module.named_modules()}
|
234 |
+
for e in entries:
|
235 |
+
name = '<top-level>' if e.mod is module else submodule_names[e.mod]
|
236 |
+
param_size = sum(t.numel() for t in e.unique_params)
|
237 |
+
buffer_size = sum(t.numel() for t in e.unique_buffers)
|
238 |
+
output_shapes = [str(list(e.outputs[0].shape)) for t in e.outputs]
|
239 |
+
output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs]
|
240 |
+
rows += [[
|
241 |
+
name + (':0' if len(e.outputs) >= 2 else ''),
|
242 |
+
str(param_size) if param_size else '-',
|
243 |
+
str(buffer_size) if buffer_size else '-',
|
244 |
+
(output_shapes + ['-'])[0],
|
245 |
+
(output_dtypes + ['-'])[0],
|
246 |
+
]]
|
247 |
+
for idx in range(1, len(e.outputs)):
|
248 |
+
rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]]
|
249 |
+
param_total += param_size
|
250 |
+
buffer_total += buffer_size
|
251 |
+
rows += [['---'] * len(rows[0])]
|
252 |
+
rows += [['Total', str(param_total), str(buffer_total), '-', '-']]
|
253 |
+
|
254 |
+
# Print table.
|
255 |
+
widths = [max(len(cell) for cell in column) for column in zip(*rows)]
|
256 |
+
print()
|
257 |
+
for row in rows:
|
258 |
+
print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths)))
|
259 |
+
print()
|
260 |
+
return outputs
|
261 |
+
|
262 |
+
#----------------------------------------------------------------------------
|
torch_utils/ops/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
# empty
|
torch_utils/ops/bias_act.cpp
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
//
|
3 |
+
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
// and proprietary rights in and to this software, related documentation
|
5 |
+
// and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
// distribution of this software and related documentation without an express
|
7 |
+
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
#include <torch/extension.h>
|
10 |
+
#include <ATen/cuda/CUDAContext.h>
|
11 |
+
#include <c10/cuda/CUDAGuard.h>
|
12 |
+
#include "bias_act.h"
|
13 |
+
|
14 |
+
//------------------------------------------------------------------------
|
15 |
+
|
16 |
+
static bool has_same_layout(torch::Tensor x, torch::Tensor y)
|
17 |
+
{
|
18 |
+
if (x.dim() != y.dim())
|
19 |
+
return false;
|
20 |
+
for (int64_t i = 0; i < x.dim(); i++)
|
21 |
+
{
|
22 |
+
if (x.size(i) != y.size(i))
|
23 |
+
return false;
|
24 |
+
if (x.size(i) >= 2 && x.stride(i) != y.stride(i))
|
25 |
+
return false;
|
26 |
+
}
|
27 |
+
return true;
|
28 |
+
}
|
29 |
+
|
30 |
+
//------------------------------------------------------------------------
|
31 |
+
|
32 |
+
static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp)
|
33 |
+
{
|
34 |
+
// Validate arguments.
|
35 |
+
TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
|
36 |
+
TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x");
|
37 |
+
TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x");
|
38 |
+
TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x");
|
39 |
+
TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x");
|
40 |
+
TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
|
41 |
+
TORCH_CHECK(b.dim() == 1, "b must have rank 1");
|
42 |
+
TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds");
|
43 |
+
TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements");
|
44 |
+
TORCH_CHECK(grad >= 0, "grad must be non-negative");
|
45 |
+
|
46 |
+
// Validate layout.
|
47 |
+
TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense");
|
48 |
+
TORCH_CHECK(b.is_contiguous(), "b must be contiguous");
|
49 |
+
TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x");
|
50 |
+
TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x");
|
51 |
+
TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x");
|
52 |
+
|
53 |
+
// Create output tensor.
|
54 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
55 |
+
torch::Tensor y = torch::empty_like(x);
|
56 |
+
TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x");
|
57 |
+
|
58 |
+
// Initialize CUDA kernel parameters.
|
59 |
+
bias_act_kernel_params p;
|
60 |
+
p.x = x.data_ptr();
|
61 |
+
p.b = (b.numel()) ? b.data_ptr() : NULL;
|
62 |
+
p.xref = (xref.numel()) ? xref.data_ptr() : NULL;
|
63 |
+
p.yref = (yref.numel()) ? yref.data_ptr() : NULL;
|
64 |
+
p.dy = (dy.numel()) ? dy.data_ptr() : NULL;
|
65 |
+
p.y = y.data_ptr();
|
66 |
+
p.grad = grad;
|
67 |
+
p.act = act;
|
68 |
+
p.alpha = alpha;
|
69 |
+
p.gain = gain;
|
70 |
+
p.clamp = clamp;
|
71 |
+
p.sizeX = (int)x.numel();
|
72 |
+
p.sizeB = (int)b.numel();
|
73 |
+
p.stepB = (b.numel()) ? (int)x.stride(dim) : 1;
|
74 |
+
|
75 |
+
// Choose CUDA kernel.
|
76 |
+
void* kernel;
|
77 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
|
78 |
+
{
|
79 |
+
kernel = choose_bias_act_kernel<scalar_t>(p);
|
80 |
+
});
|
81 |
+
TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func");
|
82 |
+
|
83 |
+
// Launch CUDA kernel.
|
84 |
+
p.loopX = 4;
|
85 |
+
int blockSize = 4 * 32;
|
86 |
+
int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
|
87 |
+
void* args[] = {&p};
|
88 |
+
AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
|
89 |
+
return y;
|
90 |
+
}
|
91 |
+
|
92 |
+
//------------------------------------------------------------------------
|
93 |
+
|
94 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
95 |
+
{
|
96 |
+
m.def("bias_act", &bias_act);
|
97 |
+
}
|
98 |
+
|
99 |
+
//------------------------------------------------------------------------
|
torch_utils/ops/bias_act.cu
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
//
|
3 |
+
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
// and proprietary rights in and to this software, related documentation
|
5 |
+
// and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
// distribution of this software and related documentation without an express
|
7 |
+
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
#include <c10/util/Half.h>
|
10 |
+
#include "bias_act.h"
|
11 |
+
|
12 |
+
//------------------------------------------------------------------------
|
13 |
+
// Helpers.
|
14 |
+
|
15 |
+
template <class T> struct InternalType;
|
16 |
+
template <> struct InternalType<double> { typedef double scalar_t; };
|
17 |
+
template <> struct InternalType<float> { typedef float scalar_t; };
|
18 |
+
template <> struct InternalType<c10::Half> { typedef float scalar_t; };
|
19 |
+
|
20 |
+
//------------------------------------------------------------------------
|
21 |
+
// CUDA kernel.
|
22 |
+
|
23 |
+
template <class T, int A>
|
24 |
+
__global__ void bias_act_kernel(bias_act_kernel_params p)
|
25 |
+
{
|
26 |
+
typedef typename InternalType<T>::scalar_t scalar_t;
|
27 |
+
int G = p.grad;
|
28 |
+
scalar_t alpha = (scalar_t)p.alpha;
|
29 |
+
scalar_t gain = (scalar_t)p.gain;
|
30 |
+
scalar_t clamp = (scalar_t)p.clamp;
|
31 |
+
scalar_t one = (scalar_t)1;
|
32 |
+
scalar_t two = (scalar_t)2;
|
33 |
+
scalar_t expRange = (scalar_t)80;
|
34 |
+
scalar_t halfExpRange = (scalar_t)40;
|
35 |
+
scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946;
|
36 |
+
scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717;
|
37 |
+
|
38 |
+
// Loop over elements.
|
39 |
+
int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
|
40 |
+
for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x)
|
41 |
+
{
|
42 |
+
// Load.
|
43 |
+
scalar_t x = (scalar_t)((const T*)p.x)[xi];
|
44 |
+
scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0;
|
45 |
+
scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0;
|
46 |
+
scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0;
|
47 |
+
scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one;
|
48 |
+
scalar_t yy = (gain != 0) ? yref / gain : 0;
|
49 |
+
scalar_t y = 0;
|
50 |
+
|
51 |
+
// Apply bias.
|
52 |
+
((G == 0) ? x : xref) += b;
|
53 |
+
|
54 |
+
// linear
|
55 |
+
if (A == 1)
|
56 |
+
{
|
57 |
+
if (G == 0) y = x;
|
58 |
+
if (G == 1) y = x;
|
59 |
+
}
|
60 |
+
|
61 |
+
// relu
|
62 |
+
if (A == 2)
|
63 |
+
{
|
64 |
+
if (G == 0) y = (x > 0) ? x : 0;
|
65 |
+
if (G == 1) y = (yy > 0) ? x : 0;
|
66 |
+
}
|
67 |
+
|
68 |
+
// lrelu
|
69 |
+
if (A == 3)
|
70 |
+
{
|
71 |
+
if (G == 0) y = (x > 0) ? x : x * alpha;
|
72 |
+
if (G == 1) y = (yy > 0) ? x : x * alpha;
|
73 |
+
}
|
74 |
+
|
75 |
+
// tanh
|
76 |
+
if (A == 4)
|
77 |
+
{
|
78 |
+
if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); }
|
79 |
+
if (G == 1) y = x * (one - yy * yy);
|
80 |
+
if (G == 2) y = x * (one - yy * yy) * (-two * yy);
|
81 |
+
}
|
82 |
+
|
83 |
+
// sigmoid
|
84 |
+
if (A == 5)
|
85 |
+
{
|
86 |
+
if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one);
|
87 |
+
if (G == 1) y = x * yy * (one - yy);
|
88 |
+
if (G == 2) y = x * yy * (one - yy) * (one - two * yy);
|
89 |
+
}
|
90 |
+
|
91 |
+
// elu
|
92 |
+
if (A == 6)
|
93 |
+
{
|
94 |
+
if (G == 0) y = (x >= 0) ? x : exp(x) - one;
|
95 |
+
if (G == 1) y = (yy >= 0) ? x : x * (yy + one);
|
96 |
+
if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one);
|
97 |
+
}
|
98 |
+
|
99 |
+
// selu
|
100 |
+
if (A == 7)
|
101 |
+
{
|
102 |
+
if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one);
|
103 |
+
if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha);
|
104 |
+
if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha);
|
105 |
+
}
|
106 |
+
|
107 |
+
// softplus
|
108 |
+
if (A == 8)
|
109 |
+
{
|
110 |
+
if (G == 0) y = (x > expRange) ? x : log(exp(x) + one);
|
111 |
+
if (G == 1) y = x * (one - exp(-yy));
|
112 |
+
if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); }
|
113 |
+
}
|
114 |
+
|
115 |
+
// swish
|
116 |
+
if (A == 9)
|
117 |
+
{
|
118 |
+
if (G == 0)
|
119 |
+
y = (x < -expRange) ? 0 : x / (exp(-x) + one);
|
120 |
+
else
|
121 |
+
{
|
122 |
+
scalar_t c = exp(xref);
|
123 |
+
scalar_t d = c + one;
|
124 |
+
if (G == 1)
|
125 |
+
y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d);
|
126 |
+
else
|
127 |
+
y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d);
|
128 |
+
yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain;
|
129 |
+
}
|
130 |
+
}
|
131 |
+
|
132 |
+
// Apply gain.
|
133 |
+
y *= gain * dy;
|
134 |
+
|
135 |
+
// Clamp.
|
136 |
+
if (clamp >= 0)
|
137 |
+
{
|
138 |
+
if (G == 0)
|
139 |
+
y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp;
|
140 |
+
else
|
141 |
+
y = (yref > -clamp & yref < clamp) ? y : 0;
|
142 |
+
}
|
143 |
+
|
144 |
+
// Store.
|
145 |
+
((T*)p.y)[xi] = (T)y;
|
146 |
+
}
|
147 |
+
}
|
148 |
+
|
149 |
+
//------------------------------------------------------------------------
|
150 |
+
// CUDA kernel selection.
|
151 |
+
|
152 |
+
template <class T> void* choose_bias_act_kernel(const bias_act_kernel_params& p)
|
153 |
+
{
|
154 |
+
if (p.act == 1) return (void*)bias_act_kernel<T, 1>;
|
155 |
+
if (p.act == 2) return (void*)bias_act_kernel<T, 2>;
|
156 |
+
if (p.act == 3) return (void*)bias_act_kernel<T, 3>;
|
157 |
+
if (p.act == 4) return (void*)bias_act_kernel<T, 4>;
|
158 |
+
if (p.act == 5) return (void*)bias_act_kernel<T, 5>;
|
159 |
+
if (p.act == 6) return (void*)bias_act_kernel<T, 6>;
|
160 |
+
if (p.act == 7) return (void*)bias_act_kernel<T, 7>;
|
161 |
+
if (p.act == 8) return (void*)bias_act_kernel<T, 8>;
|
162 |
+
if (p.act == 9) return (void*)bias_act_kernel<T, 9>;
|
163 |
+
return NULL;
|
164 |
+
}
|
165 |
+
|
166 |
+
//------------------------------------------------------------------------
|
167 |
+
// Template specializations.
|
168 |
+
|
169 |
+
template void* choose_bias_act_kernel<double> (const bias_act_kernel_params& p);
|
170 |
+
template void* choose_bias_act_kernel<float> (const bias_act_kernel_params& p);
|
171 |
+
template void* choose_bias_act_kernel<c10::Half> (const bias_act_kernel_params& p);
|
172 |
+
|
173 |
+
//------------------------------------------------------------------------
|
torch_utils/ops/bias_act.h
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
//
|
3 |
+
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
// and proprietary rights in and to this software, related documentation
|
5 |
+
// and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
// distribution of this software and related documentation without an express
|
7 |
+
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
//------------------------------------------------------------------------
|
10 |
+
// CUDA kernel parameters.
|
11 |
+
|
12 |
+
struct bias_act_kernel_params
|
13 |
+
{
|
14 |
+
const void* x; // [sizeX]
|
15 |
+
const void* b; // [sizeB] or NULL
|
16 |
+
const void* xref; // [sizeX] or NULL
|
17 |
+
const void* yref; // [sizeX] or NULL
|
18 |
+
const void* dy; // [sizeX] or NULL
|
19 |
+
void* y; // [sizeX]
|
20 |
+
|
21 |
+
int grad;
|
22 |
+
int act;
|
23 |
+
float alpha;
|
24 |
+
float gain;
|
25 |
+
float clamp;
|
26 |
+
|
27 |
+
int sizeX;
|
28 |
+
int sizeB;
|
29 |
+
int stepB;
|
30 |
+
int loopX;
|
31 |
+
};
|
32 |
+
|
33 |
+
//------------------------------------------------------------------------
|
34 |
+
// CUDA kernel selection.
|
35 |
+
|
36 |
+
template <class T> void* choose_bias_act_kernel(const bias_act_kernel_params& p);
|
37 |
+
|
38 |
+
//------------------------------------------------------------------------
|
torch_utils/ops/bias_act.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
"""Custom PyTorch ops for efficient bias and activation."""
|
10 |
+
|
11 |
+
import os
|
12 |
+
import warnings
|
13 |
+
import numpy as np
|
14 |
+
import torch
|
15 |
+
import dnnlib
|
16 |
+
import traceback
|
17 |
+
|
18 |
+
from .. import custom_ops
|
19 |
+
from .. import misc
|
20 |
+
|
21 |
+
#----------------------------------------------------------------------------
|
22 |
+
|
23 |
+
activation_funcs = {
|
24 |
+
'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False),
|
25 |
+
'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False),
|
26 |
+
'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False),
|
27 |
+
'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True),
|
28 |
+
'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True),
|
29 |
+
'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True),
|
30 |
+
'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True),
|
31 |
+
'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True),
|
32 |
+
'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True),
|
33 |
+
}
|
34 |
+
|
35 |
+
#----------------------------------------------------------------------------
|
36 |
+
|
37 |
+
_inited = False
|
38 |
+
_plugin = None
|
39 |
+
_null_tensor = torch.empty([0])
|
40 |
+
|
41 |
+
def _init():
|
42 |
+
global _inited, _plugin
|
43 |
+
if not _inited:
|
44 |
+
_inited = True
|
45 |
+
sources = ['bias_act.cpp', 'bias_act.cu']
|
46 |
+
sources = [os.path.join(os.path.dirname(__file__), s) for s in sources]
|
47 |
+
try:
|
48 |
+
_plugin = custom_ops.get_plugin('bias_act_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math'])
|
49 |
+
except:
|
50 |
+
warnings.warn('Failed to build CUDA kernels for bias_act. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc())
|
51 |
+
return _plugin is not None
|
52 |
+
|
53 |
+
#----------------------------------------------------------------------------
|
54 |
+
|
55 |
+
def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'):
|
56 |
+
r"""Fused bias and activation function.
|
57 |
+
|
58 |
+
Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
|
59 |
+
and scales the result by `gain`. Each of the steps is optional. In most cases,
|
60 |
+
the fused op is considerably more efficient than performing the same calculation
|
61 |
+
using standard PyTorch ops. It supports first and second order gradients,
|
62 |
+
but not third order gradients.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
x: Input activation tensor. Can be of any shape.
|
66 |
+
b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
|
67 |
+
as `x`. The shape must be known, and it must match the dimension of `x`
|
68 |
+
corresponding to `dim`.
|
69 |
+
dim: The dimension in `x` corresponding to the elements of `b`.
|
70 |
+
The value of `dim` is ignored if `b` is not specified.
|
71 |
+
act: Name of the activation function to evaluate, or `"linear"` to disable.
|
72 |
+
Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
|
73 |
+
See `activation_funcs` for a full list. `None` is not allowed.
|
74 |
+
alpha: Shape parameter for the activation function, or `None` to use the default.
|
75 |
+
gain: Scaling factor for the output tensor, or `None` to use default.
|
76 |
+
See `activation_funcs` for the default scaling of each activation function.
|
77 |
+
If unsure, consider specifying 1.
|
78 |
+
clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable
|
79 |
+
the clamping (default).
|
80 |
+
impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
|
81 |
+
|
82 |
+
Returns:
|
83 |
+
Tensor of the same shape and datatype as `x`.
|
84 |
+
"""
|
85 |
+
assert isinstance(x, torch.Tensor)
|
86 |
+
assert impl in ['ref', 'cuda']
|
87 |
+
if impl == 'cuda' and x.device.type == 'cuda' and _init():
|
88 |
+
return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b)
|
89 |
+
return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp)
|
90 |
+
|
91 |
+
#----------------------------------------------------------------------------
|
92 |
+
|
93 |
+
@misc.profiled_function
|
94 |
+
def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None):
|
95 |
+
"""Slow reference implementation of `bias_act()` using standard TensorFlow ops.
|
96 |
+
"""
|
97 |
+
assert isinstance(x, torch.Tensor)
|
98 |
+
assert clamp is None or clamp >= 0
|
99 |
+
spec = activation_funcs[act]
|
100 |
+
alpha = float(alpha if alpha is not None else spec.def_alpha)
|
101 |
+
gain = float(gain if gain is not None else spec.def_gain)
|
102 |
+
clamp = float(clamp if clamp is not None else -1)
|
103 |
+
|
104 |
+
# Add bias.
|
105 |
+
if b is not None:
|
106 |
+
assert isinstance(b, torch.Tensor) and b.ndim == 1
|
107 |
+
assert 0 <= dim < x.ndim
|
108 |
+
assert b.shape[0] == x.shape[dim]
|
109 |
+
x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)])
|
110 |
+
|
111 |
+
# Evaluate activation function.
|
112 |
+
alpha = float(alpha)
|
113 |
+
x = spec.func(x, alpha=alpha)
|
114 |
+
|
115 |
+
# Scale by gain.
|
116 |
+
gain = float(gain)
|
117 |
+
if gain != 1:
|
118 |
+
x = x * gain
|
119 |
+
|
120 |
+
# Clamp.
|
121 |
+
if clamp >= 0:
|
122 |
+
x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type
|
123 |
+
return x
|
124 |
+
|
125 |
+
#----------------------------------------------------------------------------
|
126 |
+
|
127 |
+
_bias_act_cuda_cache = dict()
|
128 |
+
|
129 |
+
def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None):
|
130 |
+
"""Fast CUDA implementation of `bias_act()` using custom ops.
|
131 |
+
"""
|
132 |
+
# Parse arguments.
|
133 |
+
assert clamp is None or clamp >= 0
|
134 |
+
spec = activation_funcs[act]
|
135 |
+
alpha = float(alpha if alpha is not None else spec.def_alpha)
|
136 |
+
gain = float(gain if gain is not None else spec.def_gain)
|
137 |
+
clamp = float(clamp if clamp is not None else -1)
|
138 |
+
|
139 |
+
# Lookup from cache.
|
140 |
+
key = (dim, act, alpha, gain, clamp)
|
141 |
+
if key in _bias_act_cuda_cache:
|
142 |
+
return _bias_act_cuda_cache[key]
|
143 |
+
|
144 |
+
# Forward op.
|
145 |
+
class BiasActCuda(torch.autograd.Function):
|
146 |
+
@staticmethod
|
147 |
+
def forward(ctx, x, b): # pylint: disable=arguments-differ
|
148 |
+
ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride()[1] == 1 else torch.contiguous_format
|
149 |
+
x = x.contiguous(memory_format=ctx.memory_format)
|
150 |
+
b = b.contiguous() if b is not None else _null_tensor
|
151 |
+
y = x
|
152 |
+
if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor:
|
153 |
+
y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp)
|
154 |
+
ctx.save_for_backward(
|
155 |
+
x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
|
156 |
+
b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
|
157 |
+
y if 'y' in spec.ref else _null_tensor)
|
158 |
+
return y
|
159 |
+
|
160 |
+
@staticmethod
|
161 |
+
def backward(ctx, dy): # pylint: disable=arguments-differ
|
162 |
+
dy = dy.contiguous(memory_format=ctx.memory_format)
|
163 |
+
x, b, y = ctx.saved_tensors
|
164 |
+
dx = None
|
165 |
+
db = None
|
166 |
+
|
167 |
+
if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
|
168 |
+
dx = dy
|
169 |
+
if act != 'linear' or gain != 1 or clamp >= 0:
|
170 |
+
dx = BiasActCudaGrad.apply(dy, x, b, y)
|
171 |
+
|
172 |
+
if ctx.needs_input_grad[1]:
|
173 |
+
db = dx.sum([i for i in range(dx.ndim) if i != dim])
|
174 |
+
|
175 |
+
return dx, db
|
176 |
+
|
177 |
+
# Backward op.
|
178 |
+
class BiasActCudaGrad(torch.autograd.Function):
|
179 |
+
@staticmethod
|
180 |
+
def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ
|
181 |
+
ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride()[1] == 1 else torch.contiguous_format
|
182 |
+
dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp)
|
183 |
+
ctx.save_for_backward(
|
184 |
+
dy if spec.has_2nd_grad else _null_tensor,
|
185 |
+
x, b, y)
|
186 |
+
return dx
|
187 |
+
|
188 |
+
@staticmethod
|
189 |
+
def backward(ctx, d_dx): # pylint: disable=arguments-differ
|
190 |
+
d_dx = d_dx.contiguous(memory_format=ctx.memory_format)
|
191 |
+
dy, x, b, y = ctx.saved_tensors
|
192 |
+
d_dy = None
|
193 |
+
d_x = None
|
194 |
+
d_b = None
|
195 |
+
d_y = None
|
196 |
+
|
197 |
+
if ctx.needs_input_grad[0]:
|
198 |
+
d_dy = BiasActCudaGrad.apply(d_dx, x, b, y)
|
199 |
+
|
200 |
+
if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]):
|
201 |
+
d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp)
|
202 |
+
|
203 |
+
if spec.has_2nd_grad and ctx.needs_input_grad[2]:
|
204 |
+
d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim])
|
205 |
+
|
206 |
+
return d_dy, d_x, d_b, d_y
|
207 |
+
|
208 |
+
# Add to cache.
|
209 |
+
_bias_act_cuda_cache[key] = BiasActCuda
|
210 |
+
return BiasActCuda
|
211 |
+
|
212 |
+
#----------------------------------------------------------------------------
|
torch_utils/ops/conv2d_gradfix.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
"""Custom replacement for `torch.nn.functional.conv2d` that supports
|
10 |
+
arbitrarily high order gradients with zero performance penalty."""
|
11 |
+
|
12 |
+
import warnings
|
13 |
+
import contextlib
|
14 |
+
import torch
|
15 |
+
|
16 |
+
# pylint: disable=redefined-builtin
|
17 |
+
# pylint: disable=arguments-differ
|
18 |
+
# pylint: disable=protected-access
|
19 |
+
|
20 |
+
#----------------------------------------------------------------------------
|
21 |
+
|
22 |
+
enabled = False # Enable the custom op by setting this to true.
|
23 |
+
weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights.
|
24 |
+
|
25 |
+
@contextlib.contextmanager
|
26 |
+
def no_weight_gradients():
|
27 |
+
global weight_gradients_disabled
|
28 |
+
old = weight_gradients_disabled
|
29 |
+
weight_gradients_disabled = True
|
30 |
+
yield
|
31 |
+
weight_gradients_disabled = old
|
32 |
+
|
33 |
+
#----------------------------------------------------------------------------
|
34 |
+
|
35 |
+
def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
36 |
+
if _should_use_custom_op(input):
|
37 |
+
return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias)
|
38 |
+
return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
|
39 |
+
|
40 |
+
def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
|
41 |
+
if _should_use_custom_op(input):
|
42 |
+
return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias)
|
43 |
+
return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation)
|
44 |
+
|
45 |
+
#----------------------------------------------------------------------------
|
46 |
+
|
47 |
+
def _should_use_custom_op(input):
|
48 |
+
assert isinstance(input, torch.Tensor)
|
49 |
+
if (not enabled) or (not torch.backends.cudnn.enabled):
|
50 |
+
return False
|
51 |
+
if input.device.type != 'cuda':
|
52 |
+
return False
|
53 |
+
if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']):
|
54 |
+
return True
|
55 |
+
warnings.warn(f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().')
|
56 |
+
return False
|
57 |
+
|
58 |
+
def _tuple_of_ints(xs, ndim):
|
59 |
+
xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
|
60 |
+
assert len(xs) == ndim
|
61 |
+
assert all(isinstance(x, int) for x in xs)
|
62 |
+
return xs
|
63 |
+
|
64 |
+
#----------------------------------------------------------------------------
|
65 |
+
|
66 |
+
_conv2d_gradfix_cache = dict()
|
67 |
+
|
68 |
+
def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups):
|
69 |
+
# Parse arguments.
|
70 |
+
ndim = 2
|
71 |
+
weight_shape = tuple(weight_shape)
|
72 |
+
stride = _tuple_of_ints(stride, ndim)
|
73 |
+
padding = _tuple_of_ints(padding, ndim)
|
74 |
+
output_padding = _tuple_of_ints(output_padding, ndim)
|
75 |
+
dilation = _tuple_of_ints(dilation, ndim)
|
76 |
+
|
77 |
+
# Lookup from cache.
|
78 |
+
key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
|
79 |
+
if key in _conv2d_gradfix_cache:
|
80 |
+
return _conv2d_gradfix_cache[key]
|
81 |
+
|
82 |
+
# Validate arguments.
|
83 |
+
assert groups >= 1
|
84 |
+
assert len(weight_shape) == ndim + 2
|
85 |
+
assert all(stride[i] >= 1 for i in range(ndim))
|
86 |
+
assert all(padding[i] >= 0 for i in range(ndim))
|
87 |
+
assert all(dilation[i] >= 0 for i in range(ndim))
|
88 |
+
if not transpose:
|
89 |
+
assert all(output_padding[i] == 0 for i in range(ndim))
|
90 |
+
else: # transpose
|
91 |
+
assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim))
|
92 |
+
|
93 |
+
# Helpers.
|
94 |
+
common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups)
|
95 |
+
def calc_output_padding(input_shape, output_shape):
|
96 |
+
if transpose:
|
97 |
+
return [0, 0]
|
98 |
+
return [
|
99 |
+
input_shape[i + 2]
|
100 |
+
- (output_shape[i + 2] - 1) * stride[i]
|
101 |
+
- (1 - 2 * padding[i])
|
102 |
+
- dilation[i] * (weight_shape[i + 2] - 1)
|
103 |
+
for i in range(ndim)
|
104 |
+
]
|
105 |
+
|
106 |
+
# Forward & backward.
|
107 |
+
class Conv2d(torch.autograd.Function):
|
108 |
+
@staticmethod
|
109 |
+
def forward(ctx, input, weight, bias):
|
110 |
+
assert weight.shape == weight_shape
|
111 |
+
if not transpose:
|
112 |
+
output = torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
|
113 |
+
else: # transpose
|
114 |
+
output = torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs)
|
115 |
+
ctx.save_for_backward(input, weight)
|
116 |
+
return output
|
117 |
+
|
118 |
+
@staticmethod
|
119 |
+
def backward(ctx, grad_output):
|
120 |
+
input, weight = ctx.saved_tensors
|
121 |
+
grad_input = None
|
122 |
+
grad_weight = None
|
123 |
+
grad_bias = None
|
124 |
+
|
125 |
+
if ctx.needs_input_grad[0]:
|
126 |
+
p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
|
127 |
+
grad_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, weight, None)
|
128 |
+
assert grad_input.shape == input.shape
|
129 |
+
|
130 |
+
if ctx.needs_input_grad[1] and not weight_gradients_disabled:
|
131 |
+
grad_weight = Conv2dGradWeight.apply(grad_output, input)
|
132 |
+
assert grad_weight.shape == weight_shape
|
133 |
+
|
134 |
+
if ctx.needs_input_grad[2]:
|
135 |
+
grad_bias = grad_output.sum([0, 2, 3])
|
136 |
+
|
137 |
+
return grad_input, grad_weight, grad_bias
|
138 |
+
|
139 |
+
# Gradient with respect to the weights.
|
140 |
+
class Conv2dGradWeight(torch.autograd.Function):
|
141 |
+
@staticmethod
|
142 |
+
def forward(ctx, grad_output, input):
|
143 |
+
op = torch._C._jit_get_operation('aten::cudnn_convolution_backward_weight' if not transpose else 'aten::cudnn_convolution_transpose_backward_weight')
|
144 |
+
flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32]
|
145 |
+
grad_weight = op(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags)
|
146 |
+
assert grad_weight.shape == weight_shape
|
147 |
+
ctx.save_for_backward(grad_output, input)
|
148 |
+
return grad_weight
|
149 |
+
|
150 |
+
@staticmethod
|
151 |
+
def backward(ctx, grad2_grad_weight):
|
152 |
+
grad_output, input = ctx.saved_tensors
|
153 |
+
grad2_grad_output = None
|
154 |
+
grad2_input = None
|
155 |
+
|
156 |
+
if ctx.needs_input_grad[0]:
|
157 |
+
grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None)
|
158 |
+
assert grad2_grad_output.shape == grad_output.shape
|
159 |
+
|
160 |
+
if ctx.needs_input_grad[1]:
|
161 |
+
p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
|
162 |
+
grad2_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, grad2_grad_weight, None)
|
163 |
+
assert grad2_input.shape == input.shape
|
164 |
+
|
165 |
+
return grad2_grad_output, grad2_input
|
166 |
+
|
167 |
+
_conv2d_gradfix_cache[key] = Conv2d
|
168 |
+
return Conv2d
|
169 |
+
|
170 |
+
#----------------------------------------------------------------------------
|
torch_utils/ops/conv2d_resample.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
"""2D convolution with optional up/downsampling."""
|
10 |
+
|
11 |
+
import torch
|
12 |
+
|
13 |
+
from .. import misc
|
14 |
+
from . import conv2d_gradfix
|
15 |
+
from . import upfirdn2d
|
16 |
+
from .upfirdn2d import _parse_padding
|
17 |
+
from .upfirdn2d import _get_filter_size
|
18 |
+
|
19 |
+
#----------------------------------------------------------------------------
|
20 |
+
|
21 |
+
def _get_weight_shape(w):
|
22 |
+
with misc.suppress_tracer_warnings(): # this value will be treated as a constant
|
23 |
+
shape = [int(sz) for sz in w.shape]
|
24 |
+
misc.assert_shape(w, shape)
|
25 |
+
return shape
|
26 |
+
|
27 |
+
#----------------------------------------------------------------------------
|
28 |
+
|
29 |
+
def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True):
|
30 |
+
"""Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations.
|
31 |
+
"""
|
32 |
+
out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
|
33 |
+
|
34 |
+
# Flip weight if requested.
|
35 |
+
if not flip_weight: # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).
|
36 |
+
w = w.flip([2, 3])
|
37 |
+
|
38 |
+
# Workaround performance pitfall in cuDNN 8.0.5, triggered when using
|
39 |
+
# 1x1 kernel + memory_format=channels_last + less than 64 channels.
|
40 |
+
if kw == 1 and kh == 1 and stride == 1 and padding in [0, [0, 0], (0, 0)] and not transpose:
|
41 |
+
if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64:
|
42 |
+
if out_channels <= 4 and groups == 1:
|
43 |
+
in_shape = x.shape
|
44 |
+
x = w.squeeze(3).squeeze(2) @ x.reshape([in_shape[0], in_channels_per_group, -1])
|
45 |
+
x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]])
|
46 |
+
else:
|
47 |
+
x = x.to(memory_format=torch.contiguous_format)
|
48 |
+
w = w.to(memory_format=torch.contiguous_format)
|
49 |
+
x = conv2d_gradfix.conv2d(x, w, groups=groups)
|
50 |
+
return x.to(memory_format=torch.channels_last)
|
51 |
+
|
52 |
+
# Otherwise => execute using conv2d_gradfix.
|
53 |
+
op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d
|
54 |
+
return op(x, w, stride=stride, padding=padding, groups=groups)
|
55 |
+
|
56 |
+
#----------------------------------------------------------------------------
|
57 |
+
|
58 |
+
@misc.profiled_function
|
59 |
+
def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False):
|
60 |
+
r"""2D convolution with optional up/downsampling.
|
61 |
+
|
62 |
+
Padding is performed only once at the beginning, not between the operations.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
x: Input tensor of shape
|
66 |
+
`[batch_size, in_channels, in_height, in_width]`.
|
67 |
+
w: Weight tensor of shape
|
68 |
+
`[out_channels, in_channels//groups, kernel_height, kernel_width]`.
|
69 |
+
f: Low-pass filter for up/downsampling. Must be prepared beforehand by
|
70 |
+
calling upfirdn2d.setup_filter(). None = identity (default).
|
71 |
+
up: Integer upsampling factor (default: 1).
|
72 |
+
down: Integer downsampling factor (default: 1).
|
73 |
+
padding: Padding with respect to the upsampled image. Can be a single number
|
74 |
+
or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
75 |
+
(default: 0).
|
76 |
+
groups: Split input channels into N groups (default: 1).
|
77 |
+
flip_weight: False = convolution, True = correlation (default: True).
|
78 |
+
flip_filter: False = convolution, True = correlation (default: False).
|
79 |
+
|
80 |
+
Returns:
|
81 |
+
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
82 |
+
"""
|
83 |
+
# Validate arguments.
|
84 |
+
assert isinstance(x, torch.Tensor) and (x.ndim == 4)
|
85 |
+
assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype)
|
86 |
+
assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32)
|
87 |
+
assert isinstance(up, int) and (up >= 1)
|
88 |
+
assert isinstance(down, int) and (down >= 1)
|
89 |
+
assert isinstance(groups, int) and (groups >= 1)
|
90 |
+
out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
|
91 |
+
fw, fh = _get_filter_size(f)
|
92 |
+
px0, px1, py0, py1 = _parse_padding(padding)
|
93 |
+
|
94 |
+
# Adjust padding to account for up/downsampling.
|
95 |
+
if up > 1:
|
96 |
+
px0 += (fw + up - 1) // 2
|
97 |
+
px1 += (fw - up) // 2
|
98 |
+
py0 += (fh + up - 1) // 2
|
99 |
+
py1 += (fh - up) // 2
|
100 |
+
if down > 1:
|
101 |
+
px0 += (fw - down + 1) // 2
|
102 |
+
px1 += (fw - down) // 2
|
103 |
+
py0 += (fh - down + 1) // 2
|
104 |
+
py1 += (fh - down) // 2
|
105 |
+
|
106 |
+
# Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.
|
107 |
+
if kw == 1 and kh == 1 and (down > 1 and up == 1):
|
108 |
+
x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
|
109 |
+
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
|
110 |
+
return x
|
111 |
+
|
112 |
+
# Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.
|
113 |
+
if kw == 1 and kh == 1 and (up > 1 and down == 1):
|
114 |
+
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
|
115 |
+
x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
|
116 |
+
return x
|
117 |
+
|
118 |
+
# Fast path: downsampling only => use strided convolution.
|
119 |
+
if down > 1 and up == 1:
|
120 |
+
x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
|
121 |
+
x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight)
|
122 |
+
return x
|
123 |
+
|
124 |
+
# Fast path: upsampling with optional downsampling => use transpose strided convolution.
|
125 |
+
if up > 1:
|
126 |
+
if groups == 1:
|
127 |
+
w = w.transpose(0, 1)
|
128 |
+
else:
|
129 |
+
w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw)
|
130 |
+
w = w.transpose(1, 2)
|
131 |
+
w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw)
|
132 |
+
px0 -= kw - 1
|
133 |
+
px1 -= kw - up
|
134 |
+
py0 -= kh - 1
|
135 |
+
py1 -= kh - up
|
136 |
+
pxt = max(min(-px0, -px1), 0)
|
137 |
+
pyt = max(min(-py0, -py1), 0)
|
138 |
+
x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight))
|
139 |
+
x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter)
|
140 |
+
if down > 1:
|
141 |
+
x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
|
142 |
+
return x
|
143 |
+
|
144 |
+
# Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d.
|
145 |
+
if up == 1 and down == 1:
|
146 |
+
if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0:
|
147 |
+
return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight)
|
148 |
+
|
149 |
+
# Fallback: Generic reference implementation.
|
150 |
+
x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
|
151 |
+
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
|
152 |
+
if down > 1:
|
153 |
+
x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
|
154 |
+
return x
|
155 |
+
|
156 |
+
#----------------------------------------------------------------------------
|
torch_utils/ops/fma.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
"""Fused multiply-add, with slightly faster gradients than `torch.addcmul()`."""
|
10 |
+
|
11 |
+
import torch
|
12 |
+
|
13 |
+
#----------------------------------------------------------------------------
|
14 |
+
|
15 |
+
def fma(a, b, c): # => a * b + c
|
16 |
+
return _FusedMultiplyAdd.apply(a, b, c)
|
17 |
+
|
18 |
+
#----------------------------------------------------------------------------
|
19 |
+
|
20 |
+
class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c
|
21 |
+
@staticmethod
|
22 |
+
def forward(ctx, a, b, c): # pylint: disable=arguments-differ
|
23 |
+
out = torch.addcmul(c, a, b)
|
24 |
+
ctx.save_for_backward(a, b)
|
25 |
+
ctx.c_shape = c.shape
|
26 |
+
return out
|
27 |
+
|
28 |
+
@staticmethod
|
29 |
+
def backward(ctx, dout): # pylint: disable=arguments-differ
|
30 |
+
a, b = ctx.saved_tensors
|
31 |
+
c_shape = ctx.c_shape
|
32 |
+
da = None
|
33 |
+
db = None
|
34 |
+
dc = None
|
35 |
+
|
36 |
+
if ctx.needs_input_grad[0]:
|
37 |
+
da = _unbroadcast(dout * b, a.shape)
|
38 |
+
|
39 |
+
if ctx.needs_input_grad[1]:
|
40 |
+
db = _unbroadcast(dout * a, b.shape)
|
41 |
+
|
42 |
+
if ctx.needs_input_grad[2]:
|
43 |
+
dc = _unbroadcast(dout, c_shape)
|
44 |
+
|
45 |
+
return da, db, dc
|
46 |
+
|
47 |
+
#----------------------------------------------------------------------------
|
48 |
+
|
49 |
+
def _unbroadcast(x, shape):
|
50 |
+
extra_dims = x.ndim - len(shape)
|
51 |
+
assert extra_dims >= 0
|
52 |
+
dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)]
|
53 |
+
if len(dim):
|
54 |
+
x = x.sum(dim=dim, keepdim=True)
|
55 |
+
if extra_dims:
|
56 |
+
x = x.reshape(-1, *x.shape[extra_dims+1:])
|
57 |
+
assert x.shape == shape
|
58 |
+
return x
|
59 |
+
|
60 |
+
#----------------------------------------------------------------------------
|
torch_utils/ops/grid_sample_gradfix.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
"""Custom replacement for `torch.nn.functional.grid_sample` that
|
10 |
+
supports arbitrarily high order gradients between the input and output.
|
11 |
+
Only works on 2D images and assumes
|
12 |
+
`mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`."""
|
13 |
+
|
14 |
+
import warnings
|
15 |
+
import torch
|
16 |
+
|
17 |
+
# pylint: disable=redefined-builtin
|
18 |
+
# pylint: disable=arguments-differ
|
19 |
+
# pylint: disable=protected-access
|
20 |
+
|
21 |
+
#----------------------------------------------------------------------------
|
22 |
+
|
23 |
+
enabled = False # Enable the custom op by setting this to true.
|
24 |
+
|
25 |
+
#----------------------------------------------------------------------------
|
26 |
+
|
27 |
+
def grid_sample(input, grid):
|
28 |
+
if _should_use_custom_op():
|
29 |
+
return _GridSample2dForward.apply(input, grid)
|
30 |
+
return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
|
31 |
+
|
32 |
+
#----------------------------------------------------------------------------
|
33 |
+
|
34 |
+
def _should_use_custom_op():
|
35 |
+
if not enabled:
|
36 |
+
return False
|
37 |
+
if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']):
|
38 |
+
return True
|
39 |
+
warnings.warn(f'grid_sample_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.grid_sample().')
|
40 |
+
return False
|
41 |
+
|
42 |
+
#----------------------------------------------------------------------------
|
43 |
+
|
44 |
+
class _GridSample2dForward(torch.autograd.Function):
|
45 |
+
@staticmethod
|
46 |
+
def forward(ctx, input, grid):
|
47 |
+
assert input.ndim == 4
|
48 |
+
assert grid.ndim == 4
|
49 |
+
output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
|
50 |
+
ctx.save_for_backward(input, grid)
|
51 |
+
return output
|
52 |
+
|
53 |
+
@staticmethod
|
54 |
+
def backward(ctx, grad_output):
|
55 |
+
input, grid = ctx.saved_tensors
|
56 |
+
grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid)
|
57 |
+
return grad_input, grad_grid
|
58 |
+
|
59 |
+
#----------------------------------------------------------------------------
|
60 |
+
|
61 |
+
class _GridSample2dBackward(torch.autograd.Function):
|
62 |
+
@staticmethod
|
63 |
+
def forward(ctx, grad_output, input, grid):
|
64 |
+
op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')
|
65 |
+
grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
|
66 |
+
ctx.save_for_backward(grid)
|
67 |
+
return grad_input, grad_grid
|
68 |
+
|
69 |
+
@staticmethod
|
70 |
+
def backward(ctx, grad2_grad_input, grad2_grad_grid):
|
71 |
+
_ = grad2_grad_grid # unused
|
72 |
+
grid, = ctx.saved_tensors
|
73 |
+
grad2_grad_output = None
|
74 |
+
grad2_input = None
|
75 |
+
grad2_grid = None
|
76 |
+
|
77 |
+
if ctx.needs_input_grad[0]:
|
78 |
+
grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid)
|
79 |
+
|
80 |
+
assert not ctx.needs_input_grad[2]
|
81 |
+
return grad2_grad_output, grad2_input, grad2_grid
|
82 |
+
|
83 |
+
#----------------------------------------------------------------------------
|
torch_utils/ops/upfirdn2d.cpp
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
//
|
3 |
+
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
// and proprietary rights in and to this software, related documentation
|
5 |
+
// and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
// distribution of this software and related documentation without an express
|
7 |
+
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
#include <torch/extension.h>
|
10 |
+
#include <ATen/cuda/CUDAContext.h>
|
11 |
+
#include <c10/cuda/CUDAGuard.h>
|
12 |
+
#include "upfirdn2d.h"
|
13 |
+
|
14 |
+
//------------------------------------------------------------------------
|
15 |
+
|
16 |
+
static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain)
|
17 |
+
{
|
18 |
+
// Validate arguments.
|
19 |
+
TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
|
20 |
+
TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x");
|
21 |
+
TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32");
|
22 |
+
TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
|
23 |
+
TORCH_CHECK(f.numel() <= INT_MAX, "f is too large");
|
24 |
+
TORCH_CHECK(x.dim() == 4, "x must be rank 4");
|
25 |
+
TORCH_CHECK(f.dim() == 2, "f must be rank 2");
|
26 |
+
TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1");
|
27 |
+
TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1");
|
28 |
+
TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1");
|
29 |
+
|
30 |
+
// Create output tensor.
|
31 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
32 |
+
int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx;
|
33 |
+
int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy;
|
34 |
+
TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1");
|
35 |
+
torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format());
|
36 |
+
TORCH_CHECK(y.numel() <= INT_MAX, "output is too large");
|
37 |
+
|
38 |
+
// Initialize CUDA kernel parameters.
|
39 |
+
upfirdn2d_kernel_params p;
|
40 |
+
p.x = x.data_ptr();
|
41 |
+
p.f = f.data_ptr<float>();
|
42 |
+
p.y = y.data_ptr();
|
43 |
+
p.up = make_int2(upx, upy);
|
44 |
+
p.down = make_int2(downx, downy);
|
45 |
+
p.pad0 = make_int2(padx0, pady0);
|
46 |
+
p.flip = (flip) ? 1 : 0;
|
47 |
+
p.gain = gain;
|
48 |
+
p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
|
49 |
+
p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0));
|
50 |
+
p.filterSize = make_int2((int)f.size(1), (int)f.size(0));
|
51 |
+
p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0));
|
52 |
+
p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
|
53 |
+
p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0));
|
54 |
+
p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z;
|
55 |
+
p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1;
|
56 |
+
|
57 |
+
// Choose CUDA kernel.
|
58 |
+
upfirdn2d_kernel_spec spec;
|
59 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
|
60 |
+
{
|
61 |
+
spec = choose_upfirdn2d_kernel<scalar_t>(p);
|
62 |
+
});
|
63 |
+
|
64 |
+
// Set looping options.
|
65 |
+
p.loopMajor = (p.sizeMajor - 1) / 16384 + 1;
|
66 |
+
p.loopMinor = spec.loopMinor;
|
67 |
+
p.loopX = spec.loopX;
|
68 |
+
p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1;
|
69 |
+
p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1;
|
70 |
+
|
71 |
+
// Compute grid size.
|
72 |
+
dim3 blockSize, gridSize;
|
73 |
+
if (spec.tileOutW < 0) // large
|
74 |
+
{
|
75 |
+
blockSize = dim3(4, 32, 1);
|
76 |
+
gridSize = dim3(
|
77 |
+
((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor,
|
78 |
+
(p.outSize.x - 1) / (blockSize.y * p.loopX) + 1,
|
79 |
+
p.launchMajor);
|
80 |
+
}
|
81 |
+
else // small
|
82 |
+
{
|
83 |
+
blockSize = dim3(256, 1, 1);
|
84 |
+
gridSize = dim3(
|
85 |
+
((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor,
|
86 |
+
(p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1,
|
87 |
+
p.launchMajor);
|
88 |
+
}
|
89 |
+
|
90 |
+
// Launch CUDA kernel.
|
91 |
+
void* args[] = {&p};
|
92 |
+
AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
|
93 |
+
return y;
|
94 |
+
}
|
95 |
+
|
96 |
+
//------------------------------------------------------------------------
|
97 |
+
|
98 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
99 |
+
{
|
100 |
+
m.def("upfirdn2d", &upfirdn2d);
|
101 |
+
}
|
102 |
+
|
103 |
+
//------------------------------------------------------------------------
|
torch_utils/ops/upfirdn2d.cu
ADDED
@@ -0,0 +1,350 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
//
|
3 |
+
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
// and proprietary rights in and to this software, related documentation
|
5 |
+
// and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
// distribution of this software and related documentation without an express
|
7 |
+
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
#include <c10/util/Half.h>
|
10 |
+
#include "upfirdn2d.h"
|
11 |
+
|
12 |
+
//------------------------------------------------------------------------
|
13 |
+
// Helpers.
|
14 |
+
|
15 |
+
template <class T> struct InternalType;
|
16 |
+
template <> struct InternalType<double> { typedef double scalar_t; };
|
17 |
+
template <> struct InternalType<float> { typedef float scalar_t; };
|
18 |
+
template <> struct InternalType<c10::Half> { typedef float scalar_t; };
|
19 |
+
|
20 |
+
static __device__ __forceinline__ int floor_div(int a, int b)
|
21 |
+
{
|
22 |
+
int t = 1 - a / b;
|
23 |
+
return (a + t * b) / b - t;
|
24 |
+
}
|
25 |
+
|
26 |
+
//------------------------------------------------------------------------
|
27 |
+
// Generic CUDA implementation for large filters.
|
28 |
+
|
29 |
+
template <class T> static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p)
|
30 |
+
{
|
31 |
+
typedef typename InternalType<T>::scalar_t scalar_t;
|
32 |
+
|
33 |
+
// Calculate thread index.
|
34 |
+
int minorBase = blockIdx.x * blockDim.x + threadIdx.x;
|
35 |
+
int outY = minorBase / p.launchMinor;
|
36 |
+
minorBase -= outY * p.launchMinor;
|
37 |
+
int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y;
|
38 |
+
int majorBase = blockIdx.z * p.loopMajor;
|
39 |
+
if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor)
|
40 |
+
return;
|
41 |
+
|
42 |
+
// Setup Y receptive field.
|
43 |
+
int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y;
|
44 |
+
int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y);
|
45 |
+
int h = min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY;
|
46 |
+
int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y;
|
47 |
+
if (p.flip)
|
48 |
+
filterY = p.filterSize.y - 1 - filterY;
|
49 |
+
|
50 |
+
// Loop over major, minor, and X.
|
51 |
+
for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
|
52 |
+
for (int minorIdx = 0, minor = minorBase; minorIdx < p.loopMinor & minor < p.sizeMinor; minorIdx++, minor += p.launchMinor)
|
53 |
+
{
|
54 |
+
int nc = major * p.sizeMinor + minor;
|
55 |
+
int n = nc / p.inSize.z;
|
56 |
+
int c = nc - n * p.inSize.z;
|
57 |
+
for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x; loopX++, outX += blockDim.y)
|
58 |
+
{
|
59 |
+
// Setup X receptive field.
|
60 |
+
int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x;
|
61 |
+
int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x);
|
62 |
+
int w = min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) - inX;
|
63 |
+
int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x;
|
64 |
+
if (p.flip)
|
65 |
+
filterX = p.filterSize.x - 1 - filterX;
|
66 |
+
|
67 |
+
// Initialize pointers.
|
68 |
+
const T* xp = &((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
|
69 |
+
const float* fp = &p.f[filterX * p.filterStride.x + filterY * p.filterStride.y];
|
70 |
+
int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x;
|
71 |
+
int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y;
|
72 |
+
|
73 |
+
// Inner loop.
|
74 |
+
scalar_t v = 0;
|
75 |
+
for (int y = 0; y < h; y++)
|
76 |
+
{
|
77 |
+
for (int x = 0; x < w; x++)
|
78 |
+
{
|
79 |
+
v += (scalar_t)(*xp) * (scalar_t)(*fp);
|
80 |
+
xp += p.inStride.x;
|
81 |
+
fp += filterStepX;
|
82 |
+
}
|
83 |
+
xp += p.inStride.y - w * p.inStride.x;
|
84 |
+
fp += filterStepY - w * filterStepX;
|
85 |
+
}
|
86 |
+
|
87 |
+
// Store result.
|
88 |
+
v *= p.gain;
|
89 |
+
((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
|
90 |
+
}
|
91 |
+
}
|
92 |
+
}
|
93 |
+
|
94 |
+
//------------------------------------------------------------------------
|
95 |
+
// Specialized CUDA implementation for small filters.
|
96 |
+
|
97 |
+
template <class T, int upx, int upy, int downx, int downy, int filterW, int filterH, int tileOutW, int tileOutH, int loopMinor>
|
98 |
+
static __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p)
|
99 |
+
{
|
100 |
+
typedef typename InternalType<T>::scalar_t scalar_t;
|
101 |
+
const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1;
|
102 |
+
const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1;
|
103 |
+
__shared__ volatile scalar_t sf[filterH][filterW];
|
104 |
+
__shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor];
|
105 |
+
|
106 |
+
// Calculate tile index.
|
107 |
+
int minorBase = blockIdx.x;
|
108 |
+
int tileOutY = minorBase / p.launchMinor;
|
109 |
+
minorBase -= tileOutY * p.launchMinor;
|
110 |
+
minorBase *= loopMinor;
|
111 |
+
tileOutY *= tileOutH;
|
112 |
+
int tileOutXBase = blockIdx.y * p.loopX * tileOutW;
|
113 |
+
int majorBase = blockIdx.z * p.loopMajor;
|
114 |
+
if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y | majorBase >= p.sizeMajor)
|
115 |
+
return;
|
116 |
+
|
117 |
+
// Load filter (flipped).
|
118 |
+
for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW; tapIdx += blockDim.x)
|
119 |
+
{
|
120 |
+
int fy = tapIdx / filterW;
|
121 |
+
int fx = tapIdx - fy * filterW;
|
122 |
+
scalar_t v = 0;
|
123 |
+
if (fx < p.filterSize.x & fy < p.filterSize.y)
|
124 |
+
{
|
125 |
+
int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx;
|
126 |
+
int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy;
|
127 |
+
v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y];
|
128 |
+
}
|
129 |
+
sf[fy][fx] = v;
|
130 |
+
}
|
131 |
+
|
132 |
+
// Loop over major and X.
|
133 |
+
for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
|
134 |
+
{
|
135 |
+
int baseNC = major * p.sizeMinor + minorBase;
|
136 |
+
int n = baseNC / p.inSize.z;
|
137 |
+
int baseC = baseNC - n * p.inSize.z;
|
138 |
+
for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outSize.x; loopX++, tileOutX += tileOutW)
|
139 |
+
{
|
140 |
+
// Load input pixels.
|
141 |
+
int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x;
|
142 |
+
int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y;
|
143 |
+
int tileInX = floor_div(tileMidX, upx);
|
144 |
+
int tileInY = floor_div(tileMidY, upy);
|
145 |
+
__syncthreads();
|
146 |
+
for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor; inIdx += blockDim.x)
|
147 |
+
{
|
148 |
+
int relC = inIdx;
|
149 |
+
int relInX = relC / loopMinor;
|
150 |
+
int relInY = relInX / tileInW;
|
151 |
+
relC -= relInX * loopMinor;
|
152 |
+
relInX -= relInY * tileInW;
|
153 |
+
int c = baseC + relC;
|
154 |
+
int inX = tileInX + relInX;
|
155 |
+
int inY = tileInY + relInY;
|
156 |
+
scalar_t v = 0;
|
157 |
+
if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y & c < p.inSize.z)
|
158 |
+
v = (scalar_t)((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
|
159 |
+
sx[relInY][relInX][relC] = v;
|
160 |
+
}
|
161 |
+
|
162 |
+
// Loop over output pixels.
|
163 |
+
__syncthreads();
|
164 |
+
for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor; outIdx += blockDim.x)
|
165 |
+
{
|
166 |
+
int relC = outIdx;
|
167 |
+
int relOutX = relC / loopMinor;
|
168 |
+
int relOutY = relOutX / tileOutW;
|
169 |
+
relC -= relOutX * loopMinor;
|
170 |
+
relOutX -= relOutY * tileOutW;
|
171 |
+
int c = baseC + relC;
|
172 |
+
int outX = tileOutX + relOutX;
|
173 |
+
int outY = tileOutY + relOutY;
|
174 |
+
|
175 |
+
// Setup receptive field.
|
176 |
+
int midX = tileMidX + relOutX * downx;
|
177 |
+
int midY = tileMidY + relOutY * downy;
|
178 |
+
int inX = floor_div(midX, upx);
|
179 |
+
int inY = floor_div(midY, upy);
|
180 |
+
int relInX = inX - tileInX;
|
181 |
+
int relInY = inY - tileInY;
|
182 |
+
int filterX = (inX + 1) * upx - midX - 1; // flipped
|
183 |
+
int filterY = (inY + 1) * upy - midY - 1; // flipped
|
184 |
+
|
185 |
+
// Inner loop.
|
186 |
+
if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z)
|
187 |
+
{
|
188 |
+
scalar_t v = 0;
|
189 |
+
#pragma unroll
|
190 |
+
for (int y = 0; y < filterH / upy; y++)
|
191 |
+
#pragma unroll
|
192 |
+
for (int x = 0; x < filterW / upx; x++)
|
193 |
+
v += sx[relInY + y][relInX + x][relC] * sf[filterY + y * upy][filterX + x * upx];
|
194 |
+
v *= p.gain;
|
195 |
+
((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
|
196 |
+
}
|
197 |
+
}
|
198 |
+
}
|
199 |
+
}
|
200 |
+
}
|
201 |
+
|
202 |
+
//------------------------------------------------------------------------
|
203 |
+
// CUDA kernel selection.
|
204 |
+
|
205 |
+
template <class T> upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p)
|
206 |
+
{
|
207 |
+
int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y;
|
208 |
+
|
209 |
+
upfirdn2d_kernel_spec spec = {(void*)upfirdn2d_kernel_large<T>, -1,-1,1, 4}; // contiguous
|
210 |
+
if (s == 1) spec = {(void*)upfirdn2d_kernel_large<T>, -1,-1,4, 1}; // channels_last
|
211 |
+
|
212 |
+
if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous
|
213 |
+
{
|
214 |
+
if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 7,7, 64,16,1>, 64,16,1, 1};
|
215 |
+
if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 6,6, 64,16,1>, 64,16,1, 1};
|
216 |
+
if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 5,5, 64,16,1>, 64,16,1, 1};
|
217 |
+
if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 64,16,1>, 64,16,1, 1};
|
218 |
+
if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 3,3, 64,16,1>, 64,16,1, 1};
|
219 |
+
if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,1, 128,8,1>, 128,8,1, 1};
|
220 |
+
if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 20,1, 128,8,1>, 128,8,1, 1};
|
221 |
+
if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,1, 128,8,1>, 128,8,1, 1};
|
222 |
+
if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 12,1, 128,8,1>, 128,8,1, 1};
|
223 |
+
if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 8,1, 128,8,1>, 128,8,1, 1};
|
224 |
+
if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,24, 32,32,1>, 32,32,1, 1};
|
225 |
+
if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,20, 32,32,1>, 32,32,1, 1};
|
226 |
+
if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,16, 32,32,1>, 32,32,1, 1};
|
227 |
+
if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,12, 32,32,1>, 32,32,1, 1};
|
228 |
+
if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,8, 32,32,1>, 32,32,1, 1};
|
229 |
+
}
|
230 |
+
if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last
|
231 |
+
{
|
232 |
+
if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 7,7, 16,16,8>, 16,16,8, 1};
|
233 |
+
if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
|
234 |
+
if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
|
235 |
+
if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
|
236 |
+
if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
|
237 |
+
if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,1, 128,1,16>, 128,1,16, 1};
|
238 |
+
if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 20,1, 128,1,16>, 128,1,16, 1};
|
239 |
+
if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,1, 128,1,16>, 128,1,16, 1};
|
240 |
+
if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 12,1, 128,1,16>, 128,1,16, 1};
|
241 |
+
if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 8,1, 128,1,16>, 128,1,16, 1};
|
242 |
+
if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,24, 1,128,16>, 1,128,16, 1};
|
243 |
+
if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,20, 1,128,16>, 1,128,16, 1};
|
244 |
+
if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,16, 1,128,16>, 1,128,16, 1};
|
245 |
+
if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,12, 1,128,16>, 1,128,16, 1};
|
246 |
+
if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,8, 1,128,16>, 1,128,16, 1};
|
247 |
+
}
|
248 |
+
if (s != 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous
|
249 |
+
{
|
250 |
+
if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 8,8, 64,16,1>, 64,16,1, 1};
|
251 |
+
if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 6,6, 64,16,1>, 64,16,1, 1};
|
252 |
+
if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 4,4, 64,16,1>, 64,16,1, 1};
|
253 |
+
if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 2,2, 64,16,1>, 64,16,1, 1};
|
254 |
+
}
|
255 |
+
if (s == 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last
|
256 |
+
{
|
257 |
+
if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 8,8, 16,16,8>, 16,16,8, 1};
|
258 |
+
if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 6,6, 16,16,8>, 16,16,8, 1};
|
259 |
+
if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
|
260 |
+
if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 2,2, 16,16,8>, 16,16,8, 1};
|
261 |
+
}
|
262 |
+
if (s != 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous
|
263 |
+
{
|
264 |
+
if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 24,1, 128,8,1>, 128,8,1, 1};
|
265 |
+
if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 20,1, 128,8,1>, 128,8,1, 1};
|
266 |
+
if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 16,1, 128,8,1>, 128,8,1, 1};
|
267 |
+
if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 12,1, 128,8,1>, 128,8,1, 1};
|
268 |
+
if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 8,1, 128,8,1>, 128,8,1, 1};
|
269 |
+
}
|
270 |
+
if (s == 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last
|
271 |
+
{
|
272 |
+
if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 24,1, 128,1,16>, 128,1,16, 1};
|
273 |
+
if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 20,1, 128,1,16>, 128,1,16, 1};
|
274 |
+
if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 16,1, 128,1,16>, 128,1,16, 1};
|
275 |
+
if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 12,1, 128,1,16>, 128,1,16, 1};
|
276 |
+
if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 8,1, 128,1,16>, 128,1,16, 1};
|
277 |
+
}
|
278 |
+
if (s != 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous
|
279 |
+
{
|
280 |
+
if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,24, 32,32,1>, 32,32,1, 1};
|
281 |
+
if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,20, 32,32,1>, 32,32,1, 1};
|
282 |
+
if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,16, 32,32,1>, 32,32,1, 1};
|
283 |
+
if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,12, 32,32,1>, 32,32,1, 1};
|
284 |
+
if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,8, 32,32,1>, 32,32,1, 1};
|
285 |
+
}
|
286 |
+
if (s == 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last
|
287 |
+
{
|
288 |
+
if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,24, 1,128,16>, 1,128,16, 1};
|
289 |
+
if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,20, 1,128,16>, 1,128,16, 1};
|
290 |
+
if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,16, 1,128,16>, 1,128,16, 1};
|
291 |
+
if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,12, 1,128,16>, 1,128,16, 1};
|
292 |
+
if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,8, 1,128,16>, 1,128,16, 1};
|
293 |
+
}
|
294 |
+
if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // contiguous
|
295 |
+
{
|
296 |
+
if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 8,8, 32,8,1>, 32,8,1, 1};
|
297 |
+
if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 6,6, 32,8,1>, 32,8,1, 1};
|
298 |
+
if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 4,4, 32,8,1>, 32,8,1, 1};
|
299 |
+
if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 2,2, 32,8,1>, 32,8,1, 1};
|
300 |
+
}
|
301 |
+
if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // channels_last
|
302 |
+
{
|
303 |
+
if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 8,8, 8,8,8>, 8,8,8, 1};
|
304 |
+
if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 6,6, 8,8,8>, 8,8,8, 1};
|
305 |
+
if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 4,4, 8,8,8>, 8,8,8, 1};
|
306 |
+
if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 2,2, 8,8,8>, 8,8,8, 1};
|
307 |
+
}
|
308 |
+
if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // contiguous
|
309 |
+
{
|
310 |
+
if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 24,1, 64,8,1>, 64,8,1, 1};
|
311 |
+
if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 20,1, 64,8,1>, 64,8,1, 1};
|
312 |
+
if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 16,1, 64,8,1>, 64,8,1, 1};
|
313 |
+
if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 12,1, 64,8,1>, 64,8,1, 1};
|
314 |
+
if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 8,1, 64,8,1>, 64,8,1, 1};
|
315 |
+
}
|
316 |
+
if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // channels_last
|
317 |
+
{
|
318 |
+
if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 24,1, 64,1,8>, 64,1,8, 1};
|
319 |
+
if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 20,1, 64,1,8>, 64,1,8, 1};
|
320 |
+
if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 16,1, 64,1,8>, 64,1,8, 1};
|
321 |
+
if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 12,1, 64,1,8>, 64,1,8, 1};
|
322 |
+
if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 8,1, 64,1,8>, 64,1,8, 1};
|
323 |
+
}
|
324 |
+
if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // contiguous
|
325 |
+
{
|
326 |
+
if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,24, 32,16,1>, 32,16,1, 1};
|
327 |
+
if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,20, 32,16,1>, 32,16,1, 1};
|
328 |
+
if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,16, 32,16,1>, 32,16,1, 1};
|
329 |
+
if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,12, 32,16,1>, 32,16,1, 1};
|
330 |
+
if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,8, 32,16,1>, 32,16,1, 1};
|
331 |
+
}
|
332 |
+
if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // channels_last
|
333 |
+
{
|
334 |
+
if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,24, 1,64,8>, 1,64,8, 1};
|
335 |
+
if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,20, 1,64,8>, 1,64,8, 1};
|
336 |
+
if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,16, 1,64,8>, 1,64,8, 1};
|
337 |
+
if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,12, 1,64,8>, 1,64,8, 1};
|
338 |
+
if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,8, 1,64,8>, 1,64,8, 1};
|
339 |
+
}
|
340 |
+
return spec;
|
341 |
+
}
|
342 |
+
|
343 |
+
//------------------------------------------------------------------------
|
344 |
+
// Template specializations.
|
345 |
+
|
346 |
+
template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<double> (const upfirdn2d_kernel_params& p);
|
347 |
+
template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<float> (const upfirdn2d_kernel_params& p);
|
348 |
+
template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<c10::Half>(const upfirdn2d_kernel_params& p);
|
349 |
+
|
350 |
+
//------------------------------------------------------------------------
|
torch_utils/ops/upfirdn2d.h
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
//
|
3 |
+
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
// and proprietary rights in and to this software, related documentation
|
5 |
+
// and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
// distribution of this software and related documentation without an express
|
7 |
+
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
#include <cuda_runtime.h>
|
10 |
+
|
11 |
+
//------------------------------------------------------------------------
|
12 |
+
// CUDA kernel parameters.
|
13 |
+
|
14 |
+
struct upfirdn2d_kernel_params
|
15 |
+
{
|
16 |
+
const void* x;
|
17 |
+
const float* f;
|
18 |
+
void* y;
|
19 |
+
|
20 |
+
int2 up;
|
21 |
+
int2 down;
|
22 |
+
int2 pad0;
|
23 |
+
int flip;
|
24 |
+
float gain;
|
25 |
+
|
26 |
+
int4 inSize; // [width, height, channel, batch]
|
27 |
+
int4 inStride;
|
28 |
+
int2 filterSize; // [width, height]
|
29 |
+
int2 filterStride;
|
30 |
+
int4 outSize; // [width, height, channel, batch]
|
31 |
+
int4 outStride;
|
32 |
+
int sizeMinor;
|
33 |
+
int sizeMajor;
|
34 |
+
|
35 |
+
int loopMinor;
|
36 |
+
int loopMajor;
|
37 |
+
int loopX;
|
38 |
+
int launchMinor;
|
39 |
+
int launchMajor;
|
40 |
+
};
|
41 |
+
|
42 |
+
//------------------------------------------------------------------------
|
43 |
+
// CUDA kernel specialization.
|
44 |
+
|
45 |
+
struct upfirdn2d_kernel_spec
|
46 |
+
{
|
47 |
+
void* kernel;
|
48 |
+
int tileOutW;
|
49 |
+
int tileOutH;
|
50 |
+
int loopMinor;
|
51 |
+
int loopX;
|
52 |
+
};
|
53 |
+
|
54 |
+
//------------------------------------------------------------------------
|
55 |
+
// CUDA kernel selection.
|
56 |
+
|
57 |
+
template <class T> upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p);
|
58 |
+
|
59 |
+
//------------------------------------------------------------------------
|
torch_utils/ops/upfirdn2d.py
ADDED
@@ -0,0 +1,384 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
"""Custom PyTorch ops for efficient resampling of 2D images."""
|
10 |
+
|
11 |
+
import os
|
12 |
+
import warnings
|
13 |
+
import numpy as np
|
14 |
+
import torch
|
15 |
+
import traceback
|
16 |
+
|
17 |
+
from .. import custom_ops
|
18 |
+
from .. import misc
|
19 |
+
from . import conv2d_gradfix
|
20 |
+
|
21 |
+
#----------------------------------------------------------------------------
|
22 |
+
|
23 |
+
_inited = False
|
24 |
+
_plugin = None
|
25 |
+
|
26 |
+
def _init():
|
27 |
+
global _inited, _plugin
|
28 |
+
if not _inited:
|
29 |
+
sources = ['upfirdn2d.cpp', 'upfirdn2d.cu']
|
30 |
+
sources = [os.path.join(os.path.dirname(__file__), s) for s in sources]
|
31 |
+
try:
|
32 |
+
_plugin = custom_ops.get_plugin('upfirdn2d_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math'])
|
33 |
+
except:
|
34 |
+
warnings.warn('Failed to build CUDA kernels for upfirdn2d. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc())
|
35 |
+
return _plugin is not None
|
36 |
+
|
37 |
+
def _parse_scaling(scaling):
|
38 |
+
if isinstance(scaling, int):
|
39 |
+
scaling = [scaling, scaling]
|
40 |
+
assert isinstance(scaling, (list, tuple))
|
41 |
+
assert all(isinstance(x, int) for x in scaling)
|
42 |
+
sx, sy = scaling
|
43 |
+
assert sx >= 1 and sy >= 1
|
44 |
+
return sx, sy
|
45 |
+
|
46 |
+
def _parse_padding(padding):
|
47 |
+
if isinstance(padding, int):
|
48 |
+
padding = [padding, padding]
|
49 |
+
assert isinstance(padding, (list, tuple))
|
50 |
+
assert all(isinstance(x, int) for x in padding)
|
51 |
+
if len(padding) == 2:
|
52 |
+
padx, pady = padding
|
53 |
+
padding = [padx, padx, pady, pady]
|
54 |
+
padx0, padx1, pady0, pady1 = padding
|
55 |
+
return padx0, padx1, pady0, pady1
|
56 |
+
|
57 |
+
def _get_filter_size(f):
|
58 |
+
if f is None:
|
59 |
+
return 1, 1
|
60 |
+
assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
|
61 |
+
fw = f.shape[-1]
|
62 |
+
fh = f.shape[0]
|
63 |
+
with misc.suppress_tracer_warnings():
|
64 |
+
fw = int(fw)
|
65 |
+
fh = int(fh)
|
66 |
+
misc.assert_shape(f, [fh, fw][:f.ndim])
|
67 |
+
assert fw >= 1 and fh >= 1
|
68 |
+
return fw, fh
|
69 |
+
|
70 |
+
#----------------------------------------------------------------------------
|
71 |
+
|
72 |
+
def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None):
|
73 |
+
r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`.
|
74 |
+
|
75 |
+
Args:
|
76 |
+
f: Torch tensor, numpy array, or python list of the shape
|
77 |
+
`[filter_height, filter_width]` (non-separable),
|
78 |
+
`[filter_taps]` (separable),
|
79 |
+
`[]` (impulse), or
|
80 |
+
`None` (identity).
|
81 |
+
device: Result device (default: cpu).
|
82 |
+
normalize: Normalize the filter so that it retains the magnitude
|
83 |
+
for constant input signal (DC)? (default: True).
|
84 |
+
flip_filter: Flip the filter? (default: False).
|
85 |
+
gain: Overall scaling factor for signal magnitude (default: 1).
|
86 |
+
separable: Return a separable filter? (default: select automatically).
|
87 |
+
|
88 |
+
Returns:
|
89 |
+
Float32 tensor of the shape
|
90 |
+
`[filter_height, filter_width]` (non-separable) or
|
91 |
+
`[filter_taps]` (separable).
|
92 |
+
"""
|
93 |
+
# Validate.
|
94 |
+
if f is None:
|
95 |
+
f = 1
|
96 |
+
f = torch.as_tensor(f, dtype=torch.float32)
|
97 |
+
assert f.ndim in [0, 1, 2]
|
98 |
+
assert f.numel() > 0
|
99 |
+
if f.ndim == 0:
|
100 |
+
f = f[np.newaxis]
|
101 |
+
|
102 |
+
# Separable?
|
103 |
+
if separable is None:
|
104 |
+
separable = (f.ndim == 1 and f.numel() >= 8)
|
105 |
+
if f.ndim == 1 and not separable:
|
106 |
+
f = f.ger(f)
|
107 |
+
assert f.ndim == (1 if separable else 2)
|
108 |
+
|
109 |
+
# Apply normalize, flip, gain, and device.
|
110 |
+
if normalize:
|
111 |
+
f /= f.sum()
|
112 |
+
if flip_filter:
|
113 |
+
f = f.flip(list(range(f.ndim)))
|
114 |
+
f = f * (gain ** (f.ndim / 2))
|
115 |
+
f = f.to(device=device)
|
116 |
+
return f
|
117 |
+
|
118 |
+
#----------------------------------------------------------------------------
|
119 |
+
|
120 |
+
def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'):
|
121 |
+
r"""Pad, upsample, filter, and downsample a batch of 2D images.
|
122 |
+
|
123 |
+
Performs the following sequence of operations for each channel:
|
124 |
+
|
125 |
+
1. Upsample the image by inserting N-1 zeros after each pixel (`up`).
|
126 |
+
|
127 |
+
2. Pad the image with the specified number of zeros on each side (`padding`).
|
128 |
+
Negative padding corresponds to cropping the image.
|
129 |
+
|
130 |
+
3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it
|
131 |
+
so that the footprint of all output pixels lies within the input image.
|
132 |
+
|
133 |
+
4. Downsample the image by keeping every Nth pixel (`down`).
|
134 |
+
|
135 |
+
This sequence of operations bears close resemblance to scipy.signal.upfirdn().
|
136 |
+
The fused op is considerably more efficient than performing the same calculation
|
137 |
+
using standard PyTorch ops. It supports gradients of arbitrary order.
|
138 |
+
|
139 |
+
Args:
|
140 |
+
x: Float32/float64/float16 input tensor of the shape
|
141 |
+
`[batch_size, num_channels, in_height, in_width]`.
|
142 |
+
f: Float32 FIR filter of the shape
|
143 |
+
`[filter_height, filter_width]` (non-separable),
|
144 |
+
`[filter_taps]` (separable), or
|
145 |
+
`None` (identity).
|
146 |
+
up: Integer upsampling factor. Can be a single int or a list/tuple
|
147 |
+
`[x, y]` (default: 1).
|
148 |
+
down: Integer downsampling factor. Can be a single int or a list/tuple
|
149 |
+
`[x, y]` (default: 1).
|
150 |
+
padding: Padding with respect to the upsampled image. Can be a single number
|
151 |
+
or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
152 |
+
(default: 0).
|
153 |
+
flip_filter: False = convolution, True = correlation (default: False).
|
154 |
+
gain: Overall scaling factor for signal magnitude (default: 1).
|
155 |
+
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
|
156 |
+
|
157 |
+
Returns:
|
158 |
+
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
159 |
+
"""
|
160 |
+
assert isinstance(x, torch.Tensor)
|
161 |
+
assert impl in ['ref', 'cuda']
|
162 |
+
if impl == 'cuda' and x.device.type == 'cuda' and _init():
|
163 |
+
return _upfirdn2d_cuda(up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain).apply(x, f)
|
164 |
+
return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain)
|
165 |
+
|
166 |
+
#----------------------------------------------------------------------------
|
167 |
+
|
168 |
+
@misc.profiled_function
|
169 |
+
def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):
|
170 |
+
"""Slow reference implementation of `upfirdn2d()` using standard PyTorch ops.
|
171 |
+
"""
|
172 |
+
# Validate arguments.
|
173 |
+
assert isinstance(x, torch.Tensor) and x.ndim == 4
|
174 |
+
if f is None:
|
175 |
+
f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
|
176 |
+
assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
|
177 |
+
assert f.dtype == torch.float32 and not f.requires_grad
|
178 |
+
batch_size, num_channels, in_height, in_width = x.shape
|
179 |
+
upx, upy = _parse_scaling(up)
|
180 |
+
downx, downy = _parse_scaling(down)
|
181 |
+
padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
182 |
+
|
183 |
+
# Upsample by inserting zeros.
|
184 |
+
x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1])
|
185 |
+
x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1])
|
186 |
+
x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx])
|
187 |
+
|
188 |
+
# Pad or crop.
|
189 |
+
x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)])
|
190 |
+
x = x[:, :, max(-pady0, 0) : x.shape[2] - max(-pady1, 0), max(-padx0, 0) : x.shape[3] - max(-padx1, 0)]
|
191 |
+
|
192 |
+
# Setup filter.
|
193 |
+
f = f * (gain ** (f.ndim / 2))
|
194 |
+
f = f.to(x.dtype)
|
195 |
+
if not flip_filter:
|
196 |
+
f = f.flip(list(range(f.ndim)))
|
197 |
+
|
198 |
+
# Convolve with the filter.
|
199 |
+
f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim)
|
200 |
+
if f.ndim == 4:
|
201 |
+
x = conv2d_gradfix.conv2d(input=x, weight=f, groups=num_channels)
|
202 |
+
else:
|
203 |
+
x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels)
|
204 |
+
x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels)
|
205 |
+
|
206 |
+
# Downsample by throwing away pixels.
|
207 |
+
x = x[:, :, ::downy, ::downx]
|
208 |
+
return x
|
209 |
+
|
210 |
+
#----------------------------------------------------------------------------
|
211 |
+
|
212 |
+
_upfirdn2d_cuda_cache = dict()
|
213 |
+
|
214 |
+
def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1):
|
215 |
+
"""Fast CUDA implementation of `upfirdn2d()` using custom ops.
|
216 |
+
"""
|
217 |
+
# Parse arguments.
|
218 |
+
upx, upy = _parse_scaling(up)
|
219 |
+
downx, downy = _parse_scaling(down)
|
220 |
+
padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
221 |
+
|
222 |
+
# Lookup from cache.
|
223 |
+
key = (upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
|
224 |
+
if key in _upfirdn2d_cuda_cache:
|
225 |
+
return _upfirdn2d_cuda_cache[key]
|
226 |
+
|
227 |
+
# Forward op.
|
228 |
+
class Upfirdn2dCuda(torch.autograd.Function):
|
229 |
+
@staticmethod
|
230 |
+
def forward(ctx, x, f): # pylint: disable=arguments-differ
|
231 |
+
assert isinstance(x, torch.Tensor) and x.ndim == 4
|
232 |
+
if f is None:
|
233 |
+
f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
|
234 |
+
assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
|
235 |
+
y = x
|
236 |
+
if f.ndim == 2:
|
237 |
+
y = _plugin.upfirdn2d(y, f, upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
|
238 |
+
else:
|
239 |
+
y = _plugin.upfirdn2d(y, f.unsqueeze(0), upx, 1, downx, 1, padx0, padx1, 0, 0, flip_filter, np.sqrt(gain))
|
240 |
+
y = _plugin.upfirdn2d(y, f.unsqueeze(1), 1, upy, 1, downy, 0, 0, pady0, pady1, flip_filter, np.sqrt(gain))
|
241 |
+
ctx.save_for_backward(f)
|
242 |
+
ctx.x_shape = x.shape
|
243 |
+
return y
|
244 |
+
|
245 |
+
@staticmethod
|
246 |
+
def backward(ctx, dy): # pylint: disable=arguments-differ
|
247 |
+
f, = ctx.saved_tensors
|
248 |
+
_, _, ih, iw = ctx.x_shape
|
249 |
+
_, _, oh, ow = dy.shape
|
250 |
+
fw, fh = _get_filter_size(f)
|
251 |
+
p = [
|
252 |
+
fw - padx0 - 1,
|
253 |
+
iw * upx - ow * downx + padx0 - upx + 1,
|
254 |
+
fh - pady0 - 1,
|
255 |
+
ih * upy - oh * downy + pady0 - upy + 1,
|
256 |
+
]
|
257 |
+
dx = None
|
258 |
+
df = None
|
259 |
+
|
260 |
+
if ctx.needs_input_grad[0]:
|
261 |
+
dx = _upfirdn2d_cuda(up=down, down=up, padding=p, flip_filter=(not flip_filter), gain=gain).apply(dy, f)
|
262 |
+
|
263 |
+
assert not ctx.needs_input_grad[1]
|
264 |
+
return dx, df
|
265 |
+
|
266 |
+
# Add to cache.
|
267 |
+
_upfirdn2d_cuda_cache[key] = Upfirdn2dCuda
|
268 |
+
return Upfirdn2dCuda
|
269 |
+
|
270 |
+
#----------------------------------------------------------------------------
|
271 |
+
|
272 |
+
def filter2d(x, f, padding=0, flip_filter=False, gain=1, impl='cuda'):
|
273 |
+
r"""Filter a batch of 2D images using the given 2D FIR filter.
|
274 |
+
|
275 |
+
By default, the result is padded so that its shape matches the input.
|
276 |
+
User-specified padding is applied on top of that, with negative values
|
277 |
+
indicating cropping. Pixels outside the image are assumed to be zero.
|
278 |
+
|
279 |
+
Args:
|
280 |
+
x: Float32/float64/float16 input tensor of the shape
|
281 |
+
`[batch_size, num_channels, in_height, in_width]`.
|
282 |
+
f: Float32 FIR filter of the shape
|
283 |
+
`[filter_height, filter_width]` (non-separable),
|
284 |
+
`[filter_taps]` (separable), or
|
285 |
+
`None` (identity).
|
286 |
+
padding: Padding with respect to the output. Can be a single number or a
|
287 |
+
list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
288 |
+
(default: 0).
|
289 |
+
flip_filter: False = convolution, True = correlation (default: False).
|
290 |
+
gain: Overall scaling factor for signal magnitude (default: 1).
|
291 |
+
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
|
292 |
+
|
293 |
+
Returns:
|
294 |
+
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
295 |
+
"""
|
296 |
+
padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
297 |
+
fw, fh = _get_filter_size(f)
|
298 |
+
p = [
|
299 |
+
padx0 + fw // 2,
|
300 |
+
padx1 + (fw - 1) // 2,
|
301 |
+
pady0 + fh // 2,
|
302 |
+
pady1 + (fh - 1) // 2,
|
303 |
+
]
|
304 |
+
return upfirdn2d(x, f, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
|
305 |
+
|
306 |
+
#----------------------------------------------------------------------------
|
307 |
+
|
308 |
+
def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
|
309 |
+
r"""Upsample a batch of 2D images using the given 2D FIR filter.
|
310 |
+
|
311 |
+
By default, the result is padded so that its shape is a multiple of the input.
|
312 |
+
User-specified padding is applied on top of that, with negative values
|
313 |
+
indicating cropping. Pixels outside the image are assumed to be zero.
|
314 |
+
|
315 |
+
Args:
|
316 |
+
x: Float32/float64/float16 input tensor of the shape
|
317 |
+
`[batch_size, num_channels, in_height, in_width]`.
|
318 |
+
f: Float32 FIR filter of the shape
|
319 |
+
`[filter_height, filter_width]` (non-separable),
|
320 |
+
`[filter_taps]` (separable), or
|
321 |
+
`None` (identity).
|
322 |
+
up: Integer upsampling factor. Can be a single int or a list/tuple
|
323 |
+
`[x, y]` (default: 1).
|
324 |
+
padding: Padding with respect to the output. Can be a single number or a
|
325 |
+
list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
326 |
+
(default: 0).
|
327 |
+
flip_filter: False = convolution, True = correlation (default: False).
|
328 |
+
gain: Overall scaling factor for signal magnitude (default: 1).
|
329 |
+
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
|
330 |
+
|
331 |
+
Returns:
|
332 |
+
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
333 |
+
"""
|
334 |
+
upx, upy = _parse_scaling(up)
|
335 |
+
padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
336 |
+
fw, fh = _get_filter_size(f)
|
337 |
+
p = [
|
338 |
+
padx0 + (fw + upx - 1) // 2,
|
339 |
+
padx1 + (fw - upx) // 2,
|
340 |
+
pady0 + (fh + upy - 1) // 2,
|
341 |
+
pady1 + (fh - upy) // 2,
|
342 |
+
]
|
343 |
+
return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain*upx*upy, impl=impl)
|
344 |
+
|
345 |
+
#----------------------------------------------------------------------------
|
346 |
+
|
347 |
+
def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
|
348 |
+
r"""Downsample a batch of 2D images using the given 2D FIR filter.
|
349 |
+
|
350 |
+
By default, the result is padded so that its shape is a fraction of the input.
|
351 |
+
User-specified padding is applied on top of that, with negative values
|
352 |
+
indicating cropping. Pixels outside the image are assumed to be zero.
|
353 |
+
|
354 |
+
Args:
|
355 |
+
x: Float32/float64/float16 input tensor of the shape
|
356 |
+
`[batch_size, num_channels, in_height, in_width]`.
|
357 |
+
f: Float32 FIR filter of the shape
|
358 |
+
`[filter_height, filter_width]` (non-separable),
|
359 |
+
`[filter_taps]` (separable), or
|
360 |
+
`None` (identity).
|
361 |
+
down: Integer downsampling factor. Can be a single int or a list/tuple
|
362 |
+
`[x, y]` (default: 1).
|
363 |
+
padding: Padding with respect to the input. Can be a single number or a
|
364 |
+
list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
365 |
+
(default: 0).
|
366 |
+
flip_filter: False = convolution, True = correlation (default: False).
|
367 |
+
gain: Overall scaling factor for signal magnitude (default: 1).
|
368 |
+
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
|
369 |
+
|
370 |
+
Returns:
|
371 |
+
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
372 |
+
"""
|
373 |
+
downx, downy = _parse_scaling(down)
|
374 |
+
padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
375 |
+
fw, fh = _get_filter_size(f)
|
376 |
+
p = [
|
377 |
+
padx0 + (fw - downx + 1) // 2,
|
378 |
+
padx1 + (fw - downx) // 2,
|
379 |
+
pady0 + (fh - downy + 1) // 2,
|
380 |
+
pady1 + (fh - downy) // 2,
|
381 |
+
]
|
382 |
+
return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
|
383 |
+
|
384 |
+
#----------------------------------------------------------------------------
|
torch_utils/persistence.py
ADDED
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
"""Facilities for pickling Python code alongside other data.
|
10 |
+
|
11 |
+
The pickled code is automatically imported into a separate Python module
|
12 |
+
during unpickling. This way, any previously exported pickles will remain
|
13 |
+
usable even if the original code is no longer available, or if the current
|
14 |
+
version of the code is not consistent with what was originally pickled."""
|
15 |
+
|
16 |
+
import sys
|
17 |
+
import pickle
|
18 |
+
import io
|
19 |
+
import inspect
|
20 |
+
import copy
|
21 |
+
import uuid
|
22 |
+
import types
|
23 |
+
import dnnlib
|
24 |
+
|
25 |
+
#----------------------------------------------------------------------------
|
26 |
+
|
27 |
+
_version = 6 # internal version number
|
28 |
+
_decorators = set() # {decorator_class, ...}
|
29 |
+
_import_hooks = [] # [hook_function, ...]
|
30 |
+
_module_to_src_dict = dict() # {module: src, ...}
|
31 |
+
_src_to_module_dict = dict() # {src: module, ...}
|
32 |
+
|
33 |
+
#----------------------------------------------------------------------------
|
34 |
+
|
35 |
+
def persistent_class(orig_class):
|
36 |
+
r"""Class decorator that extends a given class to save its source code
|
37 |
+
when pickled.
|
38 |
+
|
39 |
+
Example:
|
40 |
+
|
41 |
+
from torch_utils import persistence
|
42 |
+
|
43 |
+
@persistence.persistent_class
|
44 |
+
class MyNetwork(torch.nn.Module):
|
45 |
+
def __init__(self, num_inputs, num_outputs):
|
46 |
+
super().__init__()
|
47 |
+
self.fc = MyLayer(num_inputs, num_outputs)
|
48 |
+
...
|
49 |
+
|
50 |
+
@persistence.persistent_class
|
51 |
+
class MyLayer(torch.nn.Module):
|
52 |
+
...
|
53 |
+
|
54 |
+
When pickled, any instance of `MyNetwork` and `MyLayer` will save its
|
55 |
+
source code alongside other internal state (e.g., parameters, buffers,
|
56 |
+
and submodules). This way, any previously exported pickle will remain
|
57 |
+
usable even if the class definitions have been modified or are no
|
58 |
+
longer available.
|
59 |
+
|
60 |
+
The decorator saves the source code of the entire Python module
|
61 |
+
containing the decorated class. It does *not* save the source code of
|
62 |
+
any imported modules. Thus, the imported modules must be available
|
63 |
+
during unpickling, also including `torch_utils.persistence` itself.
|
64 |
+
|
65 |
+
It is ok to call functions defined in the same module from the
|
66 |
+
decorated class. However, if the decorated class depends on other
|
67 |
+
classes defined in the same module, they must be decorated as well.
|
68 |
+
This is illustrated in the above example in the case of `MyLayer`.
|
69 |
+
|
70 |
+
It is also possible to employ the decorator just-in-time before
|
71 |
+
calling the constructor. For example:
|
72 |
+
|
73 |
+
cls = MyLayer
|
74 |
+
if want_to_make_it_persistent:
|
75 |
+
cls = persistence.persistent_class(cls)
|
76 |
+
layer = cls(num_inputs, num_outputs)
|
77 |
+
|
78 |
+
As an additional feature, the decorator also keeps track of the
|
79 |
+
arguments that were used to construct each instance of the decorated
|
80 |
+
class. The arguments can be queried via `obj.init_args` and
|
81 |
+
`obj.init_kwargs`, and they are automatically pickled alongside other
|
82 |
+
object state. A typical use case is to first unpickle a previous
|
83 |
+
instance of a persistent class, and then upgrade it to use the latest
|
84 |
+
version of the source code:
|
85 |
+
|
86 |
+
with open('old_pickle.pkl', 'rb') as f:
|
87 |
+
old_net = pickle.load(f)
|
88 |
+
new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs)
|
89 |
+
misc.copy_params_and_buffers(old_net, new_net, require_all=True)
|
90 |
+
"""
|
91 |
+
assert isinstance(orig_class, type)
|
92 |
+
if is_persistent(orig_class):
|
93 |
+
return orig_class
|
94 |
+
|
95 |
+
assert orig_class.__module__ in sys.modules
|
96 |
+
orig_module = sys.modules[orig_class.__module__]
|
97 |
+
orig_module_src = _module_to_src(orig_module)
|
98 |
+
|
99 |
+
class Decorator(orig_class):
|
100 |
+
_orig_module_src = orig_module_src
|
101 |
+
_orig_class_name = orig_class.__name__
|
102 |
+
|
103 |
+
def __init__(self, *args, **kwargs):
|
104 |
+
super().__init__(*args, **kwargs)
|
105 |
+
self._init_args = copy.deepcopy(args)
|
106 |
+
self._init_kwargs = copy.deepcopy(kwargs)
|
107 |
+
assert orig_class.__name__ in orig_module.__dict__
|
108 |
+
_check_pickleable(self.__reduce__())
|
109 |
+
|
110 |
+
@property
|
111 |
+
def init_args(self):
|
112 |
+
return copy.deepcopy(self._init_args)
|
113 |
+
|
114 |
+
@property
|
115 |
+
def init_kwargs(self):
|
116 |
+
return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs))
|
117 |
+
|
118 |
+
def __reduce__(self):
|
119 |
+
fields = list(super().__reduce__())
|
120 |
+
fields += [None] * max(3 - len(fields), 0)
|
121 |
+
if fields[0] is not _reconstruct_persistent_obj:
|
122 |
+
meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2])
|
123 |
+
fields[0] = _reconstruct_persistent_obj # reconstruct func
|
124 |
+
fields[1] = (meta,) # reconstruct args
|
125 |
+
fields[2] = None # state dict
|
126 |
+
return tuple(fields)
|
127 |
+
|
128 |
+
Decorator.__name__ = orig_class.__name__
|
129 |
+
_decorators.add(Decorator)
|
130 |
+
return Decorator
|
131 |
+
|
132 |
+
#----------------------------------------------------------------------------
|
133 |
+
|
134 |
+
def is_persistent(obj):
|
135 |
+
r"""Test whether the given object or class is persistent, i.e.,
|
136 |
+
whether it will save its source code when pickled.
|
137 |
+
"""
|
138 |
+
try:
|
139 |
+
if obj in _decorators:
|
140 |
+
return True
|
141 |
+
except TypeError:
|
142 |
+
pass
|
143 |
+
return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck
|
144 |
+
|
145 |
+
#----------------------------------------------------------------------------
|
146 |
+
|
147 |
+
def import_hook(hook):
|
148 |
+
r"""Register an import hook that is called whenever a persistent object
|
149 |
+
is being unpickled. A typical use case is to patch the pickled source
|
150 |
+
code to avoid errors and inconsistencies when the API of some imported
|
151 |
+
module has changed.
|
152 |
+
|
153 |
+
The hook should have the following signature:
|
154 |
+
|
155 |
+
hook(meta) -> modified meta
|
156 |
+
|
157 |
+
`meta` is an instance of `dnnlib.EasyDict` with the following fields:
|
158 |
+
|
159 |
+
type: Type of the persistent object, e.g. `'class'`.
|
160 |
+
version: Internal version number of `torch_utils.persistence`.
|
161 |
+
module_src Original source code of the Python module.
|
162 |
+
class_name: Class name in the original Python module.
|
163 |
+
state: Internal state of the object.
|
164 |
+
|
165 |
+
Example:
|
166 |
+
|
167 |
+
@persistence.import_hook
|
168 |
+
def wreck_my_network(meta):
|
169 |
+
if meta.class_name == 'MyNetwork':
|
170 |
+
print('MyNetwork is being imported. I will wreck it!')
|
171 |
+
meta.module_src = meta.module_src.replace("True", "False")
|
172 |
+
return meta
|
173 |
+
"""
|
174 |
+
assert callable(hook)
|
175 |
+
_import_hooks.append(hook)
|
176 |
+
|
177 |
+
#----------------------------------------------------------------------------
|
178 |
+
|
179 |
+
def _reconstruct_persistent_obj(meta):
|
180 |
+
r"""Hook that is called internally by the `pickle` module to unpickle
|
181 |
+
a persistent object.
|
182 |
+
"""
|
183 |
+
meta = dnnlib.EasyDict(meta)
|
184 |
+
meta.state = dnnlib.EasyDict(meta.state)
|
185 |
+
for hook in _import_hooks:
|
186 |
+
meta = hook(meta)
|
187 |
+
assert meta is not None
|
188 |
+
|
189 |
+
assert meta.version == _version
|
190 |
+
module = _src_to_module(meta.module_src)
|
191 |
+
|
192 |
+
assert meta.type == 'class'
|
193 |
+
orig_class = module.__dict__[meta.class_name]
|
194 |
+
decorator_class = persistent_class(orig_class)
|
195 |
+
obj = decorator_class.__new__(decorator_class)
|
196 |
+
|
197 |
+
setstate = getattr(obj, '__setstate__', None)
|
198 |
+
if callable(setstate):
|
199 |
+
setstate(meta.state) # pylint: disable=not-callable
|
200 |
+
else:
|
201 |
+
obj.__dict__.update(meta.state)
|
202 |
+
return obj
|
203 |
+
|
204 |
+
#----------------------------------------------------------------------------
|
205 |
+
|
206 |
+
def _module_to_src(module):
|
207 |
+
r"""Query the source code of a given Python module.
|
208 |
+
"""
|
209 |
+
src = _module_to_src_dict.get(module, None)
|
210 |
+
if src is None:
|
211 |
+
src = inspect.getsource(module)
|
212 |
+
_module_to_src_dict[module] = src
|
213 |
+
_src_to_module_dict[src] = module
|
214 |
+
return src
|
215 |
+
|
216 |
+
def _src_to_module(src):
|
217 |
+
r"""Get or create a Python module for the given source code.
|
218 |
+
"""
|
219 |
+
module = _src_to_module_dict.get(src, None)
|
220 |
+
if module is None:
|
221 |
+
module_name = "_imported_module_" + uuid.uuid4().hex
|
222 |
+
module = types.ModuleType(module_name)
|
223 |
+
sys.modules[module_name] = module
|
224 |
+
_module_to_src_dict[module] = src
|
225 |
+
_src_to_module_dict[src] = module
|
226 |
+
exec(src, module.__dict__) # pylint: disable=exec-used
|
227 |
+
return module
|
228 |
+
|
229 |
+
#----------------------------------------------------------------------------
|
230 |
+
|
231 |
+
def _check_pickleable(obj):
|
232 |
+
r"""Check that the given object is pickleable, raising an exception if
|
233 |
+
it is not. This function is expected to be considerably more efficient
|
234 |
+
than actually pickling the object.
|
235 |
+
"""
|
236 |
+
def recurse(obj):
|
237 |
+
if isinstance(obj, (list, tuple, set)):
|
238 |
+
return [recurse(x) for x in obj]
|
239 |
+
if isinstance(obj, dict):
|
240 |
+
return [[recurse(x), recurse(y)] for x, y in obj.items()]
|
241 |
+
if isinstance(obj, (str, int, float, bool, bytes, bytearray)):
|
242 |
+
return None # Python primitive types are pickleable.
|
243 |
+
if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor']:
|
244 |
+
return None # NumPy arrays and PyTorch tensors are pickleable.
|
245 |
+
if is_persistent(obj):
|
246 |
+
return None # Persistent objects are pickleable, by virtue of the constructor check.
|
247 |
+
return obj
|
248 |
+
with io.BytesIO() as f:
|
249 |
+
pickle.dump(recurse(obj), f)
|
250 |
+
|
251 |
+
#----------------------------------------------------------------------------
|
torch_utils/training_stats.py
ADDED
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
"""Facilities for reporting and collecting training statistics across
|
10 |
+
multiple processes and devices. The interface is designed to minimize
|
11 |
+
synchronization overhead as well as the amount of boilerplate in user
|
12 |
+
code."""
|
13 |
+
|
14 |
+
import re
|
15 |
+
import numpy as np
|
16 |
+
import torch
|
17 |
+
import dnnlib
|
18 |
+
|
19 |
+
from . import misc
|
20 |
+
|
21 |
+
#----------------------------------------------------------------------------
|
22 |
+
|
23 |
+
_num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares]
|
24 |
+
_reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction.
|
25 |
+
_counter_dtype = torch.float64 # Data type to use for the internal counters.
|
26 |
+
_rank = 0 # Rank of the current process.
|
27 |
+
_sync_device = None # Device to use for multiprocess communication. None = single-process.
|
28 |
+
_sync_called = False # Has _sync() been called yet?
|
29 |
+
_counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor
|
30 |
+
_cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor
|
31 |
+
|
32 |
+
#----------------------------------------------------------------------------
|
33 |
+
|
34 |
+
def init_multiprocessing(rank, sync_device):
|
35 |
+
r"""Initializes `torch_utils.training_stats` for collecting statistics
|
36 |
+
across multiple processes.
|
37 |
+
|
38 |
+
This function must be called after
|
39 |
+
`torch.distributed.init_process_group()` and before `Collector.update()`.
|
40 |
+
The call is not necessary if multi-process collection is not needed.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
rank: Rank of the current process.
|
44 |
+
sync_device: PyTorch device to use for inter-process
|
45 |
+
communication, or None to disable multi-process
|
46 |
+
collection. Typically `torch.device('cuda', rank)`.
|
47 |
+
"""
|
48 |
+
global _rank, _sync_device
|
49 |
+
assert not _sync_called
|
50 |
+
_rank = rank
|
51 |
+
_sync_device = sync_device
|
52 |
+
|
53 |
+
#----------------------------------------------------------------------------
|
54 |
+
|
55 |
+
@misc.profiled_function
|
56 |
+
def report(name, value):
|
57 |
+
r"""Broadcasts the given set of scalars to all interested instances of
|
58 |
+
`Collector`, across device and process boundaries.
|
59 |
+
|
60 |
+
This function is expected to be extremely cheap and can be safely
|
61 |
+
called from anywhere in the training loop, loss function, or inside a
|
62 |
+
`torch.nn.Module`.
|
63 |
+
|
64 |
+
Warning: The current implementation expects the set of unique names to
|
65 |
+
be consistent across processes. Please make sure that `report()` is
|
66 |
+
called at least once for each unique name by each process, and in the
|
67 |
+
same order. If a given process has no scalars to broadcast, it can do
|
68 |
+
`report(name, [])` (empty list).
|
69 |
+
|
70 |
+
Args:
|
71 |
+
name: Arbitrary string specifying the name of the statistic.
|
72 |
+
Averages are accumulated separately for each unique name.
|
73 |
+
value: Arbitrary set of scalars. Can be a list, tuple,
|
74 |
+
NumPy array, PyTorch tensor, or Python scalar.
|
75 |
+
|
76 |
+
Returns:
|
77 |
+
The same `value` that was passed in.
|
78 |
+
"""
|
79 |
+
if name not in _counters:
|
80 |
+
_counters[name] = dict()
|
81 |
+
|
82 |
+
elems = torch.as_tensor(value)
|
83 |
+
if elems.numel() == 0:
|
84 |
+
return value
|
85 |
+
|
86 |
+
elems = elems.detach().flatten().to(_reduce_dtype)
|
87 |
+
moments = torch.stack([
|
88 |
+
torch.ones_like(elems).sum(),
|
89 |
+
elems.sum(),
|
90 |
+
elems.square().sum(),
|
91 |
+
])
|
92 |
+
assert moments.ndim == 1 and moments.shape[0] == _num_moments
|
93 |
+
moments = moments.to(_counter_dtype)
|
94 |
+
|
95 |
+
device = moments.device
|
96 |
+
if device not in _counters[name]:
|
97 |
+
_counters[name][device] = torch.zeros_like(moments)
|
98 |
+
_counters[name][device].add_(moments)
|
99 |
+
return value
|
100 |
+
|
101 |
+
#----------------------------------------------------------------------------
|
102 |
+
|
103 |
+
def report0(name, value):
|
104 |
+
r"""Broadcasts the given set of scalars by the first process (`rank = 0`),
|
105 |
+
but ignores any scalars provided by the other processes.
|
106 |
+
See `report()` for further details.
|
107 |
+
"""
|
108 |
+
report(name, value if _rank == 0 else [])
|
109 |
+
return value
|
110 |
+
|
111 |
+
#----------------------------------------------------------------------------
|
112 |
+
|
113 |
+
class Collector:
|
114 |
+
r"""Collects the scalars broadcasted by `report()` and `report0()` and
|
115 |
+
computes their long-term averages (mean and standard deviation) over
|
116 |
+
user-defined periods of time.
|
117 |
+
|
118 |
+
The averages are first collected into internal counters that are not
|
119 |
+
directly visible to the user. They are then copied to the user-visible
|
120 |
+
state as a result of calling `update()` and can then be queried using
|
121 |
+
`mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the
|
122 |
+
internal counters for the next round, so that the user-visible state
|
123 |
+
effectively reflects averages collected between the last two calls to
|
124 |
+
`update()`.
|
125 |
+
|
126 |
+
Args:
|
127 |
+
regex: Regular expression defining which statistics to
|
128 |
+
collect. The default is to collect everything.
|
129 |
+
keep_previous: Whether to retain the previous averages if no
|
130 |
+
scalars were collected on a given round
|
131 |
+
(default: True).
|
132 |
+
"""
|
133 |
+
def __init__(self, regex='.*', keep_previous=True):
|
134 |
+
self._regex = re.compile(regex)
|
135 |
+
self._keep_previous = keep_previous
|
136 |
+
self._cumulative = dict()
|
137 |
+
self._moments = dict()
|
138 |
+
self.update()
|
139 |
+
self._moments.clear()
|
140 |
+
|
141 |
+
def names(self):
|
142 |
+
r"""Returns the names of all statistics broadcasted so far that
|
143 |
+
match the regular expression specified at construction time.
|
144 |
+
"""
|
145 |
+
return [name for name in _counters if self._regex.fullmatch(name)]
|
146 |
+
|
147 |
+
def update(self):
|
148 |
+
r"""Copies current values of the internal counters to the
|
149 |
+
user-visible state and resets them for the next round.
|
150 |
+
|
151 |
+
If `keep_previous=True` was specified at construction time, the
|
152 |
+
operation is skipped for statistics that have received no scalars
|
153 |
+
since the last update, retaining their previous averages.
|
154 |
+
|
155 |
+
This method performs a number of GPU-to-CPU transfers and one
|
156 |
+
`torch.distributed.all_reduce()`. It is intended to be called
|
157 |
+
periodically in the main training loop, typically once every
|
158 |
+
N training steps.
|
159 |
+
"""
|
160 |
+
if not self._keep_previous:
|
161 |
+
self._moments.clear()
|
162 |
+
for name, cumulative in _sync(self.names()):
|
163 |
+
if name not in self._cumulative:
|
164 |
+
self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
|
165 |
+
delta = cumulative - self._cumulative[name]
|
166 |
+
self._cumulative[name].copy_(cumulative)
|
167 |
+
if float(delta[0]) != 0:
|
168 |
+
self._moments[name] = delta
|
169 |
+
|
170 |
+
def _get_delta(self, name):
|
171 |
+
r"""Returns the raw moments that were accumulated for the given
|
172 |
+
statistic between the last two calls to `update()`, or zero if
|
173 |
+
no scalars were collected.
|
174 |
+
"""
|
175 |
+
assert self._regex.fullmatch(name)
|
176 |
+
if name not in self._moments:
|
177 |
+
self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
|
178 |
+
return self._moments[name]
|
179 |
+
|
180 |
+
def num(self, name):
|
181 |
+
r"""Returns the number of scalars that were accumulated for the given
|
182 |
+
statistic between the last two calls to `update()`, or zero if
|
183 |
+
no scalars were collected.
|
184 |
+
"""
|
185 |
+
delta = self._get_delta(name)
|
186 |
+
return int(delta[0])
|
187 |
+
|
188 |
+
def mean(self, name):
|
189 |
+
r"""Returns the mean of the scalars that were accumulated for the
|
190 |
+
given statistic between the last two calls to `update()`, or NaN if
|
191 |
+
no scalars were collected.
|
192 |
+
"""
|
193 |
+
delta = self._get_delta(name)
|
194 |
+
if int(delta[0]) == 0:
|
195 |
+
return float('nan')
|
196 |
+
return float(delta[1] / delta[0])
|
197 |
+
|
198 |
+
def std(self, name):
|
199 |
+
r"""Returns the standard deviation of the scalars that were
|
200 |
+
accumulated for the given statistic between the last two calls to
|
201 |
+
`update()`, or NaN if no scalars were collected.
|
202 |
+
"""
|
203 |
+
delta = self._get_delta(name)
|
204 |
+
if int(delta[0]) == 0 or not np.isfinite(float(delta[1])):
|
205 |
+
return float('nan')
|
206 |
+
if int(delta[0]) == 1:
|
207 |
+
return float(0)
|
208 |
+
mean = float(delta[1] / delta[0])
|
209 |
+
raw_var = float(delta[2] / delta[0])
|
210 |
+
return np.sqrt(max(raw_var - np.square(mean), 0))
|
211 |
+
|
212 |
+
def as_dict(self):
|
213 |
+
r"""Returns the averages accumulated between the last two calls to
|
214 |
+
`update()` as an `dnnlib.EasyDict`. The contents are as follows:
|
215 |
+
|
216 |
+
dnnlib.EasyDict(
|
217 |
+
NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT),
|
218 |
+
...
|
219 |
+
)
|
220 |
+
"""
|
221 |
+
stats = dnnlib.EasyDict()
|
222 |
+
for name in self.names():
|
223 |
+
stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name))
|
224 |
+
return stats
|
225 |
+
|
226 |
+
def __getitem__(self, name):
|
227 |
+
r"""Convenience getter.
|
228 |
+
`collector[name]` is a synonym for `collector.mean(name)`.
|
229 |
+
"""
|
230 |
+
return self.mean(name)
|
231 |
+
|
232 |
+
#----------------------------------------------------------------------------
|
233 |
+
|
234 |
+
def _sync(names):
|
235 |
+
r"""Synchronize the global cumulative counters across devices and
|
236 |
+
processes. Called internally by `Collector.update()`.
|
237 |
+
"""
|
238 |
+
if len(names) == 0:
|
239 |
+
return []
|
240 |
+
global _sync_called
|
241 |
+
_sync_called = True
|
242 |
+
|
243 |
+
# Collect deltas within current rank.
|
244 |
+
deltas = []
|
245 |
+
device = _sync_device if _sync_device is not None else torch.device('cpu')
|
246 |
+
for name in names:
|
247 |
+
delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device)
|
248 |
+
for counter in _counters[name].values():
|
249 |
+
delta.add_(counter.to(device))
|
250 |
+
counter.copy_(torch.zeros_like(counter))
|
251 |
+
deltas.append(delta)
|
252 |
+
deltas = torch.stack(deltas)
|
253 |
+
|
254 |
+
# Sum deltas across ranks.
|
255 |
+
if _sync_device is not None:
|
256 |
+
torch.distributed.all_reduce(deltas)
|
257 |
+
|
258 |
+
# Update cumulative values.
|
259 |
+
deltas = deltas.cpu()
|
260 |
+
for idx, name in enumerate(names):
|
261 |
+
if name not in _cumulative:
|
262 |
+
_cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
|
263 |
+
_cumulative[name].add_(deltas[idx])
|
264 |
+
|
265 |
+
# Return name-value pairs.
|
266 |
+
return [(name, _cumulative[name]) for name in names]
|
267 |
+
|
268 |
+
#----------------------------------------------------------------------------
|