HaisuGuan commited on
Commit
332a731
·
1 Parent(s): 3bdd2ef

模型代码

Browse files
Files changed (4) hide show
  1. models/__init__.py +2 -0
  2. models/ddm.py +260 -0
  3. models/restoration.py +59 -0
  4. models/unet.py +331 -0
models/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from models.ddm import *
2
+ from models.restoration import *
models/ddm.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.backends.cudnn as cudnn
7
+ import utils
8
+ from models.unet import DiffusionUNet
9
+ import torch.distributed as dist
10
+ from torch.utils.tensorboard import SummaryWriter
11
+ from torch.optim.lr_scheduler import CosineAnnealingLR
12
+
13
+
14
+ def data_transform(X):
15
+ return 2 * X - 1.0
16
+
17
+
18
+ def inverse_data_transform(X):
19
+ return torch.clamp((X + 1.0) / 2.0, 0.0, 1.0)
20
+
21
+
22
+ class EMAHelper(object):
23
+ def __init__(self, mu=0.9999):
24
+ self.mu = mu
25
+ self.shadow = {}
26
+
27
+ def register(self, module):
28
+ if isinstance(module, nn.DataParallel) or isinstance(module, nn.parallel.DistributedDataParallel):
29
+ module = module.module
30
+ for name, param in module.named_parameters():
31
+ if param.requires_grad:
32
+ self.shadow[name] = param.data.clone()
33
+
34
+ def update(self, module, device):
35
+ if isinstance(module, nn.DataParallel) or isinstance(module, nn.parallel.DistributedDataParallel):
36
+ module = module.module
37
+ for name, param in module.named_parameters():
38
+ if param.requires_grad:
39
+ self.shadow[name].data = (1. - self.mu) * param.data + self.mu * self.shadow[name].data.to(device)
40
+
41
+ def ema(self, module):
42
+ if isinstance(module, nn.DataParallel) or isinstance(module, nn.parallel.DistributedDataParallel):
43
+ module = module.module
44
+ for name, param in module.named_parameters():
45
+ if param.requires_grad:
46
+ param.data.copy_(self.shadow[name].data)
47
+
48
+ def ema_copy(self, module):
49
+ if isinstance(module, nn.DataParallel) or isinstance(module, nn.parallel.DistributedDataParallel):
50
+ inner_module = module.module
51
+ module_copy = type(inner_module)(inner_module.config).to(inner_module.config.device)
52
+ module_copy.load_state_dict(inner_module.state_dict())
53
+ module_copy = nn.DataParallel(module_copy)
54
+ else:
55
+ module_copy = type(module)(module.config).to(module.config.device)
56
+ module_copy.load_state_dict(module.state_dict())
57
+ self.ema(module_copy)
58
+ return module_copy
59
+
60
+ def state_dict(self):
61
+ return self.shadow
62
+
63
+ def load_state_dict(self, state_dict):
64
+ self.shadow = state_dict
65
+
66
+
67
+ def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
68
+ def sigmoid(x):
69
+ return 1 / (np.exp(-x) + 1)
70
+
71
+ if beta_schedule == "quad":
72
+ betas = (np.linspace(beta_start ** 0.5, beta_end ** 0.5, num_diffusion_timesteps, dtype=np.float64) ** 2)
73
+ elif beta_schedule == "linear":
74
+ betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
75
+ elif beta_schedule == "const":
76
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
77
+ elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
78
+ betas = 1.0 / np.linspace(num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64)
79
+ elif beta_schedule == "sigmoid":
80
+ betas = np.linspace(-6, 6, num_diffusion_timesteps)
81
+ betas = sigmoid(betas) * (beta_end - beta_start) + beta_start
82
+ else:
83
+ raise NotImplementedError(beta_schedule)
84
+ assert betas.shape == (num_diffusion_timesteps,)
85
+ return betas
86
+
87
+
88
+ def noise_estimation_loss(model, x0, t, e, b):
89
+ a = (1 - b).cumprod(dim=0).index_select(0, t).view(-1, 1, 1, 1)
90
+ x = x0[:, 3:, :, :] * a.sqrt() + e * (1.0 - a).sqrt()
91
+ output = model(torch.cat([x0[:, :3, :, :], x], dim=1), t.float())
92
+ return (e - output).square().sum(dim=(1, 2, 3)).mean(dim=0)
93
+
94
+
95
+ class DenoisingDiffusion(object):
96
+ def __init__(self, config, test=False):
97
+ super().__init__()
98
+ self.config = config
99
+ self.device = config.device
100
+ self.writer = SummaryWriter(config.data.tensorboard)
101
+ self.model = DiffusionUNet(config)
102
+ self.model.to(self.device)
103
+ if test:
104
+ self.model = torch.nn.DataParallel(self.model)
105
+ else:
106
+ self.model = torch.nn.parallel.DistributedDataParallel(self.model, device_ids=[config.local_rank],
107
+ output_device=config.local_rank)
108
+ self.ema_helper = EMAHelper()
109
+ self.ema_helper.register(self.model)
110
+
111
+ self.optimizer = utils.optimize.get_optimizer(self.config, self.model.parameters())
112
+ self.scheduler = CosineAnnealingLR(self.optimizer, T_max=config.training.n_epochs)
113
+ self.start_epoch, self.step = 0, 0
114
+
115
+ betas = get_beta_schedule(
116
+ beta_schedule=config.diffusion.beta_schedule,
117
+ beta_start=config.diffusion.beta_start,
118
+ beta_end=config.diffusion.beta_end,
119
+ num_diffusion_timesteps=config.diffusion.num_diffusion_timesteps,
120
+ )
121
+
122
+ betas = self.betas = torch.from_numpy(betas).float().to(self.device)
123
+ self.num_timesteps = betas.shape[0]
124
+
125
+ def load_ddm_ckpt(self, load_path, ema=False):
126
+ checkpoint = utils.logging.load_checkpoint(load_path, None)
127
+ self.start_epoch = checkpoint['epoch']
128
+ self.step = checkpoint['step']
129
+ self.model.load_state_dict(checkpoint['state_dict'], strict=True)
130
+ self.optimizer.load_state_dict(checkpoint['optimizer'])
131
+ self.ema_helper.load_state_dict(checkpoint['ema_helper'])
132
+ self.scheduler.load_state_dict(checkpoint['scheduler'])
133
+ if ema:
134
+ self.ema_helper.ema(self.model)
135
+ print("=> loaded checkpoint '{}' (epoch {}, step {})".format(load_path, checkpoint['epoch'], self.step))
136
+
137
+ def train(self, DATASET):
138
+ cudnn.benchmark = True
139
+ train_loader, val_loader = DATASET.get_loaders()
140
+ pretrained_model_path = self.config.training.resume + '.pth.tar'
141
+ if os.path.isfile(pretrained_model_path):
142
+ self.load_ddm_ckpt(pretrained_model_path)
143
+ dist.barrier()
144
+ # 训练
145
+ for epoch in range(self.start_epoch, self.config.training.n_epochs):
146
+ if (epoch == 0) and dist.get_rank() == 0:
147
+ utils.logging.save_checkpoint({
148
+ 'epoch': epoch + 1,
149
+ 'step': self.step,
150
+ 'state_dict': self.model.state_dict(),
151
+ 'optimizer': self.optimizer.state_dict(),
152
+ 'ema_helper': self.ema_helper.state_dict(),
153
+ 'config': self.config,
154
+ 'scheduler': self.scheduler.state_dict()
155
+ }, filename=self.config.training.resume + '_' + str(epoch))
156
+ utils.logging.save_checkpoint({
157
+ 'epoch': epoch + 1,
158
+ 'step': self.step,
159
+ 'state_dict': self.model.state_dict(),
160
+ 'optimizer': self.optimizer.state_dict(),
161
+ 'ema_helper': self.ema_helper.state_dict(),
162
+ 'config': self.config,
163
+ 'scheduler': self.scheduler.state_dict()
164
+ }, filename=self.config.training.resume)
165
+ if dist.get_rank() == 0:
166
+ print('=> current epoch: ', epoch)
167
+ data_start = time.time()
168
+ data_time = 0
169
+ train_loader.sampler.set_epoch(epoch)
170
+ for i, (x, y) in enumerate(train_loader):
171
+ x = x.flatten(start_dim=0, end_dim=1) if x.ndim == 5 else x
172
+ n = x.size(0)
173
+ data_time += time.time() - data_start
174
+ self.model.train()
175
+ self.step += 1
176
+
177
+ x = x.to(self.device)
178
+ x = data_transform(x)
179
+ e = torch.randn_like(x[:, 3:, :, :])
180
+ b = self.betas
181
+
182
+ # antithetic sampling
183
+ t = torch.randint(low=0, high=self.num_timesteps, size=(n // 2 + 1,)).to(self.device)
184
+ t = torch.cat([t, self.num_timesteps - t - 1], dim=0)[:n]
185
+ loss = noise_estimation_loss(self.model, x, t, e, b)
186
+ current_lr = self.optimizer.param_groups[0]['lr']
187
+
188
+ if self.step % 10 == 0:
189
+ print(
190
+ 'rank: %d, step: %d, loss: %.6f, lr: %.6f, time consumption: %.6f' % (
191
+ dist.get_rank(), self.step, loss.item(), current_lr, data_time / (i + 1)))
192
+
193
+ # 更新参数
194
+ self.optimizer.zero_grad()
195
+ loss.backward()
196
+ self.optimizer.step()
197
+ self.ema_helper.update(self.model, self.device)
198
+ data_start = time.time()
199
+
200
+ if self.step % self.config.training.validation_freq == 0:
201
+ self.model.eval()
202
+ self.sample_validation_patches(val_loader, self.step)
203
+
204
+ if (self.step % 100 == 0) and dist.get_rank() == 0:
205
+ self.writer.add_scalar('train/loss', loss.item(), self.step)
206
+ self.writer.add_scalar('train/lr', current_lr, self.step)
207
+
208
+ self.scheduler.step()
209
+ # 保存模型
210
+ if (epoch % self.config.training.snapshot_freq == 0) and dist.get_rank() == 0:
211
+ utils.logging.save_checkpoint({
212
+ 'epoch': epoch + 1,
213
+ 'step': self.step,
214
+ 'state_dict': self.model.state_dict(),
215
+ 'optimizer': self.optimizer.state_dict(),
216
+ 'ema_helper': self.ema_helper.state_dict(),
217
+ 'config': self.config,
218
+ 'scheduler': self.scheduler.state_dict()
219
+ }, filename=self.config.training.resume + '_' + str(epoch))
220
+ utils.logging.save_checkpoint({
221
+ 'epoch': epoch + 1,
222
+ 'step': self.step,
223
+ 'state_dict': self.model.state_dict(),
224
+ 'optimizer': self.optimizer.state_dict(),
225
+ 'ema_helper': self.ema_helper.state_dict(),
226
+ 'config': self.config,
227
+ 'scheduler': self.scheduler.state_dict()
228
+ }, filename=self.config.training.resume)
229
+
230
+ def sample_image(self, x_cond, x, last=True, patch_locs=None, patch_size=None):
231
+ skip = self.config.diffusion.num_diffusion_timesteps // self.config.sampling.sampling_timesteps
232
+ seq = range(0, self.config.diffusion.num_diffusion_timesteps, skip)
233
+ if patch_locs is not None:
234
+ xs = utils.sampling.generalized_steps_overlapping(x, x_cond, seq, self.model, self.betas, eta=0.,
235
+ corners=patch_locs, p_size=patch_size, device=self.device)
236
+ else:
237
+ xs = utils.sampling.generalized_steps(x, x_cond, seq, self.model, self.betas, eta=0., device=self.device)
238
+ if last:
239
+ xs = xs[0][-1]
240
+ return xs
241
+
242
+ def sample_validation_patches(self, val_loader, step):
243
+ image_folder = os.path.join(self.config.data.val_save_dir, str(self.config.data.image_size))
244
+ with torch.no_grad():
245
+ if dist.get_rank() == 0:
246
+ print(f"Processing a single batch of validation images at step: {step}")
247
+ for i, (x, y) in enumerate(val_loader):
248
+ x = x.flatten(start_dim=0, end_dim=1) if x.ndim == 5 else x
249
+ break
250
+ n = x.size(0)
251
+ x_cond = x[:, :3, :, :].to(self.device) # 条件图像
252
+ x_cond = data_transform(x_cond)
253
+ x = torch.randn(n, 3, self.config.data.image_size, self.config.data.image_size, device=self.device)
254
+ x = self.sample_image(x_cond, x)
255
+ x = inverse_data_transform(x)
256
+ x_cond = inverse_data_transform(x_cond)
257
+
258
+ for i in range(n):
259
+ utils.logging.save_image(x_cond[i], os.path.join(image_folder, str(step), f"{i}_cond.png"))
260
+ utils.logging.save_image(x[i], os.path.join(image_folder, str(step), f"{i}.png"))
models/restoration.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import utils
3
+ import os
4
+ from tqdm import tqdm
5
+
6
+
7
+ def data_transform(X):
8
+ return 2 * X - 1.0
9
+
10
+
11
+ def inverse_data_transform(X):
12
+ return torch.clamp((X + 1.0) / 2.0, 0.0, 1.0)
13
+
14
+
15
+ class DiffusiveRestoration:
16
+ def __init__(self, diffusion, config):
17
+ super(DiffusiveRestoration, self).__init__()
18
+ self.config = config
19
+ self.diffusion = diffusion
20
+
21
+ # 判断预训练模型是否存在
22
+ pretrained_model_path = self.config.training.resume + '.pth.tar'
23
+ assert os.path.isfile(pretrained_model_path), ('pretrained diffusion model path is wrong!')
24
+ self.diffusion.load_ddm_ckpt(pretrained_model_path, ema=True)
25
+ self.diffusion.model.eval()
26
+ self.diffusion.model.requires_grad_(False)
27
+
28
+ def restore(self, val_loader, r=None):
29
+ image_folder = self.config.data.test_save_dir
30
+ with torch.no_grad():
31
+ for i, (x, y) in tqdm(enumerate(val_loader)):
32
+ print(f"=> starting processing image named {y}")
33
+ x = x.flatten(start_dim=0, end_dim=1) if x.ndim == 5 else x
34
+ x_cond = x[:, :3, :, :].to(self.diffusion.device)
35
+ x_output = self.diffusive_restoration(x_cond, r=r)
36
+ x_output = inverse_data_transform(x_output)
37
+ utils.logging.save_image(x_output, os.path.join(image_folder, f"{y[0]}.png"))
38
+
39
+ def diffusive_restoration(self, x_cond, r=None):
40
+ p_size = self.config.data.image_size
41
+ h_list, w_list = self.overlapping_grid_indices(x_cond, output_size=p_size, r=r)
42
+ corners = [(i, j) for i in h_list for j in w_list]
43
+ x = torch.randn(x_cond.size(), device=self.diffusion.device)
44
+ x_output = self.diffusion.sample_image(x_cond, x, patch_locs=corners, patch_size=p_size)
45
+ return x_output
46
+
47
+ def overlapping_grid_indices(self, x_cond, output_size, r=None):
48
+ _, c, h, w = x_cond.shape
49
+ r = 16 if r is None else r
50
+ h_list = [i for i in range(0, h - output_size + 1, r)]
51
+ w_list = [i for i in range(0, w - output_size + 1, r)]
52
+ return h_list, w_list
53
+
54
+ def web_restore(self, image, r=None):
55
+ with torch.no_grad():
56
+ image_cond = image.to(self.diffusion.device)
57
+ image_output = self.diffusive_restoration(image_cond, r=r)
58
+ image_output = inverse_data_transform(image_output)
59
+ return image_output
models/unet.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+
6
+ def get_timestep_embedding(timesteps, embedding_dim):
7
+ assert len(timesteps.shape) == 1
8
+
9
+ half_dim = embedding_dim // 2
10
+ emb = math.log(10000) / (half_dim - 1)
11
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
12
+ emb = emb.to(device=timesteps.device)
13
+ emb = timesteps.float()[:, None] * emb[None, :]
14
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
15
+ if embedding_dim % 2 == 1: # zero pad
16
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
17
+ return emb
18
+
19
+
20
+ def nonlinearity(x):
21
+ # swish
22
+ return x*torch.sigmoid(x)
23
+
24
+
25
+ def Normalize(in_channels):
26
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
27
+
28
+
29
+ class Upsample(nn.Module):
30
+ def __init__(self, in_channels, with_conv):
31
+ super().__init__()
32
+ self.with_conv = with_conv
33
+ if self.with_conv:
34
+ self.conv = torch.nn.Conv2d(in_channels,
35
+ in_channels,
36
+ kernel_size=3,
37
+ stride=1,
38
+ padding=1)
39
+
40
+ def forward(self, x):
41
+ x = torch.nn.functional.interpolate(
42
+ x, scale_factor=2.0, mode="nearest")
43
+ if self.with_conv:
44
+ x = self.conv(x)
45
+ return x
46
+
47
+
48
+ class Downsample(nn.Module):
49
+ def __init__(self, in_channels, with_conv):
50
+ super().__init__()
51
+ self.with_conv = with_conv
52
+ if self.with_conv:
53
+ self.conv = torch.nn.Conv2d(in_channels,
54
+ in_channels,
55
+ kernel_size=3,
56
+ stride=2,
57
+ padding=0)
58
+
59
+ def forward(self, x):
60
+ if self.with_conv:
61
+ pad = (0, 1, 0, 1)
62
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
63
+ x = self.conv(x)
64
+ else:
65
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
66
+ return x
67
+
68
+
69
+ class ResnetBlock(nn.Module):
70
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
71
+ dropout, temb_channels=512):
72
+ super().__init__()
73
+ self.in_channels = in_channels
74
+ out_channels = in_channels if out_channels is None else out_channels
75
+ self.out_channels = out_channels
76
+ self.use_conv_shortcut = conv_shortcut
77
+
78
+ self.norm1 = Normalize(in_channels)
79
+ self.conv1 = torch.nn.Conv2d(in_channels,
80
+ out_channels,
81
+ kernel_size=3,
82
+ stride=1,
83
+ padding=1)
84
+ self.temb_proj = torch.nn.Linear(temb_channels,
85
+ out_channels)
86
+ self.norm2 = Normalize(out_channels)
87
+ self.dropout = torch.nn.Dropout(dropout)
88
+ self.conv2 = torch.nn.Conv2d(out_channels,
89
+ out_channels,
90
+ kernel_size=3,
91
+ stride=1,
92
+ padding=1)
93
+ if self.in_channels != self.out_channels:
94
+ if self.use_conv_shortcut:
95
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
96
+ out_channels,
97
+ kernel_size=3,
98
+ stride=1,
99
+ padding=1)
100
+ else:
101
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
102
+ out_channels,
103
+ kernel_size=1,
104
+ stride=1,
105
+ padding=0)
106
+
107
+ def forward(self, x, temb):
108
+ h = x
109
+ h = self.norm1(h)
110
+ h = nonlinearity(h)
111
+ h = self.conv1(h)
112
+
113
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
114
+
115
+ h = self.norm2(h)
116
+ h = nonlinearity(h)
117
+ h = self.dropout(h)
118
+ h = self.conv2(h)
119
+
120
+ if self.in_channels != self.out_channels:
121
+ if self.use_conv_shortcut:
122
+ x = self.conv_shortcut(x)
123
+ else:
124
+ x = self.nin_shortcut(x)
125
+
126
+ return x+h
127
+
128
+
129
+ class AttnBlock(nn.Module):
130
+ def __init__(self, in_channels):
131
+ super().__init__()
132
+ self.in_channels = in_channels
133
+
134
+ self.norm = Normalize(in_channels)
135
+ self.q = torch.nn.Conv2d(in_channels,
136
+ in_channels,
137
+ kernel_size=1,
138
+ stride=1,
139
+ padding=0)
140
+ self.k = torch.nn.Conv2d(in_channels,
141
+ in_channels,
142
+ kernel_size=1,
143
+ stride=1,
144
+ padding=0)
145
+ self.v = torch.nn.Conv2d(in_channels,
146
+ in_channels,
147
+ kernel_size=1,
148
+ stride=1,
149
+ padding=0)
150
+ self.proj_out = torch.nn.Conv2d(in_channels,
151
+ in_channels,
152
+ kernel_size=1,
153
+ stride=1,
154
+ padding=0)
155
+
156
+ def forward(self, x):
157
+ h_ = x
158
+ h_ = self.norm(h_)
159
+ q = self.q(h_)
160
+ k = self.k(h_)
161
+ v = self.v(h_)
162
+
163
+ # 自注意力
164
+ b, c, h, w = q.shape
165
+ q = q.reshape(b, c, h*w)
166
+ q = q.permute(0, 2, 1).contiguous() # b,hw,c
167
+ k = k.reshape(b, c, h*w) # b,c,hw
168
+ w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
169
+ w_ = w_ * (int(c)**(-0.5))
170
+ w_ = torch.nn.functional.softmax(w_, dim=2)
171
+
172
+ # attend to values
173
+ v = v.reshape(b, c, h*w)
174
+ w_ = w_.permute(0, 2, 1).contiguous() # b,hw,hw (first hw of k, second of q)
175
+ h_ = torch.bmm(v, w_)
176
+ h_ = h_.reshape(b, c, h, w)
177
+
178
+ h_ = self.proj_out(h_)
179
+
180
+ return x+h_
181
+
182
+
183
+ class DiffusionUNet(nn.Module):
184
+ def __init__(self, config):
185
+ super().__init__()
186
+ self.config = config
187
+ ch, out_ch, ch_mult = config.model.ch, config.model.out_ch, tuple(config.model.ch_mult)
188
+ num_res_blocks = config.model.num_res_blocks
189
+ attn_resolutions = config.model.attn_resolutions
190
+ dropout = config.model.dropout
191
+ in_channels = config.model.in_channels * 2 if config.data.conditional else config.model.in_channels
192
+ resolution = config.data.image_size
193
+ resamp_with_conv = config.model.resamp_with_conv
194
+
195
+ self.ch = ch
196
+ self.temb_ch = self.ch*4
197
+ self.num_resolutions = len(ch_mult)
198
+ self.num_res_blocks = num_res_blocks
199
+ self.resolution = resolution
200
+ self.in_channels = in_channels
201
+
202
+ # timestep embedding
203
+ self.temb = nn.Module()
204
+ self.temb.dense = nn.ModuleList([
205
+ torch.nn.Linear(self.ch,
206
+ self.temb_ch),
207
+ torch.nn.Linear(self.temb_ch,
208
+ self.temb_ch),
209
+ ])
210
+
211
+ # 下采样
212
+ self.conv_in = torch.nn.Conv2d(in_channels,
213
+ self.ch,
214
+ kernel_size=3,
215
+ stride=1,
216
+ padding=1)
217
+
218
+ curr_res = resolution
219
+ in_ch_mult = (1,)+ch_mult
220
+ self.down = nn.ModuleList()
221
+ block_in = None
222
+ for i_level in range(self.num_resolutions):
223
+ block = nn.ModuleList()
224
+ attn = nn.ModuleList()
225
+ block_in = ch*in_ch_mult[i_level]
226
+ block_out = ch*ch_mult[i_level]
227
+ for i_block in range(self.num_res_blocks):
228
+ block.append(ResnetBlock(in_channels=block_in,
229
+ out_channels=block_out,
230
+ temb_channels=self.temb_ch,
231
+ dropout=dropout))
232
+ block_in = block_out
233
+ if curr_res in attn_resolutions:
234
+ attn.append(AttnBlock(block_in))
235
+ down = nn.Module()
236
+ down.block = block
237
+ down.attn = attn
238
+ if i_level != self.num_resolutions-1:
239
+ down.downsample = Downsample(block_in, resamp_with_conv)
240
+ curr_res = curr_res // 2
241
+ self.down.append(down)
242
+
243
+ # middle
244
+ self.mid = nn.Module()
245
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
246
+ out_channels=block_in,
247
+ temb_channels=self.temb_ch,
248
+ dropout=dropout)
249
+ self.mid.attn_1 = AttnBlock(block_in)
250
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
251
+ out_channels=block_in,
252
+ temb_channels=self.temb_ch,
253
+ dropout=dropout)
254
+
255
+ # 上采样
256
+ self.up = nn.ModuleList()
257
+ for i_level in reversed(range(self.num_resolutions)):
258
+ block = nn.ModuleList()
259
+ attn = nn.ModuleList()
260
+ block_out = ch*ch_mult[i_level]
261
+ skip_in = ch*ch_mult[i_level]
262
+ for i_block in range(self.num_res_blocks+1):
263
+ if i_block == self.num_res_blocks:
264
+ skip_in = ch*in_ch_mult[i_level]
265
+ block.append(ResnetBlock(in_channels=block_in+skip_in,
266
+ out_channels=block_out,
267
+ temb_channels=self.temb_ch,
268
+ dropout=dropout))
269
+ block_in = block_out
270
+ if curr_res in attn_resolutions:
271
+ attn.append(AttnBlock(block_in))
272
+ up = nn.Module()
273
+ up.block = block
274
+ up.attn = attn
275
+ if i_level != 0:
276
+ up.upsample = Upsample(block_in, resamp_with_conv)
277
+ curr_res = curr_res * 2
278
+ self.up.insert(0, up) # prepend to get consistent order
279
+
280
+ # end
281
+ self.norm_out = Normalize(block_in)
282
+ self.conv_out = torch.nn.Conv2d(block_in,
283
+ out_ch,
284
+ kernel_size=3,
285
+ stride=1,
286
+ padding=1)
287
+
288
+ def forward(self, x, t):
289
+ assert x.shape[2] == x.shape[3] == self.resolution
290
+
291
+ # timestep embedding
292
+ temb = get_timestep_embedding(t, self.ch)
293
+ temb = self.temb.dense[0](temb)
294
+ temb = nonlinearity(temb)
295
+ temb = self.temb.dense[1](temb)
296
+
297
+ # 下采样
298
+ hs = [self.conv_in(x)]
299
+ for i_level in range(self.num_resolutions):
300
+ for i_block in range(self.num_res_blocks):
301
+ h = self.down[i_level].block[i_block](hs[-1], temb)
302
+ if len(self.down[i_level].attn) > 0:
303
+ h = self.down[i_level].attn[i_block](h)
304
+ hs.append(h)
305
+ if i_level != self.num_resolutions-1:
306
+ hs.append(self.down[i_level].downsample(hs[-1]))
307
+
308
+ # middle
309
+ h = hs[-1]
310
+ h = self.mid.block_1(h, temb)
311
+ h = self.mid.attn_1(h)
312
+ h = self.mid.block_2(h, temb)
313
+
314
+ # 上采样
315
+ for i_level in reversed(range(self.num_resolutions)):
316
+ for i_block in range(self.num_res_blocks+1):
317
+ h = self.up[i_level].block[i_block](
318
+ torch.cat([h, hs.pop()], dim=1), temb)
319
+ if len(self.up[i_level].attn) > 0:
320
+ h = self.up[i_level].attn[i_block](h)
321
+ if i_level != 0:
322
+ h = self.up[i_level].upsample(h)
323
+
324
+ # end
325
+ h = self.norm_out(h)
326
+ h = nonlinearity(h)
327
+ h = self.conv_out(h)
328
+ return h
329
+
330
+
331
+ # net = DiffusionUNet()