File size: 2,263 Bytes
e17d2b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import argparse
import os
from pathlib import Path

import numpy
from transformers import AutoModel, AutoConfig


def replace_key(key: str) -> str:
    key = key.replace(".layer.", ".layers.")
    key = key.replace(".self.key.", ".key_proj.")
    key = key.replace(".self.query.", ".query_proj.")
    key = key.replace(".self.value.", ".value_proj.")
    key = key.replace(".attention.output.dense.", ".attention.out_proj.")
    key = key.replace(".attention.output.LayerNorm.", ".ln1.")
    key = key.replace(".output.LayerNorm.", ".ln2.")
    key = key.replace(".intermediate.dense.", ".linear1.")
    key = key.replace(".output.dense.", ".linear2.")
    key = key.replace(".LayerNorm.", ".norm.")
    key = key.replace("pooler.dense.", "pooler.")
    return key


def convert(bert_model: str, mlx_model: str) -> None:
    # Load model and its configuration
    model = AutoModel.from_pretrained(bert_model)
    config = AutoConfig.from_pretrained(bert_model)
    
    # Create output directory if it doesn't exist
    output_dir = os.path.dirname(mlx_model)
    if output_dir and not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    # Save config as well
    config_path = os.path.join(output_dir, "config.json")
    with open(config_path, "w") as f:
        f.write(config.to_json_string())
        
    print(f"Saved model config to {config_path}")
    
    # Save the tensors
    tensors = {
        replace_key(key): tensor.numpy() for key, tensor in model.state_dict().items()
    }
    numpy.savez(mlx_model, **tensors)
    print(f"Saved model weights to {mlx_model}")
    print(f"Model vocab size: {config.vocab_size}, hidden size: {config.hidden_size}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Convert BERT weights to MLX.")
    parser.add_argument(
        "--bert-model",
        type=str,
        default="bert-base-uncased",
        help="The huggingface name of the BERT model to save. Any BERT-like model can be specified.",
    )
    parser.add_argument(
        "--mlx-model",
        type=str,
        default="weights/bert-base-uncased.npz",
        help="The output path for the MLX BERT weights.",
    )
    args = parser.parse_args()

    convert(args.bert_model, args.mlx_model)