Bochkov commited on
Commit
e603189
·
verified ·
1 Parent(s): ae7f181

Upload 3 files

Browse files
Files changed (2) hide show
  1. config.json +14 -3
  2. modeling_bvv_best.py +243 -0
config.json CHANGED
@@ -1,9 +1,20 @@
1
  {
2
- "model_type": "PreTrainedTokenizerFast",
3
- "model_type": "gpt2",
 
 
 
 
 
 
 
 
 
 
 
4
  "bos_token": "<s>",
5
  "eos_token": "</s>",
6
  "unk_token": "<unk>",
7
  "pad_token": "<pad>",
8
- "vocab_size": 131072
9
  }
 
1
  {
2
+ "architectures": ["BVVBestForCausalLM"],
3
+ "auto_map": {
4
+ "AutoConfig": "modeling_bvv_best.BVVBestConfig",
5
+ "AutoModel": "modeling_bvv_best.BVVBestForCausalLM",
6
+ "AutoModelForCausalLM": "modeling_bvv_best.BVVBestForCausalLM"
7
+ },
8
+ "model_type": "bvv_best",
9
+ "vocab_size": 131072,
10
+ "block_size ": 1024,
11
+ "n_embd": 1024,
12
+ "n_layer": 16,
13
+ "n_head": 32,
14
+ "pad_id": 57344,
15
  "bos_token": "<s>",
16
  "eos_token": "</s>",
17
  "unk_token": "<unk>",
18
  "pad_token": "<pad>",
19
+ "torch_dtype": "float32"
20
  }
