|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from transformers import AutoModelForMaskedLM |
|
import torch |
|
import coremltools as ct |
|
import numpy as np |
|
import argparse |
|
|
|
|
|
def log(text): |
|
print(f"\033[92m\033[1m{text}\033[0m") |
|
|
|
|
|
parser = argparse.ArgumentParser( |
|
prog="convert.py", description="Convert ModernBERT to CoreML" |
|
) |
|
parser.add_argument("--model", type=str, default="ModernBERT-base", help="Model name") |
|
parser.add_argument("--quantize", action="store_true", help="Linear quantize model") |
|
args = parser.parse_args() |
|
|
|
|
|
class Model(torch.nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.model = AutoModelForMaskedLM.from_pretrained(f"answerdotai/{args.model}") |
|
|
|
def forward(self, input_ids): |
|
attention_mask = torch.ones_like(input_ids) |
|
return self.model(input_ids=input_ids, attention_mask=attention_mask).logits |
|
|
|
|
|
log("Loading model…") |
|
model = Model().eval() |
|
|
|
log("Tracing model…") |
|
example_input = (torch.zeros((1, 1), dtype=torch.int32),) |
|
traced_model = torch.jit.trace(model, example_input) |
|
|
|
log("Converting model…") |
|
input_shape = ( |
|
1, |
|
ct.RangeDim( |
|
lower_bound=1, upper_bound=model.model.config.max_position_embeddings, default=1 |
|
), |
|
) |
|
mlmodel = ct.convert( |
|
traced_model, |
|
inputs=[ct.TensorType(name="input_ids", shape=input_shape, dtype=np.int32)], |
|
outputs=[ct.TensorType(name="logits")], |
|
minimum_deployment_target=ct.target.macOS15, |
|
) |
|
|
|
if args.quantize: |
|
log("Quantizing model…") |
|
op_config = ct.optimize.coreml.OpLinearQuantizerConfig( |
|
mode="linear_symmetric", |
|
dtype="int4", |
|
granularity="per_block", |
|
block_size=32, |
|
) |
|
config = ct.optimize.coreml.OptimizationConfig(global_config=op_config) |
|
mlmodel = ct.optimize.coreml.linear_quantize_weights(mlmodel, config=config) |
|
|
|
mlmodel.author = "Finn Voorhees" |
|
mlmodel.short_description = "https://hf.co/finnvoorhees/ModernBERT-CoreML" |
|
|
|
log("Saving mlmodel…") |
|
if args.quantize: |
|
mlmodel.save(f"{args.model}-4bit.mlpackage") |
|
else: |
|
mlmodel.save(f"{args.model}.mlpackage") |
|
|
|
log("Done!") |
|
|