File size: 5,889 Bytes
e17d2b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import argparse
import time
from typing import List

import model
import numpy as np
import mlx.core as mx
from transformers import AutoModel, AutoTokenizer


def run_torch(bert_model: str, batch: List[str]):
    print(f"\n[PyTorch] Loading model and tokenizer: {bert_model}")
    start_time = time.time()
    tokenizer = AutoTokenizer.from_pretrained(bert_model)
    torch_model = AutoModel.from_pretrained(bert_model)
    load_time = time.time() - start_time
    print(f"[PyTorch] Model loaded in {load_time:.2f} seconds")
    
    print(f"[PyTorch] Tokenizing batch of {len(batch)} sentences")
    torch_tokens = tokenizer(batch, return_tensors="pt", padding=True)
    
    print(f"[PyTorch] Running model inference")
    inference_start = time.time()
    torch_forward = torch_model(**torch_tokens)
    inference_time = time.time() - inference_start
    print(f"[PyTorch] Inference completed in {inference_time:.4f} seconds")
    
    torch_output = torch_forward.last_hidden_state.detach().numpy()
    torch_pooled = torch_forward.pooler_output.detach().numpy()
    
    print(f"[PyTorch] Output shape: {torch_output.shape}")
    print(f"[PyTorch] Pooled output shape: {torch_pooled.shape}")
    
    # Print a small sample of the output to verify sensible values
    print(f"[PyTorch] Sample of output (first token, first 5 values): {torch_output[0, 0, :5]}")
    print(f"[PyTorch] Sample of pooled output (first 5 values): {torch_pooled[0, :5]}")
    
    return torch_output, torch_pooled


def run_mlx(bert_model: str, mlx_model: str, batch: List[str]):
    print(f"\n[MLX] Loading model and tokenizer with weights from: {mlx_model}")
    start_time = time.time()
    mlx_output, mlx_pooled = model.run(bert_model, mlx_model, batch)
    load_and_run_time = time.time() - start_time
    print(f"[MLX] Model loaded and run in {load_and_run_time:.2f} seconds")
    
    # Convert from MLX arrays to numpy for comparison
    # The correct way to convert MLX arrays to numpy
    mlx_output_np = np.array(mlx_output)
    mlx_pooled_np = np.array(mlx_pooled)
    
    print(f"[MLX] Output shape: {mlx_output_np.shape}")
    print(f"[MLX] Pooled output shape: {mlx_pooled_np.shape}")
    
    # Print a small sample of the output to verify sensible values
    print(f"[MLX] Sample of output (first token, first 5 values): {mlx_output_np[0, 0, :5]}")
    print(f"[MLX] Sample of pooled output (first 5 values): {mlx_pooled_np[0, :5]}")
    
    return mlx_output_np, mlx_pooled_np


def compare_outputs(torch_output, torch_pooled, mlx_output, mlx_pooled):
    print("\n[Comparison] Comparing PyTorch and MLX outputs")
    
    # Check shapes
    print(f"[Comparison] Shape match - Output: {torch_output.shape == mlx_output.shape}")
    print(f"[Comparison] Shape match - Pooled: {torch_pooled.shape == mlx_pooled.shape}")
    
    # Calculate differences
    output_max_diff = np.max(np.abs(torch_output - mlx_output))
    output_mean_diff = np.mean(np.abs(torch_output - mlx_output))
    pooled_max_diff = np.max(np.abs(torch_pooled - mlx_pooled))
    pooled_mean_diff = np.mean(np.abs(torch_pooled - mlx_pooled))
    
    print(f"[Comparison] Output - Max absolute difference: {output_max_diff:.6f}")
    print(f"[Comparison] Output - Mean absolute difference: {output_mean_diff:.6f}")
    print(f"[Comparison] Pooled - Max absolute difference: {pooled_max_diff:.6f}")
    print(f"[Comparison] Pooled - Mean absolute difference: {pooled_mean_diff:.6f}")
    
    # Detailed comparison of first few values from first sentence
    print("\n[Comparison] Detailed comparison of first 5 values from first output token:")
    for i in range(5):
        torch_val = torch_output[0, 0, i]
        mlx_val = mlx_output[0, 0, i]
        diff = abs(torch_val - mlx_val)
        print(f"Index {i}: PyTorch={torch_val:.6f}, MLX={mlx_val:.6f}, Diff={diff:.6f}")
    
    # Check if outputs are close
    outputs_close = np.allclose(torch_output, mlx_output, rtol=1e-4, atol=1e-4)
    pooled_close = np.allclose(torch_pooled, mlx_pooled, rtol=1e-4, atol=1e-4)
    
    print(f"\n[Comparison] Outputs match within tolerance: {outputs_close}")
    print(f"[Comparison] Pooled outputs match within tolerance: {pooled_close}")
    
    return outputs_close and pooled_close


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Run a BERT-like model for a batch of text and compare PyTorch and MLX outputs."
    )
    parser.add_argument(
        "--bert-model",
        type=str,
        default="bert-base-uncased",
        help="The model identifier for a BERT-like model from Hugging Face Transformers.",
    )
    parser.add_argument(
        "--mlx-model",
        type=str,
        default="weights/bert-base-uncased.npz",
        help="The path of the stored MLX BERT weights (npz file).",
    )
    parser.add_argument(
        "--text",
        nargs="+",
        default=["This is an example of BERT working in MLX."],
        help="A batch of texts to process. Multiple texts should be separated by spaces.",
    )
    parser.add_argument(
        "--verbose",
        action="store_true",
        help="Print detailed information about the model execution.",
    )

    args = parser.parse_args()
    
    print(f"Testing BERT model: {args.bert_model}")
    print(f"MLX weights: {args.mlx_model}")
    print(f"Input text: {args.text}")
    
    # Run both implementations
    torch_output, torch_pooled = run_torch(args.bert_model, args.text)
    mlx_output, mlx_pooled = run_mlx(args.bert_model, args.mlx_model, args.text)
    
    # Compare outputs
    all_match = compare_outputs(torch_output, torch_pooled, mlx_output, mlx_pooled)
    
    if all_match:
        print("\n✅ TEST PASSED: PyTorch and MLX implementations produce equivalent results!")
    else:
        print("\n❌ TEST FAILED: PyTorch and MLX implementations produce different results.")