Spaces:
Sleeping
Sleeping
| import logging | |
| import warnings | |
| import os | |
| from tqdm import tqdm | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig | |
| import transformers | |
| import torch | |
| import gc | |
| from torch.utils.data import DataLoader, TensorDataset | |
| from torch.nn.utils.rnn import pack_padded_sequence | |
| from calc_metrics import calculate_log_sum,calculate_log_last | |
| import torch.nn.functional as F | |
| import logging | |
| import time | |
| import traceback | |
| import datetime | |
| doday=datetime.datetime.now().strftime("%Y-%m-%d") | |
| # 配置日志 | |
| extra_info='fill' | |
| # logging.basicConfig(level=logging.INFO,filename='/wangbenyou/chenghao/fersh_bench/log/app.log', filemode='a', format='%(name)s - %(levelname)s - %(message)s') | |
| # logging.basicConfig(level=logging.INFO,filename=f'../log/app_jieduan_{extra_info}{doday}_year.log', filemode='a', format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
| import torch | |
| import pdb | |
| import json | |
| paths=[ | |
| '/mntcephfs/data/med/fanyaxin/Qwen-7B-Chat', | |
| ] | |
| # file_in_data_folder='2024-01-04_18' | |
| # file_in_data_folder='2023-12-31' | |
| file_in_data_folder='2023-12-27' | |
| # file_in_data_folder='2020_100' | |
| # file_in_data_folder='2020' | |
| # file_in_data_folder='2014' | |
| # file_in_data_folder='2017' | |
| # file_in_data_folder='2019' | |
| # file_in_data_folder='2019' | |
| # file_in_data_folder='rephrase_MMLU' | |
| # file_in_data_folder='mock_MMLU' | |
| # mmlu_mock_concat | |
| # not arxiv not year, but rep MMLU | |
| # 你的语料列表 | |
| import get_text | |
| # file_dic_list_strings=get_text.file_dic_list_strings | |
| limit_lines_per_file=10 | |
| file_dic_list_strings=get_text.get_text_from(file_in_data_folder,limit=limit_lines_per_file) | |
| # file_dic_list_strings=get_text.get_mmlu_rephrase_text(directory='/mntnfs/med_data5/chenghao/fresh_eval/data/mmlu_rephrase_concat/gpt-4-1106-preview/') | |
| # file_dic_list_strings=get_text.get_mmlu_rephrase_text(directory='/mntnfs/med_data5/chenghao/fresh_eval/data/mmlu_mock_concat/gpt-4-1106-preview/') | |
| # file_in_data_folder='2024-01-03' | |
| def get_rwkv_model_tokenizer(model_name): | |
| os.environ['RWKV_JIT_ON'] = '1' | |
| os.environ["RWKV_CUDA_ON"] = '1' | |
| from rwkv.model import RWKV | |
| from rwkv.utils import PIPELINE | |
| model=RWKV(model=model_name, strategy='cuda fp16') | |
| pipeline = PIPELINE(model, r"rwkv_vocab_v20230424") | |
| tokenizer = pipeline.tokenizer | |
| return model,tokenizer | |
| def get_mamba_model_tokenizer(model_name): | |
| from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel | |
| device = "cuda" | |
| tokenizer = AutoTokenizer.from_pretrained("/mntcephfs/data/med/chenghao/models/gpt-neox-20b_tokenizer") | |
| model = MambaLMHeadModel.from_pretrained(model_name, device=device, dtype=torch.float16) | |
| return model,tokenizer | |
| def get_HF_model_tokenizer(model_name): | |
| if 'llama_hf_13b' in model_name: | |
| tokenizer = transformers.LlamaTokenizer.from_pretrained(model_name, unk_token="<unk>") | |
| else: | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
| if 'zephyr' in model_name.lower(): | |
| model = AutoModelForCausalLM.from_pretrained(model_name,device_map="auto").eval() | |
| else: | |
| model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", trust_remote_code=True).eval() | |
| return model,tokenizer | |
| limit_lines_per_file=10 | |
| def run_model_on_dic(config): | |
| config['clear_log_first']=True | |
| logging.info("start up") | |
| paths=config['model_path'] | |
| file_dic_list_strings=config['file_dic_list_strings'] | |
| detail_log_base=config['detail_log_path'] | |
| extract_log_base=config['extract_log_path'] | |
| max_sequence_length,max_str_len,limit_lines_per_file=config['max_sequence_length'],config['max_str_len'],config['limit_lines_per_file'] | |
| for model_name in tqdm(paths): | |
| model_name=model_name.strip() | |
| tmp_path=model_name[:-1] if model_name[-1]=='/' else model_name | |
| short_model_name=tmp_path.split('/')[-1] | |
| config['detail_log_path']=detail_log_base.replace('TOFILL',f'{short_model_name}') | |
| config['extract_log_path']=extract_log_base.replace('TOFILL',f'{short_model_name}') | |
| if 'clear_log_first' in config.keys() and config['clear_log_first'] is True: | |
| with open( config['extract_log_path'],'w')as f: | |
| f.write('') | |
| with open( config['detail_log_path'],'w')as f: | |
| f.write('') | |
| print(f'\n log cleared! ') | |
| logging.basicConfig(level=logging.INFO,filename=config['detail_log_path'], filemode='a', format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',force=True) | |
| print() | |
| print('model_path',model_name) | |
| print(f'extract_log_path:{config["extract_log_path"]}\ndetail_log_path:{config["detail_log_path"]}') | |
| print() | |
| try: | |
| if config['model_type']=='RWKV':#'HF' not in model_name and (('RWKV' in model_name) or ('rwkv' in model_name )): | |
| model,tokenizer=get_rwkv_model_tokenizer(model_name) | |
| elif config['model_type']=='MAMBA':#('mamba' in model_name) or ('MAMBA'in model_name ): | |
| model,tokenizer=get_mamba_model_tokenizer(model_name) | |
| elif config['model_type']=='HF':#'HF' in model_name: | |
| model,tokenizer=get_HF_model_tokenizer(model_name) | |
| print(f'model device:{model.device}') | |
| print('[tokenizer.cls_token]',[tokenizer.cls_token]) | |
| print('[tokenizer.sep_token]',[tokenizer.sep_token]) | |
| else: | |
| raise Exception('model type not found') | |
| # === get model and tokenizer | |
| for file_name,corpus in file_dic_list_strings.items(): | |
| tokenized_corpus=[] | |
| for text in corpus: | |
| text=text[:max_str_len] | |
| if config['model_type']=='RWKV': | |
| #'HF' not in model_name and (('RWKV' in model_name) or ('rwkv' in model_name )): | |
| tokenized_corpus.append(tokenizer.encode(text)) | |
| elif 'HF' in model_name and ('RWKV' in model_name): | |
| tokenized_corpus.append(tokenizer(text, return_tensors="pt")['input_ids']) | |
| elif ('mamba' in model_name) or ('MAMBA'in model_name ): | |
| device=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
| tokenized_corpus.append(tokenizer(text, return_tensors="pt").input_ids.to(device=device)) | |
| else: | |
| tokens = tokenizer.tokenize(text) | |
| if tokenizer.cls_token:# attention here is not [None] | |
| tokens = [tokenizer.cls_token] + tokens | |
| if tokenizer.sep_token: | |
| tokens = tokens +[tokenizer.sep_token] | |
| input_ids = tokenizer.convert_tokens_to_ids(tokens) | |
| tokenized_corpus.append(input_ids) | |
| # tokenized_corpus.append(tokenizer(text, return_tensors="pt")['input_ids']) | |
| processed_sequences = [] | |
| # 遍历 tokenized_corpus,截断或补全序列 | |
| for sequence in tokenized_corpus: | |
| # print('len(sequence)',len(sequence)) | |
| if len(sequence) < max_sequence_length: | |
| pass | |
| # 补全序列 | |
| # sequence = sequence + [tokenizer.pad_token_id] * (max_sequence_length - len(sequence)) | |
| # print(f'longer {max_sequence_length - len(sequence)}') | |
| elif len(sequence) > max_sequence_length: | |
| # 截断序列 | |
| sequence = sequence[:max_sequence_length] | |
| # 将处理后的序列添加到列表中 | |
| processed_sequences.append(sequence) | |
| total_loss = 0.0 | |
| total_tokens = 0 | |
| # pdb.set_trace() | |
| for enu,batch_input_ids in tqdm(enumerate(processed_sequences)): | |
| # if 'test_fun_dev' in config['detail_log_path'] and enu>50: | |
| # print(f'enu:{enu} batch_input_ids: break') | |
| # break | |
| batch_input_ids=torch.tensor(batch_input_ids).unsqueeze(0) | |
| with torch.no_grad(): | |
| # 获取模型的输出 | |
| # pdb.set_trace() | |
| if config['model_type']=='RWKV': | |
| # if 'HF' not in model_name and (('RWKV' in model_name) or ('rwkv' in model_name )): | |
| # print('rwkv1') | |
| # pdb.set_trace() | |
| # logits = model.forward(batch_input_ids.squeeze().to(torch.float32), None, full_output=True)[0] | |
| logits = model.forward(batch_input_ids.squeeze().long(), None, full_output=True)[0] | |
| # logits = model.forward(batch_input_ids.squeeze(), None, full_output=True)[0] | |
| # print(logits.shape) | |
| ''' | |
| tmp=torch.tensor(batch_input_ids).unsqueeze(0) | |
| logits = model.forward(batch_input_ids.squeeze().long(), None) | |
| logits = model.forward(batch_input_ids.long(), None,)[0] | |
| for output in outputs:print(tokenizer.decode(output.tolist(), skip_special_tokens=True)) | |
| ''' | |
| # loss = torch.nn.functional.cross_entropy(logits[ :-1, :].view(-1, logits.shape[-1]).to(torch.float32), batch_input_ids[0,1:].to(logits.device).view(-1).to(torch.float32), reduction='none') | |
| loss = torch.nn.functional.cross_entropy(logits[ :-1, :].view(-1, logits.shape[-1]).to(torch.float32), batch_input_ids[0,1:].to(logits.device).view(-1), reduction='none') | |
| elif config['model_type']=='MAMBA': | |
| # pdb.set_trace() | |
| mamba_output = model.forward(batch_input_ids[0])#the shape should be like (1,length) | |
| logits = mamba_output.logits | |
| loss = torch.nn.functional.cross_entropy(logits[:, :-1, :].view(-1, logits.shape[-1]), batch_input_ids[0][:,1:].view(-1), reduction='none') | |
| # pdb.set_trace() | |
| elif config['model_type']=='HF': | |
| if 'HF' in model_name and 'RWKV' in model_name: | |
| # pdb.set_trace() | |
| batch_input_ids=batch_input_ids.to(model.device) | |
| logits = model.forward(batch_input_ids[0]).logits#the shape should be like (1,length) | |
| loss = torch.nn.functional.cross_entropy(logits[:, :-1, :].view(-1, logits.shape[-1]), batch_input_ids[0][:,1:].view(-1), reduction='none') | |
| ''' | |
| batch_input_ids=batch_input_ids.to(model.device) | |
| HuggingFace-Download-Accelerator/ | |
| (Pdb) c | |
| /mntnfs/med_data5/chenghao/fresh_eval/src/fun_base_fill_LLM.py:324: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). | |
| ''' | |
| else: | |
| outputs = model(batch_input_ids) | |
| # 取出模型的logits | |
| if 'chatglm3-6b' in model_name: | |
| logits = outputs.logits.float() | |
| else: | |
| logits = outputs.logits | |
| loss = torch.nn.functional.cross_entropy(logits[:, :-1, :].view(-1, logits.shape[-1]), batch_input_ids[:,1:].view(-1), reduction='none') | |
| loss_sum = loss.sum() | |
| loss_mean = loss.mean() | |
| losses_list = loss.tolist() | |
| # 准备要写入日志的数据 | |
| tmp_dic = { | |
| 'model_name': model_name, | |
| 'file_name': file_name, | |
| 'lengths': len(batch_input_ids[0]), | |
| 'length_str':len(corpus[enu][:max_str_len]), | |
| 'loss_sum': loss_sum.item(), # 转换为Python标准数据类型 | |
| 'loss_mean': loss_mean.item(), | |
| 'losses_list': losses_list | |
| } | |
| import json | |
| with open(config['detail_log_path'], 'a') as f: | |
| json.dump(tmp_dic, f) | |
| f.write("\n") | |
| total_loss += loss.sum().item() | |
| total_tokens += batch_input_ids.numel() | |
| # 计算每个类别的平均损失 | |
| # pdb.set_trace() | |
| average_loss = total_loss / total_tokens | |
| avg_str_loss = total_loss/len(tokenized_corpus) | |
| print(f"{file_name} total loss:", average_loss) | |
| import json | |
| logs = { | |
| "model_name": model_name, | |
| "file_name": file_name, | |
| "processed_sequences": len(processed_sequences), | |
| "average_loss": average_loss, | |
| "avg_str_loss": avg_str_loss | |
| } | |
| # with open(f'/mntnfs/med_data5/chenghao/fresh_eval/log/year_arxiv/j_y_ans_{file_in_data_folder}.json', 'a') as f: | |
| with open(config['extract_log_path'], 'a') as f: | |
| json.dump(logs, f) | |
| f.write("\n") | |
| logging.info(logs) | |
| except Exception as e: | |
| logging.error(f"{model_name}, error:{e} ,detail:{traceback.format_exc()}") | |
| with open(config['extract_log_path'], 'a') as f: | |
| # json.dump(logs, f) | |
| f.write(f"{model_name} failed \n") | |
| print(f"{model_name} failed for {e} detail:{traceback.format_exc()}\n") | |
| if __name__=='__main__': | |
| config={} | |
| print(file_in_data_folder) | |
| file_dic_list_strings=get_text.get_text_from(file_in_data_folder,limit=limit_lines_per_file) | |
| config['max_sequence_length'],config['max_str_len'],config['limit_lines_per_file']=2048,5000,10 | |
| config['extract_log_path']=f'/mntnfs/med_data5/chenghao/fresh_eval/log/test_fun_dev/extract.log' | |
| config['detail_log_path']=f'/mntnfs/med_data5/chenghao/fresh_eval/log/test_fun_dev/detail.log' | |
| config['model_path']='/mntnfs/med_data5/liangjuhao/models/TinyLlama-1.1B-Chat-v0.6'#paths[:1] | |
| config['batch']=16 | |
| config['model_type']='HF' | |
| print('start',config['model_path']) | |
| config['file_dic_list_strings']=file_dic_list_strings | |
| run_model_on_dic(config) |