stok-sub-1 / run_stok.py
tyraepaul's picture
Upload folder using huggingface_hub
139a32b verified
import json
import random
def strip_prompt(prompt): # used to make it more likely for the prompt to be understood
newprompt = str(prompt).lower()
newprompt = newprompt.replace(".", "")
newprompt = newprompt.replace("[", "")
newprompt = newprompt.replace("]", "")
newprompt = newprompt.replace(":", "")
newprompt = newprompt.replace(",", "")
newprompt = newprompt.replace("\"", "")
newprompt = newprompt.replace("'", "")
newprompt = newprompt.replace("(", "")
newprompt = newprompt.replace(")", "")
newprompt = newprompt.replace(";", "")
newprompt = newprompt.replace("-", "")
newprompt = newprompt.replace("_", "")
newprompt = newprompt.replace("{", "")
newprompt = newprompt.replace("}", "")
newprompt = newprompt.replace("?", "")
newprompt = newprompt.replace("!", "")
newprompt = " ".join(newprompt.split(sep=None))
return newprompt
def strip_text(prompt): # kinda wacky overall
newprompt = str(prompt).lower()
newprompt = " ".join(newprompt.split(sep=None))
return newprompt
model = {"model_data": {}}
def load_model(filename: str):
model["model_data"] = json.loads(open(filename, "r").read())
def version_03_inference(prompt: str, max_tokens: int=None, repetition_penalty: int=2):
tokens_generated = 0
split_prompt = strip_prompt(prompt).split(sep=None)
model_data = model["model_data"]
outputs = model_data["outputs"]
raw_outputs = model_data["raw_outputs"]
prompts = model_data["prompts"]
ends = model_data["ends"]
start = ""
topic = None
for token in split_prompt:
if token in prompts:
start = max(prompts[token], key=prompts[token].get)
topic = token
break
if topic == None: # use raw outputs
outputs = raw_outputs
topic = None
start = split_prompt[-1]
tokens_generated += 1
running = True
current_token = [start]
while running:
token = current_token[0]
yield f"{token} "
if token in outputs:
next_token = max(outputs[token], key=outputs[token].get)
outputs[token][next_token] -= repetition_penalty
else:
next_token = random.choice(list(outputs.keys()))
current_token[0] = next_token
tokens_generated += 1
if max_tokens != None:
if tokens_generated >= max_tokens:
running = False
if topic:
if token in ends[topic]:
running = False
else:
tokens_generated += 1
running = True
current_token = [start]
while running:
token = current_token[0]
yield f"{token} "
if outputs.get(topic) != None:
if token in outputs[topic]:
next_token = max(outputs[topic][token], key=outputs[topic][token].get)
outputs[topic][token][next_token] -= repetition_penalty
else:
next_token = random.choice(list(outputs.keys()))
current_token[0] = next_token
tokens_generated += 1
if max_tokens != None:
if tokens_generated >= max_tokens:
running = False
if topic:
if token in ends[topic]:
running = False
else:
running = False # this is because single token responses seem to break things
def version_02_inference(prompt: str, max_tokens: int=None, repetition_penalty: int=1):
tokens_generated = 0
split_prompt = strip_prompt(prompt).split(sep=None)
model_data = model["model_data"]
outputs = model_data["outputs"]
prompts = model_data["prompts"]
ends = model_data["ends"]
start = ""
for token in split_prompt:
if token in prompts:
start = max(prompts[token], key=prompts[token].get)
topic = token
break
else:
topic = random.choice(list(ends))
start = random.choice(list(prompts.keys()))
tokens_generated += 1
running = True
current_token = [start]
while running:
token = current_token[0]
yield f"{token} "
if token in outputs:
next_token = max(outputs[token], key=outputs[token].get)
outputs[token][next_token] -= repetition_penalty
else:
next_token = random.choice(list(outputs.keys()))
current_token[0] = next_token
tokens_generated += 1
if max_tokens != None:
if tokens_generated >= max_tokens:
running = False
if topic:
if token in ends[topic]:
running = False
def version_01_inference(prompt: str, max_tokens: int=None, repetition_penalty: int=1):
tokens_generated = 0
split_prompt = strip_prompt(prompt).split(sep=None)
model_data = model["model_data"]
outputs = model_data["outputs"]
prompts = model_data["prompts"]
start = ""
for token in split_prompt:
if token in prompts:
start = max(prompts[token], key=prompts[token].get)
tokens_generated += 1
running = True
current_token = [start]
while running:
token = current_token[0]
yield f"{token} "
if token in outputs:
next_token = max(outputs[token], key=outputs[token].get)
outputs[token][next_token] -= repetition_penalty
else:
next_token = random.choice(list(outputs.keys()))
current_token[0] = next_token
tokens_generated += 1
if max_tokens != None:
if tokens_generated >= max_tokens:
running = False
def run_model(prompt: str, max_tokens: int=None, repetition_penalty: int=1, temperature: float=0):
# (temperature does not work on versions below 0.3)
model_data = model["model_data"]
model_format = model_data["format"]
if model_data["format"] == "v0.1":
response = version_01_inference(prompt, max_tokens=max_tokens, repetition_penalty=repetition_penalty)
for chunk in response:
yield chunk
if model_data["format"] == "v0.2":
response = version_02_inference(prompt, max_tokens=max_tokens, repetition_penalty=repetition_penalty)
for chunk in response:
yield chunk
if model_data["format"] == "v0.3":
response = version_03_inference(prompt, max_tokens=max_tokens, repetition_penalty=repetition_penalty)
for chunk in response:
yield chunk