Fix set mem_id for inference and refactor
Browse files
scripts/finetune.py
CHANGED
|
@@ -78,6 +78,9 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
|
|
| 78 |
)
|
| 79 |
|
| 80 |
if cfg.landmark_attention:
|
|
|
|
|
|
|
|
|
|
| 81 |
model.set_mem_cache_args(
|
| 82 |
max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None
|
| 83 |
)
|
|
|
|
| 78 |
)
|
| 79 |
|
| 80 |
if cfg.landmark_attention:
|
| 81 |
+
from axolotl.monkeypatch.llama_landmark_attn import set_model_mem_id
|
| 82 |
+
|
| 83 |
+
set_model_mem_id(model, tokenizer)
|
| 84 |
model.set_mem_cache_args(
|
| 85 |
max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None
|
| 86 |
)
|
src/axolotl/monkeypatch/llama_landmark_attn.py
CHANGED
|
@@ -29,6 +29,7 @@ import torch
|
|
| 29 |
import torch.utils.checkpoint
|
| 30 |
from torch import nn
|
| 31 |
from torch.nn import CrossEntropyLoss
|
|
|
|
| 32 |
from transformers.modeling_outputs import (
|
| 33 |
BaseModelOutputWithPast,
|
| 34 |
CausalLMOutputWithPast,
|
|
@@ -1237,3 +1238,12 @@ def patch_llama_with_landmark_attn():
|
|
| 1237 |
transformers.models.llama.modeling_llama.LlamaAttention = LlamaAttention
|
| 1238 |
transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
|
| 1239 |
transformers.models.llama.modeling_llama.apply_rotary_pos_emb = apply_rotary_pos_emb
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
import torch.utils.checkpoint
|
| 30 |
from torch import nn
|
| 31 |
from torch.nn import CrossEntropyLoss
|
| 32 |
+
from transformers import LlamaTokenizer
|
| 33 |
from transformers.modeling_outputs import (
|
| 34 |
BaseModelOutputWithPast,
|
| 35 |
CausalLMOutputWithPast,
|
|
|
|
| 1238 |
transformers.models.llama.modeling_llama.LlamaAttention = LlamaAttention
|
| 1239 |
transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
|
| 1240 |
transformers.models.llama.modeling_llama.apply_rotary_pos_emb = apply_rotary_pos_emb
|
| 1241 |
+
|
| 1242 |
+
|
| 1243 |
+
def set_model_mem_id(model: LlamaForCausalLM, tokenizer: LlamaTokenizer):
|
| 1244 |
+
mem_id = tokenizer.convert_tokens_to_ids(MEM_TOKEN)
|
| 1245 |
+
model.set_mem_id(mem_id)
|
| 1246 |
+
|
| 1247 |
+
|
| 1248 |
+
def get_mem_id(tokenizer: LlamaTokenizer):
|
| 1249 |
+
return tokenizer.convert_tokens_to_ids(MEM_TOKEN)
|
src/axolotl/utils/trainer.py
CHANGED
|
@@ -239,16 +239,19 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
| 239 |
if cfg.is_llama_derived_model and cfg.landmark_attention:
|
| 240 |
from functools import partial
|
| 241 |
|
| 242 |
-
from axolotl.monkeypatch.llama_landmark_attn import
|
|
|
|
|
|
|
|
|
|
|
|
|
| 243 |
|
| 244 |
-
|
| 245 |
-
model.set_mem_id(mem_id)
|
| 246 |
|
| 247 |
logging.info("Adding landmark attention tokens to dataset")
|
| 248 |
|
| 249 |
for dataset in [train_dataset, eval_dataset]:
|
| 250 |
dataset = dataset.map(
|
| 251 |
-
partial(add_mem_tokens, mem_freq=50, mem_id=
|
| 252 |
batched=False,
|
| 253 |
num_proc=32,
|
| 254 |
)
|
|
|
|
| 239 |
if cfg.is_llama_derived_model and cfg.landmark_attention:
|
| 240 |
from functools import partial
|
| 241 |
|
| 242 |
+
from axolotl.monkeypatch.llama_landmark_attn import (
|
| 243 |
+
add_mem_tokens,
|
| 244 |
+
get_mem_id,
|
| 245 |
+
set_model_mem_id,
|
| 246 |
+
)
|
| 247 |
|
| 248 |
+
set_model_mem_id(model, tokenizer)
|
|
|
|
| 249 |
|
| 250 |
logging.info("Adding landmark attention tokens to dataset")
|
| 251 |
|
| 252 |
for dataset in [train_dataset, eval_dataset]:
|
| 253 |
dataset = dataset.map(
|
| 254 |
+
partial(add_mem_tokens, mem_freq=50, mem_id=get_mem_id(tokenizer)),
|
| 255 |
batched=False,
|
| 256 |
num_proc=32,
|
| 257 |
)
|