fix some of the edge cases for Jamba (#1452)
Browse files* fix some of the edge cases for Jamba
* update requirements for jamba
- .github/workflows/pypi.yml +1 -1
- .github/workflows/tests.yml +2 -0
- examples/jamba/README.md +8 -3
- examples/jamba/qlora.yaml +62 -0
- requirements.txt +1 -1
- setup.py +1 -1
- src/axolotl/monkeypatch/multipack.py +13 -11
- src/axolotl/utils/models.py +4 -0
    	
        .github/workflows/pypi.yml
    CHANGED
    
    | @@ -25,7 +25,7 @@ jobs: | |
| 25 |  | 
| 26 | 
             
                  - name: Install dependencies
         | 
| 27 | 
             
                    run: |
         | 
| 28 | 
            -
                      pip3 install wheel
         | 
| 29 | 
             
                      pip3 install -e .
         | 
| 30 | 
             
                      pip3 install -r requirements-tests.txt
         | 
| 31 |  | 
|  | |
| 25 |  | 
| 26 | 
             
                  - name: Install dependencies
         | 
| 27 | 
             
                    run: |
         | 
| 28 | 
            +
                      pip3 install wheel packaging
         | 
| 29 | 
             
                      pip3 install -e .
         | 
| 30 | 
             
                      pip3 install -r requirements-tests.txt
         | 
| 31 |  | 
    	
        .github/workflows/tests.yml
    CHANGED
    
    | @@ -48,6 +48,8 @@ jobs: | |
| 48 |  | 
| 49 | 
             
                  - name: Install dependencies
         | 
| 50 | 
             
                    run: |
         | 
|  | |
|  | |
| 51 | 
             
                      pip3 install -U -e .
         | 
| 52 | 
             
                      pip3 install -r requirements-tests.txt
         | 
| 53 |  | 
|  | |
| 48 |  | 
| 49 | 
             
                  - name: Install dependencies
         | 
| 50 | 
             
                    run: |
         | 
| 51 | 
            +
                      pip3 install --upgrade pip
         | 
| 52 | 
            +
                      pip3 install --upgrade packaging
         | 
| 53 | 
             
                      pip3 install -U -e .
         | 
| 54 | 
             
                      pip3 install -r requirements-tests.txt
         | 
| 55 |  | 
    	
        examples/jamba/README.md
    CHANGED
    
    | @@ -1,5 +1,10 @@ | |
| 1 | 
             
            # Jamba
         | 
| 2 |  | 
| 3 | 
            -
            qlora w/ deepspeed needs at least 2x GPUs and | 
| 4 | 
            -
             | 
| 5 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
             
            # Jamba
         | 
| 2 |  | 
| 3 | 
            +
            - ✅ qlora w/ deepspeed Zero-2 needs at least 2x GPUs and
         | 
| 4 | 
            +
              - 35GiB VRAM per GPU w minimal context length
         | 
| 5 | 
            +
              - 56GiB VRAM per GPU (w multipack enabled)
         | 
| 6 | 
            +
            - ✅ qlora w/ deepspeed Zero-3 needs at least 2x GPUs and 67GiB VRAM (wtf?)
         | 
| 7 | 
            +
            - ✅ qlora single-gpu, ~51GiB VRAM
         | 
| 8 | 
            +
            - ✅ multipack
         | 
| 9 | 
            +
            - ❓ FSDP
         | 
| 10 | 
            +
            - ❓ 8-bit LoRA
         | 
    	
        examples/jamba/qlora.yaml
    ADDED
    
    | @@ -0,0 +1,62 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            base_model: ai21labs/Jamba-v0.1
         | 
| 2 | 
            +
            trust_remote_code: true
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            load_in_8bit: false
         | 
| 5 | 
            +
            load_in_4bit: true
         | 
| 6 | 
            +
            strict: false
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            datasets:
         | 
| 9 | 
            +
              - path: mhenrichsen/alpaca_2k_test
         | 
| 10 | 
            +
                type: alpaca
         | 
| 11 | 
            +
            dataset_prepared_path:
         | 
| 12 | 
            +
            val_set_size: 0.0
         | 
| 13 | 
            +
            output_dir: ./out
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            sequence_len: 4096
         | 
| 16 | 
            +
            sample_packing: false
         | 
| 17 | 
            +
            pad_to_sequence_len: false
         | 
| 18 | 
            +
            eval_sample_packing: false
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            wandb_project:
         | 
| 21 | 
            +
            wandb_entity:
         | 
| 22 | 
            +
            wandb_watch:
         | 
| 23 | 
            +
            wandb_name:
         | 
| 24 | 
            +
            wandb_log_model:
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            adapter: qlora
         | 
| 27 | 
            +
            lora_r: 8
         | 
| 28 | 
            +
            lora_alpha: 16
         | 
| 29 | 
            +
            lora_dropout: 0.05
         | 
| 30 | 
            +
            lora_target_linear: true
         | 
| 31 | 
            +
             | 
| 32 | 
            +
            low_cpu_mem_usage: true
         | 
| 33 | 
            +
            gradient_accumulation_steps: 4
         | 
| 34 | 
            +
            micro_batch_size: 1
         | 
| 35 | 
            +
            num_epochs: 2
         | 
| 36 | 
            +
            optimizer: paged_adamw_8bit
         | 
| 37 | 
            +
            lr_scheduler: cosine
         | 
| 38 | 
            +
            learning_rate: 0.00001
         | 
| 39 | 
            +
             | 
| 40 | 
            +
            train_on_inputs: false
         | 
| 41 | 
            +
            group_by_length: false
         | 
| 42 | 
            +
            bf16: auto
         | 
| 43 | 
            +
            fp16:
         | 
| 44 | 
            +
            tf32: false
         | 
| 45 | 
            +
             | 
| 46 | 
            +
            gradient_checkpointing: true
         | 
| 47 | 
            +
            gradient_checkpointing_kwargs:
         | 
| 48 | 
            +
              use_reentrant: false
         | 
| 49 | 
            +
            early_stopping_patience:
         | 
| 50 | 
            +
            resume_from_checkpoint:
         | 
| 51 | 
            +
            local_rank:
         | 
| 52 | 
            +
            logging_steps: 1
         | 
| 53 | 
            +
            xformers_attention:
         | 
| 54 | 
            +
            flash_attention: true
         | 
| 55 | 
            +
             | 
| 56 | 
            +
            warmup_steps: 10
         | 
| 57 | 
            +
            evals_per_epoch:
         | 
| 58 | 
            +
            saves_per_epoch: 1
         | 
| 59 | 
            +
            debug:
         | 
| 60 | 
            +
            deepspeed:
         | 
| 61 | 
            +
            weight_decay: 0.0
         | 
| 62 | 
            +
            special_tokens:
         | 
    	
        requirements.txt
    CHANGED
    
    | @@ -32,7 +32,7 @@ fschat==0.2.36 | |
| 32 | 
             
            gradio==3.50.2
         | 
| 33 | 
             
            tensorboard
         | 
| 34 |  | 
| 35 | 
            -
            mamba-ssm==1. | 
| 36 |  | 
| 37 | 
             
            # remote filesystems
         | 
| 38 | 
             
            s3fs
         | 
|  | |
| 32 | 
             
            gradio==3.50.2
         | 
| 33 | 
             
            tensorboard
         | 
| 34 |  | 
| 35 | 
            +
            mamba-ssm==1.2.0.post1
         | 
| 36 |  | 
| 37 | 
             
            # remote filesystems
         | 
| 38 | 
             
            s3fs
         | 
    	
        setup.py
    CHANGED
    
    | @@ -78,7 +78,7 @@ setup( | |
| 78 | 
             
                        "deepspeed-kernels",
         | 
| 79 | 
             
                    ],
         | 
| 80 | 
             
                    "mamba-ssm": [
         | 
| 81 | 
            -
                        "mamba-ssm==1.0. | 
| 82 | 
             
                    ],
         | 
| 83 | 
             
                    "auto-gptq": [
         | 
| 84 | 
             
                        "auto-gptq==0.5.1",
         | 
|  | |
| 78 | 
             
                        "deepspeed-kernels",
         | 
| 79 | 
             
                    ],
         | 
| 80 | 
             
                    "mamba-ssm": [
         | 
| 81 | 
            +
                        "mamba-ssm==1.2.0.post1",
         | 
| 82 | 
             
                    ],
         | 
| 83 | 
             
                    "auto-gptq": [
         | 
| 84 | 
             
                        "auto-gptq==0.5.1",
         | 
    	
        src/axolotl/monkeypatch/multipack.py
    CHANGED
    
    | @@ -48,14 +48,16 @@ def patch_for_multipack(model_type, model_name=None): | |
| 48 | 
             
                        get_unpad_data
         | 
| 49 | 
             
                    )
         | 
| 50 | 
             
                elif model_type == "gemmoe":
         | 
| 51 | 
            -
                     | 
| 52 | 
            -
             | 
| 53 | 
            -
                     | 
| 54 | 
            -
             | 
| 55 | 
            -
             | 
| 56 | 
            -
             | 
| 57 | 
            -
             | 
| 58 | 
            -
             | 
| 59 | 
            -
             | 
| 60 | 
            -
             | 
| 61 | 
            -
             | 
|  | |
|  | 
|  | |
| 48 | 
             
                        get_unpad_data
         | 
| 49 | 
             
                    )
         | 
