Feat: Add rope scaling (#343)
Browse files* Feat: Add rope scaling
* fix: move rope config
- README.md +4 -0
- src/axolotl/utils/models.py +3 -1
README.md
CHANGED
|
@@ -474,6 +474,10 @@ landmark_attention:
|
|
| 474 |
# xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
|
| 475 |
# llama only
|
| 476 |
xpos_rope:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 477 |
|
| 478 |
# resume from a specific checkpoint dir
|
| 479 |
resume_from_checkpoint:
|
|
|
|
| 474 |
# xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
|
| 475 |
# llama only
|
| 476 |
xpos_rope:
|
| 477 |
+
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653
|
| 478 |
+
rope_scaling:
|
| 479 |
+
type: # linear | dynamic
|
| 480 |
+
factor: # float
|
| 481 |
|
| 482 |
# resume from a specific checkpoint dir
|
| 483 |
resume_from_checkpoint:
|
src/axolotl/utils/models.py
CHANGED
|
@@ -219,7 +219,9 @@ def load_model(
|
|
| 219 |
elif cfg.is_llama_derived_model and not cfg.trust_remote_code:
|
| 220 |
from transformers import LlamaForCausalLM
|
| 221 |
|
| 222 |
-
config = LlamaConfig.from_pretrained(
|
|
|
|
|
|
|
| 223 |
model = LlamaForCausalLM.from_pretrained(
|
| 224 |
base_model,
|
| 225 |
config=config,
|
|
|
|
| 219 |
elif cfg.is_llama_derived_model and not cfg.trust_remote_code:
|
| 220 |
from transformers import LlamaForCausalLM
|
| 221 |
|
| 222 |
+
config = LlamaConfig.from_pretrained(
|
| 223 |
+
base_model_config, rope_scaling=cfg.rope_scaling
|
| 224 |
+
)
|
| 225 |
model = LlamaForCausalLM.from_pretrained(
|
| 226 |
base_model,
|
| 227 |
config=config,
|