lodrick-the-lafted commited on
Commit
580af9b
·
verified ·
1 Parent(s): c4b3afb

add emu3_expand.py for repro

Browse files
Files changed (1) hide show
  1. emu3_expand.py +451 -0
emu3_expand.py ADDED
@@ -0,0 +1,451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import json
4
+ import torch
5
+ import shutil
6
+ import numpy as np
7
+ from pathlib import Path
8
+ #from transformers import AutoModelForCausalLM, AutoTokenizer
9
+ from safetensors.torch import safe_open, save_file
10
+
11
+ from typing import Any, Dict, List, Optional, Union
12
+
13
+ # interpolation from mergekit
14
+ # thanks charles!
15
+ def normalize(v: np.ndarray, eps: float):
16
+ norm_v = np.linalg.norm(v)
17
+ if norm_v > eps:
18
+ v = v / norm_v
19
+ return v
20
+
21
+ def lerp(
22
+ t: float, v0: Union[np.ndarray, torch.Tensor], v1: Union[np.ndarray, torch.Tensor]
23
+ ) -> Union[np.ndarray, torch.Tensor]:
24
+ return (1 - t) * v0 + t * v1
25
+
26
+ def slerp(
27
+ t: Union[float, np.ndarray],
28
+ v0: Union[np.ndarray, torch.Tensor],
29
+ v1: Union[np.ndarray, torch.Tensor],
30
+ DOT_THRESHOLD: float = 0.9995,
31
+ eps: float = 1e-8,
32
+ ):
33
+ """
34
+ Spherical linear interpolation
35
+
36
+ From: https://gist.github.com/dvschultz/3af50c40df002da3b751efab1daddf2c
37
+ Args:
38
+ t (float/np.ndarray): Float value between 0.0 and 1.0
39
+ v0 (np.ndarray): Starting vector
40
+ v1 (np.ndarray): Final vector
41
+ DOT_THRESHOLD (float): Threshold for considering the two vectors as
42
+ colinear. Not recommended to alter this.
43
+ Returns:
44
+ v2 (np.ndarray): Interpolation vector between v0 and v1
45
+ """
46
+ is_torch = False
47
+ if not isinstance(v0, np.ndarray):
48
+ is_torch = True
49
+ v0 = v0.detach().cpu().float().numpy()
50
+ if not isinstance(v1, np.ndarray):
51
+ is_torch = True
52
+ v1 = v1.detach().cpu().float().numpy()
53
+
54
+ # Copy the vectors to reuse them later
55
+ v0_copy = np.copy(v0)
56
+ v1_copy = np.copy(v1)
57
+
58
+ # Normalize the vectors to get the directions and angles
59
+ v0 = normalize(v0, eps)
60
+ v1 = normalize(v1, eps)
61
+
62
+ # Dot product with the normalized vectors (can't use np.dot in W)
63
+ dot = np.sum(v0 * v1)
64
+
65
+ # If absolute value of dot product is almost 1, vectors are ~colinear, so use lerp
66
+ if np.abs(dot) > DOT_THRESHOLD:
67
+ res = lerp(t, v0_copy, v1_copy)
68
+ return maybe_torch(res, is_torch)
69
+
70
+ # Calculate initial angle between v0 and v1
71
+ theta_0 = np.arccos(dot)
72
+ sin_theta_0 = np.sin(theta_0)
73
+
74
+ # Angle at timestep t
75
+ theta_t = theta_0 * t
76
+ sin_theta_t = np.sin(theta_t)
77
+
78
+ # Finish the slerp algorithm
79
+ s0 = np.sin(theta_0 - theta_t) / sin_theta_0
80
+ s1 = sin_theta_t / sin_theta_0
81
+ res = s0 * v0_copy + s1 * v1_copy
82
+
83
+ return maybe_torch(res, is_torch)
84
+
85
+
86
+ def maybe_torch(v: np.ndarray, is_torch: bool):
87
+ if is_torch:
88
+ return torch.from_numpy(v)
89
+ return v
90
+
91
+
92
+ # move layer indices backwards to make room for inserted layer
93
+ def move_layer_back(model_dict, num_hidden_layers, layer_keys, layer_num, t):
94
+ # just rename the keys
95
+ print(f"move_layer_back {layer_keys[layer_num]}")
96
+
97
+ d = []
98
+ for k in layer_keys[layer_num]:
99
+ tensor = model_dict[k]
100
+
101
+ # loop backwards through the layers, increasing the index
102
+ # by one until the insertion layer has been reached
103
+ # model.layers.0.mlp.down_proj -> model.layers.1.mlp.down_proj
104
+ # .weight + .bias (for qwen)
105
+
106
+ if k.startswith(f'model.layers.{layer_num}.'):
107
+ tensor_suffix = k[len(f'model.layers.{layer_num}.'):]
108
+ tensor_cur_prefix = f'model.layers.{layer_num}.'
109
+ tensor_next_prefix = f'model.layers.{layer_num+1}.'
110
+ tensor_prev_prefix = f'model.layers.{layer_num-1}.'
111
+
112
+ model_dict[tensor_next_prefix + tensor_suffix] = tensor
113
+ del model_dict[k]
114
+
115
+ d.append(tensor_next_prefix + tensor_suffix)
116
+
117
+ #print(layer_keys[layer_num])
118
+ layer_keys[layer_num+1] = d
119
+ #print(layer_keys[layer_num+1])
120
+
121
+ #import pprint
122
+ #pprint.pp(model_dict)
123
+
124
+ # given a dict of tensors, a key, and layer_num,
125
+ # return the tensor at previous layer's version of key
126
+ def get_prev_tensor(model_dict, key, layer_num):
127
+ if key.startswith(f'model.layers.{layer_num}.'):
128
+ suffix = key[len(f'model.layers.{layer_num}.'):]
129
+ cur_prefix = f'model.layers.{layer_num}.'
130
+ prev_prefix = f'model.layers.{layer_num-1}.'
131
+ return model_dict[prev_prefix + suffix]
132
+ return None
133
+
134
+ # given a dict of tensors, a key, and layer_num,
135
+ # return the tensor at the next layer's version of key
136
+ def get_next_tensor(model_dict, key, layer_num):
137
+ if key.startswith(f'model.layers.{layer_num}.'):
138
+ suffix = key[len(f'model.layers.{layer_num}.'):]
139
+ cur_prefix = f'model.layers.{layer_num}.'
140
+ next_prefix = f'model.layers.{layer_num+1}.'
141
+ return model_dict[next_prefix + suffix]
142
+ return None
143
+
144
+ def insert_layer(model_dict, num_hidden_layers, layer_keys, layer_num, t=0.5, out_scale=0.4, scale=None):
145
+ print(f"inserting layer between {layer_num-1} and {layer_num} [t={t}]")
146
+
147
+ # need to move all layers after the insertion point
148
+ for i in range(num_hidden_layers, layer_num, -1):
149
+ #print(i)
150
+ move_layer_back(model_dict, num_hidden_layers, layer_keys, i - 1, t)
151
+
152
+
153
+ # now merge layer+1 with layer-1 and save to layer
154
+ # (because everything got moved back)
155
+
156
+ for k in layer_keys[layer_num]:
157
+ #print(k)
158
+ tensor = get_next_tensor(model_dict, k, layer_num)
159
+ prev_tensor = get_prev_tensor(model_dict, k, layer_num)
160
+ merge_tensor = lerp(t, prev_tensor, tensor)
161
+ if scale is not None:
162
+ merge_tensor = merge_tensor * scale
163
+ print(f"merging {layer_num-1} w/ {layer_num+1}")
164
+ #merge_tensor = slerp(t, prev_tensor, tensor)
165
+ if k.endswith("mlp.down_proj.weight"):
166
+ merge_tensor = merge_tensor*out_scale
167
+ if k.endswith("mlp.o_proj.weight"):
168
+ merge_tensor = merge_tensor*out_scale
169
+ if k.endswith(".bias"):
170
+ merge_tensor = merge_tensor*out_scale
171
+
172
+ model_dict[k] = merge_tensor
173
+
174
+ def get_dtype_size_in_bytes(tensor):
175
+ dtype = tensor.dtype
176
+ if dtype == torch.float32:
177
+ size_in_bytes = tensor.numel() * 4
178
+ elif dtype == torch.float64:
179
+ size_in_bytes = tensor.numel() * 8
180
+ elif dtype == torch.int32:
181
+ size_in_bytes = tensor.numel() * 4
182
+ elif dtype == torch.int64:
183
+ size_in_bytes = tensor.numel() * 8
184
+ elif dtype == torch.bool:
185
+ size_in_bytes = tensor.numel() * 1
186
+ else:
187
+ size_in_bytes = 0
188
+ return size_in_bytes
189
+
190
+ model_name = 'BAAI/Emu3-Gen'
191
+ dir_name = './'
192
+ #dir_name = None
193
+ conf = {}
194
+
195
+ with open(Path(dir_name or model_name) / 'config.json') as f:
196
+ conf = json.load(f)
197
+
198
+ st_dict = {}
199
+ tensor_dict = {}
200
+
201
+ if (Path(dir_name) / 'model.safetensors.index.json').is_file():
202
+ with open(Path(dir_name or model_name) / 'model.safetensors.index.json') as f:
203
+ st_index = json.load(f)
204
+ tensors = st_index['weight_map'].keys()
205
+ files = []
206
+ for name in tensors:
207
+ if st_index['weight_map'][name] not in files:
208
+ files.append(st_index['weight_map'][name])
209
+ #print(files)
210
+ for st in files:
211
+ tensor_dict = safe_open(st, framework='pt')
212
+ for k in tensor_dict.keys():
213
+ st_dict[k] = tensor_dict.get_tensor(k)
214
+ #print(st_dict)
215
+
216
+
217
+
218
+
219
+ elif (Path(dir_name) / 'model.safetensors').is_file():
220
+ model_fn = 'model.safetensors'
221
+ tensor_dict = safe_open(model_fn, framework='pt')
222
+ for k in tensor_dict.keys():
223
+ st_dict[k] = tensor_dict.get_tensor(k)
224
+ file_dict = {'model.safetensors': st_dict}
225
+ else:
226
+ print("please convert to safetensors")
227
+ sys.exit(-1)
228
+
229
+ print(conf)
230
+ num_hidden_layers = conf['num_hidden_layers']
231
+ print(num_hidden_layers)
232
+
233
+ model = {}
234
+ #sys.exit(-1)
235
+
236
+ #for k in tensor_dict.keys():
237
+ #model[k] = tensor_dict.get_tensor(k)
238
+
239
+
240
+ #print(tensor_dict.keys())
241
+ #import pprint
242
+ #pprint.pp(model)
243
+
244
+ #layer = 0
245
+ layer_keys = {}
246
+
247
+ for layer in range(num_hidden_layers):
248
+ #layer_keys[layer] = [k for k in sorted(tensor_dict.keys()) if k.startswith(f'model.layers.{layer}.')]
249
+ layer_keys[layer] = [k for k in sorted(st_dict.keys()) if k.startswith(f'model.layers.{layer}.')]
250
+
251
+ for k in layer_keys.keys():
252
+ print(f"Layer {k}")
253
+ print(layer_keys[k])
254
+ print("")
255
+
256
+ insert_layer(st_dict, num_hidden_layers, layer_keys, 24, 0.5, 0.35, scale=None)
257
+ num_hidden_layers += 1
258
+ insert_layer(st_dict, num_hidden_layers, layer_keys, 23, 0.5, 0.35, scale=None)
259
+ num_hidden_layers += 1
260
+ insert_layer(st_dict, num_hidden_layers, layer_keys, 22, 0.5, 0.35, scale=None)
261
+ num_hidden_layers += 1
262
+ insert_layer(st_dict, num_hidden_layers, layer_keys, 16, 0.5, 0.35, scale=None)
263
+ num_hidden_layers += 1
264
+ insert_layer(st_dict, num_hidden_layers, layer_keys, 15, 0.5, 0.35, scale=None)
265
+ num_hidden_layers += 1
266
+ insert_layer(st_dict, num_hidden_layers, layer_keys, 14, 0.5, 0.35, scale=None)
267
+ num_hidden_layers += 1
268
+ insert_layer(st_dict, num_hidden_layers, layer_keys, 13, 0.5, 0.35, scale=None)
269
+ num_hidden_layers += 1
270
+ insert_layer(st_dict, num_hidden_layers, layer_keys, 12, 0.5, 0.35, scale=None)
271
+ num_hidden_layers += 1
272
+ insert_layer(st_dict, num_hidden_layers, layer_keys, 11, 0.5, 0.35, scale=None)
273
+ num_hidden_layers += 1
274
+ insert_layer(st_dict, num_hidden_layers, layer_keys, 11, 0.5, 0.35, scale=None)
275
+ num_hidden_layers += 1
276
+ insert_layer(st_dict, num_hidden_layers, layer_keys, 10, 0.5, 0.35, scale=None)
277
+ num_hidden_layers += 1
278
+ insert_layer(st_dict, num_hidden_layers, layer_keys, 9, 0.5, 0.35, scale=None)
279
+ num_hidden_layers += 1
280
+ insert_layer(st_dict, num_hidden_layers, layer_keys, 8, 0.5, 0.35, scale=None)
281
+ num_hidden_layers += 1
282
+ insert_layer(st_dict, num_hidden_layers, layer_keys, 7, 0.5, 0.35, scale=None)
283
+ num_hidden_layers += 1
284
+
285
+
286
+
287
+
288
+ os.makedirs("original", exist_ok=True)
289
+ #shutil.copy("model.safetensors", "original")
290
+ shutil.copy("config.json", "original")
291
+
292
+ #save_file(st_dict, "model.safetensors", metadata={"format": "pt"})
293
+
294
+ max_shard_size = 5000000000
295
+ current_shard_size = 0
296
+ current_shard_index = 0
297
+ shard_dict = {}
298
+ current_shard = {}
299
+ shard_names = list(st_dict.keys())
300
+
301
+ byte_sum = 0
302
+ param_sum = 0
303
+ params = {k: st_dict[k].numel() for k in st_dict.keys()}
304
+ tensor_size = {k: get_dtype_size_in_bytes(st_dict[k]) for k in st_dict.keys()}
305
+ for p in params.keys():
306
+ param_sum += params[p]
307
+ byte_sum += tensor_size[p]
308
+ print(f"total params: {param_sum}")
309
+ print(f"total size in bytes: {byte_sum}")
310
+
311
+ if 'lm_head.weight' in shard_names:
312
+ tensor_name = 'lm_head.weight'
313
+ current_shard[tensor_name] = st_dict[tensor_name]
314
+ current_shard_size += tensor_size[tensor_name]
315
+ # for i in range(len(shard_names)):
316
+ # if shard_names[i] == tensor_name:
317
+ # del shard_names[i]
318
+ # break
319
+
320
+ layers = {}
321
+
322
+ for i in range(num_hidden_layers):
323
+ current_sizes = {}
324
+ layers[i] = [k for k in shard_names if k.startswith(f"model.layers.{i}.")]
325
+
326
+ for t in layers[i]:
327
+ #current_shard[t] = st_dict[t]
328
+ #size = get_dtype_size_in_bytes(st_dict[t])
329
+ #current_sizes[t] = size
330
+ current_sizes[t] = tensor_size[t]
331
+
332
+ for i in range(len(shard_names)):
333
+ if shard_names[i] == tensor_name:
334
+ del shard_names[i]
335
+ break
336
+
337
+ z = [k for k in shard_names if k.startswith(f"model.layers.")]
338
+ z.append("lm_head.weight")
339
+
340
+ remnants = list(set(shard_names) - set(z))
341
+ print(f"remnants size: {len(remnants)}")
342
+ print(remnants)
343
+
344
+
345
+ layer_size = 0
346
+ for l in layers[0]:
347
+ layer_size += tensor_size[l]
348
+ print(f"total size of tensors in a single layer: {layer_size}")
349
+
350
+
351
+ for i in range(num_hidden_layers):
352
+ print(f"current_shard_size: {current_shard_size}")
353
+ print(f"layer_size: {layer_size}")
354
+ print(f"max_shard_size: {max_shard_size}")
355
+ if current_shard_size + layer_size >= max_shard_size:
356
+ print(current_shard.keys())
357
+ # write shard
358
+ print(f"writing xmodel-{current_shard_index}.safetensors")
359
+ save_file(current_shard, f"xmodel-{current_shard_index}.safetensors", metadata={"format": "pt"})
360
+ shard_dict[current_shard_index] = current_shard.copy()
361
+ current_shard_size = 0
362
+ current_shard_index += 1
363
+ current_shard = {}
364
+ print(f"wrote xmodel-{current_shard_index}.safetensors")
365
+
366
+ for t in layers[i]:
367
+ print(f"shard: {t}")
368
+ current_shard[t] = st_dict[t]
369
+ current_shard_size += tensor_size[t]
370
+
371
+ print("")
372
+ print(shard_names)
373
+ print("")
374
+ print("")
375
+ print(current_shard.keys())
376
+
377
+ # add remnants
378
+
379
+ for x in remnants:
380
+ remnant_size = get_dtype_size_in_bytes(st_dict[x])
381
+ if current_shard_size + remnant_size < max_shard_size:
382
+ current_shard[x] = st_dict[x]
383
+ for i in range(len(remnants)):
384
+ if remnants[i] == tensor_name:
385
+ del remnants[i]
386
+ break
387
+
388
+ # write shard
389
+ print(f"writing xmodel-{current_shard_index}.safetensors")
390
+ save_file(current_shard, f"xmodel-{current_shard_index}.safetensors", metadata={"format": "pt"})
391
+ shard_dict[current_shard_index] = current_shard.copy()
392
+ current_shard_size = 0
393
+ current_shard_index += 1
394
+ current_shard = {}
395
+ print(f"wrote xmodel-{current_shard_index}.safetensors")
396
+
397
+ for x in remnants:
398
+ current_shard[x] = st_dict[x]
399
+
400
+ if len(remnants) > 0:
401
+ # write shard
402
+ print(f"writing xmodel-{current_shard_index}.safetensors")
403
+ save_file(current_shard, f"xmodel-{current_shard_index}.safetensors", metadata={"format": "pt"})
404
+ shard_dict[current_shard_index] = current_shard.copy()
405
+ current_shard_size = 0
406
+ current_shard_index += 1
407
+ #current_shard = {}
408
+ print(f"wrote xmodel-{current_shard_index-1}.safetensors")
409
+
410
+
411
+ # move safetensors to original
412
+ print("Moving old safetensors to old/")
413
+ unsorted_files = glob.glob("model-*-of-*.safetensors")
414
+ files = sorted(unsorted_files)
415
+
416
+ os.makedirs("old", exist_ok=True)
417
+
418
+ shutil.copy("config.json", "old")
419
+
420
+ for file in files:
421
+ Path("old/" + file).unlink()
422
+ shutil.move(file, "old")
423
+
424
+ Path("old/model.safetensors.index.json").unlink()
425
+ shutil.move("model.safetensors.index.json", "old")
426
+
427
+ # move xmodel to safetensors
428
+ for idx in range(current_shard_index):
429
+ if Path(f"xmodel-{idx}.safetensors").is_file():
430
+ shutil.move(f"xmodel-{idx}.safetensors", f"model-{idx+1:05}-of-{current_shard_index:05}.safetensors")
431
+
432
+
433
+ # write safetensor index
434
+ wmap = {}
435
+ index = {}
436
+
437
+ for idx in range(current_shard_index):
438
+ #print(idx)
439
+ ts = shard_dict[idx].keys()
440
+
441
+ for tname in ts:
442
+ wmap[tname] = f"model-{idx+1:05}-of-{current_shard_index:05}.safetensors"
443
+
444
+ index['metadata'] = {'total_size': param_sum}
445
+ index['weight_map'] = wmap
446
+ with open("model.safetensors.index.json", "w") as f:
447
+ json.dump(index, f, indent=4)
448
+
449
+ conf['num_hidden_layers'] = num_hidden_layers
450
+ with open(Path(dir_name or model_name) / 'config.json', "w") as f:
451
+ json.dump(conf, f, indent=4)