Emu3-Gen-12B / emu3_expand.py
lodrick-the-lafted's picture
add emu3_expand.py for repro
580af9b verified
import os
import glob
import json
import torch
import shutil
import numpy as np
from pathlib import Path
#from transformers import AutoModelForCausalLM, AutoTokenizer
from safetensors.torch import safe_open, save_file
from typing import Any, Dict, List, Optional, Union
# interpolation from mergekit
# thanks charles!
def normalize(v: np.ndarray, eps: float):
norm_v = np.linalg.norm(v)
if norm_v > eps:
v = v / norm_v
return v
def lerp(
t: float, v0: Union[np.ndarray, torch.Tensor], v1: Union[np.ndarray, torch.Tensor]
) -> Union[np.ndarray, torch.Tensor]:
return (1 - t) * v0 + t * v1
def slerp(
t: Union[float, np.ndarray],
v0: Union[np.ndarray, torch.Tensor],
v1: Union[np.ndarray, torch.Tensor],
DOT_THRESHOLD: float = 0.9995,
eps: float = 1e-8,
):
"""
Spherical linear interpolation
From: https://gist.github.com/dvschultz/3af50c40df002da3b751efab1daddf2c
Args:
t (float/np.ndarray): Float value between 0.0 and 1.0
v0 (np.ndarray): Starting vector
v1 (np.ndarray): Final vector
DOT_THRESHOLD (float): Threshold for considering the two vectors as
colinear. Not recommended to alter this.
Returns:
v2 (np.ndarray): Interpolation vector between v0 and v1
"""
is_torch = False
if not isinstance(v0, np.ndarray):
is_torch = True
v0 = v0.detach().cpu().float().numpy()
if not isinstance(v1, np.ndarray):
is_torch = True
v1 = v1.detach().cpu().float().numpy()
# Copy the vectors to reuse them later
v0_copy = np.copy(v0)
v1_copy = np.copy(v1)
# Normalize the vectors to get the directions and angles
v0 = normalize(v0, eps)
v1 = normalize(v1, eps)
# Dot product with the normalized vectors (can't use np.dot in W)
dot = np.sum(v0 * v1)
# If absolute value of dot product is almost 1, vectors are ~colinear, so use lerp
if np.abs(dot) > DOT_THRESHOLD:
res = lerp(t, v0_copy, v1_copy)
return maybe_torch(res, is_torch)
# Calculate initial angle between v0 and v1
theta_0 = np.arccos(dot)
sin_theta_0 = np.sin(theta_0)
# Angle at timestep t
theta_t = theta_0 * t
sin_theta_t = np.sin(theta_t)
# Finish the slerp algorithm
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
s1 = sin_theta_t / sin_theta_0
res = s0 * v0_copy + s1 * v1_copy
return maybe_torch(res, is_torch)
def maybe_torch(v: np.ndarray, is_torch: bool):
if is_torch:
return torch.from_numpy(v)
return v
# move layer indices backwards to make room for inserted layer
def move_layer_back(model_dict, num_hidden_layers, layer_keys, layer_num, t):
# just rename the keys
print(f"move_layer_back {layer_keys[layer_num]}")
d = []
for k in layer_keys[layer_num]:
tensor = model_dict[k]
# loop backwards through the layers, increasing the index
# by one until the insertion layer has been reached
# model.layers.0.mlp.down_proj -> model.layers.1.mlp.down_proj
# .weight + .bias (for qwen)
if k.startswith(f'model.layers.{layer_num}.'):
tensor_suffix = k[len(f'model.layers.{layer_num}.'):]
tensor_cur_prefix = f'model.layers.{layer_num}.'
tensor_next_prefix = f'model.layers.{layer_num+1}.'
tensor_prev_prefix = f'model.layers.{layer_num-1}.'
model_dict[tensor_next_prefix + tensor_suffix] = tensor
del model_dict[k]
d.append(tensor_next_prefix + tensor_suffix)
#print(layer_keys[layer_num])
layer_keys[layer_num+1] = d
#print(layer_keys[layer_num+1])
#import pprint
#pprint.pp(model_dict)
# given a dict of tensors, a key, and layer_num,
# return the tensor at previous layer's version of key
def get_prev_tensor(model_dict, key, layer_num):
if key.startswith(f'model.layers.{layer_num}.'):
suffix = key[len(f'model.layers.{layer_num}.'):]
cur_prefix = f'model.layers.{layer_num}.'
prev_prefix = f'model.layers.{layer_num-1}.'
return model_dict[prev_prefix + suffix]
return None
# given a dict of tensors, a key, and layer_num,
# return the tensor at the next layer's version of key
def get_next_tensor(model_dict, key, layer_num):
if key.startswith(f'model.layers.{layer_num}.'):
suffix = key[len(f'model.layers.{layer_num}.'):]
cur_prefix = f'model.layers.{layer_num}.'
next_prefix = f'model.layers.{layer_num+1}.'
return model_dict[next_prefix + suffix]
return None
def insert_layer(model_dict, num_hidden_layers, layer_keys, layer_num, t=0.5, out_scale=0.4, scale=None):
print(f"inserting layer between {layer_num-1} and {layer_num} [t={t}]")
# need to move all layers after the insertion point
for i in range(num_hidden_layers, layer_num, -1):
#print(i)
move_layer_back(model_dict, num_hidden_layers, layer_keys, i - 1, t)
# now merge layer+1 with layer-1 and save to layer
# (because everything got moved back)
for k in layer_keys[layer_num]:
#print(k)
tensor = get_next_tensor(model_dict, k, layer_num)
prev_tensor = get_prev_tensor(model_dict, k, layer_num)
merge_tensor = lerp(t, prev_tensor, tensor)
if scale is not None:
merge_tensor = merge_tensor * scale
print(f"merging {layer_num-1} w/ {layer_num+1}")
#merge_tensor = slerp(t, prev_tensor, tensor)
if k.endswith("mlp.down_proj.weight"):
merge_tensor = merge_tensor*out_scale
if k.endswith("mlp.o_proj.weight"):
merge_tensor = merge_tensor*out_scale
if k.endswith(".bias"):
merge_tensor = merge_tensor*out_scale
model_dict[k] = merge_tensor
def get_dtype_size_in_bytes(tensor):
dtype = tensor.dtype
if dtype == torch.float32:
size_in_bytes = tensor.numel() * 4
elif dtype == torch.float64:
size_in_bytes = tensor.numel() * 8
elif dtype == torch.int32:
size_in_bytes = tensor.numel() * 4
elif dtype == torch.int64:
size_in_bytes = tensor.numel() * 8
elif dtype == torch.bool:
size_in_bytes = tensor.numel() * 1
else:
size_in_bytes = 0
return size_in_bytes
model_name = 'BAAI/Emu3-Gen'
dir_name = './'
#dir_name = None
conf = {}
with open(Path(dir_name or model_name) / 'config.json') as f:
conf = json.load(f)
st_dict = {}
tensor_dict = {}
if (Path(dir_name) / 'model.safetensors.index.json').is_file():
with open(Path(dir_name or model_name) / 'model.safetensors.index.json') as f:
st_index = json.load(f)
tensors = st_index['weight_map'].keys()
files = []
for name in tensors:
if st_index['weight_map'][name] not in files:
files.append(st_index['weight_map'][name])
#print(files)
for st in files:
tensor_dict = safe_open(st, framework='pt')
for k in tensor_dict.keys():
st_dict[k] = tensor_dict.get_tensor(k)
#print(st_dict)
elif (Path(dir_name) / 'model.safetensors').is_file():
model_fn = 'model.safetensors'
tensor_dict = safe_open(model_fn, framework='pt')
for k in tensor_dict.keys():
st_dict[k] = tensor_dict.get_tensor(k)
file_dict = {'model.safetensors': st_dict}
else:
print("please convert to safetensors")
sys.exit(-1)
print(conf)
num_hidden_layers = conf['num_hidden_layers']
print(num_hidden_layers)
model = {}
#sys.exit(-1)
#for k in tensor_dict.keys():
#model[k] = tensor_dict.get_tensor(k)
#print(tensor_dict.keys())
#import pprint
#pprint.pp(model)
#layer = 0
layer_keys = {}
for layer in range(num_hidden_layers):
#layer_keys[layer] = [k for k in sorted(tensor_dict.keys()) if k.startswith(f'model.layers.{layer}.')]
layer_keys[layer] = [k for k in sorted(st_dict.keys()) if k.startswith(f'model.layers.{layer}.')]
for k in layer_keys.keys():
print(f"Layer {k}")
print(layer_keys[k])
print("")
insert_layer(st_dict, num_hidden_layers, layer_keys, 24, 0.5, 0.35, scale=None)
num_hidden_layers += 1
insert_layer(st_dict, num_hidden_layers, layer_keys, 23, 0.5, 0.35, scale=None)
num_hidden_layers += 1
insert_layer(st_dict, num_hidden_layers, layer_keys, 22, 0.5, 0.35, scale=None)
num_hidden_layers += 1
insert_layer(st_dict, num_hidden_layers, layer_keys, 16, 0.5, 0.35, scale=None)
num_hidden_layers += 1
insert_layer(st_dict, num_hidden_layers, layer_keys, 15, 0.5, 0.35, scale=None)
num_hidden_layers += 1
insert_layer(st_dict, num_hidden_layers, layer_keys, 14, 0.5, 0.35, scale=None)
num_hidden_layers += 1
insert_layer(st_dict, num_hidden_layers, layer_keys, 13, 0.5, 0.35, scale=None)
num_hidden_layers += 1
insert_layer(st_dict, num_hidden_layers, layer_keys, 12, 0.5, 0.35, scale=None)
num_hidden_layers += 1
insert_layer(st_dict, num_hidden_layers, layer_keys, 11, 0.5, 0.35, scale=None)
num_hidden_layers += 1
insert_layer(st_dict, num_hidden_layers, layer_keys, 11, 0.5, 0.35, scale=None)
num_hidden_layers += 1
insert_layer(st_dict, num_hidden_layers, layer_keys, 10, 0.5, 0.35, scale=None)
num_hidden_layers += 1
insert_layer(st_dict, num_hidden_layers, layer_keys, 9, 0.5, 0.35, scale=None)
num_hidden_layers += 1
insert_layer(st_dict, num_hidden_layers, layer_keys, 8, 0.5, 0.35, scale=None)
num_hidden_layers += 1
insert_layer(st_dict, num_hidden_layers, layer_keys, 7, 0.5, 0.35, scale=None)
num_hidden_layers += 1
os.makedirs("original", exist_ok=True)
#shutil.copy("model.safetensors", "original")
shutil.copy("config.json", "original")
#save_file(st_dict, "model.safetensors", metadata={"format": "pt"})
max_shard_size = 5000000000
current_shard_size = 0
current_shard_index = 0
shard_dict = {}
current_shard = {}
shard_names = list(st_dict.keys())
byte_sum = 0
param_sum = 0
params = {k: st_dict[k].numel() for k in st_dict.keys()}
tensor_size = {k: get_dtype_size_in_bytes(st_dict[k]) for k in st_dict.keys()}
for p in params.keys():
param_sum += params[p]
byte_sum += tensor_size[p]
print(f"total params: {param_sum}")
print(f"total size in bytes: {byte_sum}")
if 'lm_head.weight' in shard_names:
tensor_name = 'lm_head.weight'
current_shard[tensor_name] = st_dict[tensor_name]
current_shard_size += tensor_size[tensor_name]
# for i in range(len(shard_names)):
# if shard_names[i] == tensor_name:
# del shard_names[i]
# break
layers = {}
for i in range(num_hidden_layers):
current_sizes = {}
layers[i] = [k for k in shard_names if k.startswith(f"model.layers.{i}.")]
for t in layers[i]:
#current_shard[t] = st_dict[t]
#size = get_dtype_size_in_bytes(st_dict[t])
#current_sizes[t] = size
current_sizes[t] = tensor_size[t]
for i in range(len(shard_names)):
if shard_names[i] == tensor_name:
del shard_names[i]
break
z = [k for k in shard_names if k.startswith(f"model.layers.")]
z.append("lm_head.weight")
remnants = list(set(shard_names) - set(z))
print(f"remnants size: {len(remnants)}")
print(remnants)
layer_size = 0
for l in layers[0]:
layer_size += tensor_size[l]
print(f"total size of tensors in a single layer: {layer_size}")
for i in range(num_hidden_layers):
print(f"current_shard_size: {current_shard_size}")
print(f"layer_size: {layer_size}")
print(f"max_shard_size: {max_shard_size}")
if current_shard_size + layer_size >= max_shard_size:
print(current_shard.keys())
# write shard
print(f"writing xmodel-{current_shard_index}.safetensors")
save_file(current_shard, f"xmodel-{current_shard_index}.safetensors", metadata={"format": "pt"})
shard_dict[current_shard_index] = current_shard.copy()
current_shard_size = 0
current_shard_index += 1
current_shard = {}
print(f"wrote xmodel-{current_shard_index}.safetensors")
for t in layers[i]:
print(f"shard: {t}")
current_shard[t] = st_dict[t]
current_shard_size += tensor_size[t]
print("")
print(shard_names)
print("")
print("")
print(current_shard.keys())
# add remnants
for x in remnants:
remnant_size = get_dtype_size_in_bytes(st_dict[x])
if current_shard_size + remnant_size < max_shard_size:
current_shard[x] = st_dict[x]
for i in range(len(remnants)):
if remnants[i] == tensor_name:
del remnants[i]
break
# write shard
print(f"writing xmodel-{current_shard_index}.safetensors")
save_file(current_shard, f"xmodel-{current_shard_index}.safetensors", metadata={"format": "pt"})
shard_dict[current_shard_index] = current_shard.copy()
current_shard_size = 0
current_shard_index += 1
current_shard = {}
print(f"wrote xmodel-{current_shard_index}.safetensors")
for x in remnants:
current_shard[x] = st_dict[x]
if len(remnants) > 0:
# write shard
print(f"writing xmodel-{current_shard_index}.safetensors")
save_file(current_shard, f"xmodel-{current_shard_index}.safetensors", metadata={"format": "pt"})
shard_dict[current_shard_index] = current_shard.copy()
current_shard_size = 0
current_shard_index += 1
#current_shard = {}
print(f"wrote xmodel-{current_shard_index-1}.safetensors")
# move safetensors to original
print("Moving old safetensors to old/")
unsorted_files = glob.glob("model-*-of-*.safetensors")
files = sorted(unsorted_files)
os.makedirs("old", exist_ok=True)
shutil.copy("config.json", "old")
for file in files:
Path("old/" + file).unlink()
shutil.move(file, "old")
Path("old/model.safetensors.index.json").unlink()
shutil.move("model.safetensors.index.json", "old")
# move xmodel to safetensors
for idx in range(current_shard_index):
if Path(f"xmodel-{idx}.safetensors").is_file():
shutil.move(f"xmodel-{idx}.safetensors", f"model-{idx+1:05}-of-{current_shard_index:05}.safetensors")
# write safetensor index
wmap = {}
index = {}
for idx in range(current_shard_index):
#print(idx)
ts = shard_dict[idx].keys()
for tname in ts:
wmap[tname] = f"model-{idx+1:05}-of-{current_shard_index:05}.safetensors"
index['metadata'] = {'total_size': param_sum}
index['weight_map'] = wmap
with open("model.safetensors.index.json", "w") as f:
json.dump(index, f, indent=4)
conf['num_hidden_layers'] = num_hidden_layers
with open(Path(dir_name or model_name) / 'config.json', "w") as f:
json.dump(conf, f, indent=4)