File size: 1,120 Bytes
6ac96e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
#!/usr/bin/env python3

import glob
import json
import os

from safetensors import safe_open  # pip install safetensors


def main():
    # Collect all shard files matching "model-*-of-*.safetensors"
    shard_files = sorted(glob.glob("model-*-of-*.safetensors"))

    # Calculate total size of all shards (in bytes)
    total_size = sum(os.path.getsize(sf) for sf in shard_files)

    metadata = {"total_size": total_size}
    weight_map = {}

    # Iterate over each shard and map its tensor names to the shard filename
    for shard_file in shard_files:
        with safe_open(shard_file, framework="np") as f:
            for tensor_name in f.keys():
                weight_map[tensor_name] = os.path.basename(shard_file)

    output_dict = {"metadata": metadata, "weight_map": weight_map}

    # Write JSON structure to "model.safetensors.index.json"
    with open("model.safetensors.index.json", "w", encoding="utf-8") as out_file:
        json.dump(output_dict, out_file, indent=2)

    print("Created model.safetensors.index.json with total size =", total_size, "bytes.")


if __name__ == "__main__":
    main()