feat: expose bnb kwargs (#1018)
Browse files* feat: expose bnb kwargs
* chore: added examples and link per suggestion
* Uncomment defaults per suggestion for readability
Co-authored-by: Hamel Husain <[email protected]>
---------
Co-authored-by: Hamel Husain <[email protected]>
- README.md +8 -0
- src/axolotl/utils/models.py +13 -6
README.md
CHANGED
|
@@ -520,6 +520,14 @@ model_config:
|
|
| 520 |
type: # linear | dynamic
|
| 521 |
factor: # float
|
| 522 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 523 |
|
| 524 |
# Whether you are training a 4-bit GPTQ quantized model
|
| 525 |
gptq: true
|
|
|
|
| 520 |
type: # linear | dynamic
|
| 521 |
factor: # float
|
| 522 |
|
| 523 |
+
# optional overrides to the bnb 4bit quantization configuration
|
| 524 |
+
# https://huggingface.co/docs/transformers/main/main_classes/quantization#transformers.BitsAndBytesConfig
|
| 525 |
+
bnb_config_kwargs:
|
| 526 |
+
# These are default values
|
| 527 |
+
llm_int8_has_fp16_weight: false
|
| 528 |
+
bnb_4bit_quant_type: nf4
|
| 529 |
+
bnb_4bit_use_double_quant: true
|
| 530 |
+
|
| 531 |
|
| 532 |
# Whether you are training a 4-bit GPTQ quantized model
|
| 533 |
gptq: true
|
src/axolotl/utils/models.py
CHANGED
|
@@ -301,13 +301,20 @@ def load_model(
|
|
| 301 |
**model_config.quantization_config
|
| 302 |
)
|
| 303 |
if cfg.adapter == "qlora" and cfg.load_in_4bit:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 304 |
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
| 305 |
-
|
| 306 |
-
llm_int8_threshold=6.0,
|
| 307 |
-
llm_int8_has_fp16_weight=False,
|
| 308 |
-
bnb_4bit_compute_dtype=cfg.torch_dtype,
|
| 309 |
-
bnb_4bit_use_double_quant=True,
|
| 310 |
-
bnb_4bit_quant_type="nf4",
|
| 311 |
)
|
| 312 |
# sample packing uses custom FA2 patch
|
| 313 |
if cfg.flash_attention:
|
|
|
|
| 301 |
**model_config.quantization_config
|
| 302 |
)
|
| 303 |
if cfg.adapter == "qlora" and cfg.load_in_4bit:
|
| 304 |
+
bnb_config = {
|
| 305 |
+
"load_in_4bit": True,
|
| 306 |
+
"llm_int8_threshold": 6.0,
|
| 307 |
+
"llm_int8_has_fp16_weight": False,
|
| 308 |
+
"bnb_4bit_compute_dtype": cfg.torch_dtype,
|
| 309 |
+
"bnb_4bit_use_double_quant": True,
|
| 310 |
+
"bnb_4bit_quant_type": "nf4",
|
| 311 |
+
}
|
| 312 |
+
|
| 313 |
+
if cfg.bnb_config_kwargs:
|
| 314 |
+
bnb_config.update(cfg.bnb_config_kwargs)
|
| 315 |
+
|
| 316 |
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
| 317 |
+
**bnb_config,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 318 |
)
|
| 319 |
# sample packing uses custom FA2 patch
|
| 320 |
if cfg.flash_attention:
|