Spaces:
Sleeping
Sleeping
import onnx | |
import onnxruntime | |
import torch | |
from transformers import BertForTokenClassification | |
from .config_train import model_load_path, onnx_path, tokenizer | |
# Convert Model to ONNX | |
def convert_to_onnx(model_path, tokenizer): | |
"""Convert the fine-tuned BERT token classification model to ONNX.""" | |
model = BertForTokenClassification.from_pretrained(model_path) | |
model.eval() | |
# Dummy input | |
dummy_sentence = "Tôi muốn đi cắm trại ngắm hoàng hôn trên biển cùng gia đình" | |
inputs = tokenizer(dummy_sentence, return_tensors="pt", padding=True, truncation=True) | |
dummy_input_ids = inputs["input_ids"] | |
dummy_attention_mask = inputs["attention_mask"] | |
# Export ONNX model | |
torch.onnx.export( | |
model, | |
(inputs["input_ids"], inputs["attention_mask"]), # Tuple of model inputs | |
onnx_path, | |
export_params=True, | |
opset_version=14, # Use Opset 14 or higher | |
input_names=["input_ids", "attention_mask"], | |
output_names=["logits"], | |
dynamic_axes={"input_ids": {0: "batch_size", 1: "sequence_length"}, | |
"attention_mask": {0: "batch_size", 1: "sequence_length"}, | |
"logits": {0: "batch_size", 1: "sequence_length"}}, | |
) | |
print(f"✅ ONNX model saved to {onnx_path}") | |
convert_to_onnx(model_load_path, tokenizer) | |