cleanup, prep for 4bit quant support
Browse files- README.md +21 -1
 - scripts/finetune.py +18 -6
 - setup.cfg +3 -0
 
    	
        README.md
    CHANGED
    
    | 
         @@ -30,4 +30,24 @@ shuf -n2000 data/vicuna_cleaned.jsonl > data/vicuna_cleaned.subset0.jsonl 
     | 
|
| 30 | 
         | 
| 31 | 
         
             
            - Create a new or update the existing YAML config (config/pythia_1_2B_alpaca.yml)[config/pythia_1_2B_alpaca.yml]
         
     | 
| 32 | 
         
             
            - Install python dependencies `pip3 install -r requirements.txt`
         
     | 
| 33 | 
         
            -
            -  
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 30 | 
         | 
| 31 | 
         
             
            - Create a new or update the existing YAML config (config/pythia_1_2B_alpaca.yml)[config/pythia_1_2B_alpaca.yml]
         
     | 
| 32 | 
         
             
            - Install python dependencies `pip3 install -r requirements.txt`
         
     | 
| 33 | 
         
            +
            - Configure accelerate `accelerate launch` or update `~/.cache/huggingface/accelerate/default_config.yaml`
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
            ```yaml
         
     | 
| 36 | 
         
            +
            compute_environment: LOCAL_MACHINE
         
     | 
| 37 | 
         
            +
            distributed_type: MULTI_GPU
         
     | 
| 38 | 
         
            +
            downcast_bf16: 'no'
         
     | 
| 39 | 
         
            +
            gpu_ids: all
         
     | 
| 40 | 
         
            +
            machine_rank: 0
         
     | 
| 41 | 
         
            +
            main_training_function: main
         
     | 
| 42 | 
         
            +
            mixed_precision: bf16
         
     | 
| 43 | 
         
            +
            num_machines: 1
         
     | 
| 44 | 
         
            +
            num_processes: 4
         
     | 
| 45 | 
         
            +
            rdzv_backend: static
         
     | 
| 46 | 
         
            +
            same_network: true
         
     | 
| 47 | 
         
            +
            tpu_env: []
         
     | 
| 48 | 
         
            +
            tpu_use_cluster: false
         
     | 
| 49 | 
         
            +
            tpu_use_sudo: false
         
     | 
| 50 | 
         
            +
            use_cpu: false
         
     | 
| 51 | 
         
            +
            ```
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
            - Train! `accelerate launch scripts/finetune.py`, make sure to choose the correct YAML config file
         
     | 
    	
        scripts/finetune.py
    CHANGED
    
    | 
         @@ -68,26 +68,27 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora", infe 
     | 
|
| 68 | 
         
             
                        from axolotl.flash_attn import replace_llama_attn_with_flash_attn
         
     | 
| 69 | 
         
             
                        replace_llama_attn_with_flash_attn()
         
     | 
| 70 | 
         | 
| 
         | 
|
| 71 | 
         
             
                try:
         
     | 
| 72 | 
         
             
                    if "llama" in base_model:
         
     | 
| 73 | 
         
             
                        model = LlamaForCausalLM.from_pretrained(
         
     | 
| 74 | 
         
             
                            base_model,
         
     | 
| 75 | 
         
             
                            load_in_8bit=cfg.load_in_8bit,
         
     | 
| 76 | 
         
            -
                            torch_dtype= 
     | 
| 77 | 
         
             
                            device_map=cfg.device_map,
         
     | 
| 78 | 
         
             
                        )
         
     | 
| 79 | 
         
             
                    else:
         
     | 
| 80 | 
         
             
                        model = getattr(transformers, model_type).from_pretrained(
         
     | 
| 81 | 
         
             
                            base_model,
         
     | 
| 82 | 
         
             
                            load_in_8bit=cfg.load_in_8bit,
         
     | 
| 83 | 
         
            -
                            torch_dtype= 
     | 
| 84 | 
         
             
                            device_map=cfg.device_map,
         
     | 
| 85 | 
         
             
                        )
         
     | 
