File size: 2,549 Bytes
6810eb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import math
import argparse
import torch
import torch.nn.functional as F
from model.gpt_model import GPTModel
from data.dataset import TextDataset
from data import utils

def main():
    parser = argparse.ArgumentParser(description="Evaluate a trained OpenGPT model on a validation set.")
    parser.add_argument("--model", type=str, required=True, help="Path to the model checkpoint (.pt file).")
    parser.add_argument("--config", type=str, required=True, help="Path to the model config file (YAML/JSON).")
    parser.add_argument("--tokenizer", type=str, required=True, help="Path to the trained tokenizer (.json or directory).")
    args = parser.parse_args()

    # Load configuration for model hyperparameters
    config = utils.load_config(args.config)
    model_conf = config.get("model", {})
    data_conf = config.get("data", {})
    vocab_size = model_conf["vocab_size"]
    max_pos = model_conf.get("max_position_embeddings", 512)
    hidden_dim = model_conf.get("embedding_dim", 768)
    n_layers = model_conf.get("n_layers", 12)
    n_heads = model_conf.get("n_heads", 12)
    dropout = model_conf.get("dropout", 0.0)

    # Initialize model and load checkpoint
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = GPTModel(vocab_size=vocab_size, max_position_embeddings=max_pos,
                     n_layers=n_layers, n_heads=n_heads, hidden_dim=hidden_dim,
                     dropout=dropout).to(device)
    model.eval()
    utils.load_checkpoint(model, optimizer=None, filepath=args.model, device=device)

    # Prepare validation dataset and loader
    valid_path = data_conf.get("valid_path", data_conf.get("train_path"))
    block_size = data_conf.get("block_size", 128)
    dataset = TextDataset(valid_path, args.tokenizer, block_size)
    loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False)

    total_loss = 0.0
    total_tokens = 0
    with torch.no_grad():
        for inputs, targets in loader:
            inputs = inputs.to(device)
            targets = targets.to(device)
            outputs = model(inputs)
            # Compute total loss (summed over all tokens in batch)
            loss = F.cross_entropy(outputs.view(-1, vocab_size), targets.view(-1), reduction='sum')
            total_loss += loss.item()
            total_tokens += targets.numel()
    avg_nll = total_loss / total_tokens  # average negative log-likelihood
    perplexity = math.exp(avg_nll)
    print(f"Validation Perplexity: {perplexity:.4f}")

if __name__ == "__main__":
    main()