File size: 3,526 Bytes
4f30dbd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import sys
from math import floor
import json
import os

def comma_number(number):
    number = int(number)
    ordered_num = list(str(number))
    ordered_num.reverse()
    if len(ordered_num) > 3:
        splits = len(ordered_num)/3
        splits = floor(splits)
        start = 0
        for x in range(0, splits):
            if start == 0:
                start += 3
            else:
                start += 4
            ordered_num.insert(start, ",")
    ordered_num.reverse()
    if ordered_num[0] == ",":
        ordered_num.pop(0)
    return "".join(ordered_num)

def getSize(filename):
    st = os.stat(filename)
    size_in_mb = st.st_size / (1024 * 1024)
    return size_in_mb 

if __name__ == "__main__":
    if len(sys.argv) > 1:
        if sys.argv[1] == "help":
            print("help - shows this command")
            print("count_parameters <file> - counts parameters of a given model")
            print("model_size <file> - Shows size of model in MB")
            print("view_token <file> <token> - Shows a token's data")
        if sys.argv[1] == "count_parameters":
            filename = sys.argv[2]
            model_data = json.loads(open(filename, "r").read())
            format_version = model_data["format"]

            if format_version == "v0.1" or format_version == "v0.2": # old outputs format
                total = len(model_data["outputs"])
                total += len(model_data["prompts"])
                for output in model_data["outputs"]:
                    total += len(model_data["outputs"][output])
                for prompt in model_data["prompts"]:
                    total += len(model_data["prompts"][prompt])

            if format_version == "v0.3": # contextualized outputs format
                total = len(model_data["outputs"])
                total += len(model_data["prompts"])
                for topic in model_data["outputs"]:
                    for token in model_data["outputs"][topic]:
                        total += len(model_data["outputs"][topic][token])
                for prompt in model_data["prompts"]:
                    total += len(model_data["prompts"][prompt])
                total += len(model_data["raw_outputs"])
                for output in model_data["raw_outputs"]:
                    total += len(model_data["raw_outputs"][output])

            if format_version == "v0.2" or format_version == "v0.3": # ends is supported in 0.2 and 0.3
                total += len(model_data["ends"])
                for topic in model_data["ends"]:
                    total += len(model_data["ends"][topic])

            print(comma_number(total))

                
        if sys.argv[1] == "model_size":
            filename = sys.argv[2]
            print(getSize(filename))

        if sys.argv[1] == "view_token":
            filename = sys.argv[2]
            token = sys.argv[3]
            model_data = json.loads(open(filename, "r").read())
            prompts = model_data["prompts"]
            outputs = model_data["outputs"]
            try:
                input_data = prompts[token]
            except KeyError:
                input_data = "NONE FOUND"
            try:
                output_data = outputs[token]
            except KeyError:
                output_data = "NONE FOUND"
            print(f"PROMPT DATA: {input_data}")
            print()
            print()
            print(f"OUTPUT DATA: {output_data}")