JGWEV / models /ddm.py
HaisuGuan's picture
模型代码
beef72d
import os
import time
import numpy as np
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import utils
from models.unet import DiffusionUNet
import torch.distributed as dist
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import CosineAnnealingLR
def data_transform(X):
return 2 * X - 1.0
def inverse_data_transform(X):
return torch.clamp((X + 1.0) / 2.0, 0.0, 1.0)
class EMAHelper(object):
def __init__(self, mu=0.9999):
self.mu = mu
self.shadow = {}
def register(self, module):
if isinstance(module, nn.DataParallel) or isinstance(module, nn.parallel.DistributedDataParallel):
module = module.module
for name, param in module.named_parameters():
if param.requires_grad:
self.shadow[name] = param.data.clone()
def update(self, module, device):
if isinstance(module, nn.DataParallel) or isinstance(module, nn.parallel.DistributedDataParallel):
module = module.module
for name, param in module.named_parameters():
if param.requires_grad:
self.shadow[name].data = (1. - self.mu) * param.data + self.mu * self.shadow[name].data.to(device)
def ema(self, module):
if isinstance(module, nn.DataParallel) or isinstance(module, nn.parallel.DistributedDataParallel):
module = module.module
for name, param in module.named_parameters():
if param.requires_grad:
param.data.copy_(self.shadow[name].data)
def ema_copy(self, module):
if isinstance(module, nn.DataParallel) or isinstance(module, nn.parallel.DistributedDataParallel):
inner_module = module.module
module_copy = type(inner_module)(inner_module.config).to(inner_module.config.device)
module_copy.load_state_dict(inner_module.state_dict())
module_copy = nn.DataParallel(module_copy)
else:
module_copy = type(module)(module.config).to(module.config.device)
module_copy.load_state_dict(module.state_dict())
self.ema(module_copy)
return module_copy
def state_dict(self):
return self.shadow
def load_state_dict(self, state_dict):
self.shadow = state_dict
def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
def sigmoid(x):
return 1 / (np.exp(-x) + 1)
if beta_schedule == "quad":
betas = (np.linspace(beta_start ** 0.5, beta_end ** 0.5, num_diffusion_timesteps, dtype=np.float64) ** 2)
elif beta_schedule == "linear":
betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
elif beta_schedule == "const":
betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
betas = 1.0 / np.linspace(num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64)
elif beta_schedule == "sigmoid":
betas = np.linspace(-6, 6, num_diffusion_timesteps)
betas = sigmoid(betas) * (beta_end - beta_start) + beta_start
else:
raise NotImplementedError(beta_schedule)
assert betas.shape == (num_diffusion_timesteps,)
return betas
def noise_estimation_loss(model, x0, t, e, b):
a = (1 - b).cumprod(dim=0).index_select(0, t).view(-1, 1, 1, 1)
x = x0[:, 3:, :, :] * a.sqrt() + e * (1.0 - a).sqrt()
output = model(torch.cat([x0[:, :3, :, :], x], dim=1), t.float())
return (e - output).square().sum(dim=(1, 2, 3)).mean(dim=0)
class DenoisingDiffusion(object):
def __init__(self, config, test=False):
super().__init__()
self.config = config
self.device = config.device
self.writer = SummaryWriter(config.data.tensorboard)
self.model = DiffusionUNet(config)
self.model.to(self.device)
if test:
self.model = torch.nn.DataParallel(self.model)
else:
self.model = torch.nn.parallel.DistributedDataParallel(self.model, device_ids=[config.local_rank],
output_device=config.local_rank)
self.ema_helper = EMAHelper()
self.ema_helper.register(self.model)
self.optimizer = utils.optimize.get_optimizer(self.config, self.model.parameters())
self.scheduler = CosineAnnealingLR(self.optimizer, T_max=config.training.n_epochs)
self.start_epoch, self.step = 0, 0
betas = get_beta_schedule(
beta_schedule=config.diffusion.beta_schedule,
beta_start=config.diffusion.beta_start,
beta_end=config.diffusion.beta_end,
num_diffusion_timesteps=config.diffusion.num_diffusion_timesteps,
)
betas = self.betas = torch.from_numpy(betas).float().to(self.device)
self.num_timesteps = betas.shape[0]
def load_ddm_ckpt(self, load_path, ema=False):
checkpoint = utils.logging.load_checkpoint(load_path, None)
self.start_epoch = checkpoint['epoch']
self.step = checkpoint['step']
self.model.load_state_dict(checkpoint['state_dict'], strict=True)
self.optimizer.load_state_dict(checkpoint['optimizer'])
self.ema_helper.load_state_dict(checkpoint['ema_helper'])
self.scheduler.load_state_dict(checkpoint['scheduler'])
if ema:
self.ema_helper.ema(self.model)
print("=> loaded checkpoint '{}' (epoch {}, step {})".format(load_path, checkpoint['epoch'], self.step))
def train(self, DATASET):
cudnn.benchmark = True
train_loader, val_loader = DATASET.get_loaders()
pretrained_model_path = self.config.training.resume + '.pth.tar'
if os.path.isfile(pretrained_model_path):
self.load_ddm_ckpt(pretrained_model_path)
dist.barrier()
# 训练
for epoch in range(self.start_epoch, self.config.training.n_epochs):
if (epoch == 0) and dist.get_rank() == 0:
utils.logging.save_checkpoint({
'epoch': epoch + 1,
'step': self.step,
'state_dict': self.model.state_dict(),
'optimizer': self.optimizer.state_dict(),
'ema_helper': self.ema_helper.state_dict(),
'config': self.config,
'scheduler': self.scheduler.state_dict()
}, filename=self.config.training.resume + '_' + str(epoch))
utils.logging.save_checkpoint({
'epoch': epoch + 1,
'step': self.step,
'state_dict': self.model.state_dict(),
'optimizer': self.optimizer.state_dict(),
'ema_helper': self.ema_helper.state_dict(),
'config': self.config,
'scheduler': self.scheduler.state_dict()
}, filename=self.config.training.resume)
if dist.get_rank() == 0:
print('=> current epoch: ', epoch)
data_start = time.time()
data_time = 0
train_loader.sampler.set_epoch(epoch)
for i, (x, y) in enumerate(train_loader):
x = x.flatten(start_dim=0, end_dim=1) if x.ndim == 5 else x
n = x.size(0)
data_time += time.time() - data_start
self.model.train()
self.step += 1
x = x.to(self.device)
x = data_transform(x)
e = torch.randn_like(x[:, 3:, :, :])
b = self.betas
# antithetic sampling
t = torch.randint(low=0, high=self.num_timesteps, size=(n // 2 + 1,)).to(self.device)
t = torch.cat([t, self.num_timesteps - t - 1], dim=0)[:n]
loss = noise_estimation_loss(self.model, x, t, e, b)
current_lr = self.optimizer.param_groups[0]['lr']
if self.step % 10 == 0:
print(
'rank: %d, step: %d, loss: %.6f, lr: %.6f, time consumption: %.6f' % (
dist.get_rank(), self.step, loss.item(), current_lr, data_time / (i + 1)))
# 更新参数
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
self.ema_helper.update(self.model, self.device)
data_start = time.time()
if self.step % self.config.training.validation_freq == 0:
self.model.eval()
self.sample_validation_patches(val_loader, self.step)
if (self.step % 100 == 0) and dist.get_rank() == 0:
self.writer.add_scalar('train/loss', loss.item(), self.step)
self.writer.add_scalar('train/lr', current_lr, self.step)
self.scheduler.step()
# 保存模型
if (epoch % self.config.training.snapshot_freq == 0) and dist.get_rank() == 0:
utils.logging.save_checkpoint({
'epoch': epoch + 1,
'step': self.step,
'state_dict': self.model.state_dict(),
'optimizer': self.optimizer.state_dict(),
'ema_helper': self.ema_helper.state_dict(),
'config': self.config,
'scheduler': self.scheduler.state_dict()
}, filename=self.config.training.resume + '_' + str(epoch))
utils.logging.save_checkpoint({
'epoch': epoch + 1,
'step': self.step,
'state_dict': self.model.state_dict(),
'optimizer': self.optimizer.state_dict(),
'ema_helper': self.ema_helper.state_dict(),
'config': self.config,
'scheduler': self.scheduler.state_dict()
}, filename=self.config.training.resume)
def sample_image(self, x_cond, x, last=True, patch_locs=None, patch_size=None):
skip = self.config.diffusion.num_diffusion_timesteps // self.config.sampling.sampling_timesteps
seq = range(0, self.config.diffusion.num_diffusion_timesteps, skip)
if patch_locs is not None:
xs = utils.sampling.generalized_steps_overlapping(x, x_cond, seq, self.model, self.betas, eta=0.,
corners=patch_locs, p_size=patch_size, device=self.device)
else:
xs = utils.sampling.generalized_steps(x, x_cond, seq, self.model, self.betas, eta=0., device=self.device)
if last:
xs = xs[0][-1]
return xs
def sample_validation_patches(self, val_loader, step):
image_folder = os.path.join(self.config.data.val_save_dir, str(self.config.data.image_size))
with torch.no_grad():
if dist.get_rank() == 0:
print(f"Processing a single batch of validation images at step: {step}")
for i, (x, y) in enumerate(val_loader):
x = x.flatten(start_dim=0, end_dim=1) if x.ndim == 5 else x
break
n = x.size(0)
x_cond = x[:, :3, :, :].to(self.device) # 条件图像
x_cond = data_transform(x_cond)
x = torch.randn(n, 3, self.config.data.image_size, self.config.data.image_size, device=self.device)
x = self.sample_image(x_cond, x)
x = inverse_data_transform(x)
x_cond = inverse_data_transform(x_cond)
for i in range(n):
utils.logging.save_image(x_cond[i], os.path.join(image_folder, str(step), f"{i}_cond.png"))
utils.logging.save_image(x[i], os.path.join(image_folder, str(step), f"{i}.png"))