|  | import importlib | 
					
						
						|  | import logging | 
					
						
						|  | import os | 
					
						
						|  | import pathlib | 
					
						
						|  | import random | 
					
						
						|  | import signal | 
					
						
						|  | import sys | 
					
						
						|  | from pathlib import Path | 
					
						
						|  | from typing import Optional | 
					
						
						|  |  | 
					
						
						|  | import fire | 
					
						
						|  | import torch | 
					
						
						|  | import transformers | 
					
						
						|  | import yaml | 
					
						
						|  | from attrdict import AttrDefault | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | from axolotl.utils.tokenization import check_dataset_labels | 
					
						
						|  |  | 
					
						
						|  | project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) | 
					
						
						|  | src_dir = os.path.join(project_root, "src") | 
					
						
						|  | sys.path.insert(0, src_dir) | 
					
						
						|  |  | 
					
						
						|  | from axolotl.utils.data import load_prepare_datasets | 
					
						
						|  | from axolotl.utils.models import load_model | 
					
						
						|  | from axolotl.utils.trainer import setup_trainer | 
					
						
						|  | from axolotl.utils.wandb import setup_wandb_env_vars | 
					
						
						|  |  | 
					
						
						|  | logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO")) | 
					
						
						|  | DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def choose_device(cfg): | 
					
						
						|  | def get_device(): | 
					
						
						|  | if torch.cuda.is_available(): | 
					
						
						|  | return "cuda" | 
					
						
						|  | else: | 
					
						
						|  | try: | 
					
						
						|  | if torch.backends.mps.is_available(): | 
					
						
						|  | return "mps" | 
					
						
						|  | except: | 
					
						
						|  | return "cpu" | 
					
						
						|  |  | 
					
						
						|  | cfg.device = get_device() | 
					
						
						|  | if cfg.device == "cuda": | 
					
						
						|  | cfg.device_map = {"": cfg.local_rank} | 
					
						
						|  | else: | 
					
						
						|  | cfg.device_map = {"": cfg.device} | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def get_multi_line_input() -> Optional[str]: | 
					
						
						|  | print("Give me an instruction (Ctrl + D to finish): ") | 
					
						
						|  | instruction = "" | 
					
						
						|  | for line in sys.stdin: | 
					
						
						|  | instruction += line | 
					
						
						|  |  | 
					
						
						|  | return instruction | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"): | 
					
						
						|  | tokenizer.add_special_tokens({"unk_token": "<unk>"}) | 
					
						
						|  | tokenizer.add_special_tokens({"bos_token": "<s>"}) | 
					
						
						|  | tokenizer.add_special_tokens({"eos_token": "</s>"}) | 
					
						
						|  |  | 
					
						
						|  | prompter_module = getattr(importlib.import_module("axolotl.prompters"), prompter) | 
					
						
						|  |  | 
					
						
						|  | while True: | 
					
						
						|  |  | 
					
						
						|  | instruction = get_multi_line_input() | 
					
						
						|  | if not instruction: | 
					
						
						|  | return | 
					
						
						|  | prompt = prompter_module().build_prompt(instruction=instruction) | 
					
						
						|  | batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True) | 
					
						
						|  |  | 
					
						
						|  | model.eval() | 
					
						
						|  | with torch.no_grad(): | 
					
						
						|  |  | 
					
						
						|  | generated = model.generate( | 
					
						
						|  | inputs=batch["input_ids"].to(cfg.device), | 
					
						
						|  | do_sample=True, | 
					
						
						|  | use_cache=True, | 
					
						
						|  | repetition_penalty=1.1, | 
					
						
						|  | max_new_tokens=100, | 
					
						
						|  | temperature=0.9, | 
					
						
						|  | top_p=0.95, | 
					
						
						|  | top_k=40, | 
					
						
						|  | return_dict_in_generate=True, | 
					
						
						|  | output_attentions=False, | 
					
						
						|  | output_hidden_states=False, | 
					
						
						|  | output_scores=False, | 
					
						
						|  | ) | 
					
						
						|  | print(tokenizer.decode(generated["sequences"].cpu().tolist()[0])) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def choose_config(path: Path): | 
					
						
						|  | yaml_files = [file for file in path.glob("*.yml")] | 
					
						
						|  |  | 
					
						
						|  | if not yaml_files: | 
					
						
						|  | raise ValueError( | 
					
						
						|  | "No YAML config files found in the specified directory. Are you using a .yml extension?" | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | print("Choose a YAML file:") | 
					
						
						|  | for idx, file in enumerate(yaml_files): | 
					
						
						|  | print(f"{idx + 1}. {file}") | 
					
						
						|  |  | 
					
						
						|  | chosen_file = None | 
					
						
						|  | while chosen_file is None: | 
					
						
						|  | try: | 
					
						
						|  | choice = int(input("Enter the number of your choice: ")) | 
					
						
						|  | if 1 <= choice <= len(yaml_files): | 
					
						
						|  | chosen_file = yaml_files[choice - 1] | 
					
						
						|  | else: | 
					
						
						|  | print("Invalid choice. Please choose a number from the list.") | 
					
						
						|  | except ValueError: | 
					
						
						|  | print("Invalid input. Please enter a number.") | 
					
						
						|  |  | 
					
						
						|  | return chosen_file | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def train( | 
					
						
						|  | config: Path = Path("configs/"), | 
					
						
						|  | prepare_ds_only: bool = False, | 
					
						
						|  | **kwargs, | 
					
						
						|  | ): | 
					
						
						|  | if Path(config).is_dir(): | 
					
						
						|  | config = choose_config(config) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | with open(config, "r") as f: | 
					
						
						|  | cfg: AttrDefault = AttrDefault(lambda: None, yaml.load(f, Loader=yaml.Loader)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | cfg_keys = dict(cfg).keys() | 
					
						
						|  | for k in kwargs: | 
					
						
						|  | if k in cfg_keys: | 
					
						
						|  |  | 
					
						
						|  | if isinstance(cfg[k], bool): | 
					
						
						|  | cfg[k] = bool(kwargs[k]) | 
					
						
						|  | else: | 
					
						
						|  | cfg[k] = kwargs[k] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | cfg.gradient_accumulation_steps = cfg.batch_size // cfg.micro_batch_size | 
					
						
						|  | cfg.world_size = int(os.environ.get("WORLD_SIZE", 1)) | 
					
						
						|  | cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0)) | 
					
						
						|  | choose_device(cfg) | 
					
						
						|  | cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1 | 
					
						
						|  | if cfg.ddp: | 
					
						
						|  | cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))} | 
					
						
						|  | cfg.gradient_accumulation_steps = ( | 
					
						
						|  | cfg.gradient_accumulation_steps // cfg.world_size | 
					
						
						|  | ) | 
					
						
						|  | setup_wandb_env_vars(cfg) | 
					
						
						|  | if cfg.device == "mps": | 
					
						
						|  | cfg.load_in_8bit = False | 
					
						
						|  | cfg.tf32 = False | 
					
						
						|  | if cfg.bf16: | 
					
						
						|  | cfg.fp16 = True | 
					
						
						|  | cfg.bf16 = False | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | logging.info("loading model, tokenizer, and peft_config...") | 
					
						
						|  | model, tokenizer, peft_config = load_model( | 
					
						
						|  | cfg.base_model, | 
					
						
						|  | cfg.base_model_config, | 
					
						
						|  | cfg.model_type, | 
					
						
						|  | cfg.tokenizer_type, | 
					
						
						|  | cfg, | 
					
						
						|  | adapter=cfg.adapter, | 
					
						
						|  | inference=("inference" in kwargs), | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | if "inference" in kwargs: | 
					
						
						|  | logging.info("calling do_inference function") | 
					
						
						|  | do_inference(cfg, model, tokenizer) | 
					
						
						|  | return | 
					
						
						|  |  | 
					
						
						|  | if "shard" in kwargs: | 
					
						
						|  | model.save_pretrained(cfg.output_dir) | 
					
						
						|  | return | 
					
						
						|  |  | 
					
						
						|  | train_dataset, eval_dataset = load_prepare_datasets( | 
					
						
						|  | tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | if prepare_ds_only: | 
					
						
						|  | logging.info("Finished preparing dataset. Exiting...") | 
					
						
						|  | return | 
					
						
						|  |  | 
					
						
						|  | if cfg.debug: | 
					
						
						|  | logging.info("check_dataset_labels...") | 
					
						
						|  | check_dataset_labels( | 
					
						
						|  | train_dataset.select([random.randrange(0, len(train_dataset) - 1) for i in range(5)]), | 
					
						
						|  | tokenizer, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | trainer = setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer) | 
					
						
						|  |  | 
					
						
						|  | model.config.use_cache = False | 
					
						
						|  |  | 
					
						
						|  | if torch.__version__ >= "2" and sys.platform != "win32": | 
					
						
						|  | logging.info("Compiling torch model") | 
					
						
						|  | model = torch.compile(model) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if peft_config: | 
					
						
						|  | logging.info(f"Pre-saving adapter config to {cfg.output_dir}") | 
					
						
						|  | peft_config.save_pretrained(cfg.output_dir) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if cfg.local_rank == 0: | 
					
						
						|  | signal.signal( | 
					
						
						|  | signal.SIGINT, | 
					
						
						|  | lambda signal, frame: (model.save_pretrained(cfg.output_dir), exit(0)), | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | logging.info("Starting trainer...") | 
					
						
						|  | resume_from_checkpoint = cfg.resume_from_checkpoint | 
					
						
						|  | if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints: | 
					
						
						|  | possible_checkpoints = [str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")] | 
					
						
						|  | if len(possible_checkpoints) > 0: | 
					
						
						|  | sorted_paths = sorted(possible_checkpoints, key=lambda path: int(path.split('-')[-1])) | 
					
						
						|  | resume_from_checkpoint = sorted_paths[-1] | 
					
						
						|  | logging.info(f"Using Auto-resume functionality to start with checkpoint at {resume_from_checkpoint}") | 
					
						
						|  | trainer.train(resume_from_checkpoint=resume_from_checkpoint) | 
					
						
						|  |  | 
					
						
						|  | logging.info( | 
					
						
						|  | f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}" | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | trainer.save_model(cfg.output_dir) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if __name__ == "__main__": | 
					
						
						|  | fire.Fire(train) | 
					
						
						|  |  |