| 86 | 
         
             
                except:
         
     | 
| 87 | 
         
             
                    model = AutoModelForCausalLM.from_pretrained(
         
     | 
| 88 | 
         
             
                        base_model,
         
     | 
| 89 | 
         
             
                        load_in_8bit=cfg.load_in_8bit,
         
     | 
| 90 | 
         
            -
                        torch_dtype= 
     | 
| 91 | 
         
             
                        device_map=cfg.device_map,
         
     | 
| 92 | 
         
             
                    )
         
     | 
| 93 | 
         | 
| 
         @@ -235,7 +236,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): 
     | 
|
| 235 | 
         
             
                save_steps = eval_steps = min(int(0.05 * total_num_steps), 200)
         
     | 
| 236 | 
         | 
| 237 | 
         
             
                training_arguments_kwargs = {}
         
     | 
| 238 | 
         
            -
                 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 239 | 
         
             
                training_arguments_kwargs["tf32"] = cfg.tf32
         
     | 
| 240 | 
         
             
                training_arguments_kwargs["warmup_steps"] = warmup_steps
         
     | 
| 241 | 
         
             
                training_arguments_kwargs["logging_steps"] = logging_steps
         
     | 
| 
         @@ -256,10 +260,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): 
     | 
|
| 256 | 
         
             
                    group_by_length=cfg.group_by_length,
         
     | 
| 257 | 
         
             
                    report_to="wandb" if cfg.use_wandb else None,
         
     | 
| 258 | 
         
             
                    run_name=cfg.wandb_run_id if cfg.use_wandb else None,
         
     | 
| 
         | 
|
| 259 | 
         
             
                    **training_arguments_kwargs,
         
     | 
| 260 | 
         
             
                )
         
     | 
| 261 | 
         | 
| 262 | 
         
            -
                trainer_kwargs = {}
         
     | 
| 263 | 
         
             
                decay_parameters = get_parameter_names(model, [nn.LayerNorm])
         
     | 
| 264 | 
         
             
                decay_parameters = [name for name in decay_parameters if "bias" not in name]
         
     | 
| 265 | 
         
             
                optimizer_grouped_parameters = [
         
     | 
| 
         @@ -282,13 +286,14 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): 
     | 
|
| 282 | 
         
             
                    lr=training_args.learning_rate,
         
     | 
| 283 | 
         
             
                )
         
     | 
| 284 | 
         | 
| 
         | 
|
| 285 | 
         
             
                lr_scheduler = transformers.get_cosine_schedule_with_warmup(
         
     | 
| 286 | 
         
             
                    adam_bnb_optim,
         
     | 
| 287 | 
         
             
                    training_args.warmup_steps,
         
     | 
| 288 | 
         
             
                    total_num_steps,
         
     | 
| 289 | 
         
             
                )
         
     | 
| 290 | 
         
            -
                trainer_kwargs["optimizers"] = (adam_bnb_optim, lr_scheduler)
         
     | 
| 291 | 
         | 
| 
         | 
|
| 292 | 
         
             
                if cfg.early_stopping_patience:
         
     | 
| 293 | 
         
             
                    early_stop_cb = EarlyStoppingCallback(
         
     | 
| 294 | 
         
             
                        cfg.early_stopping_patience,
         
     | 
| 
         @@ -300,6 +305,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): 
     | 
|
| 300 | 
         
             
                    train_dataset=train_dataset,
         
     | 
| 301 | 
         
             
                    eval_dataset=eval_dataset,
         
     | 
| 302 | 
         
             
                    args=training_args,
         
     | 
| 
         | 
|
| 303 | 
         
             
                    data_collator=transformers.DataCollatorForSeq2Seq(
         
     | 
| 304 | 
         
             
                        tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
         
     | 
| 305 | 
         
             
                    ),
         
     | 
