xue wang commited on
Commit
20231c4
·
verified ·
1 Parent(s): 7138a59

Upload 4 files

Browse files
Files changed (4) hide show
  1. config.json +90 -0
  2. model.py +1713 -0
  3. model.safetensors +3 -0
  4. model_config.py +100 -0
config.json ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+
3
+ "architectures": [
4
+ "YingLong"
5
+ ],
6
+
7
+ "auto_map": {
8
+ "AutoConfig": "model_config.YingLongConfig",
9
+ "AutoModelForCausalLM": "model.GPT"
10
+ },
11
+ "_mlp_class": "LLaMAMLP",
12
+ "_norm_class": "FusedRMSNorm",
13
+ "average_2": false,
14
+ "betas": [
15
+ 0.9,
16
+ 0.95
17
+ ],
18
+ "bias": false,
19
+ "block_size": 8224,
20
+ "block_size_val": 2080,
21
+ "condense_ratio": 1,
22
+ "decay_lr": true,
23
+ "discount": true,
24
+ "eval_iters": 100,
25
+ "eval_step_interval": 1000,
26
+ "forecasting_patch": 1,
27
+ "global_batch_size": 512,
28
+ "grad_clip": 1.0,
29
+ "group": "70m-test_all",
30
+ "haar_loss_match": false,
31
+ "haar_trans": true,
32
+ "haar_trans_inv": true,
33
+ "haar_trans_norm": "backward",
34
+ "half_diff": false,
35
+ "imputation": false,
36
+ "inner_norm": false,
37
+ "inter_control": false,
38
+ "intermediate_size": 2048,
39
+ "is_diff": false,
40
+ "is_smape": false,
41
+ "learning_rate": 0.0005,
42
+ "log_step_interval": 10,
43
+ "max_step": 100000,
44
+ "mean_replace": false,
45
+ "micro_batch_size": 128,
46
+ "micro_batch_size_val": 512,
47
+ "min_lr": 1e-05,
48
+ "mix_train": false,
49
+ "multi_loss": false,
50
+ "n_cot": 1,
51
+ "n_embd": 512,
52
+ "n_head": 16,
53
+ "n_layer": 8,
54
+ "n_query_groups": 4,
55
+ "name": "50m-unet",
56
+ "new_arch": false,
57
+ "new_tokenizer": false,
58
+ "norm_eps": 1e-05,
59
+ "num_of_devices": 4,
60
+ "num_of_nodes": 1,
61
+ "org": "Ali-Could",
62
+ "ou": false,
63
+ "ou_mean": false,
64
+ "ou_prod": false,
65
+ "padded_vocab_size": "None",
66
+ "padding_multiple": 1,
67
+ "parallel_residual": false,
68
+ "patch_size": 32,
69
+ "pid": false,
70
+ "quantitle": true,
71
+ "rollback_win": 256,
72
+ "rolling_patch": 6,
73
+ "rope_base": 10000,
74
+ "rotary_percentage": 1.0,
75
+ "save_step_interval": 10000,
76
+ "scaling": true,
77
+ "seed0": 3407,
78
+ "shared_attention_norm": false,
79
+ "stats_encoding": false,
80
+ "stats_encoding_new": false,
81
+ "sum_divided": false,
82
+ "triple_diff": false,
83
+ "triple_diff_new": false,
84
+ "unet": true,
85
+ "vocab_size": 1,
86
+ "vq": false,
87
+ "warmup_steps": 2000,
88
+ "weight_decay": 0.1,
89
+ "yj_trans": false
90
+ }
model.py ADDED
@@ -0,0 +1,1713 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Full definition of a GPT NeoX Language Model, all of it in this single file.
2
+
3
+ Based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT and
4
+ https://github.com/EleutherAI/gpt-neox/tree/main/megatron/model.
5
+ """
6
+ import math, random
7
+ import numpy as np
8
+ from typing import Any, List, Optional, Tuple
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ from lightning_utilities.core.imports import RequirementCache
13
+ from typing_extensions import Self
14
+ from flash_attn import flash_attn_func
15
+ # from lit_gpt.config import Config
16
+ from xformers.ops import SwiGLU
17
+
18
+ import torch.nn.functional as F
19
+ # from .fused_rotary_embedding import apply_rotary_emb_func
20
+ RoPECache = Tuple[torch.Tensor, torch.Tensor]
21
+ KVCache = Tuple[torch.Tensor, torch.Tensor]
22
+ PretokenCache = torch.Tensor
23
+ # Tuple[torch.Tensor, torch.Tensor]
24
+ FlashAttention2Available = RequirementCache("flash-attn>=2.0.0.post1")
25
+ from einops import rearrange
26
+ from transformers import PreTrainedModel, Cache, DynamicCache
27
+
28
+ from huggingface_hub import PyTorchModelHubMixin
29
+ from .model_config import YingLongConfig
30
+
31
+ # from torch.distributions import Normal, LowRankMultivariateNormal, kl_divergence,MultivariateNormal
32
+
33
+ class quantitleLoss(torch.nn.Module):
34
+ def __init__(self,
35
+ qSize = 99,
36
+ patch_size = 16,
37
+ *args,**kwargs) -> None:
38
+
39
+ super().__init__()
40
+ self.qSize = qSize
41
+ self.patch_size = patch_size
42
+
43
+
44
+ q = np.array([i+1 for i in range(self.qSize)])
45
+ q = q / (self.qSize + 1)
46
+ q = q.reshape((1,1,-1))
47
+
48
+ q_variance = q*(1-q)
49
+
50
+ self.register_buffer('q', torch.tensor(q))
51
+ self.register_buffer('q_variance', torch.tensor(q_variance))
52
+
53
+
54
+ def forward(self, input: torch.Tensor, target: torch.Tensor,rel_loss = False) -> torch.Tensor:
55
+
56
+
57
+
58
+ target = target.unsqueeze(-1)
59
+ input = input[:,:target.shape[1],:,:]
60
+
61
+
62
+ posPart = input - target
63
+ negPart = -posPart
64
+
65
+ raw_loss = torch.maximum(self.q * negPart, (1-self.q) * posPart)
66
+
67
+ target_absmean = torch.mean(target.abs(),dim = (1,2),keepdims = True)
68
+ raw_loss = raw_loss / torch.sqrt(self.q_variance) / (target_absmean + 1e-4)
69
+
70
+ return torch.mean(raw_loss)
71
+
72
+
73
+ def haarMatrix_unnormalized(n):
74
+ # Allow only size n of power 2
75
+ n = 2**np.ceil(np.log2(n))
76
+ if n > 2:
77
+ h = haarMatrix(n / 2)
78
+ else:
79
+ return np.array([[1, 1], [1, -1]])
80
+
81
+ # calculate upper haar part
82
+ h_n = np.kron(h, [1, 1])
83
+ # calculate lower haar part
84
+ # if normalized:
85
+ # h_i = np.sqrt(n/2)*np.kron(np.eye(len(h)), [1, -1])
86
+ # else:
87
+ h_i = np.kron(np.eye(len(h)), [1, -1])
88
+ # combine parts
89
+ h = np.vstack((h_n, h_i))
90
+ return h
91
+
92
+
93
+ def haarMatrix(n,normalized = 'ortho'):
94
+ h = haarMatrix_unnormalized(n)
95
+ scaler = np.diag(1/np.sqrt(np.diag([email protected]())))
96
+ if normalized == 'ortho':
97
+ return scaler @ h
98
+ elif normalized == 'forward':
99
+ return scaler @ h/ np.sqrt(n)
100
+
101
+ else:
102
+ return scaler @ h * np.sqrt(n)
103
+ # else:
104
+ # scaler = 1
105
+
106
+
107
+
108
+
109
+
110
+ class Tokenizer(torch.nn.Module):
111
+ def __init__(self, config: YingLongConfig, *args,**kwargs) -> None:
112
+ super().__init__()
113
+
114
+ self.config = config
115
+ self.tokenizer = nn.Linear(config.patch_size,self.config.n_embd)
116
+
117
+ self.patch_size = config.patch_size
118
+ self.mask0 = nn.Linear(1,config.n_embd)
119
+
120
+ self.register_buffer('mask_token', torch.zeros(1000))
121
+ if self.config.haar_trans:
122
+ self.register_buffer('haar_transform',torch.Tensor(haarMatrix(self.config.patch_size,normalized = self.config.haar_trans_norm)))
123
+
124
+
125
+
126
+ def forward(self,x,
127
+ future_token = 0,
128
+ prev_token = 0,
129
+ factor = 0.2,
130
+ sequential = False,
131
+ *args, **kwargs):
132
+
133
+
134
+ b = x.shape[0]
135
+
136
+ x_raw = rearrange(x, "b (l c) -> b l c", c = self.patch_size)
137
+ x_raw_0 = rearrange(x, "b (l c) -> b l c", c = self.patch_size).detach().clone()
138
+
139
+ if future_token == 0:
140
+ if not sequential:
141
+ masks = torch.randperm(x_raw.shape[1])
142
+ unmasks,masks = masks[:int(x_raw.shape[1]*factor)],masks[int(x_raw.shape[1]*factor):]
143
+ else:
144
+ masks = [_ for _ in range(x_raw.shape[1])]
145
+ factor = np.random.rand()*0.6 + 0.2
146
+ unmasks,masks = masks[:int(x_raw.shape[1]*factor)],masks[int(x_raw.shape[1]*factor):]
147
+
148
+
149
+
150
+ x_raw_remains = x_raw[:,unmasks,:]
151
+
152
+ mean = x_raw_remains.mean(dim = (-2,-1),keepdims = True)
153
+ std = x_raw_remains.std(dim = (-2,-1),keepdims = True)
154
+ x_raw = (x_raw - mean)/ (std + 1e-4)
155
+
156
+
157
+ if self.config.haar_trans:
158
+ x_featured = torch.einsum('blc,ac->bla',x_raw,self.haar_transform)
159
+ x_featured = self.tokenizer(x_featured)
160
+ else:
161
+ x_featured = self.tokenizer(x_raw)
162
+
163
+
164
+ x_featured[:,masks,:] = self.mask0(self.mask_token[0].unsqueeze(0))
165
+
166
+
167
+
168
+ else:
169
+
170
+
171
+ factor = 1
172
+ more_rows = future_token // self.patch_size + 1
173
+ prev_more_rows = prev_token // self.patch_size + 1
174
+
175
+ mean = x_raw[:,prev_more_rows:-more_rows,:].mean(dim = (-2,-1),keepdims = True)
176
+ std = x_raw[:,prev_more_rows:-more_rows,:].std(dim = (-2,-1),keepdims = True)
177
+ x_raw = (x_raw - mean)/ (std + 1e-4)
178
+
179
+
180
+ if self.config.haar_trans:
181
+ x_featured = torch.einsum('blc,ac->bla',x_raw,self.haar_transform)
182
+ x_featured = self.tokenizer(x_featured)
183
+ else:
184
+ x_featured = self.tokenizer(x_raw)
185
+
186
+
187
+ masks = [jj for jj in range(x_featured.shape[1])]
188
+ masks = masks[-more_rows:]
189
+
190
+ # if not mean_replace:
191
+ x_featured[:,-more_rows:] = self.mask0(self.mask_token[:len(masks)].unsqueeze(-1)).repeat(x_featured.shape[0],1,1)
192
+ x_featured[:,:prev_more_rows] = self.mask0(self.mask_token[:prev_more_rows].unsqueeze(-1)).repeat(x_featured.shape[0],1,1)
193
+
194
+
195
+ return x_featured, x_raw_0, masks, mean, std, x_raw
196
+
197
+
198
+
199
+ class model_tmp(PreTrainedModel):
200
+ config_class = YingLongConfig
201
+ base_model_prefix = "model"
202
+ # supports_gradient_checkpointing = True
203
+ # _no_split_modules = ["TimeMoeDecoderLayer"]
204
+ # _skip_keys_device_placement = "past_key_values"
205
+ _supports_flash_attn_2 = True
206
+ _supports_sdpa = False
207
+ _supports_cache_class = True
208
+
209
+ # class GPT(nn.Module,PreTrainedModel,PyTorchModelHubMixin):
210
+ class GPT(model_tmp):
211
+ def __init__(self, config: YingLongConfig, *args,**kwargs) -> None:
212
+
213
+
214
+ # config_class = YingLongConfig
215
+ # base_model_prefix = "model"
216
+ # # supports_gradient_checkpointing = True
217
+ # # _no_split_modules = ["TimeMoeDecoderLayer"]
218
+ # # _skip_keys_device_placement = "past_key_values"
219
+ # _supports_flash_attn_2 = True
220
+ # _supports_sdpa = False
221
+ # _supports_cache_class = True
222
+ super().__init__(config)
223
+
224
+ self.config = config
225
+ self.patch_size = config.patch_size
226
+ self.unet = config.unet
227
+
228
+
229
+ if self.config._norm_class == "RMSNorm":
230
+ # from .model import RMSNorm
231
+ self.config.norm_class = RMSNorm
232
+ elif self.config._norm_class == "FusedRMSNorm":
233
+ # from .model import FusedRMSNorm
234
+ self.config.norm_class = FusedRMSNorm
235
+ elif self.config._norm_class == 'BatchNorm':
236
+ # from .model import iBatchNorm
237
+ self.config.norm_class = iBatchNorm
238
+
239
+
240
+
241
+ if self.config._mlp_class == "GptNeoxMLP":
242
+ # from .model import GptNeoxMLP
243
+ self.config.mlp_class = GptNeoxMLP
244
+ elif self.config._mlp_class == "LLaMAMLP":
245
+ # from .model import LLaMAMLP
246
+ self.config.mlp_class = LLaMAMLP
247
+
248
+
249
+ if config.stats_encoding:
250
+ self.stat_tokens = 1
251
+ else:
252
+ self.stat_tokens = 0
253
+
254
+
255
+
256
+
257
+
258
+ self.tokenizer = Tokenizer(config)
259
+
260
+ # self.lm_head = nn.Sequential(config.norm_class(config.n_embd, eps=config.norm_eps),
261
+ # nn.Linear(config.n_embd, config.n_embd*4),
262
+ # nn.ReLU(),
263
+ # nn.Linear(config.n_embd*4, 99*self.patch_size),
264
+ # )
265
+
266
+
267
+ self.lm_head = nn.Linear(config.n_embd, 99*self.patch_size)
268
+
269
+
270
+ # self.gate = nn.Linear(config.n_embd, 1)
271
+
272
+
273
+ self.quantitleLoss = quantitleLoss(99,patch_size = self.patch_size)
274
+
275
+
276
+
277
+ if self.unet:
278
+ assert config.n_layer%2 == 0
279
+ self.unet_projection = nn.ModuleList(nn.Sequential(nn.Linear(config.n_embd*2,config.n_embd),
280
+ config.norm_class(config.n_embd, eps=config.norm_eps),
281
+ )
282
+ for _ in range(config.n_layer//2)
283
+ )
284
+ self.unet_merge = nn.ModuleList(nn.Sequential(nn.Linear(config.n_embd*2,config.n_embd),
285
+ config.norm_class(config.n_embd, eps=config.norm_eps),
286
+ )
287
+ for _ in range(config.n_layer//2)
288
+ )
289
+
290
+
291
+
292
+ self.transformer = nn.ModuleDict(dict(h = nn.ModuleList(Block(config)
293
+ for _ in range(config.n_layer))
294
+ )
295
+ )
296
+
297
+
298
+
299
+ self.rope_cache: Optional[RoPECache] = None
300
+ self.mask_cache: Optional[torch.Tensor] = None
301
+ self.kv_caches: List[KVCache] = []
302
+
303
+
304
+ def _init_weights(self, module: nn.Module) -> None:
305
+ """Meant to be used with `gpt.apply(gpt._init_weights)`."""
306
+ # GPT-NeoX https://arxiv.org/pdf/2204.06745.pdf
307
+ if isinstance(module, nn.Embedding):
308
+ torch.nn.init.normal_(module.weight, mean=0.0, std=math.sqrt(2.0 / 5 / self.config.n_embd))
309
+ # RWKV: set it to 1e-4
310
+ # torch.nn.init.uniform_(module.weight, -1e-4, 1e-4)
311
+ elif isinstance(module, nn.Linear):
312
+ torch.nn.init.normal_(module.weight, mean=0.0, std=math.sqrt(2.0 / 5 / self.config.n_embd))
313
+ if module.bias is not None:
314
+ torch.nn.init.zeros_(module.bias)
315
+ # GPT-NeoX
316
+ for name, p in module.named_parameters():
317
+ if (name == "proj.weight" and isinstance(module, LLaMAMLP)) or (name == "w3.weight" and isinstance(module, SwiGLU) or (name=="proj.weight" and isinstance(module, BidirectedlSelfAttention))): #if use xformer swiglu, fc2 layer will be renamed to w3
318
+ nn.init.normal_(p, mean=0.0, std=1 / math.sqrt(self.config.n_embd) / self.config.n_layer)
319
+
320
+
321
+ def reset_cache(self) -> None:
322
+ self.kv_caches.clear()
323
+ if self.mask_cache is not None and self.mask_cache.device.type == "xla":
324
+ # https://github.com/Lightning-AI/lit-gpt/pull/83#issuecomment-1558150179
325
+ self.rope_cache = None
326
+ self.mask_cache = None
327
+
328
+ def forward(
329
+ self, idx: torch.Tensor,
330
+ max_seq_length: Optional[int] = None,
331
+ input_pos: Optional[torch.Tensor] = None,
332
+ next_token: torch.Tensor = None,
333
+ future_token: int = 0,
334
+ prev_token: int = 0,
335
+ val: bool = False,
336
+ print_intermediate: bool = False,
337
+ cot_rounds: int = -1,
338
+ sequential: bool = False,
339
+ *args,**kwargs,
340
+ ) -> torch.Tensor:
341
+
342
+ if future_token > 0:
343
+ more_rows = future_token // self.patch_size + 1
344
+ idx = torch.cat((idx,torch.zeros(idx.shape[0],more_rows*self.patch_size).to(idx.device)),dim = -1).bfloat16()
345
+ if prev_token > 0:
346
+ more_rows = prev_token // self.patch_size + 1
347
+ idx = torch.cat((torch.zeros(idx.shape[0],more_rows*self.patch_size).to(idx.device),idx),dim = -1).bfloat16()
348
+
349
+ B, T = idx.size()
350
+
351
+ use_kv_cache = input_pos is not None
352
+
353
+ block_size = self.config.block_size
354
+ if max_seq_length is None:
355
+ max_seq_length = block_size
356
+
357
+
358
+ if use_kv_cache: # not relevant otherwise
359
+ assert (
360
+ max_seq_length >= T
361
+ ), f"Cannot forward sequence of length {T}, max seq length is only {max_seq_length}"
362
+ assert max_seq_length <= block_size, f"Cannot attend to {max_seq_length}, block size is only {block_size}"
363
+ if self.rope_cache is None:
364
+ self.rope_cache = self.build_rope_cache(idx)
365
+ if use_kv_cache and self.mask_cache is None:
366
+ self.mask_cache = self.build_mask_cache(idx)
367
+ cos, sin = self.rope_cache
368
+ if use_kv_cache:
369
+ if self.stat_tokens:
370
+ if len(input_pos) == 1:
371
+ idx = idx[:,input_pos]
372
+ input_pos = input_pos.add_(1)
373
+ else:
374
+ input_pos = torch.arange(0, input_pos[-1]+2, device=idx.device)
375
+
376
+ cos = cos.index_select(0, input_pos)
377
+ sin = sin.index_select(0, input_pos)
378
+ mask = self.mask_cache.index_select(2, input_pos)
379
+ mask = mask[:, :, :, :max_seq_length]
380
+
381
+ else:
382
+ cos = cos.index_select(0, input_pos)
383
+ sin = sin.index_select(0, input_pos)
384
+ idx = idx[:,input_pos]
385
+ else:
386
+ cos = cos[:max(T,1024) + self.stat_tokens]
387
+ sin = sin[:max(T,1024) + self.stat_tokens]
388
+ mask = None
389
+
390
+ idx_ori = idx
391
+
392
+
393
+
394
+ if use_kv_cache:
395
+ pass
396
+ else:
397
+ x,x_raw,masks,mean,std,x_0 = self.tokenizer(idx,
398
+ future_token =future_token,
399
+ prev_token = prev_token,
400
+ sequential = sequential,
401
+ )
402
+
403
+
404
+
405
+
406
+ if self.unet:
407
+ skips = []
408
+
409
+
410
+
411
+ res_intermediate = []
412
+ target_intermediate = []
413
+ if not use_kv_cache:
414
+
415
+ if cot_rounds <0:
416
+ cot_rounds = self.config.n_cot
417
+
418
+ res_list = []
419
+ gate_list = []
420
+ for rep in range(cot_rounds):
421
+ for block_idx in range(len( self.transformer.h)):
422
+
423
+
424
+
425
+ block = self.transformer.h[block_idx]
426
+
427
+ if self.unet and block_idx >=len(self.transformer.h) //2:
428
+ x = self.unet_projection[block_idx - len(self.transformer.h) //2](torch.cat((skips.pop(),x),dim = -1))
429
+
430
+ x, *_ = block(x, (cos, sin), max_seq_length)
431
+
432
+ if self.unet and block_idx <len(self.transformer.h) //2:
433
+ skips.append(x)
434
+ x_delay = torch.cat((x[:,0,:].unsqueeze(1),x[:,:-1,:]),dim = 1)
435
+ x = self.unet_merge[block_idx](torch.cat((x_delay,x),dim = -1))
436
+ # if block_idx <len(self.transformer.h) //2:
437
+ # x_delay = torch.cat((x[:,0,:].unsqueeze(1),x[:,:-1,:]),dim = 1)
438
+ # x = self.unet_merge[block_idx](torch.cat((x_delay,x),dim = -1))
439
+
440
+
441
+
442
+
443
+ # res_list.append(self.lm_head(x).unsqueeze(-1))
444
+ # gate_list.append(self.gate(x).unsqueeze(-1))
445
+ # gate_list.append(self.gate(x))
446
+ # if print_intermediate:
447
+ # res_intermediate.append(res_list[-1])
448
+ # if print_intermediate:
449
+ # res_tmp = self.lm_head(x[:,self.stat_tokens:])
450
+ # res_tmp = rearrange(res_tmp,'b c (l1 l2) -> b c l1 l2', l2 = 99)
451
+ # if self.config.haar_trans_inv:
452
+
453
+ # res_tmp = torch.einsum('bcal,ad->bcdl',res_tmp,self.tokenizer.haar_transform)
454
+ # if self.config.haar_trans_norm == "backward":
455
+ # res_tmp = res_tmp / np.sqrt(res_tmp.shape[-2])
456
+ # elif self.config.haar_trans_norm == "forward":
457
+ # res_tmp = res_tmp * np.sqrt(res_tmp.shape[-2])
458
+ # res_tmp = res_tmp * (std.unsqueeze(-1) + 1e-4) + mean.unsqueeze(-1)
459
+ # res_intermediate.append(res_tmp[:,masks,:,:])
460
+
461
+
462
+
463
+
464
+
465
+ else:
466
+ self.kv_caches = self.kv_caches or self.build_kv_caches(x, max_seq_length, cos.size(-1) * 2)
467
+ for block_idx in range(len( self.transformer.h)):
468
+ block = self.transformer.h[block_idx]
469
+ if self.unet and block_idx >=len(self.transformer.h) //2:
470
+ x = F.silu(skips.pop()) * x
471
+ x, self.kv_caches[block_idx] = block(x, (cos, sin), max_seq_length, mask, input_pos, self.kv_caches[block_idx])
472
+ if self.unet and block_idx <len(self.transformer.h) //2:
473
+ skips.append(x)
474
+
475
+
476
+
477
+
478
+ res = self.lm_head(x)
479
+ # gate = torch.cat(gate_list,dim = -1)
480
+ # gate = F.softmax(gate,dim = -1)
481
+ # res = torch.cat(res_list,dim = -1) * gate
482
+ # res = res.sum(dim = -1)
483
+
484
+
485
+ res = rearrange(res,'b c (l1 l2) -> b c l1 l2', l2 = 99)
486
+
487
+
488
+
489
+ if self.config.haar_trans_inv:
490
+ # print('res',res.shape,self.tokenizer.haar_transform.shape)
491
+ res = torch.einsum('bcal,ad->bcdl',res,self.tokenizer.haar_transform)
492
+ if self.config.haar_trans_norm == "backward":
493
+ res = res / np.sqrt(res.shape[-2])
494
+ elif self.config.haar_trans_norm == "forward":
495
+ res = res * np.sqrt(res.shape[-2])
496
+
497
+
498
+ res = res * (std.unsqueeze(-1) + 1e-4) + mean.unsqueeze(-1)
499
+
500
+
501
+
502
+
503
+ if future_token == 0:
504
+ return res[:,masks,:,:], x_raw[:,masks,:],res_intermediate,target_intermediate
505
+ else:
506
+ return res[:,masks,:,:],res_intermediate
507
+
508
+ def generate(self,*args,**kwargs):
509
+
510
+ res, _ = self.forward(*args,**kwargs)
511
+ # logits_all,res_intermediate = model(idx = x_train, future_token = (pred_len//32 + 1)* 32, prev_token = 0,print_intermediate = False,cot_rounds = 1)
512
+
513
+ res = rearrange(res, 'b l c d -> b (l c) d')
514
+ return res[:,:kwargs['future_token'],:]
515
+
516
+
517
+ @classmethod
518
+ def from_name(cls, name: str, **kwargs: Any) -> Self:
519
+ return cls(Config.from_name(name, **kwargs))
520
+
521
+ def build_rope_cache(self, idx: torch.Tensor) -> RoPECache:
522
+ return build_rope_cache(
523
+ seq_len=self.config.block_size + self.stat_tokens,
524
+ n_elem=int(self.config.rotary_percentage * self.config.head_size),
525
+ dtype=torch.bfloat16,
526
+ device=idx.device,
527
+ base = self.config.rope_base,
528
+ condense_ratio=self.config.condense_ratio,
529
+ )
530
+
531
+ def build_mask_cache(self, idx: torch.Tensor) -> torch.Tensor:
532
+ ones = torch.ones((self.config.block_size+self.stat_tokens, self.config.block_size+self.stat_tokens), device=idx.device, dtype=torch.bool)
533
+ return torch.tril(ones).unsqueeze(0).unsqueeze(0)
534
+
535
+ def build_kv_caches(self, idx: torch.Tensor, max_seq_length: int, rope_cache_length: int) -> List[KVCache]:
536
+ B = idx.size(0)
537
+ heads = 1 if self.config.n_query_groups == 1 else self.config.n_query_groups
538
+
539
+ k_cache_shape = (
540
+ B,
541
+ max_seq_length,
542
+ heads,
543
+ rope_cache_length + self.config.head_size - int(self.config.rotary_percentage * self.config.head_size),
544
+ )
545
+ v_cache_shape = (B, max_seq_length, heads, self.config.head_size)
546
+ device = idx.device
547
+ return [
548
+ (torch.zeros(k_cache_shape, device=device), torch.zeros(v_cache_shape, device=device))
549
+ for _ in range(self.config.n_layer)
550
+ ]
551
+
552
+
553
+ class Block(nn.Module):
554
+ def __init__(self, config:YingLongConfig) -> None:
555
+ super().__init__()
556
+ self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps)
557
+ self.attn = BidirectedlSelfAttention(config)
558
+ if not config.shared_attention_norm:
559
+ self.norm_2 = config.norm_class(config.n_embd, eps=config.norm_eps)
560
+ self.mlp = config.mlp_class(config)
561
+ self.config = config
562
+ def forward(
563
+ self,
564
+ x: torch.Tensor,
565
+ rope: RoPECache,
566
+ max_seq_length: int,
567
+ mask: Optional[torch.Tensor] = None,
568
+ input_pos: Optional[torch.Tensor] = None,
569
+ kv_cache: Optional[KVCache] = None,
570
+ ) -> Tuple[torch.Tensor, Optional[KVCache]]:
571
+
572
+ n_1 = self.norm_1(x)
573
+ h, new_kv_cache = self.attn(n_1, rope, max_seq_length, mask, input_pos, kv_cache)
574
+ if self.config.parallel_residual:
575
+ n_2 = n_1 if self.config.shared_attention_norm else self.norm_2(x)
576
+ x = x + h + self.mlp(n_2)
577
+ else:
578
+ if self.config.shared_attention_norm:
579
+ raise NotImplementedError(
580
+ "No checkpoint amongst the ones we support uses this configuration"
581
+ " (non-parallel residual and shared attention norm)."
582
+ )
583
+
584
+ x = x + h
585
+ x = x + self.mlp(self.norm_2(x))
586
+ return x, new_kv_cache
587
+
588
+
589
+ class BidirectedlSelfAttention(nn.Module):
590
+ def __init__(self, config:YingLongConfig) -> None:
591
+ super().__init__()
592
+ shape = (config.n_head + 2 * config.n_query_groups) * config.head_size
593
+ # key, query, value projections for all heads, but in a batch
594
+ self.attn = nn.Linear(config.n_embd, shape, bias=config.bias)
595
+ # output projection
596
+ self.proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
597
+
598
+ self.config = config
599
+
600
+ def forward(
601
+ self,
602
+ x: torch.Tensor,
603
+ rope: RoPECache,
604
+ max_seq_length: int,
605
+ mask: Optional[torch.Tensor] = None,
606
+ input_pos: Optional[torch.Tensor] = None,
607
+ kv_cache: Optional[KVCache] = None,
608
+ ) -> Tuple[torch.Tensor, Optional[KVCache]]:
609
+
610
+
611
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
612
+
613
+ qkv = self.attn(x)
614
+
615
+ # assemble into a number of query groups to support MHA, MQA and GQA together (see `config.n_query_groups`)
616
+ q_per_kv = self.config.n_head // self.config.n_query_groups
617
+ total_qkv = q_per_kv + 2 # each group has 1+ queries, 1 key, and 1 value
618
+ qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size) # (B, T, n_query_groups, total_qkv, hs)
619
+ # qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs)
620
+
621
+ # split batched computation into three
622
+ q, k, v = qkv.split((q_per_kv, 1, 1), dim=-2)
623
+
624
+ # repeat k and v if necessary
625
+ # Peiyuan: we do not need to do this as flash attention 2 already support GQA
626
+ # if self.config.n_query_groups != 1: # doing this would require a full kv cache with MQA (inefficient!)
627
+ # # for MHA this is a no-op
628
+ # k = k.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size)
629
+ # v = v.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size)
630
+
631
+ q = q.reshape(B, T, -1, self.config.head_size) # (B, T, nh_q, hs)
632
+ k = k.reshape(B, T, -1, self.config.head_size)
633
+ v = v.reshape(B, T, -1, self.config.head_size)
634
+
635
+ cos, sin = rope
636
+
637
+ # apply rope in fp32 significanly stabalize training
638
+ # fused rope expect (batch_size, seqlen, nheads, headdim)
639
+ q = apply_rotary_emb_func(q, cos, sin, False, True)
640
+ k = apply_rotary_emb_func(k, cos, sin, False, True)
641
+
642
+ # n_elem = int(self.config.rotary_percentage * self.config.head_size)
643
+
644
+ # q_roped = apply_rope(q[..., :n_elem], cos.repeat(1,2), sin.repeat(1,2))
645
+ # k_roped = apply_rope(k[..., :n_elem], cos.repeat(1,2), sin.repeat(1,2))
646
+ # print( (q_roped - q).sum())
647
+ # q = torch.cat((q_roped, q[..., n_elem:]), dim=-1)
648
+ # k = torch.cat((k_roped, k[..., n_elem:]), dim=-1)
649
+
650
+ if kv_cache is not None:
651
+ cache_k, cache_v = kv_cache
652
+ cache_k, cache_v = cache_k.to(dtype=k.dtype), cache_v.to(dtype=v.dtype)
653
+ # check if reached token limit
654
+ if input_pos[-1] >= max_seq_length:
655
+ input_pos = torch.tensor(max_seq_length - 1, device=input_pos.device)
656
+ # shift 1 position to the left
657
+ cache_k = torch.roll(cache_k, -1, dims=1)
658
+ cache_v = torch.roll(cache_v, -1, dims=1)
659
+
660
+ k = cache_k.index_copy_(1, input_pos, k)
661
+ v = cache_v.index_copy_(1, input_pos, v)
662
+ kv_cache = k, v
663
+
664
+
665
+
666
+ y = self.scaled_dot_product_attention(q, k, v, mask=mask)
667
+
668
+ y = y.reshape(B, T, C) # re-assemble all head outputs side by side
669
+
670
+ # output projection
671
+ y = self.proj(y)
672
+
673
+ return y, kv_cache
674
+
675
+ def scaled_dot_product_attention(
676
+ self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None
677
+ ):
678
+ scale = 1.0 / math.sqrt(self.config.head_size)
679
+
680
+ if (
681
+ FlashAttention2Available
682
+ and mask is None
683
+ and q.device.type == "cuda"
684
+ and q.dtype in (torch.float16, torch.bfloat16)
685
+ ):
686
+ from flash_attn import flash_attn_func
687
+
688
+ return flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=scale, causal=False)
689
+ q = q.transpose(1, 2)
690
+ k = k.transpose(1, 2)
691
+ v = v.transpose(1, 2)
692
+ if q.size() != k.size():
693
+ k = k.repeat_interleave(q.shape[1]//k.shape[1], dim=1)
694
+ v = v.repeat_interleave(q.shape[1]//v.shape[1], dim=1)
695
+ y = torch.nn.functional.scaled_dot_product_attention(
696
+ q, k, v, attn_mask=mask, dropout_p=0.0, scale=scale, is_causal=False
697
+ )
698
+ return y.transpose(1, 2)
699
+
700
+
701
+ class GptNeoxMLP(nn.Module):
702
+ def __init__(self, config:YingLongConfig) -> None:
703
+ super().__init__()
704
+ self.fc = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
705
+ self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias)
706
+
707
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
708
+ x = self.fc(x)
709
+ x = torch.nn.functional.gelu(x)
710
+ return self.proj(x)
711
+
712
+
713
+ class LLaMAMLP(nn.Module):
714
+ def __init__(self, config:YingLongConfig) -> None:
715
+ super().__init__()
716
+ # self.fc_1 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
717
+ # self.fc_2 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
718
+ # self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias)
719
+ self.swiglu = SwiGLU(config.n_embd,config.intermediate_size, bias=False, _pack_weights=False)
720
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
721
+ # x_fc_1 = self.fc_1(x)
722
+ # x_fc_2 = self.fc_2(x)
723
+ # x = torch.nn.functional.silu(x_fc_1) * x_fc_2
724
+ # return self.proj(x)
725
+ return self.swiglu(x)
726
+
727
+
728
+ def build_rope_cache(
729
+ seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000, condense_ratio: int = 1
730
+ ) -> RoPECache:
731
+ """Enhanced Transformer with Rotary Position Embedding.
732
+
733
+ Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
734
+ transformers/rope/__init__.py. MIT License:
735
+ https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
736
+ """
737
+ # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
738
+ theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device) / n_elem))
739
+
740
+ # Create position indexes `[0, 1, ..., seq_len - 1]`
741
+ seq_idx = torch.arange(seq_len, device=device) / condense_ratio
742
+
743
+ # Calculate the product of position index and $\theta_i$
744
+ idx_theta = torch.outer(seq_idx, theta)
745
+
746
+ cos, sin = torch.cos(idx_theta), torch.sin(idx_theta)
747
+
748
+ # print(' print(seq_idx.shape,theta.shape,sin.shape,cos.shape,idx_theta.shape)',seq_idx.shape,theta.shape,sin.shape,cos.shape,idx_theta.shape)
749
+ # added by peiyuan to ensure same data type with q, k, to use fused rotary embedding
750
+ if dtype == torch.bfloat16:
751
+ return cos.bfloat16(), sin.bfloat16()
752
+ # this is to mimic the behaviour of complex32, else we will get different results
753
+ if dtype in (torch.float16, torch.bfloat16, torch.int8):
754
+ return cos.half(), sin.half()
755
+ return cos, sin
756
+
757
+
758
+ def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
759
+ head_size = x.size(-1)
760
+ x1 = x[..., : head_size // 2] # (B, nh, T, hs/2)
761
+ x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2)
762
+ rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs)
763
+ roped = (x * cos) + (rotated * sin)
764
+ return roped.type_as(x)
765
+
766
+
767
+
768
+
769
+ import torch
770
+ # Copyright (c) 2022, Tri Dao.
771
+ # Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py AND https://github.com/Dao-AILab/flash-attention/blob/7a983df74215e035e566e37125b0a71e3618f39d/flash_attn/ops/layer_norm.py#L16
772
+
773
+ import dropout_layer_norm
774
+ import torch
775
+ from torch.nn import init
776
+
777
+
778
+ def maybe_align(x, alignment_in_bytes=16):
779
+ """Assume that x already has last dim divisible by alignment_in_bytes"""
780
+ # TD [2023-07-04] I'm not 100% sure that clone will align the memory
781
+ # https://discuss.pytorch.org/t/how-to-ensure-that-tensor-data-ptr-is-aligned-to-16-bytes/183440
782
+ return x if x.data_ptr() % alignment_in_bytes == 0 else x.clone()
783
+
784
+
785
+ def _dropout_add_layer_norm_forward(
786
+ x0,
787
+ residual,
788
+ gamma,
789
+ beta,
790
+ rowscale,
791
+ colscale,
792
+ dropout_p,
793
+ epsilon,
794
+ residual_in_fp32=False,
795
+ is_rms_norm=False,
796
+ ):
797
+ """Assume that arguments are contiguous and aligned to 16 bytes"""
798
+ hidden_size = gamma.numel()
799
+ x0mat = x0.view((-1, hidden_size))
800
+ residualmat = residual.view((-1, hidden_size)) if residual is not None else None
801
+ rowscale = rowscale.view(-1) if rowscale is not None else None
802
+ zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
803
+ x0mat,
804
+ residualmat,
805
+ gamma,
806
+ beta,
807
+ rowscale,
808
+ colscale,
809
+ None,
810
+ None,
811
+ dropout_p,
812
+ epsilon,
813
+ 1.0,
814
+ 0,
815
+ None,
816
+ residual_in_fp32,
817
+ is_rms_norm,
818
+ )
819
+ # dmask is None if dropout_p == 0.0
820
+ # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
821
+ return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma
822
+
823
+
824
+ def _dropout_add_layer_norm_backward(
825
+ dz,
826
+ dx,
827
+ x,
828
+ x0,
829
+ dmask,
830
+ mu,
831
+ rsigma,
832
+ gamma,
833
+ rowscale,
834
+ colscale,
835
+ dropout_p,
836
+ has_residual,
837
+ is_rms_norm=False,
838
+ ):
839
+ """Assume that arguments are contiguous and aligned to 16 bytes
840
+ dx == None means that it was a post-norm architecture
841
+ (x = drop(x0) + residual was not returned in the fwd).
842
+ x0 must not be None if we have colscale.
843
+ """
844
+ hidden_size = gamma.numel()
845
+ xmat = x.view((-1, hidden_size))
846
+ dzmat = dz.view(xmat.shape)
847
+ dxmat = dx.view(xmat.shape) if dx is not None else None
848
+ x0mat = x0.view((-1, hidden_size)) if x0 is not None else None
849
+ rowscale = rowscale.view(-1) if rowscale is not None else None
850
+ if colscale is not None:
851
+ assert x0 is not None, "x0 is required to compute the gradient of colscale"
852
+ dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
853
+ dzmat,
854
+ dxmat,
855
+ xmat,
856
+ x0mat,
857
+ dmask,
858
+ mu,
859
+ rsigma,
860
+ gamma,
861
+ rowscale,
862
+ colscale,
863
+ None,
864
+ None,
865
+ dropout_p,
866
+ 1.0,
867
+ 0,
868
+ has_residual,
869
+ is_rms_norm,
870
+ )
871
+ # dresidualmat is None if not has_residual
872
+ if colscale is None:
873
+ return dx0mat, dresidualmat, dgamma, dbeta
874
+ else:
875
+ dcolscale = rest[0]
876
+ return dx0mat, dresidualmat, dgamma, dbeta, dcolscale
877
+
878
+
879
+ def _dropout_add_layer_norm_subset_forward(
880
+ x0,
881
+ residual,
882
+ gamma,
883
+ beta,
884
+ colscale,
885
+ x0_subset,
886
+ out_subset,
887
+ dropout_p,
888
+ epsilon,
889
+ rowscale_const,
890
+ out_numrows,
891
+ residual_in_fp32=False,
892
+ is_rms_norm=False,
893
+ ):
894
+ """Assume that arguments are contiguous and aligned to 16 bytes"""
895
+ hidden_size = gamma.numel()
896
+ x0mat = x0.view((-1, hidden_size))
897
+ residualmat = residual.view((-1, hidden_size)) if residual is not None else None
898
+ x0_subset = x0_subset.view(-1) if x0_subset is not None else None
899
+ out_subset = out_subset.view(-1) if out_subset is not None else None
900
+ zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
901
+ x0mat,
902
+ residualmat,
903
+ gamma,
904
+ beta,
905
+ None,
906
+ colscale,
907
+ x0_subset,
908
+ out_subset,
909
+ dropout_p,
910
+ epsilon,
911
+ rowscale_const,
912
+ out_numrows,
913
+ None,
914
+ residual_in_fp32,
915
+ is_rms_norm,
916
+ )
917
+ # dmask is None if dropout_p == 0.0
918
+ # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
919
+ return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma
920
+
921
+
922
+ def _dropout_add_layer_norm_subset_backward(
923
+ dz,
924
+ dx,
925
+ x,
926
+ x0,
927
+ dmask,
928
+ mu,
929
+ rsigma,
930
+ gamma,
931
+ colscale,
932
+ x0_subset,
933
+ out_subset,
934
+ dropout_p,
935
+ rowscale_const,
936
+ x0_numrows,
937
+ has_residual,
938
+ is_rms_norm=False,
939
+ ):
940
+ """Assume that arguments are contiguous and aligned to 16 bytes
941
+ dx == None means that it was a post-norm architecture
942
+ (x = drop(x0) + residual was not returned in the fwd).
943
+ x0 must not be None if we have colscale.
944
+ """
945
+ hidden_size = gamma.numel()
946
+ xmat = x.view((-1, hidden_size))
947
+ dzmat = dz.view(-1, hidden_size)
948
+ dxmat = dx.view(xmat.shape) if dx is not None else None
949
+ x0mat = x0.view((-1, hidden_size)) if x0 is not None else None
950
+ x0_subset = x0_subset.view(-1) if x0_subset is not None else None
951
+ out_subset = out_subset.view(-1) if out_subset is not None else None
952
+ if colscale is not None:
953
+ assert x0 is not None, "x0 is required to compute the gradient of colscale"
954
+ dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
955
+ dzmat,
956
+ dxmat,
957
+ xmat,
958
+ x0mat,
959
+ dmask,
960
+ mu,
961
+ rsigma,
962
+ gamma,
963
+ None,
964
+ colscale,
965
+ x0_subset,
966
+ out_subset,
967
+ dropout_p,
968
+ rowscale_const,
969
+ x0_numrows,
970
+ has_residual,
971
+ is_rms_norm,
972
+ )
973
+ # dresidualmat is None if not has_residual
974
+ if colscale is None:
975
+ return dx0mat, dresidualmat, dgamma, dbeta
976
+ else:
977
+ dcolscale = rest[0]
978
+ return dx0mat, dresidualmat, dgamma, dbeta, dcolscale
979
+
980
+
981
+ def _dropout_add_layer_norm_parallel_residual_forward(
982
+ x0,
983
+ x1,
984
+ residual,
985
+ gamma0,
986
+ beta0,
987
+ gamma1,
988
+ beta1,
989
+ dropout_p,
990
+ epsilon,
991
+ residual_in_fp32=False,
992
+ is_rms_norm=False,
993
+ ):
994
+ """Assume that arguments are contiguous and aligned to 16 bytes"""
995
+ hidden_size = gamma0.numel()
996
+ x0mat = x0.view((-1, hidden_size))
997
+ x1mat = x1.view((-1, hidden_size)) if x1 is not None else None
998
+ residualmat = residual.view((-1, hidden_size)) if residual is not None else None
999
+ (
1000
+ z0mat,
1001
+ z1mat,
1002
+ xmat,
1003
+ dmask0,
1004
+ dmask1,
1005
+ mu,
1006
+ rsigma,
1007
+ ) = dropout_layer_norm.dropout_add_ln_parallel_residual_fwd(
1008
+ x0mat,
1009
+ x1mat,
1010
+ residualmat,
1011
+ gamma0,
1012
+ beta0,
1013
+ gamma1,
1014
+ beta1,
1015
+ dropout_p,
1016
+ epsilon,
1017
+ None,
1018
+ residual_in_fp32,
1019
+ is_rms_norm,
1020
+ )
1021
+ # dmask0 and dmask1 are None if dropout_p == 0.0
1022
+ # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
1023
+ return z0mat, z1mat, xmat if xmat is not None else x0mat, dmask0, dmask1, mu, rsigma
1024
+
1025
+
1026
+ def _dropout_add_layer_norm_parallel_residual_backward(
1027
+ dz0,
1028
+ dz1,
1029
+ dx,
1030
+ x,
1031
+ dmask0,
1032
+ dmask1,
1033
+ mu,
1034
+ rsigma,
1035
+ gamma0,
1036
+ gamma1,
1037
+ dropout_p,
1038
+ has_x1,
1039
+ has_residual,
1040
+ is_rms_norm=False,
1041
+ ):
1042
+ """Assume that arguments are contiguous and aligned to 16 bytes
1043
+ dx == None means that it was a post-norm architecture
1044
+ (x = drop(x0) + residual was not returned in the fwd).
1045
+ """
1046
+ hidden_size = gamma0.numel()
1047
+ xmat = x.view((-1, hidden_size))
1048
+ dz0mat = dz0.view(xmat.shape)
1049
+ dz1mat = dz1.view(xmat.shape) if dz1 is not None else None
1050
+ dxmat = dx.view(xmat.shape) if dx is not None else None
1051
+ (
1052
+ dx0mat,
1053
+ dx1mat,
1054
+ dresidualmat,
1055
+ dgamma0,
1056
+ dbeta0,
1057
+ dgamma1,
1058
+ dbeta1,
1059
+ *rest,
1060
+ ) = dropout_layer_norm.dropout_add_ln_parallel_residual_bwd(
1061
+ dz0mat,
1062
+ dz1mat,
1063
+ dxmat,
1064
+ xmat,
1065
+ dmask0,
1066
+ dmask1,
1067
+ mu,
1068
+ rsigma,
1069
+ gamma0,
1070
+ gamma1,
1071
+ dropout_p,
1072
+ has_x1,
1073
+ has_residual,
1074
+ is_rms_norm,
1075
+ )
1076
+ # dresidualmat is None if not has_residual
1077
+ return dx0mat, dx1mat, dresidualmat, dgamma0, dbeta0, dgamma1, dbeta1
1078
+
1079
+
1080
+ class DropoutAddLayerNormFn(torch.autograd.Function):
1081
+ @staticmethod
1082
+ def forward(
1083
+ ctx,
1084
+ x0,
1085
+ residual,
1086
+ gamma,
1087
+ beta,
1088
+ rowscale,
1089
+ colscale,
1090
+ dropout_p,
1091
+ epsilon,
1092
+ residual_in_fp32=False,
1093
+ prenorm=False,
1094
+ is_rms_norm=False,
1095
+ return_dmask=False,
1096
+ ):
1097
+ x0 = maybe_align(x0.contiguous(), 16)
1098
+ residual = maybe_align(residual.contiguous(), 16) if residual is not None else None
1099
+ gamma = maybe_align(gamma.contiguous(), 16)
1100
+ beta = maybe_align(beta.contiguous(), 16) if beta is not None else None
1101
+ rowscale = maybe_align(rowscale.contiguous(), 16) if rowscale is not None else None
1102
+ colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None
1103
+ zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward(
1104
+ x0,
1105
+ residual,
1106
+ gamma,
1107
+ beta,
1108
+ rowscale,
1109
+ colscale,
1110
+ dropout_p,
1111
+ epsilon,
1112
+ residual_in_fp32,
1113
+ is_rms_norm,
1114
+ )
1115
+ # Only need to save x0 if we need to compute gradient wrt colscale
1116
+ x0_saved = x0 if colscale is not None else None
1117
+ ctx.save_for_backward(
1118
+ xmat.view(x0.shape), x0_saved, dmask, gamma, mu, rsigma, rowscale, colscale
1119
+ )
1120
+ ctx.prenorm = prenorm
1121
+ ctx.dropout_p = dropout_p
1122
+ ctx.has_residual = residual is not None
1123
+ ctx.is_rms_norm = is_rms_norm
1124
+ ctx.has_beta = beta is not None
1125
+ if not return_dmask:
1126
+ return (
1127
+ zmat.view(x0.shape) if not prenorm else (zmat.view(x0.shape), xmat.view(x0.shape))
1128
+ )
1129
+ else:
1130
+ dmask = (
1131
+ dmask.view(x0.shape)
1132
+ if dropout_p > 0.0
1133
+ else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)
1134
+ )
1135
+ ctx.mark_non_differentiable(dmask)
1136
+ return (
1137
+ (zmat.view(x0.shape), dmask)
1138
+ if not prenorm
1139
+ else (zmat.view(x0.shape), xmat.view(x0.shape), dmask)
1140
+ )
1141
+
1142
+ @staticmethod
1143
+ def backward(ctx, dz, *args):
1144
+ # assert dz.is_contiguous()
1145
+ dz = maybe_align(dz.contiguous(), 16) # this happens!
1146
+ dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None
1147
+ x, x0, dmask, gamma, mu, rsigma, rowscale, colscale = ctx.saved_tensors
1148
+ # x0 is None if colscale is None
1149
+ dropout_p = ctx.dropout_p
1150
+ has_residual = ctx.has_residual
1151
+ dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_backward(
1152
+ dz,
1153
+ dx,
1154
+ x,
1155
+ x0,
1156
+ dmask,
1157
+ mu,
1158
+ rsigma,
1159
+ gamma,
1160
+ rowscale,
1161
+ colscale,
1162
+ dropout_p,
1163
+ has_residual,
1164
+ ctx.is_rms_norm,
1165
+ )
1166
+ dx0 = dx0mat.view(x.shape)
1167
+ dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
1168
+ dcolscale = rest[0] if colscale is not None else None
1169
+ return (
1170
+ dx0,
1171
+ dresidual,
1172
+ dgamma,
1173
+ dbeta if ctx.has_beta else None,
1174
+ None,
1175
+ dcolscale,
1176
+ None,
1177
+ None,
1178
+ None,
1179
+ None,
1180
+ None,
1181
+ None,
1182
+ )
1183
+
1184
+
1185
+ class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
1186
+ @staticmethod
1187
+ def forward(
1188
+ ctx,
1189
+ x0,
1190
+ residual,
1191
+ gamma,
1192
+ beta,
1193
+ colscale,
1194
+ x0_subset,
1195
+ out_subset,
1196
+ dropout_p,
1197
+ epsilon,
1198
+ rowscale_const,
1199
+ out_numrows,
1200
+ residual_in_fp32=False,
1201
+ prenorm=False,
1202
+ is_rms_norm=False,
1203
+ return_dmask=False,
1204
+ ):
1205
+ x0 = maybe_align(x0.contiguous(), 16)
1206
+ residual = maybe_align(residual.contiguous(), 16) if residual is not None else None
1207
+ gamma = maybe_align(gamma.contiguous(), 16)
1208
+ beta = maybe_align(beta.contiguous(), 16) if beta is not None else None
1209
+ colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None
1210
+ zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_subset_forward(
1211
+ x0,
1212
+ residual,
1213
+ gamma,
1214
+ beta,
1215
+ colscale,
1216
+ x0_subset,
1217
+ out_subset,
1218
+ dropout_p,
1219
+ epsilon,
1220
+ rowscale_const,
1221
+ out_numrows,
1222
+ residual_in_fp32,
1223
+ is_rms_norm,
1224
+ )
1225
+ # Only need to save x0 if we need to compute gradient wrt colscale
1226
+ x0_saved = x0 if colscale is not None else None
1227
+ x_shape = (-1, *x0.shape[1:])
1228
+ ctx.save_for_backward(
1229
+ xmat.view(x_shape), x0_saved, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset
1230
+ )
1231
+ ctx.prenorm = prenorm
1232
+ ctx.dropout_p = dropout_p
1233
+ ctx.rowscale_const = rowscale_const
1234
+ ctx.x0_numrows = x0.shape[:-1].numel()
1235
+ ctx.has_residual = residual is not None
1236
+ ctx.is_rms_norm = is_rms_norm
1237
+ ctx.has_beta = beta is not None
1238
+ z_shape = (-1, *x0.shape[1:])
1239
+ if not return_dmask:
1240
+ return zmat.view(z_shape) if not prenorm else (zmat.view(z_shape), xmat.view(x0.shape))
1241
+ else:
1242
+ z = zmat.view(z_shape)
1243
+ dmask = (
1244
+ dmask.view(x0.shape)
1245
+ if dropout_p > 0.0
1246
+ else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)
1247
+ )
1248
+ ctx.mark_non_differentiable(dmask)
1249
+ return (z, dmask) if not prenorm else (z, xmat.view(x_shape), dmask)
1250
+
1251
+ @staticmethod
1252
+ def backward(ctx, dz, *args):
1253
+ # assert dz.is_contiguous()
1254
+ dz = maybe_align(dz.contiguous(), 16) # this happens!
1255
+ dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None
1256
+ x, x0, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset = ctx.saved_tensors
1257
+ # x0 is None if colscale is None
1258
+ dropout_p = ctx.dropout_p
1259
+ has_residual = ctx.has_residual
1260
+ dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_subset_backward(
1261
+ dz,
1262
+ dx,
1263
+ x,
1264
+ x0,
1265
+ dmask,
1266
+ mu,
1267
+ rsigma,
1268
+ gamma,
1269
+ colscale,
1270
+ x0_subset,
1271
+ out_subset,
1272
+ dropout_p,
1273
+ ctx.rowscale_const,
1274
+ ctx.x0_numrows,
1275
+ has_residual,
1276
+ ctx.is_rms_norm,
1277
+ )
1278
+ dx0 = dx0mat.view(-1, *x.shape[1:])
1279
+ dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
1280
+ dcolscale = rest[0] if colscale is not None else None
1281
+ return (
1282
+ dx0,
1283
+ dresidual,
1284
+ dgamma,
1285
+ dbeta if ctx.has_beta else None,
1286
+ dcolscale,
1287
+ None,
1288
+ None,
1289
+ None,
1290
+ None,
1291
+ None,
1292
+ None,
1293
+ None,
1294
+ None,
1295
+ None,
1296
+ None,
1297
+ )
1298
+
1299
+
1300
+ class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function):
1301
+ @staticmethod
1302
+ def forward(
1303
+ ctx,
1304
+ x0,
1305
+ x1,
1306
+ residual,
1307
+ gamma0,
1308
+ beta0,
1309
+ gamma1,
1310
+ beta1,
1311
+ dropout_p,
1312
+ epsilon,
1313
+ residual_in_fp32=False,
1314
+ prenorm=False,
1315
+ is_rms_norm=False,
1316
+ return_dmask=False,
1317
+ ):
1318
+ x0 = maybe_align(x0.contiguous(), 16)
1319
+ x1 = maybe_align(x1.contiguous(), 16) if x1 is not None else None
1320
+ residual = maybe_align(residual.contiguous(), 16) if residual is not None else None
1321
+ gamma0 = maybe_align(gamma0.contiguous(), 16)
1322
+ beta0 = maybe_align(beta0.contiguous(), 16) if beta0 is not None else None
1323
+ gamma1 = maybe_align(gamma1.contiguous(), 16) if gamma1 is not None else None
1324
+ beta1 = maybe_align(beta1.contiguous(), 16) if beta1 is not None else None
1325
+ (
1326
+ z0mat,
1327
+ z1mat,
1328
+ xmat,
1329
+ dmask0,
1330
+ dmask1,
1331
+ mu,
1332
+ rsigma,
1333
+ ) = _dropout_add_layer_norm_parallel_residual_forward(
1334
+ x0,
1335
+ x1,
1336
+ residual,
1337
+ gamma0,
1338
+ beta0,
1339
+ gamma1,
1340
+ beta1,
1341
+ dropout_p,
1342
+ epsilon,
1343
+ residual_in_fp32,
1344
+ is_rms_norm,
1345
+ )
1346
+ ctx.save_for_backward(xmat.view(x0.shape), dmask0, dmask1, gamma0, gamma1, mu, rsigma)
1347
+ ctx.prenorm = prenorm
1348
+ ctx.dropout_p = dropout_p
1349
+ ctx.has_x1 = x1 is not None
1350
+ ctx.has_residual = residual is not None
1351
+ ctx.is_rms_norm = is_rms_norm
1352
+ ctx.has_beta = beta0 is not None
1353
+ z = (z0mat.view(x0.shape), z1mat.view(x0.shape) if z1mat is not None else None)
1354
+ if not return_dmask:
1355
+ return z if not prenorm else (*z, xmat.view(x0.shape))
1356
+ else:
1357
+ dmask0 = (
1358
+ dmask0.view(x0.shape)
1359
+ if dropout_p > 0.0
1360
+ else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)
1361
+ )
1362
+ dmask1 = (
1363
+ dmask1.view(x0.shape)
1364
+ if dropout_p > 0.0 and x1 is not None
1365
+ else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)
1366
+ )
1367
+ ctx.mark_non_differentiable(dmask0)
1368
+ ctx.mark_non_differentiable(dmask1)
1369
+ return (
1370
+ (*z, dmask0, dmask1) if not prenorm else (*z, xmat.view(x0.shape), dmask0, dmask1)
1371
+ )
1372
+
1373
+ @staticmethod
1374
+ def backward(ctx, dz0, dz1, *args):
1375
+ dz0 = maybe_align(dz0.contiguous(), 16) # this happens!
1376
+ dz1 = maybe_align(dz1.contiguous(), 16) if dz1 is not None else None
1377
+ dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None
1378
+ x, dmask0, dmask1, gamma0, gamma1, mu, rsigma = ctx.saved_tensors
1379
+ dropout_p = ctx.dropout_p
1380
+ has_x1 = ctx.has_x1
1381
+ has_residual = ctx.has_residual
1382
+ (
1383
+ dx0mat,
1384
+ dx1mat,
1385
+ dresidualmat,
1386
+ dgamma0,
1387
+ dbeta0,
1388
+ dgamma1,
1389
+ dbeta1,
1390
+ ) = _dropout_add_layer_norm_parallel_residual_backward(
1391
+ dz0,
1392
+ dz1,
1393
+ dx,
1394
+ x,
1395
+ dmask0,
1396
+ dmask1,
1397
+ mu,
1398
+ rsigma,
1399
+ gamma0,
1400
+ gamma1,
1401
+ dropout_p,
1402
+ has_x1,
1403
+ has_residual,
1404
+ ctx.is_rms_norm,
1405
+ )
1406
+ dx0 = dx0mat.view(x.shape)
1407
+ dx1 = dx1mat.view(x.shape) if dx1mat is not None else None
1408
+ dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
1409
+ return (
1410
+ dx0,
1411
+ dx1,
1412
+ dresidual,
1413
+ dgamma0,
1414
+ dbeta0 if ctx.has_beta else None,
1415
+ dgamma1,
1416
+ dbeta1 if ctx.has_beta else None,
1417
+ None,
1418
+ None,
1419
+ None,
1420
+ None,
1421
+ None,
1422
+ None,
1423
+ )
1424
+
1425
+
1426
+ def layer_norm(x, weight, bias, epsilon):
1427
+ return DropoutAddLayerNormFn.apply(x, None, weight, bias, None, None, 0.0, epsilon, False)
1428
+
1429
+
1430
+ def dropout_add_layer_norm(
1431
+ x0,
1432
+ residual,
1433
+ weight,
1434
+ bias,
1435
+ dropout_p,
1436
+ epsilon,
1437
+ rowscale=None,
1438
+ layerscale=None,
1439
+ prenorm=False,
1440
+ residual_in_fp32=False,
1441
+ return_dropout_mask=False,
1442
+ ):
1443
+ """residual_in_fp32 only has an effect if residual is None.
1444
+ Otherwise residual dtype is residual.dtype.
1445
+ """
1446
+ return DropoutAddLayerNormFn.apply(
1447
+ x0,
1448
+ residual,
1449
+ weight,
1450
+ bias,
1451
+ rowscale,
1452
+ layerscale,
1453
+ dropout_p,
1454
+ epsilon,
1455
+ residual_in_fp32,
1456
+ prenorm,
1457
+ False,
1458
+ return_dropout_mask,
1459
+ )
1460
+
1461
+
1462
+ def dropout_add_layer_norm_subset(
1463
+ x0,
1464
+ residual,
1465
+ weight,
1466
+ bias,
1467
+ dropout_p,
1468
+ epsilon,
1469
+ layerscale=None,
1470
+ x0_subset=None,
1471
+ out_subset=None,
1472
+ rowscale_const=1.0,
1473
+ out_numrows=0,
1474
+ prenorm=False,
1475
+ residual_in_fp32=False,
1476
+ return_dropout_mask=False,
1477
+ ):
1478
+ """residual_in_fp32 only has an effect if residual is None.
1479
+ Otherwise residual dtype is residual.dtype.
1480
+ """
1481
+ return DropoutAddLayerNormSubsetFn.apply(
1482
+ x0,
1483
+ residual,
1484
+ weight,
1485
+ bias,
1486
+ layerscale,
1487
+ x0_subset,
1488
+ out_subset,
1489
+ dropout_p,
1490
+ epsilon,
1491
+ rowscale_const,
1492
+ out_numrows,
1493
+ residual_in_fp32,
1494
+ prenorm,
1495
+ False,
1496
+ return_dropout_mask,
1497
+ )
1498
+
1499
+
1500
+ def dropout_add_layer_norm_parallel_residual(
1501
+ x0,
1502
+ x1,
1503
+ residual,
1504
+ weight0,
1505
+ bias0,
1506
+ weight1,
1507
+ bias1,
1508
+ dropout_p,
1509
+ epsilon,
1510
+ prenorm=False,
1511
+ residual_in_fp32=False,
1512
+ return_dropout_mask=False,
1513
+ ):
1514
+ """residual_in_fp32 only has an effect if residual is None.
1515
+ Otherwise residual dtype is residual.dtype.
1516
+ """
1517
+ return DropoutAddLayerNormParallelResidualFn.apply(
1518
+ x0,
1519
+ x1,
1520
+ residual,
1521
+ weight0,
1522
+ bias0,
1523
+ weight1,
1524
+ bias1,
1525
+ dropout_p,
1526
+ epsilon,
1527
+ residual_in_fp32,
1528
+ prenorm,
1529
+ False,
1530
+ return_dropout_mask,
1531
+ )
1532
+
1533
+
1534
+ class DropoutAddLayerNorm(torch.nn.Module):
1535
+ def __init__(
1536
+ self,
1537
+ hidden_size,
1538
+ prenorm=False,
1539
+ p=0.0,
1540
+ eps=1e-5,
1541
+ residual_in_fp32=False,
1542
+ device=None,
1543
+ dtype=None,
1544
+ ):
1545
+ factory_kwargs = {"device": device, "dtype": dtype}
1546
+ super().__init__()
1547
+ self.prenorm = prenorm
1548
+ self.p = p
1549
+ self.eps = eps
1550
+ self.residual_in_fp32 = residual_in_fp32
1551
+ self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
1552
+ self.bias = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
1553
+ self.reset_parameters()
1554
+
1555
+ def reset_parameters(self):
1556
+ init.ones_(self.weight)
1557
+ init.zeros_(self.bias)
1558
+
1559
+ def forward(self, x0, residual=None):
1560
+ return dropout_add_layer_norm(
1561
+ x0,
1562
+ residual,
1563
+ self.weight,
1564
+ self.bias,
1565
+ self.p if self.training else 0.0,
1566
+ self.eps,
1567
+ prenorm=self.prenorm,
1568
+ residual_in_fp32=self.residual_in_fp32,
1569
+ )
1570
+
1571
+ def rms_norm(x, weight, epsilon):
1572
+ return DropoutAddLayerNormFn.apply(
1573
+ x, None, weight, None, None, None, 0.0, epsilon, False, False, True
1574
+ )
1575
+ class FusedRMSNorm(torch.nn.Module):
1576
+ def __init__(self, size: int, dim: int = -1, eps: float = 1e-5):
1577
+ super().__init__()
1578
+ self.eps = eps
1579
+ self.weight = torch.nn.Parameter(torch.ones(size))
1580
+ self.dim = dim
1581
+ self.reset_parameters()
1582
+
1583
+ def reset_parameters(self):
1584
+ init.ones_(self.weight)
1585
+
1586
+ def forward(self, x):
1587
+ return rms_norm(x, self.weight, self.eps)
1588
+
1589
+
1590
+ class RMSNorm(torch.nn.Module):
1591
+ """Root Mean Square Layer Normalization.
1592
+
1593
+ Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License:
1594
+ https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE.
1595
+ """
1596
+
1597
+ def __init__(self, size: int, dim: int = -1, eps: float = 1e-5) -> None:
1598
+ super().__init__()
1599
+ self.weight = torch.nn.Parameter(torch.ones(size))
1600
+ self.eps = eps
1601
+ self.dim = dim
1602
+
1603
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1604
+ # NOTE: the original RMSNorm paper implementation is not equivalent
1605
+ norm_x = torch.mean(x * x, dim=self.dim, keepdim=True)
1606
+ x_normed = x * torch.rsqrt(norm_x + self.eps)
1607
+ return self.weight * x_normed
1608
+
1609
+ def reset_parameters(self):
1610
+ torch.nn.init.ones_(self.weight)
1611
+
1612
+
1613
+
1614
+
1615
+
1616
+
1617
+ # Copyright (c) 2023, Tri Dao.
1618
+
1619
+ import math
1620
+ from typing import Optional, Tuple
1621
+
1622
+ import rotary_emb
1623
+ import torch
1624
+ from einops import rearrange, repeat
1625
+
1626
+ class ApplyRotaryEmb(torch.autograd.Function):
1627
+ @staticmethod
1628
+ def forward(ctx, x, cos, sin, interleaved=False, inplace=False,future_token = 0):
1629
+ """
1630
+ x: (batch_size, seqlen, nheads, headdim)
1631
+ cos, sin: (seqlen, rotary_dim / 2)
1632
+ interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
1633
+ of 1st half and 2nd half (GPT-NeoX style).
1634
+ rotary_dim must be <= headdim
1635
+ Apply rotary embedding to the first rotary_dim of x.
1636
+ """
1637
+ batch, seqlen, nheads, headdim = x.shape
1638
+ rotary_seqlen, rotary_dim = cos.shape
1639
+ rotary_dim *= 2
1640
+
1641
+
1642
+ # print('谁纸盘仲裁',x.shape,cos.shape)
1643
+ # 谁纸盘仲裁 torch.Size([224, 96, 12, 64]) torch.Size([1, 32])
1644
+ # 谁纸盘仲裁 2049 2048
1645
+ assert rotary_dim <= headdim
1646
+ # print(seqlen,rotary_seqlen)
1647
+ assert seqlen <= rotary_seqlen
1648
+ assert sin.shape == (rotary_seqlen, rotary_dim // 2)
1649
+ x_ro = x[..., :rotary_dim]
1650
+ x1, x2 = x_ro.chunk(2, dim=-1) if not interleaved else (x_ro[..., ::2], x_ro[..., 1::2])
1651
+ out = torch.empty_like(x) if not inplace else x
1652
+ out_ro = out[..., :rotary_dim]
1653
+ if inplace:
1654
+ o1, o2 = x1, x2
1655
+ else:
1656
+ o1, o2 = (
1657
+ out_ro.chunk(2, dim=-1)
1658
+ if not interleaved
1659
+ else (out_ro[..., ::2], out_ro[..., 1::2])
1660
+ )
1661
+ rotary_emb.apply_rotary(
1662
+ x1,
1663
+ x2,
1664
+ rearrange(cos[:seqlen], "s d -> s 1 d"),
1665
+ rearrange(sin[:seqlen], "s d -> s 1 d"),
1666
+ o1,
1667
+ o2,
1668
+ False,
1669
+ )
1670
+ if not inplace and rotary_dim < headdim:
1671
+ out[..., rotary_dim:].copy_(x[..., rotary_dim:])
1672
+ ctx.save_for_backward(cos, sin)
1673
+ ctx.interleaved = interleaved
1674
+ ctx.inplace = inplace
1675
+ return out if not inplace else x
1676
+
1677
+ @staticmethod
1678
+ def backward(ctx, do):
1679
+ cos, sin = ctx.saved_tensors
1680
+ _, seqlen, _, headdim = do.shape
1681
+ rotary_dim = cos.shape[-1]
1682
+ rotary_dim *= 2
1683
+ inplace = ctx.inplace
1684
+ do_ro = do[..., :rotary_dim]
1685
+ do1, do2 = (
1686
+ do_ro.chunk(2, dim=-1) if not ctx.interleaved else (do_ro[..., ::2], do_ro[..., 1::2])
1687
+ )
1688
+ dx = torch.empty_like(do) if not inplace else do
1689
+ if inplace:
1690
+ dx1, dx2 = do1, do2
1691
+ else:
1692
+ dx_ro = dx[..., :rotary_dim]
1693
+ dx1, dx2 = (
1694
+ dx_ro.chunk(2, dim=-1)
1695
+ if not ctx.interleaved
1696
+ else (dx_ro[..., ::2], dx_ro[..., 1::2])
1697
+ )
1698
+ rotary_emb.apply_rotary(
1699
+ do1,
1700
+ do2,
1701
+ rearrange(cos[:seqlen], "s d -> s 1 d"),
1702
+ rearrange(sin[:seqlen], "s d -> s 1 d"),
1703
+ dx1,
1704
+ dx2,
1705
+ True,
1706
+ )
1707
+ if not inplace and rotary_dim < headdim:
1708
+ dx[..., rotary_dim:].copy_(do[..., rotary_dim:])
1709
+ return dx, None, None, None, None
1710
+
1711
+
1712
+ apply_rotary_emb_func = ApplyRotaryEmb.apply
1713
+
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:da44e17cefff6bb6b59af0cb6164a51e7eeda2dd625925cb11743e74eae8e812
3
+ size 72538452
model_config.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ from transformers import PretrainedConfig
3
+
4
+
5
+ class YingLongConfig(PretrainedConfig):
6
+ model_type = "yinglong"
7
+ # keys_to_ignore_at_inference = ["past_key_values"]
8
+
9
+ def __init__(
10
+ self,
11
+ # input_token_len: int = 1,
12
+ # hidden_size: int = 1024,
13
+ # intermediate_size: int = 2048,
14
+ # output_token_lens: List[int] = [1, 8, 32, 64],
15
+ # num_hidden_layers: int = 8,
16
+ # num_attention_heads: int = 8,
17
+ # hidden_act: str = "silu",
18
+ # use_cache: bool = True,
19
+ # rope_theta: int = 10000,
20
+ # attention_dropout: float = 0.0,
21
+ # initializer_range: float = 0.02,
22
+ # max_position_embeddings: int = 10000,
23
+ #####
24
+ bias = False,
25
+ condense_ratio = 1,
26
+ haar_trans = True,
27
+ haar_trans_inv = True,
28
+ haar_trans_norm = 'backward',
29
+ half_diff = False,
30
+ intermediate_size = 1024,
31
+ n_embd = 256,
32
+ n_head = 16,
33
+ n_layer = 6,
34
+ n_query_groups = 4,
35
+ norm_eps = 1e-5,
36
+ org = 'Alibaba',
37
+ patch_size = 32,
38
+ rope_base = 10000,
39
+ rotary_percentage = 1.0,
40
+ shared_attention_norm = False,
41
+ unet = True,
42
+ _mlp_class = "LLaMAMLP",
43
+ _norm_class="FusedRMSNorm",
44
+ *args,
45
+ **kwargs,
46
+ ):
47
+
48
+ # self.input_token_len = input_token_len
49
+ # self.hidden_size = hidden_size
50
+ # self.intermediate_size = intermediate_size
51
+ # self.num_hidden_layers = num_hidden_layers
52
+ # self.num_attention_heads = num_attention_heads
53
+ # self.hidden_act = hidden_act
54
+ # self.output_token_lens = output_token_lens;
55
+ # self.use_cache = use_cache
56
+ # self.rope_theta = rope_theta
57
+ # self.attention_dropout = attention_dropout
58
+ # self.initializer_range = initializer_range
59
+ # self.max_position_embeddings = max_position_embeddings
60
+ self.org = 'Alibaba'
61
+ self.patch_size = patch_size
62
+ self.unet = unet
63
+
64
+ self.n_embd = n_embd
65
+ self.intermediate_size = intermediate_size
66
+ self.n_head = n_head
67
+ self.n_layer = n_layer
68
+ self.n_query_groups = n_query_groups
69
+ self.norm_eps = norm_eps
70
+ self.bias = bias
71
+ self.shared_attention_norm = shared_attention_norm
72
+
73
+ self.condense_ratio = condense_ratio
74
+ self.rope_base = rope_base
75
+ self.rotary_percentage = rotary_percentage
76
+
77
+ self.haar_trans = haar_trans
78
+ self.haar_trans_inv = haar_trans_inv
79
+ self.haar_trans_norm = haar_trans_norm
80
+ self.half_diff = half_diff
81
+
82
+ self._norm_class = _norm_class
83
+
84
+ self._mlp_class = _mlp_class
85
+
86
+ assert self.n_embd % self.n_head == 0
87
+ assert self.n_head % self.n_query_groups == 0
88
+
89
+ self.head_size = self.n_embd // self.n_head
90
+ self.rope_n_elem = int(self.rotary_percentage * self.head_size)
91
+ self.rope_condense_ratio = self.condense_ratio
92
+
93
+
94
+
95
+
96
+
97
+
98
+ super().__init__(
99
+ **kwargs,
100
+ )