# /// script # requires-python = ">=3.11" # dependencies = [ # "coremltools", # "torch", # "transformers", # ] # # [tool.uv.sources] # transformers = { git = "https://github.com/huggingface/transformers.git", branch = "main" } # /// from transformers import AutoModelForMaskedLM import torch import coremltools as ct import numpy as np import argparse def log(text): print(f"\033[92m\033[1m{text}\033[0m") parser = argparse.ArgumentParser( prog="convert.py", description="Convert ModernBERT to CoreML" ) parser.add_argument("--model", type=str, default="ModernBERT-base", help="Model name") parser.add_argument("--quantize", action="store_true", help="Linear quantize model") args = parser.parse_args() class Model(torch.nn.Module): def __init__(self): super().__init__() self.model = AutoModelForMaskedLM.from_pretrained(f"answerdotai/{args.model}") def forward(self, input_ids): attention_mask = torch.ones_like(input_ids) return self.model(input_ids=input_ids, attention_mask=attention_mask).logits log("Loading model…") model = Model().eval() log("Tracing model…") example_input = (torch.zeros((1, 1), dtype=torch.int32),) traced_model = torch.jit.trace(model, example_input) log("Converting model…") input_shape = ( 1, ct.RangeDim( lower_bound=1, upper_bound=model.model.config.max_position_embeddings, default=1 ), ) mlmodel = ct.convert( traced_model, inputs=[ct.TensorType(name="input_ids", shape=input_shape, dtype=np.int32)], outputs=[ct.TensorType(name="logits")], minimum_deployment_target=ct.target.macOS15, ) if args.quantize: log("Quantizing model…") op_config = ct.optimize.coreml.OpLinearQuantizerConfig( mode="linear_symmetric", dtype="int4", granularity="per_block", block_size=32, ) config = ct.optimize.coreml.OptimizationConfig(global_config=op_config) mlmodel = ct.optimize.coreml.linear_quantize_weights(mlmodel, config=config) mlmodel.author = "Finn Voorhees" mlmodel.short_description = "https://hf.co/finnvoorhees/ModernBERT-CoreML" log("Saving mlmodel…") if args.quantize: mlmodel.save(f"{args.model}-4bit.mlpackage") else: mlmodel.save(f"{args.model}.mlpackage") log("Done!")