IvyGPT
Collection
5 items
•
Updated
import os
import torch
import platform
import subprocess
from transformers import AutoModelForCausalLM, AutoTokenizer
#from transformers.generation.utils import GenerationConfig
from transformers.generation import GenerationConfig
import json
from tqdm import tqdm
import re
def init_model():
print("init model ...")
'''
model = AutoModelForCausalLM.from_pretrained(
"./exportchatml",
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True
)
model.generation_config = GenerationConfig.from_pretrained(
"./exportchatml"
)
tokenizer = AutoTokenizer.from_pretrained(
"./exportchatml",
use_fast=False,
trust_remote_code=True
)
'''
tokenizer = AutoTokenizer.from_pretrained(
"./exportchatml", trust_remote_code=True, resume_download=True,
)
device_map = "auto"
model = AutoModelForCausalLM.from_pretrained(
"./exportchatml",
device_map=device_map,
trust_remote_code=True,
resume_download=True,
).eval()
'''
config = GenerationConfig.from_pretrained(
"./exportchatml", trust_remote_code=True, resume_download=True,
)
'''
config = GenerationConfig(chat_format='chatml', eos_token_id=151643, pad_token_id=151643, max_window_size=6144, max_new_tokens=512, do_sample=True, top_k=0, top_p=0.5)
return model, tokenizer, config
def main():
model, tokenizer, config = init_model()
#tokenizer = AutoTokenizer.from_pretrained("./exportchatml", trust_remote_code=True)
#model = AutoModelForCausalLM.from_pretrained("./exportchatml", trust_remote_code=True)
#model = model.eval()
with open('CMB-test-choice-question-merge.json', 'r', encoding='utf-8') as file:
data = json.load(file)
results = []
for i in tqdm(range(len(data))):
messages = []
test_id = data[i]['id']
exam_type = data[i]['exam_type']
exam_class = data[i]['exam_class']
question_type = data[i]['question_type']
question = data[i]['question']
option_str = 'A. ' + data[i]['option']['A'] + '\nB. ' + data[i]['option']['B']+ '\nC. ' + data[i]['option']['C']+ '\nD. ' + data[i]['option']['D']
#prompt = f'以下是中国{exam_type}中{exam_class}考试的一道{question_type},不需要做任何分析和解释,直接给出正确选项。\n{question}\n{option_str}'
prompt = f'以下是一道{question_type},不需要做任何分析解释,直接给出正确选项:\n{question}\n{option_str}'
#messages.append({"role": "user", "content": prompt})
#history = ''
#response = model.chat(tokenizer, messages,history)
#print(response)
#print(type(response))
'''
generation_config = GenerationConfig(max_new_tokens=1024)
text = 'User: '+prompt+'<|endoftext|>\n Assistant: '
inputs = tokenizer.encode(text, return_tensors="pt").to('cpu')
outputs = model.generate(inputs, generation_config=generation_config)
output = tokenizer.decode(outputs[0])
response = output.replace(inputs, '')
print(response)
'''
history = []
response, history = model.chat(tokenizer, prompt, history=history, generation_config=config)
print(response)
matches = re.findall("[ABCDE]", response)
print(matches)
final_result = "".join(matches)
info = {
"id": test_id,
"model_answer": final_result
}
results.append(info)
history.clear()
with open('output.json', 'w', encoding="utf-8") as f1:
json.dump(results, f1, ensure_ascii=False, indent=4)
if __name__ == "__main__":
main()