added onnx model
Browse files- export_to_onnx.py +50 -0
- model.onnx +3 -0
- modeling.py +7 -0
export_to_onnx.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoModel, AutoTokenizer
|
2 |
+
from pathlib import Path
|
3 |
+
import torch
|
4 |
+
import sys
|
5 |
+
|
6 |
+
try:
|
7 |
+
print("Loading tokenizer...")
|
8 |
+
model_name = "." # local dir
|
9 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
10 |
+
print("β Tokenizer loaded successfully")
|
11 |
+
|
12 |
+
print("Loading model...")
|
13 |
+
model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
|
14 |
+
print("β Model loaded successfully")
|
15 |
+
|
16 |
+
print("Setting model to evaluation mode...")
|
17 |
+
model.eval()
|
18 |
+
print("β Model set to evaluation mode")
|
19 |
+
|
20 |
+
print("Tokenizing input text...")
|
21 |
+
inputs = tokenizer("Export this model to ONNX!", return_tensors="pt")
|
22 |
+
print("β Input tokenized successfully")
|
23 |
+
|
24 |
+
print("Exporting model to ONNX format...")
|
25 |
+
# Export ONNX
|
26 |
+
torch.onnx.export(
|
27 |
+
model,
|
28 |
+
(inputs["input_ids"], inputs["attention_mask"]),
|
29 |
+
"model.onnx",
|
30 |
+
input_names=["input_ids", "attention_mask"],
|
31 |
+
output_names=["last_hidden_state"],
|
32 |
+
dynamic_axes={
|
33 |
+
"input_ids": {0: "batch", 1: "seq"},
|
34 |
+
"attention_mask": {0: "batch", 1: "seq"},
|
35 |
+
"last_hidden_state": {0: "batch", 1: "seq"},
|
36 |
+
},
|
37 |
+
opset_version=14,
|
38 |
+
)
|
39 |
+
print("β Model exported to ONNX successfully")
|
40 |
+
print(f"β ONNX file saved as: model.onnx")
|
41 |
+
|
42 |
+
except FileNotFoundError as e:
|
43 |
+
print(f"β Error: Model files not found in current directory: {e}")
|
44 |
+
sys.exit(1)
|
45 |
+
except ImportError as e:
|
46 |
+
print(f"β Error: Failed to import required modules: {e}")
|
47 |
+
sys.exit(1)
|
48 |
+
except Exception as e:
|
49 |
+
print(f"β Error during model export: {e}")
|
50 |
+
sys.exit(1)
|
model.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f2459b4649b9b26e9ec6a72b4fea0e1ecd25fe4b24dbefe3db6e1d7df191a844
|
3 |
+
size 436268726
|
modeling.py
CHANGED
@@ -131,6 +131,13 @@ class ConstBERT(BertPreTrainedModel):
|
|
131 |
|
132 |
return torch.nn.functional.normalize(Q, p=2, dim=2)
|
133 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
def _doc(self, input_ids, attention_mask, keep_dims=True):
|
135 |
assert keep_dims in [True, False, 'return_mask']
|
136 |
|
|
|
131 |
|
132 |
return torch.nn.functional.normalize(Q, p=2, dim=2)
|
133 |
|
134 |
+
def forward(self, input_ids, attention_mask):
|
135 |
+
"""
|
136 |
+
Forward method for ONNX export and PyTorch compatibility.
|
137 |
+
This simply calls the existing _query method, preserving all current model behavior.
|
138 |
+
"""
|
139 |
+
return self._query(input_ids, attention_mask)
|
140 |
+
|
141 |
def _doc(self, input_ids, attention_mask, keep_dims=True):
|
142 |
assert keep_dims in [True, False, 'return_mask']
|
143 |
|