codewithdark commited on
Commit
aba9604
·
verified ·
1 Parent(s): e13544e

Update modeling_arabic-gpt.py

Browse files
Files changed (1) hide show
  1. modeling_arabic-gpt.py +194 -49
modeling_arabic-gpt.py CHANGED
@@ -10,8 +10,6 @@ from tqdm import tqdm
10
  from transformers import PreTrainedModel
11
  from transformers import PretrainedConfig
12
 
13
- from transformers import PretrainedConfig
14
-
15
  class ArabicGPTConfig(PretrainedConfig):
16
  model_type = "arabic-gpt"
17
 
@@ -35,9 +33,6 @@ class ArabicGPTConfig(PretrainedConfig):
35
  self.tie_word_embeddings = True
36
 
37
 
38
- import torch
39
- import torch.nn as nn
40
- from transformers import PreTrainedModel
41
 
42
  class ArabicGPTModel(PreTrainedModel):
43
  config_class = ArabicGPTConfig
@@ -72,59 +67,209 @@ class ArabicGPTModel(PreTrainedModel):
72
  def tie_weights(self):
73
  self.model.lm_head.weight = self.model.token_embedding.weight
74
 
75
- class ArabicGPTConfig(PretrainedConfig):
76
- model_type = "arabic-gpt"
77
 
78
- def __init__(self,
79
- vocab_size=32000,
80
- max_seq_len=1024,
81
- embed_dim=768,
82
- num_heads=12,
83
- num_layers=12,
84
- ff_dim=3072,
85
- dropout=0.1,
86
- **kwargs):
87
- super().__init__(**kwargs)
88
- self.vocab_size = vocab_size
89
- self.max_seq_len = max_seq_len
90
- self.embed_dim = embed_dim
91
- self.num_heads = num_heads
92
- self.num_layers = num_layers
93
- self.ff_dim = ff_dim
94
- self.dropout = dropout
95
- self.tie_word_embeddings = True
96
 
 
 
 
97
 
98
- class ArabicGPTModel(PreTrainedModel):
99
- config_class = ArabicGPTConfig
 
 
100
 
101
- def __init__(self, config: ArabicGPTConfig):
102
- super().__init__(config)
103
- self.model = ArabicGPT(
104
- vocab_size=config.vocab_size,
105
- max_seq_len=config.max_seq_len,
106
- embed_dim=config.embed_dim,
107
- num_heads=config.num_heads,
108
- num_layers=config.num_layers,
109
- ff_dim=config.ff_dim,
110
- dropout=config.dropout,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  )
112
 
113
  def forward(self, x):
