File size: 2,549 Bytes
6810eb1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
import math
import argparse
import torch
import torch.nn.functional as F
from model.gpt_model import GPTModel
from data.dataset import TextDataset
from data import utils
def main():
parser = argparse.ArgumentParser(description="Evaluate a trained OpenGPT model on a validation set.")
parser.add_argument("--model", type=str, required=True, help="Path to the model checkpoint (.pt file).")
parser.add_argument("--config", type=str, required=True, help="Path to the model config file (YAML/JSON).")
parser.add_argument("--tokenizer", type=str, required=True, help="Path to the trained tokenizer (.json or directory).")
args = parser.parse_args()
# Load configuration for model hyperparameters
config = utils.load_config(args.config)
model_conf = config.get("model", {})
data_conf = config.get("data", {})
vocab_size = model_conf["vocab_size"]
max_pos = model_conf.get("max_position_embeddings", 512)
hidden_dim = model_conf.get("embedding_dim", 768)
n_layers = model_conf.get("n_layers", 12)
n_heads = model_conf.get("n_heads", 12)
dropout = model_conf.get("dropout", 0.0)
# Initialize model and load checkpoint
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GPTModel(vocab_size=vocab_size, max_position_embeddings=max_pos,
n_layers=n_layers, n_heads=n_heads, hidden_dim=hidden_dim,
dropout=dropout).to(device)
model.eval()
utils.load_checkpoint(model, optimizer=None, filepath=args.model, device=device)
# Prepare validation dataset and loader
valid_path = data_conf.get("valid_path", data_conf.get("train_path"))
block_size = data_conf.get("block_size", 128)
dataset = TextDataset(valid_path, args.tokenizer, block_size)
loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False)
total_loss = 0.0
total_tokens = 0
with torch.no_grad():
for inputs, targets in loader:
inputs = inputs.to(device)
targets = targets.to(device)
outputs = model(inputs)
# Compute total loss (summed over all tokens in batch)
loss = F.cross_entropy(outputs.view(-1, vocab_size), targets.view(-1), reduction='sum')
total_loss += loss.item()
total_tokens += targets.numel()
avg_nll = total_loss / total_tokens # average negative log-likelihood
perplexity = math.exp(avg_nll)
print(f"Validation Perplexity: {perplexity:.4f}")
if __name__ == "__main__":
main()
|