ModernBERT-CoreML / convert.py
finnvoorhees's picture
Add conversion script
7945f87
# /// script
# requires-python = ">=3.11"
# dependencies = [
# "coremltools",
# "torch",
# "transformers",
# ]
#
# [tool.uv.sources]
# transformers = { git = "https://github.com/huggingface/transformers.git", branch = "main" }
# ///
from transformers import AutoModelForMaskedLM
import torch
import coremltools as ct
import numpy as np
import argparse
def log(text):
print(f"\033[92m\033[1m{text}\033[0m")
parser = argparse.ArgumentParser(
prog="convert.py", description="Convert ModernBERT to CoreML"
)
parser.add_argument("--model", type=str, default="ModernBERT-base", help="Model name")
parser.add_argument("--quantize", action="store_true", help="Linear quantize model")
args = parser.parse_args()
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.model = AutoModelForMaskedLM.from_pretrained(f"answerdotai/{args.model}")
def forward(self, input_ids):
attention_mask = torch.ones_like(input_ids)
return self.model(input_ids=input_ids, attention_mask=attention_mask).logits
log("Loading model…")
model = Model().eval()
log("Tracing model…")
example_input = (torch.zeros((1, 1), dtype=torch.int32),)
traced_model = torch.jit.trace(model, example_input)
log("Converting model…")
input_shape = (
1,
ct.RangeDim(
lower_bound=1, upper_bound=model.model.config.max_position_embeddings, default=1
),
)
mlmodel = ct.convert(
traced_model,
inputs=[ct.TensorType(name="input_ids", shape=input_shape, dtype=np.int32)],
outputs=[ct.TensorType(name="logits")],
minimum_deployment_target=ct.target.macOS15,
)
if args.quantize:
log("Quantizing model…")
op_config = ct.optimize.coreml.OpLinearQuantizerConfig(
mode="linear_symmetric",
dtype="int4",
granularity="per_block",
block_size=32,
)
config = ct.optimize.coreml.OptimizationConfig(global_config=op_config)
mlmodel = ct.optimize.coreml.linear_quantize_weights(mlmodel, config=config)
mlmodel.author = "Finn Voorhees"
mlmodel.short_description = "https://hf.co/finnvoorhees/ModernBERT-CoreML"
log("Saving mlmodel…")
if args.quantize:
mlmodel.save(f"{args.model}-4bit.mlpackage")
else:
mlmodel.save(f"{args.model}.mlpackage")
log("Done!")