davda54 commited on
Commit
971057e
·
verified ·
1 Parent(s): 6c21030

Upload folder using huggingface_hub

Browse files
__init__.py ADDED
File without changes
config.json ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "GptBertFoCausalLM"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_gptbert.GptBertConfig",
7
+ "AutoModel": "modeling_gptbert.GptBertModel",
8
+ "AutoModelForCausalLM": "modeling_gptbert.GptBertForCausalLM",
9
+ "AutoModelForMaskedLM": "modeling_gptbert.GptBertForMaskedLM",
10
+ "AutoModelForSequenceClassification": "modeling_gptbert.GptBertForSequenceClassification",
11
+ "AutoModelForTokenClassification": "modeling_gptbert.GptBertForTokenClassification",
12
+ "AutoModelForQuestionAnswering": "modeling_gptbert.GptBertForQuestionAnswering",
13
+ "AutoModelForMultipleChoice": "modeling_gptbert.GptBertForMultipleChoice"
14
+ },
15
+ "attention_dropout": 0.0,
16
+ "attention_output_dropout_p": 0.0,
17
+ "attention_inter_norm_affine": false,
18
+ "attention_inter_norm_eps": 1e-07,
19
+ "attention_pre_norm_affine": false,
20
+ "attention_pre_norm_eps": 1e-07,
21
+ "attention_probabilities_dropout_p": 0.0,
22
+ "classifier_post_norm_affine": false,
23
+ "classifier_post_norm_eps": 1e-07,
24
+ "classifier_pre_norm_affine": false,
25
+ "classifier_pre_norm_eps": 1e-07,
26
+ "d_qk": 64,
27
+ "d_v": 64,
28
+ "embedding_dropout_p": 0.1,
29
+ "feed_forward_dropout_p": 0.0,
30
+ "feed_forward_inter_norm_affine": false,
31
+ "feed_forward_inter_norm_eps": 1e-07,
32
+ "feed_forward_pre_norm_affine": false,
33
+ "feed_forward_pre_norm_eps": 1e-07,
34
+ "hidden_size": 640,
35
+ "intermediate_size": 1664,
36
+ "max_sequence_length": 16384,
37
+ "num_attention_heads": 10,
38
+ "num_kv_heads": 10,
39
+ "num_layers": 24,
40
+ "rope_theta": 160000,
41
+ "vocab_size": 51200,
42
+ "word_norm_affine": true,
43
+ "word_norm_eps": 1e-07,
44
+ "short_long_ratio": 4,
45
+ "window_length": 8192,
46
+ "is_decoder": false,
47
+ "not_flex": true,
48
+ "hidden_dropout_prob": 0.2
49
+ }
configuration_gptbert.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from pathlib import Path
5
+ import copy
6
+ from transformers.configuration_utils import PretrainedConfig
7
+
8
+
9
+ class GptBertConfig(PretrainedConfig):
10
+
11
+ def __init__(
12
+ self,
13
+ config_file: Path | str | None = None,
14
+ **kwargs
15
+ ):
16
+ super().__init__(**kwargs)
17
+
18
+ self.model: str
19
+
20
+ # General information
21
+ self.model = "base"
22
+
23
+ # Vocabulary
24
+ self.vocab_size = 16384
25
+ self.max_sequence_length = 512
26
+
27
+ # Model dimensions
28
+ self.hidden_size = 768
29
+ self.intermediate_size = 2048
30
+ self.num_attention_heads = 12
31
+ self.num_layers = 12
32
+ self.d_qk = 64
33
+
34
+ # Dropout probabilities
35
+ self.embedding_dropout_p = 0.1
36
+ self.attention_probabilities_dropout_p = 0.1
37
+ self.attention_output_dropout_p = 0.1
38
+ self.feed_forward_dropout_p = 0.1
39
+ self.attention_dropout = 0.1
40
+ self.hidden_dropout_prob = 0.2
41
+
42
+ # Position Emebedding
43
+ self.rope_theta = 160_000
44
+
45
+ # Norms
46
+ self.word_norm_eps = 1e-7
47
+ self.word_norm_affine = False
48
+
49
+ self.attention_pre_norm_eps = 1e-7
50
+ self.attention_pre_norm_affine = False
51
+
52
+ self.attention_inter_norm_eps = 1e-7
53
+ self.attention_inter_norm_affine = True
54
+
55
+ self.feed_forward_pre_norm_eps = 1e-7
56
+ self.feed_forward_pre_norm_affine = False
57
+
58
+ self.feed_forward_inter_norm_eps = 1e-7
59
+ self.feed_forward_inter_norm_affine = False
60
+
61
+ self.classifier_pre_norm_eps = 1e-7
62
+ self.classifier_pre_norm_affine = False
63
+
64
+ self.classifier_post_norm_eps = 1e-7
65
+ self.classifier_post_norm_affine = False
66
+
67
+ if config_file is not None:
68
+ if type(config_file) is str:
69
+ config_file = Path(config_file)
70
+ assert type(config_file) is not Path, "The config_file should either be a Path or str"
71
+ with config_file.open("r") as file:
72
+ config = json.load(file)
73
+
74
+ for attr, value in config.items():
75
+ if isinstance(value, str):
76
+ value = value.lower()
77
+ setattr(self, attr, value)
78
+
79
+ for attr, value in kwargs.items():
80
+ if isinstance(value, str):
81
+ value = value.lower()
82
+ setattr(self, attr, value)
83
+
84
+ def __repr__(self) -> str:
85
+ return str(self.to_json_string())
86
+
87
+ def to_dict(self) -> dict:
88
+ """Serializes this instance to a Python dictionary."""
89
+ output: dict
90
+
91
+ output = copy.deepcopy(self.__dict__)
92
+ return output
93
+
94
+ def to_json_string(self) -> str:
95
+ """Serializes this instance to a JSON string."""
96
+ return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
97
+
98
+ def to_json_file(self, json_file_path: Path | str) -> None:
99
+ """Save this instance to a json file."""
100
+ if isinstance(json_file_path, str):
101
+ json_file_path: Path = Path(json_file_path)
102
+ with json_file_path.open("w", encoding='utf-8') as writer:
103
+ writer.write(self.to_json_string())
104
+
105
+ @classmethod
106
+ def create_base_config(cls, json_file_path: Path | str | None = None) -> GptBertConfig:
107
+ config: GptBertConfig
108
+
109
+ config = GptBertConfig()
110
+ if json_file_path is not None:
111
+ config.to_json_file(json_file_path)
112
+
113
+ return config
modeling_gptbert.py ADDED
@@ -0,0 +1,1216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.nn import functional as F
6
+ from torch import _softmax_backward_data as _softmax_backward_data
7
+
8
+ from functools import partial
9
+
10
+ from .configuration_gptbert import GptBertConfig
11
+ from transformers.modeling_utils import PreTrainedModel
12
+ from transformers.activations import gelu_new
13
+ from transformers.modeling_outputs import (
14
+ MaskedLMOutput,
15
+ MultipleChoiceModelOutput,
16
+ QuestionAnsweringModelOutput,
17
+ SequenceClassifierOutput,
18
+ TokenClassifierOutput,
19
+ BaseModelOutput,
20
+ CausalLMOutput
21
+ )
22
+ import math
23
+ from typing import TYPE_CHECKING, Optional, Union, Tuple, List
24
+
25
+ try:
26
+ from torch.nn.attention.flex_attention import flex_attention, create_block_mask
27
+ except ImportError:
28
+ pass
29
+
30
+
31
+ class ModelOutput:
32
+
33
+ def __init__(
34
+ self,
35
+ logits: torch.Tensor | None = None,
36
+ loss: torch.Tensor | float | None = None,
37
+ perplexity: torch.Tensor | float | None = None,
38
+ accuracy: float | None = None,
39
+ z_loss: torch.Tensor | float | None = None,
40
+ **kwargs
41
+ ):
42
+ self.logits: torch.Tensor | None
43
+ self.loss: torch.Tensor | float | None
44
+ self.perplexity: torch.Tensor | float | None
45
+ self.accuracy: float | None
46
+ self.z_loss: torch.Tensor | float | None
47
+
48
+ self.logits = logits
49
+ self.loss = loss
50
+ self.perplexity = perplexity
51
+ self.accuracy = accuracy
52
+ self.z_loss = z_loss
53
+
54
+ for attr, value in kwargs.items():
55
+ setattr(self, attr, value)
56
+
57
+
58
+ class CastedLinear(nn.Linear):
59
+
60
+ def __init__(self, in_features, out_features, bias):
61
+ super().__init__(in_features, out_features, bias=bias)
62
+
63
+ def reset_parameters(self) -> None:
64
+ std: float = math.sqrt(2.0 / (self.in_features + self.out_features))
65
+ nn.init.trunc_normal_(self.weight, mean=0.0, std=std, a=-2*std, b=2*std)
66
+
67
+ def forward(self, x):
68
+ return F.linear(x, self.weight.type_as(x), bias=self.bias.type_as(x) if self.bias is not None else None)
69
+
70
+
71
+ class CastedLinearIn(nn.Linear):
72
+
73
+ def __init__(self, in_features, out_features, bias):
74
+ super().__init__(in_features, out_features, bias=bias)
75
+ self.scale = nn.Parameter(torch.ones(in_features))
76
+
77
+ def reset_parameters(self) -> None:
78
+ std: float = math.sqrt(2.0 / (self.in_features + self.out_features))
79
+ nn.init.trunc_normal_(self.weight, mean=0.0, std=std, a=-2*std, b=2*std)
80
+
81
+ def forward(self, x):
82
+ return F.linear(x, (self.weight * (self.scale + 1.0).unsqueeze(0)).type_as(x), bias=self.bias.type_as(x) if self.bias is not None else None)
83
+
84
+
85
+ class CastedLinearOut(nn.Linear):
86
+
87
+ def __init__(self, in_features, out_features, bias):
88
+ super().__init__(in_features, out_features, bias=bias)
89
+ self.scale = nn.Parameter(torch.ones(out_features))
90
+
91
+ def reset_parameters(self) -> None:
92
+ std: float = math.sqrt(2.0 / (self.in_features + self.out_features))
93
+ nn.init.trunc_normal_(self.weight, mean=0.0, std=std, a=-2*std, b=2*std)
94
+
95
+ def forward(self, x):
96
+ return F.linear(x, (self.scale.unsqueeze(1) * self.weight).type_as(x), bias=self.bias.type_as(x) if self.bias is not None else None)
97
+
98
+
99
+ class MultiCastedLinearOrtho(nn.Module):
100
+
101
+ def __init__(self, in_features, out_features, bias):
102
+ super().__init__()
103
+ self.in_features = in_features
104
+ self.out_features = out_features
105
+
106
+ self.weights = nn.ParameterList()
107
+ for out_feature in out_features:
108
+ self.weights.append(nn.Parameter(torch.empty((out_feature, in_features))))
109
+
110
+ if bias:
111
+ self.bias = nn.Parameter(torch.zeros(sum(out_features)))
112
+ else:
113
+ self.bias = self.register_parameter("bias", None)
114
+
115
+ self.reset_parameters()
116
+
117
+ def reset_parameters(self) -> None:
118
+ for i, weight in enumerate(self.weights):
119
+ std: float = math.sqrt(2.0 / (self.in_features + self.out_features[i]))
120
+ nn.init.trunc_normal_(weight, mean=0.0, std=std, a=-2*std, b=2*std)
121
+
122
+ def forward(self, x):
123
+ return F.linear(x, torch.cat([weight for weight in self.weights], dim=0).type_as(x), bias=self.bias.type_as(x) if self.bias is not None else None)
124
+
125
+
126
+ class MultiCastedLinearOrthoIn(nn.Module):
127
+
128
+ def __init__(self, in_features, out_features, bias):
129
+ super().__init__()
130
+ self.in_features = in_features
131
+ self.out_features = out_features
132
+
133
+ self.weights = nn.ParameterList()
134
+ for out_feature in out_features:
135
+ self.weights.append(nn.Parameter(torch.empty((out_feature, in_features))))
136
+
137
+ if bias:
138
+ self.bias = nn.Parameter(torch.zeros(sum(out_features)))
139
+ else:
140
+ self.bias = self.register_parameter("bias", None)
141
+
142
+ self.scale = nn.Parameter(torch.ones(in_features))
143
+
144
+ self.reset_parameters()
145
+
146
+ def reset_parameters(self) -> None:
147
+ for weight in self.weights:
148
+ std = 0.5 * (self.in_features ** -0.5)
149
+ bound = (3 ** 0.5) * std
150
+ with torch.no_grad():
151
+ weight.uniform_(-bound, bound)
152
+
153
+ def forward(self, x):
154
+ return F.linear(x, (torch.cat([weight for weight in self.weights], dim=0) * (self.scale + 1.0).unsqueeze(0)).type_as(x), bias=self.bias.type_as(x) if self.bias is not None else None)
155
+
156
+
157
+ class MultiCastedLinearOrthoOut(nn.Module):
158
+
159
+ def __init__(self, in_features, out_features, bias):
160
+ super().__init__()
161
+ self.in_features = in_features
162
+ self.out_features = out_features
163
+
164
+ self.weights = nn.ParameterList()
165
+ for out_feature in out_features:
166
+ self.weights.append(nn.Parameter(torch.empty((out_feature, in_features))))
167
+
168
+ if bias:
169
+ self.bias = nn.Parameter(torch.zeros(sum(out_features)))
170
+ else:
171
+ self.bias = self.register_parameter("bias", None)
172
+
173
+ self.scale = nn.Parameter(torch.ones(sum(out_features)))
174
+
175
+ self.reset_parameters()
176
+
177
+ def reset_parameters(self) -> None:
178
+ for weight in self.weights:
179
+ std = 0.5 * (self.in_features ** -0.5)
180
+ bound = (3 ** 0.5) * std
181
+ with torch.no_grad():
182
+ weight.uniform_(-bound, bound)
183
+
184
+ def forward(self, x):
185
+ return F.linear(x, (self.scale.unsqueeze(1) * torch.cat([weight for weight in self.weights], dim=0)).type_as(x), bias=self.bias.type_as(x) if self.bias is not None else None)
186
+
187
+
188
+ class GeGLU(nn.Module):
189
+ def forward(self, x):
190
+ x, gate = x.chunk(2, dim=-1)
191
+ x = x * gelu_new(gate)
192
+ return x
193
+
194
+
195
+ class MaskedSoftmax(torch.autograd.Function):
196
+ @staticmethod
197
+ def forward(ctx, x: torch.Tensor, mask: torch.BoolTensor, dim: int) -> torch.Tensor:
198
+ ctx.dim: int
199
+
200
+ ctx.dim = dim
201
+ x.masked_fill_(mask, float('-inf'))
202
+ x = torch.softmax(x, ctx.dim)
203
+ x.masked_fill_(mask, 0.0)
204
+ ctx.save_for_backward(x)
205
+ return x
206
+
207
+ @staticmethod
208
+ def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None, None]:
209
+ output: torch.Tensor
210
+
211
+ output, = ctx.saved_tensors
212
+ inputGrad: torch.Tensor = _softmax_backward_data(grad_output, output, ctx.dim, output.dtype)
213
+ return inputGrad, None, None
214
+
215
+
216
+ class Encoder(nn.Module):
217
+
218
+ def __init__(self, config) -> None:
219
+ super().__init__()
220
+
221
+ self.layers: nn.ModuleList[Layer]
222
+
223
+ self.layers = nn.ModuleList([Layer(config, i) for i in range(config.num_layers)])
224
+
225
+ for i, layer in enumerate(self.layers):
226
+ for weight in layer.mlp.up_proj.weights:
227
+ weight.data *= math.sqrt(1.0 / (2.0 * (i + 1)))
228
+ layer.mlp.down_proj.weight.data *= math.sqrt(1.0 / (2.0 * (i + 1)))
229
+
230
+ self.short_long_ratio = config.short_long_ratio
231
+
232
+ def set_window_length(self, config) -> None:
233
+ for i, layer in enumerate(self.layers):
234
+ if (i+1) % self.short_long_ratio == 0:
235
+ layer.set_window_length(config.window_length, config.not_flex)
236
+ else:
237
+ layer.set_window_length(256, config.not_flex)
238
+
239
+ def forward(self, hidden_layer: torch.Tensor, embeddings: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
240
+ hidden_layer: List[torch.Tensor]
241
+ attention_probs: List[torch.Tensor]
242
+
243
+ hidden_states = []
244
+ attention_probs = []
245
+ v1 = None
246
+
247
+ for layer in self.layers:
248
+ hidden_layer, v1, attention_p = layer(hidden_layer, embeddings, v1, mask)
249
+ hidden_states.append(hidden_layer)
250
+ attention_probs.append(attention_p)
251
+
252
+ return hidden_states, attention_probs
253
+
254
+
255
+ class Layer(nn.Module):
256
+
257
+ def __init__(self, config, layer_idx: int) -> None:
258
+ super().__init__()
259
+
260
+ self.attention: SelfAttention
261
+ self.mlp: FeedForward
262
+
263
+ self.attention = SelfAttention(config, layer_idx)
264
+ self.mlp = FeedForward(config)
265
+ self.lambdas = nn.Parameter(torch.tensor([0., 0., 1., 0., 1., 0.]))
266
+
267
+ def set_window_length(self, window_length: int, not_flex: bool) -> None:
268
+ self.attention.set_window_length(window_length, not_flex)
269
+
270
+ def forward(self, hidden_layer: torch.Tensor, embeddings: torch.Tensor, v1: torch.Tensor | None, mask: torch.Tensor | None = None) -> Tuple[torch.Tensor, torch.Tensor]:
271
+ output: torch.Tensor
272
+ attention_p: torch.Tensor
273
+
274
+ attention_output = (1 - self.lambdas[0]) * hidden_layer + self.lambdas[0] * embeddings
275
+ qk_layer = (1 - self.lambdas[1]) * hidden_layer + self.lambdas[1] * embeddings
276
+ mlp_layer = F.softplus(self.lambdas[2]) * ((1 - self.lambdas[3]) * hidden_layer + self.lambdas[3] * embeddings)
277
+
278
+ attention_output, v1, attention_p = self.attention(attention_output, qk_layer, v1, mask)
279
+ mlp_layer = mlp_layer + attention_output
280
+ hidden_layer = F.softplus(self.lambdas[4]) * ((1 - self.lambdas[5]) * hidden_layer + self.lambdas[5] * embeddings)
281
+ output = hidden_layer + attention_output + self.mlp(mlp_layer)
282
+
283
+ return output, v1, attention_p
284
+
285
+
286
+ class Embedding(nn.Module):
287
+
288
+ def __init__(self, config) -> None:
289
+ super().__init__()
290
+
291
+ assert hasattr(config, "vocab_size"), "The config must have a vocab_size attribute!"
292
+ assert hasattr(config, "hidden_size"), "The config must have a hidden_size attribute!"
293
+ assert hasattr(config, "embedding_dropout_p"), "The model must have a embedding_dropout_p attribute!"
294
+
295
+ self.word_embedding: nn.Embedding
296
+ self.word_norm: nn.LayerNorm
297
+ self.dropout: nn.Dropout
298
+
299
+ self.word_embedding = nn.Embedding(config.vocab_size, config.hidden_size)
300
+ self.word_norm = nn.LayerNorm(config.hidden_size, eps=config.word_norm_eps, elementwise_affine=False, bias=False)
301
+ self.word_scale = nn.Parameter(torch.zeros(config.hidden_size))
302
+
303
+ self.dropout = nn.Dropout(config.embedding_dropout_p)
304
+
305
+ self.initialize(config.hidden_size, config.vocab_size)
306
+
307
+ @torch.no_grad()
308
+ def initialize(self, hidden_size: int, vocab_size: int) -> None:
309
+ std: float
310
+
311
+ std = math.sqrt(2.0 / (hidden_size + vocab_size))
312
+ nn.init.trunc_normal_(self.word_embedding.weight, mean=0.0, std=std, a=-2*std, b=2*std)
313
+
314
+ def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
315
+ word_embedding: torch.Tensor
316
+
317
+ word_embedding = self.word_embedding(input_ids)
318
+ word_embedding = self.word_norm(word_embedding)
319
+ word_embedding = (word_embedding * (self.word_scale + 1.0).unsqueeze(0).unsqueeze(0))
320
+
321
+ return self.dropout(word_embedding)
322
+
323
+
324
+ class MaskClassifier(nn.Module):
325
+
326
+ def __init__(self, config, embedding_weights: nn.Parameter) -> None:
327
+ super().__init__()
328
+
329
+ self.projection: CastedLinear
330
+ self.emb2vocab: CastedLinear
331
+ self.pre_norm: nn.LayerNorm
332
+ self.post_norm: nn.LayerNorm
333
+
334
+ self.pre_norm = nn.LayerNorm(config.hidden_size, eps=config.classifier_pre_norm_eps, elementwise_affine=config.classifier_pre_norm_affine)
335
+ self.projection = CastedLinearIn(config.hidden_size, config.hidden_size, bias=False)
336
+ self.post_norm = nn.LayerNorm(config.hidden_size, eps=config.classifier_post_norm_eps, elementwise_affine=config.classifier_post_norm_affine)
337
+ self.emb2vocab = CastedLinearIn(config.hidden_size, config.vocab_size, bias=True)
338
+
339
+ self.initialize(config.hidden_size, config.vocab_size, embedding_weights)
340
+
341
+ @torch.no_grad()
342
+ def initialize(self, hidden_size: int, vocab_size: int, embedding_weights: nn.Parameter) -> None:
343
+ proj_std: float = math.sqrt(2.0 / (hidden_size + 4*hidden_size))
344
+
345
+ nn.init.trunc_normal_(self.projection.weight, mean=0.0, std=proj_std, a=-2*proj_std, b=2*proj_std)
346
+ self.emb2vocab.weight = embedding_weights
347
+ self.emb2vocab.bias.zero_()
348
+
349
+ def project(self, hidden_layer: torch.Tensor) -> torch.Tensor:
350
+ projection: torch.Tensor
351
+
352
+ projection = self.projection(hidden_layer)
353
+ projection = gelu_new(projection)
354
+ projection = self.post_norm(projection)
355
+
356
+ return projection
357
+
358
+ def calculate_output(self, hidden_layer: torch.Tensor) -> torch.Tensor:
359
+ return self.emb2vocab(hidden_layer)
360
+
361
+ def forward(self, hidden_layer: torch.Tensor, labels: torch.Tensor | None = None) -> torch.Tensor:
362
+ output: torch.Tensor
363
+
364
+ if labels is not None:
365
+ hidden_layer = torch.index_select(hidden_layer.flatten(0, 1), 0, torch.nonzero(labels.flatten() != -100).squeeze())
366
+
367
+ hidden_layer = self.pre_norm(hidden_layer)
368
+ hidden_layer = self.project(hidden_layer)
369
+ output = self.calculate_output(hidden_layer)
370
+
371
+ return output
372
+
373
+
374
+ class SelfAttention(nn.Module):
375
+
376
+ def __init__(self, config, layer_idx) -> None:
377
+ super().__init__()
378
+ self.d_qk = config.d_qk
379
+ self.d_v = config.d_v
380
+ self.num_attention_heads = config.num_attention_heads
381
+ self.num_kv_heads = config.num_kv_heads
382
+ self.hidden_size = config.hidden_size
383
+
384
+ self.q_out_dim = self.d_qk * self.num_attention_heads
385
+ self.k_out_dim = self.d_qk * self.num_kv_heads
386
+ self.v_out_dim = self.d_v * self.num_kv_heads
387
+
388
+ self.qk_proj = MultiCastedLinearOrthoIn(self.hidden_size, [self.q_out_dim, self.k_out_dim], bias=False)
389
+ self.v_proj = CastedLinearIn(self.hidden_size, self.v_out_dim, bias=False)
390
+ self.out_proj = CastedLinearIn(self.d_v*self.num_attention_heads, self.hidden_size, bias=False)
391
+
392
+ self.pre_v_norm = nn.LayerNorm(config.hidden_size, eps=config.attention_pre_norm_eps, elementwise_affine=config.attention_pre_norm_affine)
393
+ self.pre_qk_norm = nn.LayerNorm(config.hidden_size, eps=config.attention_pre_norm_eps, elementwise_affine=config.attention_pre_norm_affine)
394
+ self.inter_norm = nn.LayerNorm(self.d_v * self.num_attention_heads, eps=config.attention_inter_norm_eps, elementwise_affine=config.attention_inter_norm_affine)
395
+ self.q_norm = nn.LayerNorm(config.d_qk, eps=config.attention_pre_norm_eps, elementwise_affine=False, bias=False)
396
+ self.k_norm = nn.LayerNorm(config.d_qk, eps=config.attention_pre_norm_eps, elementwise_affine=False, bias=False)
397
+ self.k_scale = nn.Parameter(torch.ones(self.num_kv_heads, config.d_qk))
398
+ self.q_scale = nn.Parameter(torch.ones(self.num_attention_heads, config.d_qk))
399
+
400
+ self.dropout = nn.Dropout(config.attention_output_dropout_p)
401
+
402
+ theta = 160_000 if (layer_idx + 1) % config.short_long_ratio == 0 else 10_000
403
+
404
+ self.rope_embedding = RotaryPositionalEmbeddings(config, theta)
405
+ self.scale: float = 1.0 / math.sqrt(self.d_qk)
406
+
407
+ self.dropout = nn.Dropout(config.attention_dropout if hasattr(config, "attention_dropout") else 0.0)
408
+
409
+ self.lambdas = nn.Parameter(torch.tensor([0.5]))
410
+
411
+ self.initialize()
412
+
413
+ self.sequence_length = config.max_sequence_length
414
+ self.is_causal = config.is_decoder
415
+ self.not_flex = config.not_flex
416
+
417
+ @torch.no_grad()
418
+ def initialize(self) -> None:
419
+ std: float = math.sqrt(2.0 / (self.hidden_size + 4*self.hidden_size))
420
+ for weight in self.qk_proj.weights:
421
+ nn.init.trunc_normal_(weight, mean=0.0, std=std, a=-2*std, b=2*std)
422
+ nn.init.trunc_normal_(self.v_proj.weight, mean=0.0, std=std, a=2*std, b=2*std)
423
+ self.out_proj.weight.data.zero_()
424
+
425
+ def set_window_length(self, window_length: int, not_flex: bool) -> None:
426
+ self.window_length: int = window_length
427
+ if not not_flex:
428
+ self.block_mask = self.create_block_mask(window_length)
429
+
430
+ def causal_mask_mode(self, window_length, b, _, q_idx, kv_idx):
431
+ return (q_idx >= kv_idx) & ((q_idx - kv_idx) < window_length)
432
+
433
+ def bidirectional_mask_mode(self, window_length, b, _, q_idx, kv_idx):
434
+ return ((q_idx - kv_idx) < window_length) & ((kv_idx - q_idx) < window_length)
435
+
436
+ def create_block_mask(self, window_length: int) -> torch.Tensor:
437
+ if self.is_causal:
438
+ return create_block_mask(
439
+ partial(self.causal_mask_mode, self.window_length),
440
+ 1, 1, self.sequence_length, self.sequence_length, device=self.k_scale.device
441
+ )
442
+ else:
443
+ return create_block_mask(
444
+ partial(self.bidirectional_mask_mode, self.window_length),
445
+ 1, 1, self.sequence_length, self.sequence_length, device=self.k_scale.device
446
+ )
447
+
448
+ def attention_operation(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, padding_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
449
+ attention_scores: torch.Tensor
450
+ attention_probabilities: torch.Tensor
451
+ batch_size: int
452
+ query_length: int
453
+ key_length: int
454
+
455
+ batch_size, _, query_length, _ = query.size()
456
+ _, _, key_length, _ = key.size()
457
+
458
+ if self.is_causal:
459
+ window_mask = ~torch.ones(query_length, key_length, dtype=torch.bool, device=self.k_scale.device).tril().triu(diagonal=-self.window_length).view(1, 1, query_length, key_length)
460
+ else:
461
+ window_mask = ~torch.ones(query_length, key_length, dtype=torch.bool, device=self.k_scale.device).tril(diagonal=self.window_length).triu(diagonal=-self.window_length).view(1, 1, query_length, key_length)
462
+
463
+ if padding_mask is not None:
464
+ attention_mask = padding_mask | window_mask
465
+ else:
466
+ attention_mask = window_mask
467
+
468
+ attention_scores = torch.bmm(query.flatten(0, 1), key.transpose(-1, -2).flatten(0, 1)) * self.scale # shape: [B*H, T, T]
469
+ attention_scores = attention_scores.view(batch_size, self.num_attention_heads, query_length, key_length)
470
+
471
+ attention_probabilities = MaskedSoftmax.apply(attention_scores, attention_mask, -1)
472
+ attention_probabilities = self.dropout(attention_probabilities)
473
+
474
+ value = torch.bmm(attention_probabilities.flatten(0, 1), value.flatten(0, 1))
475
+ value = value.view(batch_size, self.num_attention_heads, query_length, self.d_v)
476
+
477
+ return value, attention_probabilities.detach()
478
+
479
+ def forward(self, hidden_layer: torch.Tensor, qk_layer: torch.Tensor, v1: torch.Tensor | None, mask: torch.Tensor | None = None, doc_ids: torch.Tensor | None = None) -> Tuple[torch.Tensor, torch.Tensor]:
480
+ hidden_layer = self.pre_v_norm(hidden_layer)
481
+ qk_layer = self.pre_qk_norm(qk_layer)
482
+
483
+ query, key = self.qk_proj(qk_layer).tensor_split([self.q_out_dim], dim=-1)
484
+ value = self.v_proj(hidden_layer)
485
+
486
+ query_length: int = hidden_layer.size(0)
487
+ key_length: int = hidden_layer.size(0)
488
+ batch_size: int = hidden_layer.size(1)
489
+
490
+ query = query.reshape(query_length, batch_size, self.num_attention_heads, self.d_qk).permute(1, 2, 0, 3) # shape: [B, H, T, D]
491
+ key = key.reshape(key_length, batch_size, self.num_kv_heads, self.d_qk).permute(1, 2, 0, 3) # shape: [B, H, T, D]
492
+ value = value.reshape(key_length, batch_size, self.num_kv_heads, self.d_qk).permute(1, 2, 0, 3) # shape: [B, H, T, D]
493
+
494
+ query, key = ((self.q_scale + 1.0).unsqueeze(1).unsqueeze(0) * self.q_norm(query.float())).type_as(query), ((self.k_scale + 1.0).unsqueeze(1).unsqueeze(0) * self.k_norm(key.float())).type_as(key)
495
+
496
+ if v1 is None:
497
+ v1 = value
498
+ value = (1 - self.lambdas[0]) * value + self.lambdas[0] * v1
499
+
500
+ query = self.rope_embedding(query)
501
+ key = self.rope_embedding(key)
502
+
503
+ if self.not_flex:
504
+ output, attention_probabilities = self.attention_operation(query, key, value, mask)
505
+ else:
506
+ def document_score_mod(score, b, _, q_idx, kv_idx):
507
+ return torch.where(doc_ids[q_idx] == doc_ids[kv_idx], score, -float("inf"))
508
+
509
+ if self.is_causal:
510
+ block_mask = create_block_mask(
511
+ partial(self.causal_mask_mode, self.window_length),
512
+ 1, 1, query_length, key_length, device=self.k_scale.device
513
+ )
514
+ else:
515
+ block_mask = create_block_mask(
516
+ partial(self.bidirectional_mask_mode, self.window_length),
517
+ 1, 1, query_length, key_length, device=self.k_scale.device
518
+ )
519
+
520
+ output = flex_attention(
521
+ query, key, value, block_mask=block_mask, enable_gqa=True
522
+ )
523
+ attention_probabilities = None
524
+
525
+ output = output.permute(2, 0, 1, 3).flatten(2, 3) # shape: [T, B, H*D]
526
+ output = self.inter_norm(output)
527
+ output = self.out_proj(output)
528
+
529
+ return self.dropout(output), v1, attention_probabilities
530
+
531
+
532
+ class FeedForward(nn.Module):
533
+
534
+ def __init__(self, config) -> None:
535
+ super().__init__()
536
+
537
+ self.up_proj: CastedLinear
538
+ self.down_proj: CastedLinear
539
+ self.pre_norm: nn.LayerNorm
540
+ self.inter_norm: nn.LayerNorm
541
+ self.activation: GeGLU
542
+ self.dropout: nn.Dropout
543
+
544
+ self.pre_norm = nn.LayerNorm(config.hidden_size, eps=config.feed_forward_pre_norm_eps, elementwise_affine=config.feed_forward_pre_norm_affine)
545
+ self.up_proj = MultiCastedLinearOrthoIn(config.hidden_size, [config.intermediate_size, config.intermediate_size], bias=False)
546
+ self.activation = GeGLU()
547
+ self.inter_norm = nn.LayerNorm(config.intermediate_size, eps=config.feed_forward_inter_norm_eps, elementwise_affine=config.feed_forward_inter_norm_affine)
548
+ self.down_proj = CastedLinearIn(config.intermediate_size, config.hidden_size, bias=False)
549
+ self.dropout = nn.Dropout(config.feed_forward_dropout_p)
550
+
551
+ self.initialize(config.hidden_size)
552
+
553
+ @torch.no_grad()
554
+ def initialize(self, hidden_size: int) -> None:
555
+ std: float = math.sqrt(2.0 / (5*hidden_size))
556
+
557
+ for weight in self.up_proj.weights:
558
+ nn.init.trunc_normal_(weight, mean=0.0, std=std, a=-2*std, b=2*std)
559
+ self.down_proj.weight.data.zero_()
560
+
561
+ def up_project(self, hidden_layer: torch.Tensor) -> torch.Tensor:
562
+ hidden_layer = self.pre_norm(hidden_layer)
563
+ return self.up_proj(hidden_layer)
564
+
565
+ def activate(self, projection: torch.Tensor) -> torch.Tensor:
566
+ activated_projection: torch.Tensor
567
+
568
+ activated_projection = self.activation(projection)
569
+ activated_projection = self.inter_norm(activated_projection.float()).type_as(projection)
570
+
571
+ return activated_projection
572
+
573
+ def down_project(self, activated_projection: torch.Tensor) -> torch.Tensor:
574
+ output: torch.Tensor
575
+
576
+ output = self.down_proj(activated_projection)
577
+
578
+ return self.dropout(output)
579
+
580
+ def forward(self, hidden_layer: torch.Tensor) -> torch.Tensor:
581
+ output: torch.Tensor
582
+
583
+ output = self.up_project(hidden_layer)
584
+ output = self.activate(output)
585
+ output = self.down_project(output)
586
+
587
+ return output
588
+
589
+
590
+ class RotaryPositionalEmbeddings(nn.Module):
591
+
592
+ def __init__(self, config, theta: int) -> None:
593
+ super().__init__()
594
+
595
+ assert hasattr(config, "d_qk"), "The config must have a d_qk attribute!"
596
+ assert hasattr(config, "max_sequence_length"), "The config must have a max_sequence_length attribute!"
597
+
598
+ self.inv_freq: torch.Tensor
599
+ self.cos_matrix: torch.Tensor
600
+ self.sin_matrix: torch.Tensor
601
+ head_size: int
602
+ max_seq_len: int
603
+ inv_freq: torch.Tensor
604
+ pos: torch.Tensor
605
+ embedding: torch.Tensor
606
+
607
+ head_size = config.d_qk
608
+ assert head_size % 2 == 0
609
+ max_seq_len = config.max_sequence_length
610
+
611
+ inv_freq = 1.0 / (theta ** (torch.arange(0, head_size, 2, dtype=torch.float32) / head_size))
612
+ pos = torch.arange(max_seq_len, dtype=torch.float32)
613
+ embedding = torch.einsum('n, d -> nd', pos, inv_freq)
614
+ embedding = torch.cat([embedding, embedding], dim=-1).unsqueeze(0)
615
+ self.register_buffer("cos_matrix", embedding.cos(), persistent=False)
616
+ self.register_buffer("sin_matrix", embedding.sin(), persistent=False)
617
+
618
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
619
+ seq_len: int
620
+ cos_matrix: torch.Tensor
621
+ sin_matrix: torch.Tensor
622
+ x_rotate_half: torch.Tensor
623
+ out: torch.Tensor
624
+
625
+ hidden_layer = x.float()
626
+
627
+ seq_len = x.shape[2]
628
+
629
+ cos_matrix = self.cos_matrix[:, None, :seq_len, :]
630
+ sin_matrix = self.sin_matrix[:, None, :seq_len, :]
631
+
632
+ x_rotate_half = torch.cat(
633
+ [
634
+ -hidden_layer[:, :, :, x.size(-1) // 2:],
635
+ hidden_layer[:, :, :, :x.size(-1) // 2]
636
+ ],
637
+ dim=-1
638
+ )
639
+
640
+ out = hidden_layer * cos_matrix + x_rotate_half * sin_matrix
641
+ return out.type_as(x)
642
+
643
+
644
+ #
645
+ # HuggingFace wrappers
646
+ #
647
+
648
+ class GptBertPreTrainedModel(PreTrainedModel):
649
+ config_class = GptBertConfig
650
+ supports_gradient_checkpointing = False
651
+
652
+ def _set_gradient_checkpointing(self, module, value=False):
653
+ raise NotImplementedError("Gradient checkpointing is not supported by this model")
654
+
655
+ def _init_weights(self, module):
656
+ std = math.sqrt(2.0 / (5.0 * self.hidden_size))
657
+
658
+ if isinstance(module, nn.Linear):
659
+ nn.init.trunc_normal_(module.weight.data, mean=0.0, std=std, a=-2*std, b=2*std)
660
+ if module.bias is not None:
661
+ module.bias.data.zero_()
662
+ elif isinstance(module, nn.Embedding):
663
+ nn.init.trunc_normal_(module.weight.data, mean=0.0, std=std, a=-2*std, b=2*std)
664
+ elif isinstance(module, nn.LayerNorm):
665
+ module.bias.data.zero_()
666
+ module.weight.data.fill_(1.0)
667
+
668
+
669
+ class GptBertModel(GptBertPreTrainedModel):
670
+
671
+ def __init__(self, config, add_mlm_layer=False, **kwargs):
672
+ super().__init__(config, **kwargs)
673
+ self.config = config
674
+ self.hidden_size = config.hidden_size
675
+
676
+ self.embedding = Embedding(config)
677
+ self.encoder = Encoder(config)
678
+ self.classifier = MaskClassifier(config, self.embedding.word_embedding.weight) if add_mlm_layer else None
679
+ self.set_window_length(config)
680
+
681
+ def set_window_length(self, config) -> None:
682
+ self.encoder.set_window_length(config)
683
+
684
+ def get_input_embeddings(self):
685
+ return self.embedding.word_embedding
686
+
687
+ def set_input_embeddings(self, value):
688
+ self.embedding.word_embedding = value
689
+
690
+ def get_contextualized_embeddings(
691
+ self,
692
+ input_ids: Optional[torch.Tensor] = None,
693
+ attention_mask: Optional[torch.Tensor] = None
694
+ ) -> List[torch.Tensor]:
695
+ if input_ids is not None:
696
+ input_shape = input_ids.size()
697
+ else:
698
+ raise ValueError("You have to specify input_ids")
699
+
700
+ batch_size, seq_length = input_shape
701
+ device = input_ids.device
702
+
703
+ # if attention_mask is None:
704
+ # attention_mask = torch.zeros(batch_size, seq_length, dtype=torch.bool, device=device)
705
+ if attention_mask is not None:
706
+ attention_mask = ~attention_mask.bool()
707
+
708
+ if len(attention_mask.size()) == 2:
709
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
710
+ elif len(attention_mask.size()) == 3:
711
+ attention_mask = attention_mask.unsqueeze(1)
712
+
713
+ if self.config.is_decoder:
714
+ attention_mask = attention_mask | torch.triu(torch.ones(seq_length, seq_length, dtype=torch.bool, device=device), 1).unsqueeze(0).unsqueeze(0)
715
+
716
+ static_embeddings = self.embedding(input_ids.t())
717
+ contextualized_embeddings, attention_probs = self.encoder(static_embeddings, static_embeddings, attention_mask)
718
+ contextualized_embeddings = [e.transpose(0, 1) for e in contextualized_embeddings]
719
+ last_layer = contextualized_embeddings[-1]
720
+ contextualized_embeddings = [contextualized_embeddings[0]] + [
721
+ contextualized_embeddings[i] - contextualized_embeddings[i - 1]
722
+ for i in range(1, len(contextualized_embeddings))
723
+ ]
724
+ return last_layer, contextualized_embeddings, attention_probs
725
+
726
+ def forward(
727
+ self,
728
+ input_ids: Optional[torch.Tensor] = None,
729
+ attention_mask: Optional[torch.Tensor] = None,
730
+ token_type_ids: Optional[torch.Tensor] = None,
731
+ position_ids: Optional[torch.Tensor] = None,
732
+ output_hidden_states: Optional[bool] = None,
733
+ output_attentions: Optional[bool] = None,
734
+ return_dict: Optional[bool] = None,
735
+ **kwargs
736
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutput]:
737
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
738
+
739
+ sequence_output, contextualized_embeddings, attention_probs = self.get_contextualized_embeddings(input_ids, attention_mask)
740
+
741
+ if not return_dict:
742
+ return (
743
+ sequence_output,
744
+ *([contextualized_embeddings] if output_hidden_states else []),
745
+ *([attention_probs] if output_attentions else [])
746
+ )
747
+
748
+ return BaseModelOutput(
749
+ last_hidden_state=sequence_output,
750
+ hidden_states=contextualized_embeddings if output_hidden_states else None,
751
+ attentions=attention_probs if output_attentions else None
752
+ )
753
+
754
+
755
+ class GptBertForMaskedLM(GptBertModel):
756
+ _keys_to_ignore_on_load_unexpected = ["head"]
757
+
758
+ def __init__(self, config, **kwargs):
759
+ super().__init__(config, add_mlm_layer=True, **kwargs)
760
+
761
+ def get_output_embeddings(self):
762
+ return self.classifier.emb2vocab.weight
763
+
764
+ def set_output_embeddings(self, new_embeddings):
765
+ self.classifier.emb2vocab.weight = new_embeddings
766
+
767
+ def forward(
768
+ self,
769
+ input_ids: Optional[torch.Tensor] = None,
770
+ attention_mask: Optional[torch.Tensor] = None,
771
+ token_type_ids: Optional[torch.Tensor] = None,
772
+ position_ids: Optional[torch.Tensor] = None,
773
+ output_hidden_states: Optional[bool] = None,
774
+ output_attentions: Optional[bool] = None,
775
+ return_dict: Optional[bool] = None,
776
+ labels: Optional[torch.LongTensor] = None,
777
+ **kwargs
778
+ ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
779
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
780
+
781
+ sequence_output, contextualized_embeddings, attention_probs = self.get_contextualized_embeddings(input_ids, attention_mask)
782
+ subword_prediction = self.classifier(sequence_output)
783
+ subword_prediction = 30 * torch.sigmoid(subword_prediction / 7.5)
784
+
785
+ masked_lm_loss = None
786
+ if labels is not None:
787
+ labels_flatten = labels[:, 1:].flatten()
788
+ subword_prediction_flatten = subword_prediction[:, :-1].flatten(0, 1)
789
+ masked_lm_loss = F.cross_entropy(subword_prediction_flatten, labels_flatten)
790
+
791
+ if not return_dict:
792
+ output = (
793
+ subword_prediction,
794
+ *([contextualized_embeddings] if output_hidden_states else []),
795
+ *([attention_probs] if output_attentions else [])
796
+ )
797
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
798
+
799
+ return MaskedLMOutput(
800
+ loss=masked_lm_loss,
801
+ logits=subword_prediction,
802
+ hidden_states=contextualized_embeddings if output_hidden_states else None,
803
+ attentions=attention_probs if output_attentions else None
804
+ )
805
+
806
+
807
+ class Classifier(nn.Module):
808
+ def __init__(self, config, num_labels: int):
809
+ super().__init__()
810
+
811
+ drop_out = getattr(config, "cls_dropout", None)
812
+ drop_out = config.hidden_dropout_prob if drop_out is None else drop_out
813
+
814
+ self.projection: CastedLinear
815
+ self.emb2vocab: CastedLinear
816
+ self.pre_norm: nn.LayerNorm
817
+ self.post_norm: nn.LayerNorm
818
+
819
+ self.pre_norm = nn.LayerNorm(config.hidden_size, eps=config.classifier_pre_norm_eps, elementwise_affine=config.classifier_pre_norm_affine)
820
+ self.projection = CastedLinear(config.hidden_size, config.hidden_size, bias=False)
821
+ self.post_norm = nn.LayerNorm(config.hidden_size, eps=config.classifier_post_norm_eps, elementwise_affine=config.classifier_post_norm_affine)
822
+ self.emb2vocab = CastedLinear(config.hidden_size, num_labels, bias=True)
823
+ self.dropout = nn.Dropout(drop_out)
824
+
825
+ self.initialize(config.hidden_size, config.intermediate_size, num_labels)
826
+
827
+ @torch.no_grad()
828
+ def initialize(self, hidden_size: int, intermediate_size: int, vocab_size: int) -> None:
829
+ proj_std: float = math.sqrt(2.0 / (hidden_size + intermediate_size))
830
+
831
+ nn.init.trunc_normal_(self.projection.weight, mean=0.0, std=proj_std, a=-2*proj_std, b=2*proj_std)
832
+ nn.init.trunc_normal_(self.emb2vocab.weight, mean=0.0, std=proj_std, a=-2*proj_std, b=2*proj_std)
833
+ self.emb2vocab.bias.zero_()
834
+
835
+ def project(self, hidden_layer: torch.Tensor) -> torch.Tensor:
836
+ projection: torch.Tensor
837
+
838
+ projection = self.pre_norm(hidden_layer)
839
+ projection = self.dropout(projection)
840
+ projection = self.projection(hidden_layer)
841
+ projection = gelu_new(projection)
842
+ projection = self.post_norm(projection)
843
+
844
+ return projection
845
+
846
+ def calculate_output(self, hidden_layer: torch.Tensor) -> torch.Tensor:
847
+ return self.emb2vocab(hidden_layer)
848
+
849
+ def forward(self, hidden_layer: torch.Tensor) -> torch.Tensor:
850
+ output: torch.Tensor
851
+ projection: torch.Tensor
852
+
853
+ projection = self.project(hidden_layer)
854
+ output = self.calculate_output(projection)
855
+
856
+ return output
857
+
858
+
859
+ class GptBertForCausalLM(GptBertModel):
860
+ _keys_to_ignore_on_load_unexpected = ["head"]
861
+
862
+ def __init__(self, config, **kwargs):
863
+ config.is_decoder = True
864
+ super().__init__(config, add_mlm_layer=True, **kwargs)
865
+
866
+ def get_output_embeddings(self):
867
+ return self.classifier.emb2vocab.weight
868
+
869
+ def set_output_embeddings(self, new_embeddings):
870
+ self.classifier.emb2vocab.weight = new_embeddings
871
+
872
+ def get_input_embeddings(self):
873
+ return self.embedding.word_embedding
874
+
875
+ def set_input_embeddings(self, value):
876
+ self.embedding.word_embedding = value
877
+
878
+ def set_decoder(self, decoder):
879
+ self.encoder = decoder
880
+
881
+ def get_decoder(self):
882
+ return self.encoder
883
+
884
+ def can_generate(self):
885
+ return True
886
+
887
+ def forward(
888
+ self,
889
+ input_ids: torch.LongTensor = None,
890
+ attention_mask: Optional[torch.Tensor] = None,
891
+ position_ids: Optional[torch.LongTensor] = None,
892
+ token_type_ids: Optional[torch.Tensor] = None,
893
+ past_key_values: Optional[torch.Tensor] = None,
894
+ inputs_embeds: Optional[torch.FloatTensor] = None,
895
+ labels: Optional[torch.LongTensor] = None,
896
+ use_cache: Optional[bool] = None,
897
+ cache_position: Optional[torch.LongTensor] = None,
898
+ output_attentions: Optional[bool] = None,
899
+ output_hidden_states: Optional[bool] = None,
900
+ return_dict: Optional[bool] = None
901
+ ) -> Union[Tuple, CausalLMOutput]:
902
+
903
+ assert inputs_embeds is None, "inputs_embeds is not supported for now"
904
+ assert past_key_values is None, "past_key_values is not supported for now"
905
+ assert not use_cache, "use_cache is not supported for now"
906
+
907
+ sequence_output, contextualized_embeddings, attention_probs = self.get_contextualized_embeddings(input_ids, attention_mask)
908
+ subword_prediction = self.classifier(sequence_output)
909
+ subword_prediction = 30 * torch.sigmoid(subword_prediction / 7.5)
910
+
911
+ masked_lm_loss = None
912
+ if labels is not None:
913
+ labels_flatten = labels[:, 1:].flatten()
914
+ subword_prediction_flatten = subword_prediction[:, :-1].flatten(0, 1)
915
+ masked_lm_loss = F.cross_entropy(subword_prediction_flatten, labels_flatten)
916
+
917
+ if not return_dict:
918
+ output = (
919
+ subword_prediction,
920
+ *([contextualized_embeddings] if output_hidden_states else []),
921
+ *([attention_probs] if output_attentions else [])
922
+ )
923
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
924
+
925
+ return CausalLMOutput(
926
+ loss=masked_lm_loss,
927
+ logits=subword_prediction,
928
+ hidden_states=contextualized_embeddings if output_hidden_states else None,
929
+ attentions=attention_probs if output_attentions else None
930
+ )
931
+
932
+ def prepare_inputs_for_generation(
933
+ self,
934
+ input_ids: torch.Tensor,
935
+ past_key_values: Optional[torch.Tensor] = None,
936
+ attention_mask: Optional[torch.Tensor] = None,
937
+ inputs_embeds: Optional[torch.Tensor] = None,
938
+ cache_position: Optional[torch.LongTensor] = None,
939
+ position_ids: Optional[torch.LongTensor] = None,
940
+ use_cache: bool = True,
941
+ num_logits_to_keep: Optional[int] = None,
942
+ **kwargs,
943
+ ):
944
+ # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
945
+ # Exception 1: when passing input_embeds, input_ids may be missing entries
946
+ # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
947
+ if past_key_values is not None:
948
+ if inputs_embeds is not None: # Exception 1
949
+ input_ids = input_ids[:, -cache_position.shape[0] :]
950
+ elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
951
+ input_ids = input_ids[:, cache_position]
952
+
953
+ if attention_mask is not None and position_ids is None:
954
+ # create position_ids on the fly for batch generation
955
+ position_ids = attention_mask.long().cumsum(-1) - 1
956
+ position_ids.masked_fill_(attention_mask == 0, 1)
957
+ if past_key_values:
958
+ position_ids = position_ids[:, -input_ids.shape[1] :]
959
+
960
+ # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
961
+ position_ids = position_ids.clone(memory_format=torch.contiguous_format)
962
+
963
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
964
+ if inputs_embeds is not None and cache_position[0] == 0:
965
+ model_inputs = {"inputs_embeds": inputs_embeds}
966
+ else:
967
+ model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
968
+
969
+ if num_logits_to_keep is not None:
970
+ model_inputs["num_logits_to_keep"] = num_logits_to_keep
971
+
972
+ model_inputs.update(
973
+ {
974
+ "position_ids": position_ids,
975
+ "cache_position": cache_position,
976
+ "past_key_values": past_key_values,
977
+ "use_cache": use_cache,
978
+ "attention_mask": attention_mask,
979
+ }
980
+ )
981
+ return model_inputs
982
+
983
+
984
+ class GptBertForSequenceClassification(GptBertModel):
985
+ _keys_to_ignore_on_load_unexpected = ["classifier"]
986
+ _keys_to_ignore_on_load_missing = ["head"]
987
+
988
+ def __init__(self, config, **kwargs):
989
+ super().__init__(config, add_mlm_layer=False, **kwargs)
990
+
991
+ self.num_labels = config.num_labels
992
+ self.head = Classifier(config, self.num_labels)
993
+
994
+ def forward(
995
+ self,
996
+ input_ids: Optional[torch.Tensor] = None,
997
+ attention_mask: Optional[torch.Tensor] = None,
998
+ token_type_ids: Optional[torch.Tensor] = None,
999
+ position_ids: Optional[torch.Tensor] = None,
1000
+ output_attentions: Optional[bool] = None,
1001
+ output_hidden_states: Optional[bool] = None,
1002
+ return_dict: Optional[bool] = None,
1003
+ labels: Optional[torch.LongTensor] = None,
1004
+ **kwargs
1005
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
1006
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1007
+
1008
+ sequence_output, contextualized_embeddings, attention_probs = self.get_contextualized_embeddings(input_ids, attention_mask)
1009
+ logits = self.head(sequence_output[:, 0, :])
1010
+
1011
+ loss = None
1012
+ if labels is not None:
1013
+ if self.config.problem_type is None:
1014
+ if self.num_labels == 1:
1015
+ self.config.problem_type = "regression"
1016
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1017
+ self.config.problem_type = "single_label_classification"
1018
+ else:
1019
+ self.config.problem_type = "multi_label_classification"
1020
+
1021
+ if self.config.problem_type == "regression":
1022
+ loss_fct = nn.MSELoss()
1023
+ if self.num_labels == 1:
1024
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1025
+ else:
1026
+ loss = loss_fct(logits, labels)
1027
+ elif self.config.problem_type == "single_label_classification":
1028
+ loss_fct = nn.CrossEntropyLoss()
1029
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1030
+ elif self.config.problem_type == "multi_label_classification":
1031
+ loss_fct = nn.BCEWithLogitsLoss()
1032
+ loss = loss_fct(logits, labels)
1033
+
1034
+ if not return_dict:
1035
+ output = (
1036
+ logits,
1037
+ *([contextualized_embeddings] if output_hidden_states else []),
1038
+ *([attention_probs] if output_attentions else [])
1039
+ )
1040
+ return ((loss,) + output) if loss is not None else output
1041
+
1042
+ return SequenceClassifierOutput(
1043
+ loss=loss,
1044
+ logits=logits,
1045
+ hidden_states=contextualized_embeddings if output_hidden_states else None,
1046
+ attentions=attention_probs if output_attentions else None
1047
+ )
1048
+
1049
+
1050
+ class GptBertForTokenClassification(GptBertModel):
1051
+ _keys_to_ignore_on_load_unexpected = ["classifier"]
1052
+ _keys_to_ignore_on_load_missing = ["head"]
1053
+
1054
+ def __init__(self, config, **kwargs):
1055
+ super().__init__(config, add_mlm_layer=False, **kwargs)
1056
+
1057
+ self.num_labels = config.num_labels
1058
+ self.head = Classifier(config, self.num_labels)
1059
+
1060
+ def forward(
1061
+ self,
1062
+ input_ids: Optional[torch.Tensor] = None,
1063
+ attention_mask: Optional[torch.Tensor] = None,
1064
+ token_type_ids: Optional[torch.Tensor] = None,
1065
+ position_ids: Optional[torch.Tensor] = None,
1066
+ output_attentions: Optional[bool] = None,
1067
+ output_hidden_states: Optional[bool] = None,
1068
+ return_dict: Optional[bool] = None,
1069
+ labels: Optional[torch.LongTensor] = None,
1070
+ **kwargs
1071
+ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
1072
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1073
+
1074
+ sequence_output, contextualized_embeddings, attention_probs = self.get_contextualized_embeddings(input_ids, attention_mask)
1075
+ logits = self.head(sequence_output)
1076
+
1077
+ loss = None
1078
+ if labels is not None:
1079
+ loss_fct = nn.CrossEntropyLoss()
1080
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1081
+
1082
+ if not return_dict:
1083
+ output = (
1084
+ logits,
1085
+ *([contextualized_embeddings] if output_hidden_states else []),
1086
+ *([attention_probs] if output_attentions else [])
1087
+ )
1088
+ return ((loss,) + output) if loss is not None else output
1089
+
1090
+ return TokenClassifierOutput(
1091
+ loss=loss,
1092
+ logits=logits,
1093
+ hidden_states=contextualized_embeddings if output_hidden_states else None,
1094
+ attentions=attention_probs if output_attentions else None
1095
+ )
1096
+
1097
+
1098
+ class GptBertForQuestionAnswering(GptBertModel):
1099
+ _keys_to_ignore_on_load_unexpected = ["classifier"]
1100
+ _keys_to_ignore_on_load_missing = ["head"]
1101
+
1102
+ def __init__(self, config, **kwargs):
1103
+ super().__init__(config, add_mlm_layer=False, **kwargs)
1104
+
1105
+ self.num_labels = config.num_labels
1106
+ self.head = Classifier(config, self.num_labels)
1107
+
1108
+ def forward(
1109
+ self,
1110
+ input_ids: Optional[torch.Tensor] = None,
1111
+ attention_mask: Optional[torch.Tensor] = None,
1112
+ token_type_ids: Optional[torch.Tensor] = None,
1113
+ position_ids: Optional[torch.Tensor] = None,
1114
+ output_attentions: Optional[bool] = None,
1115
+ output_hidden_states: Optional[bool] = None,
1116
+ return_dict: Optional[bool] = None,
1117
+ start_positions: Optional[torch.Tensor] = None,
1118
+ end_positions: Optional[torch.Tensor] = None,
1119
+ **kwargs
1120
+ ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
1121
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1122
+
1123
+ sequence_output, contextualized_embeddings, attention_probs = self.get_contextualized_embeddings(input_ids, attention_mask)
1124
+ logits = self.head(sequence_output)
1125
+
1126
+ start_logits, end_logits = logits.split(1, dim=-1)
1127
+ start_logits = start_logits.squeeze(-1).contiguous()
1128
+ end_logits = end_logits.squeeze(-1).contiguous()
1129
+
1130
+ total_loss = None
1131
+ if start_positions is not None and end_positions is not None:
1132
+ # If we are on multi-GPU, split add a dimension
1133
+ if len(start_positions.size()) > 1:
1134
+ start_positions = start_positions.squeeze(-1)
1135
+ if len(end_positions.size()) > 1:
1136
+ end_positions = end_positions.squeeze(-1)
1137
+
1138
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1139
+ ignored_index = start_logits.size(1)
1140
+ start_positions = start_positions.clamp(0, ignored_index)
1141
+ end_positions = end_positions.clamp(0, ignored_index)
1142
+
1143
+ loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index)
1144
+ start_loss = loss_fct(start_logits, start_positions)
1145
+ end_loss = loss_fct(end_logits, end_positions)
1146
+ total_loss = (start_loss + end_loss) / 2
1147
+
1148
+ if not return_dict:
1149
+ output = (
1150
+ start_logits,
1151
+ end_logits,
1152
+ *([contextualized_embeddings] if output_hidden_states else []),
1153
+ *([attention_probs] if output_attentions else [])
1154
+ )
1155
+ return ((total_loss,) + output) if total_loss is not None else output
1156
+
1157
+ return QuestionAnsweringModelOutput(
1158
+ loss=total_loss,
1159
+ start_logits=start_logits,
1160
+ end_logits=end_logits,
1161
+ hidden_states=contextualized_embeddings if output_hidden_states else None,
1162
+ attentions=attention_probs if output_attentions else None
1163
+ )
1164
+
1165
+
1166
+ class GptBertForMultipleChoice(GptBertModel):
1167
+ _keys_to_ignore_on_load_unexpected = ["classifier"]
1168
+ _keys_to_ignore_on_load_missing = ["head"]
1169
+
1170
+ def __init__(self, config, **kwargs):
1171
+ super().__init__(config, add_mlm_layer=False, **kwargs)
1172
+
1173
+ self.num_labels = getattr(config, "num_labels", 2)
1174
+ self.head = Classifier(config, self.num_labels)
1175
+
1176
+ def forward(
1177
+ self,
1178
+ input_ids: Optional[torch.Tensor] = None,
1179
+ attention_mask: Optional[torch.Tensor] = None,
1180
+ token_type_ids: Optional[torch.Tensor] = None,
1181
+ position_ids: Optional[torch.Tensor] = None,
1182
+ labels: Optional[torch.Tensor] = None,
1183
+ output_attentions: Optional[bool] = None,
1184
+ output_hidden_states: Optional[bool] = None,
1185
+ return_dict: Optional[bool] = None,
1186
+ **kwargs
1187
+ ) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]:
1188
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1189
+ num_choices = input_ids.shape[1]
1190
+
1191
+ flat_input_ids = input_ids.view(-1, input_ids.size(-1))
1192
+ flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
1193
+
1194
+ sequence_output, contextualized_embeddings, attention_probs = self.get_contextualized_embeddings(flat_input_ids, flat_attention_mask)
1195
+ logits = self.head(sequence_output)
1196
+ reshaped_logits = logits.view(-1, num_choices)
1197
+
1198
+ loss = None
1199
+ if labels is not None:
1200
+ loss_fct = nn.CrossEntropyLoss()
1201
+ loss = loss_fct(reshaped_logits, labels)
1202
+
1203
+ if not return_dict:
1204
+ output = (
1205
+ reshaped_logits,
1206
+ *([contextualized_embeddings] if output_hidden_states else []),
1207
+ *([attention_probs] if output_attentions else [])
1208
+ )
1209
+ return ((loss,) + output) if loss is not None else output
1210
+
1211
+ return MultipleChoiceModelOutput(
1212
+ loss=loss,
1213
+ logits=reshaped_logits,
1214
+ hidden_states=contextualized_embeddings if output_hidden_states else None,
1215
+ attentions=attention_probs if output_attentions else None
1216
+ )
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7babbe2f0f59391dedc55eb4609596d3981e87bc62f877633851e494720cb95e
3
+ size 597611298
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"bos_token": "<s>", "eos_token": "</s>", "unk_token": "<unk>", "sep_token": "</s>", "pad_token": "<oad>", "cls_token": "<s>", "mask_token": "<mask>"}
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "tokenizer_class": "PreTrainedTokenizerFast",
3
+ "bos_token": "<s>",
4
+ "eos_token": "</s>",
5
+ "unk_token": "<unk>",
6
+ "sep_token": "</s>",
7
+ "pad_token": "<pad>",
8
+ "cls_token": "<s>",
9
+ "mask_token": "<mask>"
10
+ }