|
import json |
|
from safetensors import safe_open |
|
from safetensors.torch import save_file |
|
import os |
|
|
|
|
|
SAFETENSORS_DIR = "./" |
|
INDEX_PATH = "model.safetensors.index.json" |
|
|
|
|
|
def rename_safetensors_keys(file_path): |
|
"""Rename keys in a single .safetensors file""" |
|
with safe_open(file_path, framework="pt") as f: |
|
tensors = {k: f.get_tensor(k) for k in f.keys()} |
|
metadata = f.metadata() |
|
|
|
|
|
new_tensors = {} |
|
for key in tensors: |
|
if key.startswith("speech_generator."): |
|
new_key = f"model.{key}" |
|
new_tensors[new_key] = tensors[key] |
|
else: |
|
new_tensors[key] = tensors[key] |
|
|
|
|
|
save_file(new_tensors, file_path, metadata=metadata) |
|
|
|
def update_index_file(): |
|
"""Update keys in the index file""" |
|
with open(INDEX_PATH, "r") as f: |
|
index = json.load(f) |
|
|
|
new_index = {"metadata": index["metadata"], "weight_map": {}} |
|
index = index["weight_map"] |
|
for key, value in index.items(): |
|
if key.startswith("speech_generator."): |
|
new_key = f"model.{key}" |
|
new_index["weight_map"][new_key] = value |
|
else: |
|
new_index["weight_map"][key] = value |
|
|
|
with open(INDEX_PATH, "w") as f: |
|
json.dump(new_index, f, indent=2) |
|
|
|
|
|
for filename in os.listdir(SAFETENSORS_DIR): |
|
if filename.endswith(".safetensors"): |
|
file_path = os.path.join(SAFETENSORS_DIR, filename) |
|
rename_safetensors_keys(file_path) |
|
|
|
|
|
update_index_file() |
|
|
|
print("Keys renamed successfully!") |