Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	File size: 1,988 Bytes
			
			| 7ae68fe | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 | # ------------------------------------------------------------------------------
# Copyright (c) Microsoft
# Licensed under the MIT License.
# Written by Zigang Geng ([email protected])
# ------------------------------------------------------------------------------
from __future__ import annotations
import sys
import torch
from argparse import ArgumentParser
from omegaconf import OmegaConf
sys.path.append("./stable_diffusion")
from stable_diffusion.ldm.util import instantiate_from_config
if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--config", default="configs/instruct_diffusion.yaml", type=str)
    parser.add_argument("--ema-ckpt", default="logs/instruct_diffusion/checkpoints/ckpt_epoch_200/state.pth", type=str)
    parser.add_argument("--vae-ckpt", default="stable_diffusion/models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt", type=str)
    parser.add_argument("--out-ckpt", default="checkpoints/v1-5-pruned-emaonly-adaption-task.ckpt", type=str)
    args = parser.parse_args()
    config = OmegaConf.load(args.config)
    model = instantiate_from_config(config.model)
    ema_ckpt = torch.load(args.ema_ckpt, map_location="cpu")
    all_keys = [key for key, value in model.named_parameters()]
    all_keys_rmv = [key.replace('.','') for key in all_keys]
    new_ema_ckpt = {}
    for k, v in ema_ckpt['model_ema'].items():
        try:
            k_index = all_keys_rmv.index(k)
            new_ema_ckpt[all_keys[k_index]] = v
        except:
            print(k+' is not in the list.')
    vae_ckpt = torch.load(args.vae_ckpt, map_location="cpu")
    for k, v in vae_ckpt['state_dict'].items():
        if k not in new_ema_ckpt and k in all_keys:
            new_ema_ckpt[k] = v
    checkpoint = {'state_dict': new_ema_ckpt}
    with open(args.out_ckpt, 'wb') as f:
        torch.save(checkpoint, f)
        f.flush()
    print('Converted successfully, the new checkpoint has been saved to ' + str(args.out_ckpt)) |