File size: 2,586 Bytes
0d84d06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
# Got a bunch of .ckpt files to convert?
# Here's a handy script to take care of all that for you!
# Original .ckpt files are not touched!
# Make sure you have enough disk space! You are going to DOUBLE the size of your models folder!
#
# First, run:
# pip install torch torchsde==0.2.5 safetensors==0.2.5
#
# Place this file in the **SAME DIRECTORY** as all of your .ckpt files, open a command prompt for that folder, and run:
# python convert_to_safe.py

# Original script https://gist.github.com/xrpgame/8f756f99b00b02697edcd5eec5202c59
# Edited by @Tumppi066 for use with folders https://github.com/Tumppi066/

import os
import torch
from safetensors.torch import save_file

files = os.listdir()

# Loop through all files in the folder to find the .ckpt files
models = []
safeTensors = []
for path, subdirs, files in os.walk(os.path.abspath(os.getcwd())):
    for name in files:
        if name.lower().endswith('.ckpt'):
            models.append(os.path.join(path, name))
        if name.lower().endswith('.safetensors'):
            safeTensors.append(os.path.join(path, name))

if len(models) == 0:
    print('\033[91m> No .ckpt files found in this directory ({}).\033[0m'.format(os.path.abspath(os.getcwd())))
    input('> Press enter to exit... ')
    exit()
print(f"\n\033[92m> Found {len(models)} .ckpt files to convert.\033[0m")
for model in models:
    print(str(models.index(model)+1) +": "+ model.split("\\")[-1])

input("> Press enter to continue... ")
print("\n")

for index in range(len(models)):
    f = models[index]
    modelName = f.split("\\")[-1] # This is for easy printing (without printing the full path)
    tensorName = f"{modelName.replace('.ckpt', '')}.safetensors"
    fn = f"{f.replace('.ckpt', '')}.safetensors"

    if fn in safeTensors:
        # Print the model name and skip it if it already exists in yellow
        print(f"\033[33m\n> Skipping {modelName}, as {tensorName} already exists.\033[0m")
        continue
    
    print(f'\n> Loading {modelName} ({index+1}/{len(models)})...')

    try:
        with torch.no_grad():
            map_location = torch.device('cpu')
            weights = torch.load(f, map_location=map_location)
            # keysList = list(weights.keys())
            # print(keysList)
            # weights = weights["state_dict"]
            fn = f"{f.replace('.ckpt', '')}.safetensors"
            print(f'Saving {tensorName}...')
            save_file(weights, fn)
    except Exception as ex:
        print(f'ERROR converting {modelName}: {ex}')

print("\n\033[92mDone!\033[0m")
input("> Press enter to exit... ")