constbert-onnx / export_to_onnx.py
ag-nexla's picture
added onnx model
634cac7
raw
history blame
1.63 kB
from transformers import AutoModel, AutoTokenizer
from pathlib import Path
import torch
import sys
try:
print("Loading tokenizer...")
model_name = "." # local dir
tokenizer = AutoTokenizer.from_pretrained(model_name)
print("βœ“ Tokenizer loaded successfully")
print("Loading model...")
model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
print("βœ“ Model loaded successfully")
print("Setting model to evaluation mode...")
model.eval()
print("βœ“ Model set to evaluation mode")
print("Tokenizing input text...")
inputs = tokenizer("Export this model to ONNX!", return_tensors="pt")
print("βœ“ Input tokenized successfully")
print("Exporting model to ONNX format...")
# Export ONNX
torch.onnx.export(
model,
(inputs["input_ids"], inputs["attention_mask"]),
"model.onnx",
input_names=["input_ids", "attention_mask"],
output_names=["last_hidden_state"],
dynamic_axes={
"input_ids": {0: "batch", 1: "seq"},
"attention_mask": {0: "batch", 1: "seq"},
"last_hidden_state": {0: "batch", 1: "seq"},
},
opset_version=14,
)
print("βœ“ Model exported to ONNX successfully")
print(f"βœ“ ONNX file saved as: model.onnx")
except FileNotFoundError as e:
print(f"❌ Error: Model files not found in current directory: {e}")
sys.exit(1)
except ImportError as e:
print(f"❌ Error: Failed to import required modules: {e}")
sys.exit(1)
except Exception as e:
print(f"❌ Error during model export: {e}")
sys.exit(1)