|
import json
|
|
import random
|
|
|
|
def strip_prompt(prompt):
|
|
newprompt = str(prompt).lower()
|
|
newprompt = newprompt.replace(".", "")
|
|
newprompt = newprompt.replace("[", "")
|
|
newprompt = newprompt.replace("]", "")
|
|
newprompt = newprompt.replace(":", "")
|
|
newprompt = newprompt.replace(",", "")
|
|
newprompt = newprompt.replace("\"", "")
|
|
newprompt = newprompt.replace("'", "")
|
|
newprompt = newprompt.replace("(", "")
|
|
newprompt = newprompt.replace(")", "")
|
|
newprompt = newprompt.replace(";", "")
|
|
newprompt = newprompt.replace("-", "")
|
|
newprompt = newprompt.replace("_", "")
|
|
newprompt = newprompt.replace("{", "")
|
|
newprompt = newprompt.replace("}", "")
|
|
newprompt = newprompt.replace("?", "")
|
|
newprompt = newprompt.replace("!", "")
|
|
newprompt = " ".join(newprompt.split(sep=None))
|
|
return newprompt
|
|
|
|
def strip_text(prompt):
|
|
newprompt = str(prompt).lower()
|
|
newprompt = " ".join(newprompt.split(sep=None))
|
|
return newprompt
|
|
|
|
model = {"model_data": {}}
|
|
def load_model(filename: str):
|
|
model["model_data"] = json.loads(open(filename, "r").read())
|
|
|
|
def version_03_inference(prompt: str, max_tokens: int=None, repetition_penalty: int=2):
|
|
tokens_generated = 0
|
|
split_prompt = strip_prompt(prompt).split(sep=None)
|
|
model_data = model["model_data"]
|
|
outputs = model_data["outputs"]
|
|
raw_outputs = model_data["raw_outputs"]
|
|
prompts = model_data["prompts"]
|
|
ends = model_data["ends"]
|
|
start = ""
|
|
topic = None
|
|
for token in split_prompt:
|
|
if token in prompts:
|
|
start = max(prompts[token], key=prompts[token].get)
|
|
topic = token
|
|
break
|
|
if topic == None:
|
|
outputs = raw_outputs
|
|
topic = None
|
|
start = split_prompt[-1]
|
|
tokens_generated += 1
|
|
running = True
|
|
current_token = [start]
|
|
while running:
|
|
token = current_token[0]
|
|
yield f"{token} "
|
|
if token in outputs:
|
|
next_token = max(outputs[token], key=outputs[token].get)
|
|
outputs[token][next_token] -= repetition_penalty
|
|
else:
|
|
next_token = random.choice(list(outputs.keys()))
|
|
current_token[0] = next_token
|
|
tokens_generated += 1
|
|
if max_tokens != None:
|
|
if tokens_generated >= max_tokens:
|
|
running = False
|
|
if topic:
|
|
if token in ends[topic]:
|
|
running = False
|
|
else:
|
|
tokens_generated += 1
|
|
running = True
|
|
current_token = [start]
|
|
while running:
|
|
token = current_token[0]
|
|
yield f"{token} "
|
|
if outputs.get(topic) != None:
|
|
if token in outputs[topic]:
|
|
next_token = max(outputs[topic][token], key=outputs[topic][token].get)
|
|
outputs[topic][token][next_token] -= repetition_penalty
|
|
else:
|
|
next_token = random.choice(list(outputs.keys()))
|
|
current_token[0] = next_token
|
|
tokens_generated += 1
|
|
if max_tokens != None:
|
|
if tokens_generated >= max_tokens:
|
|
running = False
|
|
if topic:
|
|
if token in ends[topic]:
|
|
running = False
|
|
else:
|
|
running = False
|
|
|
|
def version_02_inference(prompt: str, max_tokens: int=None, repetition_penalty: int=1):
|
|
tokens_generated = 0
|
|
split_prompt = strip_prompt(prompt).split(sep=None)
|
|
model_data = model["model_data"]
|
|
outputs = model_data["outputs"]
|
|
prompts = model_data["prompts"]
|
|
ends = model_data["ends"]
|
|
start = ""
|
|
for token in split_prompt:
|
|
if token in prompts:
|
|
start = max(prompts[token], key=prompts[token].get)
|
|
topic = token
|
|
break
|
|
else:
|
|
topic = random.choice(list(ends))
|
|
start = random.choice(list(prompts.keys()))
|
|
tokens_generated += 1
|
|
running = True
|
|
current_token = [start]
|
|
while running:
|
|
token = current_token[0]
|
|
yield f"{token} "
|
|
if token in outputs:
|
|
next_token = max(outputs[token], key=outputs[token].get)
|
|
outputs[token][next_token] -= repetition_penalty
|
|
else:
|
|
next_token = random.choice(list(outputs.keys()))
|
|
current_token[0] = next_token
|
|
tokens_generated += 1
|
|
if max_tokens != None:
|
|
if tokens_generated >= max_tokens:
|
|
running = False
|
|
if topic:
|
|
if token in ends[topic]:
|
|
running = False
|
|
|
|
def version_01_inference(prompt: str, max_tokens: int=None, repetition_penalty: int=1):
|
|
tokens_generated = 0
|
|
split_prompt = strip_prompt(prompt).split(sep=None)
|
|
model_data = model["model_data"]
|
|
outputs = model_data["outputs"]
|
|
prompts = model_data["prompts"]
|
|
start = ""
|
|
for token in split_prompt:
|
|
if token in prompts:
|
|
start = max(prompts[token], key=prompts[token].get)
|
|
tokens_generated += 1
|
|
running = True
|
|
current_token = [start]
|
|
while running:
|
|
token = current_token[0]
|
|
yield f"{token} "
|
|
if token in outputs:
|
|
next_token = max(outputs[token], key=outputs[token].get)
|
|
outputs[token][next_token] -= repetition_penalty
|
|
else:
|
|
next_token = random.choice(list(outputs.keys()))
|
|
current_token[0] = next_token
|
|
tokens_generated += 1
|
|
if max_tokens != None:
|
|
if tokens_generated >= max_tokens:
|
|
running = False
|
|
|
|
def run_model(prompt: str, max_tokens: int=None, repetition_penalty: int=1, temperature: float=0):
|
|
|
|
model_data = model["model_data"]
|
|
model_format = model_data["format"]
|
|
if model_data["format"] == "v0.1":
|
|
response = version_01_inference(prompt, max_tokens=max_tokens, repetition_penalty=repetition_penalty)
|
|
for chunk in response:
|
|
yield chunk
|
|
|
|
if model_data["format"] == "v0.2":
|
|
response = version_02_inference(prompt, max_tokens=max_tokens, repetition_penalty=repetition_penalty)
|
|
for chunk in response:
|
|
yield chunk
|
|
|
|
if model_data["format"] == "v0.3":
|
|
response = version_03_inference(prompt, max_tokens=max_tokens, repetition_penalty=repetition_penalty)
|
|
for chunk in response:
|
|
yield chunk
|
|
|