cstr commited on
Commit
e17d2b6
·
verified ·
1 Parent(s): 6ab012c

Upload 3 files

Browse files
Files changed (3) hide show
  1. convert.py +66 -0
  2. model.py +203 -0
  3. test.py +143 -0
convert.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from pathlib import Path
4
+
5
+ import numpy
6
+ from transformers import AutoModel, AutoConfig
7
+
8
+
9
+ def replace_key(key: str) -> str:
10
+ key = key.replace(".layer.", ".layers.")
11
+ key = key.replace(".self.key.", ".key_proj.")
12
+ key = key.replace(".self.query.", ".query_proj.")
13
+ key = key.replace(".self.value.", ".value_proj.")
14
+ key = key.replace(".attention.output.dense.", ".attention.out_proj.")
15
+ key = key.replace(".attention.output.LayerNorm.", ".ln1.")
16
+ key = key.replace(".output.LayerNorm.", ".ln2.")
17
+ key = key.replace(".intermediate.dense.", ".linear1.")
18
+ key = key.replace(".output.dense.", ".linear2.")
19
+ key = key.replace(".LayerNorm.", ".norm.")
20
+ key = key.replace("pooler.dense.", "pooler.")
21
+ return key
22
+
23
+
24
+ def convert(bert_model: str, mlx_model: str) -> None:
25
+ # Load model and its configuration
26
+ model = AutoModel.from_pretrained(bert_model)
27
+ config = AutoConfig.from_pretrained(bert_model)
28
+
29
+ # Create output directory if it doesn't exist
30
+ output_dir = os.path.dirname(mlx_model)
31
+ if output_dir and not os.path.exists(output_dir):
32
+ os.makedirs(output_dir)
33
+
34
+ # Save config as well
35
+ config_path = os.path.join(output_dir, "config.json")
36
+ with open(config_path, "w") as f:
37
+ f.write(config.to_json_string())
38
+
39
+ print(f"Saved model config to {config_path}")
40
+
41
+ # Save the tensors
42
+ tensors = {
43
+ replace_key(key): tensor.numpy() for key, tensor in model.state_dict().items()
44
+ }
45
+ numpy.savez(mlx_model, **tensors)
46
+ print(f"Saved model weights to {mlx_model}")
47
+ print(f"Model vocab size: {config.vocab_size}, hidden size: {config.hidden_size}")
48
+
49
+
50
+ if __name__ == "__main__":
51
+ parser = argparse.ArgumentParser(description="Convert BERT weights to MLX.")
52
+ parser.add_argument(
53
+ "--bert-model",
54
+ type=str,
55
+ default="bert-base-uncased",
56
+ help="The huggingface name of the BERT model to save. Any BERT-like model can be specified.",
57
+ )
58
+ parser.add_argument(
59
+ "--mlx-model",
60
+ type=str,
61
+ default="weights/bert-base-uncased.npz",
62
+ help="The output path for the MLX BERT weights.",
63
+ )
64
+ args = parser.parse_args()
65
+
66
+ convert(args.bert_model, args.mlx_model)
model.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ from dataclasses import dataclass
4
+ from pathlib import Path
5
+ from typing import List, Optional, Tuple
6
+
7
+ import mlx.core as mx
8
+ import mlx.nn as nn
9
+ import numpy
10
+ import numpy as np
11
+ from mlx.utils import tree_unflatten
12
+ from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizerBase
13
+
14
+
15
+ class TransformerEncoderLayer(nn.Module):
16
+ """
17
+ A transformer encoder layer with (the original BERT) post-normalization.
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ dims: int,
23
+ num_heads: int,
24
+ mlp_dims: Optional[int] = None,
25
+ layer_norm_eps: float = 1e-12,
26
+ ):
27
+ super().__init__()
28
+ mlp_dims = mlp_dims or dims * 4
29
+ self.attention = nn.MultiHeadAttention(dims, num_heads, bias=True)
30
+ self.ln1 = nn.LayerNorm(dims, eps=layer_norm_eps)
31
+ self.ln2 = nn.LayerNorm(dims, eps=layer_norm_eps)
32
+ self.linear1 = nn.Linear(dims, mlp_dims)
33
+ self.linear2 = nn.Linear(mlp_dims, dims)
34
+ self.gelu = nn.GELU()
35
+
36
+ def __call__(self, x, mask):
37
+ attention_out = self.attention(x, x, x, mask)
38
+ add_and_norm = self.ln1(x + attention_out)
39
+
40
+ ff = self.linear1(add_and_norm)
41
+ ff_gelu = self.gelu(ff)
42
+ ff_out = self.linear2(ff_gelu)
43
+ x = self.ln2(ff_out + add_and_norm)
44
+
45
+ return x
46
+
47
+
48
+ class TransformerEncoder(nn.Module):
49
+ def __init__(
50
+ self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None
51
+ ):
52
+ super().__init__()
53
+ self.layers = [
54
+ TransformerEncoderLayer(dims, num_heads, mlp_dims)
55
+ for i in range(num_layers)
56
+ ]
57
+
58
+ def __call__(self, x, mask):
59
+ for layer in self.layers:
60
+ x = layer(x, mask)
61
+
62
+ return x
63
+
64
+
65
+ class BertEmbeddings(nn.Module):
66
+ def __init__(self, config):
67
+ super().__init__()
68
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
69
+ self.token_type_embeddings = nn.Embedding(
70
+ config.type_vocab_size, config.hidden_size
71
+ )
72
+ self.position_embeddings = nn.Embedding(
73
+ config.max_position_embeddings, config.hidden_size
74
+ )
75
+ self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
76
+
77
+ def __call__(
78
+ self, input_ids: mx.array, token_type_ids: mx.array = None
79
+ ) -> mx.array:
80
+ words = self.word_embeddings(input_ids)
81
+ position = self.position_embeddings(
82
+ mx.broadcast_to(mx.arange(input_ids.shape[1]), input_ids.shape)
83
+ )
84
+
85
+ if token_type_ids is None:
86
+ # If token_type_ids is not provided, default to zeros
87
+ token_type_ids = mx.zeros_like(input_ids)
88
+
89
+ token_types = self.token_type_embeddings(token_type_ids)
90
+
91
+ embeddings = position + words + token_types
92
+ return self.norm(embeddings)
93
+
94
+
95
+ class Bert(nn.Module):
96
+ def __init__(self, config):
97
+ super().__init__()
98
+ self.embeddings = BertEmbeddings(config)
99
+ self.encoder = TransformerEncoder(
100
+ num_layers=config.num_hidden_layers,
101
+ dims=config.hidden_size,
102
+ num_heads=config.num_attention_heads,
103
+ mlp_dims=config.intermediate_size,
104
+ )
105
+ self.pooler = nn.Linear(config.hidden_size, config.hidden_size)
106
+
107
+ def __call__(
108
+ self,
109
+ input_ids: mx.array,
110
+ token_type_ids: mx.array = None,
111
+ attention_mask: mx.array = None,
112
+ ) -> Tuple[mx.array, mx.array]:
113
+ x = self.embeddings(input_ids, token_type_ids)
114
+
115
+ if attention_mask is not None:
116
+ # convert 0's to -infs, 1's to 0's, and make it broadcastable
117
+ attention_mask = mx.log(attention_mask)
118
+ attention_mask = mx.expand_dims(attention_mask, (1, 2))
119
+
120
+ y = self.encoder(x, attention_mask)
121
+ return y, mx.tanh(self.pooler(y[:, 0]))
122
+
123
+
124
+ def load_model(
125
+ bert_model: str, weights_path: str
126
+ ) -> Tuple[Bert, PreTrainedTokenizerBase]:
127
+ if not Path(weights_path).exists():
128
+ raise ValueError(f"No model weights found in {weights_path}")
129
+
130
+ # First check if there's a local config
131
+ config_path = Path(weights_path).parent / "config.json"
132
+ if config_path.exists():
133
+ with open(config_path, "r") as f:
134
+ config_dict = json.load(f)
135
+ config = AutoConfig.for_model(**config_dict)
136
+ print(f"Loaded local config from {config_path}")
137
+ else:
138
+ # If no local config, use the HuggingFace one
139
+ config = AutoConfig.from_pretrained(bert_model)
140
+ print(f"Loaded config from HuggingFace for {bert_model}")
141
+
142
+ # Create and update the model
143
+ print(f"Creating model with vocab_size={config.vocab_size}, hidden_size={config.hidden_size}")
144
+ model = Bert(config)
145
+ model.load_weights(weights_path)
146
+
147
+ tokenizer = AutoTokenizer.from_pretrained(bert_model)
148
+
149
+ return model, tokenizer
150
+
151
+
152
+ def run(bert_model: str, mlx_model: str, batch: List[str]):
153
+ import time
154
+
155
+ # Time model loading
156
+ load_start = time.time()
157
+ model, tokenizer = load_model(bert_model, mlx_model)
158
+ load_time = time.time() - load_start
159
+ print(f"[MLX] Model loaded in {load_time:.2f} seconds")
160
+
161
+ # Time tokenization
162
+ print(f"[MLX] Tokenizing batch of {len(batch)} sentences")
163
+ token_start = time.time()
164
+ tokens = tokenizer(batch, return_tensors="np", padding=True)
165
+ token_time = time.time() - token_start
166
+ print(f"[MLX] Tokenization completed in {token_time:.4f} seconds")
167
+
168
+ print(f"[MLX] Tokens shape: input_ids={tokens['input_ids'].shape}")
169
+ tokens = {key: mx.array(v) for key, v in tokens.items()}
170
+
171
+ # Time inference
172
+ print(f"[MLX] Running model inference")
173
+ infer_start = time.time()
174
+ output, pooled = model(**tokens)
175
+ mx.eval(output, pooled) # Force evaluation of lazy arrays
176
+ infer_time = time.time() - infer_start
177
+ print(f"[MLX] Inference completed in {infer_time:.4f} seconds")
178
+
179
+ return output, pooled
180
+
181
+
182
+ if __name__ == "__main__":
183
+ parser = argparse.ArgumentParser(description="Run the BERT model using MLX.")
184
+ parser.add_argument(
185
+ "--bert-model",
186
+ type=str,
187
+ default="bert-base-uncased",
188
+ help="The huggingface name of the BERT model to save.",
189
+ )
190
+ parser.add_argument(
191
+ "--mlx-model",
192
+ type=str,
193
+ default="weights/bert-base-uncased.npz",
194
+ help="The path of the stored MLX BERT weights (npz file).",
195
+ )
196
+ parser.add_argument(
197
+ "--text",
198
+ type=str,
199
+ default="This is an example of BERT working in MLX",
200
+ help="The text to generate embeddings for.",
201
+ )
202
+ args = parser.parse_args()
203
+ run(args.bert_model, args.mlx_model, args.text)
test.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import time
3
+ from typing import List
4
+
5
+ import model
6
+ import numpy as np
7
+ import mlx.core as mx
8
+ from transformers import AutoModel, AutoTokenizer
9
+
10
+
11
+ def run_torch(bert_model: str, batch: List[str]):
12
+ print(f"\n[PyTorch] Loading model and tokenizer: {bert_model}")
13
+ start_time = time.time()
14
+ tokenizer = AutoTokenizer.from_pretrained(bert_model)
15
+ torch_model = AutoModel.from_pretrained(bert_model)
16
+ load_time = time.time() - start_time
17
+ print(f"[PyTorch] Model loaded in {load_time:.2f} seconds")
18
+
19
+ print(f"[PyTorch] Tokenizing batch of {len(batch)} sentences")
20
+ torch_tokens = tokenizer(batch, return_tensors="pt", padding=True)
21
+
22
+ print(f"[PyTorch] Running model inference")
23
+ inference_start = time.time()
24
+ torch_forward = torch_model(**torch_tokens)
25
+ inference_time = time.time() - inference_start
26
+ print(f"[PyTorch] Inference completed in {inference_time:.4f} seconds")
27
+
28
+ torch_output = torch_forward.last_hidden_state.detach().numpy()
29
+ torch_pooled = torch_forward.pooler_output.detach().numpy()
30
+
31
+ print(f"[PyTorch] Output shape: {torch_output.shape}")
32
+ print(f"[PyTorch] Pooled output shape: {torch_pooled.shape}")
33
+
34
+ # Print a small sample of the output to verify sensible values
35
+ print(f"[PyTorch] Sample of output (first token, first 5 values): {torch_output[0, 0, :5]}")
36
+ print(f"[PyTorch] Sample of pooled output (first 5 values): {torch_pooled[0, :5]}")
37
+
38
+ return torch_output, torch_pooled
39
+
40
+
41
+ def run_mlx(bert_model: str, mlx_model: str, batch: List[str]):
42
+ print(f"\n[MLX] Loading model and tokenizer with weights from: {mlx_model}")
43
+ start_time = time.time()
44
+ mlx_output, mlx_pooled = model.run(bert_model, mlx_model, batch)
45
+ load_and_run_time = time.time() - start_time
46
+ print(f"[MLX] Model loaded and run in {load_and_run_time:.2f} seconds")
47
+
48
+ # Convert from MLX arrays to numpy for comparison
49
+ # The correct way to convert MLX arrays to numpy
50
+ mlx_output_np = np.array(mlx_output)
51
+ mlx_pooled_np = np.array(mlx_pooled)
52
+
53
+ print(f"[MLX] Output shape: {mlx_output_np.shape}")
54
+ print(f"[MLX] Pooled output shape: {mlx_pooled_np.shape}")
55
+
56
+ # Print a small sample of the output to verify sensible values
57
+ print(f"[MLX] Sample of output (first token, first 5 values): {mlx_output_np[0, 0, :5]}")
58
+ print(f"[MLX] Sample of pooled output (first 5 values): {mlx_pooled_np[0, :5]}")
59
+
60
+ return mlx_output_np, mlx_pooled_np
61
+
62
+
63
+ def compare_outputs(torch_output, torch_pooled, mlx_output, mlx_pooled):
64
+ print("\n[Comparison] Comparing PyTorch and MLX outputs")
65
+
66
+ # Check shapes
67
+ print(f"[Comparison] Shape match - Output: {torch_output.shape == mlx_output.shape}")
68
+ print(f"[Comparison] Shape match - Pooled: {torch_pooled.shape == mlx_pooled.shape}")
69
+
70
+ # Calculate differences
71
+ output_max_diff = np.max(np.abs(torch_output - mlx_output))
72
+ output_mean_diff = np.mean(np.abs(torch_output - mlx_output))
73
+ pooled_max_diff = np.max(np.abs(torch_pooled - mlx_pooled))
74
+ pooled_mean_diff = np.mean(np.abs(torch_pooled - mlx_pooled))
75
+
76
+ print(f"[Comparison] Output - Max absolute difference: {output_max_diff:.6f}")
77
+ print(f"[Comparison] Output - Mean absolute difference: {output_mean_diff:.6f}")
78
+ print(f"[Comparison] Pooled - Max absolute difference: {pooled_max_diff:.6f}")
79
+ print(f"[Comparison] Pooled - Mean absolute difference: {pooled_mean_diff:.6f}")
80
+
81
+ # Detailed comparison of first few values from first sentence
82
+ print("\n[Comparison] Detailed comparison of first 5 values from first output token:")
83
+ for i in range(5):
84
+ torch_val = torch_output[0, 0, i]
85
+ mlx_val = mlx_output[0, 0, i]
86
+ diff = abs(torch_val - mlx_val)
87
+ print(f"Index {i}: PyTorch={torch_val:.6f}, MLX={mlx_val:.6f}, Diff={diff:.6f}")
88
+
89
+ # Check if outputs are close
90
+ outputs_close = np.allclose(torch_output, mlx_output, rtol=1e-4, atol=1e-4)
91
+ pooled_close = np.allclose(torch_pooled, mlx_pooled, rtol=1e-4, atol=1e-4)
92
+
93
+ print(f"\n[Comparison] Outputs match within tolerance: {outputs_close}")
94
+ print(f"[Comparison] Pooled outputs match within tolerance: {pooled_close}")
95
+
96
+ return outputs_close and pooled_close
97
+
98
+
99
+ if __name__ == "__main__":
100
+ parser = argparse.ArgumentParser(
101
+ description="Run a BERT-like model for a batch of text and compare PyTorch and MLX outputs."
102
+ )
103
+ parser.add_argument(
104
+ "--bert-model",
105
+ type=str,
106
+ default="bert-base-uncased",
107
+ help="The model identifier for a BERT-like model from Hugging Face Transformers.",
108
+ )
109
+ parser.add_argument(
110
+ "--mlx-model",
111
+ type=str,
112
+ default="weights/bert-base-uncased.npz",
113
+ help="The path of the stored MLX BERT weights (npz file).",
114
+ )
115
+ parser.add_argument(
116
+ "--text",
117
+ nargs="+",
118
+ default=["This is an example of BERT working in MLX."],
119
+ help="A batch of texts to process. Multiple texts should be separated by spaces.",
120
+ )
121
+ parser.add_argument(
122
+ "--verbose",
123
+ action="store_true",
124
+ help="Print detailed information about the model execution.",
125
+ )
126
+
127
+ args = parser.parse_args()
128
+
129
+ print(f"Testing BERT model: {args.bert_model}")
130
+ print(f"MLX weights: {args.mlx_model}")
131
+ print(f"Input text: {args.text}")
132
+
133
+ # Run both implementations
134
+ torch_output, torch_pooled = run_torch(args.bert_model, args.text)
135
+ mlx_output, mlx_pooled = run_mlx(args.bert_model, args.mlx_model, args.text)
136
+
137
+ # Compare outputs
138
+ all_match = compare_outputs(torch_output, torch_pooled, mlx_output, mlx_pooled)
139
+
140
+ if all_match:
141
+ print("\n✅ TEST PASSED: PyTorch and MLX implementations produce equivalent results!")
142
+ else:
143
+ print("\n❌ TEST FAILED: PyTorch and MLX implementations produce different results.")