File size: 5,889 Bytes
e17d2b6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import argparse
import time
from typing import List
import model
import numpy as np
import mlx.core as mx
from transformers import AutoModel, AutoTokenizer
def run_torch(bert_model: str, batch: List[str]):
print(f"\n[PyTorch] Loading model and tokenizer: {bert_model}")
start_time = time.time()
tokenizer = AutoTokenizer.from_pretrained(bert_model)
torch_model = AutoModel.from_pretrained(bert_model)
load_time = time.time() - start_time
print(f"[PyTorch] Model loaded in {load_time:.2f} seconds")
print(f"[PyTorch] Tokenizing batch of {len(batch)} sentences")
torch_tokens = tokenizer(batch, return_tensors="pt", padding=True)
print(f"[PyTorch] Running model inference")
inference_start = time.time()
torch_forward = torch_model(**torch_tokens)
inference_time = time.time() - inference_start
print(f"[PyTorch] Inference completed in {inference_time:.4f} seconds")
torch_output = torch_forward.last_hidden_state.detach().numpy()
torch_pooled = torch_forward.pooler_output.detach().numpy()
print(f"[PyTorch] Output shape: {torch_output.shape}")
print(f"[PyTorch] Pooled output shape: {torch_pooled.shape}")
# Print a small sample of the output to verify sensible values
print(f"[PyTorch] Sample of output (first token, first 5 values): {torch_output[0, 0, :5]}")
print(f"[PyTorch] Sample of pooled output (first 5 values): {torch_pooled[0, :5]}")
return torch_output, torch_pooled
def run_mlx(bert_model: str, mlx_model: str, batch: List[str]):
print(f"\n[MLX] Loading model and tokenizer with weights from: {mlx_model}")
start_time = time.time()
mlx_output, mlx_pooled = model.run(bert_model, mlx_model, batch)
load_and_run_time = time.time() - start_time
print(f"[MLX] Model loaded and run in {load_and_run_time:.2f} seconds")
# Convert from MLX arrays to numpy for comparison
# The correct way to convert MLX arrays to numpy
mlx_output_np = np.array(mlx_output)
mlx_pooled_np = np.array(mlx_pooled)
print(f"[MLX] Output shape: {mlx_output_np.shape}")
print(f"[MLX] Pooled output shape: {mlx_pooled_np.shape}")
# Print a small sample of the output to verify sensible values
print(f"[MLX] Sample of output (first token, first 5 values): {mlx_output_np[0, 0, :5]}")
print(f"[MLX] Sample of pooled output (first 5 values): {mlx_pooled_np[0, :5]}")
return mlx_output_np, mlx_pooled_np
def compare_outputs(torch_output, torch_pooled, mlx_output, mlx_pooled):
print("\n[Comparison] Comparing PyTorch and MLX outputs")
# Check shapes
print(f"[Comparison] Shape match - Output: {torch_output.shape == mlx_output.shape}")
print(f"[Comparison] Shape match - Pooled: {torch_pooled.shape == mlx_pooled.shape}")
# Calculate differences
output_max_diff = np.max(np.abs(torch_output - mlx_output))
output_mean_diff = np.mean(np.abs(torch_output - mlx_output))
pooled_max_diff = np.max(np.abs(torch_pooled - mlx_pooled))
pooled_mean_diff = np.mean(np.abs(torch_pooled - mlx_pooled))
print(f"[Comparison] Output - Max absolute difference: {output_max_diff:.6f}")
print(f"[Comparison] Output - Mean absolute difference: {output_mean_diff:.6f}")
print(f"[Comparison] Pooled - Max absolute difference: {pooled_max_diff:.6f}")
print(f"[Comparison] Pooled - Mean absolute difference: {pooled_mean_diff:.6f}")
# Detailed comparison of first few values from first sentence
print("\n[Comparison] Detailed comparison of first 5 values from first output token:")
for i in range(5):
torch_val = torch_output[0, 0, i]
mlx_val = mlx_output[0, 0, i]
diff = abs(torch_val - mlx_val)
print(f"Index {i}: PyTorch={torch_val:.6f}, MLX={mlx_val:.6f}, Diff={diff:.6f}")
# Check if outputs are close
outputs_close = np.allclose(torch_output, mlx_output, rtol=1e-4, atol=1e-4)
pooled_close = np.allclose(torch_pooled, mlx_pooled, rtol=1e-4, atol=1e-4)
print(f"\n[Comparison] Outputs match within tolerance: {outputs_close}")
print(f"[Comparison] Pooled outputs match within tolerance: {pooled_close}")
return outputs_close and pooled_close
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Run a BERT-like model for a batch of text and compare PyTorch and MLX outputs."
)
parser.add_argument(
"--bert-model",
type=str,
default="bert-base-uncased",
help="The model identifier for a BERT-like model from Hugging Face Transformers.",
)
parser.add_argument(
"--mlx-model",
type=str,
default="weights/bert-base-uncased.npz",
help="The path of the stored MLX BERT weights (npz file).",
)
parser.add_argument(
"--text",
nargs="+",
default=["This is an example of BERT working in MLX."],
help="A batch of texts to process. Multiple texts should be separated by spaces.",
)
parser.add_argument(
"--verbose",
action="store_true",
help="Print detailed information about the model execution.",
)
args = parser.parse_args()
print(f"Testing BERT model: {args.bert_model}")
print(f"MLX weights: {args.mlx_model}")
print(f"Input text: {args.text}")
# Run both implementations
torch_output, torch_pooled = run_torch(args.bert_model, args.text)
mlx_output, mlx_pooled = run_mlx(args.bert_model, args.mlx_model, args.text)
# Compare outputs
all_match = compare_outputs(torch_output, torch_pooled, mlx_output, mlx_pooled)
if all_match:
print("\n✅ TEST PASSED: PyTorch and MLX implementations produce equivalent results!")
else:
print("\n❌ TEST FAILED: PyTorch and MLX implementations produce different results.")
|