Spaces:
Runtime error
Runtime error
| # coding=utf-8 | |
| from typing import Dict, Optional | |
| import time | |
| import os | |
| import pandas as pd | |
| import torch | |
| from datasets import Dataset, load_dataset | |
| from transformers import PreTrainedTokenizerFast, TrainingArguments | |
| from trl import DPOTrainer | |
| from tokenizers import Tokenizer | |
| from peft import LoraConfig, TaskType, PeftModel | |
| from config import DpoConfig, T5ModelConfig | |
| from model.chat_model import TextToTextModel | |
| from utils.functions import get_T5_config | |
| os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0' | |
| def get_dataset(split: str, file: str, cache_dir: str = '.cache') -> Dataset: | |
| """Load the Anthropic Helpful-Harmless dataset from Hugging Face and convert it to the necessary format. | |
| The dataset is converted to a dictionary with the following structure: | |
| { | |
| 'prompt': List[str], | |
| 'chosen': List[str], | |
| 'rejected': List[str], | |
| } | |
| """ | |
| dataset = load_dataset('json', data_files=file, split=split, cache_dir=cache_dir) | |
| def split_prompt_and_responses(sample: dict) -> Dict[str, str]: | |
| return { | |
| # add an eos token for signal that end of sentence, using in generate. | |
| "prompt": f"{sample['prompt']}[EOS]", | |
| "chosen": f"{sample['chosen']}[EOS]", | |
| "rejected": f"{sample['rejected']}[EOS]", | |
| } | |
| return dataset.map(split_prompt_and_responses).shuffle(2333) | |
| def train_dpo(config: DpoConfig, peft_config: LoraConfig=None) -> None: | |
| # step 1. 加载tokenizer | |
| tokenizer = PreTrainedTokenizerFast.from_pretrained(config.tokenizer_dir) | |
| # step 2. 加载预训练模型 | |
| model_train, model_ref = None, None | |
| if os.path.isdir(config.sft_model_file): | |
| # 传入文件夹则 from_pretrained | |
| model_train = TextToTextModel.from_pretrained(config.sft_model_file) | |
| model_ref = TextToTextModel.from_pretrained(config.sft_model_file) | |
| else: | |
| # load_state_dict | |
| t5_config = get_T5_config(T5ModelConfig(), vocab_size=len(tokenizer), decoder_start_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id) | |
| model_train = TextToTextModel(t5_config) | |
| model_train.load_state_dict(torch.load(config.sft_model_file, map_location='cpu')) # set cpu for no exception | |
| model_ref = TextToTextModel(t5_config) | |
| model_ref.load_state_dict(torch.load(config.sft_model_file, map_location='cpu')) | |
| # 4. 加载训练数据集 | |
| train_dataset = get_dataset("train", file=config.dpo_train_file) | |
| # 5. 加载评估数据集 | |
| # eval_dataset = get_dataset("train", file=config.dpo_eval_file) | |
| eval_dataset = None | |
| # 6. 初始化训练参数 | |
| training_args = TrainingArguments( | |
| per_device_train_batch_size=config.per_device_train_batch_size, | |
| num_train_epochs=config.num_train_epochs, | |
| auto_find_batch_size=True, | |
| remove_unused_columns=False, | |
| gradient_accumulation_steps=config.gradient_accumulation_steps, | |
| learning_rate=config.learning_rate, | |
| logging_first_step=True, | |
| logging_steps=config.logging_steps, | |
| save_steps=config.save_steps, | |
| output_dir=config.output_dir, | |
| optim="adafactor", | |
| report_to="tensorboard", | |
| log_level='info', | |
| warmup_steps=config.warmup_steps, | |
| bf16=False, | |
| fp16=config.fp16, | |
| seed=config.seed, | |
| logging_dir=config.log_dir, | |
| ) | |
| # 7. 初始化 DPO trainer | |
| dpo_trainer = DPOTrainer( | |
| model_train, | |
| model_ref, | |
| peft_config=peft_config, | |
| args=training_args, | |
| beta=config.beta, | |
| train_dataset=train_dataset, | |
| eval_dataset=eval_dataset, | |
| tokenizer=tokenizer, | |
| max_length=config.max_seq_len, | |
| max_target_length=config.max_seq_len, | |
| max_prompt_length=config.max_seq_len, | |
| generate_during_eval=True, | |
| is_encoder_decoder=True, | |
| ) | |
| # 8. 训练 | |
| dpo_trainer.train( | |
| # resume_from_checkpoint=True | |
| ) | |
| # 9. save log | |
| loss_log = pd.DataFrame(dpo_trainer.state.log_history) | |
| log_dir = './logs' | |
| if not os.path.exists(log_dir): | |
| os.mkdir(log_dir) | |
| loss_log.to_csv(f"{log_dir}/dpo_train_log_{time.strftime('%Y%m%d-%H%M')}.csv") | |
| # 10. 保存模型/lora | |
| suffixe = '/lora/' if peft_config is not None else '/dpo' | |
| model_save_dir = '/'.join(config.sft_model_file.split('/')[0: -1]) + suffixe | |
| dpo_trainer.save_model(model_save_dir) | |
| print('save model or lora adapter to: {}'.format(model_save_dir)) | |
| def merge_lora_weight_into_model(config: DpoConfig, peft_config: LoraConfig) -> None: | |
| # step 1. 加载tokenizer | |
| tokenizer = PreTrainedTokenizerFast.from_pretrained(config.tokenizer_dir) | |
| # step 2. 加载预训练模型 | |
| sft_model = None | |
| if os.path.isdir(config.sft_model_file): | |
| # 传入文件夹则 from_pretrained | |
| sft_model = TextToTextModel.from_pretrained(config.sft_model_file) | |
| else: | |
| # load_state_dict | |
| t5_config = get_T5_config(T5ModelConfig(), vocab_size=len(tokenizer), decoder_start_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id) | |
| sft_model = TextToTextModel(t5_config) | |
| sft_model.load_state_dict(torch.load(config.sft_model_file, map_location='cpu')) # set cpu for no exception | |
| # 注意这个路径要和上面的model_save_dir一致 | |
| # train_dpo函数代码 | |
| # 9. 保存模型/lora | |
| # suffixe = '/lora/' if peft_config is not None else '/dpo' | |
| # model_save_dir = '/'.join(config.sft_model_file.split('/')[0: -1]) + suffixe | |
| adapter_save_dir = '/'.join(config.sft_model_file.split('/')[0: -1]) + '/lora' | |
| peft_model = PeftModel.from_pretrained( | |
| model=sft_model, | |
| model_id=adapter_save_dir, | |
| config=peft_config, | |
| adapter_name='adapter', | |
| ) | |
| # peft_model = PeftModel( | |
| # model=sft_model, | |
| # peft_config=peft_config, | |
| # adapter_name='adapter', | |
| # ) | |
| # 3. load adapter | |
| print('load adapter from dir: {}'.format(adapter_save_dir)) | |
| peft_model.load_adapter(model_id=adapter_save_dir, adapter_name='adapter',) | |
| # 4. merge | |
| peft_model = peft_model.merge_and_unload() | |
| # 5. save | |
| save_merge_file = config.sft_model_file + '.dpo_lora_merged' | |
| sft_model.save_pretrained(save_merge_file) | |
| print('save merge model file to: {}'.format(save_merge_file)) | |
| if __name__ == "__main__": | |
| peft_config = LoraConfig( | |
| task_type=TaskType.SEQ_2_SEQ_LM, # text 2 text lora model | |
| inference_mode=False, | |
| r=16, | |
| lora_alpha=16, | |
| lora_dropout=0.1, | |
| bias="all", | |
| ) | |
| dpo_config = DpoConfig() | |
| # 1. train | |
| train_dpo(dpo_config, peft_config=None) | |
| # 2. merge lora adapter into model | |
| # merge_lora_weight_into_model(dpo_config, peft_config) | |