Upload Br-T-1-m-pre.py
Browse filesMain model file (Python) (py/Py)
- Br-T-1-m-pre.py +237 -0
Br-T-1-m-pre.py
ADDED
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tensorflow.keras.models import Model
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.optim as optim
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import math
|
7 |
+
from torch.utils.data import DataLoader
|
8 |
+
from datasets import load_dataset
|
9 |
+
|
10 |
+
##############################
|
11 |
+
# 1) Veri Seti Alma + İşleme #
|
12 |
+
##############################
|
13 |
+
|
14 |
+
raw_dataset = load_dataset("wikitext", "wikitext-2-raw-v1")
|
15 |
+
|
16 |
+
# Demo için sınırlı sayıda örnek kullanılıyor.
|
17 |
+
train_data = raw_dataset["train"][:3000]
|
18 |
+
val_data = raw_dataset["validation"][:500]
|
19 |
+
|
20 |
+
train_texts = [example for example in train_data["text"]]
|
21 |
+
val_texts = [example for example in val_data["text"]]
|
22 |
+
|
23 |
+
class MyTokenizer:
|
24 |
+
def __init__(self, vocab_size=15000, max_length=64):
|
25 |
+
self.vocab_size = vocab_size
|
26 |
+
self.max_length = max_length
|
27 |
+
# [PAD] ve [UNK] tokenlarını tanımlıyoruz.
|
28 |
+
self.PAD = "[PAD]"
|
29 |
+
self.UNK = "[UNK]"
|
30 |
+
self.pad_id = 0
|
31 |
+
self.unk_id = 1
|
32 |
+
self.word2id = {self.PAD: self.pad_id, self.UNK: self.unk_id}
|
33 |
+
self.id2word = {self.pad_id: self.PAD, self.unk_id: self.UNK}
|
34 |
+
|
35 |
+
def build_vocab(self, all_texts):
|
36 |
+
from collections import Counter
|
37 |
+
freq = Counter()
|
38 |
+
for line in all_texts:
|
39 |
+
tokens = line.strip().split()
|
40 |
+
freq.update(tokens)
|
41 |
+
most_common = freq.most_common(self.vocab_size - len(self.word2id))
|
42 |
+
idx = len(self.word2id)
|
43 |
+
for word, count in most_common:
|
44 |
+
if word not in self.word2id:
|
45 |
+
self.word2id[word] = idx
|
46 |
+
self.id2word[idx] = word
|
47 |
+
idx += 1
|
48 |
+
|
49 |
+
def encode(self, text):
|
50 |
+
tokens = text.strip().split()
|
51 |
+
token_ids = [self.word2id.get(t, self.unk_id) for t in tokens]
|
52 |
+
token_ids = token_ids[:self.max_length]
|
53 |
+
token_ids += [self.pad_id] * (self.max_length - len(token_ids))
|
54 |
+
return token_ids
|
55 |
+
|
56 |
+
def decode(self, token_ids):
|
57 |
+
words = []
|
58 |
+
for tid in token_ids:
|
59 |
+
if tid in self.id2word and self.id2word[tid] != self.PAD:
|
60 |
+
words.append(self.id2word[tid])
|
61 |
+
return " ".join(words)
|
62 |
+
|
63 |
+
# Tokenizer'ı oluştur ve sözlüğü inşa et.
|
64 |
+
my_tokenizer = MyTokenizer(vocab_size=15000, max_length=64)
|
65 |
+
my_tokenizer.build_vocab(train_texts)
|
66 |
+
print(f"Vocab boyutu: {len(my_tokenizer.word2id)}")
|
67 |
+
|
68 |
+
def tokenize_function(text):
|
69 |
+
return {"input_ids": my_tokenizer.encode(text)}
|
70 |
+
|
71 |
+
train_encodings = list(map(tokenize_function, train_texts))
|
72 |
+
val_encodings = list(map(tokenize_function, val_texts))
|
73 |
+
|
74 |
+
class WikiTextDataset:
|
75 |
+
def __init__(self, encodings):
|
76 |
+
self.encodings = encodings
|
77 |
+
def __len__(self):
|
78 |
+
return len(self.encodings)
|
79 |
+
def __getitem__(self, idx):
|
80 |
+
item = self.encodings[idx]
|
81 |
+
input_ids = torch.tensor(item["input_ids"], dtype=torch.long)
|
82 |
+
attn_mask = (input_ids != my_tokenizer.pad_id).long()
|
83 |
+
return {"input_ids": input_ids, "attention_mask": attn_mask}
|
84 |
+
|
85 |
+
train_dataset = WikiTextDataset(train_encodings)
|
86 |
+
val_dataset = WikiTextDataset(val_encodings)
|
87 |
+
|
88 |
+
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
|
89 |
+
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)
|
90 |
+
|
91 |
+
##############################
|
92 |
+
# 2) Transformer Model #
|
93 |
+
##############################
|
94 |
+
|
95 |
+
def generate_square_subsequent_mask(sz):
|
96 |
+
mask = torch.triu(torch.ones(sz, sz), diagonal=1)
|
97 |
+
mask = mask.masked_fill(mask == 1, float('-inf'))
|
98 |
+
return mask
|
99 |
+
|
100 |
+
class PositionalEncoding(nn.Module):
|
101 |
+
def __init__(self, d_model, max_len=5000):
|
102 |
+
super().__init__()
|
103 |
+
pe = torch.zeros(max_len, d_model)
|
104 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
105 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
|
106 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
107 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
108 |
+
pe = pe.unsqueeze(0) # (1, max_len, d_model)
|
109 |
+
self.register_buffer('pe', pe)
|
110 |
+
|
111 |
+
def forward(self, x):
|
112 |
+
seq_len = x.size(1)
|
113 |
+
return x + self.pe[:, :seq_len, :]
|
114 |
+
|
115 |
+
class TransformerLM(nn.Module):
|
116 |
+
def __init__(self, vocab_size, d_model=256, nhead=8, num_layers=4, dim_feedforward=1024, dropout=0.1):
|
117 |
+
super().__init__()
|
118 |
+
self.d_model = d_model
|
119 |
+
self.embedding = nn.Embedding(vocab_size, d_model)
|
120 |
+
self.pos_encoder = PositionalEncoding(d_model)
|
121 |
+
encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, activation='relu')
|
122 |
+
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers)
|
123 |
+
self.fc_out = nn.Linear(d_model, vocab_size)
|
124 |
+
|
125 |
+
def forward(self, input_ids, attention_mask=None):
|
126 |
+
embedded = self.embedding(input_ids) # (batch, seq_len, d_model)
|
127 |
+
embedded = self.pos_encoder(embedded)
|
128 |
+
embedded = embedded.permute(1, 0, 2) # (seq_len, batch, d_model)
|
129 |
+
seq_len = embedded.size(0)
|
130 |
+
mask = generate_square_subsequent_mask(seq_len).to(embedded.device)
|
131 |
+
encoded = self.transformer_encoder(embedded, mask=mask)
|
132 |
+
encoded = encoded.permute(1, 0, 2) # (batch, seq_len, d_model)
|
133 |
+
logits = self.fc_out(encoded)
|
134 |
+
return logits
|
135 |
+
|
136 |
+
vocab_size = len(my_tokenizer.word2id)
|
137 |
+
model = TransformerLM(
|
138 |
+
vocab_size=vocab_size,
|
139 |
+
d_model=256,
|
140 |
+
nhead=8,
|
141 |
+
num_layers=4,
|
142 |
+
dim_feedforward=1024,
|
143 |
+
dropout=0.1
|
144 |
+
)
|
145 |
+
|
146 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
147 |
+
model.to(device)
|
148 |
+
|
149 |
+
##############################
|
150 |
+
# 3) Eğitim (Mixed Precision)#
|
151 |
+
##############################
|
152 |
+
|
153 |
+
optimizer = optim.Adam(model.parameters(), lr=1e-4)
|
154 |
+
loss_fn = nn.CrossEntropyLoss()
|
155 |
+
|
156 |
+
# Demo için epoch sayısını 3'e çıkarıyoruz (gerçek uygulamalarda çok daha fazlasına ihtiyaç var).
|
157 |
+
num_epochs = 1
|
158 |
+
scaler = torch.cuda.amp.GradScaler() if device.type == 'cuda' else None
|
159 |
+
|
160 |
+
for epoch in range(num_epochs):
|
161 |
+
model.train()
|
162 |
+
total_loss = 0
|
163 |
+
for batch in train_loader:
|
164 |
+
input_ids = batch["input_ids"].to(device)
|
165 |
+
optimizer.zero_grad()
|
166 |
+
if scaler:
|
167 |
+
with torch.cuda.amp.autocast():
|
168 |
+
logits = model(input_ids)
|
169 |
+
shift_logits = logits[:, :-1, :].contiguous()
|
170 |
+
shift_labels = input_ids[:, 1:].contiguous()
|
171 |
+
loss = loss_fn(shift_logits.view(-1, vocab_size), shift_labels.view(-1))
|
172 |
+
scaler.scale(loss).backward()
|
173 |
+
scaler.step(optimizer)
|
174 |
+
scaler.update()
|
175 |
+
else:
|
176 |
+
logits = model(input_ids)
|
177 |
+
shift_logits = logits[:, :-1, :].contiguous()
|
178 |
+
shift_labels = input_ids[:, 1:].contiguous()
|
179 |
+
loss = loss_fn(shift_logits.view(-1, vocab_size), shift_labels.view(-1))
|
180 |
+
loss.backward()
|
181 |
+
optimizer.step()
|
182 |
+
total_loss += loss.item()
|
183 |
+
avg_loss = total_loss / len(train_loader)
|
184 |
+
print(f"Epoch {epoch+1}, Loss: {avg_loss:.4f}")
|
185 |
+
|
186 |
+
##############################
|
187 |
+
# 4) Soru-Cevap Üretimi #
|
188 |
+
##############################
|
189 |
+
|
190 |
+
def generate_text(prompt: str, max_new_tokens=30, temperature=1.5, top_k=200):
|
191 |
+
model.eval()
|
192 |
+
# Prompt'u tokenize et ve tensor'a çevir.
|
193 |
+
input_ids = torch.tensor([my_tokenizer.encode(prompt)], dtype=torch.long).to(device)
|
194 |
+
original_length = input_ids.shape[1]
|
195 |
+
with torch.no_grad():
|
196 |
+
for _ in range(max_new_tokens):
|
197 |
+
logits = model(input_ids) # (batch, seq_len, vocab_size)
|
198 |
+
next_token_logits = logits[:, -1, :] / temperature
|
199 |
+
|
200 |
+
k = min(top_k, next_token_logits.size(-1))
|
201 |
+
values, indices = torch.topk(next_token_logits, k)
|
202 |
+
filtered_logits = torch.full_like(next_token_logits, float('-inf'))
|
203 |
+
filtered_logits.scatter_(1, indices, values)
|
204 |
+
probs = F.softmax(filtered_logits, dim=-1)
|
205 |
+
|
206 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
207 |
+
input_ids = torch.cat([input_ids, next_token], dim=1)
|
208 |
+
full_output = my_tokenizer.decode(input_ids[0].cpu().numpy().tolist())
|
209 |
+
# Sadece üretilen yeni tokenlar:
|
210 |
+
new_tokens = my_tokenizer.decode(input_ids[0][original_length:].cpu().numpy().tolist())
|
211 |
+
return full_output, new_tokens
|
212 |
+
|
213 |
+
question_prompt = "What is the capital of France?"
|
214 |
+
full_generated_text, new_generated_text = generate_text(question_prompt, max_new_tokens=20, temperature=1.5, top_k=200)
|
215 |
+
|
216 |
+
print("\nPrompt:", question_prompt)
|
217 |
+
print("Full Output (Prompt + Üretilen):", full_generated_text)
|
218 |
+
print("Yalnızca Üretilen Tokenlar:", new_generated_text)
|
219 |
+
|
220 |
+
##############################
|
221 |
+
# 5) Değerlendirme #
|
222 |
+
##############################
|
223 |
+
|
224 |
+
model.eval()
|
225 |
+
total_val_loss = 0
|
226 |
+
with torch.no_grad():
|
227 |
+
for batch in val_loader:
|
228 |
+
input_ids = batch["input_ids"].to(device)
|
229 |
+
logits = model(input_ids)
|
230 |
+
shift_logits = logits[:, :-1, :].contiguous()
|
231 |
+
shift_labels = input_ids[:, 1:].contiguous()
|
232 |
+
loss = loss_fn(shift_logits.view(-1, vocab_size), shift_labels.view(-1))
|
233 |
+
total_val_loss += loss.item()
|
234 |
+
avg_val_loss = total_val_loss / len(val_loader)
|
235 |
+
print(f"Validation Loss: {avg_val_loss:.4f}")
|
236 |
+
|
237 |
+
torch.save(model, "brt-1 mmlu.pth")
|