Add files using upload-large-folder tool
Browse files- .gitattributes +10 -0
- fla/models/abc/__pycache__/__init__.cpython-312.pyc +0 -0
- fla/models/abc/__pycache__/modeling_abc.cpython-312.pyc +0 -0
- fla/models/bitnet/__pycache__/__init__.cpython-312.pyc +0 -0
- fla/models/bitnet/__pycache__/modeling_bitnet.cpython-312.pyc +0 -0
- fla/models/delta_net/__pycache__/__init__.cpython-312.pyc +0 -0
- fla/models/forgetting_transformer/__pycache__/modeling_forgetting_transformer.cpython-312.pyc +0 -0
- fla/models/gated_deltanet/__pycache__/__init__.cpython-312.pyc +0 -0
- fla/models/gated_deltanet/configuration_gated_deltanet.py +83 -0
- fla/models/gated_deltaproduct/__pycache__/modeling_gated_deltaproduct.cpython-312.pyc +0 -0
- fla/models/gated_deltaproduct/modeling_gated_deltaproduct.py +520 -0
- fla/models/gla/modeling_gla.py +417 -0
- fla/models/gsa/configuration_gsa.py +97 -0
- fla/models/hgrn/__pycache__/__init__.cpython-312.pyc +0 -0
- fla/models/hgrn/__pycache__/configuration_hgrn.cpython-312.pyc +0 -0
- fla/models/lightnet/__pycache__/configuration_lightnet.cpython-312.pyc +0 -0
- fla/models/lightnet/__pycache__/modeling_lightnet.cpython-312.pyc +0 -0
- fla/models/lightnet/modeling_lightnet.py +410 -0
- fla/models/linear_attn/__pycache__/configuration_linear_attn.cpython-312.pyc +0 -0
- fla/models/linear_attn/__pycache__/modeling_linear_attn.cpython-312.pyc +0 -0
- fla/models/mamba2/__pycache__/__init__.cpython-312.pyc +0 -0
- fla/models/mamba2/__pycache__/configuration_mamba2.cpython-312.pyc +0 -0
- fla/models/nsa/__pycache__/__init__.cpython-312.pyc +0 -0
- fla/models/nsa/modeling_nsa.py +398 -0
- fla/models/retnet/configuration_retnet.py +92 -0
- fla/models/rwkv6/__pycache__/__init__.cpython-312.pyc +0 -0
- fla/models/rwkv7/__pycache__/__init__.cpython-312.pyc +0 -0
- fla/models/rwkv7/__pycache__/configuration_rwkv7.cpython-312.pyc +0 -0
- fla/models/transformer/__pycache__/configuration_transformer.cpython-312.pyc +0 -0
- fla/models/transformer_mtp/__pycache__/modeling_transformer.cpython-312.pyc +0 -0
- fla/models/transformer_mtp/modeling_transformer.py +608 -0
- fla/models/transformer_top/__pycache__/modeling_transformer.cpython-312.pyc +0 -0
- fla/modules/__pycache__/feature_map.cpython-312.pyc +0 -0
- logs/none_yagntt11/attempt_0/0/stderr.log +3 -0
- logs/none_yagntt11/attempt_0/1/stderr.log +3 -0
- logs/none_yagntt11/attempt_0/2/stderr.log +3 -0
- logs/none_yagntt11/attempt_0/3/stderr.log +3 -0
- logs/none_yagntt11/attempt_0/4/stderr.log +3 -0
- logs/none_yagntt11/attempt_0/5/stderr.log +3 -0
- logs/none_yagntt11/attempt_0/5/stdout.log +0 -0
- logs/none_yagntt11/attempt_0/6/stderr.log +3 -0
- logs/none_yagntt11/attempt_0/7/stderr.log +3 -0
- model-00001-of-00002.safetensors +3 -0
- model-00002-of-00002.safetensors +3 -0
- tb/20250716-2210/wandb/run-20250716_221000-mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/files/output.log +3 -0
- tb/20250716-2210/wandb/run-20250716_221000-mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/run-mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201.wandb +3 -0
.gitattributes
CHANGED
@@ -33,3 +33,13 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
logs/none_yagntt11/attempt_0/4/stderr.log filter=lfs diff=lfs merge=lfs -text
|
37 |
+
logs/none_yagntt11/attempt_0/1/stderr.log filter=lfs diff=lfs merge=lfs -text
|
38 |
+
logs/none_yagntt11/attempt_0/2/stderr.log filter=lfs diff=lfs merge=lfs -text
|
39 |
+
logs/none_yagntt11/attempt_0/5/stderr.log filter=lfs diff=lfs merge=lfs -text
|
40 |
+
logs/none_yagntt11/attempt_0/3/stderr.log filter=lfs diff=lfs merge=lfs -text
|
41 |
+
logs/none_yagntt11/attempt_0/7/stderr.log filter=lfs diff=lfs merge=lfs -text
|
42 |
+
logs/none_yagntt11/attempt_0/6/stderr.log filter=lfs diff=lfs merge=lfs -text
|
43 |
+
tb/20250716-2210/wandb/run-20250716_221000-mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/files/output.log filter=lfs diff=lfs merge=lfs -text
|
44 |
+
logs/none_yagntt11/attempt_0/0/stderr.log filter=lfs diff=lfs merge=lfs -text
|
45 |
+
tb/20250716-2210/wandb/run-20250716_221000-mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/run-mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201.wandb filter=lfs diff=lfs merge=lfs -text
|
fla/models/abc/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (657 Bytes). View file
|
|
fla/models/abc/__pycache__/modeling_abc.cpython-312.pyc
ADDED
Binary file (18.4 kB). View file
|
|
fla/models/bitnet/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (682 Bytes). View file
|
|
fla/models/bitnet/__pycache__/modeling_bitnet.cpython-312.pyc
ADDED
Binary file (18.6 kB). View file
|
|
fla/models/delta_net/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (701 Bytes). View file
|
|
fla/models/forgetting_transformer/__pycache__/modeling_forgetting_transformer.cpython-312.pyc
ADDED
Binary file (17.2 kB). View file
|
|
fla/models/gated_deltanet/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (746 Bytes). View file
|
|
fla/models/gated_deltanet/configuration_gated_deltanet.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
from typing import Dict, Optional
|
4 |
+
|
5 |
+
from transformers.configuration_utils import PretrainedConfig
|
6 |
+
|
7 |
+
|
8 |
+
class GatedDeltaNetConfig(PretrainedConfig):
|
9 |
+
model_type = 'gated_deltanet'
|
10 |
+
keys_to_ignore_at_inference = ['past_key_values']
|
11 |
+
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
attn_mode: str = "chunk",
|
15 |
+
hidden_size: int = 2048,
|
16 |
+
expand_v: int = 2,
|
17 |
+
use_gate: bool = True,
|
18 |
+
use_short_conv: bool = True,
|
19 |
+
conv_size: int = 4,
|
20 |
+
head_dim: int = 256,
|
21 |
+
num_heads: int = 6,
|
22 |
+
max_position_embeddings: int = 2048,
|
23 |
+
hidden_ratio: Optional[int] = 4,
|
24 |
+
intermediate_size: Optional[int] = None,
|
25 |
+
hidden_act: str = "swish",
|
26 |
+
num_hidden_layers: int = 21,
|
27 |
+
norm_eps: float = 1e-6,
|
28 |
+
attn: Optional[Dict] = None,
|
29 |
+
use_cache: bool = True,
|
30 |
+
pad_token_id: int = None,
|
31 |
+
bos_token_id: int = 1,
|
32 |
+
eos_token_id: int = 2,
|
33 |
+
tie_word_embeddings: bool = False,
|
34 |
+
initializer_range: float = 0.006,
|
35 |
+
fuse_norm: bool = True,
|
36 |
+
fuse_swiglu: bool = True,
|
37 |
+
fuse_cross_entropy: bool = True,
|
38 |
+
vocab_size: int = 32000,
|
39 |
+
**kwargs
|
40 |
+
):
|
41 |
+
self.attn_mode = attn_mode
|
42 |
+
self.hidden_size = hidden_size
|
43 |
+
self.expand_v = expand_v
|
44 |
+
self.use_gate = use_gate
|
45 |
+
self.use_short_conv = use_short_conv
|
46 |
+
self.conv_size = conv_size
|
47 |
+
self.head_dim = head_dim
|
48 |
+
self.num_heads = num_heads
|
49 |
+
self.max_position_embeddings = max_position_embeddings
|
50 |
+
|
51 |
+
self.hidden_ratio = hidden_ratio
|
52 |
+
self.intermediate_size = intermediate_size
|
53 |
+
self.hidden_act = hidden_act
|
54 |
+
self.num_hidden_layers = num_hidden_layers
|
55 |
+
self.norm_eps = norm_eps
|
56 |
+
self.attn = attn
|
57 |
+
self.use_cache = use_cache
|
58 |
+
self.initializer_range = initializer_range
|
59 |
+
|
60 |
+
self.fuse_norm = fuse_norm
|
61 |
+
self.fuse_swiglu = fuse_swiglu
|
62 |
+
self.fuse_cross_entropy = fuse_cross_entropy
|
63 |
+
self.vocab_size = vocab_size
|
64 |
+
|
65 |
+
if attn is not None:
|
66 |
+
if not isinstance(attn, Dict):
|
67 |
+
raise ValueError("attn must be a dictionary")
|
68 |
+
if 'layers' not in attn:
|
69 |
+
raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
|
70 |
+
if 'num_heads' not in attn:
|
71 |
+
raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
|
72 |
+
attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads'])
|
73 |
+
attn['qkv_bias'] = attn.get('qkv_bias', False)
|
74 |
+
attn['window_size'] = attn.get('window_size', None)
|
75 |
+
attn['rope_theta'] = attn.get('rope_theta', 10000.)
|
76 |
+
|
77 |
+
super().__init__(
|
78 |
+
pad_token_id=pad_token_id,
|
79 |
+
bos_token_id=bos_token_id,
|
80 |
+
eos_token_id=eos_token_id,
|
81 |
+
tie_word_embeddings=tie_word_embeddings,
|
82 |
+
**kwargs,
|
83 |
+
)
|
fla/models/gated_deltaproduct/__pycache__/modeling_gated_deltaproduct.cpython-312.pyc
ADDED
Binary file (20.7 kB). View file
|
|
fla/models/gated_deltaproduct/modeling_gated_deltaproduct.py
ADDED
@@ -0,0 +1,520 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
from __future__ import annotations
|
4 |
+
|
5 |
+
import math
|
6 |
+
import warnings
|
7 |
+
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.utils.checkpoint
|
12 |
+
from transformers.activations import ACT2FN
|
13 |
+
from transformers.generation import GenerationMixin
|
14 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
15 |
+
from transformers.modeling_utils import PreTrainedModel
|
16 |
+
from transformers.utils import logging
|
17 |
+
from transformers.utils.deprecation import deprecate_kwarg
|
18 |
+
|
19 |
+
from fla.layers.attn import Attention
|
20 |
+
from fla.layers.gated_deltaproduct import GatedDeltaProduct
|
21 |
+
from fla.models.gated_deltaproduct.configuration_gated_deltaproduct import GatedDeltaProductConfig
|
22 |
+
from fla.models.utils import Cache
|
23 |
+
from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss, RMSNorm
|
24 |
+
from fla.modules.activations import swiglu_linear
|
25 |
+
from fla.modules.layernorm import rms_norm_linear
|
26 |
+
|
27 |
+
if TYPE_CHECKING:
|
28 |
+
from transformers.processing_utils import Unpack
|
29 |
+
|
30 |
+
logger = logging.get_logger(__name__)
|
31 |
+
|
32 |
+
|
33 |
+
class GatedDeltaNetMLP(nn.Module):
|
34 |
+
def __init__(
|
35 |
+
self,
|
36 |
+
hidden_size: int,
|
37 |
+
hidden_ratio: Optional[int] = None,
|
38 |
+
intermediate_size: Optional[int] = None,
|
39 |
+
hidden_act: str = "swish",
|
40 |
+
norm_first: bool = True,
|
41 |
+
norm_eps: float = 1e-5,
|
42 |
+
) -> GatedDeltaNetMLP:
|
43 |
+
super().__init__()
|
44 |
+
|
45 |
+
self.hidden_size = hidden_size
|
46 |
+
# the final number of params is `hidden_ratio * hidden_size^2`
|
47 |
+
# `intermediate_size` is chosen to be a multiple of 256 closest to `2/3 * hidden_size * hidden_ratio`
|
48 |
+
if hidden_ratio is None:
|
49 |
+
hidden_ratio = 4
|
50 |
+
if intermediate_size is None:
|
51 |
+
intermediate_size = int(hidden_size * hidden_ratio * 2 / 3)
|
52 |
+
intermediate_size = 256 * ((intermediate_size + 256 - 1) // 256)
|
53 |
+
self.hidden_ratio = hidden_ratio
|
54 |
+
self.intermediate_size = intermediate_size
|
55 |
+
self.norm_first = norm_first
|
56 |
+
|
57 |
+
if norm_first:
|
58 |
+
self.norm = RMSNorm(hidden_size=hidden_size, eps=norm_eps)
|
59 |
+
|
60 |
+
self.gate_proj = nn.Linear(
|
61 |
+
self.hidden_size, self.intermediate_size * 2, bias=False
|
62 |
+
)
|
63 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
64 |
+
self.act_fn = ACT2FN[hidden_act]
|
65 |
+
|
66 |
+
def forward(
|
67 |
+
self,
|
68 |
+
x: torch.Tensor,
|
69 |
+
**kwargs: Unpack[Dict],
|
70 |
+
) -> torch.Tensor:
|
71 |
+
if self.norm_first:
|
72 |
+
x = rms_norm_linear(
|
73 |
+
x,
|
74 |
+
self.norm.weight,
|
75 |
+
self.norm.bias,
|
76 |
+
self.gate_proj.weight,
|
77 |
+
self.gate_proj.bias,
|
78 |
+
)
|
79 |
+
else:
|
80 |
+
x = self.gate_proj(x)
|
81 |
+
gate, y = x.chunk(2, -1)
|
82 |
+
return swiglu_linear(gate, y, self.down_proj.weight, self.down_proj.bias)
|
83 |
+
|
84 |
+
|
85 |
+
class GatedDeltaProductBlock(nn.Module):
|
86 |
+
def __init__(self, config: GatedDeltaProductConfig, layer_idx: int):
|
87 |
+
super().__init__()
|
88 |
+
self.hidden_size = config.hidden_size
|
89 |
+
|
90 |
+
if not config.norm_first:
|
91 |
+
self.attn_norm = RMSNorm(
|
92 |
+
hidden_size=config.hidden_size, eps=config.norm_eps
|
93 |
+
)
|
94 |
+
if config.attn is not None and layer_idx in config.attn["layers"]:
|
95 |
+
self.attn = Attention(
|
96 |
+
hidden_size=config.hidden_size,
|
97 |
+
num_heads=config.attn["num_heads"],
|
98 |
+
num_kv_heads=config.attn["num_kv_heads"],
|
99 |
+
window_size=config.attn["window_size"],
|
100 |
+
max_position_embeddings=config.max_position_embeddings,
|
101 |
+
layer_idx=layer_idx,
|
102 |
+
)
|
103 |
+
else:
|
104 |
+
self.attn = GatedDeltaProduct(
|
105 |
+
mode=config.attn_mode,
|
106 |
+
hidden_size=config.hidden_size,
|
107 |
+
expand_v=config.expand_v,
|
108 |
+
head_dim=config.head_dim,
|
109 |
+
num_heads=config.num_heads,
|
110 |
+
use_gate=config.use_gate,
|
111 |
+
use_forget_gate=config.use_forget_gate,
|
112 |
+
use_short_conv=config.use_short_conv,
|
113 |
+
conv_size=config.conv_size,
|
114 |
+
norm_first=config.norm_first,
|
115 |
+
norm_eps=config.norm_eps,
|
116 |
+
allow_neg_eigval=config.allow_neg_eigval,
|
117 |
+
num_householder=config.num_householder,
|
118 |
+
layer_idx=layer_idx,
|
119 |
+
use_beta_conv=config.use_beta_conv
|
120 |
+
)
|
121 |
+
if not config.norm_first:
|
122 |
+
self.mlp_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps)
|
123 |
+
self.mlp = GatedDeltaNetMLP(
|
124 |
+
hidden_size=config.hidden_size,
|
125 |
+
hidden_ratio=config.hidden_ratio,
|
126 |
+
intermediate_size=config.intermediate_size,
|
127 |
+
hidden_act=config.hidden_act,
|
128 |
+
norm_first=config.norm_first,
|
129 |
+
norm_eps=config.norm_eps,
|
130 |
+
)
|
131 |
+
|
132 |
+
def forward(
|
133 |
+
self,
|
134 |
+
hidden_states: torch.Tensor,
|
135 |
+
attention_mask: Optional[torch.Tensor] = None,
|
136 |
+
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
137 |
+
use_cache: Optional[bool] = False,
|
138 |
+
output_attentions: Optional[bool] = False,
|
139 |
+
**kwargs: Unpack[Dict],
|
140 |
+
) -> Tuple[
|
141 |
+
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
|
142 |
+
]:
|
143 |
+
residual = hidden_states
|
144 |
+
if hasattr(self, "attn_norm"):
|
145 |
+
hidden_states = self.attn_norm(hidden_states)
|
146 |
+
hidden_states, attentions, past_key_values = self.attn(
|
147 |
+
hidden_states=hidden_states,
|
148 |
+
attention_mask=attention_mask,
|
149 |
+
past_key_values=past_key_values,
|
150 |
+
use_cache=use_cache,
|
151 |
+
output_attentions=output_attentions,
|
152 |
+
**kwargs,
|
153 |
+
)
|
154 |
+
if hasattr(self, "mlp_norm"):
|
155 |
+
hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
|
156 |
+
else:
|
157 |
+
hidden_states = residual + hidden_states
|
158 |
+
residual = hidden_states
|
159 |
+
hidden_states = self.mlp(hidden_states, **kwargs)
|
160 |
+
hidden_states = residual + hidden_states
|
161 |
+
|
162 |
+
outputs = (hidden_states, attentions, past_key_values)
|
163 |
+
|
164 |
+
return outputs
|
165 |
+
|
166 |
+
|
167 |
+
class GatedDeltaProductPreTrainedModel(PreTrainedModel):
|
168 |
+
config_class = GatedDeltaProductConfig
|
169 |
+
supports_gradient_checkpointing = True
|
170 |
+
_no_split_modules = ["GatedDeltaNetBlock"]
|
171 |
+
|
172 |
+
def __init__(self, *inputs, **kwargs):
|
173 |
+
super().__init__(*inputs, **kwargs)
|
174 |
+
|
175 |
+
def _init_weights(
|
176 |
+
self,
|
177 |
+
module: nn.Module,
|
178 |
+
rescale_prenorm_residual: bool = True,
|
179 |
+
num_residuals_per_layer: int = 2,
|
180 |
+
):
|
181 |
+
if isinstance(module, (nn.Linear, nn.Conv1d)):
|
182 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
183 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
184 |
+
nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
|
185 |
+
if module.bias is not None:
|
186 |
+
nn.init.zeros_(module.bias)
|
187 |
+
elif isinstance(module, nn.Embedding):
|
188 |
+
nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
|
189 |
+
if module.padding_idx is not None:
|
190 |
+
module.weight.data[module.padding_idx].zero_()
|
191 |
+
|
192 |
+
if rescale_prenorm_residual:
|
193 |
+
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
|
194 |
+
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
|
195 |
+
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
|
196 |
+
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
|
197 |
+
#
|
198 |
+
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
|
199 |
+
for name, p in module.named_parameters():
|
200 |
+
if name in ["o_proj.weight", "down_proj.weight"]:
|
201 |
+
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
|
202 |
+
# Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
|
203 |
+
# We need to reinit p since this code could be called multiple times
|
204 |
+
# Having just p *= scale would repeatedly scale it down
|
205 |
+
with torch.no_grad():
|
206 |
+
p /= math.sqrt(
|
207 |
+
num_residuals_per_layer * self.config.num_hidden_layers
|
208 |
+
)
|
209 |
+
|
210 |
+
|
211 |
+
class GatedDeltaProductModel(GatedDeltaProductPreTrainedModel):
|
212 |
+
def __init__(self, config: GatedDeltaProductConfig):
|
213 |
+
super().__init__(config)
|
214 |
+
self.padding_idx = config.pad_token_id
|
215 |
+
self.vocab_size = config.vocab_size
|
216 |
+
|
217 |
+
self.embeddings = nn.Embedding(
|
218 |
+
config.vocab_size, config.hidden_size, self.padding_idx
|
219 |
+
)
|
220 |
+
self.layers = nn.ModuleList(
|
221 |
+
[
|
222 |
+
GatedDeltaProductBlock(config, layer_idx)
|
223 |
+
for layer_idx in range(config.num_hidden_layers)
|
224 |
+
]
|
225 |
+
)
|
226 |
+
self.norm = RMSNorm(config.hidden_size, eps=config.norm_eps)
|
227 |
+
|
228 |
+
self.gradient_checkpointing = False
|
229 |
+
|
230 |
+
self.post_init()
|
231 |
+
|
232 |
+
def get_input_embeddings(self):
|
233 |
+
return self.embeddings
|
234 |
+
|
235 |
+
def set_input_embeddings(self, value):
|
236 |
+
self.embeddings = value
|
237 |
+
|
238 |
+
def forward(
|
239 |
+
self,
|
240 |
+
input_ids: Optional[torch.LongTensor] = None,
|
241 |
+
attention_mask: Optional[torch.Tensor] = None,
|
242 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
243 |
+
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
244 |
+
use_cache: Optional[bool] = None,
|
245 |
+
output_attentions: Optional[bool] = None,
|
246 |
+
output_hidden_states: Optional[bool] = None,
|
247 |
+
return_dict: Optional[bool] = None,
|
248 |
+
**kwargs: Unpack[Dict],
|
249 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
250 |
+
if output_attentions:
|
251 |
+
warnings.warn(
|
252 |
+
"`GatedDeltaNetModel` does not `output_attentions` now, setting it to `False`.",
|
253 |
+
stacklevel=2,
|
254 |
+
)
|
255 |
+
output_attentions = False
|
256 |
+
output_attentions = (
|
257 |
+
output_attentions
|
258 |
+
if output_attentions is not None
|
259 |
+
else self.config.output_attentions
|
260 |
+
)
|
261 |
+
output_hidden_states = (
|
262 |
+
output_hidden_states
|
263 |
+
if output_hidden_states is not None
|
264 |
+
else self.config.output_hidden_states
|
265 |
+
)
|
266 |
+
use_cache = (
|
267 |
+
use_cache
|
268 |
+
if use_cache is not None
|
269 |
+
else (self.config.use_cache if not self.training else False)
|
270 |
+
)
|
271 |
+
return_dict = (
|
272 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
273 |
+
)
|
274 |
+
|
275 |
+
# retrieve input_ids and inputs_embeds
|
276 |
+
if input_ids is not None and inputs_embeds is not None:
|
277 |
+
raise ValueError(
|
278 |
+
"You cannot specify both input_ids and inputs_embeds at the same time"
|
279 |
+
)
|
280 |
+
if input_ids is None and inputs_embeds is None:
|
281 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
282 |
+
|
283 |
+
if inputs_embeds is None:
|
284 |
+
inputs_embeds = self.embeddings(input_ids)
|
285 |
+
hidden_states = inputs_embeds
|
286 |
+
|
287 |
+
if use_cache and not isinstance(past_key_values, Cache):
|
288 |
+
past_key_values = Cache.from_legacy_cache(past_key_values)
|
289 |
+
|
290 |
+
if self.gradient_checkpointing and self.training and use_cache:
|
291 |
+
logger.warning_once(
|
292 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
293 |
+
)
|
294 |
+
use_cache = False
|
295 |
+
|
296 |
+
all_hidden_states = () if output_hidden_states else None
|
297 |
+
all_attns = () if output_attentions else None
|
298 |
+
for layer in self.layers:
|
299 |
+
if output_hidden_states:
|
300 |
+
all_hidden_states += (hidden_states,)
|
301 |
+
|
302 |
+
if self.gradient_checkpointing and self.training:
|
303 |
+
hidden_states, attentions, past_key_values = (
|
304 |
+
self._gradient_checkpointing_func(
|
305 |
+
layer.__call__,
|
306 |
+
hidden_states,
|
307 |
+
attention_mask,
|
308 |
+
past_key_values,
|
309 |
+
use_cache,
|
310 |
+
output_attentions,
|
311 |
+
**kwargs,
|
312 |
+
)
|
313 |
+
)
|
314 |
+
else:
|
315 |
+
hidden_states, attentions, past_key_values = layer(
|
316 |
+
hidden_states,
|
317 |
+
attention_mask=attention_mask,
|
318 |
+
past_key_values=past_key_values,
|
319 |
+
use_cache=use_cache,
|
320 |
+
output_attentions=output_attentions,
|
321 |
+
**kwargs,
|
322 |
+
)
|
323 |
+
|
324 |
+
if output_attentions:
|
325 |
+
all_attns += (attentions,)
|
326 |
+
|
327 |
+
hidden_states = self.norm(hidden_states)
|
328 |
+
# add hidden states from the last decoder layer
|
329 |
+
if output_hidden_states:
|
330 |
+
all_hidden_states += (hidden_states,)
|
331 |
+
|
332 |
+
if not return_dict:
|
333 |
+
return tuple(
|
334 |
+
i
|
335 |
+
for i in [
|
336 |
+
hidden_states,
|
337 |
+
past_key_values,
|
338 |
+
all_hidden_states,
|
339 |
+
all_attns,
|
340 |
+
]
|
341 |
+
if i is not None
|
342 |
+
)
|
343 |
+
return BaseModelOutputWithPast(
|
344 |
+
last_hidden_state=hidden_states,
|
345 |
+
past_key_values=past_key_values,
|
346 |
+
hidden_states=all_hidden_states,
|
347 |
+
attentions=all_attns,
|
348 |
+
)
|
349 |
+
|
350 |
+
|
351 |
+
class GatedDeltaProductForCausalLM(GatedDeltaProductPreTrainedModel, GenerationMixin):
|
352 |
+
_tied_weights_keys = ["lm_head.weight"]
|
353 |
+
|
354 |
+
def __init__(self, config):
|
355 |
+
super().__init__(config)
|
356 |
+
self.model = GatedDeltaProductModel(config)
|
357 |
+
self.vocab_size = config.vocab_size
|
358 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
359 |
+
|
360 |
+
# Initialize weights and apply final processing
|
361 |
+
self.post_init()
|
362 |
+
|
363 |
+
def get_input_embeddings(self):
|
364 |
+
return self.model.embeddings
|
365 |
+
|
366 |
+
def set_input_embeddings(self, value):
|
367 |
+
self.model.embeddings = value
|
368 |
+
|
369 |
+
def get_output_embeddings(self):
|
370 |
+
return self.lm_head
|
371 |
+
|
372 |
+
def set_output_embeddings(self, new_embeddings):
|
373 |
+
self.lm_head = new_embeddings
|
374 |
+
|
375 |
+
def set_decoder(self, decoder):
|
376 |
+
self.model = decoder
|
377 |
+
|
378 |
+
def get_decoder(self):
|
379 |
+
return self.model
|
380 |
+
|
381 |
+
def generate(self, *args, **kwargs):
|
382 |
+
try:
|
383 |
+
return super().generate(*args, **kwargs)
|
384 |
+
except AttributeError as exception:
|
385 |
+
if "past_key_values" in str(exception):
|
386 |
+
raise AttributeError(
|
387 |
+
f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
|
388 |
+
f"which is not supported for {self.__class__.__name__}. "
|
389 |
+
f"Try another generation strategy instead. "
|
390 |
+
f"For the available generation strategies, check this doc: "
|
391 |
+
f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
|
392 |
+
)
|
393 |
+
else:
|
394 |
+
raise exception
|
395 |
+
|
396 |
+
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
397 |
+
def prepare_inputs_for_generation(
|
398 |
+
self,
|
399 |
+
input_ids: torch.LongTensor = None,
|
400 |
+
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
401 |
+
attention_mask: Optional[torch.Tensor] = None,
|
402 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
403 |
+
use_cache: bool = True,
|
404 |
+
num_logits_to_keep: Optional[int] = None,
|
405 |
+
logits_to_keep: Optional[int] = None,
|
406 |
+
**kwargs,
|
407 |
+
):
|
408 |
+
# only last token for `inputs_ids` if the `past_key_values` is passed along is not empty.
|
409 |
+
if past_key_values is not None and len(past_key_values) > 0:
|
410 |
+
input_ids = input_ids[:, -1:]
|
411 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
412 |
+
if inputs_embeds is not None and past_key_values is None:
|
413 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
414 |
+
else:
|
415 |
+
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
|
416 |
+
# recompiles graphs as the stride of the inputs is a guard.
|
417 |
+
# Ref: https://github.com/huggingface/transformers/pull/29114
|
418 |
+
# TODO: use `next_tokens` directly instead.
|
419 |
+
model_inputs = {"input_ids": input_ids.contiguous()}
|
420 |
+
|
421 |
+
if logits_to_keep is not None:
|
422 |
+
model_inputs['logits_to_keep'] = logits_to_keep
|
423 |
+
|
424 |
+
model_inputs.update(
|
425 |
+
{
|
426 |
+
"past_key_values": past_key_values,
|
427 |
+
"use_cache": use_cache,
|
428 |
+
"attention_mask": attention_mask,
|
429 |
+
"num_logits_to_keep": num_logits_to_keep,
|
430 |
+
}
|
431 |
+
)
|
432 |
+
return model_inputs
|
433 |
+
|
434 |
+
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
435 |
+
def forward(
|
436 |
+
self,
|
437 |
+
input_ids: torch.LongTensor = None,
|
438 |
+
attention_mask: Optional[torch.Tensor] = None,
|
439 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
440 |
+
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
441 |
+
labels: Optional[torch.LongTensor] = None,
|
442 |
+
use_cache: Optional[bool] = None,
|
443 |
+
output_attentions: Optional[bool] = None,
|
444 |
+
output_hidden_states: Optional[bool] = None,
|
445 |
+
return_dict: Optional[bool] = None,
|
446 |
+
num_logits_to_keep: Optional[int] = 0,
|
447 |
+
logits_to_keep: Optional[int] = 0,
|
448 |
+
**kwargs: Unpack[Dict],
|
449 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
450 |
+
num_logits_to_keep = 0 if num_logits_to_keep is None else num_logits_to_keep
|
451 |
+
output_attentions = (
|
452 |
+
output_attentions
|
453 |
+
if output_attentions is not None
|
454 |
+
else self.config.output_attentions
|
455 |
+
)
|
456 |
+
output_hidden_states = (
|
457 |
+
output_hidden_states
|
458 |
+
if output_hidden_states is not None
|
459 |
+
else self.config.output_hidden_states
|
460 |
+
)
|
461 |
+
return_dict = (
|
462 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
463 |
+
)
|
464 |
+
kwargs.pop("num_items_in_batch", None)
|
465 |
+
outputs = self.model(
|
466 |
+
input_ids=input_ids,
|
467 |
+
attention_mask=attention_mask,
|
468 |
+
inputs_embeds=inputs_embeds,
|
469 |
+
past_key_values=past_key_values,
|
470 |
+
use_cache=use_cache,
|
471 |
+
output_attentions=output_attentions,
|
472 |
+
output_hidden_states=output_hidden_states,
|
473 |
+
return_dict=return_dict,
|
474 |
+
**kwargs,
|
475 |
+
)
|
476 |
+
hidden_states = outputs[0]
|
477 |
+
fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
|
478 |
+
|
479 |
+
loss, logits = None, None
|
480 |
+
if not fuse_linear_and_cross_entropy or labels is None:
|
481 |
+
logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
|
482 |
+
if labels is not None:
|
483 |
+
if self.config.fuse_cross_entropy:
|
484 |
+
if fuse_linear_and_cross_entropy:
|
485 |
+
loss_fct = FusedLinearCrossEntropyLoss()
|
486 |
+
else:
|
487 |
+
loss_fct = FusedCrossEntropyLoss(inplace_backward=True)
|
488 |
+
else:
|
489 |
+
loss_fct = nn.CrossEntropyLoss()
|
490 |
+
# Enable model parallelism
|
491 |
+
labels = labels.to(hidden_states.device)
|
492 |
+
labels = torch.cat(
|
493 |
+
(
|
494 |
+
labels[..., 1:],
|
495 |
+
torch.full_like(labels[:, :1], loss_fct.ignore_index),
|
496 |
+
),
|
497 |
+
1,
|
498 |
+
)
|
499 |
+
if fuse_linear_and_cross_entropy:
|
500 |
+
loss = loss_fct(
|
501 |
+
hidden_states.view(-1, self.config.hidden_size),
|
502 |
+
labels.view(-1),
|
503 |
+
self.lm_head.weight,
|
504 |
+
self.lm_head.bias,
|
505 |
+
)
|
506 |
+
else:
|
507 |
+
loss = loss_fct(
|
508 |
+
logits.view(-1, self.config.vocab_size), labels.view(-1)
|
509 |
+
)
|
510 |
+
|
511 |
+
if not return_dict:
|
512 |
+
output = (logits,) + outputs[1:]
|
513 |
+
return (loss, *output) if loss is not None else output
|
514 |
+
return CausalLMOutputWithPast(
|
515 |
+
loss=loss,
|
516 |
+
logits=logits,
|
517 |
+
past_key_values=outputs.past_key_values,
|
518 |
+
hidden_states=outputs.hidden_states,
|
519 |
+
attentions=outputs.attentions,
|
520 |
+
)
|
fla/models/gla/modeling_gla.py
ADDED
@@ -0,0 +1,417 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
from __future__ import annotations
|
4 |
+
|
5 |
+
import math
|
6 |
+
import warnings
|
7 |
+
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.utils.checkpoint
|
12 |
+
from transformers.generation import GenerationMixin
|
13 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
14 |
+
from transformers.modeling_utils import PreTrainedModel
|
15 |
+
from transformers.utils import logging
|
16 |
+
from transformers.utils.deprecation import deprecate_kwarg
|
17 |
+
|
18 |
+
from fla.layers.attn import Attention
|
19 |
+
from fla.layers.gla import GatedLinearAttention
|
20 |
+
from fla.models.gla.configuration_gla import GLAConfig
|
21 |
+
from fla.models.utils import Cache
|
22 |
+
from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss
|
23 |
+
from fla.modules import GatedMLP as GLAMLP
|
24 |
+
from fla.modules import RMSNorm
|
25 |
+
|
26 |
+
if TYPE_CHECKING:
|
27 |
+
from transformers.processing_utils import Unpack
|
28 |
+
|
29 |
+
logger = logging.get_logger(__name__)
|
30 |
+
|
31 |
+
|
32 |
+
class GLABlock(nn.Module):
|
33 |
+
def __init__(self, config: GLAConfig, layer_idx: int):
|
34 |
+
super().__init__()
|
35 |
+
|
36 |
+
self.config = config
|
37 |
+
self.layer_idx = layer_idx
|
38 |
+
|
39 |
+
self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
|
40 |
+
if config.attn is not None and layer_idx in config.attn['layers']:
|
41 |
+
self.attn = Attention(
|
42 |
+
hidden_size=config.hidden_size,
|
43 |
+
num_heads=config.attn['num_heads'],
|
44 |
+
num_kv_heads=config.attn['num_kv_heads'],
|
45 |
+
qkv_bias=config.attn['qkv_bias'],
|
46 |
+
window_size=config.attn['window_size'],
|
47 |
+
rope_theta=config.attn['rope_theta'],
|
48 |
+
max_position_embeddings=config.max_position_embeddings,
|
49 |
+
layer_idx=layer_idx
|
50 |
+
)
|
51 |
+
else:
|
52 |
+
self.attn = GatedLinearAttention(
|
53 |
+
mode=config.attn_mode,
|
54 |
+
hidden_size=config.hidden_size,
|
55 |
+
expand_k=config.expand_k,
|
56 |
+
expand_v=config.expand_v,
|
57 |
+
num_heads=config.num_heads,
|
58 |
+
num_kv_heads=config.num_kv_heads,
|
59 |
+
feature_map=config.feature_map,
|
60 |
+
use_short_conv=config.use_short_conv,
|
61 |
+
conv_size=config.conv_size,
|
62 |
+
use_output_gate=config.use_output_gate,
|
63 |
+
gate_fn=config.hidden_act,
|
64 |
+
elementwise_affine=config.elementwise_affine,
|
65 |
+
norm_eps=config.norm_eps,
|
66 |
+
clamp_min=config.clamp_min,
|
67 |
+
fuse_norm=config.fuse_norm,
|
68 |
+
layer_idx=layer_idx
|
69 |
+
)
|
70 |
+
self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
|
71 |
+
self.mlp = GLAMLP(
|
72 |
+
hidden_size=config.hidden_size,
|
73 |
+
hidden_ratio=config.hidden_ratio,
|
74 |
+
intermediate_size=config.intermediate_size,
|
75 |
+
hidden_act=config.hidden_act,
|
76 |
+
fuse_swiglu=config.fuse_swiglu
|
77 |
+
)
|
78 |
+
|
79 |
+
def forward(
|
80 |
+
self,
|
81 |
+
hidden_states: torch.Tensor,
|
82 |
+
attention_mask: Optional[torch.Tensor] = None,
|
83 |
+
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
84 |
+
use_cache: Optional[bool] = False,
|
85 |
+
output_attentions: Optional[bool] = False,
|
86 |
+
**kwargs: Unpack[Dict]
|
87 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
88 |
+
residual = hidden_states
|
89 |
+
hidden_states = self.attn_norm(hidden_states)
|
90 |
+
hidden_states, attentions, past_key_values = self.attn(
|
91 |
+
hidden_states=hidden_states,
|
92 |
+
attention_mask=attention_mask,
|
93 |
+
past_key_values=past_key_values,
|
94 |
+
use_cache=use_cache,
|
95 |
+
output_attentions=output_attentions,
|
96 |
+
**kwargs
|
97 |
+
)
|
98 |
+
if self.config.fuse_norm:
|
99 |
+
hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
|
100 |
+
else:
|
101 |
+
hidden_states = residual + hidden_states
|
102 |
+
residual = hidden_states
|
103 |
+
hidden_states = self.mlp_norm(hidden_states)
|
104 |
+
hidden_states = self.mlp(hidden_states, **kwargs)
|
105 |
+
hidden_states = residual + hidden_states
|
106 |
+
|
107 |
+
outputs = (hidden_states, attentions, past_key_values)
|
108 |
+
|
109 |
+
return outputs
|
110 |
+
|
111 |
+
|
112 |
+
class GLAPreTrainedModel(PreTrainedModel):
|
113 |
+
|
114 |
+
config_class = GLAConfig
|
115 |
+
base_model_prefix = 'model'
|
116 |
+
supports_gradient_checkpointing = True
|
117 |
+
_no_split_modules = ['GLABlock']
|
118 |
+
_supports_cache_class = True
|
119 |
+
|
120 |
+
def __init__(self, *inputs, **kwargs):
|
121 |
+
super().__init__(*inputs, **kwargs)
|
122 |
+
|
123 |
+
def _init_weights(
|
124 |
+
self,
|
125 |
+
module: nn.Module,
|
126 |
+
prenorm_residual_strategy: Optional[str] = 'rescale',
|
127 |
+
num_residuals_per_layer: int = 2,
|
128 |
+
):
|
129 |
+
if isinstance(module, (nn.Linear, nn.Conv1d)):
|
130 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
131 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
132 |
+
nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
|
133 |
+
if module.bias is not None:
|
134 |
+
nn.init.zeros_(module.bias)
|
135 |
+
elif isinstance(module, nn.Embedding):
|
136 |
+
nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
|
137 |
+
elif hasattr(module, 'reset_parameters'):
|
138 |
+
module.reset_parameters()
|
139 |
+
|
140 |
+
if prenorm_residual_strategy is not None:
|
141 |
+
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
|
142 |
+
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
|
143 |
+
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
|
144 |
+
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
|
145 |
+
#
|
146 |
+
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
|
147 |
+
p = None
|
148 |
+
if hasattr(module, 'o_proj'):
|
149 |
+
p = module.o_proj.weight
|
150 |
+
elif hasattr(module, 'down_proj'):
|
151 |
+
p = module.down_proj.weight
|
152 |
+
if p is not None:
|
153 |
+
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
|
154 |
+
# Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
|
155 |
+
# We need to reinit p since this code could be called multiple times
|
156 |
+
# Having just p *= scale would repeatedly scale it down
|
157 |
+
if prenorm_residual_strategy == 'rescale':
|
158 |
+
nn.init.kaiming_uniform_(p, a=math.sqrt(5))
|
159 |
+
with torch.no_grad():
|
160 |
+
p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
|
161 |
+
elif prenorm_residual_strategy == 'zero':
|
162 |
+
nn.init.zeros_(p)
|
163 |
+
else:
|
164 |
+
raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}")
|
165 |
+
|
166 |
+
|
167 |
+
class GLAModel(GLAPreTrainedModel):
|
168 |
+
|
169 |
+
def __init__(self, config: GLAConfig):
|
170 |
+
super().__init__(config)
|
171 |
+
self.padding_idx = config.pad_token_id
|
172 |
+
self.vocab_size = config.vocab_size
|
173 |
+
|
174 |
+
self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
175 |
+
self.layers = nn.ModuleList([GLABlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
|
176 |
+
self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
|
177 |
+
|
178 |
+
self.gradient_checkpointing = False
|
179 |
+
|
180 |
+
self.post_init()
|
181 |
+
|
182 |
+
def get_input_embeddings(self):
|
183 |
+
return self.embeddings
|
184 |
+
|
185 |
+
def set_input_embeddings(self, value):
|
186 |
+
self.embeddings = value
|
187 |
+
|
188 |
+
def forward(
|
189 |
+
self,
|
190 |
+
input_ids: Optional[torch.LongTensor] = None,
|
191 |
+
attention_mask: Optional[torch.Tensor] = None, # noqa
|
192 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
193 |
+
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
194 |
+
use_cache: Optional[bool] = None,
|
195 |
+
output_attentions: Optional[bool] = None,
|
196 |
+
output_hidden_states: Optional[bool] = None,
|
197 |
+
return_dict: Optional[bool] = None,
|
198 |
+
**kwargs: Unpack[Dict]
|
199 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
200 |
+
if output_attentions:
|
201 |
+
warnings.warn("`GLAModel` does not `output_attentions` now, setting it to `False`.")
|
202 |
+
output_attentions = False
|
203 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
204 |
+
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
205 |
+
use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
|
206 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
207 |
+
|
208 |
+
# retrieve input_ids and inputs_embeds
|
209 |
+
if input_ids is not None and inputs_embeds is not None:
|
210 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
211 |
+
if input_ids is None and inputs_embeds is None:
|
212 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
213 |
+
|
214 |
+
if inputs_embeds is None:
|
215 |
+
inputs_embeds = self.embeddings(input_ids)
|
216 |
+
hidden_states = inputs_embeds
|
217 |
+
|
218 |
+
if use_cache and not isinstance(past_key_values, Cache):
|
219 |
+
past_key_values = Cache.from_legacy_cache(past_key_values)
|
220 |
+
|
221 |
+
if self.gradient_checkpointing and self.training and use_cache:
|
222 |
+
logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
|
223 |
+
use_cache = False
|
224 |
+
|
225 |
+
all_hidden_states = () if output_hidden_states else None
|
226 |
+
all_attns = () if output_attentions else None
|
227 |
+
for layer in self.layers:
|
228 |
+
if output_hidden_states:
|
229 |
+
all_hidden_states += (hidden_states,)
|
230 |
+
|
231 |
+
if self.gradient_checkpointing and self.training:
|
232 |
+
hidden_states, attentions, past_key_values = self._gradient_checkpointing_func(
|
233 |
+
layer.__call__,
|
234 |
+
hidden_states,
|
235 |
+
attention_mask,
|
236 |
+
past_key_values,
|
237 |
+
use_cache,
|
238 |
+
output_attentions,
|
239 |
+
**kwargs
|
240 |
+
)
|
241 |
+
else:
|
242 |
+
hidden_states, attentions, past_key_values = layer(
|
243 |
+
hidden_states,
|
244 |
+
attention_mask=attention_mask,
|
245 |
+
past_key_values=past_key_values,
|
246 |
+
use_cache=use_cache,
|
247 |
+
output_attentions=output_attentions,
|
248 |
+
**kwargs
|
249 |
+
)
|
250 |
+
|
251 |
+
if output_attentions:
|
252 |
+
all_attns += (attentions,)
|
253 |
+
|
254 |
+
hidden_states = self.norm(hidden_states)
|
255 |
+
|
256 |
+
# add hidden states from the last decoder layer
|
257 |
+
if output_hidden_states:
|
258 |
+
all_hidden_states += (hidden_states,)
|
259 |
+
|
260 |
+
if not return_dict:
|
261 |
+
return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None)
|
262 |
+
return BaseModelOutputWithPast(
|
263 |
+
last_hidden_state=hidden_states,
|
264 |
+
past_key_values=past_key_values,
|
265 |
+
hidden_states=all_hidden_states,
|
266 |
+
attentions=all_attns
|
267 |
+
)
|
268 |
+
|
269 |
+
|
270 |
+
class GLAForCausalLM(GLAPreTrainedModel, GenerationMixin):
|
271 |
+
|
272 |
+
_tied_weights_keys = ["lm_head.weight"]
|
273 |
+
|
274 |
+
def __init__(self, config):
|
275 |
+
super().__init__(config)
|
276 |
+
self.model = GLAModel(config)
|
277 |
+
self.vocab_size = config.vocab_size
|
278 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
279 |
+
self.criterion = None
|
280 |
+
|
281 |
+
# Initialize weights and apply final processing
|
282 |
+
self.post_init()
|
283 |
+
|
284 |
+
def get_input_embeddings(self):
|
285 |
+
return self.model.embeddings
|
286 |
+
|
287 |
+
def set_input_embeddings(self, value):
|
288 |
+
self.model.embeddings = value
|
289 |
+
|
290 |
+
def get_output_embeddings(self):
|
291 |
+
return self.lm_head
|
292 |
+
|
293 |
+
def set_output_embeddings(self, new_embeddings):
|
294 |
+
self.lm_head = new_embeddings
|
295 |
+
|
296 |
+
def set_decoder(self, decoder):
|
297 |
+
self.model = decoder
|
298 |
+
|
299 |
+
def get_decoder(self):
|
300 |
+
return self.model
|
301 |
+
|
302 |
+
def generate(self, *args, **kwargs):
|
303 |
+
try:
|
304 |
+
return super().generate(*args, **kwargs)
|
305 |
+
except AttributeError as exception:
|
306 |
+
if 'past_key_values' in str(exception):
|
307 |
+
raise AttributeError(
|
308 |
+
f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
|
309 |
+
f"which is not supported for {self.__class__.__name__}. "
|
310 |
+
f"Try another generation strategy instead. "
|
311 |
+
f"For the available generation strategies, check this doc: "
|
312 |
+
f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
|
313 |
+
)
|
314 |
+
else:
|
315 |
+
raise exception
|
316 |
+
|
317 |
+
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
318 |
+
def prepare_inputs_for_generation(
|
319 |
+
self,
|
320 |
+
input_ids: torch.LongTensor = None,
|
321 |
+
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
322 |
+
attention_mask: Optional[torch.Tensor] = None,
|
323 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
324 |
+
use_cache: bool = True,
|
325 |
+
logits_to_keep: Optional[int] = None,
|
326 |
+
**kwargs
|
327 |
+
):
|
328 |
+
# only last token for `inputs_ids` if the `past_key_values` is not empty.
|
329 |
+
if past_key_values is not None and len(past_key_values) > 0:
|
330 |
+
input_ids = input_ids[:, -1:]
|
331 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
332 |
+
if inputs_embeds is not None and len(past_key_values) == 0:
|
333 |
+
model_inputs = {'inputs_embeds': inputs_embeds}
|
334 |
+
else:
|
335 |
+
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
|
336 |
+
# recompiles graphs as the stride of the inputs is a guard.
|
337 |
+
# Ref: https://github.com/huggingface/transformers/pull/29114
|
338 |
+
# TODO: use `next_tokens` directly instead.
|
339 |
+
model_inputs = {'input_ids': input_ids.contiguous()}
|
340 |
+
|
341 |
+
if logits_to_keep is not None:
|
342 |
+
model_inputs['logits_to_keep'] = logits_to_keep
|
343 |
+
|
344 |
+
model_inputs.update({
|
345 |
+
'past_key_values': past_key_values,
|
346 |
+
'use_cache': use_cache,
|
347 |
+
'attention_mask': attention_mask,
|
348 |
+
})
|
349 |
+
return model_inputs
|
350 |
+
|
351 |
+
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
352 |
+
def forward(
|
353 |
+
self,
|
354 |
+
input_ids: torch.LongTensor = None,
|
355 |
+
attention_mask: Optional[torch.Tensor] = None,
|
356 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
357 |
+
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
358 |
+
labels: Optional[torch.LongTensor] = None,
|
359 |
+
use_cache: Optional[bool] = None,
|
360 |
+
output_attentions: Optional[bool] = None,
|
361 |
+
output_hidden_states: Optional[bool] = None,
|
362 |
+
return_dict: Optional[bool] = None,
|
363 |
+
logits_to_keep: Optional[int] = 0,
|
364 |
+
**kwargs: Unpack[Dict]
|
365 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
366 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
367 |
+
output_hidden_states = (
|
368 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
369 |
+
)
|
370 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
371 |
+
|
372 |
+
outputs = self.model(
|
373 |
+
input_ids=input_ids,
|
374 |
+
attention_mask=attention_mask,
|
375 |
+
inputs_embeds=inputs_embeds,
|
376 |
+
past_key_values=past_key_values,
|
377 |
+
use_cache=use_cache,
|
378 |
+
output_attentions=output_attentions,
|
379 |
+
output_hidden_states=output_hidden_states,
|
380 |
+
return_dict=return_dict,
|
381 |
+
**kwargs
|
382 |
+
)
|
383 |
+
|
384 |
+
hidden_states = outputs[0]
|
385 |
+
fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
|
386 |
+
|
387 |
+
loss, logits = None, None
|
388 |
+
if not fuse_linear_and_cross_entropy or labels is None:
|
389 |
+
logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
|
390 |
+
if labels is not None:
|
391 |
+
if getattr(self, 'criterion', None) is None:
|
392 |
+
if fuse_linear_and_cross_entropy:
|
393 |
+
criterion = FusedLinearCrossEntropyLoss()
|
394 |
+
elif self.config.fuse_cross_entropy:
|
395 |
+
criterion = FusedCrossEntropyLoss(inplace_backward=True)
|
396 |
+
else:
|
397 |
+
criterion = nn.CrossEntropyLoss()
|
398 |
+
else:
|
399 |
+
criterion = self.criterion
|
400 |
+
labels = labels.to(hidden_states.device)
|
401 |
+
labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
|
402 |
+
if fuse_linear_and_cross_entropy:
|
403 |
+
loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
|
404 |
+
else:
|
405 |
+
loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
|
406 |
+
|
407 |
+
if not return_dict:
|
408 |
+
output = (logits,) + outputs[1:]
|
409 |
+
return (loss,) + output if loss is not None else output
|
410 |
+
|
411 |
+
return CausalLMOutputWithPast(
|
412 |
+
loss=loss,
|
413 |
+
logits=logits,
|
414 |
+
past_key_values=outputs.past_key_values,
|
415 |
+
hidden_states=outputs.hidden_states,
|
416 |
+
attentions=outputs.attentions,
|
417 |
+
)
|
fla/models/gsa/configuration_gsa.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
from typing import Dict, Optional
|
4 |
+
|
5 |
+
from transformers.configuration_utils import PretrainedConfig
|
6 |
+
|
7 |
+
|
8 |
+
class GSAConfig(PretrainedConfig):
|
9 |
+
|
10 |
+
model_type = 'gsa'
|
11 |
+
keys_to_ignore_at_inference = ['past_key_values']
|
12 |
+
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
hidden_size: int = 2048,
|
16 |
+
gate_logit_normalizer: Optional[int] = 8,
|
17 |
+
clamp_min: Optional[float] = None,
|
18 |
+
clamp_max: Optional[float] = None,
|
19 |
+
hidden_ratio: Optional[int] = 4,
|
20 |
+
intermediate_size: Optional[int] = None,
|
21 |
+
num_hidden_layers: int = 24,
|
22 |
+
num_heads: int = 4,
|
23 |
+
num_kv_heads: Optional[int] = None,
|
24 |
+
num_slots: Optional[int] = 64,
|
25 |
+
use_short_conv: bool = False,
|
26 |
+
conv_size: int = 4,
|
27 |
+
exapnd_k: float = 1,
|
28 |
+
exapnd_v: float = 1,
|
29 |
+
feature_map: str = 'swish',
|
30 |
+
use_output_gate: bool = False,
|
31 |
+
use_norm: bool = True,
|
32 |
+
max_position_embeddings: int = 2048,
|
33 |
+
hidden_act: str = "swish",
|
34 |
+
elementwise_affine: Optional[bool] = True,
|
35 |
+
norm_eps: float = 1e-6,
|
36 |
+
attn: Optional[Dict] = None,
|
37 |
+
use_cache: bool = True,
|
38 |
+
pad_token_id: int = None,
|
39 |
+
bos_token_id: int = 1,
|
40 |
+
eos_token_id: int = 2,
|
41 |
+
initializer_range: float = 0.006,
|
42 |
+
tie_word_embeddings: bool = False,
|
43 |
+
fuse_norm: bool = True,
|
44 |
+
fuse_swiglu: bool = True,
|
45 |
+
fuse_cross_entropy: bool = True,
|
46 |
+
vocab_size: int = 32000,
|
47 |
+
**kwargs
|
48 |
+
):
|
49 |
+
self.hidden_size = hidden_size
|
50 |
+
self.gate_logit_normalizer = gate_logit_normalizer
|
51 |
+
self.clamp_min = clamp_min
|
52 |
+
self.clamp_max = clamp_max
|
53 |
+
self.hidden_ratio = hidden_ratio
|
54 |
+
self.intermediate_size = intermediate_size
|
55 |
+
self.num_hidden_layers = num_hidden_layers
|
56 |
+
self.num_heads = num_heads
|
57 |
+
self.num_kv_heads = num_kv_heads
|
58 |
+
self.num_slots = num_slots
|
59 |
+
self.use_short_conv = use_short_conv
|
60 |
+
self.conv_size = conv_size
|
61 |
+
self.expand_k = exapnd_k
|
62 |
+
self.expand_v = exapnd_v
|
63 |
+
self.feature_map = feature_map
|
64 |
+
self.use_output_gate = use_output_gate
|
65 |
+
self.use_norm = use_norm
|
66 |
+
self.max_position_embeddings = max_position_embeddings
|
67 |
+
self.hidden_act = hidden_act
|
68 |
+
self.elementwise_affine = elementwise_affine
|
69 |
+
self.norm_eps = norm_eps
|
70 |
+
self.attn = attn
|
71 |
+
self.use_cache = use_cache
|
72 |
+
self.initializer_range = initializer_range
|
73 |
+
|
74 |
+
self.fuse_norm = fuse_norm
|
75 |
+
self.fuse_swiglu = fuse_swiglu
|
76 |
+
self.fuse_cross_entropy = fuse_cross_entropy
|
77 |
+
self.vocab_size = vocab_size
|
78 |
+
|
79 |
+
if attn is not None:
|
80 |
+
if not isinstance(attn, Dict):
|
81 |
+
raise ValueError("attn must be a dictionary")
|
82 |
+
if 'layers' not in attn:
|
83 |
+
raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
|
84 |
+
if 'num_heads' not in attn:
|
85 |
+
raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
|
86 |
+
attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads'])
|
87 |
+
attn['qkv_bias'] = attn.get('qkv_bias', False)
|
88 |
+
attn['window_size'] = attn.get('window_size', None)
|
89 |
+
attn['rope_theta'] = attn.get('rope_theta', 10000.)
|
90 |
+
|
91 |
+
super().__init__(
|
92 |
+
pad_token_id=pad_token_id,
|
93 |
+
bos_token_id=bos_token_id,
|
94 |
+
eos_token_id=eos_token_id,
|
95 |
+
tie_word_embeddings=tie_word_embeddings,
|
96 |
+
**kwargs,
|
97 |
+
)
|
fla/models/hgrn/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (665 Bytes). View file
|
|
fla/models/hgrn/__pycache__/configuration_hgrn.cpython-312.pyc
ADDED
Binary file (3.28 kB). View file
|
|
fla/models/lightnet/__pycache__/configuration_lightnet.cpython-312.pyc
ADDED
Binary file (3.36 kB). View file
|
|
fla/models/lightnet/__pycache__/modeling_lightnet.cpython-312.pyc
ADDED
Binary file (18.3 kB). View file
|
|
fla/models/lightnet/modeling_lightnet.py
ADDED
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
from __future__ import annotations
|
4 |
+
|
5 |
+
import math
|
6 |
+
import warnings
|
7 |
+
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.utils.checkpoint
|
12 |
+
from transformers.generation import GenerationMixin
|
13 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
14 |
+
from transformers.modeling_utils import PreTrainedModel
|
15 |
+
from transformers.utils import logging
|
16 |
+
from transformers.utils.deprecation import deprecate_kwarg
|
17 |
+
|
18 |
+
from fla.layers.attn import Attention
|
19 |
+
from fla.layers.lightnet import LightNetAttention
|
20 |
+
from fla.models.lightnet.configuration_lightnet import LightNetConfig
|
21 |
+
from fla.models.utils import Cache
|
22 |
+
from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss
|
23 |
+
from fla.modules import GatedMLP as LightNetMLP
|
24 |
+
from fla.modules import RMSNorm
|
25 |
+
|
26 |
+
if TYPE_CHECKING:
|
27 |
+
from transformers.processing_utils import Unpack
|
28 |
+
|
29 |
+
logger = logging.get_logger(__name__)
|
30 |
+
|
31 |
+
|
32 |
+
class LightNetBlock(nn.Module):
|
33 |
+
def __init__(self, config: LightNetConfig, layer_idx: int):
|
34 |
+
super().__init__()
|
35 |
+
|
36 |
+
self.config = config
|
37 |
+
self.layer_idx = layer_idx
|
38 |
+
|
39 |
+
self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
|
40 |
+
if config.attn is not None and layer_idx in config.attn['layers']:
|
41 |
+
self.attn = Attention(
|
42 |
+
hidden_size=config.hidden_size,
|
43 |
+
num_heads=config.attn['num_heads'],
|
44 |
+
num_kv_heads=config.attn['num_kv_heads'],
|
45 |
+
qkv_bias=config.attn['qkv_bias'],
|
46 |
+
window_size=config.attn['window_size'],
|
47 |
+
max_position_embeddings=config.max_position_embeddings,
|
48 |
+
layer_idx=layer_idx
|
49 |
+
)
|
50 |
+
else:
|
51 |
+
self.attn = LightNetAttention(
|
52 |
+
mode=config.attn_mode,
|
53 |
+
hidden_size=config.hidden_size,
|
54 |
+
num_heads=config.num_heads,
|
55 |
+
expand_ratio=config.expand_ratio,
|
56 |
+
use_short_conv=config.use_short_conv,
|
57 |
+
conv_size=config.conv_size,
|
58 |
+
gate_low_rank_dim=config.gate_low_rank_dim,
|
59 |
+
elementwise_affine=config.elementwise_affine,
|
60 |
+
norm_eps=config.norm_eps,
|
61 |
+
layer_idx=layer_idx
|
62 |
+
)
|
63 |
+
self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
|
64 |
+
self.mlp = LightNetMLP(
|
65 |
+
hidden_size=config.hidden_size,
|
66 |
+
hidden_ratio=config.hidden_ratio,
|
67 |
+
intermediate_size=config.intermediate_size,
|
68 |
+
hidden_act=config.hidden_act,
|
69 |
+
fuse_swiglu=config.fuse_swiglu
|
70 |
+
)
|
71 |
+
|
72 |
+
def forward(
|
73 |
+
self,
|
74 |
+
hidden_states: torch.Tensor,
|
75 |
+
attention_mask: Optional[torch.Tensor] = None,
|
76 |
+
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
77 |
+
use_cache: Optional[bool] = False,
|
78 |
+
output_attentions: Optional[bool] = False,
|
79 |
+
**kwargs: Unpack[Dict]
|
80 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
81 |
+
residual = hidden_states
|
82 |
+
hidden_states = self.attn_norm(hidden_states)
|
83 |
+
hidden_states, attentions, past_key_values = self.attn(
|
84 |
+
hidden_states=hidden_states,
|
85 |
+
attention_mask=attention_mask,
|
86 |
+
past_key_values=past_key_values,
|
87 |
+
use_cache=use_cache,
|
88 |
+
output_attentions=output_attentions,
|
89 |
+
**kwargs
|
90 |
+
)
|
91 |
+
if self.config.fuse_norm:
|
92 |
+
hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
|
93 |
+
else:
|
94 |
+
hidden_states = residual + hidden_states
|
95 |
+
residual = hidden_states
|
96 |
+
hidden_states = self.mlp_norm(hidden_states)
|
97 |
+
hidden_states = self.mlp(hidden_states, **kwargs)
|
98 |
+
hidden_states = residual + hidden_states
|
99 |
+
|
100 |
+
outputs = (hidden_states, attentions, past_key_values)
|
101 |
+
|
102 |
+
return outputs
|
103 |
+
|
104 |
+
|
105 |
+
class LightNetPreTrainedModel(PreTrainedModel):
|
106 |
+
|
107 |
+
config_class = LightNetConfig
|
108 |
+
supports_gradient_checkpointing = True
|
109 |
+
_no_split_modules = ['LightNetBlock']
|
110 |
+
_supports_cache_class = True
|
111 |
+
|
112 |
+
def __init__(self, *inputs, **kwargs):
|
113 |
+
super().__init__(*inputs, **kwargs)
|
114 |
+
|
115 |
+
def _init_weights(
|
116 |
+
self,
|
117 |
+
module: nn.Module,
|
118 |
+
prenorm_residual_strategy: Optional[str] = 'rescale',
|
119 |
+
num_residuals_per_layer: int = 2,
|
120 |
+
):
|
121 |
+
if isinstance(module, (nn.Linear, nn.Conv1d)):
|
122 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
123 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
124 |
+
nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
|
125 |
+
if module.bias is not None:
|
126 |
+
nn.init.zeros_(module.bias)
|
127 |
+
elif isinstance(module, nn.Embedding):
|
128 |
+
nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
|
129 |
+
elif hasattr(module, 'reset_parameters'):
|
130 |
+
module.reset_parameters()
|
131 |
+
|
132 |
+
if prenorm_residual_strategy is not None:
|
133 |
+
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
|
134 |
+
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
|
135 |
+
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
|
136 |
+
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
|
137 |
+
#
|
138 |
+
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
|
139 |
+
p = None
|
140 |
+
if hasattr(module, 'o_proj'):
|
141 |
+
p = module.o_proj.weight
|
142 |
+
elif hasattr(module, 'down_proj'):
|
143 |
+
p = module.down_proj.weight
|
144 |
+
if p is not None:
|
145 |
+
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
|
146 |
+
# Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
|
147 |
+
# We need to reinit p since this code could be called multiple times
|
148 |
+
# Having just p *= scale would repeatedly scale it down
|
149 |
+
if prenorm_residual_strategy == 'rescale':
|
150 |
+
nn.init.kaiming_uniform_(p, a=math.sqrt(5))
|
151 |
+
with torch.no_grad():
|
152 |
+
p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
|
153 |
+
elif prenorm_residual_strategy == 'zero':
|
154 |
+
nn.init.zeros_(p)
|
155 |
+
else:
|
156 |
+
raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}")
|
157 |
+
|
158 |
+
|
159 |
+
class LightNetModel(LightNetPreTrainedModel):
|
160 |
+
|
161 |
+
def __init__(self, config: LightNetConfig):
|
162 |
+
super().__init__(config)
|
163 |
+
self.padding_idx = config.pad_token_id
|
164 |
+
self.vocab_size = config.vocab_size
|
165 |
+
|
166 |
+
self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
167 |
+
self.layers = nn.ModuleList([LightNetBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
|
168 |
+
self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
|
169 |
+
|
170 |
+
self.gradient_checkpointing = False
|
171 |
+
|
172 |
+
self.post_init()
|
173 |
+
|
174 |
+
def get_input_embeddings(self):
|
175 |
+
return self.embeddings
|
176 |
+
|
177 |
+
def set_input_embeddings(self, value):
|
178 |
+
self.embeddings = value
|
179 |
+
|
180 |
+
def forward(
|
181 |
+
self,
|
182 |
+
input_ids: Optional[torch.LongTensor] = None,
|
183 |
+
attention_mask: Optional[torch.Tensor] = None, # noqa
|
184 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
185 |
+
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
186 |
+
use_cache: Optional[bool] = None,
|
187 |
+
output_attentions: Optional[bool] = None,
|
188 |
+
output_hidden_states: Optional[bool] = None,
|
189 |
+
return_dict: Optional[bool] = None,
|
190 |
+
**kwargs: Unpack[Dict]
|
191 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
192 |
+
if output_attentions:
|
193 |
+
warnings.warn("`LightNetModel` does not `output_attentions` now, setting it to `False`.")
|
194 |
+
output_attentions = False
|
195 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
196 |
+
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
197 |
+
use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
|
198 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
199 |
+
|
200 |
+
# retrieve input_ids and inputs_embeds
|
201 |
+
if input_ids is not None and inputs_embeds is not None:
|
202 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
203 |
+
if input_ids is None and inputs_embeds is None:
|
204 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
205 |
+
|
206 |
+
if inputs_embeds is None:
|
207 |
+
inputs_embeds = self.embeddings(input_ids)
|
208 |
+
hidden_states = inputs_embeds
|
209 |
+
|
210 |
+
if use_cache and not isinstance(past_key_values, Cache):
|
211 |
+
past_key_values = Cache.from_legacy_cache(past_key_values)
|
212 |
+
|
213 |
+
if self.gradient_checkpointing and self.training and use_cache:
|
214 |
+
logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
|
215 |
+
use_cache = False
|
216 |
+
|
217 |
+
all_hidden_states = () if output_hidden_states else None
|
218 |
+
all_attns = () if output_attentions else None
|
219 |
+
|
220 |
+
for i, layer in enumerate(self.layers):
|
221 |
+
if output_hidden_states:
|
222 |
+
all_hidden_states += (hidden_states,)
|
223 |
+
|
224 |
+
if self.gradient_checkpointing and self.training:
|
225 |
+
hidden_states, attentions, past_key_values = self._gradient_checkpointing_func(
|
226 |
+
layer.__call__,
|
227 |
+
hidden_states,
|
228 |
+
attention_mask,
|
229 |
+
past_key_values,
|
230 |
+
use_cache,
|
231 |
+
output_attentions,
|
232 |
+
**kwargs
|
233 |
+
)
|
234 |
+
else:
|
235 |
+
hidden_states, attentions, past_key_values = layer(
|
236 |
+
hidden_states,
|
237 |
+
attention_mask=attention_mask,
|
238 |
+
past_key_values=past_key_values,
|
239 |
+
use_cache=use_cache,
|
240 |
+
output_attentions=output_attentions,
|
241 |
+
**kwargs
|
242 |
+
)
|
243 |
+
|
244 |
+
if output_attentions:
|
245 |
+
all_attns += (attentions,)
|
246 |
+
|
247 |
+
hidden_states = self.norm(hidden_states)
|
248 |
+
|
249 |
+
# add hidden states from the last decoder layer
|
250 |
+
if output_hidden_states:
|
251 |
+
all_hidden_states += (hidden_states,)
|
252 |
+
|
253 |
+
if not return_dict:
|
254 |
+
return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None)
|
255 |
+
return BaseModelOutputWithPast(
|
256 |
+
last_hidden_state=hidden_states,
|
257 |
+
past_key_values=past_key_values,
|
258 |
+
hidden_states=all_hidden_states,
|
259 |
+
attentions=all_attns
|
260 |
+
)
|
261 |
+
|
262 |
+
|
263 |
+
class LightNetForCausalLM(LightNetPreTrainedModel, GenerationMixin):
|
264 |
+
|
265 |
+
_tied_weights_keys = ["lm_head.weight"]
|
266 |
+
|
267 |
+
def __init__(self, config):
|
268 |
+
super().__init__(config)
|
269 |
+
self.model = LightNetModel(config)
|
270 |
+
self.vocab_size = config.vocab_size
|
271 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
272 |
+
self.criterion = None
|
273 |
+
|
274 |
+
# Initialize weights and apply final processing
|
275 |
+
self.post_init()
|
276 |
+
|
277 |
+
def get_input_embeddings(self):
|
278 |
+
return self.model.embeddings
|
279 |
+
|
280 |
+
def set_input_embeddings(self, value):
|
281 |
+
self.model.embeddings = value
|
282 |
+
|
283 |
+
def get_output_embeddings(self):
|
284 |
+
return self.lm_head
|
285 |
+
|
286 |
+
def set_output_embeddings(self, new_embeddings):
|
287 |
+
self.lm_head = new_embeddings
|
288 |
+
|
289 |
+
def set_decoder(self, decoder):
|
290 |
+
self.model = decoder
|
291 |
+
|
292 |
+
def get_decoder(self):
|
293 |
+
return self.model
|
294 |
+
|
295 |
+
def generate(self, *args, **kwargs):
|
296 |
+
try:
|
297 |
+
return super().generate(*args, **kwargs)
|
298 |
+
except AttributeError as exception:
|
299 |
+
if 'past_key_values' in str(exception):
|
300 |
+
raise AttributeError(
|
301 |
+
f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
|
302 |
+
f"which is not supported for {self.__class__.__name__}. "
|
303 |
+
f"Try another generation strategy instead. "
|
304 |
+
f"For the available generation strategies, check this doc: "
|
305 |
+
f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
|
306 |
+
)
|
307 |
+
else:
|
308 |
+
raise exception
|
309 |
+
|
310 |
+
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
311 |
+
def prepare_inputs_for_generation(
|
312 |
+
self,
|
313 |
+
input_ids: torch.LongTensor = None,
|
314 |
+
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
315 |
+
attention_mask: Optional[torch.Tensor] = None,
|
316 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
317 |
+
use_cache: bool = True,
|
318 |
+
logits_to_keep: Optional[int] = None,
|
319 |
+
**kwargs: Unpack[Dict]
|
320 |
+
):
|
321 |
+
# only last token for `inputs_ids` if the `past_key_values` is not empty.
|
322 |
+
if past_key_values is not None and len(past_key_values) > 0:
|
323 |
+
input_ids = input_ids[:, -1:]
|
324 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
325 |
+
if inputs_embeds is not None and len(past_key_values) == 0:
|
326 |
+
model_inputs = {'inputs_embeds': inputs_embeds}
|
327 |
+
else:
|
328 |
+
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
|
329 |
+
# recompiles graphs as the stride of the inputs is a guard.
|
330 |
+
# Ref: https://github.com/huggingface/transformers/pull/29114
|
331 |
+
# TODO: use `next_tokens` directly instead.
|
332 |
+
model_inputs = {'input_ids': input_ids.contiguous()}
|
333 |
+
|
334 |
+
if logits_to_keep is not None:
|
335 |
+
model_inputs['logits_to_keep'] = logits_to_keep
|
336 |
+
|
337 |
+
model_inputs.update({
|
338 |
+
'past_key_values': past_key_values,
|
339 |
+
'use_cache': use_cache,
|
340 |
+
'attention_mask': attention_mask,
|
341 |
+
})
|
342 |
+
return model_inputs
|
343 |
+
|
344 |
+
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
345 |
+
def forward(
|
346 |
+
self,
|
347 |
+
input_ids: torch.LongTensor = None,
|
348 |
+
attention_mask: Optional[torch.Tensor] = None,
|
349 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
350 |
+
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
351 |
+
labels: Optional[torch.LongTensor] = None,
|
352 |
+
use_cache: Optional[bool] = None,
|
353 |
+
output_attentions: Optional[bool] = None,
|
354 |
+
output_hidden_states: Optional[bool] = None,
|
355 |
+
return_dict: Optional[bool] = None,
|
356 |
+
logits_to_keep: Optional[int] = 0,
|
357 |
+
**kwargs: Unpack[Dict]
|
358 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
359 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
360 |
+
output_hidden_states = (
|
361 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
362 |
+
)
|
363 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
364 |
+
|
365 |
+
outputs = self.model(
|
366 |
+
input_ids=input_ids,
|
367 |
+
attention_mask=attention_mask,
|
368 |
+
inputs_embeds=inputs_embeds,
|
369 |
+
past_key_values=past_key_values,
|
370 |
+
use_cache=use_cache,
|
371 |
+
output_attentions=output_attentions,
|
372 |
+
output_hidden_states=output_hidden_states,
|
373 |
+
return_dict=return_dict,
|
374 |
+
**kwargs
|
375 |
+
)
|
376 |
+
|
377 |
+
hidden_states = outputs[0]
|
378 |
+
fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
|
379 |
+
|
380 |
+
loss, logits = None, None
|
381 |
+
if not fuse_linear_and_cross_entropy or labels is None:
|
382 |
+
logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
|
383 |
+
if labels is not None:
|
384 |
+
if getattr(self, 'criterion', None) is None:
|
385 |
+
if fuse_linear_and_cross_entropy:
|
386 |
+
criterion = FusedLinearCrossEntropyLoss()
|
387 |
+
elif self.config.fuse_cross_entropy:
|
388 |
+
criterion = FusedCrossEntropyLoss(inplace_backward=True)
|
389 |
+
else:
|
390 |
+
criterion = nn.CrossEntropyLoss()
|
391 |
+
else:
|
392 |
+
criterion = self.criterion
|
393 |
+
labels = labels.to(hidden_states.device)
|
394 |
+
labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
|
395 |
+
if fuse_linear_and_cross_entropy:
|
396 |
+
loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
|
397 |
+
else:
|
398 |
+
loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
|
399 |
+
|
400 |
+
if not return_dict:
|
401 |
+
output = (logits,) + outputs[1:]
|
402 |
+
return (loss,) + output if loss is not None else output
|
403 |
+
|
404 |
+
return CausalLMOutputWithPast(
|
405 |
+
loss=loss,
|
406 |
+
logits=logits,
|
407 |
+
past_key_values=outputs.past_key_values,
|
408 |
+
hidden_states=outputs.hidden_states,
|
409 |
+
attentions=outputs.attentions,
|
410 |
+
)
|
fla/models/linear_attn/__pycache__/configuration_linear_attn.cpython-312.pyc
ADDED
Binary file (3.65 kB). View file
|
|
fla/models/linear_attn/__pycache__/modeling_linear_attn.cpython-312.pyc
ADDED
Binary file (18.5 kB). View file
|
|
fla/models/mamba2/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (695 Bytes). View file
|
|
fla/models/mamba2/__pycache__/configuration_mamba2.cpython-312.pyc
ADDED
Binary file (7.5 kB). View file
|
|
fla/models/nsa/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (657 Bytes). View file
|
|
fla/models/nsa/modeling_nsa.py
ADDED
@@ -0,0 +1,398 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
from __future__ import annotations
|
4 |
+
|
5 |
+
import math
|
6 |
+
import warnings
|
7 |
+
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.utils.checkpoint
|
12 |
+
from transformers.generation import GenerationMixin
|
13 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
14 |
+
from transformers.modeling_utils import PreTrainedModel
|
15 |
+
from transformers.utils import logging
|
16 |
+
from transformers.utils.deprecation import deprecate_kwarg
|
17 |
+
|
18 |
+
from fla.layers.nsa import NativeSparseAttention
|
19 |
+
from fla.models.nsa.configuration_nsa import NSAConfig
|
20 |
+
from fla.models.utils import Cache
|
21 |
+
from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss
|
22 |
+
from fla.modules import GatedMLP as NSAMLP
|
23 |
+
from fla.modules import RMSNorm
|
24 |
+
|
25 |
+
if TYPE_CHECKING:
|
26 |
+
from transformers.processing_utils import Unpack
|
27 |
+
|
28 |
+
logger = logging.get_logger(__name__)
|
29 |
+
|
30 |
+
|
31 |
+
class NSABlock(nn.Module):
|
32 |
+
def __init__(self, config: NSAConfig, layer_idx: int):
|
33 |
+
super().__init__()
|
34 |
+
|
35 |
+
self.config = config
|
36 |
+
self.layer_idx = layer_idx
|
37 |
+
|
38 |
+
self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
|
39 |
+
self.attn = NativeSparseAttention(
|
40 |
+
hidden_size=config.hidden_size,
|
41 |
+
num_heads=config.num_heads,
|
42 |
+
num_kv_heads=config.num_kv_heads,
|
43 |
+
qkv_bias=config.qkv_bias,
|
44 |
+
block_size=config.block_size,
|
45 |
+
block_counts=config.block_counts,
|
46 |
+
window_size=config.window_size,
|
47 |
+
rope_theta=config.rope_theta,
|
48 |
+
max_position_embeddings=config.max_position_embeddings,
|
49 |
+
layer_idx=layer_idx
|
50 |
+
)
|
51 |
+
self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
|
52 |
+
self.mlp = NSAMLP(
|
53 |
+
hidden_size=config.hidden_size,
|
54 |
+
hidden_ratio=config.hidden_ratio,
|
55 |
+
intermediate_size=config.intermediate_size,
|
56 |
+
hidden_act=config.hidden_act,
|
57 |
+
fuse_swiglu=config.fuse_swiglu
|
58 |
+
)
|
59 |
+
|
60 |
+
def forward(
|
61 |
+
self,
|
62 |
+
hidden_states: torch.Tensor,
|
63 |
+
attention_mask: Optional[torch.Tensor] = None,
|
64 |
+
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
65 |
+
use_cache: Optional[bool] = False,
|
66 |
+
output_attentions: Optional[bool] = False,
|
67 |
+
**kwargs: Unpack[Dict]
|
68 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
69 |
+
residual = hidden_states
|
70 |
+
hidden_states = self.attn_norm(hidden_states)
|
71 |
+
hidden_states, attentions, past_key_values = self.attn(
|
72 |
+
hidden_states=hidden_states,
|
73 |
+
attention_mask=attention_mask,
|
74 |
+
past_key_values=past_key_values,
|
75 |
+
use_cache=use_cache,
|
76 |
+
output_attentions=output_attentions,
|
77 |
+
**kwargs
|
78 |
+
)
|
79 |
+
if self.config.fuse_norm:
|
80 |
+
hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
|
81 |
+
else:
|
82 |
+
hidden_states = residual + hidden_states
|
83 |
+
residual = hidden_states
|
84 |
+
hidden_states = self.mlp_norm(hidden_states)
|
85 |
+
hidden_states = self.mlp(hidden_states, **kwargs)
|
86 |
+
hidden_states = residual + hidden_states
|
87 |
+
|
88 |
+
outputs = (hidden_states, attentions, past_key_values)
|
89 |
+
|
90 |
+
return outputs
|
91 |
+
|
92 |
+
|
93 |
+
class NSAPreTrainedModel(PreTrainedModel):
|
94 |
+
|
95 |
+
config_class = NSAConfig
|
96 |
+
base_model_prefix = 'model'
|
97 |
+
supports_gradient_checkpointing = True
|
98 |
+
_no_split_modules = ['NSABlock']
|
99 |
+
_supports_cache_class = True
|
100 |
+
|
101 |
+
def __init__(self, *inputs, **kwargs):
|
102 |
+
super().__init__(*inputs, **kwargs)
|
103 |
+
|
104 |
+
def _init_weights(
|
105 |
+
self,
|
106 |
+
module: nn.Module,
|
107 |
+
prenorm_residual_strategy: Optional[str] = 'rescale',
|
108 |
+
num_residuals_per_layer: int = 2,
|
109 |
+
):
|
110 |
+
if isinstance(module, (nn.Linear, nn.Conv1d)):
|
111 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
112 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
113 |
+
nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
|
114 |
+
if module.bias is not None:
|
115 |
+
nn.init.zeros_(module.bias)
|
116 |
+
elif isinstance(module, nn.Embedding):
|
117 |
+
nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
|
118 |
+
elif hasattr(module, 'reset_parameters'):
|
119 |
+
module.reset_parameters()
|
120 |
+
|
121 |
+
if prenorm_residual_strategy is not None:
|
122 |
+
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
|
123 |
+
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
|
124 |
+
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
|
125 |
+
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
|
126 |
+
#
|
127 |
+
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
|
128 |
+
p = None
|
129 |
+
if hasattr(module, 'o_proj'):
|
130 |
+
p = module.o_proj.weight
|
131 |
+
elif hasattr(module, 'down_proj'):
|
132 |
+
p = module.down_proj.weight
|
133 |
+
if p is not None:
|
134 |
+
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
|
135 |
+
# Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
|
136 |
+
# We need to reinit p since this code could be called multiple times
|
137 |
+
# Having just p *= scale would repeatedly scale it down
|
138 |
+
if prenorm_residual_strategy == 'rescale':
|
139 |
+
nn.init.kaiming_uniform_(p, a=math.sqrt(5))
|
140 |
+
with torch.no_grad():
|
141 |
+
p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
|
142 |
+
elif prenorm_residual_strategy == 'zero':
|
143 |
+
nn.init.zeros_(p)
|
144 |
+
else:
|
145 |
+
raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}")
|
146 |
+
|
147 |
+
|
148 |
+
class NSAModel(NSAPreTrainedModel):
|
149 |
+
|
150 |
+
def __init__(self, config: NSAConfig):
|
151 |
+
super().__init__(config)
|
152 |
+
self.padding_idx = config.pad_token_id
|
153 |
+
self.vocab_size = config.vocab_size
|
154 |
+
|
155 |
+
self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
156 |
+
self.layers = nn.ModuleList([NSABlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
|
157 |
+
self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
|
158 |
+
|
159 |
+
self.gradient_checkpointing = False
|
160 |
+
|
161 |
+
self.post_init()
|
162 |
+
|
163 |
+
def get_input_embeddings(self):
|
164 |
+
return self.embeddings
|
165 |
+
|
166 |
+
def set_input_embeddings(self, value):
|
167 |
+
self.embeddings = value
|
168 |
+
|
169 |
+
def forward(
|
170 |
+
self,
|
171 |
+
input_ids: Optional[torch.LongTensor] = None,
|
172 |
+
attention_mask: Optional[torch.Tensor] = None, # noqa
|
173 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
174 |
+
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
175 |
+
use_cache: Optional[bool] = None,
|
176 |
+
output_attentions: Optional[bool] = None,
|
177 |
+
output_hidden_states: Optional[bool] = None,
|
178 |
+
return_dict: Optional[bool] = None,
|
179 |
+
**kwargs: Unpack[Dict]
|
180 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
181 |
+
if output_attentions:
|
182 |
+
warnings.warn("`NSAModel` does not `output_attentions` now, setting it to `False`.")
|
183 |
+
output_attentions = False
|
184 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
185 |
+
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
186 |
+
use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
|
187 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
188 |
+
|
189 |
+
# retrieve input_ids and inputs_embeds
|
190 |
+
if input_ids is not None and inputs_embeds is not None:
|
191 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
192 |
+
if input_ids is None and inputs_embeds is None:
|
193 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
194 |
+
|
195 |
+
if inputs_embeds is None:
|
196 |
+
inputs_embeds = self.embeddings(input_ids)
|
197 |
+
hidden_states = inputs_embeds
|
198 |
+
|
199 |
+
if use_cache and not isinstance(past_key_values, Cache):
|
200 |
+
past_key_values = Cache.from_legacy_cache(past_key_values)
|
201 |
+
|
202 |
+
if self.gradient_checkpointing and self.training and use_cache:
|
203 |
+
logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
|
204 |
+
use_cache = False
|
205 |
+
|
206 |
+
all_hidden_states = () if output_hidden_states else None
|
207 |
+
all_attns = () if output_attentions else None
|
208 |
+
for layer in self.layers:
|
209 |
+
if output_hidden_states:
|
210 |
+
all_hidden_states += (hidden_states,)
|
211 |
+
|
212 |
+
if self.gradient_checkpointing and self.training:
|
213 |
+
hidden_states, attentions, past_key_values = self._gradient_checkpointing_func(
|
214 |
+
layer.__call__,
|
215 |
+
hidden_states,
|
216 |
+
attention_mask,
|
217 |
+
past_key_values,
|
218 |
+
use_cache,
|
219 |
+
output_attentions,
|
220 |
+
**kwargs
|
221 |
+
)
|
222 |
+
else:
|
223 |
+
hidden_states, attentions, past_key_values = layer(
|
224 |
+
hidden_states,
|
225 |
+
attention_mask=attention_mask,
|
226 |
+
past_key_values=past_key_values,
|
227 |
+
use_cache=use_cache,
|
228 |
+
output_attentions=output_attentions,
|
229 |
+
**kwargs
|
230 |
+
)
|
231 |
+
|
232 |
+
if output_attentions:
|
233 |
+
all_attns += (attentions,)
|
234 |
+
|
235 |
+
hidden_states = self.norm(hidden_states)
|
236 |
+
|
237 |
+
# add hidden states from the last decoder layer
|
238 |
+
if output_hidden_states:
|
239 |
+
all_hidden_states += (hidden_states,)
|
240 |
+
|
241 |
+
if not return_dict:
|
242 |
+
return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None)
|
243 |
+
return BaseModelOutputWithPast(
|
244 |
+
last_hidden_state=hidden_states,
|
245 |
+
past_key_values=past_key_values,
|
246 |
+
hidden_states=all_hidden_states,
|
247 |
+
attentions=all_attns
|
248 |
+
)
|
249 |
+
|
250 |
+
|
251 |
+
class NSAForCausalLM(NSAPreTrainedModel, GenerationMixin):
|
252 |
+
|
253 |
+
_tied_weights_keys = ["lm_head.weight"]
|
254 |
+
|
255 |
+
def __init__(self, config):
|
256 |
+
super().__init__(config)
|
257 |
+
self.model = NSAModel(config)
|
258 |
+
self.vocab_size = config.vocab_size
|
259 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
260 |
+
self.criterion = None
|
261 |
+
|
262 |
+
# Initialize weights and apply final processing
|
263 |
+
self.post_init()
|
264 |
+
|
265 |
+
def get_input_embeddings(self):
|
266 |
+
return self.model.embeddings
|
267 |
+
|
268 |
+
def set_input_embeddings(self, value):
|
269 |
+
self.model.embeddings = value
|
270 |
+
|
271 |
+
def get_output_embeddings(self):
|
272 |
+
return self.lm_head
|
273 |
+
|
274 |
+
def set_output_embeddings(self, new_embeddings):
|
275 |
+
self.lm_head = new_embeddings
|
276 |
+
|
277 |
+
def set_decoder(self, decoder):
|
278 |
+
self.model = decoder
|
279 |
+
|
280 |
+
def get_decoder(self):
|
281 |
+
return self.model
|
282 |
+
|
283 |
+
def generate(self, *args, **kwargs):
|
284 |
+
try:
|
285 |
+
return super().generate(*args, **kwargs)
|
286 |
+
except AttributeError as exception:
|
287 |
+
if 'past_key_values' in str(exception):
|
288 |
+
raise AttributeError(
|
289 |
+
f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
|
290 |
+
f"which is not supported for {self.__class__.__name__}. "
|
291 |
+
f"Try another generation strategy instead. "
|
292 |
+
f"For the available generation strategies, check this doc: "
|
293 |
+
f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
|
294 |
+
)
|
295 |
+
else:
|
296 |
+
raise exception
|
297 |
+
|
298 |
+
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
299 |
+
def prepare_inputs_for_generation(
|
300 |
+
self,
|
301 |
+
input_ids: torch.LongTensor = None,
|
302 |
+
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
303 |
+
attention_mask: Optional[torch.Tensor] = None,
|
304 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
305 |
+
use_cache: bool = True,
|
306 |
+
logits_to_keep: Optional[int] = None,
|
307 |
+
**kwargs
|
308 |
+
):
|
309 |
+
# only last token for `inputs_ids` if the `past_key_values` is not empty.
|
310 |
+
if past_key_values is not None and len(past_key_values) > 0:
|
311 |
+
input_ids = input_ids[:, -1:]
|
312 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
313 |
+
if inputs_embeds is not None and len(past_key_values) == 0:
|
314 |
+
model_inputs = {'inputs_embeds': inputs_embeds}
|
315 |
+
else:
|
316 |
+
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
|
317 |
+
# recompiles graphs as the stride of the inputs is a guard.
|
318 |
+
# Ref: https://github.com/huggingface/transformers/pull/29114
|
319 |
+
# TODO: use `next_tokens` directly instead.
|
320 |
+
model_inputs = {'input_ids': input_ids.contiguous()}
|
321 |
+
|
322 |
+
if logits_to_keep is not None:
|
323 |
+
model_inputs['logits_to_keep'] = logits_to_keep
|
324 |
+
|
325 |
+
model_inputs.update({
|
326 |
+
'past_key_values': past_key_values,
|
327 |
+
'use_cache': use_cache,
|
328 |
+
'attention_mask': attention_mask,
|
329 |
+
})
|
330 |
+
return model_inputs
|
331 |
+
|
332 |
+
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
333 |
+
def forward(
|
334 |
+
self,
|
335 |
+
input_ids: torch.LongTensor = None,
|
336 |
+
attention_mask: Optional[torch.Tensor] = None,
|
337 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
338 |
+
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
339 |
+
labels: Optional[torch.LongTensor] = None,
|
340 |
+
use_cache: Optional[bool] = None,
|
341 |
+
output_attentions: Optional[bool] = None,
|
342 |
+
output_hidden_states: Optional[bool] = None,
|
343 |
+
return_dict: Optional[bool] = None,
|
344 |
+
logits_to_keep: Optional[int] = 0,
|
345 |
+
**kwargs: Unpack[Dict]
|
346 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
347 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
348 |
+
output_hidden_states = (
|
349 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
350 |
+
)
|
351 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
352 |
+
|
353 |
+
outputs = self.model(
|
354 |
+
input_ids=input_ids,
|
355 |
+
attention_mask=attention_mask,
|
356 |
+
inputs_embeds=inputs_embeds,
|
357 |
+
past_key_values=past_key_values,
|
358 |
+
use_cache=use_cache,
|
359 |
+
output_attentions=output_attentions,
|
360 |
+
output_hidden_states=output_hidden_states,
|
361 |
+
return_dict=return_dict,
|
362 |
+
**kwargs
|
363 |
+
)
|
364 |
+
|
365 |
+
hidden_states = outputs[0]
|
366 |
+
fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
|
367 |
+
|
368 |
+
loss, logits = None, None
|
369 |
+
if not fuse_linear_and_cross_entropy or labels is None:
|
370 |
+
logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
|
371 |
+
if labels is not None:
|
372 |
+
if getattr(self, 'criterion', None) is None:
|
373 |
+
if fuse_linear_and_cross_entropy:
|
374 |
+
criterion = FusedLinearCrossEntropyLoss()
|
375 |
+
elif self.config.fuse_cross_entropy:
|
376 |
+
criterion = FusedCrossEntropyLoss(inplace_backward=True)
|
377 |
+
else:
|
378 |
+
criterion = nn.CrossEntropyLoss()
|
379 |
+
else:
|
380 |
+
criterion = self.criterion
|
381 |
+
labels = labels.to(hidden_states.device)
|
382 |
+
labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
|
383 |
+
if fuse_linear_and_cross_entropy:
|
384 |
+
loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
|
385 |
+
else:
|
386 |
+
loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
|
387 |
+
|
388 |
+
if not return_dict:
|
389 |
+
output = (logits,) + outputs[1:]
|
390 |
+
return (loss,) + output if loss is not None else output
|
391 |
+
|
392 |
+
return CausalLMOutputWithPast(
|
393 |
+
loss=loss,
|
394 |
+
logits=logits,
|
395 |
+
past_key_values=outputs.past_key_values,
|
396 |
+
hidden_states=outputs.hidden_states,
|
397 |
+
attentions=outputs.attentions,
|
398 |
+
)
|
fla/models/retnet/configuration_retnet.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
from __future__ import annotations
|
4 |
+
|
5 |
+
from typing import Dict, Optional
|
6 |
+
|
7 |
+
from transformers.configuration_utils import PretrainedConfig
|
8 |
+
|
9 |
+
|
10 |
+
class RetNetConfig(PretrainedConfig):
|
11 |
+
|
12 |
+
model_type = 'retnet'
|
13 |
+
keys_to_ignore_at_inference = ['past_key_values']
|
14 |
+
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
attn_mode: str = "chunk",
|
18 |
+
hidden_size: int = 2048,
|
19 |
+
expand_k: int = 1,
|
20 |
+
expand_v: int = 2,
|
21 |
+
hidden_ratio: Optional[int] = 2,
|
22 |
+
intermediate_size: Optional[int] = None,
|
23 |
+
num_hidden_layers: int = 24,
|
24 |
+
num_heads: int = 8,
|
25 |
+
num_kv_heads: Optional[int] = None,
|
26 |
+
feature_map: Optional[str] = None,
|
27 |
+
hidden_act: str = "swish",
|
28 |
+
use_short_conv: bool = False,
|
29 |
+
conv_size: int = 4,
|
30 |
+
use_output_gate: bool = True,
|
31 |
+
max_position_embeddings: int = 2048,
|
32 |
+
elementwise_affine: Optional[bool] = True,
|
33 |
+
norm_eps: float = 1e-6,
|
34 |
+
attn: Optional[Dict] = None,
|
35 |
+
use_cache: bool = True,
|
36 |
+
pad_token_id: int = None,
|
37 |
+
bos_token_id: int = 1,
|
38 |
+
eos_token_id: int = 2,
|
39 |
+
tie_word_embeddings: bool = False,
|
40 |
+
initializer_range: float = 0.006,
|
41 |
+
fuse_norm: bool = True,
|
42 |
+
fuse_swiglu: bool = True,
|
43 |
+
fuse_cross_entropy: bool = True,
|
44 |
+
vocab_size: int = 32000,
|
45 |
+
**kwargs
|
46 |
+
) -> RetNetConfig:
|
47 |
+
self.attn_mode = attn_mode
|
48 |
+
self.hidden_size = hidden_size
|
49 |
+
self.expand_k = expand_k
|
50 |
+
self.expand_v = expand_v
|
51 |
+
self.hidden_ratio = hidden_ratio
|
52 |
+
self.intermediate_size = intermediate_size
|
53 |
+
self.num_hidden_layers = num_hidden_layers
|
54 |
+
self.num_heads = num_heads
|
55 |
+
self.num_kv_heads = num_kv_heads
|
56 |
+
self.feature_map = feature_map
|
57 |
+
self.hidden_act = hidden_act
|
58 |
+
self.use_short_conv = use_short_conv
|
59 |
+
self.conv_size = conv_size
|
60 |
+
self.use_output_gate = use_output_gate
|
61 |
+
self.hidden_act = hidden_act
|
62 |
+
self.max_position_embeddings = max_position_embeddings
|
63 |
+
self.elementwise_affine = elementwise_affine
|
64 |
+
self.norm_eps = norm_eps
|
65 |
+
self.attn = attn
|
66 |
+
self.use_cache = use_cache
|
67 |
+
self.initializer_range = initializer_range
|
68 |
+
|
69 |
+
self.fuse_norm = fuse_norm
|
70 |
+
self.fuse_swiglu = fuse_swiglu
|
71 |
+
self.fuse_cross_entropy = fuse_cross_entropy
|
72 |
+
self.vocab_size = vocab_size
|
73 |
+
|
74 |
+
if attn is not None:
|
75 |
+
if not isinstance(attn, Dict):
|
76 |
+
raise ValueError("attn must be a dictionary")
|
77 |
+
if 'layers' not in attn:
|
78 |
+
raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
|
79 |
+
if 'num_heads' not in attn:
|
80 |
+
raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
|
81 |
+
attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads'])
|
82 |
+
attn['qkv_bias'] = attn.get('qkv_bias', False)
|
83 |
+
attn['window_size'] = attn.get('window_size', None)
|
84 |
+
attn['rope_theta'] = attn.get('rope_theta', 10000.)
|
85 |
+
|
86 |
+
super().__init__(
|
87 |
+
pad_token_id=pad_token_id,
|
88 |
+
bos_token_id=bos_token_id,
|
89 |
+
eos_token_id=eos_token_id,
|
90 |
+
tie_word_embeddings=tie_word_embeddings,
|
91 |
+
**kwargs,
|
92 |
+
)
|
fla/models/rwkv6/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (687 Bytes). View file
|
|
fla/models/rwkv7/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (687 Bytes). View file
|
|
fla/models/rwkv7/__pycache__/configuration_rwkv7.cpython-312.pyc
ADDED
Binary file (4.24 kB). View file
|
|
fla/models/transformer/__pycache__/configuration_transformer.cpython-312.pyc
ADDED
Binary file (2.52 kB). View file
|
|
fla/models/transformer_mtp/__pycache__/modeling_transformer.cpython-312.pyc
ADDED
Binary file (24.7 kB). View file
|
|
fla/models/transformer_mtp/modeling_transformer.py
ADDED
@@ -0,0 +1,608 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
from __future__ import annotations
|
4 |
+
|
5 |
+
import math
|
6 |
+
import warnings
|
7 |
+
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
import torch.utils.checkpoint
|
13 |
+
from dataclasses import dataclass
|
14 |
+
from transformers.generation import GenerationMixin
|
15 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
16 |
+
from transformers.modeling_utils import PreTrainedModel
|
17 |
+
from transformers.utils import logging
|
18 |
+
from transformers.utils.deprecation import deprecate_kwarg
|
19 |
+
|
20 |
+
import triton
|
21 |
+
import triton.language as tl
|
22 |
+
|
23 |
+
from fla.layers.attn import Attention
|
24 |
+
from fla.models.transformer_mtp.configuration_transformer import MTPTransformerConfig
|
25 |
+
from fla.models.utils import Cache
|
26 |
+
from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss
|
27 |
+
from fla.modules import GatedMLP as TransformerMLP
|
28 |
+
from fla.modules import RMSNorm
|
29 |
+
|
30 |
+
if TYPE_CHECKING:
|
31 |
+
from transformers.processing_utils import Unpack
|
32 |
+
|
33 |
+
|
34 |
+
logger = logging.get_logger(__name__)
|
35 |
+
|
36 |
+
class SequentialHeadsCustomBackward(torch.autograd.Function):
|
37 |
+
@staticmethod
|
38 |
+
def forward(ctx, trunk_output, lm_head, norm_layer, logits_to_keep, *prediction_heads):
|
39 |
+
# We now need the norm layer in the forward pass calculation
|
40 |
+
ctx.prediction_heads = prediction_heads
|
41 |
+
ctx.lm_head = lm_head
|
42 |
+
ctx.norm_layer = norm_layer
|
43 |
+
ctx.logits_to_keep = logits_to_keep
|
44 |
+
ctx.save_for_backward(trunk_output)
|
45 |
+
|
46 |
+
latents = []
|
47 |
+
for head in prediction_heads:
|
48 |
+
# Assuming head forward signature is `head(hidden_states)`
|
49 |
+
latent = head(trunk_output)[0]
|
50 |
+
latents.append(latent)
|
51 |
+
|
52 |
+
latents_stacked = torch.stack(latents, dim=-2)
|
53 |
+
# Apply the final norm before the lm_head
|
54 |
+
normalized_latents = norm_layer(latents_stacked)
|
55 |
+
all_logits = lm_head(normalized_latents[:, -logits_to_keep:])
|
56 |
+
return all_logits
|
57 |
+
|
58 |
+
@staticmethod
|
59 |
+
def backward(ctx, grad_output):
|
60 |
+
trunk_output, = ctx.saved_tensors
|
61 |
+
prediction_heads = ctx.prediction_heads
|
62 |
+
lm_head = ctx.lm_head
|
63 |
+
norm_layer = ctx.norm_layer
|
64 |
+
logits_to_keep = ctx.logits_to_keep
|
65 |
+
|
66 |
+
d = trunk_output.detach().requires_grad_(True)
|
67 |
+
grad_output_per_head = grad_output.unbind(dim=2)
|
68 |
+
|
69 |
+
# We need to manually handle the backward pass for the final norm layer once
|
70 |
+
# before the loop, as its gradient depends on all heads.
|
71 |
+
# To do this, we reconstruct the input to the lm_head and do a backward pass.
|
72 |
+
with torch.enable_grad():
|
73 |
+
# Re-run the head computations to get the input to the norm layer
|
74 |
+
latents = []
|
75 |
+
for head in prediction_heads:
|
76 |
+
latents.append(head(d)[0])
|
77 |
+
latents_stacked = torch.stack(latents, dim=-2)
|
78 |
+
latents_stacked.requires_grad_(True)
|
79 |
+
# The part of the graph we need to backprop through first
|
80 |
+
normalized_latents = norm_layer(latents_stacked)
|
81 |
+
|
82 |
+
# Backpropagate through the lm_head and norm_layer
|
83 |
+
normalized_latents.backward(lm_head.weight.grad @ grad_output)
|
84 |
+
|
85 |
+
# Now, `latents_stacked.grad` contains the sum of gradients from all heads
|
86 |
+
# just before the final normalization. We can now unbind it.
|
87 |
+
grad_per_head_latent = latents_stacked.grad.unbind(dim=-2)
|
88 |
+
|
89 |
+
# Now, backpropagate through each head individually.
|
90 |
+
for i, head in enumerate(prediction_heads):
|
91 |
+
with torch.enable_grad():
|
92 |
+
head_latent = head(d)[0]
|
93 |
+
# Backpropagate using the gradient for this specific head's output
|
94 |
+
head_latent.backward(gradient=grad_per_head_latent[i])
|
95 |
+
|
96 |
+
num_nones = 2 + len(prediction_heads) # for lm_head, norm_layer, and *prediction_heads
|
97 |
+
return (d.grad,) + (None,) * num_nones
|
98 |
+
|
99 |
+
def seq_to_mtp(
|
100 |
+
long_input_ids: torch.Tensor,
|
101 |
+
model_seq_len: int,
|
102 |
+
n_future_tokens: int
|
103 |
+
) -> torch.Tensor:
|
104 |
+
"""
|
105 |
+
Generates a tensor of future targets on the fly from a long input sequence.
|
106 |
+
|
107 |
+
This version assumes `long_input_ids` contains both the tokens for the model's
|
108 |
+
input AND the future tokens needed for the labels.
|
109 |
+
It extracts the correct targets without adding artificial padding.
|
110 |
+
|
111 |
+
Args:
|
112 |
+
long_input_ids (torch.Tensor): The input sequences from the dataloader,
|
113 |
+
shape (B, T + n_future_tokens).
|
114 |
+
model_seq_len (int): The sequence length `T` that the model processes.
|
115 |
+
n_future_tokens (int): The number of future tokens to predict for each time step.
|
116 |
+
|
117 |
+
Returns:
|
118 |
+
torch.Tensor: The target tensor of shape (B, T, n_future_tokens).
|
119 |
+
y[b, t, k] corresponds to the (k+1)-th token after input_ids[b, t].
|
120 |
+
"""
|
121 |
+
B, total_len = long_input_ids.shape
|
122 |
+
assert total_len >= model_seq_len + n_future_tokens, \
|
123 |
+
"long_input_ids must be at least model_seq_len + n_future_tokens long."
|
124 |
+
|
125 |
+
# 1. Create sliding windows (views) over the long tensor.
|
126 |
+
# .unfold() is a highly efficient way to create sliding windows.
|
127 |
+
# We create windows of size `n_future_tokens + 1`. For each time step `t`,
|
128 |
+
# the window will contain the input token and its `n_future_tokens` targets.
|
129 |
+
# Example (n=3, window_size=4):
|
130 |
+
# For t=0, window is [t0, t1, t2, t3]
|
131 |
+
# For t=1, window is [t1, t2, t3, t4]
|
132 |
+
# Shape of windows: (B, total_len - n_future_tokens, n_future_tokens + 1)
|
133 |
+
windows = long_input_ids.unfold(dimension=1, size=n_future_tokens + 1, step=1)
|
134 |
+
|
135 |
+
# 2. Slice the windows to get only the targets.
|
136 |
+
# We slice off the first element of each window (the input token itself)
|
137 |
+
# to keep only the future tokens.
|
138 |
+
# Example window [t0, t1, t2, t3] -> becomes targets [t1, t2, t3]
|
139 |
+
all_targets = windows[:, :, 1:]
|
140 |
+
|
141 |
+
# 3. Trim the result to match the model's output sequence length.
|
142 |
+
# We only need the targets for the first `model_seq_len` positions.
|
143 |
+
output_targets = all_targets[:, :model_seq_len, :]
|
144 |
+
|
145 |
+
return output_targets.transpose(1, 2)
|
146 |
+
|
147 |
+
|
148 |
+
@dataclass
|
149 |
+
class MTPLMOutputWithPast(CausalLMOutputWithPast):
|
150 |
+
ntp_loss: Optional[torch.FloatTensor] = None
|
151 |
+
mtp_loss: Optional[torch.FloatTensor] = None
|
152 |
+
|
153 |
+
class MTPTransformerBlock(nn.Module):
|
154 |
+
|
155 |
+
def __init__(self, config: MTPTransformerConfig, layer_idx: int):
|
156 |
+
super().__init__()
|
157 |
+
|
158 |
+
self.config = config
|
159 |
+
self.layer_idx = layer_idx
|
160 |
+
|
161 |
+
self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
|
162 |
+
self.attn = Attention(
|
163 |
+
hidden_size=config.hidden_size,
|
164 |
+
num_heads=config.num_heads,
|
165 |
+
num_kv_heads=config.num_kv_heads,
|
166 |
+
qkv_bias=config.qkv_bias,
|
167 |
+
qk_norm=config.qk_norm,
|
168 |
+
window_size=config.window_size,
|
169 |
+
rope_theta=config.rope_theta,
|
170 |
+
max_position_embeddings=config.max_position_embeddings,
|
171 |
+
layer_idx=layer_idx
|
172 |
+
)
|
173 |
+
|
174 |
+
self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
|
175 |
+
self.mlp = TransformerMLP(
|
176 |
+
hidden_size=config.hidden_size,
|
177 |
+
hidden_ratio=config.hidden_ratio,
|
178 |
+
intermediate_size=config.intermediate_size,
|
179 |
+
hidden_act=config.hidden_act,
|
180 |
+
fuse_swiglu=config.fuse_swiglu
|
181 |
+
)
|
182 |
+
|
183 |
+
def forward(
|
184 |
+
self,
|
185 |
+
hidden_states: torch.Tensor,
|
186 |
+
attention_mask: Optional[torch.Tensor] = None,
|
187 |
+
past_key_values: Optional[Tuple[torch.Tensor]] = None,
|
188 |
+
output_attentions: Optional[bool] = False,
|
189 |
+
use_cache: Optional[bool] = False,
|
190 |
+
**kwargs: Unpack[Any]
|
191 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
192 |
+
|
193 |
+
residual = hidden_states
|
194 |
+
hidden_states = self.attn_norm(hidden_states)
|
195 |
+
hidden_states, attentions, past_key_values = self.attn(
|
196 |
+
hidden_states=hidden_states,
|
197 |
+
attention_mask=attention_mask,
|
198 |
+
past_key_values=past_key_values,
|
199 |
+
use_cache=use_cache,
|
200 |
+
output_attentions=output_attentions,
|
201 |
+
**kwargs
|
202 |
+
)
|
203 |
+
if self.config.fuse_norm:
|
204 |
+
hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
|
205 |
+
else:
|
206 |
+
hidden_states = residual + hidden_states
|
207 |
+
residual = hidden_states
|
208 |
+
hidden_states = self.mlp_norm(hidden_states)
|
209 |
+
hidden_states = self.mlp(hidden_states, **kwargs)
|
210 |
+
hidden_states = residual + hidden_states
|
211 |
+
|
212 |
+
outputs = (hidden_states,)
|
213 |
+
|
214 |
+
if output_attentions:
|
215 |
+
outputs += (attentions,)
|
216 |
+
|
217 |
+
if use_cache:
|
218 |
+
outputs += (past_key_values,)
|
219 |
+
|
220 |
+
return outputs
|
221 |
+
|
222 |
+
|
223 |
+
class MTPTransformerPreTrainedModel(PreTrainedModel):
|
224 |
+
|
225 |
+
config_class = MTPTransformerConfig
|
226 |
+
base_model_prefix = 'model'
|
227 |
+
supports_gradient_checkpointing = True
|
228 |
+
_no_split_modules = ['MTPTransformerBlock']
|
229 |
+
_supports_cache_class = True
|
230 |
+
|
231 |
+
def __init__(self, *inputs, **kwargs):
|
232 |
+
super().__init__(*inputs, **kwargs)
|
233 |
+
|
234 |
+
def _init_weights(
|
235 |
+
self,
|
236 |
+
module: nn.Module,
|
237 |
+
rescale_prenorm_residual: bool = False,
|
238 |
+
num_residuals_per_layer: int = 2,
|
239 |
+
):
|
240 |
+
if isinstance(module, (nn.Linear, nn.Conv1d)):
|
241 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
242 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
243 |
+
nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
|
244 |
+
if module.bias is not None:
|
245 |
+
nn.init.zeros_(module.bias)
|
246 |
+
elif isinstance(module, nn.Embedding):
|
247 |
+
nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
|
248 |
+
elif hasattr(module, 'reset_parameters'):
|
249 |
+
module.reset_parameters()
|
250 |
+
|
251 |
+
if rescale_prenorm_residual:
|
252 |
+
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
|
253 |
+
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
|
254 |
+
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
|
255 |
+
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
|
256 |
+
#
|
257 |
+
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
|
258 |
+
p = None
|
259 |
+
if hasattr(module, 'o_proj'):
|
260 |
+
p = module.o_proj.weight
|
261 |
+
elif hasattr(module, 'down_proj'):
|
262 |
+
p = module.down_proj.weight
|
263 |
+
if p is not None:
|
264 |
+
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
|
265 |
+
# Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
|
266 |
+
# We need to reinit p since this code could be called multiple times
|
267 |
+
# Having just p *= scale would repeatedly scale it down
|
268 |
+
nn.init.kaiming_uniform_(p, a=math.sqrt(5))
|
269 |
+
with torch.no_grad():
|
270 |
+
p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
|
271 |
+
|
272 |
+
|
273 |
+
class MTPTransformerModel(MTPTransformerPreTrainedModel):
|
274 |
+
|
275 |
+
def __init__(
|
276 |
+
self,
|
277 |
+
config: MTPTransformerConfig
|
278 |
+
) -> MTPTransformerModel:
|
279 |
+
super().__init__(config)
|
280 |
+
self.padding_idx = config.pad_token_id
|
281 |
+
self.vocab_size = config.vocab_size
|
282 |
+
|
283 |
+
self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
284 |
+
self.layers = nn.ModuleList([MTPTransformerBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers - config.n_future_tokens)])
|
285 |
+
self.extra_heads = nn.ModuleList([MTPTransformerBlock(config, layer_idx) for layer_idx in range(config.n_future_tokens)])
|
286 |
+
self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
|
287 |
+
|
288 |
+
self.gradient_checkpointing = False
|
289 |
+
|
290 |
+
self.post_init()
|
291 |
+
|
292 |
+
def get_input_embeddings(self):
|
293 |
+
return self.embeddings
|
294 |
+
|
295 |
+
def set_input_embeddings(self, value):
|
296 |
+
self.embeddings = value
|
297 |
+
|
298 |
+
def forward(
|
299 |
+
self,
|
300 |
+
input_ids: Optional[torch.LongTensor] = None,
|
301 |
+
attention_mask: Optional[torch.Tensor] = None,
|
302 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
303 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
304 |
+
use_cache: Optional[bool] = None,
|
305 |
+
output_attentions: Optional[bool] = None,
|
306 |
+
output_hidden_states: Optional[bool] = None,
|
307 |
+
return_dict: Optional[bool] = None,
|
308 |
+
return_all_heads: bool = False, # if Training, this is True
|
309 |
+
**kwargs: Unpack[Any]
|
310 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
311 |
+
if output_attentions:
|
312 |
+
warnings.warn(
|
313 |
+
"`TransformerModel` does not support output attention weights now, so `output_attentions` is set to `False`."
|
314 |
+
)
|
315 |
+
output_attentions = False
|
316 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
317 |
+
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
318 |
+
use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
|
319 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
320 |
+
use_custom_backward = self.config.use_custom_backward and self.training
|
321 |
+
if self.training and return_all_heads is False:
|
322 |
+
logger.warning_once(
|
323 |
+
"`return_all_heads=False` is incompatible with training. Setting `return_all_heads=True`..."
|
324 |
+
)
|
325 |
+
return_all_heads = True
|
326 |
+
|
327 |
+
# retrieve input_ids and inputs_embeds
|
328 |
+
if input_ids is not None and inputs_embeds is not None:
|
329 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
330 |
+
elif input_ids is None and inputs_embeds is None:
|
331 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
332 |
+
|
333 |
+
if use_cache and not isinstance(past_key_values, Cache):
|
334 |
+
past_key_values = Cache.from_legacy_cache(past_key_values)
|
335 |
+
|
336 |
+
if inputs_embeds is None:
|
337 |
+
inputs_embeds = self.embeddings(input_ids)
|
338 |
+
|
339 |
+
# embed positions
|
340 |
+
hidden_states = inputs_embeds
|
341 |
+
|
342 |
+
if self.gradient_checkpointing and self.training:
|
343 |
+
if use_cache:
|
344 |
+
logger.warning_once(
|
345 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
346 |
+
)
|
347 |
+
use_cache = False
|
348 |
+
|
349 |
+
all_hidden_states = () if output_hidden_states else None
|
350 |
+
all_attns = () if output_attentions else None
|
351 |
+
next_cache = None
|
352 |
+
|
353 |
+
for layer in self.layers:
|
354 |
+
if output_hidden_states:
|
355 |
+
all_hidden_states += (hidden_states,)
|
356 |
+
|
357 |
+
if self.gradient_checkpointing and self.training:
|
358 |
+
layer_outputs = self._gradient_checkpointing_func(
|
359 |
+
layer.__call__,
|
360 |
+
hidden_states,
|
361 |
+
attention_mask,
|
362 |
+
past_key_values,
|
363 |
+
output_attentions,
|
364 |
+
use_cache,
|
365 |
+
**kwargs
|
366 |
+
)
|
367 |
+
else:
|
368 |
+
layer_outputs = layer(
|
369 |
+
hidden_states,
|
370 |
+
attention_mask=attention_mask,
|
371 |
+
past_key_values=past_key_values,
|
372 |
+
output_attentions=output_attentions,
|
373 |
+
use_cache=use_cache,
|
374 |
+
**kwargs
|
375 |
+
)
|
376 |
+
|
377 |
+
hidden_states = layer_outputs[0]
|
378 |
+
|
379 |
+
if use_cache:
|
380 |
+
next_cache = layer_outputs[2 if output_attentions else 1]
|
381 |
+
|
382 |
+
if output_attentions:
|
383 |
+
all_attns += (layer_outputs[1],)
|
384 |
+
|
385 |
+
trunk = hidden_states
|
386 |
+
|
387 |
+
n_heads_to_use = self.config.n_future_tokens if return_all_heads else 1
|
388 |
+
prediction_heads_to_use = self.extra_heads[:n_heads_to_use]
|
389 |
+
|
390 |
+
if use_custom_backward and self.training:
|
391 |
+
# all_logits = SequentialHeadsCustomBackward.apply(trunk, self.lm_head, *prediction_heads)
|
392 |
+
hidden_states = trunk # return hidden states and apply custom backward on the MTPTransformersLM
|
393 |
+
else:
|
394 |
+
latents = []
|
395 |
+
for i, layer in enumerate(prediction_heads_to_use):
|
396 |
+
if output_hidden_states:
|
397 |
+
all_hidden_states += (hidden_states,)
|
398 |
+
|
399 |
+
if self.gradient_checkpointing and self.training:
|
400 |
+
layer_outputs = self._gradient_checkpointing_func(
|
401 |
+
layer.__call__,
|
402 |
+
trunk, # Use trunk instead of hidden states
|
403 |
+
attention_mask,
|
404 |
+
past_key_values,
|
405 |
+
output_attentions,
|
406 |
+
use_cache,
|
407 |
+
**kwargs
|
408 |
+
)
|
409 |
+
else:
|
410 |
+
layer_outputs = layer(
|
411 |
+
trunk, # Use trunk instead of hidden states
|
412 |
+
attention_mask=attention_mask,
|
413 |
+
past_key_values=past_key_values,
|
414 |
+
output_attentions=output_attentions,
|
415 |
+
use_cache=use_cache,
|
416 |
+
**kwargs
|
417 |
+
)
|
418 |
+
hidden_states = layer_outputs[0]
|
419 |
+
latents.append(hidden_states)
|
420 |
+
|
421 |
+
if use_cache:
|
422 |
+
next_cache = layer_outputs[2 if output_attentions else 1]
|
423 |
+
|
424 |
+
if output_attentions:
|
425 |
+
all_attns += (layer_outputs[1],)
|
426 |
+
|
427 |
+
hidden_states = torch.stack(latents, dim=-2) # (B, T, n_heads_to_use, D)
|
428 |
+
hidden_states = self.norm(hidden_states)
|
429 |
+
|
430 |
+
# add hidden states from the last decoder layer
|
431 |
+
if output_hidden_states and not self.custom_backward:
|
432 |
+
all_hidden_states += (hidden_states,)
|
433 |
+
|
434 |
+
if not return_dict:
|
435 |
+
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_attns] if v is not None)
|
436 |
+
|
437 |
+
return BaseModelOutputWithPast(
|
438 |
+
last_hidden_state=hidden_states,
|
439 |
+
past_key_values=next_cache,
|
440 |
+
hidden_states=all_hidden_states,
|
441 |
+
attentions=all_attns
|
442 |
+
)
|
443 |
+
|
444 |
+
|
445 |
+
class MTPTransformerForCausalLM(MTPTransformerPreTrainedModel, GenerationMixin):
|
446 |
+
|
447 |
+
_tied_weights_keys = ["lm_head.weight"]
|
448 |
+
|
449 |
+
def __init__(self, config):
|
450 |
+
super().__init__(config)
|
451 |
+
self.model = MTPTransformerModel(config)
|
452 |
+
self.vocab_size = config.vocab_size
|
453 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
454 |
+
self.criterion = None
|
455 |
+
self.pad_token_id = config.pad_token_id
|
456 |
+
|
457 |
+
# Initialize weights and apply final processing
|
458 |
+
self.post_init()
|
459 |
+
|
460 |
+
def get_input_embeddings(self):
|
461 |
+
return self.model.embeddings
|
462 |
+
|
463 |
+
def set_input_embeddings(self, value):
|
464 |
+
self.model.embeddings = value
|
465 |
+
|
466 |
+
def get_output_embeddings(self):
|
467 |
+
return self.lm_head
|
468 |
+
|
469 |
+
def set_output_embeddings(self, new_embeddings):
|
470 |
+
self.lm_head = new_embeddings
|
471 |
+
|
472 |
+
def set_decoder(self, decoder):
|
473 |
+
self.model = decoder
|
474 |
+
|
475 |
+
def get_decoder(self):
|
476 |
+
return self.model
|
477 |
+
|
478 |
+
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
479 |
+
def prepare_inputs_for_generation(
|
480 |
+
self,
|
481 |
+
input_ids: torch.LongTensor = None,
|
482 |
+
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
483 |
+
attention_mask: Optional[torch.Tensor] = None,
|
484 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
485 |
+
use_cache: bool = True,
|
486 |
+
logits_to_keep: Optional[int] = None,
|
487 |
+
**kwargs
|
488 |
+
):
|
489 |
+
# only last token for `inputs_ids` if the `past_key_values` is not empty.
|
490 |
+
if past_key_values is not None and len(past_key_values) > 0:
|
491 |
+
input_ids = input_ids[:, -1:]
|
492 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
493 |
+
if inputs_embeds is not None and len(past_key_values) == 0:
|
494 |
+
model_inputs = {'inputs_embeds': inputs_embeds}
|
495 |
+
else:
|
496 |
+
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
|
497 |
+
# recompiles graphs as the stride of the inputs is a guard.
|
498 |
+
# Ref: https://github.com/huggingface/transformers/pull/29114
|
499 |
+
# TODO: use `next_tokens` directly instead.
|
500 |
+
model_inputs = {'input_ids': input_ids.contiguous()}
|
501 |
+
|
502 |
+
if logits_to_keep is not None:
|
503 |
+
model_inputs['logits_to_keep'] = logits_to_keep
|
504 |
+
|
505 |
+
model_inputs.update({
|
506 |
+
'past_key_values': past_key_values,
|
507 |
+
'use_cache': use_cache,
|
508 |
+
'attention_mask': attention_mask,
|
509 |
+
})
|
510 |
+
return model_inputs
|
511 |
+
|
512 |
+
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
513 |
+
def forward(
|
514 |
+
self,
|
515 |
+
input_ids: torch.LongTensor = None,
|
516 |
+
attention_mask: Optional[torch.Tensor] = None,
|
517 |
+
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
518 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
519 |
+
labels: Optional[torch.LongTensor] = None,
|
520 |
+
use_cache: Optional[bool] = None,
|
521 |
+
output_attentions: Optional[bool] = None,
|
522 |
+
output_hidden_states: Optional[bool] = None,
|
523 |
+
return_dict: Optional[bool] = None,
|
524 |
+
logits_to_keep: Optional[int] = 0,
|
525 |
+
**kwargs: Unpack[Any]
|
526 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
527 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
528 |
+
output_hidden_states = (
|
529 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
530 |
+
)
|
531 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
532 |
+
|
533 |
+
outputs = self.model(
|
534 |
+
input_ids=input_ids,
|
535 |
+
attention_mask=attention_mask,
|
536 |
+
past_key_values=past_key_values,
|
537 |
+
inputs_embeds=inputs_embeds,
|
538 |
+
use_cache=use_cache,
|
539 |
+
output_attentions=output_attentions,
|
540 |
+
output_hidden_states=output_hidden_states,
|
541 |
+
return_dict=return_dict,
|
542 |
+
return_all_heads=self.training,
|
543 |
+
**kwargs
|
544 |
+
)
|
545 |
+
|
546 |
+
hidden_states = outputs[0] # (B, T, n_heads_to_use, D)
|
547 |
+
|
548 |
+
all_logits = None
|
549 |
+
if not self.training:
|
550 |
+
if hidden_states.dim() == 4:
|
551 |
+
hidden_states = hidden_states.squeeze(-2) # Remove the n_heads_to_use dimension if not training
|
552 |
+
all_logits = self.lm_head(hidden_states)
|
553 |
+
else:
|
554 |
+
fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
|
555 |
+
use_custom_backward = self.config.use_custom_backward and self.training
|
556 |
+
if use_custom_backward:
|
557 |
+
all_logits = SequentialHeadsCustomBackward.apply(
|
558 |
+
hidden_states, self.lm_head, self.model.norm, logits_to_keep, *self.model.extra_heads
|
559 |
+
)
|
560 |
+
else:
|
561 |
+
all_logits = None if fuse_linear_and_cross_entropy else self.lm_head(hidden_states[:, -logits_to_keep:])
|
562 |
+
|
563 |
+
loss = None
|
564 |
+
if labels is not None:
|
565 |
+
B, T, n_heads_prediction, D = hidden_states.shape
|
566 |
+
loss = torch.zeros(1, device=hidden_states.device)
|
567 |
+
ntp_loss = torch.zeros(1, device=hidden_states.device)
|
568 |
+
mtp_loss = torch.zeros(1, device=hidden_states.device)
|
569 |
+
if getattr(self, 'criterion', None) is None:
|
570 |
+
if fuse_linear_and_cross_entropy:
|
571 |
+
criterion = FusedLinearCrossEntropyLoss()
|
572 |
+
elif self.config.fuse_cross_entropy:
|
573 |
+
criterion = FusedCrossEntropyLoss(inplace_backward=True)
|
574 |
+
else:
|
575 |
+
criterion = nn.CrossEntropyLoss()
|
576 |
+
else:
|
577 |
+
criterion = self.criterion
|
578 |
+
# Enable model parallelism
|
579 |
+
labels = labels.to(hidden_states.device)
|
580 |
+
all_labels = seq_to_mtp(labels, n_future_tokens=n_heads_prediction, model_seq_len=T)
|
581 |
+
# Loop across prediction heads
|
582 |
+
for i in range(n_heads_prediction):
|
583 |
+
# labels in the shape of (B, n_heads_prediction, T)
|
584 |
+
labels = all_labels[:, i, :]
|
585 |
+
if fuse_linear_and_cross_entropy:
|
586 |
+
current_loss = criterion(hidden_states[:, :, i, :], labels.contiguous(), self.lm_head.weight, self.lm_head.bias)
|
587 |
+
else:
|
588 |
+
logits = all_logits[:, :, i, :]
|
589 |
+
current_loss = criterion(logits.view(labels.numel(), -1), labels.reshape(-1))
|
590 |
+
if i == 0: # NTP
|
591 |
+
ntp_loss = current_loss
|
592 |
+
else:
|
593 |
+
mtp_loss += current_loss
|
594 |
+
loss += current_loss
|
595 |
+
|
596 |
+
if not return_dict:
|
597 |
+
output = (all_logits,) + outputs[1:]
|
598 |
+
return (loss,) + output if loss is not None else output
|
599 |
+
|
600 |
+
return MTPLMOutputWithPast(
|
601 |
+
loss=loss,
|
602 |
+
ntp_loss=ntp_loss if loss is not None else None,
|
603 |
+
mtp_loss=mtp_loss if loss is not None else None,
|
604 |
+
logits=all_logits,
|
605 |
+
past_key_values=outputs.past_key_values,
|
606 |
+
hidden_states=outputs.hidden_states,
|
607 |
+
attentions=outputs.attentions,
|
608 |
+
)
|
fla/models/transformer_top/__pycache__/modeling_transformer.cpython-312.pyc
ADDED
Binary file (18.7 kB). View file
|
|
fla/modules/__pycache__/feature_map.cpython-312.pyc
ADDED
Binary file (17.6 kB). View file
|
|
logs/none_yagntt11/attempt_0/0/stderr.log
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ef32c39ad6f7ca02c833bf4d4f8196743faf7f289a49798fb5b451327ea3b019
|
3 |
+
size 25537738
|
logs/none_yagntt11/attempt_0/1/stderr.log
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e2b1dba31ba3e451fdd9b07d71992f0a323cc218972d1006358219cbe1b65db2
|
3 |
+
size 15389397
|
logs/none_yagntt11/attempt_0/2/stderr.log
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b06065c7a987a777cf34d02283c391ce448589bf548d0054c2fb2360d9bd0f84
|
3 |
+
size 15389394
|
logs/none_yagntt11/attempt_0/3/stderr.log
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d2ae4d559f5ace303133096b8d857c05c92e5ba11ef04042d664753267a5c871
|
3 |
+
size 15448342
|
logs/none_yagntt11/attempt_0/4/stderr.log
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:898546e9d6c18959efdf0095393dcb46d73c602f8ff6ceb3acd6230bf9dc8198
|
3 |
+
size 15389392
|
logs/none_yagntt11/attempt_0/5/stderr.log
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fb65879d3b844e10b4a4a278450b49fb28cdddea127970882da65bf2d008f18d
|
3 |
+
size 15389393
|
logs/none_yagntt11/attempt_0/5/stdout.log
ADDED
File without changes
|
logs/none_yagntt11/attempt_0/6/stderr.log
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cfa0202f405108b85b5c3d3c21eb0311b7cbdb7503729d63d74b8b7260a96093
|
3 |
+
size 15389394
|
logs/none_yagntt11/attempt_0/7/stderr.log
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9d4af7c40646947d1d997442de38fc8e1b37755f4c4347ba8b758bad92f9df4d
|
3 |
+
size 15389389
|
model-00001-of-00002.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:14bf00f42c711dafca4cd11373c3e7eee50c53323ce810d9b4b4893e77c76b68
|
3 |
+
size 4989532648
|
model-00002-of-00002.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2fcd72cf1fcb3823fb818e67d13418d894973df4a14cf4a8bc0af1cc9466c20c
|
3 |
+
size 2111988680
|
tb/20250716-2210/wandb/run-20250716_221000-mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/files/output.log
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6eff1add295e31de6ccda5b78a6d7949ea83e385c5400a7138cc4ee5c6078f7a
|
3 |
+
size 15411430
|
tb/20250716-2210/wandb/run-20250716_221000-mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201/run-mtp_transformer-mtp.1B.batch16.seqlen4096.context4096.warmup2000.update1.steps200000.lr2e-4.cosine-202507162201.wandb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b3c4838163e23cde9bacb4a0b52d47fb6cad45de8a3efa9704359a68e119d037
|
3 |
+
size 265364709
|