set fp16 to false if bf16, update bf16: auto in example YAMLs (#1122) [skip ci]
Browse files* set fp16 to false if bf16, update bf16: auto in example YAMLs
* unset fp16 so that it fallsback properly if bf16 isn't available
* Update README.md [skip-ci]
Co-authored-by: NanoCode012 <[email protected]>
* test that bf16 disables fp16
---------
Co-authored-by: NanoCode012 <[email protected]>
- README.md +2 -2
- examples/cerebras/btlm-ft.yml +2 -2
- examples/cerebras/qlora.yml +2 -2
- examples/code-llama/13b/lora.yml +2 -2
- examples/code-llama/13b/qlora.yml +2 -2
- examples/code-llama/34b/lora.yml +2 -2
- examples/code-llama/34b/qlora.yml +2 -2
- examples/code-llama/7b/lora.yml +2 -2
- examples/code-llama/7b/qlora.yml +2 -2
- examples/falcon/config-7b-lora.yml +2 -2
- examples/falcon/config-7b-qlora.yml +2 -2
- examples/falcon/config-7b.yml +2 -2
- examples/gptj/qlora.yml +2 -2
- examples/jeopardy-bot/config.yml +1 -1
- examples/llama-2/fft_optimized.yml +2 -2
- examples/llama-2/lora.yml +2 -2
- examples/llama-2/qlora.yml +2 -2
- examples/llama-2/relora.yml +2 -2
- examples/mamba/config.yml +2 -2
- examples/mistral/config.yml +2 -2
- examples/mistral/mixtral.yml +2 -2
- examples/mistral/qlora.yml +2 -2
- examples/mpt-7b/config.yml +1 -1
- examples/phi/phi-ft.yml +2 -2
- examples/phi/phi-qlora.yml +2 -2
- examples/phi/phi2-ft.yml +2 -2
- examples/pythia/lora.yml +1 -1
- examples/qwen/lora.yml +2 -2
- examples/qwen/qlora.yml +2 -2
- examples/redpajama/config-3b.yml +1 -1
- examples/replit-3b/config-lora.yml +1 -1
- examples/tiny-llama/lora.yml +2 -2
- examples/tiny-llama/pretrain.yml +2 -2
- examples/tiny-llama/qlora.yml +2 -2
- examples/xgen-7b/xgen-7b-8k-qlora.yml +2 -2
- examples/yi-34B-chat/qlora.yml +2 -2
- src/axolotl/utils/config.py +4 -0
- tests/test_normalize_config.py +15 -0
    	
        README.md
    CHANGED
    
    | @@ -464,8 +464,8 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod | |
| 464 | 
             
              ```yaml
         | 
| 465 | 
             
              load_in_4bit: true
         | 
| 466 | 
             
              load_in_8bit: true
         | 
| 467 | 
            -
              bf16:  | 
| 468 | 
            -
              fp16:  | 
| 469 | 
             
              tf32: true # require >=ampere
         | 
| 470 | 
             
              bfloat16: true # require >=ampere, use instead of bf16 when you don't want AMP (automatic mixed precision)
         | 
| 471 | 
             
              float16: true # use instead of fp16 when you don't want AMP
         | 
|  | |
| 464 | 
             
              ```yaml
         | 
| 465 | 
             
              load_in_4bit: true
         | 
| 466 | 
             
              load_in_8bit: true
         | 
| 467 | 
            +
              bf16: auto # require >=ampere, auto will detect if your GPU supports this and choose automatically.
         | 
| 468 | 
            +
              fp16: # leave empty to use fp16 when bf16 is 'auto'. set to false if you want to fallback to fp32
         | 
| 469 | 
             
              tf32: true # require >=ampere
         | 
| 470 | 
             
              bfloat16: true # require >=ampere, use instead of bf16 when you don't want AMP (automatic mixed precision)
         | 
| 471 | 
             
              float16: true # use instead of fp16 when you don't want AMP
         | 
    	
        examples/cerebras/btlm-ft.yml
    CHANGED
    
    | @@ -53,8 +53,8 @@ lr_quadratic_warmup: true | |
| 53 | 
             
            learning_rate: 0.000085
         | 
| 54 | 
             
            train_on_inputs: true
         | 
| 55 | 
             
            group_by_length: false
         | 
| 56 | 
            -
            bf16:  | 
| 57 | 
            -
            fp16: | 
| 58 | 
             
            tf32: true
         | 
| 59 |  | 
| 60 | 
             
            gradient_checkpointing: false
         | 
|  | |
| 53 | 
             
            learning_rate: 0.000085
         | 
| 54 | 
             
            train_on_inputs: true
         | 
| 55 | 
             
            group_by_length: false
         | 
| 56 | 
            +
            bf16: auto
         | 
| 57 | 
            +
            fp16:
         | 
| 58 | 
             
            tf32: true
         | 
| 59 |  | 
| 60 | 
             
            gradient_checkpointing: false
         | 
    	
        examples/cerebras/qlora.yml
    CHANGED
    
    | @@ -36,8 +36,8 @@ lr_scheduler: cosine | |
| 36 | 
             
            learning_rate: 0.0002
         | 
| 37 | 
             
            train_on_inputs: false
         | 
| 38 | 
             
            group_by_length: false
         | 
| 39 | 
            -
            bf16:  | 
| 40 | 
            -
            fp16: | 
| 41 | 
             
            tf32: true
         | 
| 42 | 
             
            gradient_checkpointing: true
         | 
| 43 | 
             
            early_stopping_patience:
         | 
|  | |
| 36 | 
             
            learning_rate: 0.0002
         | 
| 37 | 
             
            train_on_inputs: false
         | 
| 38 | 
             
            group_by_length: false
         | 
| 39 | 
            +
            bf16: auto
         | 
| 40 | 
            +
            fp16:
         | 
| 41 | 
             
            tf32: true
         | 
| 42 | 
             
            gradient_checkpointing: true
         | 
| 43 | 
             
            early_stopping_patience:
         | 
    	
        examples/code-llama/13b/lora.yml
    CHANGED
    
    | @@ -41,8 +41,8 @@ learning_rate: 0.0002 | |
| 41 |  | 
| 42 | 
             
            train_on_inputs: false
         | 
| 43 | 
             
            group_by_length: false
         | 
| 44 | 
            -
            bf16:  | 
| 45 | 
            -
            fp16: | 
| 46 | 
             
            tf32: false
         | 
| 47 |  | 
| 48 | 
             
            gradient_checkpointing: true
         | 
|  | |
| 41 |  | 
| 42 | 
             
            train_on_inputs: false
         | 
| 43 | 
             
            group_by_length: false
         | 
| 44 | 
            +
            bf16: auto
         | 
| 45 | 
            +
            fp16:
         | 
| 46 | 
             
            tf32: false
         | 
| 47 |  | 
| 48 | 
             
            gradient_checkpointing: true
         | 
    	
        examples/code-llama/13b/qlora.yml
    CHANGED
    
    | @@ -43,8 +43,8 @@ learning_rate: 0.0002 | |
| 43 |  | 
| 44 | 
             
            train_on_inputs: false
         | 
| 45 | 
             
            group_by_length: false
         | 
| 46 | 
            -
            bf16:  | 
| 47 | 
            -
            fp16: | 
| 48 | 
             
            tf32: false
         | 
| 49 |  | 
| 50 | 
             
            gradient_checkpointing: true
         | 
|  | |
| 43 |  | 
| 44 | 
             
            train_on_inputs: false
         | 
| 45 | 
             
            group_by_length: false
         | 
| 46 | 
            +
            bf16: auto
         | 
| 47 | 
            +
            fp16:
         | 
| 48 | 
             
            tf32: false
         | 
| 49 |  | 
| 50 | 
             
            gradient_checkpointing: true
         | 
    	
        examples/code-llama/34b/lora.yml
    CHANGED
    
    | @@ -41,8 +41,8 @@ learning_rate: 0.0002 | |
| 41 |  | 
| 42 | 
             
            train_on_inputs: false
         | 
| 43 | 
             
            group_by_length: false
         | 
| 44 | 
            -
            bf16:  | 
| 45 | 
            -
            fp16: | 
| 46 | 
             
            tf32: false
         | 
| 47 |  | 
| 48 | 
             
            gradient_checkpointing: true
         | 
|  | |
| 41 |  | 
| 42 | 
             
            train_on_inputs: false
         | 
| 43 | 
             
            group_by_length: false
         | 
| 44 | 
            +
            bf16: auto
         | 
| 45 | 
            +
            fp16:
         | 
| 46 | 
             
            tf32: false
         | 
| 47 |  | 
| 48 | 
             
            gradient_checkpointing: true
         | 
    	
        examples/code-llama/34b/qlora.yml
    CHANGED
    
    | @@ -43,8 +43,8 @@ learning_rate: 0.0002 | |
| 43 |  | 
| 44 | 
             
            train_on_inputs: false
         | 
| 45 | 
             
            group_by_length: false
         | 
| 46 | 
            -
            bf16:  | 
| 47 | 
            -
            fp16: | 
| 48 | 
             
            tf32: false
         | 
| 49 |  | 
| 50 | 
             
            gradient_checkpointing: true
         | 
|  | |
| 43 |  | 
| 44 | 
             
            train_on_inputs: false
         | 
| 45 | 
             
            group_by_length: false
         | 
| 46 | 
            +
            bf16: auto
         | 
| 47 | 
            +
            fp16:
         | 
| 48 | 
             
            tf32: false
         | 
| 49 |  | 
| 50 | 
             
            gradient_checkpointing: true
         | 
    	
        examples/code-llama/7b/lora.yml
    CHANGED
    
    | @@ -41,8 +41,8 @@ learning_rate: 0.0002 | |
| 41 |  | 
| 42 | 
             
            train_on_inputs: false
         | 
| 43 | 
             
            group_by_length: false
         | 
| 44 | 
            -
            bf16:  | 
| 45 | 
            -
            fp16: | 
| 46 | 
             
            tf32: false
         | 
| 47 |  | 
| 48 | 
             
            gradient_checkpointing: true
         | 
|  | |
| 41 |  | 
| 42 | 
             
            train_on_inputs: false
         | 
| 43 | 
             
            group_by_length: false
         | 
| 44 | 
            +
            bf16: auto
         | 
| 45 | 
            +
            fp16:
         | 
| 46 | 
             
            tf32: false
         | 
| 47 |  | 
| 48 | 
             
            gradient_checkpointing: true
         | 
    	
        examples/code-llama/7b/qlora.yml
    CHANGED
    
    | @@ -43,8 +43,8 @@ learning_rate: 0.0002 | |
| 43 |  | 
| 44 | 
             
            train_on_inputs: false
         | 
| 45 | 
             
            group_by_length: false
         | 
| 46 | 
            -
            bf16:  | 
| 47 | 
            -
            fp16: | 
| 48 | 
             
            tf32: false
         | 
| 49 |  | 
| 50 | 
             
            gradient_checkpointing: true
         | 
|  | |
| 43 |  | 
| 44 | 
             
            train_on_inputs: false
         | 
| 45 | 
             
            group_by_length: false
         | 
| 46 | 
            +
            bf16: auto
         | 
| 47 | 
            +
            fp16:
         | 
| 48 | 
             
            tf32: false
         | 
| 49 |  | 
| 50 | 
             
            gradient_checkpointing: true
         | 
    	
        examples/falcon/config-7b-lora.yml
    CHANGED
    
    | @@ -38,8 +38,8 @@ lr_scheduler: cosine | |
| 38 | 
             
            learning_rate: 0.00003
         | 
| 39 | 
             
            train_on_inputs: false
         | 
| 40 | 
             
            group_by_length: false
         | 
| 41 | 
            -
            bf16:  | 
| 42 | 
            -
            fp16: | 
| 43 | 
             
            tf32: true
         | 
| 44 | 
             
            gradient_checkpointing: true
         | 
| 45 | 
             
            early_stopping_patience:
         | 
|  | |
| 38 | 
             
            learning_rate: 0.00003
         | 
| 39 | 
             
            train_on_inputs: false
         | 
| 40 | 
             
            group_by_length: false
         | 
| 41 | 
            +
            bf16: auto
         | 
| 42 | 
            +
            fp16:
         | 
| 43 | 
             
            tf32: true
         | 
| 44 | 
             
            gradient_checkpointing: true
         | 
| 45 | 
             
            early_stopping_patience:
         | 
    	
        examples/falcon/config-7b-qlora.yml
    CHANGED
    
    | @@ -64,8 +64,8 @@ lr_scheduler: cosine | |
| 64 | 
             
            learning_rate: 0.0002
         | 
| 65 | 
             
            train_on_inputs: false
         | 
| 66 | 
             
            group_by_length: false
         | 
| 67 | 
            -
            bf16:  | 
| 68 | 
            -
            fp16: | 
| 69 | 
             
            tf32: true
         | 
| 70 | 
             
            gradient_checkpointing: true
         | 
| 71 | 
             
            # stop training after this many evaluation losses have increased in a row
         | 
|  | |
| 64 | 
             
            learning_rate: 0.0002
         | 
| 65 | 
             
            train_on_inputs: false
         | 
| 66 | 
             
            group_by_length: false
         | 
| 67 | 
            +
            bf16: auto
         | 
| 68 | 
            +
            fp16:
         | 
| 69 | 
             
            tf32: true
         | 
| 70 | 
             
            gradient_checkpointing: true
         | 
| 71 | 
             
            # stop training after this many evaluation losses have increased in a row
         | 
    	
        examples/falcon/config-7b.yml
    CHANGED
    
    | @@ -38,8 +38,8 @@ lr_scheduler: cosine | |
| 38 | 
             
            learning_rate: 0.00003
         | 
| 39 | 
             
            train_on_inputs: false
         | 
| 40 | 
             
            group_by_length: false
         | 
| 41 | 
            -
            bf16:  | 
| 42 | 
            -
            fp16: | 
| 43 | 
             
            tf32: true
         | 
| 44 | 
             
            gradient_checkpointing: true
         | 
| 45 | 
             
            early_stopping_patience:
         | 
|  | |
| 38 | 
             
            learning_rate: 0.00003
         | 
| 39 | 
             
            train_on_inputs: false
         | 
| 40 | 
             
            group_by_length: false
         | 
| 41 | 
            +
            bf16: auto
         | 
| 42 | 
            +
            fp16:
         | 
| 43 | 
             
            tf32: true
         | 
| 44 | 
             
            gradient_checkpointing: true
         | 
| 45 | 
             
            early_stopping_patience:
         | 
    	
        examples/gptj/qlora.yml
    CHANGED
    
    | @@ -33,8 +33,8 @@ lr_scheduler: cosine | |
| 33 | 
             
            learning_rate: 0.0001
         | 
| 34 | 
             
            train_on_inputs: false
         | 
| 35 | 
             
            group_by_length: false
         | 
| 36 | 
            -
            bf16:  | 
| 37 | 
            -
            fp16: | 
| 38 | 
             
            tf32: true
         | 
| 39 | 
             
            gradient_checkpointing: true
         | 
| 40 | 
             
            early_stopping_patience:
         | 
|  | |
| 33 | 
             
            learning_rate: 0.0001
         | 
| 34 | 
             
            train_on_inputs: false
         | 
| 35 | 
             
            group_by_length: false
         | 
| 36 | 
            +
            bf16: auto
         | 
| 37 | 
            +
            fp16:
         | 
| 38 | 
             
            tf32: true
         | 
| 39 | 
             
            gradient_checkpointing: true
         | 
| 40 | 
             
            early_stopping_patience:
         | 
    	
        examples/jeopardy-bot/config.yml
    CHANGED
    
    | @@ -31,7 +31,7 @@ lr_scheduler: cosine | |
| 31 | 
             
            learning_rate: 0.00003
         | 
| 32 | 
             
            train_on_inputs: false
         | 
| 33 | 
             
            group_by_length: false
         | 
| 34 | 
            -
            bf16:  | 
| 35 | 
             
            tf32: true
         | 
| 36 | 
             
            early_stopping_patience:
         | 
| 37 | 
             
            resume_from_checkpoint:
         | 
|  | |
| 31 | 
             
            learning_rate: 0.00003
         | 
| 32 | 
             
            train_on_inputs: false
         | 
| 33 | 
             
            group_by_length: false
         | 
| 34 | 
            +
            bf16: auto
         | 
| 35 | 
             
            tf32: true
         | 
| 36 | 
             
            early_stopping_patience:
         | 
| 37 | 
             
            resume_from_checkpoint:
         | 
    	
        examples/llama-2/fft_optimized.yml
    CHANGED
    
    | @@ -41,8 +41,8 @@ learning_rate: 0.0002 | |
| 41 |  | 
| 42 | 
             
            train_on_inputs: false
         | 
| 43 | 
             
            group_by_length: false
         | 
| 44 | 
            -
            bf16:  | 
| 45 | 
            -
            fp16: | 
| 46 | 
             
            tf32: false
         | 
| 47 |  | 
| 48 | 
             
            gradient_checkpointing: true
         | 
|  | |
| 41 |  | 
| 42 | 
             
            train_on_inputs: false
         | 
| 43 | 
             
            group_by_length: false
         | 
| 44 | 
            +
            bf16: auto
         | 
| 45 | 
            +
            fp16:
         | 
| 46 | 
             
            tf32: false
         | 
| 47 |  | 
| 48 | 
             
            gradient_checkpointing: true
         | 
    	
        examples/llama-2/lora.yml
    CHANGED
    
    | @@ -41,8 +41,8 @@ learning_rate: 0.0002 | |
| 41 |  | 
| 42 | 
             
            train_on_inputs: false
         | 
| 43 | 
             
            group_by_length: false
         | 
| 44 | 
            -
            bf16:  | 
| 45 | 
            -
            fp16: | 
| 46 | 
             
            tf32: false
         | 
| 47 |  | 
| 48 | 
             
            gradient_checkpointing: true
         | 
|  | |
| 41 |  | 
| 42 | 
             
            train_on_inputs: false
         | 
| 43 | 
             
            group_by_length: false
         | 
| 44 | 
            +
            bf16: auto
         | 
| 45 | 
            +
            fp16:
         | 
| 46 | 
             
            tf32: false
         | 
| 47 |  | 
| 48 | 
             
            gradient_checkpointing: true
         | 
    	
        examples/llama-2/qlora.yml
    CHANGED
    
    | @@ -43,8 +43,8 @@ learning_rate: 0.0002 | |
| 43 |  | 
| 44 | 
             
            train_on_inputs: false
         | 
| 45 | 
             
            group_by_length: false
         | 
| 46 | 
            -
            bf16:  | 
| 47 | 
            -
            fp16: | 
| 48 | 
             
            tf32: false
         | 
| 49 |  | 
| 50 | 
             
            gradient_checkpointing: true
         | 
|  | |
| 43 |  | 
| 44 | 
             
            train_on_inputs: false
         | 
| 45 | 
             
            group_by_length: false
         | 
| 46 | 
            +
            bf16: auto
         | 
| 47 | 
            +
            fp16:
         | 
| 48 | 
             
            tf32: false
         | 
| 49 |  | 
| 50 | 
             
            gradient_checkpointing: true
         | 
    	
        examples/llama-2/relora.yml
    CHANGED
    
    | @@ -47,8 +47,8 @@ learning_rate: 0.0002 | |
| 47 |  | 
| 48 | 
             
            train_on_inputs: false
         | 
| 49 | 
             
            group_by_length: false
         | 
| 50 | 
            -
            bf16:  | 
| 51 | 
            -
            fp16: | 
| 52 | 
             
            tf32: false
         | 
| 53 |  | 
| 54 | 
             
            gradient_checkpointing: true
         | 
|  | |
| 47 |  | 
| 48 | 
             
            train_on_inputs: false
         | 
| 49 | 
             
            group_by_length: false
         | 
| 50 | 
            +
            bf16: auto
         | 
| 51 | 
            +
            fp16:
         | 
| 52 | 
             
            tf32: false
         | 
| 53 |  | 
| 54 | 
             
            gradient_checkpointing: true
         | 
    	
        examples/mamba/config.yml
    CHANGED
    
    | @@ -34,8 +34,8 @@ learning_rate: 5e-5 | |
| 34 | 
             
            train_on_inputs: false
         | 
| 35 | 
             
            group_by_length: true
         | 
| 36 |  | 
| 37 | 
            -
            bf16:  | 
| 38 | 
            -
            fp16: | 
| 39 | 
             
            tf32: true
         | 
| 40 |  | 
| 41 | 
             
            gradient_checkpointing: false
         | 
|  | |
| 34 | 
             
            train_on_inputs: false
         | 
| 35 | 
             
            group_by_length: true
         | 
| 36 |  | 
| 37 | 
            +
            bf16: auto
         | 
| 38 | 
            +
            fp16:
         | 
| 39 | 
             
            tf32: true
         | 
| 40 |  | 
| 41 | 
             
            gradient_checkpointing: false
         | 
    	
        examples/mistral/config.yml
    CHANGED
    
    | @@ -34,8 +34,8 @@ learning_rate: 0.000005 | |
| 34 |  | 
| 35 | 
             
            train_on_inputs: false
         | 
| 36 | 
             
            group_by_length: false
         | 
| 37 | 
            -
            bf16:  | 
| 38 | 
            -
            fp16: | 
| 39 | 
             
            tf32: false
         | 
| 40 |  | 
| 41 | 
             
            gradient_checkpointing: true
         | 
|  | |
| 34 |  | 
| 35 | 
             
            train_on_inputs: false
         | 
| 36 | 
             
            group_by_length: false
         | 
| 37 | 
            +
            bf16: auto
         | 
| 38 | 
            +
            fp16:
         | 
| 39 | 
             
            tf32: false
         | 
| 40 |  | 
| 41 | 
             
            gradient_checkpointing: true
         | 
    	
        examples/mistral/mixtral.yml
    CHANGED
    
    | @@ -63,8 +63,8 @@ learning_rate: 0.0002 | |
| 63 |  | 
| 64 | 
             
            train_on_inputs: false
         | 
| 65 | 
             
            group_by_length: false
         | 
| 66 | 
            -
            bf16:  | 
| 67 | 
            -
            fp16: | 
| 68 | 
             
            tf32: false
         | 
| 69 |  | 
| 70 | 
             
            gradient_checkpointing: true
         | 
|  | |
| 63 |  | 
| 64 | 
             
            train_on_inputs: false
         | 
| 65 | 
             
            group_by_length: false
         | 
| 66 | 
            +
            bf16: auto
         | 
| 67 | 
            +
            fp16:
         | 
| 68 | 
             
            tf32: false
         | 
| 69 |  | 
| 70 | 
             
            gradient_checkpointing: true
         | 
    	
        examples/mistral/qlora.yml
    CHANGED
    
    | @@ -50,8 +50,8 @@ learning_rate: 0.0002 | |
| 50 |  | 
| 51 | 
             
            train_on_inputs: false
         | 
| 52 | 
             
            group_by_length: false
         | 
| 53 | 
            -
            bf16:  | 
| 54 | 
            -
            fp16: | 
| 55 | 
             
            tf32: false
         | 
| 56 |  | 
| 57 | 
             
            gradient_checkpointing: true
         | 
|  | |
| 50 |  | 
| 51 | 
             
            train_on_inputs: false
         | 
| 52 | 
             
            group_by_length: false
         | 
| 53 | 
            +
            bf16: auto
         | 
| 54 | 
            +
            fp16:
         | 
| 55 | 
             
            tf32: false
         | 
| 56 |  | 
| 57 | 
             
            gradient_checkpointing: true
         | 
    	
        examples/mpt-7b/config.yml
    CHANGED
    
    | @@ -33,7 +33,7 @@ lr_scheduler: cosine | |
| 33 | 
             
            learning_rate: 0.0000002
         | 
| 34 | 
             
            train_on_inputs: false
         | 
| 35 | 
             
            group_by_length: false
         | 
| 36 | 
            -
            bf16:  | 
| 37 | 
             
            tf32: true
         | 
| 38 | 
             
            early_stopping_patience:
         | 
| 39 | 
             
            resume_from_checkpoint:
         | 
|  | |
| 33 | 
             
            learning_rate: 0.0000002
         | 
| 34 | 
             
            train_on_inputs: false
         | 
| 35 | 
             
            group_by_length: false
         | 
| 36 | 
            +
            bf16: auto
         | 
| 37 | 
             
            tf32: true
         | 
| 38 | 
             
            early_stopping_patience:
         | 
| 39 | 
             
            resume_from_checkpoint:
         | 
    	
        examples/phi/phi-ft.yml
    CHANGED
    
    | @@ -46,8 +46,8 @@ learning_rate: 0.000003 | |
| 46 |  | 
| 47 | 
             
            train_on_inputs: false
         | 
| 48 | 
             
            group_by_length: true
         | 
| 49 | 
            -
            bf16:  | 
| 50 | 
            -
            fp16: | 
| 51 | 
             
            tf32: true
         | 
| 52 |  | 
| 53 | 
             
            gradient_checkpointing:
         | 
|  | |
| 46 |  | 
| 47 | 
             
            train_on_inputs: false
         | 
| 48 | 
             
            group_by_length: true
         | 
| 49 | 
            +
            bf16: auto
         | 
| 50 | 
            +
            fp16:
         | 
| 51 | 
             
            tf32: true
         | 
| 52 |  | 
| 53 | 
             
            gradient_checkpointing:
         | 
    	
        examples/phi/phi-qlora.yml
    CHANGED
    
    | @@ -46,8 +46,8 @@ learning_rate: 0.000003 | |
| 46 |  | 
| 47 | 
             
            train_on_inputs: false
         | 
| 48 | 
             
            group_by_length: true
         | 
| 49 | 
            -
            bf16:  | 
| 50 | 
            -
            fp16: | 
| 51 | 
             
            tf32: true
         | 
| 52 |  | 
| 53 | 
             
            gradient_checkpointing:
         | 
|  | |
| 46 |  | 
| 47 | 
             
            train_on_inputs: false
         | 
| 48 | 
             
            group_by_length: true
         | 
| 49 | 
            +
            bf16: auto
         | 
| 50 | 
            +
            fp16:
         | 
| 51 | 
             
            tf32: true
         | 
| 52 |  | 
| 53 | 
             
            gradient_checkpointing:
         | 
    	
        examples/phi/phi2-ft.yml
    CHANGED
    
    | @@ -49,8 +49,8 @@ learning_rate: 1e-5 | |
| 49 |  | 
| 50 | 
             
            train_on_inputs: false
         | 
| 51 | 
             
            group_by_length: false
         | 
| 52 | 
            -
            bf16:  | 
| 53 | 
            -
            fp16: | 
| 54 | 
             
            tf32: true
         | 
| 55 |  | 
| 56 | 
             
            gradient_checkpointing: true
         | 
|  | |
| 49 |  | 
| 50 | 
             
            train_on_inputs: false
         | 
| 51 | 
             
            group_by_length: false
         | 
| 52 | 
            +
            bf16: auto
         | 
| 53 | 
            +
            fp16:
         | 
| 54 | 
             
            tf32: true
         | 
| 55 |  | 
| 56 | 
             
            gradient_checkpointing: true
         | 
    	
        examples/pythia/lora.yml
    CHANGED
    
    | @@ -27,7 +27,7 @@ num_epochs: 4 | |
| 27 | 
             
            learning_rate: 0.00001
         | 
| 28 | 
             
            train_on_inputs: false
         | 
| 29 | 
             
            group_by_length: false
         | 
| 30 | 
            -
            bf16:  | 
| 31 | 
             
            tf32: true
         | 
| 32 | 
             
            early_stopping_patience:
         | 
| 33 | 
             
            resume_from_checkpoint:
         | 
|  | |
| 27 | 
             
            learning_rate: 0.00001
         | 
| 28 | 
             
            train_on_inputs: false
         | 
| 29 | 
             
            group_by_length: false
         | 
| 30 | 
            +
            bf16: auto
         | 
| 31 | 
             
            tf32: true
         | 
| 32 | 
             
            early_stopping_patience:
         | 
| 33 | 
             
            resume_from_checkpoint:
         | 
    	
        examples/qwen/lora.yml
    CHANGED
    
    | @@ -43,8 +43,8 @@ learning_rate: 0.0002 | |
| 43 |  | 
| 44 | 
             
            train_on_inputs: false
         | 
| 45 | 
             
            group_by_length: false
         | 
| 46 | 
            -
            bf16:  | 
| 47 | 
            -
            fp16: | 
| 48 | 
             
            tf32: false
         | 
| 49 |  | 
| 50 | 
             
            gradient_checkpointing: false
         | 
|  | |
| 43 |  | 
| 44 | 
             
            train_on_inputs: false
         | 
| 45 | 
             
            group_by_length: false
         | 
| 46 | 
            +
            bf16: auto
         | 
| 47 | 
            +
            fp16:
         | 
| 48 | 
             
            tf32: false
         | 
| 49 |  | 
| 50 | 
             
            gradient_checkpointing: false
         | 
    	
        examples/qwen/qlora.yml
    CHANGED
    
    | @@ -43,8 +43,8 @@ learning_rate: 0.0002 | |
| 43 |  | 
| 44 | 
             
            train_on_inputs: false
         | 
| 45 | 
             
            group_by_length: false
         | 
| 46 | 
            -
            bf16:  | 
| 47 | 
            -
            fp16: | 
| 48 | 
             
            tf32: false
         | 
| 49 |  | 
| 50 | 
             
            gradient_checkpointing: false
         | 
|  | |
| 43 |  | 
| 44 | 
             
            train_on_inputs: false
         | 
| 45 | 
             
            group_by_length: false
         | 
| 46 | 
            +
            bf16: auto
         | 
| 47 | 
            +
            fp16:
         | 
| 48 | 
             
            tf32: false
         | 
| 49 |  | 
| 50 | 
             
            gradient_checkpointing: false
         | 
    	
        examples/redpajama/config-3b.yml
    CHANGED
    
    | @@ -34,7 +34,7 @@ lr_scheduler: cosine | |
| 34 | 
             
            learning_rate: 0.0000002
         | 
| 35 | 
             
            train_on_inputs: false
         | 
| 36 | 
             
            group_by_length: false
         | 
| 37 | 
            -
            bf16:  | 
| 38 | 
             
            tf32: true
         | 
| 39 | 
             
            early_stopping_patience:
         | 
| 40 | 
             
            resume_from_checkpoint:
         | 
|  | |
| 34 | 
             
            learning_rate: 0.0000002
         | 
| 35 | 
             
            train_on_inputs: false
         | 
| 36 | 
             
            group_by_length: false
         | 
| 37 | 
            +
            bf16: auto
         | 
| 38 | 
             
            tf32: true
         | 
| 39 | 
             
            early_stopping_patience:
         | 
| 40 | 
             
            resume_from_checkpoint:
         | 
    	
        examples/replit-3b/config-lora.yml
    CHANGED
    
    | @@ -33,7 +33,7 @@ lr_scheduler: | |
| 33 | 
             
            learning_rate: 0.00001
         | 
| 34 | 
             
            train_on_inputs: false
         | 
| 35 | 
             
            group_by_length: false
         | 
| 36 | 
            -
            bf16:  | 
| 37 | 
             
            tf32: true
         | 
| 38 | 
             
            gradient_checkpointing:
         | 
| 39 | 
             
            early_stopping_patience:
         | 
|  | |
| 33 | 
             
            learning_rate: 0.00001
         | 
| 34 | 
             
            train_on_inputs: false
         | 
| 35 | 
             
            group_by_length: false
         | 
| 36 | 
            +
            bf16: auto
         | 
| 37 | 
             
            tf32: true
         | 
| 38 | 
             
            gradient_checkpointing:
         | 
| 39 | 
             
            early_stopping_patience:
         | 
    	
        examples/tiny-llama/lora.yml
    CHANGED
    
    | @@ -41,8 +41,8 @@ learning_rate: 0.0002 | |
| 41 |  | 
| 42 | 
             
            train_on_inputs: false
         | 
| 43 | 
             
            group_by_length: false
         | 
| 44 | 
            -
            bf16:  | 
| 45 | 
            -
            fp16: | 
| 46 | 
             
            tf32: false
         | 
| 47 |  | 
| 48 | 
             
            gradient_checkpointing: true
         | 
|  | |
| 41 |  | 
| 42 | 
             
            train_on_inputs: false
         | 
| 43 | 
             
            group_by_length: false
         | 
| 44 | 
            +
            bf16: auto
         | 
| 45 | 
            +
            fp16:
         | 
| 46 | 
             
            tf32: false
         | 
| 47 |  | 
| 48 | 
             
            gradient_checkpointing: true
         | 
    	
        examples/tiny-llama/pretrain.yml
    CHANGED
    
    | @@ -34,8 +34,8 @@ learning_rate: 0.0002 | |
| 34 |  | 
| 35 | 
             
            train_on_inputs: false
         | 
| 36 | 
             
            group_by_length: false
         | 
| 37 | 
            -
            bf16:  | 
| 38 | 
            -
            fp16: | 
| 39 | 
             
            tf32: false
         | 
| 40 |  | 
| 41 | 
             
            gradient_checkpointing: true
         | 
|  | |
| 34 |  | 
| 35 | 
             
            train_on_inputs: false
         | 
| 36 | 
             
            group_by_length: false
         | 
| 37 | 
            +
            bf16: auto
         | 
| 38 | 
            +
            fp16:
         | 
| 39 | 
             
            tf32: false
         | 
| 40 |  | 
| 41 | 
             
            gradient_checkpointing: true
         | 
    	
        examples/tiny-llama/qlora.yml
    CHANGED
    
    | @@ -43,8 +43,8 @@ learning_rate: 0.0002 | |
| 43 |  | 
| 44 | 
             
            train_on_inputs: false
         | 
| 45 | 
             
            group_by_length: false
         | 
| 46 | 
            -
            bf16:  | 
| 47 | 
            -
            fp16: | 
| 48 | 
             
            tf32: false
         | 
| 49 |  | 
| 50 | 
             
            gradient_checkpointing: true
         | 
|  | |
| 43 |  | 
| 44 | 
             
            train_on_inputs: false
         | 
| 45 | 
             
            group_by_length: false
         | 
| 46 | 
            +
            bf16: auto
         | 
| 47 | 
            +
            fp16:
         | 
| 48 | 
             
            tf32: false
         | 
| 49 |  | 
| 50 | 
             
            gradient_checkpointing: true
         | 
    	
        examples/xgen-7b/xgen-7b-8k-qlora.yml
    CHANGED
    
    | @@ -62,8 +62,8 @@ lr_scheduler: cosine | |
| 62 | 
             
            learning_rate: 0.00002
         | 
| 63 | 
             
            train_on_inputs: false
         | 
| 64 | 
             
            group_by_length: false
         | 
| 65 | 
            -
            bf16:  | 
| 66 | 
            -
            fp16: | 
| 67 | 
             
            tf32: false
         | 
| 68 | 
             
            gradient_checkpointing: true
         | 
| 69 | 
             
            # stop training after this many evaluation losses have increased in a row
         | 
|  | |
| 62 | 
             
            learning_rate: 0.00002
         | 
| 63 | 
             
            train_on_inputs: false
         | 
| 64 | 
             
            group_by_length: false
         | 
| 65 | 
            +
            bf16: auto
         | 
| 66 | 
            +
            fp16:
         | 
| 67 | 
             
            tf32: false
         | 
| 68 | 
             
            gradient_checkpointing: true
         | 
| 69 | 
             
            # stop training after this many evaluation losses have increased in a row
         | 
    	
        examples/yi-34B-chat/qlora.yml
    CHANGED
    
    | @@ -7,8 +7,8 @@ load_in_8bit: false | |
| 7 | 
             
            load_in_4bit: true
         | 
| 8 | 
             
            strict: false
         | 
| 9 | 
             
            sequence_len: 1024
         | 
| 10 | 
            -
            bf16:  | 
| 11 | 
            -
            fp16: | 
| 12 | 
             
            tf32: false
         | 
| 13 | 
             
            flash_attention: true
         | 
| 14 | 
             
            special_tokens:
         | 
|  | |
| 7 | 
             
            load_in_4bit: true
         | 
| 8 | 
             
            strict: false
         | 
| 9 | 
             
            sequence_len: 1024
         | 
| 10 | 
            +
            bf16: auto
         | 
| 11 | 
            +
            fp16:
         | 
| 12 | 
             
            tf32: false
         | 
| 13 | 
             
            flash_attention: true
         | 
| 14 | 
             
            special_tokens:
         | 
    	
        src/axolotl/utils/config.py
    CHANGED
    
    | @@ -70,6 +70,8 @@ def normalize_config(cfg): | |
| 70 | 
             
                    else:
         | 
| 71 | 
             
                        LOG.debug("bf16 support not detected, disabling for this configuration.")
         | 
| 72 | 
             
                        cfg.bf16 = False
         | 
|  | |
|  | |
| 73 |  | 
| 74 | 
             
                if cfg.device == "mps":
         | 
| 75 | 
             
                    cfg.load_in_8bit = False
         | 
| @@ -79,6 +81,8 @@ def normalize_config(cfg): | |
| 79 | 
             
                    cfg.bf16 = False
         | 
| 80 | 
             
                else:
         | 
| 81 | 
             
                    torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False
         | 
|  | |
|  | |
| 82 |  | 
| 83 | 
             
                if cfg.bf16 or cfg.bfloat16:
         | 
| 84 | 
             
                    cfg.torch_dtype = torch.bfloat16
         | 
|  | |
| 70 | 
             
                    else:
         | 
| 71 | 
             
                        LOG.debug("bf16 support not detected, disabling for this configuration.")
         | 
| 72 | 
             
                        cfg.bf16 = False
         | 
| 73 | 
            +
                        if cfg.fp16 is None:
         | 
| 74 | 
            +
                            cfg.fp16 = True
         | 
| 75 |  | 
| 76 | 
             
                if cfg.device == "mps":
         | 
| 77 | 
             
                    cfg.load_in_8bit = False
         | 
|  | |
| 81 | 
             
                    cfg.bf16 = False
         | 
| 82 | 
             
                else:
         | 
| 83 | 
             
                    torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False
         | 
| 84 | 
            +
                    if cfg.bf16:
         | 
| 85 | 
            +
                        cfg.fp16 = False
         | 
| 86 |  | 
| 87 | 
             
                if cfg.bf16 or cfg.bfloat16:
         | 
| 88 | 
             
                    cfg.torch_dtype = torch.bfloat16
         | 
    	
        tests/test_normalize_config.py
    CHANGED
    
    | @@ -78,13 +78,28 @@ class NormalizeConfigTestCase(unittest.TestCase): | |
| 78 | 
             
                    normalize_config(cfg)
         | 
| 79 |  | 
| 80 | 
             
                    self.assertTrue(cfg.bf16)
         | 
|  | |
| 81 |  | 
| 82 | 
             
                @patch("axolotl.utils.config.is_torch_bf16_gpu_available")
         | 
| 83 | 
             
                def test_bf16_auto_setter_not_available(self, mock_bf16_avail):
         | 
| 84 | 
             
                    cfg = self._get_base_cfg()
         | 
| 85 | 
             
                    cfg.bf16 = "auto"
         | 
|  | |
| 86 | 
             
                    mock_bf16_avail.return_value = False
         | 
| 87 |  | 
| 88 | 
             
                    normalize_config(cfg)
         | 
| 89 |  | 
| 90 | 
             
                    self.assertFalse(cfg.bf16)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 78 | 
             
                    normalize_config(cfg)
         | 
| 79 |  | 
| 80 | 
             
                    self.assertTrue(cfg.bf16)
         | 
| 81 | 
            +
                    self.assertFalse(cfg.fp16)
         | 
| 82 |  | 
| 83 | 
             
                @patch("axolotl.utils.config.is_torch_bf16_gpu_available")
         | 
| 84 | 
             
                def test_bf16_auto_setter_not_available(self, mock_bf16_avail):
         | 
| 85 | 
             
                    cfg = self._get_base_cfg()
         | 
| 86 | 
             
                    cfg.bf16 = "auto"
         | 
| 87 | 
            +
                    cfg.fp16 = None
         | 
| 88 | 
             
                    mock_bf16_avail.return_value = False
         | 
| 89 |  | 
| 90 | 
             
                    normalize_config(cfg)
         | 
| 91 |  | 
| 92 | 
             
                    self.assertFalse(cfg.bf16)
         | 
| 93 | 
            +
                    self.assertTrue(cfg.fp16)
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                @patch("axolotl.utils.config.is_torch_bf16_gpu_available")
         | 
| 96 | 
            +
                def test_bf16_disables_fp16(self, mock_bf16_avail):
         | 
| 97 | 
            +
                    cfg = self._get_base_cfg()
         | 
| 98 | 
            +
                    cfg.bf16 = True
         | 
| 99 | 
            +
                    cfg.fp16 = False
         | 
| 100 | 
            +
                    mock_bf16_avail.return_value = True
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                    normalize_config(cfg)
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                    self.assertTrue(cfg.bf16)
         | 
| 105 | 
            +
                    self.assertFalse(cfg.fp16)
         | 
