Bertug1911 commited on
Commit
86f7a43
·
verified ·
1 Parent(s): 80e26ee

Upload Br-T-1-m-pre.py

Browse files

Main model file (Python) (py/Py)

Files changed (1) hide show
  1. Br-T-1-m-pre.py +237 -0
Br-T-1-m-pre.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tensorflow.keras.models import Model
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ import torch.nn.functional as F
6
+ import math
7
+ from torch.utils.data import DataLoader
8
+ from datasets import load_dataset
9
+
10
+ ##############################
11
+ # 1) Veri Seti Alma + İşleme #
12
+ ##############################
13
+
14
+ raw_dataset = load_dataset("wikitext", "wikitext-2-raw-v1")
15
+
16
+ # Demo için sınırlı sayıda örnek kullanılıyor.
17
+ train_data = raw_dataset["train"][:3000]
18
+ val_data = raw_dataset["validation"][:500]
19
+
20
+ train_texts = [example for example in train_data["text"]]
21
+ val_texts = [example for example in val_data["text"]]
22
+
23
+ class MyTokenizer:
24
+ def __init__(self, vocab_size=15000, max_length=64):
25
+ self.vocab_size = vocab_size
26
+ self.max_length = max_length
27
+ # [PAD] ve [UNK] tokenlarını tanımlıyoruz.
28
+ self.PAD = "[PAD]"
29
+ self.UNK = "[UNK]"
30
+ self.pad_id = 0
31
+ self.unk_id = 1
32
+ self.word2id = {self.PAD: self.pad_id, self.UNK: self.unk_id}
33
+ self.id2word = {self.pad_id: self.PAD, self.unk_id: self.UNK}
34
+
35
+ def build_vocab(self, all_texts):
36
+ from collections import Counter
37
+ freq = Counter()
38
+ for line in all_texts:
39
+ tokens = line.strip().split()
40
+ freq.update(tokens)
41
+ most_common = freq.most_common(self.vocab_size - len(self.word2id))
42
+ idx = len(self.word2id)
43
+ for word, count in most_common:
44
+ if word not in self.word2id:
45
+ self.word2id[word] = idx
46
+ self.id2word[idx] = word
47
+ idx += 1
48
+
49
+ def encode(self, text):
50
+ tokens = text.strip().split()
51
+ token_ids = [self.word2id.get(t, self.unk_id) for t in tokens]
52
+ token_ids = token_ids[:self.max_length]
53
+ token_ids += [self.pad_id] * (self.max_length - len(token_ids))
54
+ return token_ids
55
+
56
+ def decode(self, token_ids):
57
+ words = []
58
+ for tid in token_ids:
59
+ if tid in self.id2word and self.id2word[tid] != self.PAD:
60
+ words.append(self.id2word[tid])
61
+ return " ".join(words)
62
+
63
+ # Tokenizer'ı oluştur ve sözlüğü inşa et.
64
+ my_tokenizer = MyTokenizer(vocab_size=15000, max_length=64)
65
+ my_tokenizer.build_vocab(train_texts)
66
+ print(f"Vocab boyutu: {len(my_tokenizer.word2id)}")
67
+
68
+ def tokenize_function(text):
69
+ return {"input_ids": my_tokenizer.encode(text)}
70
+
71
+ train_encodings = list(map(tokenize_function, train_texts))
72
+ val_encodings = list(map(tokenize_function, val_texts))
73
+
74
+ class WikiTextDataset:
75
+ def __init__(self, encodings):
76
+ self.encodings = encodings
77
+ def __len__(self):
78
+ return len(self.encodings)
79
+ def __getitem__(self, idx):
80
+ item = self.encodings[idx]
81
+ input_ids = torch.tensor(item["input_ids"], dtype=torch.long)
82
+ attn_mask = (input_ids != my_tokenizer.pad_id).long()
83
+ return {"input_ids": input_ids, "attention_mask": attn_mask}
84
+
85
+ train_dataset = WikiTextDataset(train_encodings)
86
+ val_dataset = WikiTextDataset(val_encodings)
87
+
88
+ train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
89
+ val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)
90
+
91
+ ##############################
92
+ # 2) Transformer Model #
93
+ ##############################
94
+
95
+ def generate_square_subsequent_mask(sz):
96
+ mask = torch.triu(torch.ones(sz, sz), diagonal=1)
97
+ mask = mask.masked_fill(mask == 1, float('-inf'))
98
+ return mask
99
+
100
+ class PositionalEncoding(nn.Module):
101
+ def __init__(self, d_model, max_len=5000):
102
+ super().__init__()
103
+ pe = torch.zeros(max_len, d_model)
104
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
105
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
106
+ pe[:, 0::2] = torch.sin(position * div_term)
107
+ pe[:, 1::2] = torch.cos(position * div_term)
108
+ pe = pe.unsqueeze(0) # (1, max_len, d_model)
109
+ self.register_buffer('pe', pe)
110
+
111
+ def forward(self, x):
112
+ seq_len = x.size(1)
113
+ return x + self.pe[:, :seq_len, :]
114
+
115
+ class TransformerLM(nn.Module):
116
+ def __init__(self, vocab_size, d_model=256, nhead=8, num_layers=4, dim_feedforward=1024, dropout=0.1):
117
+ super().__init__()
118
+ self.d_model = d_model
119
+ self.embedding = nn.Embedding(vocab_size, d_model)
120
+ self.pos_encoder = PositionalEncoding(d_model)
121
+ encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, activation='relu')
122
+ self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers)
123
+ self.fc_out = nn.Linear(d_model, vocab_size)
124
+
125
+ def forward(self, input_ids, attention_mask=None):
126
+ embedded = self.embedding(input_ids) # (batch, seq_len, d_model)
127
+ embedded = self.pos_encoder(embedded)
128
+ embedded = embedded.permute(1, 0, 2) # (seq_len, batch, d_model)
129
+ seq_len = embedded.size(0)
130
+ mask = generate_square_subsequent_mask(seq_len).to(embedded.device)
131
+ encoded = self.transformer_encoder(embedded, mask=mask)
132
+ encoded = encoded.permute(1, 0, 2) # (batch, seq_len, d_model)
133
+ logits = self.fc_out(encoded)
134
+ return logits
135
+
136
+ vocab_size = len(my_tokenizer.word2id)
137
+ model = TransformerLM(
138
+ vocab_size=vocab_size,
139
+ d_model=256,
140
+ nhead=8,
141
+ num_layers=4,
142
+ dim_feedforward=1024,
143
+ dropout=0.1
144
+ )
145
+
146
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
147
+ model.to(device)
148
+
149
+ ##############################
150
+ # 3) Eğitim (Mixed Precision)#
151
+ ##############################
152
+
153
+ optimizer = optim.Adam(model.parameters(), lr=1e-4)
154
+ loss_fn = nn.CrossEntropyLoss()
155
+
156
+ # Demo için epoch sayısını 3'e çıkarıyoruz (gerçek uygulamalarda çok daha fazlasına ihtiyaç var).
157
+ num_epochs = 1
158
+ scaler = torch.cuda.amp.GradScaler() if device.type == 'cuda' else None
159
+
160
+ for epoch in range(num_epochs):
161
+ model.train()
162
+ total_loss = 0
163
+ for batch in train_loader:
164
+ input_ids = batch["input_ids"].to(device)
165
+ optimizer.zero_grad()
166
+ if scaler:
167
+ with torch.cuda.amp.autocast():
168
+ logits = model(input_ids)
169
+ shift_logits = logits[:, :-1, :].contiguous()
170
+ shift_labels = input_ids[:, 1:].contiguous()
171
+ loss = loss_fn(shift_logits.view(-1, vocab_size), shift_labels.view(-1))
172
+ scaler.scale(loss).backward()
173
+ scaler.step(optimizer)
174
+ scaler.update()
175
+ else:
176
+ logits = model(input_ids)
177
+ shift_logits = logits[:, :-1, :].contiguous()
178
+ shift_labels = input_ids[:, 1:].contiguous()
179
+ loss = loss_fn(shift_logits.view(-1, vocab_size), shift_labels.view(-1))
180
+ loss.backward()
181
+ optimizer.step()
182
+ total_loss += loss.item()
183
+ avg_loss = total_loss / len(train_loader)
184
+ print(f"Epoch {epoch+1}, Loss: {avg_loss:.4f}")
185
+
186
+ ##############################
187
+ # 4) Soru-Cevap Üretimi #
188
+ ##############################
189
+
190
+ def generate_text(prompt: str, max_new_tokens=30, temperature=1.5, top_k=200):
191
+ model.eval()
192
+ # Prompt'u tokenize et ve tensor'a çevir.
193
+ input_ids = torch.tensor([my_tokenizer.encode(prompt)], dtype=torch.long).to(device)
194
+ original_length = input_ids.shape[1]
195
+ with torch.no_grad():
196
+ for _ in range(max_new_tokens):
197
+ logits = model(input_ids) # (batch, seq_len, vocab_size)
198
+ next_token_logits = logits[:, -1, :] / temperature
199
+
200
+ k = min(top_k, next_token_logits.size(-1))
201
+ values, indices = torch.topk(next_token_logits, k)
202
+ filtered_logits = torch.full_like(next_token_logits, float('-inf'))
203
+ filtered_logits.scatter_(1, indices, values)
204
+ probs = F.softmax(filtered_logits, dim=-1)
205
+
206
+ next_token = torch.multinomial(probs, num_samples=1)
207
+ input_ids = torch.cat([input_ids, next_token], dim=1)
208
+ full_output = my_tokenizer.decode(input_ids[0].cpu().numpy().tolist())
209
+ # Sadece üretilen yeni tokenlar:
210
+ new_tokens = my_tokenizer.decode(input_ids[0][original_length:].cpu().numpy().tolist())
211
+ return full_output, new_tokens
212
+
213
+ question_prompt = "What is the capital of France?"
214
+ full_generated_text, new_generated_text = generate_text(question_prompt, max_new_tokens=20, temperature=1.5, top_k=200)
215
+
216
+ print("\nPrompt:", question_prompt)
217
+ print("Full Output (Prompt + Üretilen):", full_generated_text)
218
+ print("Yalnızca Üretilen Tokenlar:", new_generated_text)
219
+
220
+ ##############################
221
+ # 5) Değerlendirme #
222
+ ##############################
223
+
224
+ model.eval()
225
+ total_val_loss = 0
226
+ with torch.no_grad():
227
+ for batch in val_loader:
228
+ input_ids = batch["input_ids"].to(device)
229
+ logits = model(input_ids)
230
+ shift_logits = logits[:, :-1, :].contiguous()
231
+ shift_labels = input_ids[:, 1:].contiguous()
232
+ loss = loss_fn(shift_logits.view(-1, vocab_size), shift_labels.view(-1))
233
+ total_val_loss += loss.item()
234
+ avg_val_loss = total_val_loss / len(val_loader)
235
+ print(f"Validation Loss: {avg_val_loss:.4f}")
236
+
237
+ torch.save(model, "brt-1 mmlu.pth")