File size: 6,837 Bytes
4f30dbd 6feab05 4f30dbd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
import json
import random
def strip_prompt(prompt): # used to make it more likely for the prompt to be understood
newprompt = str(prompt).lower()
newprompt = newprompt.replace(".", "")
newprompt = newprompt.replace("[", "")
newprompt = newprompt.replace("]", "")
newprompt = newprompt.replace(":", "")
newprompt = newprompt.replace(",", "")
newprompt = newprompt.replace("\"", "")
newprompt = newprompt.replace("'", "")
newprompt = newprompt.replace("(", "")
newprompt = newprompt.replace(")", "")
newprompt = newprompt.replace(";", "")
newprompt = newprompt.replace("-", "")
newprompt = newprompt.replace("_", "")
newprompt = newprompt.replace("{", "")
newprompt = newprompt.replace("}", "")
newprompt = newprompt.replace("?", "")
newprompt = newprompt.replace("!", "")
newprompt = " ".join(newprompt.split(sep=None))
return newprompt
def strip_text(prompt): # kinda wacky overall
newprompt = str(prompt).lower()
newprompt = " ".join(newprompt.split(sep=None))
return newprompt
model = {"model_data": {}}
def load_model(filename: str):
model["model_data"] = json.loads(open(filename, "r").read())
def version_03_inference(prompt: str, max_tokens: int=None, repetition_penalty: int=2):
tokens_generated = 0
split_prompt = strip_prompt(prompt).split(sep=None)
model_data = model["model_data"]
outputs = model_data["outputs"]
raw_outputs = model_data["raw_outputs"]
prompts = model_data["prompts"]
ends = model_data["ends"]
start = ""
topic = None
for token in split_prompt:
if token in prompts:
start = max(prompts[token], key=prompts[token].get)
topic = token
break
if topic == None: # use raw outputs
outputs = raw_outputs
topic = None
start = split_prompt[-1]
tokens_generated += 1
running = True
current_token = [start]
while running:
token = current_token[0]
yield f"{token} "
if token in outputs:
next_token = max(outputs[token], key=outputs[token].get)
outputs[token][next_token] -= repetition_penalty
else:
next_token = random.choice(list(outputs.keys()))
current_token[0] = next_token
tokens_generated += 1
if max_tokens != None:
if tokens_generated >= max_tokens:
running = False
if topic:
if token in ends[topic]:
running = False
else:
tokens_generated += 1
running = True
current_token = [start]
while running:
token = current_token[0]
yield f"{token} "
if outputs.get(topic) != None:
if token in outputs[topic]:
next_token = max(outputs[topic][token], key=outputs[topic][token].get)
outputs[topic][token][next_token] -= repetition_penalty
else:
next_token = random.choice(list(outputs.keys()))
current_token[0] = next_token
tokens_generated += 1
if max_tokens != None:
if tokens_generated >= max_tokens:
running = False
if topic:
if token in ends[topic]:
running = False
else:
running = False # this is because single token responses seem to break things
def version_02_inference(prompt: str, max_tokens: int=None, repetition_penalty: int=1):
tokens_generated = 0
split_prompt = strip_prompt(prompt).split(sep=None)
model_data = model["model_data"]
outputs = model_data["outputs"]
prompts = model_data["prompts"]
ends = model_data["ends"]
start = ""
for token in split_prompt:
if token in prompts:
start = max(prompts[token], key=prompts[token].get)
topic = token
break
else:
topic = random.choice(list(ends))
start = random.choice(list(prompts.keys()))
tokens_generated += 1
running = True
current_token = [start]
while running:
token = current_token[0]
yield f"{token} "
if token in outputs:
next_token = max(outputs[token], key=outputs[token].get)
outputs[token][next_token] -= repetition_penalty
else:
next_token = random.choice(list(outputs.keys()))
current_token[0] = next_token
tokens_generated += 1
if max_tokens != None:
if tokens_generated >= max_tokens:
running = False
if topic:
if token in ends[topic]:
running = False
def version_01_inference(prompt: str, max_tokens: int=None, repetition_penalty: int=1):
tokens_generated = 0
split_prompt = strip_prompt(prompt).split(sep=None)
model_data = model["model_data"]
outputs = model_data["outputs"]
prompts = model_data["prompts"]
start = ""
for token in split_prompt:
if token in prompts:
start = max(prompts[token], key=prompts[token].get)
tokens_generated += 1
running = True
current_token = [start]
while running:
token = current_token[0]
yield f"{token} "
if token in outputs:
next_token = max(outputs[token], key=outputs[token].get)
outputs[token][next_token] -= repetition_penalty
else:
next_token = random.choice(list(outputs.keys()))
current_token[0] = next_token
tokens_generated += 1
if max_tokens != None:
if tokens_generated >= max_tokens:
running = False
def run_model(prompt: str, max_tokens: int=None, repetition_penalty: int=1, temperature: float=0):
# (temperature does not work on versions below 0.3)
model_data = model["model_data"]
model_format = model_data["format"]
if model_data["format"] == "v0.1":
response = version_01_inference(prompt, max_tokens=max_tokens, repetition_penalty=repetition_penalty)
for chunk in response:
yield chunk
if model_data["format"] == "v0.2":
response = version_02_inference(prompt, max_tokens=max_tokens, repetition_penalty=repetition_penalty)
for chunk in response:
yield chunk
if model_data["format"] == "v0.3":
response = version_03_inference(prompt, max_tokens=max_tokens, repetition_penalty=repetition_penalty)
for chunk in response:
yield chunk
|