import sys from math import floor import json import os def comma_number(number): number = int(number) ordered_num = list(str(number)) ordered_num.reverse() if len(ordered_num) > 3: splits = len(ordered_num)/3 splits = floor(splits) start = 0 for x in range(0, splits): if start == 0: start += 3 else: start += 4 ordered_num.insert(start, ",") ordered_num.reverse() if ordered_num[0] == ",": ordered_num.pop(0) return "".join(ordered_num) def getSize(filename): st = os.stat(filename) size_in_mb = st.st_size / (1024 * 1024) return size_in_mb if __name__ == "__main__": if len(sys.argv) > 1: if sys.argv[1] == "help": print("help - shows this command") print("count_parameters - counts parameters of a given model") print("model_size - Shows size of model in MB") print("view_token - Shows a token's data") if sys.argv[1] == "count_parameters": filename = sys.argv[2] model_data = json.loads(open(filename, "r").read()) format_version = model_data["format"] if format_version == "v0.1" or format_version == "v0.2": # old outputs format total = len(model_data["outputs"]) total += len(model_data["prompts"]) for output in model_data["outputs"]: total += len(model_data["outputs"][output]) for prompt in model_data["prompts"]: total += len(model_data["prompts"][prompt]) if format_version == "v0.3": # contextualized outputs format total = len(model_data["outputs"]) total += len(model_data["prompts"]) for topic in model_data["outputs"]: for token in model_data["outputs"][topic]: total += len(model_data["outputs"][topic][token]) for prompt in model_data["prompts"]: total += len(model_data["prompts"][prompt]) total += len(model_data["raw_outputs"]) for output in model_data["raw_outputs"]: total += len(model_data["raw_outputs"][output]) if format_version == "v0.2" or format_version == "v0.3": # ends is supported in 0.2 and 0.3 total += len(model_data["ends"]) for topic in model_data["ends"]: total += len(model_data["ends"][topic]) print(comma_number(total)) if sys.argv[1] == "model_size": filename = sys.argv[2] print(getSize(filename)) if sys.argv[1] == "view_token": filename = sys.argv[2] token = sys.argv[3] model_data = json.loads(open(filename, "r").read()) prompts = model_data["prompts"] outputs = model_data["outputs"] try: input_data = prompts[token] except KeyError: input_data = "NONE FOUND" try: output_data = outputs[token] except KeyError: output_data = "NONE FOUND" print(f"PROMPT DATA: {input_data}") print() print() print(f"OUTPUT DATA: {output_data}")