File size: 6,837 Bytes
4f30dbd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6feab05
4f30dbd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import json
import random

def strip_prompt(prompt): # used to make it more likely for the prompt to be understood
    newprompt = str(prompt).lower()
    newprompt = newprompt.replace(".", "")
    newprompt = newprompt.replace("[", "")
    newprompt = newprompt.replace("]", "")
    newprompt = newprompt.replace(":", "")
    newprompt = newprompt.replace(",", "")
    newprompt = newprompt.replace("\"", "")
    newprompt = newprompt.replace("'", "")
    newprompt = newprompt.replace("(", "")
    newprompt = newprompt.replace(")", "")
    newprompt = newprompt.replace(";", "")
    newprompt = newprompt.replace("-", "")
    newprompt = newprompt.replace("_", "")
    newprompt = newprompt.replace("{", "")
    newprompt = newprompt.replace("}", "")
    newprompt = newprompt.replace("?", "")
    newprompt = newprompt.replace("!", "")
    newprompt = " ".join(newprompt.split(sep=None))
    return newprompt

def strip_text(prompt): # kinda wacky overall
    newprompt = str(prompt).lower()
    newprompt = " ".join(newprompt.split(sep=None))
    return newprompt

model = {"model_data": {}}
def load_model(filename: str):
    model["model_data"] = json.loads(open(filename, "r").read())

def version_03_inference(prompt: str, max_tokens: int=None, repetition_penalty: int=2): 
    tokens_generated = 0
    split_prompt = strip_prompt(prompt).split(sep=None)
    model_data = model["model_data"]
    outputs = model_data["outputs"]
    raw_outputs = model_data["raw_outputs"]
    prompts = model_data["prompts"]
    ends = model_data["ends"]
    start = ""
    topic = None
    for token in split_prompt:
        if token in prompts:
            start = max(prompts[token], key=prompts[token].get)
            topic = token
            break
    if topic == None: # use raw outputs
        outputs = raw_outputs
        topic = None
        start = split_prompt[-1]
        tokens_generated += 1
        running = True
        current_token = [start]
        while running:
            token = current_token[0]
            yield f"{token} "
            if token in outputs:
                next_token = max(outputs[token], key=outputs[token].get)
                outputs[token][next_token] -= repetition_penalty
            else:
                next_token = random.choice(list(outputs.keys()))
            current_token[0] = next_token
            tokens_generated += 1
            if max_tokens != None:
                if tokens_generated >= max_tokens:
                    running = False
            if topic:
                if token in ends[topic]:
                    running = False
    else:
        tokens_generated += 1
        running = True
        current_token = [start]
        while running:
            token = current_token[0]
            yield f"{token} "
            if outputs.get(topic) != None:
                if token in outputs[topic]:
                    next_token = max(outputs[topic][token], key=outputs[topic][token].get)
                    outputs[topic][token][next_token] -= repetition_penalty
                else:
                    next_token = random.choice(list(outputs.keys()))
                current_token[0] = next_token
                tokens_generated += 1
                if max_tokens != None:
                    if tokens_generated >= max_tokens:
                        running = False
                if topic:
                    if token in ends[topic]:
                        running = False
            else:
                running = False # this is because single token responses seem to break things

def version_02_inference(prompt: str, max_tokens: int=None, repetition_penalty: int=1): 
    tokens_generated = 0
    split_prompt = strip_prompt(prompt).split(sep=None)
    model_data = model["model_data"]
    outputs = model_data["outputs"]
    prompts = model_data["prompts"]
    ends = model_data["ends"]
    start = ""
    for token in split_prompt:
        if token in prompts:
            start = max(prompts[token], key=prompts[token].get)
            topic = token
            break
        else:
            topic = random.choice(list(ends))
            start = random.choice(list(prompts.keys()))
    tokens_generated += 1
    running = True
    current_token = [start]
    while running:
        token = current_token[0]
        yield f"{token} "
        if token in outputs:
            next_token = max(outputs[token], key=outputs[token].get)
            outputs[token][next_token] -= repetition_penalty
        else:
            next_token = random.choice(list(outputs.keys()))
        current_token[0] = next_token
        tokens_generated += 1
        if max_tokens != None:
            if tokens_generated >= max_tokens:
                running = False
        if topic:
            if token in ends[topic]:
                running = False

def version_01_inference(prompt: str, max_tokens: int=None, repetition_penalty: int=1): 
    tokens_generated = 0
    split_prompt = strip_prompt(prompt).split(sep=None)
    model_data = model["model_data"]
    outputs = model_data["outputs"]
    prompts = model_data["prompts"]
    start = ""
    for token in split_prompt:
        if token in prompts:
            start = max(prompts[token], key=prompts[token].get)
    tokens_generated += 1
    running = True
    current_token = [start]
    while running:
        token = current_token[0]
        yield f"{token} "
        if token in outputs:
            next_token = max(outputs[token], key=outputs[token].get)
            outputs[token][next_token] -= repetition_penalty
        else:
            next_token = random.choice(list(outputs.keys()))
        current_token[0] = next_token
        tokens_generated += 1
        if max_tokens != None:
            if tokens_generated >= max_tokens:
                running = False

def run_model(prompt: str, max_tokens: int=None, repetition_penalty: int=1, temperature: float=0):
    # (temperature does not work on versions below 0.3)
    model_data = model["model_data"]
    model_format = model_data["format"] 
    if model_data["format"] == "v0.1":
        response = version_01_inference(prompt, max_tokens=max_tokens, repetition_penalty=repetition_penalty)
        for chunk in response:
            yield chunk
    
    if model_data["format"] == "v0.2":
        response = version_02_inference(prompt, max_tokens=max_tokens, repetition_penalty=repetition_penalty)
        for chunk in response:
            yield chunk

    if model_data["format"] == "v0.3":
        response = version_03_inference(prompt, max_tokens=max_tokens, repetition_penalty=repetition_penalty)
        for chunk in response:
            yield chunk