|  | """Prepare and train a model on a dataset. Can also infer from a model or merge lora""" | 
					
						
						|  |  | 
					
						
						|  | import importlib | 
					
						
						|  | import logging | 
					
						
						|  | import os | 
					
						
						|  | import random | 
					
						
						|  | import signal | 
					
						
						|  | import sys | 
					
						
						|  | from pathlib import Path | 
					
						
						|  | from typing import Any, Dict, List, Optional, Union | 
					
						
						|  |  | 
					
						
						|  | import fire | 
					
						
						|  | import torch | 
					
						
						|  | import yaml | 
					
						
						|  | from transformers import GenerationConfig | 
					
						
						|  |  | 
					
						
						|  | from axolotl.utils.data import load_prepare_datasets | 
					
						
						|  | from axolotl.utils.dict import DictDefault | 
					
						
						|  | from axolotl.utils.models import load_model, load_tokenizer | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | from axolotl.utils.tokenization import check_dataset_labels | 
					
						
						|  | from axolotl.utils.trainer import setup_trainer | 
					
						
						|  | from axolotl.utils.validation import validate_config | 
					
						
						|  | from axolotl.utils.wandb import setup_wandb_env_vars | 
					
						
						|  |  | 
					
						
						|  | project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) | 
					
						
						|  | src_dir = os.path.join(project_root, "src") | 
					
						
						|  | sys.path.insert(0, src_dir) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO")) | 
					
						
						|  | DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def choose_device(cfg): | 
					
						
						|  | def get_device(): | 
					
						
						|  | try: | 
					
						
						|  | if torch.cuda.is_available(): | 
					
						
						|  | return f"cuda:{cfg.local_rank}" | 
					
						
						|  |  | 
					
						
						|  | if torch.backends.mps.is_available(): | 
					
						
						|  | return "mps" | 
					
						
						|  |  | 
					
						
						|  | raise SystemError("No CUDA/mps device found") | 
					
						
						|  | except Exception: | 
					
						
						|  | return "cpu" | 
					
						
						|  |  | 
					
						
						|  | cfg.device = get_device() | 
					
						
						|  | if cfg.device == "cuda": | 
					
						
						|  | cfg.device_map = {"": cfg.local_rank} | 
					
						
						|  | else: | 
					
						
						|  | cfg.device_map = {"": cfg.device} | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def get_multi_line_input() -> Optional[str]: | 
					
						
						|  | print("Give me an instruction (Ctrl + D to finish): ") | 
					
						
						|  | instruction = "" | 
					
						
						|  | for line in sys.stdin: | 
					
						
						|  | instruction += line | 
					
						
						|  |  | 
					
						
						|  | return instruction | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"): | 
					
						
						|  | tokenizer.add_special_tokens({"unk_token": "<unk>"}) | 
					
						
						|  | tokenizer.add_special_tokens({"bos_token": "<s>"}) | 
					
						
						|  | tokenizer.add_special_tokens({"eos_token": "</s>"}) | 
					
						
						|  |  | 
					
						
						|  | prompter_module = getattr(importlib.import_module("axolotl.prompters"), prompter) | 
					
						
						|  |  | 
					
						
						|  | while True: | 
					
						
						|  |  | 
					
						
						|  | instruction = get_multi_line_input() | 
					
						
						|  | if not instruction: | 
					
						
						|  | return | 
					
						
						|  | prompt: str = next( | 
					
						
						|  | prompter_module().build_prompt(instruction=instruction.strip("\n")) | 
					
						
						|  | ) | 
					
						
						|  | batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True) | 
					
						
						|  |  | 
					
						
						|  | model.eval() | 
					
						
						|  | with torch.no_grad(): | 
					
						
						|  | generation_config = GenerationConfig( | 
					
						
						|  | repetition_penalty=1.1, | 
					
						
						|  | max_new_tokens=1024, | 
					
						
						|  | temperature=0.9, | 
					
						
						|  | top_p=0.95, | 
					
						
						|  | top_k=40, | 
					
						
						|  | bos_token_id=tokenizer.bos_token_id, | 
					
						
						|  | eos_token_id=tokenizer.eos_token_id, | 
					
						
						|  | pad_token_id=tokenizer.pad_token_id, | 
					
						
						|  | do_sample=True, | 
					
						
						|  | use_cache=True, | 
					
						
						|  | return_dict_in_generate=True, | 
					
						
						|  | output_attentions=False, | 
					
						
						|  | output_hidden_states=False, | 
					
						
						|  | output_scores=False, | 
					
						
						|  | ) | 
					
						
						|  | generated = model.generate( | 
					
						
						|  | inputs=batch["input_ids"].to(cfg.device), | 
					
						
						|  | generation_config=generation_config, | 
					
						
						|  | ) | 
					
						
						|  | print(tokenizer.decode(generated["sequences"].cpu().tolist()[0])) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def choose_config(path: Path): | 
					
						
						|  | yaml_files = list(path.glob("*.yml")) | 
					
						
						|  |  | 
					
						
						|  | if not yaml_files: | 
					
						
						|  | raise ValueError( | 
					
						
						|  | "No YAML config files found in the specified directory. Are you using a .yml extension?" | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | print("Choose a YAML file:") | 
					
						
						|  | for idx, file in enumerate(yaml_files): | 
					
						
						|  | print(f"{idx + 1}. {file}") | 
					
						
						|  |  | 
					
						
						|  | chosen_file = None | 
					
						
						|  | while chosen_file is None: | 
					
						
						|  | try: | 
					
						
						|  | choice = int(input("Enter the number of your choice: ")) | 
					
						
						|  | if 1 <= choice <= len(yaml_files): | 
					
						
						|  | chosen_file = yaml_files[choice - 1] | 
					
						
						|  | else: | 
					
						
						|  | print("Invalid choice. Please choose a number from the list.") | 
					
						
						|  | except ValueError: | 
					
						
						|  | print("Invalid input. Please enter a number.") | 
					
						
						|  |  | 
					
						
						|  | return chosen_file | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def check_not_in(list1: List[str], list2: Union[Dict[str, Any], List[str]]) -> bool: | 
					
						
						|  | return not any(el in list2 for el in list1) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def train( | 
					
						
						|  | config: Path = Path("configs/"), | 
					
						
						|  | prepare_ds_only: bool = False, | 
					
						
						|  | **kwargs, | 
					
						
						|  | ): | 
					
						
						|  | if Path(config).is_dir(): | 
					
						
						|  | config = choose_config(config) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | with open(config, encoding="utf-8") as file: | 
					
						
						|  | cfg: DictDefault = DictDefault(yaml.safe_load(file)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | cfg_keys = cfg.keys() | 
					
						
						|  | for k, _ in kwargs.items(): | 
					
						
						|  |  | 
					
						
						|  | if k in cfg_keys or cfg.strict is False: | 
					
						
						|  |  | 
					
						
						|  | if isinstance(cfg[k], bool): | 
					
						
						|  | cfg[k] = bool(kwargs[k]) | 
					
						
						|  | else: | 
					
						
						|  | cfg[k] = kwargs[k] | 
					
						
						|  |  | 
					
						
						|  | validate_config(cfg) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or ( | 
					
						
						|  | cfg.batch_size // cfg.micro_batch_size | 
					
						
						|  | ) | 
					
						
						|  | cfg.batch_size = ( | 
					
						
						|  | cfg.batch_size or cfg.micro_batch_size * cfg.gradient_accumulation_steps | 
					
						
						|  | ) | 
					
						
						|  | cfg.world_size = int(os.environ.get("WORLD_SIZE", 1)) | 
					
						
						|  | cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0)) | 
					
						
						|  | choose_device(cfg) | 
					
						
						|  | cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1 | 
					
						
						|  | if cfg.ddp: | 
					
						
						|  | cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))} | 
					
						
						|  | cfg.batch_size = cfg.batch_size * cfg.world_size | 
					
						
						|  |  | 
					
						
						|  | setup_wandb_env_vars(cfg) | 
					
						
						|  | if cfg.device == "mps": | 
					
						
						|  | cfg.load_in_8bit = False | 
					
						
						|  | cfg.tf32 = False | 
					
						
						|  | if cfg.bf16: | 
					
						
						|  | cfg.fp16 = True | 
					
						
						|  | cfg.bf16 = False | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | tokenizer_config = cfg.tokenizer_config or cfg.base_model_config | 
					
						
						|  | logging.info(f"loading tokenizer... {tokenizer_config}") | 
					
						
						|  | tokenizer = load_tokenizer(tokenizer_config, cfg.tokenizer_type, cfg) | 
					
						
						|  |  | 
					
						
						|  | if check_not_in( | 
					
						
						|  | ["inference", "shard", "merge_lora"], kwargs | 
					
						
						|  | ): | 
					
						
						|  | train_dataset, eval_dataset = load_prepare_datasets( | 
					
						
						|  | tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | if cfg.debug or "debug" in kwargs: | 
					
						
						|  | logging.info("check_dataset_labels...") | 
					
						
						|  | check_dataset_labels( | 
					
						
						|  | train_dataset.select( | 
					
						
						|  | [random.randrange(0, len(train_dataset) - 1) for _ in range(5)] | 
					
						
						|  | ), | 
					
						
						|  | tokenizer, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | if prepare_ds_only: | 
					
						
						|  | logging.info("Finished preparing dataset. Exiting...") | 
					
						
						|  | return | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | logging.info("loading model and peft_config...") | 
					
						
						|  | model, peft_config = load_model( | 
					
						
						|  | cfg.base_model, | 
					
						
						|  | cfg.base_model_config, | 
					
						
						|  | cfg.model_type, | 
					
						
						|  | tokenizer, | 
					
						
						|  | cfg, | 
					
						
						|  | adapter=cfg.adapter, | 
					
						
						|  | inference=("inference" in kwargs), | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | if "merge_lora" in kwargs and cfg.adapter is not None: | 
					
						
						|  | logging.info("running merge of LoRA with base model") | 
					
						
						|  | model = model.merge_and_unload() | 
					
						
						|  | model.to(dtype=torch.float16) | 
					
						
						|  |  | 
					
						
						|  | if cfg.local_rank == 0: | 
					
						
						|  | logging.info("saving merged model") | 
					
						
						|  | model.save_pretrained(str(Path(cfg.output_dir) / "merged")) | 
					
						
						|  | return | 
					
						
						|  |  | 
					
						
						|  | if "inference" in kwargs: | 
					
						
						|  | logging.info("calling do_inference function") | 
					
						
						|  | do_inference(cfg, model, tokenizer) | 
					
						
						|  | return | 
					
						
						|  |  | 
					
						
						|  | if "shard" in kwargs: | 
					
						
						|  | model.save_pretrained(cfg.output_dir) | 
					
						
						|  | return | 
					
						
						|  |  | 
					
						
						|  | trainer = setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer) | 
					
						
						|  |  | 
					
						
						|  | model.config.use_cache = False | 
					
						
						|  |  | 
					
						
						|  | if torch.__version__ >= "2" and sys.platform != "win32": | 
					
						
						|  | logging.info("Compiling torch model") | 
					
						
						|  | model = torch.compile(model) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if peft_config: | 
					
						
						|  | logging.info(f"Pre-saving adapter config to {cfg.output_dir}") | 
					
						
						|  | peft_config.save_pretrained(cfg.output_dir) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if cfg.local_rank == 0: | 
					
						
						|  | signal.signal( | 
					
						
						|  | signal.SIGINT, | 
					
						
						|  | lambda signal, frame: ( | 
					
						
						|  | model.save_pretrained(cfg.output_dir), | 
					
						
						|  | sys.exit(0), | 
					
						
						|  | ), | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | logging.info("Starting trainer...") | 
					
						
						|  | if cfg.group_by_length: | 
					
						
						|  | logging.info("hang tight... sorting dataset for group_by_length") | 
					
						
						|  | resume_from_checkpoint = cfg.resume_from_checkpoint | 
					
						
						|  | if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints: | 
					
						
						|  | possible_checkpoints = [ | 
					
						
						|  | str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*") | 
					
						
						|  | ] | 
					
						
						|  | if len(possible_checkpoints) > 0: | 
					
						
						|  | sorted_paths = sorted( | 
					
						
						|  | possible_checkpoints, | 
					
						
						|  | key=lambda path: int(path.split("-")[-1]), | 
					
						
						|  | ) | 
					
						
						|  | resume_from_checkpoint = sorted_paths[-1] | 
					
						
						|  | logging.info( | 
					
						
						|  | f"Using Auto-resume functionality to start with checkpoint at {resume_from_checkpoint}" | 
					
						
						|  | ) | 
					
						
						|  | trainer.train(resume_from_checkpoint=resume_from_checkpoint) | 
					
						
						|  |  | 
					
						
						|  | logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if cfg.local_rank == 0: | 
					
						
						|  | model.save_pretrained(cfg.output_dir) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if __name__ == "__main__": | 
					
						
						|  | fire.Fire(train) | 
					
						
						|  |  |