OpenGPT / model /gpt_model.py
VolodymyrPugachov's picture
Upload 17 files
6810eb1 verified
import torch
import torch.nn as nn
from model.layers import TransformerBlock
class GPTModel(nn.Module):
"""
GPT-style Language Model (decoder-only Transformer).
"""
def __init__(self, vocab_size: int, max_position_embeddings: int, n_layers: int,
n_heads: int, hidden_dim: int, dropout: float = 0.1):
super().__init__()
# Embedding layers
self.tok_embedding = nn.Embedding(vocab_size, hidden_dim)
self.pos_embedding = nn.Embedding(max_position_embeddings, hidden_dim)
self.dropout = nn.Dropout(dropout)
# Transformer blocks
self.layers = nn.ModuleList([
TransformerBlock(hidden_dim, n_heads, dropout) for _ in range(n_layers)
])
# Final layer normalization
self.ln_f = nn.LayerNorm(hidden_dim)
# Output projection to vocabulary size
self.output_proj = nn.Linear(hidden_dim, vocab_size, bias=False)
def forward(self, x):
"""
x: Tensor of token IDs with shape (batch_size, seq_length).
Returns: Logits of shape (batch_size, seq_length, vocab_size).
"""
batch_size, seq_length = x.shape
# Token and positional embeddings
tok_emb = self.tok_embedding(x) # (batch, seq_len, hidden_dim)
positions = torch.arange(0, seq_length, device=x.device).unsqueeze(0)
pos_emb = self.pos_embedding(positions) # (1, seq_len, hidden_dim)
h = self.dropout(tok_emb + pos_emb) # (batch, seq_len, hidden_dim)
# Transformer decoder blocks
for layer in self.layers:
h = layer(h)
# Final layer norm
h = self.ln_f(h)
# Compute logits
logits = self.output_proj(h) # (batch, seq_len, vocab_size)
return logits