|
import argparse |
|
import json |
|
from dataclasses import dataclass |
|
from pathlib import Path |
|
from typing import List, Optional, Tuple |
|
|
|
import mlx.core as mx |
|
import mlx.nn as nn |
|
import numpy |
|
import numpy as np |
|
from mlx.utils import tree_unflatten |
|
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizerBase |
|
|
|
|
|
class TransformerEncoderLayer(nn.Module): |
|
""" |
|
A transformer encoder layer with (the original BERT) post-normalization. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
dims: int, |
|
num_heads: int, |
|
mlp_dims: Optional[int] = None, |
|
layer_norm_eps: float = 1e-12, |
|
): |
|
super().__init__() |
|
mlp_dims = mlp_dims or dims * 4 |
|
self.attention = nn.MultiHeadAttention(dims, num_heads, bias=True) |
|
self.ln1 = nn.LayerNorm(dims, eps=layer_norm_eps) |
|
self.ln2 = nn.LayerNorm(dims, eps=layer_norm_eps) |
|
self.linear1 = nn.Linear(dims, mlp_dims) |
|
self.linear2 = nn.Linear(mlp_dims, dims) |
|
self.gelu = nn.GELU() |
|
|
|
def __call__(self, x, mask): |
|
attention_out = self.attention(x, x, x, mask) |
|
add_and_norm = self.ln1(x + attention_out) |
|
|
|
ff = self.linear1(add_and_norm) |
|
ff_gelu = self.gelu(ff) |
|
ff_out = self.linear2(ff_gelu) |
|
x = self.ln2(ff_out + add_and_norm) |
|
|
|
return x |
|
|
|
|
|
class TransformerEncoder(nn.Module): |
|
def __init__( |
|
self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None |
|
): |
|
super().__init__() |
|
self.layers = [ |
|
TransformerEncoderLayer(dims, num_heads, mlp_dims) |
|
for i in range(num_layers) |
|
] |
|
|
|
def __call__(self, x, mask): |
|
for layer in self.layers: |
|
x = layer(x, mask) |
|
|
|
return x |
|
|
|
|
|
class BertEmbeddings(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) |
|
self.token_type_embeddings = nn.Embedding( |
|
config.type_vocab_size, config.hidden_size |
|
) |
|
self.position_embeddings = nn.Embedding( |
|
config.max_position_embeddings, config.hidden_size |
|
) |
|
self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
|
|
def __call__( |
|
self, input_ids: mx.array, token_type_ids: mx.array = None |
|
) -> mx.array: |
|
words = self.word_embeddings(input_ids) |
|
position = self.position_embeddings( |
|
mx.broadcast_to(mx.arange(input_ids.shape[1]), input_ids.shape) |
|
) |
|
|
|
if token_type_ids is None: |
|
|
|
token_type_ids = mx.zeros_like(input_ids) |
|
|
|
token_types = self.token_type_embeddings(token_type_ids) |
|
|
|
embeddings = position + words + token_types |
|
return self.norm(embeddings) |
|
|
|
|
|
class Bert(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.embeddings = BertEmbeddings(config) |
|
self.encoder = TransformerEncoder( |
|
num_layers=config.num_hidden_layers, |
|
dims=config.hidden_size, |
|
num_heads=config.num_attention_heads, |
|
mlp_dims=config.intermediate_size, |
|
) |
|
self.pooler = nn.Linear(config.hidden_size, config.hidden_size) |
|
|
|
def __call__( |
|
self, |
|
input_ids: mx.array, |
|
token_type_ids: mx.array = None, |
|
attention_mask: mx.array = None, |
|
) -> Tuple[mx.array, mx.array]: |
|
x = self.embeddings(input_ids, token_type_ids) |
|
|
|
if attention_mask is not None: |
|
|
|
attention_mask = mx.log(attention_mask) |
|
attention_mask = mx.expand_dims(attention_mask, (1, 2)) |
|
|
|
y = self.encoder(x, attention_mask) |
|
return y, mx.tanh(self.pooler(y[:, 0])) |
|
|
|
|
|
def load_model( |
|
bert_model: str, weights_path: str |
|
) -> Tuple[Bert, PreTrainedTokenizerBase]: |
|
if not Path(weights_path).exists(): |
|
raise ValueError(f"No model weights found in {weights_path}") |
|
|
|
|
|
config_path = Path(weights_path).parent / "config.json" |
|
if config_path.exists(): |
|
with open(config_path, "r") as f: |
|
config_dict = json.load(f) |
|
config = AutoConfig.for_model(**config_dict) |
|
print(f"Loaded local config from {config_path}") |
|
else: |
|
|
|
config = AutoConfig.from_pretrained(bert_model) |
|
print(f"Loaded config from HuggingFace for {bert_model}") |
|
|
|
|
|
print(f"Creating model with vocab_size={config.vocab_size}, hidden_size={config.hidden_size}") |
|
model = Bert(config) |
|
model.load_weights(weights_path) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(bert_model) |
|
|
|
return model, tokenizer |
|
|
|
|
|
def run(bert_model: str, mlx_model: str, batch: List[str]): |
|
import time |
|
|
|
|
|
load_start = time.time() |
|
model, tokenizer = load_model(bert_model, mlx_model) |
|
load_time = time.time() - load_start |
|
print(f"[MLX] Model loaded in {load_time:.2f} seconds") |
|
|
|
|
|
print(f"[MLX] Tokenizing batch of {len(batch)} sentences") |
|
token_start = time.time() |
|
tokens = tokenizer(batch, return_tensors="np", padding=True) |
|
token_time = time.time() - token_start |
|
print(f"[MLX] Tokenization completed in {token_time:.4f} seconds") |
|
|
|
print(f"[MLX] Tokens shape: input_ids={tokens['input_ids'].shape}") |
|
tokens = {key: mx.array(v) for key, v in tokens.items()} |
|
|
|
|
|
print(f"[MLX] Running model inference") |
|
infer_start = time.time() |
|
output, pooled = model(**tokens) |
|
mx.eval(output, pooled) |
|
infer_time = time.time() - infer_start |
|
print(f"[MLX] Inference completed in {infer_time:.4f} seconds") |
|
|
|
return output, pooled |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(description="Run the BERT model using MLX.") |
|
parser.add_argument( |
|
"--bert-model", |
|
type=str, |
|
default="bert-base-uncased", |
|
help="The huggingface name of the BERT model to save.", |
|
) |
|
parser.add_argument( |
|
"--mlx-model", |
|
type=str, |
|
default="weights/bert-base-uncased.npz", |
|
help="The path of the stored MLX BERT weights (npz file).", |
|
) |
|
parser.add_argument( |
|
"--text", |
|
type=str, |
|
default="This is an example of BERT working in MLX", |
|
help="The text to generate embeddings for.", |
|
) |
|
args = parser.parse_args() |
|
run(args.bert_model, args.mlx_model, args.text) |
|
|