Upload 3 files
Browse files- convert.py +66 -0
- model.py +203 -0
- test.py +143 -0
convert.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import numpy
|
6 |
+
from transformers import AutoModel, AutoConfig
|
7 |
+
|
8 |
+
|
9 |
+
def replace_key(key: str) -> str:
|
10 |
+
key = key.replace(".layer.", ".layers.")
|
11 |
+
key = key.replace(".self.key.", ".key_proj.")
|
12 |
+
key = key.replace(".self.query.", ".query_proj.")
|
13 |
+
key = key.replace(".self.value.", ".value_proj.")
|
14 |
+
key = key.replace(".attention.output.dense.", ".attention.out_proj.")
|
15 |
+
key = key.replace(".attention.output.LayerNorm.", ".ln1.")
|
16 |
+
key = key.replace(".output.LayerNorm.", ".ln2.")
|
17 |
+
key = key.replace(".intermediate.dense.", ".linear1.")
|
18 |
+
key = key.replace(".output.dense.", ".linear2.")
|
19 |
+
key = key.replace(".LayerNorm.", ".norm.")
|
20 |
+
key = key.replace("pooler.dense.", "pooler.")
|
21 |
+
return key
|
22 |
+
|
23 |
+
|
24 |
+
def convert(bert_model: str, mlx_model: str) -> None:
|
25 |
+
# Load model and its configuration
|
26 |
+
model = AutoModel.from_pretrained(bert_model)
|
27 |
+
config = AutoConfig.from_pretrained(bert_model)
|
28 |
+
|
29 |
+
# Create output directory if it doesn't exist
|
30 |
+
output_dir = os.path.dirname(mlx_model)
|
31 |
+
if output_dir and not os.path.exists(output_dir):
|
32 |
+
os.makedirs(output_dir)
|
33 |
+
|
34 |
+
# Save config as well
|
35 |
+
config_path = os.path.join(output_dir, "config.json")
|
36 |
+
with open(config_path, "w") as f:
|
37 |
+
f.write(config.to_json_string())
|
38 |
+
|
39 |
+
print(f"Saved model config to {config_path}")
|
40 |
+
|
41 |
+
# Save the tensors
|
42 |
+
tensors = {
|
43 |
+
replace_key(key): tensor.numpy() for key, tensor in model.state_dict().items()
|
44 |
+
}
|
45 |
+
numpy.savez(mlx_model, **tensors)
|
46 |
+
print(f"Saved model weights to {mlx_model}")
|
47 |
+
print(f"Model vocab size: {config.vocab_size}, hidden size: {config.hidden_size}")
|
48 |
+
|
49 |
+
|
50 |
+
if __name__ == "__main__":
|
51 |
+
parser = argparse.ArgumentParser(description="Convert BERT weights to MLX.")
|
52 |
+
parser.add_argument(
|
53 |
+
"--bert-model",
|
54 |
+
type=str,
|
55 |
+
default="bert-base-uncased",
|
56 |
+
help="The huggingface name of the BERT model to save. Any BERT-like model can be specified.",
|
57 |
+
)
|
58 |
+
parser.add_argument(
|
59 |
+
"--mlx-model",
|
60 |
+
type=str,
|
61 |
+
default="weights/bert-base-uncased.npz",
|
62 |
+
help="The output path for the MLX BERT weights.",
|
63 |
+
)
|
64 |
+
args = parser.parse_args()
|
65 |
+
|
66 |
+
convert(args.bert_model, args.mlx_model)
|
model.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
from dataclasses import dataclass
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import List, Optional, Tuple
|
6 |
+
|
7 |
+
import mlx.core as mx
|
8 |
+
import mlx.nn as nn
|
9 |
+
import numpy
|
10 |
+
import numpy as np
|
11 |
+
from mlx.utils import tree_unflatten
|
12 |
+
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizerBase
|
13 |
+
|
14 |
+
|
15 |
+
class TransformerEncoderLayer(nn.Module):
|
16 |
+
"""
|
17 |
+
A transformer encoder layer with (the original BERT) post-normalization.
|
18 |
+
"""
|
19 |
+
|
20 |
+
def __init__(
|
21 |
+
self,
|
22 |
+
dims: int,
|
23 |
+
num_heads: int,
|
24 |
+
mlp_dims: Optional[int] = None,
|
25 |
+
layer_norm_eps: float = 1e-12,
|
26 |
+
):
|
27 |
+
super().__init__()
|
28 |
+
mlp_dims = mlp_dims or dims * 4
|
29 |
+
self.attention = nn.MultiHeadAttention(dims, num_heads, bias=True)
|
30 |
+
self.ln1 = nn.LayerNorm(dims, eps=layer_norm_eps)
|
31 |
+
self.ln2 = nn.LayerNorm(dims, eps=layer_norm_eps)
|
32 |
+
self.linear1 = nn.Linear(dims, mlp_dims)
|
33 |
+
self.linear2 = nn.Linear(mlp_dims, dims)
|
34 |
+
self.gelu = nn.GELU()
|
35 |
+
|
36 |
+
def __call__(self, x, mask):
|
37 |
+
attention_out = self.attention(x, x, x, mask)
|
38 |
+
add_and_norm = self.ln1(x + attention_out)
|
39 |
+
|
40 |
+
ff = self.linear1(add_and_norm)
|
41 |
+
ff_gelu = self.gelu(ff)
|
42 |
+
ff_out = self.linear2(ff_gelu)
|
43 |
+
x = self.ln2(ff_out + add_and_norm)
|
44 |
+
|
45 |
+
return x
|
46 |
+
|
47 |
+
|
48 |
+
class TransformerEncoder(nn.Module):
|
49 |
+
def __init__(
|
50 |
+
self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None
|
51 |
+
):
|
52 |
+
super().__init__()
|
53 |
+
self.layers = [
|
54 |
+
TransformerEncoderLayer(dims, num_heads, mlp_dims)
|
55 |
+
for i in range(num_layers)
|
56 |
+
]
|
57 |
+
|
58 |
+
def __call__(self, x, mask):
|
59 |
+
for layer in self.layers:
|
60 |
+
x = layer(x, mask)
|
61 |
+
|
62 |
+
return x
|
63 |
+
|
64 |
+
|
65 |
+
class BertEmbeddings(nn.Module):
|
66 |
+
def __init__(self, config):
|
67 |
+
super().__init__()
|
68 |
+
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
|
69 |
+
self.token_type_embeddings = nn.Embedding(
|
70 |
+
config.type_vocab_size, config.hidden_size
|
71 |
+
)
|
72 |
+
self.position_embeddings = nn.Embedding(
|
73 |
+
config.max_position_embeddings, config.hidden_size
|
74 |
+
)
|
75 |
+
self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
76 |
+
|
77 |
+
def __call__(
|
78 |
+
self, input_ids: mx.array, token_type_ids: mx.array = None
|
79 |
+
) -> mx.array:
|
80 |
+
words = self.word_embeddings(input_ids)
|
81 |
+
position = self.position_embeddings(
|
82 |
+
mx.broadcast_to(mx.arange(input_ids.shape[1]), input_ids.shape)
|
83 |
+
)
|
84 |
+
|
85 |
+
if token_type_ids is None:
|
86 |
+
# If token_type_ids is not provided, default to zeros
|
87 |
+
token_type_ids = mx.zeros_like(input_ids)
|
88 |
+
|
89 |
+
token_types = self.token_type_embeddings(token_type_ids)
|
90 |
+
|
91 |
+
embeddings = position + words + token_types
|
92 |
+
return self.norm(embeddings)
|
93 |
+
|
94 |
+
|
95 |
+
class Bert(nn.Module):
|
96 |
+
def __init__(self, config):
|
97 |
+
super().__init__()
|
98 |
+
self.embeddings = BertEmbeddings(config)
|
99 |
+
self.encoder = TransformerEncoder(
|
100 |
+
num_layers=config.num_hidden_layers,
|
101 |
+
dims=config.hidden_size,
|
102 |
+
num_heads=config.num_attention_heads,
|
103 |
+
mlp_dims=config.intermediate_size,
|
104 |
+
)
|
105 |
+
self.pooler = nn.Linear(config.hidden_size, config.hidden_size)
|
106 |
+
|
107 |
+
def __call__(
|
108 |
+
self,
|
109 |
+
input_ids: mx.array,
|
110 |
+
token_type_ids: mx.array = None,
|
111 |
+
attention_mask: mx.array = None,
|
112 |
+
) -> Tuple[mx.array, mx.array]:
|
113 |
+
x = self.embeddings(input_ids, token_type_ids)
|
114 |
+
|
115 |
+
if attention_mask is not None:
|
116 |
+
# convert 0's to -infs, 1's to 0's, and make it broadcastable
|
117 |
+
attention_mask = mx.log(attention_mask)
|
118 |
+
attention_mask = mx.expand_dims(attention_mask, (1, 2))
|
119 |
+
|
120 |
+
y = self.encoder(x, attention_mask)
|
121 |
+
return y, mx.tanh(self.pooler(y[:, 0]))
|
122 |
+
|
123 |
+
|
124 |
+
def load_model(
|
125 |
+
bert_model: str, weights_path: str
|
126 |
+
) -> Tuple[Bert, PreTrainedTokenizerBase]:
|
127 |
+
if not Path(weights_path).exists():
|
128 |
+
raise ValueError(f"No model weights found in {weights_path}")
|
129 |
+
|
130 |
+
# First check if there's a local config
|
131 |
+
config_path = Path(weights_path).parent / "config.json"
|
132 |
+
if config_path.exists():
|
133 |
+
with open(config_path, "r") as f:
|
134 |
+
config_dict = json.load(f)
|
135 |
+
config = AutoConfig.for_model(**config_dict)
|
136 |
+
print(f"Loaded local config from {config_path}")
|
137 |
+
else:
|
138 |
+
# If no local config, use the HuggingFace one
|
139 |
+
config = AutoConfig.from_pretrained(bert_model)
|
140 |
+
print(f"Loaded config from HuggingFace for {bert_model}")
|
141 |
+
|
142 |
+
# Create and update the model
|
143 |
+
print(f"Creating model with vocab_size={config.vocab_size}, hidden_size={config.hidden_size}")
|
144 |
+
model = Bert(config)
|
145 |
+
model.load_weights(weights_path)
|
146 |
+
|
147 |
+
tokenizer = AutoTokenizer.from_pretrained(bert_model)
|
148 |
+
|
149 |
+
return model, tokenizer
|
150 |
+
|
151 |
+
|
152 |
+
def run(bert_model: str, mlx_model: str, batch: List[str]):
|
153 |
+
import time
|
154 |
+
|
155 |
+
# Time model loading
|
156 |
+
load_start = time.time()
|
157 |
+
model, tokenizer = load_model(bert_model, mlx_model)
|
158 |
+
load_time = time.time() - load_start
|
159 |
+
print(f"[MLX] Model loaded in {load_time:.2f} seconds")
|
160 |
+
|
161 |
+
# Time tokenization
|
162 |
+
print(f"[MLX] Tokenizing batch of {len(batch)} sentences")
|
163 |
+
token_start = time.time()
|
164 |
+
tokens = tokenizer(batch, return_tensors="np", padding=True)
|
165 |
+
token_time = time.time() - token_start
|
166 |
+
print(f"[MLX] Tokenization completed in {token_time:.4f} seconds")
|
167 |
+
|
168 |
+
print(f"[MLX] Tokens shape: input_ids={tokens['input_ids'].shape}")
|
169 |
+
tokens = {key: mx.array(v) for key, v in tokens.items()}
|
170 |
+
|
171 |
+
# Time inference
|
172 |
+
print(f"[MLX] Running model inference")
|
173 |
+
infer_start = time.time()
|
174 |
+
output, pooled = model(**tokens)
|
175 |
+
mx.eval(output, pooled) # Force evaluation of lazy arrays
|
176 |
+
infer_time = time.time() - infer_start
|
177 |
+
print(f"[MLX] Inference completed in {infer_time:.4f} seconds")
|
178 |
+
|
179 |
+
return output, pooled
|
180 |
+
|
181 |
+
|
182 |
+
if __name__ == "__main__":
|
183 |
+
parser = argparse.ArgumentParser(description="Run the BERT model using MLX.")
|
184 |
+
parser.add_argument(
|
185 |
+
"--bert-model",
|
186 |
+
type=str,
|
187 |
+
default="bert-base-uncased",
|
188 |
+
help="The huggingface name of the BERT model to save.",
|
189 |
+
)
|
190 |
+
parser.add_argument(
|
191 |
+
"--mlx-model",
|
192 |
+
type=str,
|
193 |
+
default="weights/bert-base-uncased.npz",
|
194 |
+
help="The path of the stored MLX BERT weights (npz file).",
|
195 |
+
)
|
196 |
+
parser.add_argument(
|
197 |
+
"--text",
|
198 |
+
type=str,
|
199 |
+
default="This is an example of BERT working in MLX",
|
200 |
+
help="The text to generate embeddings for.",
|
201 |
+
)
|
202 |
+
args = parser.parse_args()
|
203 |
+
run(args.bert_model, args.mlx_model, args.text)
|
test.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import time
|
3 |
+
from typing import List
|
4 |
+
|
5 |
+
import model
|
6 |
+
import numpy as np
|
7 |
+
import mlx.core as mx
|
8 |
+
from transformers import AutoModel, AutoTokenizer
|
9 |
+
|
10 |
+
|
11 |
+
def run_torch(bert_model: str, batch: List[str]):
|
12 |
+
print(f"\n[PyTorch] Loading model and tokenizer: {bert_model}")
|
13 |
+
start_time = time.time()
|
14 |
+
tokenizer = AutoTokenizer.from_pretrained(bert_model)
|
15 |
+
torch_model = AutoModel.from_pretrained(bert_model)
|
16 |
+
load_time = time.time() - start_time
|
17 |
+
print(f"[PyTorch] Model loaded in {load_time:.2f} seconds")
|
18 |
+
|
19 |
+
print(f"[PyTorch] Tokenizing batch of {len(batch)} sentences")
|
20 |
+
torch_tokens = tokenizer(batch, return_tensors="pt", padding=True)
|
21 |
+
|
22 |
+
print(f"[PyTorch] Running model inference")
|
23 |
+
inference_start = time.time()
|
24 |
+
torch_forward = torch_model(**torch_tokens)
|
25 |
+
inference_time = time.time() - inference_start
|
26 |
+
print(f"[PyTorch] Inference completed in {inference_time:.4f} seconds")
|
27 |
+
|
28 |
+
torch_output = torch_forward.last_hidden_state.detach().numpy()
|
29 |
+
torch_pooled = torch_forward.pooler_output.detach().numpy()
|
30 |
+
|
31 |
+
print(f"[PyTorch] Output shape: {torch_output.shape}")
|
32 |
+
print(f"[PyTorch] Pooled output shape: {torch_pooled.shape}")
|
33 |
+
|
34 |
+
# Print a small sample of the output to verify sensible values
|
35 |
+
print(f"[PyTorch] Sample of output (first token, first 5 values): {torch_output[0, 0, :5]}")
|
36 |
+
print(f"[PyTorch] Sample of pooled output (first 5 values): {torch_pooled[0, :5]}")
|
37 |
+
|
38 |
+
return torch_output, torch_pooled
|
39 |
+
|
40 |
+
|
41 |
+
def run_mlx(bert_model: str, mlx_model: str, batch: List[str]):
|
42 |
+
print(f"\n[MLX] Loading model and tokenizer with weights from: {mlx_model}")
|
43 |
+
start_time = time.time()
|
44 |
+
mlx_output, mlx_pooled = model.run(bert_model, mlx_model, batch)
|
45 |
+
load_and_run_time = time.time() - start_time
|
46 |
+
print(f"[MLX] Model loaded and run in {load_and_run_time:.2f} seconds")
|
47 |
+
|
48 |
+
# Convert from MLX arrays to numpy for comparison
|
49 |
+
# The correct way to convert MLX arrays to numpy
|
50 |
+
mlx_output_np = np.array(mlx_output)
|
51 |
+
mlx_pooled_np = np.array(mlx_pooled)
|
52 |
+
|
53 |
+
print(f"[MLX] Output shape: {mlx_output_np.shape}")
|
54 |
+
print(f"[MLX] Pooled output shape: {mlx_pooled_np.shape}")
|
55 |
+
|
56 |
+
# Print a small sample of the output to verify sensible values
|
57 |
+
print(f"[MLX] Sample of output (first token, first 5 values): {mlx_output_np[0, 0, :5]}")
|
58 |
+
print(f"[MLX] Sample of pooled output (first 5 values): {mlx_pooled_np[0, :5]}")
|
59 |
+
|
60 |
+
return mlx_output_np, mlx_pooled_np
|
61 |
+
|
62 |
+
|
63 |
+
def compare_outputs(torch_output, torch_pooled, mlx_output, mlx_pooled):
|
64 |
+
print("\n[Comparison] Comparing PyTorch and MLX outputs")
|
65 |
+
|
66 |
+
# Check shapes
|
67 |
+
print(f"[Comparison] Shape match - Output: {torch_output.shape == mlx_output.shape}")
|
68 |
+
print(f"[Comparison] Shape match - Pooled: {torch_pooled.shape == mlx_pooled.shape}")
|
69 |
+
|
70 |
+
# Calculate differences
|
71 |
+
output_max_diff = np.max(np.abs(torch_output - mlx_output))
|
72 |
+
output_mean_diff = np.mean(np.abs(torch_output - mlx_output))
|
73 |
+
pooled_max_diff = np.max(np.abs(torch_pooled - mlx_pooled))
|
74 |
+
pooled_mean_diff = np.mean(np.abs(torch_pooled - mlx_pooled))
|
75 |
+
|
76 |
+
print(f"[Comparison] Output - Max absolute difference: {output_max_diff:.6f}")
|
77 |
+
print(f"[Comparison] Output - Mean absolute difference: {output_mean_diff:.6f}")
|
78 |
+
print(f"[Comparison] Pooled - Max absolute difference: {pooled_max_diff:.6f}")
|
79 |
+
print(f"[Comparison] Pooled - Mean absolute difference: {pooled_mean_diff:.6f}")
|
80 |
+
|
81 |
+
# Detailed comparison of first few values from first sentence
|
82 |
+
print("\n[Comparison] Detailed comparison of first 5 values from first output token:")
|
83 |
+
for i in range(5):
|
84 |
+
torch_val = torch_output[0, 0, i]
|
85 |
+
mlx_val = mlx_output[0, 0, i]
|
86 |
+
diff = abs(torch_val - mlx_val)
|
87 |
+
print(f"Index {i}: PyTorch={torch_val:.6f}, MLX={mlx_val:.6f}, Diff={diff:.6f}")
|
88 |
+
|
89 |
+
# Check if outputs are close
|
90 |
+
outputs_close = np.allclose(torch_output, mlx_output, rtol=1e-4, atol=1e-4)
|
91 |
+
pooled_close = np.allclose(torch_pooled, mlx_pooled, rtol=1e-4, atol=1e-4)
|
92 |
+
|
93 |
+
print(f"\n[Comparison] Outputs match within tolerance: {outputs_close}")
|
94 |
+
print(f"[Comparison] Pooled outputs match within tolerance: {pooled_close}")
|
95 |
+
|
96 |
+
return outputs_close and pooled_close
|
97 |
+
|
98 |
+
|
99 |
+
if __name__ == "__main__":
|
100 |
+
parser = argparse.ArgumentParser(
|
101 |
+
description="Run a BERT-like model for a batch of text and compare PyTorch and MLX outputs."
|
102 |
+
)
|
103 |
+
parser.add_argument(
|
104 |
+
"--bert-model",
|
105 |
+
type=str,
|
106 |
+
default="bert-base-uncased",
|
107 |
+
help="The model identifier for a BERT-like model from Hugging Face Transformers.",
|
108 |
+
)
|
109 |
+
parser.add_argument(
|
110 |
+
"--mlx-model",
|
111 |
+
type=str,
|
112 |
+
default="weights/bert-base-uncased.npz",
|
113 |
+
help="The path of the stored MLX BERT weights (npz file).",
|
114 |
+
)
|
115 |
+
parser.add_argument(
|
116 |
+
"--text",
|
117 |
+
nargs="+",
|
118 |
+
default=["This is an example of BERT working in MLX."],
|
119 |
+
help="A batch of texts to process. Multiple texts should be separated by spaces.",
|
120 |
+
)
|
121 |
+
parser.add_argument(
|
122 |
+
"--verbose",
|
123 |
+
action="store_true",
|
124 |
+
help="Print detailed information about the model execution.",
|
125 |
+
)
|
126 |
+
|
127 |
+
args = parser.parse_args()
|
128 |
+
|
129 |
+
print(f"Testing BERT model: {args.bert_model}")
|
130 |
+
print(f"MLX weights: {args.mlx_model}")
|
131 |
+
print(f"Input text: {args.text}")
|
132 |
+
|
133 |
+
# Run both implementations
|
134 |
+
torch_output, torch_pooled = run_torch(args.bert_model, args.text)
|
135 |
+
mlx_output, mlx_pooled = run_mlx(args.bert_model, args.mlx_model, args.text)
|
136 |
+
|
137 |
+
# Compare outputs
|
138 |
+
all_match = compare_outputs(torch_output, torch_pooled, mlx_output, mlx_pooled)
|
139 |
+
|
140 |
+
if all_match:
|
141 |
+
print("\n✅ TEST PASSED: PyTorch and MLX implementations produce equivalent results!")
|
142 |
+
else:
|
143 |
+
print("\n❌ TEST FAILED: PyTorch and MLX implementations produce different results.")
|