go-bruins-v2 / handler.py
rwitz's picture
Update handler.py
f0e6827
raw
history blame
No virus
2.13 kB
import runpod
import os
import time
from exllamav2 import(
ExLlamaV2,
ExLlamaV2Config,
ExLlamaV2Cache,
ExLlamaV2Tokenizer,
)
from exllamav2.generator import (
ExLlamaV2BaseGenerator,
ExLlamaV2Sampler,
ExLlamaV2StreamingGenerator
)
config = ExLlamaV2Config()
model_directory = "/go-bruins-v2/"
if not os.path.isdir(model_directory):
os.makedirs(model_directory)
from huggingface_hub import snapshot_download
snapshot_download(repo_id="LoneStriker/go-bruins-v2-6.0bpw-h6-exl2-2",allow_patterns=["*.json","*.model","*.safetensors"],local_dir=model_directory)
config.model_dir = model_directory
config.prepare()
model = ExLlamaV2(config)
model.load([24])
cache = ExLlamaV2Cache(model)
tokenizer = ExLlamaV2Tokenizer(config)
generator = ExLlamaV2StreamingGenerator(model, cache, tokenizer)
# Generate some text
sleep_time = int(os.environ.get('SLEEP_TIME', 3))
# Use a pipeline as a high-level helper
## load your model(s) into vram here
def handler(event):
inp = event["input"]
prompt=inp['prompt']
sampling_params=inp['sampling_params']
maxn=sampling_params['max_new_tokens']
temp=sampling_params['temperature']
top_p=sampling_params['top_p']
settings = ExLlamaV2Sampler.Settings()
settings.temperature = temp
settings.top_k = sampling_params['top_k']
settings.top_p = top_p
settings.max_new_tokens=maxn
settings.token_repetition_penalty = 1.15
generator.warmup()
def run(prompt):
tokens=0
output=generator.begin_stream(tokenizer.encode(prompt), settings)
response_text = ""
while True:
chunk, eos, token = generator.stream()
response_text += chunk
#print(response_text)
if "|im_end|" in response_text:
return response_text.replace("|im_end|","")
break
elif tokens>=maxn:
return response_text.replace("|im_end|","")
break
tokens=len(response_text.split(" "))
return run(prompt)
runpod.serverless.start({
"handler": handler
})