Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn.functional as F | |
| import os | |
| import argparse | |
| import numpy as np | |
| from PIL import Image | |
| from tokenizer.tokenizer_image.vq_model import VQ_models | |
| from dataset.augmentation import center_crop_arr | |
| def main(args): | |
| # Setup PyTorch: | |
| torch.manual_seed(args.seed) | |
| torch.set_grad_enabled(False) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # create and load model | |
| model = VQ_models[args.vq_model]( | |
| codebook_size=args.codebook_size, | |
| codebook_embed_dim=args.codebook_embed_dim) | |
| model.to(device) | |
| model.eval() | |
| checkpoint = torch.load(args.vq_ckpt, map_location="cpu") | |
| if "ema" in checkpoint: # ema | |
| model_weight = checkpoint["ema"] | |
| elif "model" in checkpoint: # ddp | |
| model_weight = checkpoint["model"] | |
| elif "state_dict" in checkpoint: | |
| model_weight = checkpoint["state_dict"] | |
| else: | |
| raise Exception("please check model weight") | |
| model.load_state_dict(model_weight) | |
| del checkpoint | |
| # output dir | |
| os.makedirs(args.output_dir, exist_ok=True) | |
| out_path = args.image_path.replace('.jpg', '_{}.jpg'.format(args.suffix)) | |
| out_path = out_path.replace('.jpeg', '_{}.jpeg'.format(args.suffix)) | |
| out_path = out_path.replace('.png', '_{}.png'.format(args.suffix)) | |
| out_filename = out_path.split('/')[-1] | |
| out_path = os.path.join(args.output_dir, out_filename) | |
| # load image | |
| pil_image = Image.open(args.image_path).convert("RGB") | |
| img = center_crop_arr(pil_image, args.image_size) | |
| # # preprocess | |
| # size_org = img.size | |
| # img = img.resize((input_size, input_size)) | |
| img = np.array(img) / 255. | |
| x = 2.0 * img - 1.0 # x value is between [-1, 1] | |
| x = torch.tensor(x) | |
| x = x.unsqueeze(dim=0) | |
| x = torch.einsum('nhwc->nchw', x) | |
| x_input = x.float().to("cuda") | |
| # inference | |
| with torch.no_grad(): | |
| latent, _, [_, _, indices] = model.encode(x_input) | |
| output = model.decode_code(indices, latent.shape) # output value is between [-1, 1] | |
| # postprocess | |
| output = F.interpolate(output, size=[args.image_size, args.image_size], mode='bicubic').permute(0, 2, 3, 1)[0] | |
| sample = torch.clamp(127.5 * output + 128.0, 0, 255).to("cpu", dtype=torch.uint8).numpy() | |
| # save | |
| Image.fromarray(sample).save(out_path) | |
| print("Reconstructed image is saved to {}".format(out_path)) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--image-path", type=str, default="assets/example.jpg") | |
| parser.add_argument("--output-dir", type=str, default="output_vq_demo") | |
| parser.add_argument("--suffix", type=str, default="tokenizer_image") | |
| parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16") | |
| parser.add_argument("--vq-ckpt", type=str, default=None, help="ckpt path for vq model") | |
| parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization") | |
| parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization") | |
| parser.add_argument("--image-size", type=int, choices=[256, 384, 448, 512, 1024], default=512) | |
| parser.add_argument("--seed", type=int, default=0) | |
| args = parser.parse_args() | |
| main(args) |