|
import torch |
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
torch.backends.cudnn.allow_tf32 = True |
|
import torch.distributed as dist |
|
from torch.utils.data import Dataset, DataLoader |
|
from torch.utils.data.distributed import DistributedSampler |
|
from torchvision.datasets import ImageFolder |
|
from torchvision import transforms |
|
from tqdm import tqdm |
|
import os |
|
import itertools |
|
from PIL import Image |
|
import numpy as np |
|
import argparse |
|
import random |
|
|
|
from skimage.metrics import peak_signal_noise_ratio as psnr_loss |
|
from skimage.metrics import structural_similarity as ssim_loss |
|
from diffusers.models import ConsistencyDecoderVAE |
|
|
|
|
|
class SingleFolderDataset(Dataset): |
|
def __init__(self, directory, transform=None): |
|
super().__init__() |
|
self.directory = directory |
|
self.transform = transform |
|
self.image_paths = [os.path.join(directory, file_name) for file_name in os.listdir(directory) |
|
if os.path.isfile(os.path.join(directory, file_name))] |
|
|
|
def __len__(self): |
|
return len(self.image_paths) |
|
|
|
def __getitem__(self, idx): |
|
image_path = self.image_paths[idx] |
|
image = Image.open(image_path).convert('RGB') |
|
if self.transform: |
|
image = self.transform(image) |
|
return image, torch.tensor(0) |
|
|
|
|
|
def create_npz_from_sample_folder(sample_dir, num=50_000): |
|
""" |
|
Builds a single .npz file from a folder of .png samples. |
|
""" |
|
samples = [] |
|
for i in tqdm(range(num), desc="Building .npz file from samples"): |
|
sample_pil = Image.open(f"{sample_dir}/{i:06d}.png") |
|
sample_np = np.asarray(sample_pil).astype(np.uint8) |
|
samples.append(sample_np) |
|
|
|
random.shuffle(samples) |
|
samples = np.stack(samples) |
|
assert samples.shape == (num, samples.shape[1], samples.shape[2], 3) |
|
npz_path = f"{sample_dir}.npz" |
|
np.savez(npz_path, arr_0=samples) |
|
print(f"Saved .npz file to {npz_path} [shape={samples.shape}].") |
|
return npz_path |
|
|
|
|
|
def center_crop_arr(pil_image, image_size): |
|
""" |
|
Center cropping implementation from ADM. |
|
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 |
|
""" |
|
while min(*pil_image.size) >= 2 * image_size: |
|
pil_image = pil_image.resize( |
|
tuple(x // 2 for x in pil_image.size), resample=Image.BOX |
|
) |
|
|
|
scale = image_size / min(*pil_image.size) |
|
pil_image = pil_image.resize( |
|
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC |
|
) |
|
|
|
arr = np.array(pil_image) |
|
crop_y = (arr.shape[0] - image_size) // 2 |
|
crop_x = (arr.shape[1] - image_size) // 2 |
|
return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) |
|
|
|
|
|
def main(args): |
|
|
|
assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage" |
|
torch.set_grad_enabled(False) |
|
|
|
|
|
dist.init_process_group("nccl") |
|
rank = dist.get_rank() |
|
device = rank % torch.cuda.device_count() |
|
seed = args.global_seed * dist.get_world_size() + rank |
|
torch.manual_seed(seed) |
|
torch.cuda.set_device(device) |
|
print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.") |
|
|
|
|
|
vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=torch.float16).to("cuda:{}".format(device)) |
|
|
|
|
|
folder_name = f"openai-consistencydecoder-{args.dataset}-size-{args.image_size}-seed-{args.global_seed}" |
|
sample_folder_dir = f"{args.sample_dir}/{folder_name}" |
|
if rank == 0: |
|
os.makedirs(sample_folder_dir, exist_ok=True) |
|
print(f"Saving .png samples at {sample_folder_dir}") |
|
dist.barrier() |
|
|
|
|
|
transform = transforms.Compose([ |
|
transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.image_size)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) |
|
]) |
|
if args.dataset == 'imagenet': |
|
dataset = ImageFolder(args.data_path, transform=transform) |
|
num_fid_samples = 50000 |
|
elif args.dataset == 'coco': |
|
dataset = SingleFolderDataset(args.data_path, transform=transform) |
|
num_fid_samples = 5000 |
|
else: |
|
raise Exception("please check dataset") |
|
sampler = DistributedSampler( |
|
dataset, |
|
num_replicas=dist.get_world_size(), |
|
rank=rank, |
|
shuffle=False, |
|
seed=args.global_seed |
|
) |
|
loader = DataLoader( |
|
dataset, |
|
batch_size=args.per_proc_batch_size, |
|
shuffle=False, |
|
sampler=sampler, |
|
num_workers=args.num_workers, |
|
pin_memory=True, |
|
drop_last=False |
|
) |
|
|
|
|
|
n = args.per_proc_batch_size |
|
global_batch_size = n * dist.get_world_size() |
|
psnr_val_rgb = [] |
|
ssim_val_rgb = [] |
|
|
|
loader = tqdm(loader) if rank == 0 else loader |
|
total = 0 |
|
for x, _ in loader: |
|
rgb_gts = x |
|
rgb_gts = (rgb_gts.permute(0, 2, 3, 1).to("cpu").numpy() + 1.0) / 2.0 |
|
x = x.half().to("cuda:{}".format(device)) |
|
with torch.no_grad(): |
|
|
|
latent = vae.encode(x).latent_dist.sample().mul_(0.18215) |
|
|
|
samples = vae.decode(latent / 0.18215).sample |
|
samples = torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy() |
|
|
|
|
|
for i, (sample, rgb_gt) in enumerate(zip(samples, rgb_gts)): |
|
index = i * dist.get_world_size() + rank + total |
|
Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png") |
|
|
|
rgb_restored = sample.astype(np.float32) / 255. |
|
psnr = psnr_loss(rgb_restored, rgb_gt) |
|
ssim = ssim_loss(rgb_restored, rgb_gt, multichannel=True, data_range=2.0, channel_axis=-1) |
|
psnr_val_rgb.append(psnr) |
|
ssim_val_rgb.append(ssim) |
|
total += global_batch_size |
|
|
|
|
|
|
|
|
|
|
|
dist.barrier() |
|
world_size = dist.get_world_size() |
|
gather_psnr_val = [None for _ in range(world_size)] |
|
gather_ssim_val = [None for _ in range(world_size)] |
|
dist.all_gather_object(gather_psnr_val, psnr_val_rgb) |
|
dist.all_gather_object(gather_ssim_val, ssim_val_rgb) |
|
|
|
if rank == 0: |
|
gather_psnr_val = list(itertools.chain(*gather_psnr_val)) |
|
gather_ssim_val = list(itertools.chain(*gather_ssim_val)) |
|
psnr_val_rgb = sum(gather_psnr_val) / len(gather_psnr_val) |
|
ssim_val_rgb = sum(gather_ssim_val) / len(gather_ssim_val) |
|
print("PSNR: %f, SSIM: %f " % (psnr_val_rgb, ssim_val_rgb)) |
|
|
|
result_file = f"{sample_folder_dir}_results.txt" |
|
print("writing results to {}".format(result_file)) |
|
with open(result_file, 'w') as f: |
|
print("PSNR: %f, SSIM: %f " % (psnr_val_rgb, ssim_val_rgb), file=f) |
|
|
|
create_npz_from_sample_folder(sample_folder_dir, num_fid_samples) |
|
print("Done.") |
|
|
|
dist.barrier() |
|
dist.destroy_process_group() |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--data-path", type=str, required=True) |
|
parser.add_argument("--dataset", type=str, choices=['imagenet', 'coco'], default='imagenet') |
|
parser.add_argument("--image-size", type=int, choices=[256, 512], default=256) |
|
parser.add_argument("--sample-dir", type=str, default="reconstructions") |
|
parser.add_argument("--per-proc-batch-size", type=int, default=32) |
|
parser.add_argument("--global-seed", type=int, default=0) |
|
parser.add_argument("--num-workers", type=int, default=4) |
|
args = parser.parse_args() |
|
main(args) |