silk-road's picture
Upload response_qwen_base.py
730647b verified
raw
history blame
3.54 kB
# coding: utf-8
import warnings
warnings.filterwarnings("ignore")
import os
import re
import json
import torch
import pickle
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM
client = None
def get_prompt(message):
#prompt = system_info.format(role_name=role_name, persona=persona)
persona = ""
for msg in message:
if msg["role"] == "system":
persona = persona + msg["content"]
prompt = "<<SYS>>" + persona + "<</SYS>>"
from ChatHaruhi.utils import normalize2uaua
message_ua = normalize2uaua(message[1:], if_replace_system = True)
for i in range(0, len(message_ua)-1, 2):
prompt = prompt + "[INST]" + message_ua[i]["content"] + "[/INST]" + message_ua[i+1]["content"] + "<|im_end|>"
prompt = prompt + "[INST]" + message_ua[-1]["content"] + "[/INST]"
print(prompt)
return prompt
import os
class qwen_model:
def __init__(self, model_name):
self.DEVICE = torch.device("cuda")
self.tokenizer = AutoTokenizer.from_pretrained(
"silk-road/"+model_name,
low_cpu_mem_usage=True,
use_fast = False,
padding_side="left",
trust_remote_code=True
)
if self.tokenizer.pad_token is None:
self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
self.tokenizer.eos_token_id = 151645
# print(tokenizer.eos_token_id)
# print(tokenizer.pad_token_id)
self.model = AutoModelForCausalLM.from_pretrained(
"silk-road/"+model_name,
load_in_8bit=False,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
device_map='auto',
trust_remote_code=True,
).eval()
# model.to("cuda")
# model.eval()
# self.tokenizer = AutoTokenizer.from_pretrained("silk-road/"+model_name, trust_remote_code=True)
# self.model = AutoModelForCausalLM.from_pretrained("silk-road/"+model_name, device_map="auto", trust_remote_code=True).eval()
def get_response(self, message):
with torch.inference_mode():
prompt = get_prompt(message)
batch = self.tokenizer(prompt, return_tensors="pt", padding=True)
batch = self.tokenizer(prompt,
return_tensors="pt",
padding=True,
add_special_tokens=False)
batch = {k: v.to(self.DEVICE) for k, v in batch.items()}
generated = self.model.generate(input_ids=batch["input_ids"],
max_new_tokens=1024,
temperature=0.2,
top_p=0.9,
top_k=40,
do_sample=False,
num_beams=1,
repetition_penalty=1.3,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.pad_token_id)
response = self.tokenizer.decode(generated[0][batch["input_ids"].shape[1]:]).strip().replace("<|im_end|>", "")
return response
def init_client(model_name):
# 将client设置为全局变量
global client
client = qwen_model(model_name = model_name)
def get_response(message, model_name = "Haruhi-Zero-1_8B-0_4"):
if client is None:
init_client(model_name)
response = client.get_response(message)
return response