Spaces:
Build error
Build error
| # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/train_t2s.py | |
| import os | |
| import pdb | |
| if "_CUDA_VISIBLE_DEVICES" in os.environ: | |
| os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"] | |
| import argparse | |
| import logging | |
| from pathlib import Path | |
| import torch, platform | |
| from pytorch_lightning import seed_everything | |
| from pytorch_lightning import Trainer | |
| from pytorch_lightning.callbacks import ModelCheckpoint | |
| from pytorch_lightning.loggers import TensorBoardLogger # WandbLogger | |
| from pytorch_lightning.strategies import DDPStrategy | |
| from AR.data.data_module import Text2SemanticDataModule | |
| from AR.models.t2s_lightning_module import Text2SemanticLightningModule | |
| from AR.utils.io import load_yaml_config | |
| logging.getLogger("numba").setLevel(logging.WARNING) | |
| logging.getLogger("matplotlib").setLevel(logging.WARNING) | |
| torch.set_float32_matmul_precision("high") | |
| from AR.utils import get_newest_ckpt | |
| from collections import OrderedDict | |
| from time import time as ttime | |
| import shutil | |
| def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path | |
| dir=os.path.dirname(path) | |
| name=os.path.basename(path) | |
| tmp_path="%s.pth"%(ttime()) | |
| torch.save(fea,tmp_path) | |
| shutil.move(tmp_path,"%s/%s"%(dir,name)) | |
| class my_model_ckpt(ModelCheckpoint): | |
| def __init__( | |
| self, | |
| config, | |
| if_save_latest, | |
| if_save_every_weights, | |
| half_weights_save_dir, | |
| exp_name, | |
| **kwargs | |
| ): | |
| super().__init__(**kwargs) | |
| self.if_save_latest = if_save_latest | |
| self.if_save_every_weights = if_save_every_weights | |
| self.half_weights_save_dir = half_weights_save_dir | |
| self.exp_name = exp_name | |
| self.config = config | |
| def on_train_epoch_end(self, trainer, pl_module): | |
| # if not self._should_skip_saving_checkpoint(trainer) and self._should_save_on_train_epoch_end(trainer): | |
| if self._should_save_on_train_epoch_end(trainer): | |
| monitor_candidates = self._monitor_candidates(trainer) | |
| if ( | |
| self._every_n_epochs >= 1 | |
| and (trainer.current_epoch + 1) % self._every_n_epochs == 0 | |
| ): | |
| if ( | |
| self.if_save_latest == True | |
| ): ####如果设置只保存最后一个ckpt,在保存下一个ckpt后要清理掉之前的所有ckpt | |
| to_clean = list(os.listdir(self.dirpath)) | |
| self._save_topk_checkpoint(trainer, monitor_candidates) | |
| if self.if_save_latest == True: | |
| for name in to_clean: | |
| try: | |
| os.remove("%s/%s" % (self.dirpath, name)) | |
| except: | |
| pass | |
| if self.if_save_every_weights == True: | |
| to_save_od = OrderedDict() | |
| to_save_od["weight"] = OrderedDict() | |
| dictt = trainer.strategy._lightning_module.state_dict() | |
| for key in dictt: | |
| to_save_od["weight"][key] = dictt[key].half() | |
| to_save_od["config"] = self.config | |
| to_save_od["info"] = "GPT-e%s" % (trainer.current_epoch + 1) | |
| # torch.save( | |
| my_save( | |
| to_save_od, | |
| "%s/%s-e%s.ckpt" | |
| % ( | |
| self.half_weights_save_dir, | |
| self.exp_name, | |
| trainer.current_epoch + 1, | |
| ), | |
| ) | |
| self._save_last_checkpoint(trainer, monitor_candidates) | |
| def main(args): | |
| config = load_yaml_config(args.config_file) | |
| output_dir = Path(config["output_dir"]) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| ckpt_dir = output_dir / "ckpt" | |
| ckpt_dir.mkdir(parents=True, exist_ok=True) | |
| seed_everything(config["train"]["seed"], workers=True) | |
| ckpt_callback: ModelCheckpoint = my_model_ckpt( | |
| config=config, | |
| if_save_latest=config["train"]["if_save_latest"], | |
| if_save_every_weights=config["train"]["if_save_every_weights"], | |
| half_weights_save_dir=config["train"]["half_weights_save_dir"], | |
| exp_name=config["train"]["exp_name"], | |
| save_top_k=-1, | |
| monitor="top_3_acc", | |
| mode="max", | |
| save_on_train_epoch_end=True, | |
| every_n_epochs=config["train"]["save_every_n_epoch"], | |
| dirpath=ckpt_dir, | |
| ) | |
| logger = TensorBoardLogger(name=output_dir.stem, save_dir=output_dir) | |
| os.environ["MASTER_ADDR"]="localhost" | |
| trainer: Trainer = Trainer( | |
| max_epochs=config["train"]["epochs"], | |
| accelerator="gpu", | |
| # val_check_interval=9999999999999999999999,###不要验证 | |
| # check_val_every_n_epoch=None, | |
| limit_val_batches=0, | |
| devices=-1, | |
| benchmark=False, | |
| fast_dev_run=False, | |
| strategy = "auto" if torch.backends.mps.is_available() else DDPStrategy( | |
| process_group_backend="nccl" if platform.system() != "Windows" else "gloo" | |
| ), # mps 不支持多节点训练 | |
| precision=config["train"]["precision"], | |
| logger=logger, | |
| num_sanity_val_steps=0, | |
| callbacks=[ckpt_callback], | |
| ) | |
| model: Text2SemanticLightningModule = Text2SemanticLightningModule( | |
| config, output_dir | |
| ) | |
| data_module: Text2SemanticDataModule = Text2SemanticDataModule( | |
| config, | |
| train_semantic_path=config["train_semantic_path"], | |
| train_phoneme_path=config["train_phoneme_path"], | |
| # dev_semantic_path=args.dev_semantic_path, | |
| # dev_phoneme_path=args.dev_phoneme_path | |
| ) | |
| try: | |
| # 使用正则表达式匹配文件名中的数字部分,并按数字大小进行排序 | |
| newest_ckpt_name = get_newest_ckpt(os.listdir(ckpt_dir)) | |
| ckpt_path = ckpt_dir / newest_ckpt_name | |
| except Exception: | |
| ckpt_path = None | |
| print("ckpt_path:", ckpt_path) | |
| trainer.fit(model, data_module, ckpt_path=ckpt_path) | |
| # srun --gpus-per-node=1 --ntasks-per-node=1 python train.py --path-to-configuration configurations/default.yaml | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "-c", | |
| "--config_file", | |
| type=str, | |
| default="configs/s1longer.yaml", | |
| help="path of config file", | |
| ) | |
| # args for dataset | |
| # parser.add_argument('--train_semantic_path',type=str,default='/data/docker/liujing04/gpt-vits/fine_tune_dataset/xuangou/6-name2semantic.tsv') | |
| # parser.add_argument('--train_phoneme_path', type=str, default='/data/docker/liujing04/gpt-vits/fine_tune_dataset/xuangou/2-name2text.txt') | |
| # parser.add_argument('--dev_semantic_path', type=str, default='dump_mix/semantic_dev.tsv') | |
| # parser.add_argument('--dev_phoneme_path', type=str, default='dump_mix/phoneme_dev.npy') | |
| # parser.add_argument('--output_dir',type=str,default='/data/docker/liujing04/gpt-vits/fine_tune_dataset/xuangou/logs_s1',help='directory to save the results') | |
| # parser.add_argument('--output_dir',type=str,default='/liujing04/gpt_logs/s1/xuangou_ft',help='directory to save the results') | |
| args = parser.parse_args() | |
| logging.info(str(args)) | |
| main(args) | |