| 
         @@ -342,6 +348,12 @@ def train( 
     | 
|
| 342 | 
         
             
                        cfg.gradient_accumulation_steps // cfg.world_size
         
     | 
| 343 | 
         
             
                    )
         
     | 
| 344 | 
         
             
                setup_wandb_env_vars(cfg)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 345 | 
         | 
| 346 | 
         
             
                # Load the model and tokenizer
         
     | 
| 347 | 
         
             
                model, tokenizer, lora_config = load_model(
         
     | 
| 
         | 
|
| 68 | 
         
             
                        from axolotl.flash_attn import replace_llama_attn_with_flash_attn
         
     | 
| 69 | 
         
             
                        replace_llama_attn_with_flash_attn()
         
     | 
| 70 | 
         | 
| 71 | 
         
            +
                torch_dtype = torch.float16 if cfg.load_in_8bit or cfg.fp16 else torch.float32,
         
     | 
| 72 | 
         
             
                try:
         
     | 
| 73 | 
         
             
                    if "llama" in base_model:
         
     | 
| 74 | 
         
             
                        model = LlamaForCausalLM.from_pretrained(
         
     | 
| 75 | 
         
             
                            base_model,
         
     | 
| 76 | 
         
             
                            load_in_8bit=cfg.load_in_8bit,
         
     | 
| 77 | 
         
            +
                            torch_dtype=torch_dtype,
         
     | 
| 78 | 
         
             
                            device_map=cfg.device_map,
         
     | 
| 79 | 
         
             
                        )
         
     | 
| 80 | 
         
             
                    else:
         
     | 
| 81 | 
         
             
                        model = getattr(transformers, model_type).from_pretrained(
         
     | 
| 82 | 
         
             
                            base_model,
         
     | 
| 83 | 
         
             
                            load_in_8bit=cfg.load_in_8bit,
         
     | 
| 84 | 
         
            +
                            torch_dtype=torch_dtype,
         
     | 
| 85 | 
         
             
                            device_map=cfg.device_map,
         
     | 
| 86 | 
         
             
                        )
         
     | 
| 87 | 
         
             
                except:
         
     | 
| 88 | 
         
             
                    model = AutoModelForCausalLM.from_pretrained(
         
     | 
| 89 | 
         
             
                        base_model,
         
     | 
| 90 | 
         
             
                        load_in_8bit=cfg.load_in_8bit,
         
     | 
| 91 | 
         
            +
                        torch_dtype=torch_dtype,
         
     | 
| 92 | 
         
             
                        device_map=cfg.device_map,
         
     | 
| 93 | 
         
             
                    )
         
     | 
| 94 | 
         | 
| 
         | 
|
| 236 | 
         
             
                save_steps = eval_steps = min(int(0.05 * total_num_steps), 200)
         
     | 
| 237 | 
         | 
| 238 | 
         
             
                training_arguments_kwargs = {}
         
     | 
| 239 | 
         
            +
                if cfg.bf16 == "full":
         
     | 
| 240 | 
         
            +
                    training_arguments_kwargs["bf16_full_eval"] = True
         
     | 
| 241 | 
         
            +
                else:
         
     | 
| 242 | 
         
            +
                    training_arguments_kwargs["bf16"] = cfg.bf16
         
     | 
| 243 | 
         
             
                training_arguments_kwargs["tf32"] = cfg.tf32
         
     | 
| 244 | 
         
             
                training_arguments_kwargs["warmup_steps"] = warmup_steps
         
     | 
| 245 | 
         
             
                training_arguments_kwargs["logging_steps"] = logging_steps
         
     | 
| 
         | 
|
| 260 | 
         
             
                    group_by_length=cfg.group_by_length,
         
     | 
| 261 | 
         
             
                    report_to="wandb" if cfg.use_wandb else None,
         
     | 
| 262 | 
         
             
                    run_name=cfg.wandb_run_id if cfg.use_wandb else None,
         
     | 
| 263 | 
         
            +
                    gradient_checkpointing=cfg.gradient_checkpointing,
         
     | 
| 264 | 
         
             
                    **training_arguments_kwargs,
         
     | 
| 265 | 
         
             
                )
         
     | 
