import run_stok | |
import sys | |
from run_stok import load_model, run_model | |
import time | |
total = [] | |
model = "stok-0.3.json" | |
show_speed = False | |
if len(sys.argv) > 1: # it is set up like this to add more parameters in the future | |
if sys.argv[1] == "help": | |
print("help - shows this command") | |
print("-m <model> - specifies the file you want to inference") | |
print("-speed - if added, enables speed logging") | |
args = list(sys.argv) | |
running = True | |
while running: | |
if len(args) < 2: | |
running = False | |
elif args[1] == "-m": | |
model = args[2] | |
args.pop(1) | |
args.pop(1) | |
elif args[1] == "-speed": | |
show_speed = True | |
args.pop(1) | |
else: | |
running = False | |
load_model(model) | |
running = True | |
while running: | |
total = [] | |
message = input(">>>") | |
if message == "/quit" or message == "/exit" or message == "/bye": | |
running = False | |
else: | |
chunks = run_model(message, max_tokens=100, repetition_penalty=2) | |
start = time.time() | |
for chunk in chunks: | |
total.append(chunk) | |
print(chunk, end="") | |
end = time.time() | |
print() | |
if show_speed: | |
print(f"Took: {end-start}s") | |
print(f"Generated: {len(total)}") | |
print(f"Speed: {len(total)/(end-start)} t/s") | |
print("_____________________________") | |