misc fixes/improvements (#513)
Browse files- src/axolotl/train.py +5 -3
- src/axolotl/utils/trainer.py +11 -7
    	
        src/axolotl/train.py
    CHANGED
    
    | @@ -88,6 +88,11 @@ def train( | |
| 88 | 
             
                if peft_config:
         | 
| 89 | 
             
                    LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")
         | 
| 90 | 
             
                    peft_config.save_pretrained(cfg.output_dir)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 91 |  | 
| 92 | 
             
                # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
         | 
| 93 | 
             
                if cfg.local_rank == 0:
         | 
| @@ -106,9 +111,6 @@ def train( | |
| 106 | 
             
                if cfg.group_by_length:
         | 
| 107 | 
             
                    LOG.info("hang tight... sorting dataset for group_by_length")
         | 
| 108 |  | 
| 109 | 
            -
                if not Path(cfg.output_dir).is_dir():
         | 
| 110 | 
            -
                    os.makedirs(cfg.output_dir, exist_ok=True)
         | 
| 111 | 
            -
                tokenizer.save_pretrained(cfg.output_dir)
         | 
| 112 | 
             
                if cfg.flash_optimum:
         | 
| 113 | 
             
                    with torch.backends.cuda.sdp_kernel(
         | 
| 114 | 
             
                        enable_flash=True, enable_math=True, enable_mem_efficient=True
         | 
|  | |
| 88 | 
             
                if peft_config:
         | 
| 89 | 
             
                    LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")
         | 
| 90 | 
             
                    peft_config.save_pretrained(cfg.output_dir)
         | 
| 91 | 
            +
                # additionally presave the tokenizer and model configs
         | 
| 92 | 
            +
                if not Path(cfg.output_dir).is_dir():
         | 
| 93 | 
            +
                    os.makedirs(cfg.output_dir, exist_ok=True)
         | 
| 94 | 
            +
                tokenizer.save_pretrained(str(Path(cfg.output_dir)))
         | 
| 95 | 
            +
                model.config.save_pretrained(str(Path(cfg.output_dir)))
         | 
| 96 |  | 
| 97 | 
             
                # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
         | 
| 98 | 
             
                if cfg.local_rank == 0:
         | 
|  | |
| 111 | 
             
                if cfg.group_by_length:
         | 
| 112 | 
             
                    LOG.info("hang tight... sorting dataset for group_by_length")
         | 
| 113 |  | 
|  | |
|  | |
|  | |
| 114 | 
             
                if cfg.flash_optimum:
         | 
| 115 | 
             
                    with torch.backends.cuda.sdp_kernel(
         | 
| 116 | 
             
                        enable_flash=True, enable_math=True, enable_mem_efficient=True
         | 
    	
        src/axolotl/utils/trainer.py
    CHANGED
    
    | @@ -33,6 +33,7 @@ from axolotl.utils.callbacks import ( | |
| 33 | 
             
            )
         | 
| 34 | 
             
            from axolotl.utils.collators import DataCollatorForSeq2Seq
         | 
| 35 | 
             
            from axolotl.utils.dataloader import MultipackDistributedDataloader
         | 
|  | |
| 36 | 
             
            from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
         | 
| 37 |  | 
| 38 | 
             
            LOG = logging.getLogger("axolotl")
         | 
| @@ -375,14 +376,17 @@ def disable_datasets_caching(): | |
| 375 |  | 
| 376 | 
             
            def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
         | 
| 377 | 
             
                drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
         | 
| 378 | 
            -
                 | 
| 379 | 
            -
             | 
| 380 | 
            -
                    eval_dataset = eval_dataset.filter(drop_long, num_proc=os.cpu_count())
         | 
| 381 | 
            -
             | 
| 382 | 
            -
                if cfg.sample_packing:
         | 
| 383 | 
            -
                    train_dataset = train_dataset.map(add_position_ids, num_proc=os.cpu_count())
         | 
| 384 | 
             
                    if eval_dataset:
         | 
| 385 | 
            -
                        eval_dataset = eval_dataset. | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 386 | 
             
                return train_dataset, eval_dataset
         | 
| 387 |  | 
| 388 |  | 
|  | |
| 33 | 
             
            )
         | 
| 34 | 
             
            from axolotl.utils.collators import DataCollatorForSeq2Seq
         | 
| 35 | 
             
            from axolotl.utils.dataloader import MultipackDistributedDataloader
         | 
| 36 | 
            +
            from axolotl.utils.distributed import is_main_process, zero_first
         | 
| 37 | 
             
            from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
         | 
| 38 |  | 
| 39 | 
             
            LOG = logging.getLogger("axolotl")
         | 
|  | |
| 376 |  | 
| 377 | 
             
            def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
         | 
| 378 | 
             
                drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
         | 
| 379 | 
            +
                with zero_first(is_main_process()):
         | 
| 380 | 
            +
                    train_dataset = train_dataset.filter(drop_long, num_proc=os.cpu_count())
         | 
|  | |
|  | |
|  | |
|  | |
| 381 | 
             
                    if eval_dataset:
         | 
| 382 | 
            +
                        eval_dataset = eval_dataset.filter(drop_long, num_proc=os.cpu_count())
         | 
| 383 | 
            +
             | 
| 384 | 
            +
                    if cfg.sample_packing:
         | 
| 385 | 
            +
                        train_dataset = train_dataset.map(add_position_ids, num_proc=os.cpu_count())
         | 
| 386 | 
            +
                        if eval_dataset:
         | 
| 387 | 
            +
                            eval_dataset = eval_dataset.map(
         | 
| 388 | 
            +
                                add_position_ids, num_proc=os.cpu_count()
         | 
| 389 | 
            +
                            )
         | 
| 390 | 
             
                return train_dataset, eval_dataset
         | 
| 391 |  | 
| 392 |  | 
