DeepSeek-R1-FP4 / generate_metadata.py
zhiyucheng's picture
Initial commit
6ac96e2
#!/usr/bin/env python3
import glob
import json
import os
from safetensors import safe_open # pip install safetensors
def main():
# Collect all shard files matching "model-*-of-*.safetensors"
shard_files = sorted(glob.glob("model-*-of-*.safetensors"))
# Calculate total size of all shards (in bytes)
total_size = sum(os.path.getsize(sf) for sf in shard_files)
metadata = {"total_size": total_size}
weight_map = {}
# Iterate over each shard and map its tensor names to the shard filename
for shard_file in shard_files:
with safe_open(shard_file, framework="np") as f:
for tensor_name in f.keys():
weight_map[tensor_name] = os.path.basename(shard_file)
output_dict = {"metadata": metadata, "weight_map": weight_map}
# Write JSON structure to "model.safetensors.index.json"
with open("model.safetensors.index.json", "w", encoding="utf-8") as out_file:
json.dump(output_dict, out_file, indent=2)
print("Created model.safetensors.index.json with total size =", total_size, "bytes.")
if __name__ == "__main__":
main()