Angainor Development
commited on
WIP: Rely on cfg.inference
Browse files
src/axolotl/utils/models.py
CHANGED
|
@@ -80,8 +80,7 @@ def load_model(
|
|
| 80 |
model_type,
|
| 81 |
tokenizer,
|
| 82 |
cfg,
|
| 83 |
-
adapter="lora"
|
| 84 |
-
inference=False,
|
| 85 |
):
|
| 86 |
# type: (str, str, str, str, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
| 87 |
"""
|
|
@@ -95,7 +94,7 @@ def load_model(
|
|
| 95 |
)
|
| 96 |
|
| 97 |
if is_llama_derived_model and cfg.flash_attention:
|
| 98 |
-
if cfg.device not in ["mps", "cpu"] and inference is False:
|
| 99 |
from axolotl.flash_attn import replace_llama_attn_with_flash_attn
|
| 100 |
|
| 101 |
logging.info("patching with flash attention")
|
|
@@ -402,7 +401,7 @@ def load_lora(model, cfg):
|
|
| 402 |
model = PeftModel.from_pretrained(
|
| 403 |
model,
|
| 404 |
cfg.lora_model_dir,
|
| 405 |
-
is_trainable=
|
| 406 |
device_map=cfg.device_map,
|
| 407 |
# torch_dtype=torch.float16,
|
| 408 |
)
|
|
|
|
| 80 |
model_type,
|
| 81 |
tokenizer,
|
| 82 |
cfg,
|
| 83 |
+
adapter="lora"
|
|
|
|
| 84 |
):
|
| 85 |
# type: (str, str, str, str, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
| 86 |
"""
|
|
|
|
| 94 |
)
|
| 95 |
|
| 96 |
if is_llama_derived_model and cfg.flash_attention:
|
| 97 |
+
if cfg.device not in ["mps", "cpu"] and cfg.inference is False:
|
| 98 |
from axolotl.flash_attn import replace_llama_attn_with_flash_attn
|
| 99 |
|
| 100 |
logging.info("patching with flash attention")
|
|
|
|
| 401 |
model = PeftModel.from_pretrained(
|
| 402 |
model,
|
| 403 |
cfg.lora_model_dir,
|
| 404 |
+
is_trainable=not cfg.inference,
|
| 405 |
device_map=cfg.device_map,
|
| 406 |
# torch_dtype=torch.float16,
|
| 407 |
)
|