|
import os |
|
import glob |
|
import json |
|
import torch |
|
import shutil |
|
import numpy as np |
|
from pathlib import Path |
|
|
|
from safetensors.torch import safe_open, save_file |
|
|
|
from typing import Any, Dict, List, Optional, Union |
|
|
|
|
|
|
|
def normalize(v: np.ndarray, eps: float): |
|
norm_v = np.linalg.norm(v) |
|
if norm_v > eps: |
|
v = v / norm_v |
|
return v |
|
|
|
def lerp( |
|
t: float, v0: Union[np.ndarray, torch.Tensor], v1: Union[np.ndarray, torch.Tensor] |
|
) -> Union[np.ndarray, torch.Tensor]: |
|
return (1 - t) * v0 + t * v1 |
|
|
|
def slerp( |
|
t: Union[float, np.ndarray], |
|
v0: Union[np.ndarray, torch.Tensor], |
|
v1: Union[np.ndarray, torch.Tensor], |
|
DOT_THRESHOLD: float = 0.9995, |
|
eps: float = 1e-8, |
|
): |
|
""" |
|
Spherical linear interpolation |
|
|
|
From: https://gist.github.com/dvschultz/3af50c40df002da3b751efab1daddf2c |
|
Args: |
|
t (float/np.ndarray): Float value between 0.0 and 1.0 |
|
v0 (np.ndarray): Starting vector |
|
v1 (np.ndarray): Final vector |
|
DOT_THRESHOLD (float): Threshold for considering the two vectors as |
|
colinear. Not recommended to alter this. |
|
Returns: |
|
v2 (np.ndarray): Interpolation vector between v0 and v1 |
|
""" |
|
is_torch = False |
|
if not isinstance(v0, np.ndarray): |
|
is_torch = True |
|
v0 = v0.detach().cpu().float().numpy() |
|
if not isinstance(v1, np.ndarray): |
|
is_torch = True |
|
v1 = v1.detach().cpu().float().numpy() |
|
|
|
|
|
v0_copy = np.copy(v0) |
|
v1_copy = np.copy(v1) |
|
|
|
|
|
v0 = normalize(v0, eps) |
|
v1 = normalize(v1, eps) |
|
|
|
|
|
dot = np.sum(v0 * v1) |
|
|
|
|
|
if np.abs(dot) > DOT_THRESHOLD: |
|
res = lerp(t, v0_copy, v1_copy) |
|
return maybe_torch(res, is_torch) |
|
|
|
|
|
theta_0 = np.arccos(dot) |
|
sin_theta_0 = np.sin(theta_0) |
|
|
|
|
|
theta_t = theta_0 * t |
|
sin_theta_t = np.sin(theta_t) |
|
|
|
|
|
s0 = np.sin(theta_0 - theta_t) / sin_theta_0 |
|
s1 = sin_theta_t / sin_theta_0 |
|
res = s0 * v0_copy + s1 * v1_copy |
|
|
|
return maybe_torch(res, is_torch) |
|
|
|
|
|
def maybe_torch(v: np.ndarray, is_torch: bool): |
|
if is_torch: |
|
return torch.from_numpy(v) |
|
return v |
|
|
|
|
|
|
|
def move_layer_back(model_dict, num_hidden_layers, layer_keys, layer_num, t): |
|
|
|
print(f"move_layer_back {layer_keys[layer_num]}") |
|
|
|
d = [] |
|
for k in layer_keys[layer_num]: |
|
tensor = model_dict[k] |
|
|
|
|
|
|
|
|
|
|
|
|
|
if k.startswith(f'model.layers.{layer_num}.'): |
|
tensor_suffix = k[len(f'model.layers.{layer_num}.'):] |
|
tensor_cur_prefix = f'model.layers.{layer_num}.' |
|
tensor_next_prefix = f'model.layers.{layer_num+1}.' |
|
tensor_prev_prefix = f'model.layers.{layer_num-1}.' |
|
|
|
model_dict[tensor_next_prefix + tensor_suffix] = tensor |
|
del model_dict[k] |
|
|
|
d.append(tensor_next_prefix + tensor_suffix) |
|
|
|
|
|
layer_keys[layer_num+1] = d |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_prev_tensor(model_dict, key, layer_num): |
|
if key.startswith(f'model.layers.{layer_num}.'): |
|
suffix = key[len(f'model.layers.{layer_num}.'):] |
|
cur_prefix = f'model.layers.{layer_num}.' |
|
prev_prefix = f'model.layers.{layer_num-1}.' |
|
return model_dict[prev_prefix + suffix] |
|
return None |
|
|
|
|
|
|
|
def get_next_tensor(model_dict, key, layer_num): |
|
if key.startswith(f'model.layers.{layer_num}.'): |
|
suffix = key[len(f'model.layers.{layer_num}.'):] |
|
cur_prefix = f'model.layers.{layer_num}.' |
|
next_prefix = f'model.layers.{layer_num+1}.' |
|
return model_dict[next_prefix + suffix] |
|
return None |
|
|
|
def insert_layer(model_dict, num_hidden_layers, layer_keys, layer_num, t=0.5, out_scale=0.4, scale=None): |
|
print(f"inserting layer between {layer_num-1} and {layer_num} [t={t}]") |
|
|
|
|
|
for i in range(num_hidden_layers, layer_num, -1): |
|
|
|
move_layer_back(model_dict, num_hidden_layers, layer_keys, i - 1, t) |
|
|
|
|
|
|
|
|
|
|
|
for k in layer_keys[layer_num]: |
|
|
|
tensor = get_next_tensor(model_dict, k, layer_num) |
|
prev_tensor = get_prev_tensor(model_dict, k, layer_num) |
|
merge_tensor = lerp(t, prev_tensor, tensor) |
|
if scale is not None: |
|
merge_tensor = merge_tensor * scale |
|
print(f"merging {layer_num-1} w/ {layer_num+1}") |
|
|
|
if k.endswith("mlp.down_proj.weight"): |
|
merge_tensor = merge_tensor*out_scale |
|
if k.endswith("mlp.o_proj.weight"): |
|
merge_tensor = merge_tensor*out_scale |
|
if k.endswith(".bias"): |
|
merge_tensor = merge_tensor*out_scale |
|
|
|
model_dict[k] = merge_tensor |
|
|
|
def get_dtype_size_in_bytes(tensor): |
|
dtype = tensor.dtype |
|
if dtype == torch.float32: |
|
size_in_bytes = tensor.numel() * 4 |
|
elif dtype == torch.float64: |
|
size_in_bytes = tensor.numel() * 8 |
|
elif dtype == torch.int32: |
|
size_in_bytes = tensor.numel() * 4 |
|
elif dtype == torch.int64: |
|
size_in_bytes = tensor.numel() * 8 |
|
elif dtype == torch.bool: |
|
size_in_bytes = tensor.numel() * 1 |
|
else: |
|
size_in_bytes = 0 |
|
return size_in_bytes |
|
|
|
model_name = 'BAAI/Emu3-Gen' |
|
dir_name = './' |
|
|
|
conf = {} |
|
|
|
with open(Path(dir_name or model_name) / 'config.json') as f: |
|
conf = json.load(f) |
|
|
|
st_dict = {} |
|
tensor_dict = {} |
|
|
|
if (Path(dir_name) / 'model.safetensors.index.json').is_file(): |
|
with open(Path(dir_name or model_name) / 'model.safetensors.index.json') as f: |
|
st_index = json.load(f) |
|
tensors = st_index['weight_map'].keys() |
|
files = [] |
|
for name in tensors: |
|
if st_index['weight_map'][name] not in files: |
|
files.append(st_index['weight_map'][name]) |
|
|
|
for st in files: |
|
tensor_dict = safe_open(st, framework='pt') |
|
for k in tensor_dict.keys(): |
|
st_dict[k] = tensor_dict.get_tensor(k) |
|
|
|
|
|
|
|
|
|
|
|
elif (Path(dir_name) / 'model.safetensors').is_file(): |
|
model_fn = 'model.safetensors' |
|
tensor_dict = safe_open(model_fn, framework='pt') |
|
for k in tensor_dict.keys(): |
|
st_dict[k] = tensor_dict.get_tensor(k) |
|
file_dict = {'model.safetensors': st_dict} |
|
else: |
|
print("please convert to safetensors") |
|
sys.exit(-1) |
|
|
|
print(conf) |
|
num_hidden_layers = conf['num_hidden_layers'] |
|
print(num_hidden_layers) |
|
|
|
model = {} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
layer_keys = {} |
|
|
|
for layer in range(num_hidden_layers): |
|
|
|
layer_keys[layer] = [k for k in sorted(st_dict.keys()) if k.startswith(f'model.layers.{layer}.')] |
|
|
|
for k in layer_keys.keys(): |
|
print(f"Layer {k}") |
|
print(layer_keys[k]) |
|
print("") |
|
|
|
insert_layer(st_dict, num_hidden_layers, layer_keys, 24, 0.5, 0.35, scale=None) |
|
num_hidden_layers += 1 |
|
insert_layer(st_dict, num_hidden_layers, layer_keys, 23, 0.5, 0.35, scale=None) |
|
num_hidden_layers += 1 |
|
insert_layer(st_dict, num_hidden_layers, layer_keys, 22, 0.5, 0.35, scale=None) |
|
num_hidden_layers += 1 |
|
insert_layer(st_dict, num_hidden_layers, layer_keys, 16, 0.5, 0.35, scale=None) |
|
num_hidden_layers += 1 |
|
insert_layer(st_dict, num_hidden_layers, layer_keys, 15, 0.5, 0.35, scale=None) |
|
num_hidden_layers += 1 |
|
insert_layer(st_dict, num_hidden_layers, layer_keys, 14, 0.5, 0.35, scale=None) |
|
num_hidden_layers += 1 |
|
insert_layer(st_dict, num_hidden_layers, layer_keys, 13, 0.5, 0.35, scale=None) |
|
num_hidden_layers += 1 |
|
insert_layer(st_dict, num_hidden_layers, layer_keys, 12, 0.5, 0.35, scale=None) |
|
num_hidden_layers += 1 |
|
insert_layer(st_dict, num_hidden_layers, layer_keys, 11, 0.5, 0.35, scale=None) |
|
num_hidden_layers += 1 |
|
insert_layer(st_dict, num_hidden_layers, layer_keys, 11, 0.5, 0.35, scale=None) |
|
num_hidden_layers += 1 |
|
insert_layer(st_dict, num_hidden_layers, layer_keys, 10, 0.5, 0.35, scale=None) |
|
num_hidden_layers += 1 |
|
insert_layer(st_dict, num_hidden_layers, layer_keys, 9, 0.5, 0.35, scale=None) |
|
num_hidden_layers += 1 |
|
insert_layer(st_dict, num_hidden_layers, layer_keys, 8, 0.5, 0.35, scale=None) |
|
num_hidden_layers += 1 |
|
insert_layer(st_dict, num_hidden_layers, layer_keys, 7, 0.5, 0.35, scale=None) |
|
num_hidden_layers += 1 |
|
|
|
|
|
|
|
|
|
os.makedirs("original", exist_ok=True) |
|
|
|
shutil.copy("config.json", "original") |
|
|
|
|
|
|
|
max_shard_size = 5000000000 |
|
current_shard_size = 0 |
|
current_shard_index = 0 |
|
shard_dict = {} |
|
current_shard = {} |
|
shard_names = list(st_dict.keys()) |
|
|
|
byte_sum = 0 |
|
param_sum = 0 |
|
params = {k: st_dict[k].numel() for k in st_dict.keys()} |
|
tensor_size = {k: get_dtype_size_in_bytes(st_dict[k]) for k in st_dict.keys()} |
|
for p in params.keys(): |
|
param_sum += params[p] |
|
byte_sum += tensor_size[p] |
|
print(f"total params: {param_sum}") |
|
print(f"total size in bytes: {byte_sum}") |
|
|
|
if 'lm_head.weight' in shard_names: |
|
tensor_name = 'lm_head.weight' |
|
current_shard[tensor_name] = st_dict[tensor_name] |
|
current_shard_size += tensor_size[tensor_name] |
|
|
|
|
|
|
|
|
|
|
|
layers = {} |
|
|
|
for i in range(num_hidden_layers): |
|
current_sizes = {} |
|
layers[i] = [k for k in shard_names if k.startswith(f"model.layers.{i}.")] |
|
|
|
for t in layers[i]: |
|
|
|
|
|
|
|
current_sizes[t] = tensor_size[t] |
|
|
|
for i in range(len(shard_names)): |
|
if shard_names[i] == tensor_name: |
|
del shard_names[i] |
|
break |
|
|
|
z = [k for k in shard_names if k.startswith(f"model.layers.")] |
|
z.append("lm_head.weight") |
|
|
|
remnants = list(set(shard_names) - set(z)) |
|
print(f"remnants size: {len(remnants)}") |
|
print(remnants) |
|
|
|
|
|
layer_size = 0 |
|
for l in layers[0]: |
|
layer_size += tensor_size[l] |
|
print(f"total size of tensors in a single layer: {layer_size}") |
|
|
|
|
|
for i in range(num_hidden_layers): |
|
print(f"current_shard_size: {current_shard_size}") |
|
print(f"layer_size: {layer_size}") |
|
print(f"max_shard_size: {max_shard_size}") |
|
if current_shard_size + layer_size >= max_shard_size: |
|
print(current_shard.keys()) |
|
|
|
print(f"writing xmodel-{current_shard_index}.safetensors") |
|
save_file(current_shard, f"xmodel-{current_shard_index}.safetensors", metadata={"format": "pt"}) |
|
shard_dict[current_shard_index] = current_shard.copy() |
|
current_shard_size = 0 |
|
current_shard_index += 1 |
|
current_shard = {} |
|
print(f"wrote xmodel-{current_shard_index}.safetensors") |
|
|
|
for t in layers[i]: |
|
print(f"shard: {t}") |
|
current_shard[t] = st_dict[t] |
|
current_shard_size += tensor_size[t] |
|
|
|
print("") |
|
print(shard_names) |
|
print("") |
|
print("") |
|
print(current_shard.keys()) |
|
|
|
|
|
|
|
for x in remnants: |
|
remnant_size = get_dtype_size_in_bytes(st_dict[x]) |
|
if current_shard_size + remnant_size < max_shard_size: |
|
current_shard[x] = st_dict[x] |
|
for i in range(len(remnants)): |
|
if remnants[i] == tensor_name: |
|
del remnants[i] |
|
break |
|
|
|
|
|
print(f"writing xmodel-{current_shard_index}.safetensors") |
|
save_file(current_shard, f"xmodel-{current_shard_index}.safetensors", metadata={"format": "pt"}) |
|
shard_dict[current_shard_index] = current_shard.copy() |
|
current_shard_size = 0 |
|
current_shard_index += 1 |
|
current_shard = {} |
|
print(f"wrote xmodel-{current_shard_index}.safetensors") |
|
|
|
for x in remnants: |
|
current_shard[x] = st_dict[x] |
|
|
|
if len(remnants) > 0: |
|
|
|
print(f"writing xmodel-{current_shard_index}.safetensors") |
|
save_file(current_shard, f"xmodel-{current_shard_index}.safetensors", metadata={"format": "pt"}) |
|
shard_dict[current_shard_index] = current_shard.copy() |
|
current_shard_size = 0 |
|
current_shard_index += 1 |
|
|
|
print(f"wrote xmodel-{current_shard_index-1}.safetensors") |
|
|
|
|
|
|
|
print("Moving old safetensors to old/") |
|
unsorted_files = glob.glob("model-*-of-*.safetensors") |
|
files = sorted(unsorted_files) |
|
|
|
os.makedirs("old", exist_ok=True) |
|
|
|
shutil.copy("config.json", "old") |
|
|
|
for file in files: |
|
Path("old/" + file).unlink() |
|
shutil.move(file, "old") |
|
|
|
Path("old/model.safetensors.index.json").unlink() |
|
shutil.move("model.safetensors.index.json", "old") |
|
|
|
|
|
for idx in range(current_shard_index): |
|
if Path(f"xmodel-{idx}.safetensors").is_file(): |
|
shutil.move(f"xmodel-{idx}.safetensors", f"model-{idx+1:05}-of-{current_shard_index:05}.safetensors") |
|
|
|
|
|
|
|
wmap = {} |
|
index = {} |
|
|
|
for idx in range(current_shard_index): |
|
|
|
ts = shard_dict[idx].keys() |
|
|
|
for tname in ts: |
|
wmap[tname] = f"model-{idx+1:05}-of-{current_shard_index:05}.safetensors" |
|
|
|
index['metadata'] = {'total_size': param_sum} |
|
index['weight_map'] = wmap |
|
with open("model.safetensors.index.json", "w") as f: |
|
json.dump(index, f, indent=4) |
|
|
|
conf['num_hidden_layers'] = num_hidden_layers |
|
with open(Path(dir_name or model_name) / 'config.json', "w") as f: |
|
json.dump(conf, f, indent=4) |
|
|