Bochkov commited on
Commit
9ab5614
·
verified ·
1 Parent(s): 9a26907

Upload 3 files

Browse files
Files changed (3) hide show
  1. config.json +14 -3
  2. modeling_bvv_pro.py +156 -0
  3. tokenizer_config.json +5 -2
config.json CHANGED
@@ -1,9 +1,20 @@
1
  {
2
- "model_type": "PreTrainedTokenizerFast",
3
- "model_type": "gpt2",
 
 
 
 
 
 
 
 
 
 
 
4
  "bos_token": "<s>",
5
  "eos_token": "</s>",
6
  "unk_token": "<unk>",
7
  "pad_token": "<pad>",
8
- "vocab_size": 65536
9
  }
 
1
  {
2
+ "architectures": ["BVVProForCausalLM"],
3
+ "auto_map": {
4
+ "AutoConfig": "modeling_bvv_pro.BVVProConfig",
5
+ "AutoModel": "modeling_bvv_pro.BVVProForCausalLM",
6
+ "AutoModelForCausalLM": "modeling_bvv_pro.BVVProForCausalLM"
7
+ },
8
+ "model_type": "bvv_pro",
9
+ "vocab_size": 65536,
10
+ "block_size ": 1024,
11
+ "n_embd": 1024,
12
+ "n_layer": 8,
13
+ "n_head": 8,
14
+ "pad_id": 57344,
15
  "bos_token": "<s>",
16
  "eos_token": "</s>",
17
  "unk_token": "<unk>",
18
  "pad_token": "<pad>",
19
+ "torch_dtype": "float32"
20
  }
