aufa.husen commited on
Commit
2e5de56
·
1 Parent(s): 5cceaab

use safetensor

Browse files
main.py DELETED
@@ -1,161 +0,0 @@
1
- import argparse
2
- import io
3
-
4
- import requests
5
- import torch
6
- from omegaconf import OmegaConf
7
-
8
- from diffusers import AutoencoderKL
9
-
10
-
11
- from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
12
- assign_to_checkpoint,
13
- conv_attn_to_linear,
14
- create_vae_diffusers_config,
15
- renew_vae_attention_paths,
16
- renew_vae_resnet_paths,
17
- )
18
-
19
-
20
- def custom_convert_ldm_vae_checkpoint(checkpoint, config):
21
- vae_state_dict = checkpoint
22
-
23
- new_checkpoint = {}
24
-
25
- new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
26
- new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
27
- new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
28
- new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
29
- new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
30
- new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
31
-
32
- new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
33
- new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
34
- new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
35
- new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
36
- new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
37
- new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
38
-
39
- new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
40
- new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
41
- new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
42
- new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
43
-
44
- # Retrieves the keys for the encoder down blocks only
45
- num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
46
- down_blocks = {
47
- layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
48
- }
49
-
50
- # Retrieves the keys for the decoder up blocks only
51
- num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
52
- up_blocks = {
53
- layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
54
- }
55
-
56
- for i in range(num_down_blocks):
57
- resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
58
-
59
- if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
60
- new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
61
- f"encoder.down.{i}.downsample.conv.weight"
62
- )
63
- new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
64
- f"encoder.down.{i}.downsample.conv.bias"
65
- )
66
-
67
- paths = renew_vae_resnet_paths(resnets)
68
- meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
69
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
70
-
71
- mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
72
- num_mid_res_blocks = 2
73
- for i in range(1, num_mid_res_blocks + 1):
74
- resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
75
-
76
- paths = renew_vae_resnet_paths(resnets)
77
- meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
78
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
79
-
80
- mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
81
- paths = renew_vae_attention_paths(mid_attentions)
82
- meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
83
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
84
- conv_attn_to_linear(new_checkpoint)
85
-
86
- for i in range(num_up_blocks):
87
- block_id = num_up_blocks - 1 - i
88
- resnets = [
89
- key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
90
- ]
91
-
92
- if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
93
- new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
94
- f"decoder.up.{block_id}.upsample.conv.weight"
95
- ]
96
- new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
97
- f"decoder.up.{block_id}.upsample.conv.bias"
98
- ]
99
-
100
- paths = renew_vae_resnet_paths(resnets)
101
- meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
102
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
103
-
104
- mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
105
- num_mid_res_blocks = 2
106
- for i in range(1, num_mid_res_blocks + 1):
107
- resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
108
-
109
- paths = renew_vae_resnet_paths(resnets)
110
- meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
111
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
112
-
113
- mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
114
- paths = renew_vae_attention_paths(mid_attentions)
115
- meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
116
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
117
- conv_attn_to_linear(new_checkpoint)
118
- return new_checkpoint
119
-
120
-
121
- def vae_pt_to_vae_diffuser(
122
- checkpoint_path: str,
123
- output_path: str,
124
- ):
125
- # Only support V1
126
- r = requests.get(
127
- " https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
128
- )
129
- io_obj = io.BytesIO(r.content)
130
-
131
- original_config = OmegaConf.load(io_obj)
132
- image_size = 512
133
- device = "cuda" if torch.cuda.is_available() else "cpu"
134
- if checkpoint_path.endswith("safetensors"):
135
- from safetensors import safe_open
136
-
137
- checkpoint = {}
138
- with safe_open(checkpoint_path, framework="pt", device="cpu") as f:
139
- for key in f.keys():
140
- checkpoint[key] = f.get_tensor(key)
141
- else:
142
- checkpoint = torch.load(checkpoint_path, map_location=device)["state_dict"]
143
-
144
- # Convert the VAE model.
145
- vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
146
- converted_vae_checkpoint = custom_convert_ldm_vae_checkpoint(checkpoint, vae_config)
147
-
148
- vae = AutoencoderKL(**vae_config)
149
- vae.load_state_dict(converted_vae_checkpoint)
150
- vae.save_pretrained(output_path)
151
-
152
-
153
- if __name__ == "__main__":
154
- parser = argparse.ArgumentParser()
155
-
156
- parser.add_argument("--vae_pt_path", default=None, type=str, required=True, help="Path to the VAE.pt to convert.")
157
- parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the VAE.pt to convert.")
158
-
159
- args = parser.parse_args()
160
-
161
- vae_pt_to_vae_diffuser(args.vae_pt_path, args.dump_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
orangemix.vae.pt DELETED
File without changes
vae/{diffusion_pytorch_model.bin → diffusion_pytorch_model.safetensors} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:d8e7b68fc05f8ebc0475f23bc1e3ca74f5ae87f1fde8b0504d80228806ca4675
3
- size 334712113
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:032e8cc3e99294ab44d6f4889a1ec527d0e7ff5db3d5d9583bb0312550f0c590
3
+ size 334643268