modeling_bvv_best.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel, PretrainedConfig
2
+ from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.nn import functional as F
6
+ from transformers.modeling_outputs import CausalLMOutput
7
+
8
+ class BVVBestConfig(PretrainedConfig):
9
+ model_type = "bvv_best"
10
+
11
+ def __init__(
12
+ self,
13
+ vocab_size = 131072,
14
+ n_embd = 1024,
15
+ n_head = 32,
16
+ n_layer = 16,
17
+ block_size = 1024,
18
+ pad_id = 57344,
19
+ **kwargs
20
+ ):
21
+ super().__init__(**kwargs)
22
+ self.vocab_size = vocab_size
23
+ self.block_size = block_size
24
+ self.n_embd = n_embd
25
+ self.n_layer = n_layer
26
+ self.n_head = n_head
27
+ self.pad_id = pad_id
28
+
29
+ class RotaryEmbedding(nn.Module):
30
+ def __init__(self, dim): # dim = head_dim (?? n_embd!)
31
+ super().__init__()
32
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
33
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
34
+
35
+ def forward(self, seq_len, device):
36
+ t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
37
+ freqs = torch.einsum('i,j->ij', t, self.inv_freq)
38
+ emb = torch.cat([freqs, freqs], dim=-1) # (seq_len, dim)
39
+ return emb
40
+
41
+ def apply_rotary_emb(x, rot_emb):
42
+ # x: (B, n_head, seq_len, head_dim)
43
+ # rot_emb: (seq_len, head_dim)
44
+ seq_len = x.shape[-2]
45
+ rot_emb = rot_emb[:seq_len]
46
+
47
+ cos = torch.cos(rot_emb).unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, head_dim)
48
+ sin = torch.sin(rot_emb).unsqueeze(0).unsqueeze(0)
49
+
50
+ x_shape = x.shape
51
+ x = x.reshape(*x_shape[:-1], -1, 2) # (..., head_dim/2, 2)
52
+ x1 = x[..., 0]
53
+ x2 = x[..., 1]
54
+
55
+ cos = cos.reshape(*cos.shape[:-1], -1, 2)[..., 0]
56
+ sin = sin.reshape(*sin.shape[:-1], -1, 2)[..., 0]
57
+
58
+ x1_rot = x1 * cos - x2 * sin
59
+ x2_rot = x1 * sin + x2 * cos
60
+
61
+ x_rot = torch.stack([x1_rot, x2_rot], dim=-1)
62
+ return x_rot.reshape(x_shape)
63
+
64
+ class MultiHeadSelfAttention(nn.Module):
65
+ def __init__(self, n_embd, n_head, block_size):
66
+ super().__init__()
67
+ assert n_embd % n_head == 0
68
+ self.n_embd = n_embd
69
+ self.n_head = n_head
70
+ self.head_dim = n_embd // n_head
71
+
72
+ self.q_proj = nn.Linear(n_embd, n_embd, bias=False)
73
+ self.k_proj = nn.Linear(n_embd, n_embd, bias=False)
74
+ self.v_proj = nn.Linear(n_embd, n_embd, bias=False)
75
+ self.o_proj = nn.Linear(n_embd, n_embd, bias=False)
76
+
77
+ self.rotary_emb = RotaryEmbedding(self.head_dim)
78
+ self.dropout = nn.Dropout(0.0)
79
+
80
+ self.register_buffer(
81
+ "tril", torch.tril(torch.ones(block_size, block_size)), persistent=False
82
+ )
83
+
84
+ def forward(self, x):
85
+ # x: (B, T, n_embd)
86
+ B, T, C = x.shape
87
+
88
+ q = self.q_proj(x) # (B, T, n_embd)
89
+ k = self.k_proj(x)
90
+ v = self.v_proj(x)
91
+
92
+ q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2) # (B, n_head, T, head_dim)
93
+ k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
94
+ v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
95
+
96
+ # Rotary embeddings
97
+ rot_emb = self.rotary_emb(seq_len=T, device=x.device) # (T, head_dim)
98
+ q = apply_rotary_emb(q, rot_emb)
99
+ k = apply_rotary_emb(k, rot_emb)
100
+
101
+ # Attention
102
+ attn_scores = torch.matmul(q, k.transpose(-2, -1)) * (self.head_dim ** -0.5) # (B, n_head, T, T)
103
+ attn_scores = attn_scores.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
104
+ attn_probs = F.softmax(attn_scores, dim=-1)
105
+ attn_probs = self.dropout(attn_probs)
106
+
107
+ out = torch.matmul(attn_probs, v) # (B, n_head, T, head_dim)
108
+ out = out.transpose(1, 2).contiguous().view(B, T, C) # (B, T, n_embd)
109
+
110
+ return self.o_proj(out)
111
+
112
+
113
+ class TransformerMLP(nn.Module):
114
+ def __init__(self, n_embd):
115
+ super().__init__()
116
+ self.net = nn.Sequential(
117
+ nn.Linear(n_embd, 4 * n_embd),
118
+ nn.GELU(),
119
+ nn.Linear(4 * n_embd, n_embd),
120
+ nn.Dropout(0.0),
121
+ )
122
+
123
+ def forward(self, x):
124
+ return self.net(x)
125
+
126
+ class TransformerBlock(nn.Module):
127
+ def __init__(self, n_embd, n_head, block_size):
128
+ super().__init__()
129
+ self.self_attn = MultiHeadSelfAttention(n_embd, n_head, block_size)
130
+ self.mlp = TransformerMLP(n_embd)
131
+ self.input_layernorm = nn.LayerNorm(n_embd)
132
+ self.post_attention_layernorm = nn.LayerNorm(n_embd)
133
+
134
+ def forward(self, x):
135
+ x = x + self.self_attn(self.input_layernorm(x))
136
+ x = x + self.mlp(self.post_attention_layernorm(x))
137
+ return x
138
+
139
+ class BVVBestForCausalLM(PreTrainedModel):
140
+ config_class = BVVBestConfig
141
+
142
+ def __init__(self, config):
143
+ super().__init__(config)
144
+ self.token_embeddings = nn.Embedding(config.vocab_size, config.n_embd)
145
+
146
+ self.transformer_layers = nn.Sequential(*[
147
+ TransformerBlock(config.n_embd, n_head=config.n_head, block_size=config.block_size) for _ in range(config.n_layer)
148
+ ])
149
+ self.final_layernorm = nn.LayerNorm(config.n_embd)
150
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size)
151
+
152
+ self.apply(self._init_weights)
153
+
154
+ def _init_weights(self, module):
155
+ if isinstance(module, nn.Linear):
156
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
157
+ if module.bias is not None:
158
+ torch.nn.init.zeros_(module.bias)
159
+ elif isinstance(module, nn.Embedding):
160
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
161
+
162
+
163
+ def forward(self, idx, targets=None):
164
+ B, T = idx.shape
165
+
166
+ x = self.token_embeddings(idx)
167
+
168
+ x = self.transformer_layers(x)
169
+ x = self.final_layernorm(x)
170
+ logits = self.lm_head(x)
171
+
172
+ loss = None
173
+ if targets is not None:
174
+ #logits_flat = logits.view(-1, logits.size(-1))
175
+ #targets_flat = targets.view(-1)
176
+ logits_flat = logits.reshape(-1, logits.size(-1))
177
+ targets_flat = targets.reshape(-1)
178
+ loss = F.cross_entropy(logits_flat, targets_flat, ignore_index = 57344)
179
+
180
+ return CausalLMOutput(
181
+ logits=logits,
182
+ loss=loss,
183
+ )
184
+
185
+ def generate(self,
186
+ input_ids=None,
187
+ max_new_tokens=None,
188
+ max_length=None,
189
+ temperature=1.0,
190
+ top_k=None,
191
+ top_p=None,
192
+ do_sample=True,
193
+ pad_token_id=None,
194
+ eos_token_id=None,
195
+ **kwargs):
196
+
197
+ if input_ids is None:
198
+ raise ValueError("Input_ids must be provided")
199
+
200
+ idx = input_ids
201
+
202
+ if max_new_tokens is None:
203
+ if max_length is not None:
204
+ max_new_tokens = max_length - idx.shape[1]
205
+ else:
206
+ max_new_tokens = 50
207
+
208
+ with torch.no_grad():
209
+ for _ in range(max_new_tokens):
210
+ idx_cond = idx[:, -self.config.block_size:]
211
+
212
+ outputs = self(idx_cond)
213
+ logits = outputs.logits[:, -1, :] / temperature
214
+
215
+ if top_k is not None:
216
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
217
+ logits[logits < v[:, [-1]]] = float('-inf')
218
+
219
+ if top_p is not None:
220
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
221
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
222
+
223
+ sorted_indices_to_remove = cumulative_probs > top_p
224
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
225
+ sorted_indices_to_remove[..., 0] = 0
226
+
227
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
228
+ logits[indices_to_remove] = float('-inf')
229
+
230
+ probs = F.softmax(logits, dim=-1)
231
+
232
+ if do_sample:
233
+ idx_next = torch.multinomial(probs, num_samples=1)
234
+ else:
235
+ idx_next = torch.argmax(logits, dim=-1, keepdim=True)
236
+
237
+ idx = torch.cat((idx, idx_next), dim=1)
238
+
239
+
240
+ if eos_token_id is not None and (idx_next == eos_token_id).any():
241
+ break
242
+
243
+ return idx