aufa.husen commited on
Commit
5cceaab
·
1 Parent(s): fa4b84d

add orangemix vae

Browse files
main.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import io
3
+
4
+ import requests
5
+ import torch
6
+ from omegaconf import OmegaConf
7
+
8
+ from diffusers import AutoencoderKL
9
+
10
+
11
+ from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
12
+ assign_to_checkpoint,
13
+ conv_attn_to_linear,
14
+ create_vae_diffusers_config,
15
+ renew_vae_attention_paths,
16
+ renew_vae_resnet_paths,
17
+ )
18
+
19
+
20
+ def custom_convert_ldm_vae_checkpoint(checkpoint, config):
21
+ vae_state_dict = checkpoint
22
+
23
+ new_checkpoint = {}
24
+
25
+ new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
26
+ new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
27
+ new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
28
+ new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
29
+ new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
30
+ new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
31
+
32
+ new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
33
+ new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
34
+ new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
35
+ new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
36
+ new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
37
+ new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
38
+
39
+ new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
40
+ new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
41
+ new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
42
+ new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
43
+
44
+ # Retrieves the keys for the encoder down blocks only
45
+ num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
46
+ down_blocks = {
47
+ layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
48
+ }
49
+
50
+ # Retrieves the keys for the decoder up blocks only
51
+ num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
52
+ up_blocks = {
53
+ layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
54
+ }
55
+
56
+ for i in range(num_down_blocks):
57
+ resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
58
+
59
+ if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
60
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
61
+ f"encoder.down.{i}.downsample.conv.weight"
62
+ )
63
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
64
+ f"encoder.down.{i}.downsample.conv.bias"
65
+ )
66
+
67
+ paths = renew_vae_resnet_paths(resnets)
68
+ meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
69
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
70
+
71
+ mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
72
+ num_mid_res_blocks = 2
73
+ for i in range(1, num_mid_res_blocks + 1):
74
+ resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
75
+
76
+ paths = renew_vae_resnet_paths(resnets)
77
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
78
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
79
+
80
+ mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
81
+ paths = renew_vae_attention_paths(mid_attentions)
82
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
83
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
84
+ conv_attn_to_linear(new_checkpoint)
85
+
86
+ for i in range(num_up_blocks):
87
+ block_id = num_up_blocks - 1 - i
88
+ resnets = [
89
+ key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
90
+ ]
91
+
92
+ if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
93
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
94
+ f"decoder.up.{block_id}.upsample.conv.weight"
95
+ ]
96
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
97
+ f"decoder.up.{block_id}.upsample.conv.bias"
98
+ ]
99
+
100
+ paths = renew_vae_resnet_paths(resnets)
101
+ meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
102
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
103
+
104
+ mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
105
+ num_mid_res_blocks = 2
106
+ for i in range(1, num_mid_res_blocks + 1):
107
+ resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
108
+
109
+ paths = renew_vae_resnet_paths(resnets)
110
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
111
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
112
+
113
+ mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
114
+ paths = renew_vae_attention_paths(mid_attentions)
115
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
116
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
117
+ conv_attn_to_linear(new_checkpoint)
118
+ return new_checkpoint
119
+
120
+
121
+ def vae_pt_to_vae_diffuser(
122
+ checkpoint_path: str,
123
+ output_path: str,
124
+ ):
125
+ # Only support V1
126
+ r = requests.get(
127
+ " https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
128
+ )
129
+ io_obj = io.BytesIO(r.content)
130
+
131
+ original_config = OmegaConf.load(io_obj)
132
+ image_size = 512
133
+ device = "cuda" if torch.cuda.is_available() else "cpu"
134
+ if checkpoint_path.endswith("safetensors"):
135
+ from safetensors import safe_open
136
+
137
+ checkpoint = {}
138
+ with safe_open(checkpoint_path, framework="pt", device="cpu") as f:
139
+ for key in f.keys():
140
+ checkpoint[key] = f.get_tensor(key)
141
+ else:
142
+ checkpoint = torch.load(checkpoint_path, map_location=device)["state_dict"]
143
+
144
+ # Convert the VAE model.
145
+ vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
146
+ converted_vae_checkpoint = custom_convert_ldm_vae_checkpoint(checkpoint, vae_config)
147
+
148
+ vae = AutoencoderKL(**vae_config)
149
+ vae.load_state_dict(converted_vae_checkpoint)
150
+ vae.save_pretrained(output_path)
151
+
152
+
153
+ if __name__ == "__main__":
154
+ parser = argparse.ArgumentParser()
155
+
156
+ parser.add_argument("--vae_pt_path", default=None, type=str, required=True, help="Path to the VAE.pt to convert.")
157
+ parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the VAE.pt to convert.")
158
+
159
+ args = parser.parse_args()
160
+
161
+ vae_pt_to_vae_diffuser(args.vae_pt_path, args.dump_path)
orangemix.vae.pt ADDED
File without changes
vae/config.json CHANGED
@@ -14,6 +14,7 @@
14
  "DownEncoderBlock2D",
15
  "DownEncoderBlock2D"
16
  ],
 
17
  "in_channels": 3,
18
  "latent_channels": 4,
19
  "layers_per_block": 2,
@@ -27,4 +28,4 @@
27
  "UpDecoderBlock2D",
28
  "UpDecoderBlock2D"
29
  ]
30
- }
 
14
  "DownEncoderBlock2D",
15
  "DownEncoderBlock2D"
16
  ],
17
+ "force_upcast": true,
18
  "in_channels": 3,
19
  "latent_channels": 4,
20
  "layers_per_block": 2,
 
28
  "UpDecoderBlock2D",
29
  "UpDecoderBlock2D"
30
  ]
31
+ }
vae/{diffusion_pytorch_model.safetensors → diffusion_pytorch_model.bin} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:5efc2e02187d3e51fdc774715ed45212a742fd612677a3eea4e118c91239ec64
3
- size 334643268
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d8e7b68fc05f8ebc0475f23bc1e3ca74f5ae87f1fde8b0504d80228806ca4675
3
+ size 334712113