modeling_bvv_pro.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel, PretrainedConfig
2
+ from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.nn import functional as F
6
+ from transformers.modeling_outputs import CausalLMOutput
7
+
8
+ class BVVProConfig(PretrainedConfig):
9
+ model_type = "bvv_pro"
10
+
11
+ def __init__(
12
+ self,
13
+ vocab_size = 65536,
14
+ n_embd = 1024,
15
+ n_head = 8,
16
+ n_layer = 8,
17
+ block_size = 1024,
18
+ pad_id = 57344,
19
+ **kwargs
20
+ ):
21
+ super().__init__(**kwargs)
22
+ self.vocab_size = vocab_size
23
+ self.block_size = block_size
24
+ self.n_embd = n_embd
25
+ self.n_layer = n_layer
26
+ self.n_head = n_head
27
+ self.pad_id = pad_id
28
+
29
+ class SimpleSelfAttentionHead(nn.Module):
30
+ def __init__(self, head_size, n_embd, block_size):
31
+ super().__init__()
32
+ self.q_proj = nn.Linear(n_embd, head_size, bias=False)
33
+ self.k_proj = nn.Linear(n_embd, head_size, bias=False)
34
+ self.v_proj = nn.Linear(n_embd, head_size, bias=False)
35
+ self.o_proj = nn.Linear(head_size, head_size, bias=False)
36
+
37
+ self.dropout = nn.Dropout(0.0)
38
+ self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
39
+
40
+ def forward(self, x):
41
+ B, T, C = x.shape
42
+ q = self.q_proj(x) # (B,T,head_size)
43
+ k = self.k_proj(x)
44
+ v = self.v_proj(x)
45
+
46
+ attn_scores = q @ k.transpose(-2, -1) * (k.shape[-1] ** -0.5) # (B,T,T)
47
+
48
+ attn_scores = attn_scores.masked_fill(self.tril[:T, :T] == 0, torch.finfo(attn_scores.dtype).min) #float('-inf'))
49
+
50
+ attn_probs = F.softmax(attn_scores, dim=-1)
51
+ attn_probs = self.dropout(attn_probs)
52
+
53
+ out = attn_probs @ v # (B,T,head_size)
54
+ out = self.o_proj(out) # (B,T,head_size)
55
+
56
+ return out
57
+
58
+ class SimpleMultiHeadSelfAttention(nn.Module):
59
+ def __init__(self, n_embd, n_head, block_size):
60
+ super().__init__()
61
+ self.head_size = n_embd // n_head
62
+ self.heads = nn.ModuleList([SimpleSelfAttentionHead(self.head_size, n_embd, block_size) for _ in range(n_head)])
63
+ self.out_proj = nn.Linear(n_head * self.head_size, n_embd)
64
+ self.dropout = nn.Dropout(0.0)
65
+
66
+ def forward(self, x):
67
+ out = torch.cat([head(x) for head in self.heads], dim=-1)
68
+ out = self.dropout(self.out_proj(out))
69
+ return out
70
+
71
+ class TransformerMLP(nn.Module):
72
+ def __init__(self, n_embd):
73
+ super().__init__()
74
+ self.net = nn.Sequential(
75
+ nn.Linear(n_embd, 4 * n_embd),
76
+ nn.GELU(),
77
+ nn.Linear(4 * n_embd, n_embd),
78
+ nn.Dropout(0.0),
79
+ )
80
+
81
+ def forward(self, x):
82
+ return self.net(x)
83
+
84
+ class TransformerBlock(nn.Module):
85
+ def __init__(self, n_embd, n_head, block_size):
86
+ super().__init__()
87
+ self.self_attn = SimpleMultiHeadSelfAttention(n_embd, n_head, block_size)
88
+ self.mlp = TransformerMLP(n_embd)
89
+ self.input_layernorm = nn.LayerNorm(n_embd)
90
+ self.post_attention_layernorm = nn.LayerNorm(n_embd)
91
+
92
+ def forward(self, x):
93
+ x = x + self.self_attn(self.input_layernorm(x))
94
+ x = x + self.mlp(self.post_attention_layernorm(x))
95
+ return x
96
+
97
+ class BVVProForCausalLM(PreTrainedModel):
98
+ config_class = BVVProConfig
99
+
100
+ def __init__(self, config):
101
+ super().__init__(config)
102
+ self.token_embeddings = nn.Embedding(config.vocab_size, config.n_embd)
103
+ self.position_embeddings = nn.Embedding(config.block_size, config.n_embd)
104
+ self.transformer_layers = nn.Sequential(*[
105
+ TransformerBlock(config.n_embd, n_head=config.n_head, block_size=config.block_size) for _ in range(config.n_layer)
106
+ ])
107
+ self.final_layernorm = nn.LayerNorm(config.n_embd)
108
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size)
109
+
110
+ self.apply(self._init_weights)
111
+
112
+ def _init_weights(self, module):
113
+ if isinstance(module, nn.Linear):
114
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
115
+ if module.bias is not None:
116
+ torch.nn.init.zeros_(module.bias)
117
+ elif isinstance(module, nn.Embedding):
118
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
119
+
120
+
121
+ def forward(self, idx, targets=None):
122
+ B, T = idx.shape
123
+
124
+ positions = torch.arange(0, T, device=idx.device).unsqueeze(0).expand(B, T)
125
+
126
+ x = self.token_embeddings(idx) + self.position_embeddings(positions)
127
+
128
+ x = self.transformer_layers(x)
129
+ x = self.final_layernorm(x)
130
+ logits = self.lm_head(x)
131
+
132
+ loss = None
133
+ if targets is not None:
134
+ #logits_flat = logits.view(-1, logits.size(-1))
135
+ #targets_flat = targets.view(-1)
136
+ logits_flat = logits.reshape(-1, logits.size(-1))
137
+ targets_flat = targets.reshape(-1)
138
+ loss = F.cross_entropy(logits_flat, targets_flat, ignore_index = 57344)
139
+
140
+ return CausalLMOutput(
141
+ logits=logits,
142
+ loss=loss,
143
+ )
144
+
145
+ def generate(self, idx, max_new_tokens):
146
+ with torch.no_grad():
147
+ for _ in range(max_new_tokens):
148
+ idx_cond = idx[:, -self.config.block_size:]
149
+ outputs = self(idx_cond)
150
+ logits = outputs.logits
151
+ logits = logits[:, -1, :]
152
+ probs = F.softmax(logits, dim=-1)
153
+ idx_next = torch.multinomial(probs, num_samples=1)
154
+ idx = torch.cat((idx, idx_next), dim=1)
155
+
156
+ return idx
tokenizer_config.json CHANGED
@@ -1,6 +1,9 @@
1
  {
 
 
 
 
2
  "unk_token": "<unk>",
3
  "pad_token": "<pad>",
4
- "bos_token": "<s>",
5
- "eos_token": "</s>"
6
  }
 
1
  {
2
+ "tokenizer_class": "PreTrainedTokenizerFast",
3
+ "model_type": "gpt2",
4
+ "bos_token": "<s>",
5
+ "eos_token": "</s>",
6
  "unk_token": "<unk>",
7
  "pad_token": "<pad>",
8
+ "vocab_size": 65536
 
9
  }