updated onnx model
Browse files- model.onnx +2 -2
- modeling.py +4 -5
model.onnx
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:30dae9a99d07f56c103a09173deaa9f76f141976ca20dd8f7e5a5cce8152dee8
|
3 |
+
size 436269030
|
modeling.py
CHANGED
@@ -124,9 +124,7 @@ class ConstBERT(BertPreTrainedModel):
|
|
124 |
# Q = self.query_project(Q) #(64, 128,8)
|
125 |
# Q = Q.permute(0, 2, 1) #(64,8,128)
|
126 |
Q = self.linear(Q)
|
127 |
-
|
128 |
-
|
129 |
-
mask = torch.tensor(self.mask(input_ids, skiplist=[]), device=self.device).unsqueeze(2).float()
|
130 |
Q = Q * mask
|
131 |
|
132 |
return torch.nn.functional.normalize(Q, p=2, dim=2)
|
@@ -165,8 +163,9 @@ class ConstBERT(BertPreTrainedModel):
|
|
165 |
return D
|
166 |
|
167 |
def mask(self, input_ids, skiplist):
|
168 |
-
|
169 |
-
|
|
|
170 |
|
171 |
def query(self, *args, to_cpu=False, **kw_args):
|
172 |
with torch.no_grad():
|
|
|
124 |
# Q = self.query_project(Q) #(64, 128,8)
|
125 |
# Q = Q.permute(0, 2, 1) #(64,8,128)
|
126 |
Q = self.linear(Q)
|
127 |
+
mask = self.mask(input_ids, skiplist=[]).unsqueeze(2)
|
|
|
|
|
128 |
Q = Q * mask
|
129 |
|
130 |
return torch.nn.functional.normalize(Q, p=2, dim=2)
|
|
|
163 |
return D
|
164 |
|
165 |
def mask(self, input_ids, skiplist):
|
166 |
+
# For ONNX export and inference, skiplist should be empty
|
167 |
+
# Create mask: 1 where input_ids != pad_token, else 0
|
168 |
+
return (input_ids != self.pad_token).float()
|
169 |
|
170 |
def query(self, *args, to_cpu=False, **kw_args):
|
171 |
with torch.no_grad():
|