| 50 | 
             
                elif model_type == "gemmoe":
         | 
| 51 | 
            +
                    patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe")
         | 
| 52 | 
            +
                elif model_type == "jamba":
         | 
| 53 | 
            +
                    patch_remote(model_name, ".configuration_jamba", ".modeling_jamba")
         | 
| 54 | 
            +
             | 
| 55 | 
            +
             | 
| 56 | 
            +
            def patch_remote(model_name, config_name, modeling_name):
         | 
| 57 | 
            +
                model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
         | 
| 58 | 
            +
                # we need to load the model here in order for modeling_* to be available
         | 
| 59 | 
            +
                with init_empty_weights():
         | 
| 60 | 
            +
                    AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
         | 
| 61 | 
            +
                module_name = model_config.__class__.__module__.replace(config_name, modeling_name)
         | 
| 62 | 
            +
                modeling_arch = importlib.import_module(module_name)
         | 
| 63 | 
            +
                modeling_arch._get_unpad_data = get_unpad_data  # pylint: disable=protected-access
         | 
    	
        src/axolotl/utils/models.py
    CHANGED
    
    | @@ -456,6 +456,10 @@ def load_model( | |
| 456 | 
             
                        "bnb_4bit_quant_type": "nf4",
         | 
| 457 | 
             
                        "bnb_4bit_quant_storage": torch.bfloat16,
         | 
| 458 | 
             
                    }
         | 
|  | |
|  | |
|  | |
|  | |
| 459 |  | 
| 460 | 
             
                    if cfg.bnb_config_kwargs:
         | 
| 461 | 
             
                        bnb_config.update(cfg.bnb_config_kwargs)
         | 
|  | |
| 456 | 
             
                        "bnb_4bit_quant_type": "nf4",
         | 
| 457 | 
             
                        "bnb_4bit_quant_storage": torch.bfloat16,
         | 
| 458 | 
             
                    }
         | 
| 459 | 
            +
                    if cfg.model_config_type == "jamba" and not cfg.deepspeed:
         | 
| 460 | 
            +
                        # for some reason, this causes the loss to be off by an order of magnitude
         | 
| 461 | 
            +
                        # but deepspeed needs this still in bfloat16
         | 
| 462 | 
            +
                        bnb_config["bnb_4bit_quant_storage"] = torch.float32
         | 
| 463 |  | 
| 464 | 
             
                    if cfg.bnb_config_kwargs:
         | 
| 465 | 
             
                        bnb_config.update(cfg.bnb_config_kwargs)
         | 
