move flash-attn monkey patch alongside the others
Browse files
src/axolotl/{flash_attn.py → monkeypatch/llama_attn_hijack_flash.py}
RENAMED
|
File without changes
|
src/axolotl/utils/models.py
CHANGED
|
@@ -92,7 +92,9 @@ def load_model(
|
|
| 92 |
|
| 93 |
if cfg.is_llama_derived_model and cfg.flash_attention:
|
| 94 |
if cfg.device not in ["mps", "cpu"] and not cfg.inference:
|
| 95 |
-
from axolotl.
|
|
|
|
|
|
|
| 96 |
|
| 97 |
LOG.info("patching with flash attention")
|
| 98 |
replace_llama_attn_with_flash_attn()
|
|
|
|
| 92 |
|
| 93 |
if cfg.is_llama_derived_model and cfg.flash_attention:
|
| 94 |
if cfg.device not in ["mps", "cpu"] and not cfg.inference:
|
| 95 |
+
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
| 96 |
+
replace_llama_attn_with_flash_attn,
|
| 97 |
+
)
|
| 98 |
|
| 99 |
LOG.info("patching with flash attention")
|
| 100 |
replace_llama_attn_with_flash_attn()
|