114
- return self.model(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
  def generate(self, prompt_ids, max_new_tokens, temperature=1.0, top_k=50, top_p=0.9):
117
- return self.model.generate(prompt_ids, max_new_tokens, temperature=1.0, top_k=50, top_p=0.9)
 
 
 
 
 
118
 
119
- def get_input_embeddings(self):
120
- return self.model.token_embedding
 
 
121
 
122
- def set_input_embeddings(self, new_embeddings):
123
- self.model.token_embedding = new_embeddings
124
-
125
- def get_output_embeddings(self):
126
- return self.model.lm_head
127
-
128
- def tie_weights(self):
129
- self.model.lm_head.weight = self.model.token_embedding.weight
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
 
10
  from transformers import PreTrainedModel
11
  from transformers import PretrainedConfig
12
 
 
 
13
  class ArabicGPTConfig(PretrainedConfig):
14
  model_type = "arabic-gpt"
15
 
 
33
  self.tie_word_embeddings = True
34
 
35
 
 
 
 
36
 
37
  class ArabicGPTModel(PreTrainedModel):
38
  config_class = ArabicGPTConfig
 
67
  def tie_weights(self):
68
  self.model.lm_head.weight = self.model.token_embedding.weight
69
 
 
 
70
 
71
+ # Part 2: GPT Model Implementation
72
+ class AttentionHead(nn.Module):
73
+ def __init__(self, embed_dim, head_dim, mask=True):
74
+ super().__init__()
75
+ self.q = nn.Linear(embed_dim, head_dim)
76
+ self.k = nn.Linear(embed_dim, head_dim)
77
+ self.v = nn.Linear(embed_dim, head_dim)
78
+ self.mask = mask
79
+ self.scale = head_dim ** -0.5
 
 
 
 
 
 
 
 
 
80
 
81
+ def forward(self, x):
82
+ # x shape: (batch, seq_len, embed_dim)
83
+ batch_size, seq_len, _ = x.shape
84
 
85
+ # Linear projections
86
+ q = self.q(x) # (batch, seq_len, head_dim)
87
+ k = self.k(x) # (batch, seq_len, head_dim)
88
+ v = self.v(x) # (batch, seq_len, head_dim)
89
 
90
+ # Compute attention scores
91
+ attn = torch.bmm(q, k.transpose(1, 2)) * self.scale # (batch, seq_len, seq_len)
92
+
93
+ # Apply causal mask for decoder
94
+ if self.mask:
95
+ mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device), diagonal=1).bool()
96
+ attn.masked_fill_(mask, float('-inf'))
97
+
98
+ # Apply softmax and get weighted values
99
+ attn = F.softmax(attn, dim=-1)
100
+ output = torch.bmm(attn, v) # (batch, seq_len, head_dim)
101
+
102
+ return output
103
+
104
+ class MultiHeadAttention(nn.Module):
105
+ def __init__(self, embed_dim, num_heads, mask=True):
106
+ super().__init__()
107
+ self.heads = nn.ModuleList([
108
+ AttentionHead(embed_dim, embed_dim // num_heads, mask)
109
+ for _ in range(num_heads)
110
+ ])
111
+ self.linear = nn.Linear(embed_dim, embed_dim)
112
+
113
+ def forward(self, x):
114
+ # Concatenate outputs from all heads
115
+ heads_output = torch.cat([head(x) for head in self.heads], dim=-1)
116
+ # Final linear projection
117
+ output = self.linear(heads_output)
118
+ return output
119
+
120
+ class FeedForward(nn.Module):
121
+ def __init__(self, embed_dim, ff_dim):
122
+ super().__init__()
123
+ self.net = nn.Sequential(
124
+ nn.Linear(embed_dim, ff_dim),
125
+ nn.GELU(),
126
+ nn.Linear(ff_dim, embed_dim)
127
  )
128
 
129
  def forward(self, x):
130
+ return self.net(x)
131
+
132
+ class TransformerBlock(nn.Module):
133
+ def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1):
134
+ super().__init__()
135
+ self.attn = MultiHeadAttention(embed_dim, num_heads)
136
+ self.ff = FeedForward(embed_dim, ff_dim)
137
+ self.norm1 = nn.LayerNorm(embed_dim)
138
+ self.norm2 = nn.LayerNorm(embed_dim)
139
+ self.dropout = nn.Dropout(dropout)
140
+
141
+ def forward(self, x):
142
+ # Self-attention with residual connection and layer norm
143
+ attn_output = self.attn(self.norm1(x))
144
+ x = x + self.dropout(attn_output)
145
+
146
+ # Feed-forward with residual connection and layer norm
147
+ ff_output = self.ff(self.norm2(x))
148
+ x = x + self.dropout(ff_output)
149
+
150
+ return x
151
+
152
+ class ArabicGPT(nn.Module):
153
+ def __init__(self, vocab_size, max_seq_len=1024, embed_dim=768, num_heads=12,
154
+ num_layers=12, ff_dim=3072, dropout=0.1):
155
+ super().__init__()
156
+ self.max_seq_len = max_seq_len
157
+ self.token_embedding = nn.Embedding(vocab_size, embed_dim)
158
+ self.position_embedding = nn.Embedding(max_seq_len, embed_dim)
159
+
160
+ # Transformer blocks
161
+ self.blocks = nn.ModuleList([
162
+ TransformerBlock(embed_dim, num_heads, ff_dim, dropout)
163
+ for _ in range(num_layers)
164
+ ])
165
+
166
+ # Final layer norm
167
+ self.norm = nn.LayerNorm(embed_dim)
168
+
169
+ # Language model head
170
+ self.lm_head = nn.Linear(embed_dim, vocab_size, bias=False)
171
+
172
+ # Share weights between token embedding and LM head
173
+ # self.lm_head.weight = self.token_embedding.weight
174
+
175
+ # Initialize weights
176
+ self.apply(self._init_weights)
177
+
178
+ def _init_weights(self, module):
179
+ if isinstance(module, nn.Linear):
180
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
181
+ if module.bias is not None:
182
+ torch.nn.init.zeros_(module.bias)
183
+ elif isinstance(module, nn.Embedding):
184
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
185
+ elif isinstance(module, nn.LayerNorm):
186
+ torch.nn.init.zeros_(module.bias)
187
+ torch.nn.init.ones_(module.weight)
188
+
189
+ def forward(self, x):
190
+ # x shape: (batch, seq_len)
191
+ batch_size, seq_len = x.shape
192
+
193
+ # Get positions
194
+ positions = torch.arange(0, seq_len, device=x.device).unsqueeze(0).expand(batch_size, -1)
195
+
196
+ # Get token and position embeddings
197
+ token_embed = self.token_embedding(x)
198
+ pos_embed = self.position_embedding(positions)
199
+
200
+ # Combine embeddings
201
+ x = token_embed + pos_embed
202
+
203
+ # Apply transformer blocks
204
+ for block in self.blocks:
205
+ x = block(x)
206
+
207
+ # Apply final layer norm
208
+ x = self.norm(x)
209
+
210
+ # Get logits
211
+ logits = self.lm_head(x)
212
+
213
+ return logits
214
 
