|  | """Prepare and train a model on a dataset. Can also infer from a model or merge lora""" | 
					
						
						|  |  | 
					
						
						|  | import os | 
					
						
						|  | import signal | 
					
						
						|  | import sys | 
					
						
						|  | from dataclasses import dataclass | 
					
						
						|  | from pathlib import Path | 
					
						
						|  | from typing import Optional, Tuple, Union | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  | import transformers.modelcard | 
					
						
						|  | from accelerate.logging import get_logger | 
					
						
						|  | from datasets import Dataset | 
					
						
						|  | from peft import PeftModel | 
					
						
						|  | from pkg_resources import get_distribution | 
					
						
						|  | from transformers import PreTrainedModel, PreTrainedTokenizer | 
					
						
						|  | from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled | 
					
						
						|  |  | 
					
						
						|  | from axolotl.common.cli import TrainerCliArgs | 
					
						
						|  | from axolotl.logging_config import configure_logging | 
					
						
						|  | from axolotl.utils.dict import DictDefault | 
					
						
						|  | from axolotl.utils.freeze import freeze_layers_except | 
					
						
						|  | from axolotl.utils.models import load_model, load_tokenizer | 
					
						
						|  | from axolotl.utils.trainer import setup_trainer | 
					
						
						|  |  | 
					
						
						|  | try: | 
					
						
						|  | from optimum.bettertransformer import BetterTransformer | 
					
						
						|  | except ImportError: | 
					
						
						|  | BetterTransformer = None | 
					
						
						|  |  | 
					
						
						|  | project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) | 
					
						
						|  | src_dir = os.path.join(project_root, "src") | 
					
						
						|  | sys.path.insert(0, src_dir) | 
					
						
						|  |  | 
					
						
						|  | configure_logging() | 
					
						
						|  | LOG = get_logger("axolotl.train") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @dataclass | 
					
						
						|  | class TrainDatasetMeta: | 
					
						
						|  | """ | 
					
						
						|  | dataclass to capture the dataset specific options for training | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | train_dataset: Dataset | 
					
						
						|  | eval_dataset: Optional[Dataset] = None | 
					
						
						|  | total_num_steps: Optional[int] = None | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def train( | 
					
						
						|  | *, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta | 
					
						
						|  | ) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]: | 
					
						
						|  |  | 
					
						
						|  | LOG.debug( | 
					
						
						|  | f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}", | 
					
						
						|  | main_process_only=True, | 
					
						
						|  | ) | 
					
						
						|  | tokenizer = load_tokenizer(cfg) | 
					
						
						|  |  | 
					
						
						|  | train_dataset = dataset_meta.train_dataset | 
					
						
						|  | eval_dataset = dataset_meta.eval_dataset | 
					
						
						|  | total_num_steps = dataset_meta.total_num_steps | 
					
						
						|  |  | 
					
						
						|  | if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints: | 
					
						
						|  | possible_checkpoints = [ | 
					
						
						|  | str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*") | 
					
						
						|  | ] | 
					
						
						|  | if len(possible_checkpoints) > 0: | 
					
						
						|  | sorted_paths = sorted( | 
					
						
						|  | possible_checkpoints, | 
					
						
						|  | key=lambda path: int(path.split("-")[-1]), | 
					
						
						|  | ) | 
					
						
						|  | cfg.resume_from_checkpoint = sorted_paths[-1] | 
					
						
						|  | LOG.info( | 
					
						
						|  | f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}" | 
					
						
						|  | ) | 
					
						
						|  | resume_from_checkpoint = cfg.resume_from_checkpoint | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | msg = "loading model" | 
					
						
						|  | if cfg.adapter: | 
					
						
						|  | msg += " and peft_config..." | 
					
						
						|  | LOG.debug(msg) | 
					
						
						|  | model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference) | 
					
						
						|  | model.generation_config.do_sample = True | 
					
						
						|  |  | 
					
						
						|  | model_ref = None | 
					
						
						|  | if cfg.rl and cfg.rl != "orpo": | 
					
						
						|  | if cfg.adapter and not cfg.rl_adapter_ref_model: | 
					
						
						|  |  | 
					
						
						|  | LOG.debug("Passing model_ref: None to RL trainer") | 
					
						
						|  | model_ref = None | 
					
						
						|  | else: | 
					
						
						|  |  | 
					
						
						|  | model_ref, _ = load_model( | 
					
						
						|  | cfg, tokenizer, inference=cli_args.inference, reference_model=True | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | safe_serialization = cfg.save_safetensors is True | 
					
						
						|  |  | 
					
						
						|  | if cfg.unfrozen_parameters: | 
					
						
						|  | freeze_layers_except(model, cfg.unfrozen_parameters) | 
					
						
						|  |  | 
					
						
						|  | trainer = setup_trainer( | 
					
						
						|  | cfg, | 
					
						
						|  | train_dataset, | 
					
						
						|  | eval_dataset, | 
					
						
						|  | (model, model_ref, peft_config), | 
					
						
						|  | tokenizer, | 
					
						
						|  | total_num_steps, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if peft_config: | 
					
						
						|  | LOG.info(f"Pre-saving adapter config to {cfg.output_dir}") | 
					
						
						|  | peft_config.save_pretrained(cfg.output_dir) | 
					
						
						|  |  | 
					
						
						|  | if not Path(cfg.output_dir).is_dir(): | 
					
						
						|  | os.makedirs(cfg.output_dir, exist_ok=True) | 
					
						
						|  | tokenizer.save_pretrained(str(Path(cfg.output_dir))) | 
					
						
						|  | if hasattr(model, "config"): | 
					
						
						|  | model.config.save_pretrained(str(Path(cfg.output_dir))) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if cfg.local_rank == 0: | 
					
						
						|  |  | 
					
						
						|  | def terminate_handler(_, __, model): | 
					
						
						|  | if cfg.flash_optimum and BetterTransformer: | 
					
						
						|  | model = BetterTransformer.reverse(model) | 
					
						
						|  | model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) | 
					
						
						|  | sys.exit(0) | 
					
						
						|  |  | 
					
						
						|  | signal.signal( | 
					
						
						|  | signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model) | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | badge_markdown = """[<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)""" | 
					
						
						|  | transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}" | 
					
						
						|  |  | 
					
						
						|  | if getattr(cfg, "axolotl_config_path"): | 
					
						
						|  | raw_axolotl_cfg = Path(cfg.axolotl_config_path) | 
					
						
						|  | version = get_distribution("axolotl").version | 
					
						
						|  | if raw_axolotl_cfg.is_file(): | 
					
						
						|  | transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n<details><summary>See axolotl config</summary>\n\naxolotl version: `{version}`\n```yaml\n{raw_axolotl_cfg.read_text(encoding='utf-8')}\n```\n\n</details><br>\n" | 
					
						
						|  |  | 
					
						
						|  | LOG.info("Starting trainer...") | 
					
						
						|  | if cfg.group_by_length: | 
					
						
						|  | LOG.info("hang tight... sorting dataset for group_by_length") | 
					
						
						|  |  | 
					
						
						|  | pretrain_hooks(cfg, trainer) | 
					
						
						|  | if cfg.flash_optimum: | 
					
						
						|  | with torch.backends.cuda.sdp_kernel( | 
					
						
						|  |  | 
					
						
						|  | enable_flash=True, | 
					
						
						|  | enable_math=True, | 
					
						
						|  | enable_mem_efficient=True, | 
					
						
						|  | ): | 
					
						
						|  | trainer.train(resume_from_checkpoint=resume_from_checkpoint) | 
					
						
						|  | else: | 
					
						
						|  | trainer.train(resume_from_checkpoint=resume_from_checkpoint) | 
					
						
						|  | post_train_hooks(cfg, trainer) | 
					
						
						|  |  | 
					
						
						|  | LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | for name, module in model.named_modules(): | 
					
						
						|  | if hasattr(module, "_post_training"): | 
					
						
						|  | module._post_training(model, name) | 
					
						
						|  |  | 
					
						
						|  | if trainer.is_fsdp_enabled: | 
					
						
						|  | trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT") | 
					
						
						|  | LOG.info("Set FSDP state dict type to FULL_STATE_DICT for saving.") | 
					
						
						|  |  | 
					
						
						|  | if cfg.relora_steps: | 
					
						
						|  | if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit): | 
					
						
						|  | model = model.merge_and_unload() | 
					
						
						|  | else: | 
					
						
						|  |  | 
					
						
						|  | return model, tokenizer | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if cfg.fsdp: | 
					
						
						|  | trainer.save_model(cfg.output_dir) | 
					
						
						|  | elif cfg.deepspeed and is_deepspeed_zero3_enabled(): | 
					
						
						|  |  | 
					
						
						|  | trainer.accelerator.wait_for_everyone() | 
					
						
						|  | unwrapped_model = trainer.accelerator.unwrap_model(trainer.model_wrapped) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | unwrapped_model.save_pretrained( | 
					
						
						|  | cfg.output_dir, | 
					
						
						|  | is_main_process=trainer.accelerator.is_main_process, | 
					
						
						|  | save_function=trainer.accelerator.save, | 
					
						
						|  | state_dict=trainer.accelerator.get_state_dict(trainer.model_wrapped), | 
					
						
						|  | ) | 
					
						
						|  | elif cfg.local_rank == 0: | 
					
						
						|  | if cfg.flash_optimum and BetterTransformer: | 
					
						
						|  | model = BetterTransformer.reverse(model) | 
					
						
						|  |  | 
					
						
						|  | model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) | 
					
						
						|  |  | 
					
						
						|  | if not cfg.hub_model_id: | 
					
						
						|  | try: | 
					
						
						|  | trainer.create_model_card(model_name=cfg.output_dir.lstrip("./")) | 
					
						
						|  | except AttributeError: | 
					
						
						|  | pass | 
					
						
						|  | elif cfg.hub_model_id: | 
					
						
						|  |  | 
					
						
						|  | trainer.push_to_hub() | 
					
						
						|  |  | 
					
						
						|  | return model, tokenizer | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def pretrain_hooks(_cfg, _trainer): | 
					
						
						|  | """ | 
					
						
						|  | Run hooks right before kicking off the training | 
					
						
						|  | :param cfg: | 
					
						
						|  | :param trainer: | 
					
						
						|  | :return: | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def post_train_hooks(_cfg, _trainer): | 
					
						
						|  | """ | 
					
						
						|  | Run hooks right after training completes | 
					
						
						|  | :param cfg: | 
					
						
						|  | :param trainer: | 
					
						
						|  | :return: | 
					
						
						|  | """ | 
					
						
						|  |  |