fix bug when model_type not explicitly passed
Browse files
src/axolotl/utils/models.py
CHANGED
|
@@ -35,7 +35,7 @@ def load_model(
|
|
| 35 |
# TODO refactor as a kwarg
|
| 36 |
load_in_8bit = cfg.load_in_8bit
|
| 37 |
tokenizer = None
|
| 38 |
-
is_llama_derived_model = "llama" in base_model or "llama" in cfg.model_type.lower()
|
| 39 |
|
| 40 |
if is_llama_derived_model and cfg.flash_attention:
|
| 41 |
if cfg.device not in ["mps", "cpu"] and inference is False:
|
|
|
|
| 35 |
# TODO refactor as a kwarg
|
| 36 |
load_in_8bit = cfg.load_in_8bit
|
| 37 |
tokenizer = None
|
| 38 |
+
is_llama_derived_model = "llama" in base_model or (cfg.model_type and "llama" in cfg.model_type.lower())
|
| 39 |
|
| 40 |
if is_llama_derived_model and cfg.flash_attention:
|
| 41 |
if cfg.device not in ["mps", "cpu"] and inference is False:
|