stok-sub-1 / stok-tools.py
tyraepaul's picture
Upload folder using huggingface_hub
4f30dbd verified
import sys
from math import floor
import json
import os
def comma_number(number):
number = int(number)
ordered_num = list(str(number))
ordered_num.reverse()
if len(ordered_num) > 3:
splits = len(ordered_num)/3
splits = floor(splits)
start = 0
for x in range(0, splits):
if start == 0:
start += 3
else:
start += 4
ordered_num.insert(start, ",")
ordered_num.reverse()
if ordered_num[0] == ",":
ordered_num.pop(0)
return "".join(ordered_num)
def getSize(filename):
st = os.stat(filename)
size_in_mb = st.st_size / (1024 * 1024)
return size_in_mb
if __name__ == "__main__":
if len(sys.argv) > 1:
if sys.argv[1] == "help":
print("help - shows this command")
print("count_parameters <file> - counts parameters of a given model")
print("model_size <file> - Shows size of model in MB")
print("view_token <file> <token> - Shows a token's data")
if sys.argv[1] == "count_parameters":
filename = sys.argv[2]
model_data = json.loads(open(filename, "r").read())
format_version = model_data["format"]
if format_version == "v0.1" or format_version == "v0.2": # old outputs format
total = len(model_data["outputs"])
total += len(model_data["prompts"])
for output in model_data["outputs"]:
total += len(model_data["outputs"][output])
for prompt in model_data["prompts"]:
total += len(model_data["prompts"][prompt])
if format_version == "v0.3": # contextualized outputs format
total = len(model_data["outputs"])
total += len(model_data["prompts"])
for topic in model_data["outputs"]:
for token in model_data["outputs"][topic]:
total += len(model_data["outputs"][topic][token])
for prompt in model_data["prompts"]:
total += len(model_data["prompts"][prompt])
total += len(model_data["raw_outputs"])
for output in model_data["raw_outputs"]:
total += len(model_data["raw_outputs"][output])
if format_version == "v0.2" or format_version == "v0.3": # ends is supported in 0.2 and 0.3
total += len(model_data["ends"])
for topic in model_data["ends"]:
total += len(model_data["ends"][topic])
print(comma_number(total))
if sys.argv[1] == "model_size":
filename = sys.argv[2]
print(getSize(filename))
if sys.argv[1] == "view_token":
filename = sys.argv[2]
token = sys.argv[3]
model_data = json.loads(open(filename, "r").read())
prompts = model_data["prompts"]
outputs = model_data["outputs"]
try:
input_data = prompts[token]
except KeyError:
input_data = "NONE FOUND"
try:
output_data = outputs[token]
except KeyError:
output_data = "NONE FOUND"
print(f"PROMPT DATA: {input_data}")
print()
print()
print(f"OUTPUT DATA: {output_data}")