|
import sys
|
|
from math import floor
|
|
import json
|
|
import os
|
|
|
|
def comma_number(number):
|
|
number = int(number)
|
|
ordered_num = list(str(number))
|
|
ordered_num.reverse()
|
|
if len(ordered_num) > 3:
|
|
splits = len(ordered_num)/3
|
|
splits = floor(splits)
|
|
start = 0
|
|
for x in range(0, splits):
|
|
if start == 0:
|
|
start += 3
|
|
else:
|
|
start += 4
|
|
ordered_num.insert(start, ",")
|
|
ordered_num.reverse()
|
|
if ordered_num[0] == ",":
|
|
ordered_num.pop(0)
|
|
return "".join(ordered_num)
|
|
|
|
def getSize(filename):
|
|
st = os.stat(filename)
|
|
size_in_mb = st.st_size / (1024 * 1024)
|
|
return size_in_mb
|
|
|
|
if __name__ == "__main__":
|
|
if len(sys.argv) > 1:
|
|
if sys.argv[1] == "help":
|
|
print("help - shows this command")
|
|
print("count_parameters <file> - counts parameters of a given model")
|
|
print("model_size <file> - Shows size of model in MB")
|
|
print("view_token <file> <token> - Shows a token's data")
|
|
if sys.argv[1] == "count_parameters":
|
|
filename = sys.argv[2]
|
|
model_data = json.loads(open(filename, "r").read())
|
|
format_version = model_data["format"]
|
|
|
|
if format_version == "v0.1" or format_version == "v0.2":
|
|
total = len(model_data["outputs"])
|
|
total += len(model_data["prompts"])
|
|
for output in model_data["outputs"]:
|
|
total += len(model_data["outputs"][output])
|
|
for prompt in model_data["prompts"]:
|
|
total += len(model_data["prompts"][prompt])
|
|
|
|
if format_version == "v0.3":
|
|
total = len(model_data["outputs"])
|
|
total += len(model_data["prompts"])
|
|
for topic in model_data["outputs"]:
|
|
for token in model_data["outputs"][topic]:
|
|
total += len(model_data["outputs"][topic][token])
|
|
for prompt in model_data["prompts"]:
|
|
total += len(model_data["prompts"][prompt])
|
|
total += len(model_data["raw_outputs"])
|
|
for output in model_data["raw_outputs"]:
|
|
total += len(model_data["raw_outputs"][output])
|
|
|
|
if format_version == "v0.2" or format_version == "v0.3":
|
|
total += len(model_data["ends"])
|
|
for topic in model_data["ends"]:
|
|
total += len(model_data["ends"][topic])
|
|
|
|
print(comma_number(total))
|
|
|
|
|
|
if sys.argv[1] == "model_size":
|
|
filename = sys.argv[2]
|
|
print(getSize(filename))
|
|
|
|
if sys.argv[1] == "view_token":
|
|
filename = sys.argv[2]
|
|
token = sys.argv[3]
|
|
model_data = json.loads(open(filename, "r").read())
|
|
prompts = model_data["prompts"]
|
|
outputs = model_data["outputs"]
|
|
try:
|
|
input_data = prompts[token]
|
|
except KeyError:
|
|
input_data = "NONE FOUND"
|
|
try:
|
|
output_data = outputs[token]
|
|
except KeyError:
|
|
output_data = "NONE FOUND"
|
|
print(f"PROMPT DATA: {input_data}")
|
|
print()
|
|
print()
|
|
print(f"OUTPUT DATA: {output_data}")
|
|
|
|
|
|
|