|
import argparse, os, sys, glob, math, time |
|
import torch |
|
import numpy as np |
|
from omegaconf import OmegaConf |
|
from PIL import Image |
|
from main import instantiate_from_config, DataModuleFromConfig |
|
from torch.utils.data import DataLoader |
|
from torch.utils.data.dataloader import default_collate |
|
from tqdm import trange |
|
|
|
|
|
def save_image(x, path): |
|
c,h,w = x.shape |
|
assert c==3 |
|
x = ((x.detach().cpu().numpy().transpose(1,2,0)+1.0)*127.5).clip(0,255).astype(np.uint8) |
|
Image.fromarray(x).save(path) |
|
|
|
|
|
@torch.no_grad() |
|
def run_conditional(model, dsets, outdir, top_k, temperature, batch_size=1): |
|
if len(dsets.datasets) > 1: |
|
split = sorted(dsets.datasets.keys())[0] |
|
dset = dsets.datasets[split] |
|
else: |
|
dset = next(iter(dsets.datasets.values())) |
|
print("Dataset: ", dset.__class__.__name__) |
|
for start_idx in trange(0,len(dset)-batch_size+1,batch_size): |
|
indices = list(range(start_idx, start_idx+batch_size)) |
|
example = default_collate([dset[i] for i in indices]) |
|
|
|
x = model.get_input("image", example).to(model.device) |
|
for i in range(x.shape[0]): |
|
save_image(x[i], os.path.join(outdir, "originals", |
|
"{:06}.png".format(indices[i]))) |
|
|
|
cond_key = model.cond_stage_key |
|
c = model.get_input(cond_key, example).to(model.device) |
|
|
|
scale_factor = 1.0 |
|
quant_z, z_indices = model.encode_to_z(x) |
|
quant_c, c_indices = model.encode_to_c(c) |
|
|
|
cshape = quant_z.shape |
|
|
|
xrec = model.first_stage_model.decode(quant_z) |
|
for i in range(xrec.shape[0]): |
|
save_image(xrec[i], os.path.join(outdir, "reconstructions", |
|
"{:06}.png".format(indices[i]))) |
|
|
|
if cond_key == "segmentation": |
|
|
|
num_classes = c.shape[1] |
|
c = torch.argmax(c, dim=1, keepdim=True) |
|
c = torch.nn.functional.one_hot(c, num_classes=num_classes) |
|
c = c.squeeze(1).permute(0, 3, 1, 2).float() |
|
c = model.cond_stage_model.to_rgb(c) |
|
|
|
idx = z_indices |
|
|
|
half_sample = False |
|
if half_sample: |
|
start = idx.shape[1]//2 |
|
else: |
|
start = 0 |
|
|
|
idx[:,start:] = 0 |
|
idx = idx.reshape(cshape[0],cshape[2],cshape[3]) |
|
start_i = start//cshape[3] |
|
start_j = start %cshape[3] |
|
|
|
cidx = c_indices |
|
cidx = cidx.reshape(quant_c.shape[0],quant_c.shape[2],quant_c.shape[3]) |
|
|
|
sample = True |
|
|
|
for i in range(start_i,cshape[2]-0): |
|
if i <= 8: |
|
local_i = i |
|
elif cshape[2]-i < 8: |
|
local_i = 16-(cshape[2]-i) |
|
else: |
|
local_i = 8 |
|
for j in range(start_j,cshape[3]-0): |
|
if j <= 8: |
|
local_j = j |
|
elif cshape[3]-j < 8: |
|
local_j = 16-(cshape[3]-j) |
|
else: |
|
local_j = 8 |
|
|
|
i_start = i-local_i |
|
i_end = i_start+16 |
|
j_start = j-local_j |
|
j_end = j_start+16 |
|
patch = idx[:,i_start:i_end,j_start:j_end] |
|
patch = patch.reshape(patch.shape[0],-1) |
|
cpatch = cidx[:, i_start:i_end, j_start:j_end] |
|
cpatch = cpatch.reshape(cpatch.shape[0], -1) |
|
patch = torch.cat((cpatch, patch), dim=1) |
|
logits,_ = model.transformer(patch[:,:-1]) |
|
logits = logits[:, -256:, :] |
|
logits = logits.reshape(cshape[0],16,16,-1) |
|
logits = logits[:,local_i,local_j,:] |
|
|
|
logits = logits/temperature |
|
|
|
if top_k is not None: |
|
logits = model.top_k_logits(logits, top_k) |
|
|
|
probs = torch.nn.functional.softmax(logits, dim=-1) |
|
|
|
if sample: |
|
ix = torch.multinomial(probs, num_samples=1) |
|
else: |
|
_, ix = torch.topk(probs, k=1, dim=-1) |
|
idx[:,i,j] = ix |
|
|
|
xsample = model.decode_to_img(idx[:,:cshape[2],:cshape[3]], cshape) |
|
for i in range(xsample.shape[0]): |
|
save_image(xsample[i], os.path.join(outdir, "samples", |
|
"{:06}.png".format(indices[i]))) |
|
|
|
|
|
def get_parser(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"-r", |
|
"--resume", |
|
type=str, |
|
nargs="?", |
|
help="load from logdir or checkpoint in logdir", |
|
) |
|
parser.add_argument( |
|
"-b", |
|
"--base", |
|
nargs="*", |
|
metavar="base_config.yaml", |
|
help="paths to base configs. Loaded from left-to-right. " |
|
"Parameters can be overwritten or added with command-line options of the form `--key value`.", |
|
default=list(), |
|
) |
|
parser.add_argument( |
|
"-c", |
|
"--config", |
|
nargs="?", |
|
metavar="single_config.yaml", |
|
help="path to single config. If specified, base configs will be ignored " |
|
"(except for the last one if left unspecified).", |
|
const=True, |
|
default="", |
|
) |
|
parser.add_argument( |
|
"--ignore_base_data", |
|
action="store_true", |
|
help="Ignore data specification from base configs. Useful if you want " |
|
"to specify a custom datasets on the command line.", |
|
) |
|
parser.add_argument( |
|
"--outdir", |
|
required=True, |
|
type=str, |
|
help="Where to write outputs to.", |
|
) |
|
parser.add_argument( |
|
"--top_k", |
|
type=int, |
|
default=100, |
|
help="Sample from among top-k predictions.", |
|
) |
|
parser.add_argument( |
|
"--temperature", |
|
type=float, |
|
default=1.0, |
|
help="Sampling temperature.", |
|
) |
|
return parser |
|
|
|
|
|
def load_model_from_config(config, sd, gpu=True, eval_mode=True): |
|
if "ckpt_path" in config.params: |
|
print("Deleting the restore-ckpt path from the config...") |
|
config.params.ckpt_path = None |
|
if "downsample_cond_size" in config.params: |
|
print("Deleting downsample-cond-size from the config and setting factor=0.5 instead...") |
|
config.params.downsample_cond_size = -1 |
|
config.params["downsample_cond_factor"] = 0.5 |
|
try: |
|
if "ckpt_path" in config.params.first_stage_config.params: |
|
config.params.first_stage_config.params.ckpt_path = None |
|
print("Deleting the first-stage restore-ckpt path from the config...") |
|
if "ckpt_path" in config.params.cond_stage_config.params: |
|
config.params.cond_stage_config.params.ckpt_path = None |
|
print("Deleting the cond-stage restore-ckpt path from the config...") |
|
except: |
|
pass |
|
|
|
model = instantiate_from_config(config) |
|
if sd is not None: |
|
missing, unexpected = model.load_state_dict(sd, strict=False) |
|
print(f"Missing Keys in State Dict: {missing}") |
|
print(f"Unexpected Keys in State Dict: {unexpected}") |
|
if gpu: |
|
model.cuda() |
|
if eval_mode: |
|
model.eval() |
|
return {"model": model} |
|
|
|
|
|
def get_data(config): |
|
|
|
data = instantiate_from_config(config.data) |
|
data.prepare_data() |
|
data.setup() |
|
return data |
|
|
|
|
|
def load_model_and_dset(config, ckpt, gpu, eval_mode): |
|
|
|
dsets = get_data(config) |
|
|
|
|
|
if ckpt: |
|
pl_sd = torch.load(ckpt, map_location="cpu") |
|
global_step = pl_sd["global_step"] |
|
else: |
|
pl_sd = {"state_dict": None} |
|
global_step = None |
|
model = load_model_from_config(config.model, |
|
pl_sd["state_dict"], |
|
gpu=gpu, |
|
eval_mode=eval_mode)["model"] |
|
return dsets, model, global_step |
|
|
|
|
|
if __name__ == "__main__": |
|
sys.path.append(os.getcwd()) |
|
|
|
parser = get_parser() |
|
|
|
opt, unknown = parser.parse_known_args() |
|
|
|
ckpt = None |
|
if opt.resume: |
|
if not os.path.exists(opt.resume): |
|
raise ValueError("Cannot find {}".format(opt.resume)) |
|
if os.path.isfile(opt.resume): |
|
paths = opt.resume.split("/") |
|
try: |
|
idx = len(paths)-paths[::-1].index("logs")+1 |
|
except ValueError: |
|
idx = -2 |
|
logdir = "/".join(paths[:idx]) |
|
ckpt = opt.resume |
|
else: |
|
assert os.path.isdir(opt.resume), opt.resume |
|
logdir = opt.resume.rstrip("/") |
|
ckpt = os.path.join(logdir, "checkpoints", "last.ckpt") |
|
print(f"logdir:{logdir}") |
|
base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*-project.yaml"))) |
|
opt.base = base_configs+opt.base |
|
|
|
if opt.config: |
|
if type(opt.config) == str: |
|
opt.base = [opt.config] |
|
else: |
|
opt.base = [opt.base[-1]] |
|
|
|
configs = [OmegaConf.load(cfg) for cfg in opt.base] |
|
cli = OmegaConf.from_dotlist(unknown) |
|
if opt.ignore_base_data: |
|
for config in configs: |
|
if hasattr(config, "data"): del config["data"] |
|
config = OmegaConf.merge(*configs, cli) |
|
|
|
print(ckpt) |
|
gpu = True |
|
eval_mode = True |
|
show_config = False |
|
if show_config: |
|
print(OmegaConf.to_container(config)) |
|
|
|
dsets, model, global_step = load_model_and_dset(config, ckpt, gpu, eval_mode) |
|
print(f"Global step: {global_step}") |
|
|
|
outdir = os.path.join(opt.outdir, "{:06}_{}_{}".format(global_step, |
|
opt.top_k, |
|
opt.temperature)) |
|
os.makedirs(outdir, exist_ok=True) |
|
print("Writing samples to ", outdir) |
|
for k in ["originals", "reconstructions", "samples"]: |
|
os.makedirs(os.path.join(outdir, k), exist_ok=True) |
|
run_conditional(model, dsets, outdir, opt.top_k, opt.temperature) |
|
|