Spaces:
				
			
			
	
			
			
		Configuration error
		
	
	
	
			
			
	
	
	
	
		
		
		Configuration error
		
	| import importlib | |
| import numpy as np | |
| import taming | |
| import torch | |
| import yaml | |
| from omegaconf import OmegaConf | |
| from PIL import Image | |
| from taming.models.vqgan import VQModel | |
| from utils import get_device | |
| # import discriminator | |
| def load_config(config_path, display=False): | |
| config = OmegaConf.load(config_path) | |
| if display: | |
| print(yaml.dump(OmegaConf.to_container(config))) | |
| return config | |
| # def load_disc(device): | |
| # dconf = load_config("disc_config.yaml") | |
| # sd = torch.load("disc.pt", map_location=device) | |
| # # print(sd.keys()) | |
| # model = discriminator.NLayerDiscriminator() | |
| # model.load_state_dict(sd, strict=True) | |
| # model.to(device) | |
| # return model | |
| # print(dconf.keys()) | |
| def load_default(device): | |
| # device = get_device() | |
| ckpt_path = "logs/2021-04-23T18-11-19_celebahq_transformer/checkpoints/last.ckpt" | |
| conf_path = "./unwrapped.yaml" | |
| config = load_config(conf_path, display=False) | |
| model = taming.models.vqgan.VQModel(**config.model.params) | |
| sd = torch.load("./vqgan_only.pt", map_location=device) | |
| model.load_state_dict(sd, strict=True) | |
| model.to(device) | |
| return model | |
| def load_vqgan(config, ckpt_path=None, is_gumbel=False): | |
| if is_gumbel: | |
| model = GumbelVQ(**config.model.params) | |
| else: | |
| model = VQModel(**config.model.params) | |
| if ckpt_path is not None: | |
| sd = torch.load(ckpt_path, map_location="cpu")["state_dict"] | |
| missing, unexpected = model.load_state_dict(sd, strict=False) | |
| return model.eval() | |
| def load_ffhq(): | |
| conf = "2020-11-09T13-33-36_faceshq_vqgan/configs/2020-11-09T13-33-36-project.yaml" | |
| ckpt = "2020-11-09T13-33-36_faceshq_vqgan/checkpoints/last.ckpt" | |
| vqgan = load_model(load_config(conf), ckpt, True, True)[0] | |
| def reconstruct_with_vqgan(x, model): | |
| # could also use model(x) for reconstruction but use explicit encoding and decoding here | |
| z, _, [_, _, indices] = model.encode(x) | |
| print(f"VQGAN --- {model.__class__.__name__}: latent shape: {z.shape[2:]}") | |
| xrec = model.decode(z) | |
| return xrec | |
| def get_obj_from_str(string, reload=False): | |
| module, cls = string.rsplit(".", 1) | |
| if reload: | |
| module_imp = importlib.import_module(module) | |
| importlib.reload(module_imp) | |
| return getattr(importlib.import_module(module, package=None), cls) | |
| def instantiate_from_config(config): | |
| if not "target" in config: | |
| raise KeyError("Expected key `target` to instantiate.") | |
| return get_obj_from_str(config["target"])(**config.get("params", dict())) | |
| def load_model_from_config(config, sd, gpu=True, eval_mode=True): | |
| model = instantiate_from_config(config) | |
| if sd is not None: | |
| model.load_state_dict(sd) | |
| if gpu: | |
| model.cuda() | |
| if eval_mode: | |
| model.eval() | |
| return {"model": model} | |
| def load_model(config, ckpt, gpu, eval_mode): | |
| # load the specified checkpoint | |
| if ckpt: | |
| pl_sd = torch.load(ckpt, map_location="cpu") | |
| global_step = pl_sd["global_step"] | |
| print(f"loaded model from global step {global_step}.") | |
| else: | |
| pl_sd = {"state_dict": None} | |
| global_step = None | |
| model = load_model_from_config(config.model, pl_sd["state_dict"], gpu=gpu, eval_mode=eval_mode)["model"] | |
| return model, global_step |