ag-nexla commited on
Commit
634cac7
Β·
1 Parent(s): 8e26d13

added onnx model

Browse files
Files changed (3) hide show
  1. export_to_onnx.py +50 -0
  2. model.onnx +3 -0
  3. modeling.py +7 -0
export_to_onnx.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModel, AutoTokenizer
2
+ from pathlib import Path
3
+ import torch
4
+ import sys
5
+
6
+ try:
7
+ print("Loading tokenizer...")
8
+ model_name = "." # local dir
9
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
10
+ print("βœ“ Tokenizer loaded successfully")
11
+
12
+ print("Loading model...")
13
+ model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
14
+ print("βœ“ Model loaded successfully")
15
+
16
+ print("Setting model to evaluation mode...")
17
+ model.eval()
18
+ print("βœ“ Model set to evaluation mode")
19
+
20
+ print("Tokenizing input text...")
21
+ inputs = tokenizer("Export this model to ONNX!", return_tensors="pt")
22
+ print("βœ“ Input tokenized successfully")
23
+
24
+ print("Exporting model to ONNX format...")
25
+ # Export ONNX
26
+ torch.onnx.export(
27
+ model,
28
+ (inputs["input_ids"], inputs["attention_mask"]),
29
+ "model.onnx",
30
+ input_names=["input_ids", "attention_mask"],
31
+ output_names=["last_hidden_state"],
32
+ dynamic_axes={
33
+ "input_ids": {0: "batch", 1: "seq"},
34
+ "attention_mask": {0: "batch", 1: "seq"},
35
+ "last_hidden_state": {0: "batch", 1: "seq"},
36
+ },
37
+ opset_version=14,
38
+ )
39
+ print("βœ“ Model exported to ONNX successfully")
40
+ print(f"βœ“ ONNX file saved as: model.onnx")
41
+
42
+ except FileNotFoundError as e:
43
+ print(f"❌ Error: Model files not found in current directory: {e}")
44
+ sys.exit(1)
45
+ except ImportError as e:
46
+ print(f"❌ Error: Failed to import required modules: {e}")
47
+ sys.exit(1)
48
+ except Exception as e:
49
+ print(f"❌ Error during model export: {e}")
50
+ sys.exit(1)
model.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f2459b4649b9b26e9ec6a72b4fea0e1ecd25fe4b24dbefe3db6e1d7df191a844
3
+ size 436268726
modeling.py CHANGED
@@ -131,6 +131,13 @@ class ConstBERT(BertPreTrainedModel):
131
 
132
  return torch.nn.functional.normalize(Q, p=2, dim=2)
133
 
 
 
 
 
 
 
 
134
  def _doc(self, input_ids, attention_mask, keep_dims=True):
135
  assert keep_dims in [True, False, 'return_mask']
136
 
 
131
 
132
  return torch.nn.functional.normalize(Q, p=2, dim=2)
133
 
134
+ def forward(self, input_ids, attention_mask):
135
+ """
136
+ Forward method for ONNX export and PyTorch compatibility.
137
+ This simply calls the existing _query method, preserving all current model behavior.
138
+ """
139
+ return self._query(input_ids, attention_mask)
140
+
141
  def _doc(self, input_ids, attention_mask, keep_dims=True):
142
  assert keep_dims in [True, False, 'return_mask']
143