Fix patching via import instead of hijacking
Browse files
src/axolotl/utils/models.py
CHANGED
|
@@ -20,7 +20,9 @@ from transformers import ( # noqa: F401
|
|
| 20 |
)
|
| 21 |
|
| 22 |
try:
|
| 23 |
-
from transformers import
|
|
|
|
|
|
|
| 24 |
except ImportError:
|
| 25 |
logging.warning(
|
| 26 |
"This version of transformers does not support Llama. Consider upgrading."
|
|
@@ -115,15 +117,15 @@ def load_model(
|
|
| 115 |
logging.info("patching with sdp attention")
|
| 116 |
hijack_llama_sdp_attention()
|
| 117 |
elif cfg.is_llama_derived_model and cfg.landmark_attention:
|
| 118 |
-
from axolotl.monkeypatch.llama_landmark_attn import (
|
| 119 |
MEM_TOKEN,
|
| 120 |
-
|
| 121 |
)
|
| 122 |
|
| 123 |
logging.info("patching with landmark attention")
|
| 124 |
-
hijack_llama_landmark_attn()
|
| 125 |
|
| 126 |
-
|
|
|
|
| 127 |
|
| 128 |
if cfg.bf16:
|
| 129 |
torch_dtype = torch.bfloat16
|
|
|
|
| 20 |
)
|
| 21 |
|
| 22 |
try:
|
| 23 |
+
from transformers import ( # pylint: disable=unused-import # noqa: F401
|
| 24 |
+
LlamaForCausalLM,
|
| 25 |
+
)
|
| 26 |
except ImportError:
|
| 27 |
logging.warning(
|
| 28 |
"This version of transformers does not support Llama. Consider upgrading."
|
|
|
|
| 117 |
logging.info("patching with sdp attention")
|
| 118 |
hijack_llama_sdp_attention()
|
| 119 |
elif cfg.is_llama_derived_model and cfg.landmark_attention:
|
| 120 |
+
from axolotl.monkeypatch.llama_landmark_attn import ( # pylint: disable=redefined-outer-name # noqa: F811
|
| 121 |
MEM_TOKEN,
|
| 122 |
+
LlamaForCausalLM,
|
| 123 |
)
|
| 124 |
|
| 125 |
logging.info("patching with landmark attention")
|
|
|
|
| 126 |
|
| 127 |
+
# TODO: Check if this would overwrite previous additional_special_tokens
|
| 128 |
+
tokenizer.add_special_tokens({"additional_special_tokens": [MEM_TOKEN]})
|
| 129 |
|
| 130 |
if cfg.bf16:
|
| 131 |
torch_dtype = torch.bfloat16
|