Spaces:
Sleeping
Sleeping
import logging | |
import warnings | |
import os | |
from tqdm import tqdm | |
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig | |
import transformers | |
import torch | |
import gc | |
from torch.utils.data import DataLoader, TensorDataset | |
from torch.nn.utils.rnn import pack_padded_sequence | |
from calc_metrics import calculate_log_sum,calculate_log_last | |
import torch.nn.functional as F | |
import logging | |
import time | |
import traceback | |
import datetime | |
doday=datetime.datetime.now().strftime("%Y-%m-%d") | |
# 配置日志 | |
extra_info='fill' | |
# logging.basicConfig(level=logging.INFO,filename='/wangbenyou/chenghao/fersh_bench/log/app.log', filemode='a', format='%(name)s - %(levelname)s - %(message)s') | |
# logging.basicConfig(level=logging.INFO,filename=f'../log/app_jieduan_{extra_info}{doday}_year.log', filemode='a', format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
import torch | |
import pdb | |
import json | |
paths=[ | |
'/mntcephfs/data/med/fanyaxin/Qwen-7B-Chat', | |
] | |
# file_in_data_folder='2024-01-04_18' | |
# file_in_data_folder='2023-12-31' | |
file_in_data_folder='2023-12-27' | |
# file_in_data_folder='2020_100' | |
# file_in_data_folder='2020' | |
# file_in_data_folder='2014' | |
# file_in_data_folder='2017' | |
# file_in_data_folder='2019' | |
# file_in_data_folder='2019' | |
# file_in_data_folder='rephrase_MMLU' | |
# file_in_data_folder='mock_MMLU' | |
# mmlu_mock_concat | |
# not arxiv not year, but rep MMLU | |
# 你的语料列表 | |
import get_text | |
# file_dic_list_strings=get_text.file_dic_list_strings | |
limit_lines_per_file=10 | |
file_dic_list_strings=get_text.get_text_from(file_in_data_folder,limit=limit_lines_per_file) | |
# file_dic_list_strings=get_text.get_mmlu_rephrase_text(directory='/mntnfs/med_data5/chenghao/fresh_eval/data/mmlu_rephrase_concat/gpt-4-1106-preview/') | |
# file_dic_list_strings=get_text.get_mmlu_rephrase_text(directory='/mntnfs/med_data5/chenghao/fresh_eval/data/mmlu_mock_concat/gpt-4-1106-preview/') | |
# file_in_data_folder='2024-01-03' | |
def get_rwkv_model_tokenizer(model_name): | |
os.environ['RWKV_JIT_ON'] = '1' | |
os.environ["RWKV_CUDA_ON"] = '1' | |
from rwkv.model import RWKV | |
from rwkv.utils import PIPELINE | |
model=RWKV(model=model_name, strategy='cuda fp16') | |
pipeline = PIPELINE(model, r"rwkv_vocab_v20230424") | |
tokenizer = pipeline.tokenizer | |
return model,tokenizer | |
def get_mamba_model_tokenizer(model_name): | |
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel | |
device = "cuda" | |
tokenizer = AutoTokenizer.from_pretrained("/mntcephfs/data/med/chenghao/models/gpt-neox-20b_tokenizer") | |
model = MambaLMHeadModel.from_pretrained(model_name, device=device, dtype=torch.float16) | |
return model,tokenizer | |
def get_HF_model_tokenizer(model_name): | |
if 'llama_hf_13b' in model_name: | |
tokenizer = transformers.LlamaTokenizer.from_pretrained(model_name, unk_token="<unk>") | |
else: | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
if 'zephyr' in model_name.lower(): | |
model = AutoModelForCausalLM.from_pretrained(model_name,device_map="auto").eval() | |
else: | |
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", trust_remote_code=True).eval() | |
return model,tokenizer | |
limit_lines_per_file=10 | |
def run_model_on_dic(config): | |
config['clear_log_first']=True | |
logging.info("start up") | |
paths=config['model_path'] | |
file_dic_list_strings=config['file_dic_list_strings'] | |
detail_log_base=config['detail_log_path'] | |
extract_log_base=config['extract_log_path'] | |
max_sequence_length,max_str_len,limit_lines_per_file=config['max_sequence_length'],config['max_str_len'],config['limit_lines_per_file'] | |
for model_name in tqdm(paths): | |
model_name=model_name.strip() | |
tmp_path=model_name[:-1] if model_name[-1]=='/' else model_name | |
short_model_name=tmp_path.split('/')[-1] | |
config['detail_log_path']=detail_log_base.replace('TOFILL',f'{short_model_name}') | |
config['extract_log_path']=extract_log_base.replace('TOFILL',f'{short_model_name}') | |
if 'clear_log_first' in config.keys() and config['clear_log_first'] is True: | |
with open( config['extract_log_path'],'w')as f: | |
f.write('') | |
with open( config['detail_log_path'],'w')as f: | |
f.write('') | |
print(f'\n log cleared! ') | |
logging.basicConfig(level=logging.INFO,filename=config['detail_log_path'], filemode='a', format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',force=True) | |
print() | |
print('model_path',model_name) | |
print(f'extract_log_path:{config["extract_log_path"]}\ndetail_log_path:{config["detail_log_path"]}') | |
print() | |
try: | |
if config['model_type']=='RWKV':#'HF' not in model_name and (('RWKV' in model_name) or ('rwkv' in model_name )): | |
model,tokenizer=get_rwkv_model_tokenizer(model_name) | |
elif config['model_type']=='MAMBA':#('mamba' in model_name) or ('MAMBA'in model_name ): | |
model,tokenizer=get_mamba_model_tokenizer(model_name) | |
elif config['model_type']=='HF':#'HF' in model_name: | |
model,tokenizer=get_HF_model_tokenizer(model_name) | |
print(f'model device:{model.device}') | |
print('[tokenizer.cls_token]',[tokenizer.cls_token]) | |
print('[tokenizer.sep_token]',[tokenizer.sep_token]) | |
else: | |
raise Exception('model type not found') | |
# === get model and tokenizer | |
for file_name,corpus in file_dic_list_strings.items(): | |
tokenized_corpus=[] | |
for text in corpus: | |
text=text[:max_str_len] | |
if config['model_type']=='RWKV': | |
#'HF' not in model_name and (('RWKV' in model_name) or ('rwkv' in model_name )): | |
tokenized_corpus.append(tokenizer.encode(text)) | |
elif 'HF' in model_name and ('RWKV' in model_name): | |
tokenized_corpus.append(tokenizer(text, return_tensors="pt")['input_ids']) | |
elif ('mamba' in model_name) or ('MAMBA'in model_name ): | |
device=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
tokenized_corpus.append(tokenizer(text, return_tensors="pt").input_ids.to(device=device)) | |
else: | |
tokens = tokenizer.tokenize(text) | |
if tokenizer.cls_token:# attention here is not [None] | |
tokens = [tokenizer.cls_token] + tokens | |
if tokenizer.sep_token: | |
tokens = tokens +[tokenizer.sep_token] | |
input_ids = tokenizer.convert_tokens_to_ids(tokens) | |
tokenized_corpus.append(input_ids) | |
# tokenized_corpus.append(tokenizer(text, return_tensors="pt")['input_ids']) | |
processed_sequences = [] | |
# 遍历 tokenized_corpus,截断或补全序列 | |
for sequence in tokenized_corpus: | |
# print('len(sequence)',len(sequence)) | |
if len(sequence) < max_sequence_length: | |
pass | |
# 补全序列 | |
# sequence = sequence + [tokenizer.pad_token_id] * (max_sequence_length - len(sequence)) | |
# print(f'longer {max_sequence_length - len(sequence)}') | |
elif len(sequence) > max_sequence_length: | |
# 截断序列 | |
sequence = sequence[:max_sequence_length] | |
# 将处理后的序列添加到列表中 | |
processed_sequences.append(sequence) | |
total_loss = 0.0 | |
total_tokens = 0 | |
# pdb.set_trace() | |
for enu,batch_input_ids in tqdm(enumerate(processed_sequences)): | |
# if 'test_fun_dev' in config['detail_log_path'] and enu>50: | |
# print(f'enu:{enu} batch_input_ids: break') | |
# break | |
batch_input_ids=torch.tensor(batch_input_ids).unsqueeze(0) | |
with torch.no_grad(): | |
# 获取模型的输出 | |
# pdb.set_trace() | |
if config['model_type']=='RWKV': | |
# if 'HF' not in model_name and (('RWKV' in model_name) or ('rwkv' in model_name )): | |
# print('rwkv1') | |
# pdb.set_trace() | |
# logits = model.forward(batch_input_ids.squeeze().to(torch.float32), None, full_output=True)[0] | |
logits = model.forward(batch_input_ids.squeeze().long(), None, full_output=True)[0] | |
# logits = model.forward(batch_input_ids.squeeze(), None, full_output=True)[0] | |
# print(logits.shape) | |
''' | |
tmp=torch.tensor(batch_input_ids).unsqueeze(0) | |
logits = model.forward(batch_input_ids.squeeze().long(), None) | |
logits = model.forward(batch_input_ids.long(), None,)[0] | |
for output in outputs:print(tokenizer.decode(output.tolist(), skip_special_tokens=True)) | |
''' | |
# loss = torch.nn.functional.cross_entropy(logits[ :-1, :].view(-1, logits.shape[-1]).to(torch.float32), batch_input_ids[0,1:].to(logits.device).view(-1).to(torch.float32), reduction='none') | |
loss = torch.nn.functional.cross_entropy(logits[ :-1, :].view(-1, logits.shape[-1]).to(torch.float32), batch_input_ids[0,1:].to(logits.device).view(-1), reduction='none') | |
elif config['model_type']=='MAMBA': | |
# pdb.set_trace() | |
mamba_output = model.forward(batch_input_ids[0])#the shape should be like (1,length) | |
logits = mamba_output.logits | |
loss = torch.nn.functional.cross_entropy(logits[:, :-1, :].view(-1, logits.shape[-1]), batch_input_ids[0][:,1:].view(-1), reduction='none') | |
# pdb.set_trace() | |
elif config['model_type']=='HF': | |
if 'HF' in model_name and 'RWKV' in model_name: | |
# pdb.set_trace() | |
batch_input_ids=batch_input_ids.to(model.device) | |
logits = model.forward(batch_input_ids[0]).logits#the shape should be like (1,length) | |
loss = torch.nn.functional.cross_entropy(logits[:, :-1, :].view(-1, logits.shape[-1]), batch_input_ids[0][:,1:].view(-1), reduction='none') | |
''' | |
batch_input_ids=batch_input_ids.to(model.device) | |
HuggingFace-Download-Accelerator/ | |
(Pdb) c | |
/mntnfs/med_data5/chenghao/fresh_eval/src/fun_base_fill_LLM.py:324: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). | |
''' | |
else: | |
outputs = model(batch_input_ids) | |
# 取出模型的logits | |
if 'chatglm3-6b' in model_name: | |
logits = outputs.logits.float() | |
else: | |
logits = outputs.logits | |
loss = torch.nn.functional.cross_entropy(logits[:, :-1, :].view(-1, logits.shape[-1]), batch_input_ids[:,1:].view(-1), reduction='none') | |
loss_sum = loss.sum() | |
loss_mean = loss.mean() | |
losses_list = loss.tolist() | |
# 准备要写入日志的数据 | |
tmp_dic = { | |
'model_name': model_name, | |
'file_name': file_name, | |
'lengths': len(batch_input_ids[0]), | |
'length_str':len(corpus[enu][:max_str_len]), | |
'loss_sum': loss_sum.item(), # 转换为Python标准数据类型 | |
'loss_mean': loss_mean.item(), | |
'losses_list': losses_list | |
} | |
import json | |
with open(config['detail_log_path'], 'a') as f: | |
json.dump(tmp_dic, f) | |
f.write("\n") | |
total_loss += loss.sum().item() | |
total_tokens += batch_input_ids.numel() | |
# 计算每个类别的平均损失 | |
# pdb.set_trace() | |
average_loss = total_loss / total_tokens | |
avg_str_loss = total_loss/len(tokenized_corpus) | |
print(f"{file_name} total loss:", average_loss) | |
import json | |
logs = { | |
"model_name": model_name, | |
"file_name": file_name, | |
"processed_sequences": len(processed_sequences), | |
"average_loss": average_loss, | |
"avg_str_loss": avg_str_loss | |
} | |
# with open(f'/mntnfs/med_data5/chenghao/fresh_eval/log/year_arxiv/j_y_ans_{file_in_data_folder}.json', 'a') as f: | |
with open(config['extract_log_path'], 'a') as f: | |
json.dump(logs, f) | |
f.write("\n") | |
logging.info(logs) | |
except Exception as e: | |
logging.error(f"{model_name}, error:{e} ,detail:{traceback.format_exc()}") | |
with open(config['extract_log_path'], 'a') as f: | |
# json.dump(logs, f) | |
f.write(f"{model_name} failed \n") | |
print(f"{model_name} failed for {e} detail:{traceback.format_exc()}\n") | |
if __name__=='__main__': | |
config={} | |
print(file_in_data_folder) | |
file_dic_list_strings=get_text.get_text_from(file_in_data_folder,limit=limit_lines_per_file) | |
config['max_sequence_length'],config['max_str_len'],config['limit_lines_per_file']=2048,5000,10 | |
config['extract_log_path']=f'/mntnfs/med_data5/chenghao/fresh_eval/log/test_fun_dev/extract.log' | |
config['detail_log_path']=f'/mntnfs/med_data5/chenghao/fresh_eval/log/test_fun_dev/detail.log' | |
config['model_path']='/mntnfs/med_data5/liangjuhao/models/TinyLlama-1.1B-Chat-v0.6'#paths[:1] | |
config['batch']=16 | |
config['model_type']='HF' | |
print('start',config['model_path']) | |
config['file_dic_list_strings']=file_dic_list_strings | |
run_model_on_dic(config) |