darkbreakerk's picture
Refactor + convert onnx model
280d87f
import onnx
import onnxruntime
import torch
from transformers import BertForTokenClassification
from .config_train import model_load_path, onnx_path, tokenizer
# Convert Model to ONNX
def convert_to_onnx(model_path, tokenizer):
"""Convert the fine-tuned BERT token classification model to ONNX."""
model = BertForTokenClassification.from_pretrained(model_path)
model.eval()
# Dummy input
dummy_sentence = "Tôi muốn đi cắm trại ngắm hoàng hôn trên biển cùng gia đình"
inputs = tokenizer(dummy_sentence, return_tensors="pt", padding=True, truncation=True)
dummy_input_ids = inputs["input_ids"]
dummy_attention_mask = inputs["attention_mask"]
# Export ONNX model
torch.onnx.export(
model,
(inputs["input_ids"], inputs["attention_mask"]), # Tuple of model inputs
onnx_path,
export_params=True,
opset_version=14, # Use Opset 14 or higher
input_names=["input_ids", "attention_mask"],
output_names=["logits"],
dynamic_axes={"input_ids": {0: "batch_size", 1: "sequence_length"},
"attention_mask": {0: "batch_size", 1: "sequence_length"},
"logits": {0: "batch_size", 1: "sequence_length"}},
)
print(f"✅ ONNX model saved to {onnx_path}")
convert_to_onnx(model_load_path, tokenizer)