cstr's picture
Upload 3 files
e17d2b6 verified
import argparse
import time
from typing import List
import model
import numpy as np
import mlx.core as mx
from transformers import AutoModel, AutoTokenizer
def run_torch(bert_model: str, batch: List[str]):
print(f"\n[PyTorch] Loading model and tokenizer: {bert_model}")
start_time = time.time()
tokenizer = AutoTokenizer.from_pretrained(bert_model)
torch_model = AutoModel.from_pretrained(bert_model)
load_time = time.time() - start_time
print(f"[PyTorch] Model loaded in {load_time:.2f} seconds")
print(f"[PyTorch] Tokenizing batch of {len(batch)} sentences")
torch_tokens = tokenizer(batch, return_tensors="pt", padding=True)
print(f"[PyTorch] Running model inference")
inference_start = time.time()
torch_forward = torch_model(**torch_tokens)
inference_time = time.time() - inference_start
print(f"[PyTorch] Inference completed in {inference_time:.4f} seconds")
torch_output = torch_forward.last_hidden_state.detach().numpy()
torch_pooled = torch_forward.pooler_output.detach().numpy()
print(f"[PyTorch] Output shape: {torch_output.shape}")
print(f"[PyTorch] Pooled output shape: {torch_pooled.shape}")
# Print a small sample of the output to verify sensible values
print(f"[PyTorch] Sample of output (first token, first 5 values): {torch_output[0, 0, :5]}")
print(f"[PyTorch] Sample of pooled output (first 5 values): {torch_pooled[0, :5]}")
return torch_output, torch_pooled
def run_mlx(bert_model: str, mlx_model: str, batch: List[str]):
print(f"\n[MLX] Loading model and tokenizer with weights from: {mlx_model}")
start_time = time.time()
mlx_output, mlx_pooled = model.run(bert_model, mlx_model, batch)
load_and_run_time = time.time() - start_time
print(f"[MLX] Model loaded and run in {load_and_run_time:.2f} seconds")
# Convert from MLX arrays to numpy for comparison
# The correct way to convert MLX arrays to numpy
mlx_output_np = np.array(mlx_output)
mlx_pooled_np = np.array(mlx_pooled)
print(f"[MLX] Output shape: {mlx_output_np.shape}")
print(f"[MLX] Pooled output shape: {mlx_pooled_np.shape}")
# Print a small sample of the output to verify sensible values
print(f"[MLX] Sample of output (first token, first 5 values): {mlx_output_np[0, 0, :5]}")
print(f"[MLX] Sample of pooled output (first 5 values): {mlx_pooled_np[0, :5]}")
return mlx_output_np, mlx_pooled_np
def compare_outputs(torch_output, torch_pooled, mlx_output, mlx_pooled):
print("\n[Comparison] Comparing PyTorch and MLX outputs")
# Check shapes
print(f"[Comparison] Shape match - Output: {torch_output.shape == mlx_output.shape}")
print(f"[Comparison] Shape match - Pooled: {torch_pooled.shape == mlx_pooled.shape}")
# Calculate differences
output_max_diff = np.max(np.abs(torch_output - mlx_output))
output_mean_diff = np.mean(np.abs(torch_output - mlx_output))
pooled_max_diff = np.max(np.abs(torch_pooled - mlx_pooled))
pooled_mean_diff = np.mean(np.abs(torch_pooled - mlx_pooled))
print(f"[Comparison] Output - Max absolute difference: {output_max_diff:.6f}")
print(f"[Comparison] Output - Mean absolute difference: {output_mean_diff:.6f}")
print(f"[Comparison] Pooled - Max absolute difference: {pooled_max_diff:.6f}")
print(f"[Comparison] Pooled - Mean absolute difference: {pooled_mean_diff:.6f}")
# Detailed comparison of first few values from first sentence
print("\n[Comparison] Detailed comparison of first 5 values from first output token:")
for i in range(5):
torch_val = torch_output[0, 0, i]
mlx_val = mlx_output[0, 0, i]
diff = abs(torch_val - mlx_val)
print(f"Index {i}: PyTorch={torch_val:.6f}, MLX={mlx_val:.6f}, Diff={diff:.6f}")
# Check if outputs are close
outputs_close = np.allclose(torch_output, mlx_output, rtol=1e-4, atol=1e-4)
pooled_close = np.allclose(torch_pooled, mlx_pooled, rtol=1e-4, atol=1e-4)
print(f"\n[Comparison] Outputs match within tolerance: {outputs_close}")
print(f"[Comparison] Pooled outputs match within tolerance: {pooled_close}")
return outputs_close and pooled_close
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Run a BERT-like model for a batch of text and compare PyTorch and MLX outputs."
)
parser.add_argument(
"--bert-model",
type=str,
default="bert-base-uncased",
help="The model identifier for a BERT-like model from Hugging Face Transformers.",
)
parser.add_argument(
"--mlx-model",
type=str,
default="weights/bert-base-uncased.npz",
help="The path of the stored MLX BERT weights (npz file).",
)
parser.add_argument(
"--text",
nargs="+",
default=["This is an example of BERT working in MLX."],
help="A batch of texts to process. Multiple texts should be separated by spaces.",
)
parser.add_argument(
"--verbose",
action="store_true",
help="Print detailed information about the model execution.",
)
args = parser.parse_args()
print(f"Testing BERT model: {args.bert_model}")
print(f"MLX weights: {args.mlx_model}")
print(f"Input text: {args.text}")
# Run both implementations
torch_output, torch_pooled = run_torch(args.bert_model, args.text)
mlx_output, mlx_pooled = run_mlx(args.bert_model, args.mlx_model, args.text)
# Compare outputs
all_match = compare_outputs(torch_output, torch_pooled, mlx_output, mlx_pooled)
if all_match:
print("\n✅ TEST PASSED: PyTorch and MLX implementations produce equivalent results!")
else:
print("\n❌ TEST FAILED: PyTorch and MLX implementations produce different results.")