2vXpSwA7 commited on
Commit
2e50127
·
verified ·
1 Parent(s): c1a5584

Upload keybased_modelmerger.py

Browse files
Files changed (1) hide show
  1. keybased_modelmerger.py +93 -0
keybased_modelmerger.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from safetensors.torch import safe_open
3
+ from modules import scripts, sd_models, shared
4
+ import gradio as gr
5
+ from modules.processing import process_images
6
+
7
+
8
+ class KeyBasedModelMerger(scripts.Script):
9
+ def title(self):
10
+ return "Key-based model merging"
11
+
12
+ def ui(self, is_txt2img):
13
+ # UI コンポーネントを定義
14
+ model_names = sorted(sd_models.checkpoints_list.keys(), key=str.casefold)
15
+
16
+ model_a_dropdown = gr.Dropdown(
17
+ label="Model A", choices=model_names, value=model_names[0] if model_names else None
18
+ )
19
+ model_b_dropdown = gr.Dropdown(
20
+ label="Model B", choices=model_names, value=model_names[0] if model_names else None
21
+ )
22
+ keys_and_alphas_textbox = gr.Textbox(
23
+ label="マージするテンソルのキーとマージ比率 (部分一致, 1行に1つ, カンマ区切り)",
24
+ lines=5,
25
+ placeholder="例:\nmodel.diffusion_model.input_blocks.0,0.5\nmodel.diffusion_model.middle_block,0.3"
26
+ )
27
+ merge_checkbox = gr.Checkbox(label="モデルのマージを有効にする", value=True)
28
+ use_gpu_checkbox = gr.Checkbox(label="GPUを使用", value=True) # GPU/CPU切り替えチェックボックス
29
+ batch_size_slider = gr.Slider(minimum=1, maximum=500, step=1, value=250, label="KeyMgerge_BatchSize")
30
+
31
+ return [model_a_dropdown, model_b_dropdown, keys_and_alphas_textbox, merge_checkbox, use_gpu_checkbox, batch_size_slider]
32
+
33
+ def run(self, p, model_a_name, model_b_name, keys_and_alphas_str, merge_enabled, use_gpu, batch_size):
34
+ if not model_a_name or not model_b_name:
35
+ print("Error: Model A or Model B is not selected.")
36
+ return p
37
+
38
+ try:
39
+ model_a_filename = sd_models.checkpoints_list[model_a_name].filename
40
+ model_b_filename = sd_models.checkpoints_list[model_b_name].filename
41
+ except KeyError as e:
42
+ print(f"Error: Selected model is not found in checkpoints list. {e}")
43
+ return p
44
+
45
+ # マージ処理
46
+ if merge_enabled:
47
+ input_keys_and_alphas = []
48
+ for line in keys_and_alphas_str.split("\n"):
49
+ if "," in line:
50
+ key_part, alpha_str = line.split(",", 1)
51
+ try:
52
+ alpha = float(alpha_str)
53
+ input_keys_and_alphas.append((key_part, alpha))
54
+ except ValueError:
55
+ print(f"Invalid alpha value in line '{line}', skipping...")
56
+
57
+ # state_dictからキーのリストを事前に作成
58
+ model_keys = list(shared.sd_model.state_dict().keys())
59
+
60
+ # 部分一致検索を行う
61
+ final_keys_and_alphas = {}
62
+ for key_part, alpha in input_keys_and_alphas:
63
+ for model_key in model_keys:
64
+ if key_part in model_key:
65
+ final_keys_and_alphas[model_key] = alpha
66
+
67
+ # デバイスの設定 (GPUかCPUか選べるようにする)
68
+ device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu"
69
+
70
+ # バッチ処理でキーをまとめて処理
71
+ batched_keys = list(final_keys_and_alphas.items())
72
+
73
+ # モデルAとモデルBからテンソルをまとめて取得
74
+ with safe_open(model_a_filename, framework="pt", device=device) as f_a, \
75
+ safe_open(model_b_filename, framework="pt", device=device) as f_b:
76
+
77
+ # バッチごとに処理
78
+ for i in range(0, len(batched_keys), batch_size):
79
+ batch = batched_keys[i:i + batch_size]
80
+
81
+ # バッチでテンソルを取得して一度にマージ
82
+ tensors_a = [f_a.get_tensor(key) for key, _ in batch]
83
+ tensors_b = [f_b.get_tensor(key) for key, _ in batch]
84
+ alphas = [final_keys_and_alphas[key] for key, _ in batch]
85
+
86
+ # バッチでテンソルをマージして一度に適用
87
+ for key, alpha, tensor_a, tensor_b in zip([key for key, _ in batch], alphas, tensors_a, tensors_b):
88
+ # 直接 state_dict にマージ結果を適用
89
+ shared.sd_model.state_dict()[key].copy_(torch.lerp(tensor_a, tensor_b, alpha).to(device))
90
+ print(f"merged {alpha}:{key}")
91
+
92
+ # 必要に応じて process_images を実行
93
+ return process_images(p)