|
import torch |
|
import torch.nn as nn |
|
from model.layers import TransformerBlock |
|
|
|
|
|
class GPTModel(nn.Module): |
|
""" |
|
GPT-style Language Model (decoder-only Transformer). |
|
""" |
|
|
|
def __init__(self, vocab_size: int, max_position_embeddings: int, n_layers: int, |
|
n_heads: int, hidden_dim: int, dropout: float = 0.1): |
|
super().__init__() |
|
|
|
self.tok_embedding = nn.Embedding(vocab_size, hidden_dim) |
|
self.pos_embedding = nn.Embedding(max_position_embeddings, hidden_dim) |
|
self.dropout = nn.Dropout(dropout) |
|
|
|
self.layers = nn.ModuleList([ |
|
TransformerBlock(hidden_dim, n_heads, dropout) for _ in range(n_layers) |
|
]) |
|
|
|
self.ln_f = nn.LayerNorm(hidden_dim) |
|
|
|
self.output_proj = nn.Linear(hidden_dim, vocab_size, bias=False) |
|
|
|
def forward(self, x): |
|
""" |
|
x: Tensor of token IDs with shape (batch_size, seq_length). |
|
Returns: Logits of shape (batch_size, seq_length, vocab_size). |
|
""" |
|
batch_size, seq_length = x.shape |
|
|
|
tok_emb = self.tok_embedding(x) |
|
positions = torch.arange(0, seq_length, device=x.device).unsqueeze(0) |
|
pos_emb = self.pos_embedding(positions) |
|
h = self.dropout(tok_emb + pos_emb) |
|
|
|
for layer in self.layers: |
|
h = layer(h) |
|
|
|
h = self.ln_f(h) |
|
|
|
logits = self.output_proj(h) |
|
return logits |
|
|