iroiro-lora / tt2ss_key_convert.py
2vXpSwA7's picture
Upload tt2ss_key_convert.py
9d89de8 verified
import os
import re
import argparse
import torch
from safetensors.torch import load_file, save_file
from safetensors import safe_open
# ----------------------------------------------------------------
# キー変換ロジック
# ----------------------------------------------------------------
re_compiled = {}
re_digits = re.compile(r"\d+")
suffix_conversion = {
"attentions": {},
"resnets": {
"conv1": "in_layers.2",
"conv2": "out_layers.3",
"norm1": "in_layers.0",
"norm2": "out_layers.0",
"time_emb_proj": "emb_layers.1",
"conv_shortcut": "skip_connection",
}
}
def _match_and_parse(key, regex_text):
"""指定された正規表現にキーがマッチするか確認し、結果をパースする関数"""
match_list = []
regex = re_compiled.get(regex_text)
if regex is None:
regex = re.compile(regex_text)
re_compiled[regex_text] = regex
r = re.match(regex, key)
if not r:
return None
for g in r.groups():
match_list.append(int(g) if re.match(re_digits, g) else g)
return match_list
def diffusers_to_compvis(key, is_sd2):
"""
【ステップ1】Diffusers形式のキーを標準的なCompVis形式に変換する。
"""
key_body = key
if key.startswith("lora_unet_"):
key_body = key[len("lora_unet_"):]
elif key.startswith("lora_te_"):
key_body = key[len("lora_te_"):]
elif key.startswith("lora_te1_"):
key_body = key[len("lora_te1_"):]
elif key.startswith("lora_te2_"):
key_body = key[len("lora_te2_"):]
# Text Encoder (CLIP)
if m := _match_and_parse(key_body, r"text_model_encoder_layers_(\d+)_(.+)"):
compvis_body = f"transformer.text_model.encoder.layers.{m[0]}.{m[1].replace('_', '.')}"
if is_sd2:
if 'mlp.fc1' in compvis_body:
return compvis_body.replace('mlp.fc1', 'mlp.c_fc')
elif 'mlp.fc2' in compvis_body:
return compvis_body.replace('mlp.fc2', 'mlp.c_proj')
else:
return compvis_body.replace('self.attn', 'attn')
return compvis_body
# Text Encoder 2 (SDXL)
if m := _match_and_parse(key_body, r"1_text_model_encoder_layers_(\d+)_(.+)"):
compvis_body = f"1_model.transformer.resblocks.{m[0]}.{m[1].replace('_', '.')}"
if 'mlp.fc1' in compvis_body:
return compvis_body.replace('mlp.fc1', 'mlp.c_fc')
elif 'mlp.fc2' in compvis_body:
return compvis_body.replace('mlp.fc2', 'mlp.c_proj')
else:
return compvis_body.replace('self.attn', 'attn')
# U-Net
prefix = "diffusion_model."
if m := _match_and_parse(key_body, r"conv_in(.*)"):
return f'{prefix}input_blocks.0.0{m[0]}'
if m := _match_and_parse(key_body, r"conv_out(.*)"):
return f'{prefix}out.2{m[0]}'
if m := _match_and_parse(key_body, r"time_embedding_linear_(\d+)_(.+)"):
return f"{prefix}time_embed.{m[0] * 2 - 2}.{m[1].replace('_', '.')}"
if m := _match_and_parse(key_body, r"down_blocks_(\d+)_attentions_(\d+)_(.+)"):
return f"{prefix}input_blocks.{1 + m[0] * 3 + m[1]}.1.{m[2].replace('_', '.')}"
if m := _match_and_parse(key_body, r"down_blocks_(\d+)_resnets_(\d+)_(.+)"):
suffix = suffix_conversion["resnets"].get(m[2].split('_')[0], m[2].replace('_', '.'))
return f"{prefix}input_blocks.{1 + m[0] * 3 + m[1]}.0.{suffix}"
if m := _match_and_parse(key_body, r"mid_block_attentions_(\d+)_(.+)"):
return f"{prefix}middle_block.1.{m[1].replace('_', '.')}"
if m := _match_and_parse(key_body, r"mid_block_resnets_(\d+)_(.+)"):
suffix = suffix_conversion["resnets"].get(m[1].split('_')[0], m[1].replace('_', '.'))
return f"{prefix}middle_block.{m[0] * 2}.{suffix}"
if m := _match_and_parse(key_body, r"up_blocks_(\d+)_attentions_(\d+)_(.+)"):
return f"{prefix}output_blocks.{m[0] * 3 + m[1]}.1.{m[2].replace('_', '.')}"
if m := _match_and_parse(key_body, r"up_blocks_(\d+)_resnets_(\d+)_(.+)"):
suffix = suffix_conversion["resnets"].get(m[2].split('_')[0], m[2].replace('_', '.'))
return f"{prefix}output_blocks.{m[0] * 3 + m[1]}.0.{suffix}"
if m := _match_and_parse(key_body, r"down_blocks_(\d+)_downsamplers_0_conv"):
return f"{prefix}input_blocks.{3 + m[0] * 3}.0.op"
if m := _match_and_parse(key_body, r"up_blocks_(\d+)_upsamplers_0_conv"):
resnet_idx = 2 if m[0] > 0 else 1
return f"{prefix}output_blocks.{2 + m[0] * 3}.{resnet_idx}.conv"
return None
def compvis_to_lora(key):
"""
【ステップ2】標準CompVis形式のキーをLoRAスクリプト形式に変換する。
"""
if key.startswith("diffusion_model."):
return "lora_unet_" + key[len("diffusion_model."):].replace('.', '_')
elif key.startswith("transformer."):
return "lora_te_" + key[len("transformer."):].replace('.', '_')
return key
def get_new_key(key, is_sd2, output_format):
"""
キー変換のメインプロセスを管理する。
"""
key_stem, *key_suffix_parts = key.split('.')
key_suffix = '.' + '.'.join(key_suffix_parts) if key_suffix_parts else ''
compvis_stem = diffusers_to_compvis(key_stem, is_sd2)
if compvis_stem is None:
return key
# ステップ2: 出力形式に応じて変換
if output_format == 'lora':
final_stem = compvis_to_lora(compvis_stem)
elif output_format == 'compvis':
final_stem = compvis_stem
else:
return key
return final_stem + key_suffix
# ----------------------------------------------------------------
# メインの処理
# ----------------------------------------------------------------
def convert_and_save_lora(input_path, output_path, is_sd2, device, output_format):
"""
LoRAファイルを読み込み、指定された形式にキーを変換して保存する。
"""
print(f"Loading LoRA from: {input_path}")
try:
metadata = {}
state_dict = {}
with safe_open(input_path, framework="pt", device=device) as f:
metadata = f.metadata() or {}
for k in f.keys():
state_dict[k] = f.get_tensor(k)
except Exception as e:
print(f"Error loading safetensors: {e}")
print("Trying to load with torch.load (metadata will be lost)...")
state_dict = torch.load(input_path, map_location=device)
metadata = {}
new_state_dict = {}
converted_count = 0
total_count = len(state_dict)
print(f"\nConverting keys to '{output_format}' format...")
for key, value in state_dict.items():
new_key = get_new_key(key, is_sd2, output_format)
if new_key != key:
print(f" '{key}' -> '{new_key}'")
converted_count += 1
new_state_dict[new_key] = value
print(f"\nConversion complete. {converted_count} of {total_count} keys were converted.")
metadata['ss_converted_by'] = 'convert_lora_keys.py v3'
metadata['ss_output_format'] = output_format
print(f"Saving converted LoRA to: {output_path}")
save_file(new_state_dict, output_path, metadata=metadata)
print("Done.")
import os
import re
import argparse
import torch
from safetensors.torch import load_file, save_file
from safetensors import safe_open
re_compiled = {}
re_digits = re.compile(r"\d+")
suffix_conversion = {
"attentions": {},
"resnets": {
"conv1": "in_layers.2",
"conv2": "out_layers.3",
"norm1": "in_layers.0",
"norm2": "out_layers.0",
"time_emb_proj": "emb_layers.1",
"conv_shortcut": "skip_connection",
}
}
def _match_and_parse(key, regex_text):
"""指定された正規表現にキーがマッチするか確認し、結果をパースする関数"""
match_list = []
regex = re_compiled.get(regex_text)
if regex is None:
regex = re.compile(regex_text)
re_compiled[regex_text] = regex
r = re.match(regex, key)
if not r:
return None
for g in r.groups():
match_list.append(int(g) if re.match(re_digits, g) else g)
return match_list
def diffusers_to_compvis(key, is_sd2):
"""
【ステップ1】Diffusers形式のキーを標準的なCompVis形式に変換する。
"""
key_body = key
if key.startswith("lora_unet_"):
key_body = key[len("lora_unet_"):]
elif key.startswith("lora_te_"):
key_body = key[len("lora_te_"):]
elif key.startswith("lora_te1_"):
key_body = key[len("lora_te1_"):]
elif key.startswith("lora_te2_"):
key_body = key[len("lora_te2_"):]
# Text Encoder (CLIP)
if m := _match_and_parse(key_body, r"text_model_encoder_layers_(\d+)_(.+)"):
compvis_body = f"transformer.text_model.encoder.layers.{m[0]}.{m[1].replace('_', '.')}"
if is_sd2:
if 'mlp.fc1' in compvis_body:
return compvis_body.replace('mlp.fc1', 'mlp.c_fc')
elif 'mlp.fc2' in compvis_body:
return compvis_body.replace('mlp.fc2', 'mlp.c_proj')
else:
return compvis_body.replace('self.attn', 'attn')
return compvis_body
# Text Encoder 2 (SDXL)
if m := _match_and_parse(key_body, r"1_text_model_encoder_layers_(\d+)_(.+)"):
compvis_body = f"1_model.transformer.resblocks.{m[0]}.{m[1].replace('_', '.')}"
if 'mlp.fc1' in compvis_body:
return compvis_body.replace('mlp.fc1', 'mlp.c_fc')
elif 'mlp.fc2' in compvis_body:
return compvis_body.replace('mlp.fc2', 'mlp.c_proj')
else:
return compvis_body.replace('self.attn', 'attn')
# U-Net
prefix = "diffusion_model."
if m := _match_and_parse(key_body, r"conv_in(.*)"):
return f'{prefix}input_blocks.0.0{m[0]}'
if m := _match_and_parse(key_body, r"conv_out(.*)"):
return f'{prefix}out.2{m[0]}'
if m := _match_and_parse(key_body, r"time_embedding_linear_(\d+)_(.+)"):
return f"{prefix}time_embed.{m[0] * 2 - 2}.{m[1].replace('_', '.')}"
if m := _match_and_parse(key_body, r"down_blocks_(\d+)_attentions_(\d+)_(.+)"):
return f"{prefix}input_blocks.{1 + m[0] * 3 + m[1]}.1.{m[2].replace('_', '.')}"
if m := _match_and_parse(key_body, r"down_blocks_(\d+)_resnets_(\d+)_(.+)"):
suffix = suffix_conversion["resnets"].get(m[2].split('_')[0], m[2].replace('_', '.'))
return f"{prefix}input_blocks.{1 + m[0] * 3 + m[1]}.0.{suffix}"
if m := _match_and_parse(key_body, r"mid_block_attentions_(\d+)_(.+)"):
return f"{prefix}middle_block.1.{m[1].replace('_', '.')}"
if m := _match_and_parse(key_body, r"mid_block_resnets_(\d+)_(.+)"):
suffix = suffix_conversion["resnets"].get(m[1].split('_')[0], m[1].replace('_', '.'))
return f"{prefix}middle_block.{m[0] * 2}.{suffix}"
if m := _match_and_parse(key_body, r"up_blocks_(\d+)_attentions_(\d+)_(.+)"):
return f"{prefix}output_blocks.{m[0] * 3 + m[1]}.1.{m[2].replace('_', '.')}"
if m := _match_and_parse(key_body, r"up_blocks_(\d+)_resnets_(\d+)_(.+)"):
suffix = suffix_conversion["resnets"].get(m[2].split('_')[0], m[2].replace('_', '.'))
return f"{prefix}output_blocks.{m[0] * 3 + m[1]}.0.{suffix}"
if m := _match_and_parse(key_body, r"down_blocks_(\d+)_downsamplers_0_conv"):
return f"{prefix}input_blocks.{3 + m[0] * 3}.0.op"
if m := _match_and_parse(key_body, r"up_blocks_(\d+)_upsamplers_0_conv"):
resnet_idx = 2 if m[0] > 0 else 1
return f"{prefix}output_blocks.{2 + m[0] * 3}.{resnet_idx}.conv"
return None
def compvis_to_lora(key):
"""
【ステップ2】標準CompVis形式のキーをLoRAスクリプト形式に変換する。
"""
if key.startswith("diffusion_model."):
return "lora_unet_" + key[len("diffusion_model."):].replace('.', '_')
elif key.startswith("transformer."):
return "lora_te_" + key[len("transformer."):].replace('.', '_')
return key
def get_new_key(key, is_sd2, output_format):
"""
キー変換のメインプロセスを管理する。
"""
key_stem, *key_suffix_parts = key.split('.')
key_suffix = '.' + '.'.join(key_suffix_parts) if key_suffix_parts else ''
compvis_stem = diffusers_to_compvis(key_stem, is_sd2)
if compvis_stem is None:
return key
# ステップ2: 出力形式に応じて変換
if output_format == 'lora':
final_stem = compvis_to_lora(compvis_stem)
elif output_format == 'compvis':
final_stem = compvis_stem
else:
return key
return final_stem + key_suffix
def convert_and_save_lora(input_path, output_path, is_sd2, device, output_format):
"""
LoRAファイルを読み込み、指定された形式にキーを変換して保存する。
"""
print(f"Loading LoRA from: {input_path}")
try:
metadata = {}
state_dict = {}
with safe_open(input_path, framework="pt", device=device) as f:
metadata = f.metadata() or {}
for k in f.keys():
state_dict[k] = f.get_tensor(k)
except Exception as e:
print(f"Error loading safetensors: {e}")
print("Trying to load with torch.load (metadata will be lost)...")
state_dict = torch.load(input_path, map_location=device)
metadata = {}
new_state_dict = {}
converted_count = 0
total_count = len(state_dict)
#print(f"\nConverting keys to '{output_format}' format...")
for key, value in state_dict.items():
new_key = get_new_key(key, is_sd2, output_format)
if new_key != key:
# print(f" '{key}' -> '{new_key}'")
converted_count += 1
new_state_dict[new_key] = value
if converted_count == 0:
print("No keys were converted. The model might already be in the correct format.")
else:
print(f"Conversion complete. {converted_count} of {total_count} keys were converted.")
metadata['ss_converted_by'] = 'convert_lora_keys.py v3'
metadata['ss_output_format'] = output_format
print(f"Saving converted LoRA to: {output_path}")
save_file(new_state_dict, output_path, metadata=metadata)
print("Done.")
def main():
parser = argparse.ArgumentParser(
description="Convert LoRA model keys between formats. Supports single file or batch processing of a directory.",
formatter_class=argparse.RawTextHelpFormatter
)
parser.add_argument("--input", type=str, required=True, help="Path to the input LoRA file or directory.")
parser.add_argument("--output", type=str, required=True, help="Path to the output LoRA file or directory.")
parser.add_argument(
"--format",
type=str,
default="lora",
choices=['lora', 'compvis'],
help="Output key format:\n"
" 'lora': For LoRA training scripts (e.g., 'lora_unet_...'). (default)\n"
" 'compvis': Standard CompVis format (e.g., 'diffusion_model....')."
)
parser.add_argument("--is_sd2", action="store_true", help="Set this flag if the LoRA is for a Stable Diffusion 2.x model.")
parser.add_argument("--device", type=str, default="cpu", help="Device to load the model on (e.g., 'cpu', 'cuda').")
args = parser.parse_args()
if not os.path.exists(args.input):
print(f"Error: Input path not found at '{args.input}'")
return
is_input_dir = os.path.isdir(args.input)
# --- バッチ処理 (入力がディレクトリの場合) ---
if is_input_dir:
if os.path.exists(args.output) and not os.path.isdir(args.output):
print(f"Error: Input is a directory, but output path '{args.output}' is an existing file. Please specify a directory for the output.")
return
os.makedirs(args.output, exist_ok=True)
supported_extensions = ('.safetensors', '.pt', '.ckpt')
files_to_process = [f for f in os.listdir(args.input) if f.lower().endswith(supported_extensions)]
if not files_to_process:
print(f"No compatible LoRA files found in '{args.input}'. Searched for {supported_extensions}.")
return
print(f"Found {len(files_to_process)} files for batch processing in '{args.input}'.")
for filename in files_to_process:
input_file = os.path.join(args.input, filename)
base_name, _ = os.path.splitext(filename)
output_file = os.path.join(args.output, f"{base_name}.safetensors")
print("-" * 60)
try:
convert_and_save_lora(input_file, output_file, args.is_sd2, args.device, args.format)
except Exception as e:
print(f"An error occurred while processing '{input_file}': {e}")
print("-" * 60)
print("Batch processing finished.")
# --- 単体処理 (入力がファイルの場合) ---
else:
final_output_path = args.output
if os.path.isdir(final_output_path):
base_name, _ = os.path.splitext(os.path.basename(args.input))
final_output_path = os.path.join(final_output_path, f"{base_name}.safetensors")
output_dir = os.path.dirname(final_output_path)
if output_dir:
os.makedirs(output_dir, exist_ok=True)
if not final_output_path.lower().endswith(".safetensors"):
print(f"Warning: Output path '{final_output_path}' does not end with .safetensors. It will be saved in this format anyway.")
convert_and_save_lora(args.input, final_output_path, args.is_sd2, args.device, args.format)
if __name__ == "__main__":
main()