|
import os
|
|
import re
|
|
import argparse
|
|
import torch
|
|
from safetensors.torch import load_file, save_file
|
|
from safetensors import safe_open
|
|
|
|
|
|
|
|
|
|
|
|
re_compiled = {}
|
|
re_digits = re.compile(r"\d+")
|
|
|
|
suffix_conversion = {
|
|
"attentions": {},
|
|
"resnets": {
|
|
"conv1": "in_layers.2",
|
|
"conv2": "out_layers.3",
|
|
"norm1": "in_layers.0",
|
|
"norm2": "out_layers.0",
|
|
"time_emb_proj": "emb_layers.1",
|
|
"conv_shortcut": "skip_connection",
|
|
}
|
|
}
|
|
|
|
def _match_and_parse(key, regex_text):
|
|
"""指定された正規表現にキーがマッチするか確認し、結果をパースする関数"""
|
|
match_list = []
|
|
regex = re_compiled.get(regex_text)
|
|
if regex is None:
|
|
regex = re.compile(regex_text)
|
|
re_compiled[regex_text] = regex
|
|
|
|
r = re.match(regex, key)
|
|
if not r:
|
|
return None
|
|
|
|
for g in r.groups():
|
|
match_list.append(int(g) if re.match(re_digits, g) else g)
|
|
return match_list
|
|
|
|
def diffusers_to_compvis(key, is_sd2):
|
|
"""
|
|
【ステップ1】Diffusers形式のキーを標準的なCompVis形式に変換する。
|
|
"""
|
|
key_body = key
|
|
if key.startswith("lora_unet_"):
|
|
key_body = key[len("lora_unet_"):]
|
|
elif key.startswith("lora_te_"):
|
|
key_body = key[len("lora_te_"):]
|
|
elif key.startswith("lora_te1_"):
|
|
key_body = key[len("lora_te1_"):]
|
|
elif key.startswith("lora_te2_"):
|
|
key_body = key[len("lora_te2_"):]
|
|
|
|
|
|
if m := _match_and_parse(key_body, r"text_model_encoder_layers_(\d+)_(.+)"):
|
|
compvis_body = f"transformer.text_model.encoder.layers.{m[0]}.{m[1].replace('_', '.')}"
|
|
if is_sd2:
|
|
if 'mlp.fc1' in compvis_body:
|
|
return compvis_body.replace('mlp.fc1', 'mlp.c_fc')
|
|
elif 'mlp.fc2' in compvis_body:
|
|
return compvis_body.replace('mlp.fc2', 'mlp.c_proj')
|
|
else:
|
|
return compvis_body.replace('self.attn', 'attn')
|
|
return compvis_body
|
|
|
|
|
|
if m := _match_and_parse(key_body, r"1_text_model_encoder_layers_(\d+)_(.+)"):
|
|
compvis_body = f"1_model.transformer.resblocks.{m[0]}.{m[1].replace('_', '.')}"
|
|
if 'mlp.fc1' in compvis_body:
|
|
return compvis_body.replace('mlp.fc1', 'mlp.c_fc')
|
|
elif 'mlp.fc2' in compvis_body:
|
|
return compvis_body.replace('mlp.fc2', 'mlp.c_proj')
|
|
else:
|
|
return compvis_body.replace('self.attn', 'attn')
|
|
|
|
|
|
prefix = "diffusion_model."
|
|
if m := _match_and_parse(key_body, r"conv_in(.*)"):
|
|
return f'{prefix}input_blocks.0.0{m[0]}'
|
|
if m := _match_and_parse(key_body, r"conv_out(.*)"):
|
|
return f'{prefix}out.2{m[0]}'
|
|
if m := _match_and_parse(key_body, r"time_embedding_linear_(\d+)_(.+)"):
|
|
return f"{prefix}time_embed.{m[0] * 2 - 2}.{m[1].replace('_', '.')}"
|
|
if m := _match_and_parse(key_body, r"down_blocks_(\d+)_attentions_(\d+)_(.+)"):
|
|
return f"{prefix}input_blocks.{1 + m[0] * 3 + m[1]}.1.{m[2].replace('_', '.')}"
|
|
if m := _match_and_parse(key_body, r"down_blocks_(\d+)_resnets_(\d+)_(.+)"):
|
|
suffix = suffix_conversion["resnets"].get(m[2].split('_')[0], m[2].replace('_', '.'))
|
|
return f"{prefix}input_blocks.{1 + m[0] * 3 + m[1]}.0.{suffix}"
|
|
if m := _match_and_parse(key_body, r"mid_block_attentions_(\d+)_(.+)"):
|
|
return f"{prefix}middle_block.1.{m[1].replace('_', '.')}"
|
|
if m := _match_and_parse(key_body, r"mid_block_resnets_(\d+)_(.+)"):
|
|
suffix = suffix_conversion["resnets"].get(m[1].split('_')[0], m[1].replace('_', '.'))
|
|
return f"{prefix}middle_block.{m[0] * 2}.{suffix}"
|
|
if m := _match_and_parse(key_body, r"up_blocks_(\d+)_attentions_(\d+)_(.+)"):
|
|
return f"{prefix}output_blocks.{m[0] * 3 + m[1]}.1.{m[2].replace('_', '.')}"
|
|
if m := _match_and_parse(key_body, r"up_blocks_(\d+)_resnets_(\d+)_(.+)"):
|
|
suffix = suffix_conversion["resnets"].get(m[2].split('_')[0], m[2].replace('_', '.'))
|
|
return f"{prefix}output_blocks.{m[0] * 3 + m[1]}.0.{suffix}"
|
|
if m := _match_and_parse(key_body, r"down_blocks_(\d+)_downsamplers_0_conv"):
|
|
return f"{prefix}input_blocks.{3 + m[0] * 3}.0.op"
|
|
if m := _match_and_parse(key_body, r"up_blocks_(\d+)_upsamplers_0_conv"):
|
|
resnet_idx = 2 if m[0] > 0 else 1
|
|
return f"{prefix}output_blocks.{2 + m[0] * 3}.{resnet_idx}.conv"
|
|
|
|
return None
|
|
|
|
def compvis_to_lora(key):
|
|
"""
|
|
【ステップ2】標準CompVis形式のキーをLoRAスクリプト形式に変換する。
|
|
"""
|
|
if key.startswith("diffusion_model."):
|
|
return "lora_unet_" + key[len("diffusion_model."):].replace('.', '_')
|
|
elif key.startswith("transformer."):
|
|
return "lora_te_" + key[len("transformer."):].replace('.', '_')
|
|
return key
|
|
|
|
|
|
def get_new_key(key, is_sd2, output_format):
|
|
"""
|
|
キー変換のメインプロセスを管理する。
|
|
"""
|
|
key_stem, *key_suffix_parts = key.split('.')
|
|
key_suffix = '.' + '.'.join(key_suffix_parts) if key_suffix_parts else ''
|
|
|
|
compvis_stem = diffusers_to_compvis(key_stem, is_sd2)
|
|
|
|
if compvis_stem is None:
|
|
return key
|
|
|
|
|
|
if output_format == 'lora':
|
|
final_stem = compvis_to_lora(compvis_stem)
|
|
elif output_format == 'compvis':
|
|
final_stem = compvis_stem
|
|
else:
|
|
return key
|
|
|
|
return final_stem + key_suffix
|
|
|
|
|
|
|
|
|
|
|
|
def convert_and_save_lora(input_path, output_path, is_sd2, device, output_format):
|
|
"""
|
|
LoRAファイルを読み込み、指定された形式にキーを変換して保存する。
|
|
"""
|
|
print(f"Loading LoRA from: {input_path}")
|
|
|
|
try:
|
|
metadata = {}
|
|
state_dict = {}
|
|
with safe_open(input_path, framework="pt", device=device) as f:
|
|
metadata = f.metadata() or {}
|
|
for k in f.keys():
|
|
state_dict[k] = f.get_tensor(k)
|
|
except Exception as e:
|
|
print(f"Error loading safetensors: {e}")
|
|
print("Trying to load with torch.load (metadata will be lost)...")
|
|
state_dict = torch.load(input_path, map_location=device)
|
|
metadata = {}
|
|
|
|
new_state_dict = {}
|
|
converted_count = 0
|
|
total_count = len(state_dict)
|
|
|
|
print(f"\nConverting keys to '{output_format}' format...")
|
|
for key, value in state_dict.items():
|
|
new_key = get_new_key(key, is_sd2, output_format)
|
|
if new_key != key:
|
|
print(f" '{key}' -> '{new_key}'")
|
|
converted_count += 1
|
|
new_state_dict[new_key] = value
|
|
|
|
print(f"\nConversion complete. {converted_count} of {total_count} keys were converted.")
|
|
|
|
metadata['ss_converted_by'] = 'convert_lora_keys.py v3'
|
|
metadata['ss_output_format'] = output_format
|
|
|
|
print(f"Saving converted LoRA to: {output_path}")
|
|
save_file(new_state_dict, output_path, metadata=metadata)
|
|
print("Done.")
|
|
|
|
|
|
import os
|
|
import re
|
|
import argparse
|
|
import torch
|
|
from safetensors.torch import load_file, save_file
|
|
from safetensors import safe_open
|
|
|
|
re_compiled = {}
|
|
re_digits = re.compile(r"\d+")
|
|
|
|
suffix_conversion = {
|
|
"attentions": {},
|
|
"resnets": {
|
|
"conv1": "in_layers.2",
|
|
"conv2": "out_layers.3",
|
|
"norm1": "in_layers.0",
|
|
"norm2": "out_layers.0",
|
|
"time_emb_proj": "emb_layers.1",
|
|
"conv_shortcut": "skip_connection",
|
|
}
|
|
}
|
|
|
|
def _match_and_parse(key, regex_text):
|
|
"""指定された正規表現にキーがマッチするか確認し、結果をパースする関数"""
|
|
match_list = []
|
|
regex = re_compiled.get(regex_text)
|
|
if regex is None:
|
|
regex = re.compile(regex_text)
|
|
re_compiled[regex_text] = regex
|
|
|
|
r = re.match(regex, key)
|
|
if not r:
|
|
return None
|
|
|
|
for g in r.groups():
|
|
match_list.append(int(g) if re.match(re_digits, g) else g)
|
|
return match_list
|
|
|
|
def diffusers_to_compvis(key, is_sd2):
|
|
"""
|
|
【ステップ1】Diffusers形式のキーを標準的なCompVis形式に変換する。
|
|
"""
|
|
key_body = key
|
|
if key.startswith("lora_unet_"):
|
|
key_body = key[len("lora_unet_"):]
|
|
elif key.startswith("lora_te_"):
|
|
key_body = key[len("lora_te_"):]
|
|
elif key.startswith("lora_te1_"):
|
|
key_body = key[len("lora_te1_"):]
|
|
elif key.startswith("lora_te2_"):
|
|
key_body = key[len("lora_te2_"):]
|
|
|
|
|
|
if m := _match_and_parse(key_body, r"text_model_encoder_layers_(\d+)_(.+)"):
|
|
compvis_body = f"transformer.text_model.encoder.layers.{m[0]}.{m[1].replace('_', '.')}"
|
|
if is_sd2:
|
|
if 'mlp.fc1' in compvis_body:
|
|
return compvis_body.replace('mlp.fc1', 'mlp.c_fc')
|
|
elif 'mlp.fc2' in compvis_body:
|
|
return compvis_body.replace('mlp.fc2', 'mlp.c_proj')
|
|
else:
|
|
return compvis_body.replace('self.attn', 'attn')
|
|
return compvis_body
|
|
|
|
|
|
if m := _match_and_parse(key_body, r"1_text_model_encoder_layers_(\d+)_(.+)"):
|
|
compvis_body = f"1_model.transformer.resblocks.{m[0]}.{m[1].replace('_', '.')}"
|
|
if 'mlp.fc1' in compvis_body:
|
|
return compvis_body.replace('mlp.fc1', 'mlp.c_fc')
|
|
elif 'mlp.fc2' in compvis_body:
|
|
return compvis_body.replace('mlp.fc2', 'mlp.c_proj')
|
|
else:
|
|
return compvis_body.replace('self.attn', 'attn')
|
|
|
|
|
|
prefix = "diffusion_model."
|
|
if m := _match_and_parse(key_body, r"conv_in(.*)"):
|
|
return f'{prefix}input_blocks.0.0{m[0]}'
|
|
if m := _match_and_parse(key_body, r"conv_out(.*)"):
|
|
return f'{prefix}out.2{m[0]}'
|
|
if m := _match_and_parse(key_body, r"time_embedding_linear_(\d+)_(.+)"):
|
|
return f"{prefix}time_embed.{m[0] * 2 - 2}.{m[1].replace('_', '.')}"
|
|
if m := _match_and_parse(key_body, r"down_blocks_(\d+)_attentions_(\d+)_(.+)"):
|
|
return f"{prefix}input_blocks.{1 + m[0] * 3 + m[1]}.1.{m[2].replace('_', '.')}"
|
|
if m := _match_and_parse(key_body, r"down_blocks_(\d+)_resnets_(\d+)_(.+)"):
|
|
suffix = suffix_conversion["resnets"].get(m[2].split('_')[0], m[2].replace('_', '.'))
|
|
return f"{prefix}input_blocks.{1 + m[0] * 3 + m[1]}.0.{suffix}"
|
|
if m := _match_and_parse(key_body, r"mid_block_attentions_(\d+)_(.+)"):
|
|
return f"{prefix}middle_block.1.{m[1].replace('_', '.')}"
|
|
if m := _match_and_parse(key_body, r"mid_block_resnets_(\d+)_(.+)"):
|
|
suffix = suffix_conversion["resnets"].get(m[1].split('_')[0], m[1].replace('_', '.'))
|
|
return f"{prefix}middle_block.{m[0] * 2}.{suffix}"
|
|
if m := _match_and_parse(key_body, r"up_blocks_(\d+)_attentions_(\d+)_(.+)"):
|
|
return f"{prefix}output_blocks.{m[0] * 3 + m[1]}.1.{m[2].replace('_', '.')}"
|
|
if m := _match_and_parse(key_body, r"up_blocks_(\d+)_resnets_(\d+)_(.+)"):
|
|
suffix = suffix_conversion["resnets"].get(m[2].split('_')[0], m[2].replace('_', '.'))
|
|
return f"{prefix}output_blocks.{m[0] * 3 + m[1]}.0.{suffix}"
|
|
if m := _match_and_parse(key_body, r"down_blocks_(\d+)_downsamplers_0_conv"):
|
|
return f"{prefix}input_blocks.{3 + m[0] * 3}.0.op"
|
|
if m := _match_and_parse(key_body, r"up_blocks_(\d+)_upsamplers_0_conv"):
|
|
resnet_idx = 2 if m[0] > 0 else 1
|
|
return f"{prefix}output_blocks.{2 + m[0] * 3}.{resnet_idx}.conv"
|
|
|
|
return None
|
|
|
|
def compvis_to_lora(key):
|
|
"""
|
|
【ステップ2】標準CompVis形式のキーをLoRAスクリプト形式に変換する。
|
|
"""
|
|
if key.startswith("diffusion_model."):
|
|
return "lora_unet_" + key[len("diffusion_model."):].replace('.', '_')
|
|
elif key.startswith("transformer."):
|
|
return "lora_te_" + key[len("transformer."):].replace('.', '_')
|
|
return key
|
|
|
|
|
|
def get_new_key(key, is_sd2, output_format):
|
|
"""
|
|
キー変換のメインプロセスを管理する。
|
|
"""
|
|
key_stem, *key_suffix_parts = key.split('.')
|
|
key_suffix = '.' + '.'.join(key_suffix_parts) if key_suffix_parts else ''
|
|
|
|
compvis_stem = diffusers_to_compvis(key_stem, is_sd2)
|
|
|
|
if compvis_stem is None:
|
|
return key
|
|
|
|
|
|
if output_format == 'lora':
|
|
final_stem = compvis_to_lora(compvis_stem)
|
|
elif output_format == 'compvis':
|
|
final_stem = compvis_stem
|
|
else:
|
|
return key
|
|
|
|
return final_stem + key_suffix
|
|
|
|
def convert_and_save_lora(input_path, output_path, is_sd2, device, output_format):
|
|
"""
|
|
LoRAファイルを読み込み、指定された形式にキーを変換して保存する。
|
|
"""
|
|
print(f"Loading LoRA from: {input_path}")
|
|
|
|
try:
|
|
metadata = {}
|
|
state_dict = {}
|
|
with safe_open(input_path, framework="pt", device=device) as f:
|
|
metadata = f.metadata() or {}
|
|
for k in f.keys():
|
|
state_dict[k] = f.get_tensor(k)
|
|
except Exception as e:
|
|
print(f"Error loading safetensors: {e}")
|
|
print("Trying to load with torch.load (metadata will be lost)...")
|
|
state_dict = torch.load(input_path, map_location=device)
|
|
metadata = {}
|
|
|
|
new_state_dict = {}
|
|
converted_count = 0
|
|
total_count = len(state_dict)
|
|
|
|
|
|
for key, value in state_dict.items():
|
|
new_key = get_new_key(key, is_sd2, output_format)
|
|
if new_key != key:
|
|
|
|
converted_count += 1
|
|
new_state_dict[new_key] = value
|
|
|
|
if converted_count == 0:
|
|
print("No keys were converted. The model might already be in the correct format.")
|
|
else:
|
|
print(f"Conversion complete. {converted_count} of {total_count} keys were converted.")
|
|
|
|
metadata['ss_converted_by'] = 'convert_lora_keys.py v3'
|
|
metadata['ss_output_format'] = output_format
|
|
|
|
print(f"Saving converted LoRA to: {output_path}")
|
|
save_file(new_state_dict, output_path, metadata=metadata)
|
|
print("Done.")
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="Convert LoRA model keys between formats. Supports single file or batch processing of a directory.",
|
|
formatter_class=argparse.RawTextHelpFormatter
|
|
)
|
|
parser.add_argument("--input", type=str, required=True, help="Path to the input LoRA file or directory.")
|
|
parser.add_argument("--output", type=str, required=True, help="Path to the output LoRA file or directory.")
|
|
parser.add_argument(
|
|
"--format",
|
|
type=str,
|
|
default="lora",
|
|
choices=['lora', 'compvis'],
|
|
help="Output key format:\n"
|
|
" 'lora': For LoRA training scripts (e.g., 'lora_unet_...'). (default)\n"
|
|
" 'compvis': Standard CompVis format (e.g., 'diffusion_model....')."
|
|
)
|
|
parser.add_argument("--is_sd2", action="store_true", help="Set this flag if the LoRA is for a Stable Diffusion 2.x model.")
|
|
parser.add_argument("--device", type=str, default="cpu", help="Device to load the model on (e.g., 'cpu', 'cuda').")
|
|
|
|
args = parser.parse_args()
|
|
|
|
if not os.path.exists(args.input):
|
|
print(f"Error: Input path not found at '{args.input}'")
|
|
return
|
|
|
|
is_input_dir = os.path.isdir(args.input)
|
|
|
|
|
|
if is_input_dir:
|
|
if os.path.exists(args.output) and not os.path.isdir(args.output):
|
|
print(f"Error: Input is a directory, but output path '{args.output}' is an existing file. Please specify a directory for the output.")
|
|
return
|
|
|
|
os.makedirs(args.output, exist_ok=True)
|
|
|
|
supported_extensions = ('.safetensors', '.pt', '.ckpt')
|
|
files_to_process = [f for f in os.listdir(args.input) if f.lower().endswith(supported_extensions)]
|
|
|
|
if not files_to_process:
|
|
print(f"No compatible LoRA files found in '{args.input}'. Searched for {supported_extensions}.")
|
|
return
|
|
|
|
print(f"Found {len(files_to_process)} files for batch processing in '{args.input}'.")
|
|
|
|
for filename in files_to_process:
|
|
input_file = os.path.join(args.input, filename)
|
|
|
|
base_name, _ = os.path.splitext(filename)
|
|
output_file = os.path.join(args.output, f"{base_name}.safetensors")
|
|
|
|
print("-" * 60)
|
|
try:
|
|
convert_and_save_lora(input_file, output_file, args.is_sd2, args.device, args.format)
|
|
except Exception as e:
|
|
print(f"An error occurred while processing '{input_file}': {e}")
|
|
|
|
print("-" * 60)
|
|
print("Batch processing finished.")
|
|
|
|
|
|
else:
|
|
final_output_path = args.output
|
|
|
|
if os.path.isdir(final_output_path):
|
|
base_name, _ = os.path.splitext(os.path.basename(args.input))
|
|
final_output_path = os.path.join(final_output_path, f"{base_name}.safetensors")
|
|
|
|
output_dir = os.path.dirname(final_output_path)
|
|
if output_dir:
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
if not final_output_path.lower().endswith(".safetensors"):
|
|
print(f"Warning: Output path '{final_output_path}' does not end with .safetensors. It will be saved in this format anyway.")
|
|
|
|
convert_and_save_lora(args.input, final_output_path, args.is_sd2, args.device, args.format)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |