tyraepaul commited on
Commit
4f30dbd
·
verified ·
1 Parent(s): 81d12ef

Upload folder using huggingface_hub

Browse files
Files changed (6) hide show
  1. .gitattributes +2 -0
  2. run_stok.py +176 -0
  3. stok-0.3-large.json +3 -0
  4. stok-0.3.json +3 -0
  5. stok-tools.py +93 -0
  6. stokfile.py +49 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ stok-0.3-large.json filter=lfs diff=lfs merge=lfs -text
37
+ stok-0.3.json filter=lfs diff=lfs merge=lfs -text
run_stok.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import random
3
+
4
+ def strip_prompt(prompt): # used to make it more likely for the prompt to be understood
5
+ newprompt = str(prompt).lower()
6
+ newprompt = newprompt.replace(".", "")
7
+ newprompt = newprompt.replace("[", "")
8
+ newprompt = newprompt.replace("]", "")
9
+ newprompt = newprompt.replace(":", "")
10
+ newprompt = newprompt.replace(",", "")
11
+ newprompt = newprompt.replace("\"", "")
12
+ newprompt = newprompt.replace("'", "")
13
+ newprompt = newprompt.replace("/", "")
14
+ newprompt = newprompt.replace("(", "")
15
+ newprompt = newprompt.replace(")", "")
16
+ newprompt = newprompt.replace(";", "")
17
+ newprompt = newprompt.replace("-", "")
18
+ newprompt = newprompt.replace("_", "")
19
+ newprompt = newprompt.replace("{", "")
20
+ newprompt = newprompt.replace("}", "")
21
+ newprompt = newprompt.replace("?", "")
22
+ newprompt = " ".join(newprompt.split(sep=None))
23
+ return newprompt
24
+
25
+ def strip_text(prompt): # kinda wacky overall
26
+ newprompt = str(prompt).lower()
27
+ newprompt = " ".join(newprompt.split(sep=None))
28
+ return newprompt
29
+
30
+ model = {"model_data": {}}
31
+ def load_model(filename: str):
32
+ model["model_data"] = json.loads(open(filename, "r").read())
33
+
34
+ def version_03_inference(prompt: str, max_tokens: int=None, repetition_penalty: int=2):
35
+ tokens_generated = 0
36
+ split_prompt = strip_prompt(prompt).split(sep=None)
37
+ model_data = model["model_data"]
38
+ outputs = model_data["outputs"]
39
+ raw_outputs = model_data["raw_outputs"]
40
+ prompts = model_data["prompts"]
41
+ ends = model_data["ends"]
42
+ start = ""
43
+ topic = None
44
+ for token in split_prompt:
45
+ if token in prompts:
46
+ start = max(prompts[token], key=prompts[token].get)
47
+ topic = token
48
+ break
49
+ if topic == None: # use raw outputs
50
+ outputs = raw_outputs
51
+ topic = None
52
+ start = split_prompt[-1]
53
+ tokens_generated += 1
54
+ running = True
55
+ current_token = [start]
56
+ while running:
57
+ token = current_token[0]
58
+ yield f"{token} "
59
+ if token in outputs:
60
+ next_token = max(outputs[token], key=outputs[token].get)
61
+ outputs[token][next_token] -= repetition_penalty
62
+ else:
63
+ next_token = random.choice(list(outputs.keys()))
64
+ current_token[0] = next_token
65
+ tokens_generated += 1
66
+ if max_tokens != None:
67
+ if tokens_generated >= max_tokens:
68
+ running = False
69
+ if topic:
70
+ if token in ends[topic]:
71
+ running = False
72
+ else:
73
+ tokens_generated += 1
74
+ running = True
75
+ current_token = [start]
76
+ while running:
77
+ token = current_token[0]
78
+ yield f"{token} "
79
+ if outputs.get(topic) != None:
80
+ if token in outputs[topic]:
81
+ next_token = max(outputs[topic][token], key=outputs[topic][token].get)
82
+ outputs[topic][token][next_token] -= repetition_penalty
83
+ else:
84
+ next_token = random.choice(list(outputs.keys()))
85
+ current_token[0] = next_token
86
+ tokens_generated += 1
87
+ if max_tokens != None:
88
+ if tokens_generated >= max_tokens:
89
+ running = False
90
+ if topic:
91
+ if token in ends[topic]:
92
+ running = False
93
+ else:
94
+ running = False # this is because single token responses seem to break things
95
+
96
+ def version_02_inference(prompt: str, max_tokens: int=None, repetition_penalty: int=1):
97
+ tokens_generated = 0
98
+ split_prompt = strip_prompt(prompt).split(sep=None)
99
+ model_data = model["model_data"]
100
+ outputs = model_data["outputs"]
101
+ prompts = model_data["prompts"]
102
+ ends = model_data["ends"]
103
+ start = ""
104
+ for token in split_prompt:
105
+ if token in prompts:
106
+ start = max(prompts[token], key=prompts[token].get)
107
+ topic = token
108
+ break
109
+ else:
110
+ topic = random.choice(list(ends))
111
+ start = random.choice(list(prompts.keys()))
112
+ tokens_generated += 1
113
+ running = True
114
+ current_token = [start]
115
+ while running:
116
+ token = current_token[0]
117
+ yield f"{token} "
118
+ if token in outputs:
119
+ next_token = max(outputs[token], key=outputs[token].get)
120
+ outputs[token][next_token] -= repetition_penalty
121
+ else:
122
+ next_token = random.choice(list(outputs.keys()))
123
+ current_token[0] = next_token
124
+ tokens_generated += 1
125
+ if max_tokens != None:
126
+ if tokens_generated >= max_tokens:
127
+ running = False
128
+ if topic:
129
+ if token in ends[topic]:
130
+ running = False
131
+
132
+ def version_01_inference(prompt: str, max_tokens: int=None, repetition_penalty: int=1):
133
+ tokens_generated = 0
134
+ split_prompt = strip_prompt(prompt).split(sep=None)
135
+ model_data = model["model_data"]
136
+ outputs = model_data["outputs"]
137
+ prompts = model_data["prompts"]
138
+ start = ""
139
+ for token in split_prompt:
140
+ if token in prompts:
141
+ start = max(prompts[token], key=prompts[token].get)
142
+ tokens_generated += 1
143
+ running = True
144
+ current_token = [start]
145
+ while running:
146
+ token = current_token[0]
147
+ yield f"{token} "
148
+ if token in outputs:
149
+ next_token = max(outputs[token], key=outputs[token].get)
150
+ outputs[token][next_token] -= repetition_penalty
151
+ else:
152
+ next_token = random.choice(list(outputs.keys()))
153
+ current_token[0] = next_token
154
+ tokens_generated += 1
155
+ if max_tokens != None:
156
+ if tokens_generated >= max_tokens:
157
+ running = False
158
+
159
+ def run_model(prompt: str, max_tokens: int=None, repetition_penalty: int=1, temperature: float=0):
160
+ # (temperature does not work on versions below 0.3)
161
+ model_data = model["model_data"]
162
+ model_format = model_data["format"]
163
+ if model_data["format"] == "v0.1":
164
+ response = version_01_inference(prompt, max_tokens=max_tokens, repetition_penalty=repetition_penalty)
165
+ for chunk in response:
166
+ yield chunk
167
+
168
+ if model_data["format"] == "v0.2":
169
+ response = version_02_inference(prompt, max_tokens=max_tokens, repetition_penalty=repetition_penalty)
170
+ for chunk in response:
171
+ yield chunk
172
+
173
+ if model_data["format"] == "v0.3":
174
+ response = version_03_inference(prompt, max_tokens=max_tokens, repetition_penalty=repetition_penalty)
175
+ for chunk in response:
176
+ yield chunk
stok-0.3-large.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a0586fcdc0d6ef99a76d96d1f45bb02f520b4a9e0a325a882bc87cd8fa95f8b6
3
+ size 478367292
stok-0.3.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0b1df825b31947f352a7cae62937842ff1c791a35a534a32bd5d21d6dd93c9cc
3
+ size 15166112
stok-tools.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from math import floor
3
+ import json
4
+ import os
5
+
6
+ def comma_number(number):
7
+ number = int(number)
8
+ ordered_num = list(str(number))
9
+ ordered_num.reverse()
10
+ if len(ordered_num) > 3:
11
+ splits = len(ordered_num)/3
12
+ splits = floor(splits)
13
+ start = 0
14
+ for x in range(0, splits):
15
+ if start == 0:
16
+ start += 3
17
+ else:
18
+ start += 4
19
+ ordered_num.insert(start, ",")
20
+ ordered_num.reverse()
21
+ if ordered_num[0] == ",":
22
+ ordered_num.pop(0)
23
+ return "".join(ordered_num)
24
+
25
+ def getSize(filename):
26
+ st = os.stat(filename)
27
+ size_in_mb = st.st_size / (1024 * 1024)
28
+ return size_in_mb
29
+
30
+ if __name__ == "__main__":
31
+ if len(sys.argv) > 1:
32
+ if sys.argv[1] == "help":
33
+ print("help - shows this command")
34
+ print("count_parameters <file> - counts parameters of a given model")
35
+ print("model_size <file> - Shows size of model in MB")
36
+ print("view_token <file> <token> - Shows a token's data")
37
+ if sys.argv[1] == "count_parameters":
38
+ filename = sys.argv[2]
39
+ model_data = json.loads(open(filename, "r").read())
40
+ format_version = model_data["format"]
41
+
42
+ if format_version == "v0.1" or format_version == "v0.2": # old outputs format
43
+ total = len(model_data["outputs"])
44
+ total += len(model_data["prompts"])
45
+ for output in model_data["outputs"]:
46
+ total += len(model_data["outputs"][output])
47
+ for prompt in model_data["prompts"]:
48
+ total += len(model_data["prompts"][prompt])
49
+
50
+ if format_version == "v0.3": # contextualized outputs format
51
+ total = len(model_data["outputs"])
52
+ total += len(model_data["prompts"])
53
+ for topic in model_data["outputs"]:
54
+ for token in model_data["outputs"][topic]:
55
+ total += len(model_data["outputs"][topic][token])
56
+ for prompt in model_data["prompts"]:
57
+ total += len(model_data["prompts"][prompt])
58
+ total += len(model_data["raw_outputs"])
59
+ for output in model_data["raw_outputs"]:
60
+ total += len(model_data["raw_outputs"][output])
61
+
62
+ if format_version == "v0.2" or format_version == "v0.3": # ends is supported in 0.2 and 0.3
63
+ total += len(model_data["ends"])
64
+ for topic in model_data["ends"]:
65
+ total += len(model_data["ends"][topic])
66
+
67
+ print(comma_number(total))
68
+
69
+
70
+ if sys.argv[1] == "model_size":
71
+ filename = sys.argv[2]
72
+ print(getSize(filename))
73
+
74
+ if sys.argv[1] == "view_token":
75
+ filename = sys.argv[2]
76
+ token = sys.argv[3]
77
+ model_data = json.loads(open(filename, "r").read())
78
+ prompts = model_data["prompts"]
79
+ outputs = model_data["outputs"]
80
+ try:
81
+ input_data = prompts[token]
82
+ except KeyError:
83
+ input_data = "NONE FOUND"
84
+ try:
85
+ output_data = outputs[token]
86
+ except KeyError:
87
+ output_data = "NONE FOUND"
88
+ print(f"PROMPT DATA: {input_data}")
89
+ print()
90
+ print()
91
+ print(f"OUTPUT DATA: {output_data}")
92
+
93
+
stokfile.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import run_stok
2
+ import sys
3
+ from run_stok import load_model, run_model
4
+ import time
5
+ total = []
6
+ model = "stok-0.3.json"
7
+ show_speed = False
8
+ if len(sys.argv) > 1: # it is set up like this to add more parameters in the future
9
+ if sys.argv[1] == "help":
10
+ print("help - shows this command")
11
+ print("-m <model> - specifies the file you want to inference")
12
+ print("-speed - if added, enables speed logging")
13
+ args = list(sys.argv)
14
+ running = True
15
+ while running:
16
+ if len(args) < 2:
17
+ running = False
18
+ elif args[1] == "-m":
19
+ model = args[2]
20
+ args.pop(1)
21
+ args.pop(1)
22
+ elif args[1] == "-speed":
23
+ show_speed = True
24
+ args.pop(1)
25
+ else:
26
+ running = False
27
+
28
+ load_model(model)
29
+ running = True
30
+ while running:
31
+ total = []
32
+ message = input(">>>")
33
+ if message == "/quit" or message == "/exit" or message == "/bye":
34
+ running = False
35
+ else:
36
+ chunks = run_model(message, max_tokens=100, repetition_penalty=2)
37
+ start = time.time()
38
+ for chunk in chunks:
39
+ total.append(chunk)
40
+ print(chunk, end="")
41
+ end = time.time()
42
+ print()
43
+ if show_speed:
44
+ print(f"Took: {end-start}s")
45
+ print(f"Generated: {len(total)}")
46
+ print(f"Speed: {len(total)/(end-start)} t/s")
47
+ print("_____________________________")
48
+
49
+