215
  def generate(self, prompt_ids, max_new_tokens, temperature=1.0, top_k=50, top_p=0.9):
216
+ """Generate text using the model."""
217
+ self.eval()
218
+ with torch.no_grad():
219
+ # Convert prompt to tensor if needed
220
+ if not isinstance(prompt_ids, torch.Tensor):
221
+ prompt_ids = torch.tensor(prompt_ids, dtype=torch.long)
222
 
223
+ # Move to device and add batch dimension if needed
224
+ if len(prompt_ids.shape) == 1:
225
+ prompt_ids = prompt_ids.unsqueeze(0)
226
+ prompt_ids = prompt_ids.to(next(self.parameters()).device)
227
 
228
+ # Start with prompt
229
+ generated_ids = prompt_ids.clone()
230
+
231
+ # Generate new tokens
232
+ for _ in range(max_new_tokens):
233
+ # Take last context up to max sequence length
234
+ input_ids = generated_ids[:, -self.max_seq_len:]
235
+
236
+ # Get logits for next token
237
+ logits = self(input_ids)
238
+ next_token_logits = logits[:, -1, :]
239
+
240
+ # Apply temperature
241
+ if temperature > 0:
242
+ next_token_logits = next_token_logits / temperature
243
+
244
+ # Apply top-k filtering
245
+ if top_k > 0:
246
+ indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
247
+ next_token_logits[indices_to_remove] = float('-inf')
248
+
249
+ # Apply top-p (nucleus) filtering
250
+ if top_p < 1.0:
251
+ sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
252
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
253
+
254
+ # Remove tokens with cumulative probability above the threshold
255
+ sorted_indices_to_remove = cumulative_probs > top_p
256
+ # Shift the indices to the right to keep the first token above threshold
257
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
258
+ sorted_indices_to_remove[..., 0] = 0
259
+
260
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
261
+ next_token_logits[:, indices_to_remove] = float('-inf')
262
+
263
+ # Sample next token
264
+ probs = F.softmax(next_token_logits, dim=-1)
265
+ next_token = torch.multinomial(probs, num_samples=1)
266
+
267
+ # Append next token to generated
268
+ generated_ids = torch.cat([generated_ids, next_token], dim=1)
269
+
270
+ # Stop if EOS token
271
+ if next_token.item() == 2: # Standard EOS token id
272
+ break
273
+
274
+ return generated_ids
275