| 266 | 
         | 
| 
         | 
|
| 267 | 
         
             
                decay_parameters = get_parameter_names(model, [nn.LayerNorm])
         
     | 
| 268 | 
         
             
                decay_parameters = [name for name in decay_parameters if "bias" not in name]
         
     | 
| 269 | 
         
             
                optimizer_grouped_parameters = [
         
     | 
| 
         | 
|
| 286 | 
         
             
                    lr=training_args.learning_rate,
         
     | 
| 287 | 
         
             
                )
         
     | 
| 288 | 
         | 
| 289 | 
         
            +
                # TODO optionally use torch.optim.OneCycleLR
         
     | 
| 290 | 
         
             
                lr_scheduler = transformers.get_cosine_schedule_with_warmup(
         
     | 
| 291 | 
         
             
                    adam_bnb_optim,
         
     | 
| 292 | 
         
             
                    training_args.warmup_steps,
         
     | 
| 293 | 
         
             
                    total_num_steps,
         
     | 
| 294 | 
         
             
                )
         
     | 
| 
         | 
|
| 295 | 
         | 
| 296 | 
         
            +
                trainer_kwargs = {}
         
     | 
| 297 | 
         
             
                if cfg.early_stopping_patience:
         
     | 
| 298 | 
         
             
                    early_stop_cb = EarlyStoppingCallback(
         
     | 
| 299 | 
         
             
                        cfg.early_stopping_patience,
         
     | 
| 
         | 
|
| 305 | 
         
             
                    train_dataset=train_dataset,
         
     | 
| 306 | 
         
             
                    eval_dataset=eval_dataset,
         
     | 
| 307 | 
         
             
                    args=training_args,
         
     | 
| 308 | 
         
            +
                    optimizers=(adam_bnb_optim, lr_scheduler),
         
     | 
| 309 | 
         
             
                    data_collator=transformers.DataCollatorForSeq2Seq(
         
     | 
| 310 | 
         
             
                        tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
         
     | 
| 311 | 
         
             
                    ),
         
     | 
| 
         | 
|
| 348 | 
         
             
                        cfg.gradient_accumulation_steps // cfg.world_size
         
     | 
| 349 | 
         
             
                    )
         
     | 
| 350 | 
         
             
                setup_wandb_env_vars(cfg)
         
     | 
| 351 | 
         
            +
                if cfg.device == "mps":
         
     | 
| 352 | 
         
            +
                    cfg.load_in_8bit = False
         
     | 
| 353 | 
         
            +
                    cfg.tf32 = False
         
     | 
| 354 | 
         
            +
                    if cfg.bf16:
         
     | 
| 355 | 
         
            +
                        cfg.fp16 = True
         
     | 
| 356 | 
         
            +
                    cfg.bf16 = False
         
     | 
| 357 | 
         | 
| 358 | 
         
             
                # Load the model and tokenizer
         
     | 
| 359 | 
         
             
                model, tokenizer, lora_config = load_model(
         
     | 
    	
        setup.cfg
    CHANGED
    
    | 
         @@ -28,3 +28,6 @@ install_requires = 
     | 
|
| 28 | 
         
             
            [options.packages.find]
         
     | 
| 29 | 
         
             
            where = src
         
     | 
| 30 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 28 | 
         
             
            [options.packages.find]
         
     | 
| 29 | 
         
             
            where = src
         
     | 
| 30 | 
         | 
| 31 | 
         
            +
            [options.extras_require]
         
     | 
| 32 | 
         
            +
            gptq_cuda = alpaca_lora_4bit[cuda] @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip#egg=alpaca_lora_4bit[cuda]
         
     | 
| 33 | 
         
            +
            gptq_triton = alpaca_lora_4bit[triton] @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip#egg=alpaca_lora_4bit[triton]
         
     |