ag-nexla commited on
Commit
09fb7a6
·
1 Parent(s): 9cb83c5

updated onnx model

Browse files
Files changed (2) hide show
  1. model.onnx +2 -2
  2. modeling.py +4 -5
model.onnx CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:503b7a157de83c5ae3fc63dac56be01bf724cdc4c7f2141febfbe2d65ce8468d
3
- size 436270743
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:30dae9a99d07f56c103a09173deaa9f76f141976ca20dd8f7e5a5cce8152dee8
3
+ size 436269030
modeling.py CHANGED
@@ -124,9 +124,7 @@ class ConstBERT(BertPreTrainedModel):
124
  # Q = self.query_project(Q) #(64, 128,8)
125
  # Q = Q.permute(0, 2, 1) #(64,8,128)
126
  Q = self.linear(Q)
127
- # mask = torch.ones(Q.shape[0], Q.shape[1], device=self.device).unsqueeze(2).float()
128
-
129
- mask = torch.tensor(self.mask(input_ids, skiplist=[]), device=self.device).unsqueeze(2).float()
130
  Q = Q * mask
131
 
132
  return torch.nn.functional.normalize(Q, p=2, dim=2)
@@ -165,8 +163,9 @@ class ConstBERT(BertPreTrainedModel):
165
  return D
166
 
167
  def mask(self, input_ids, skiplist):
168
- mask = [[(x not in skiplist) and (x != self.pad_token) for x in d] for d in input_ids.cpu().tolist()]
169
- return mask
 
170
 
171
  def query(self, *args, to_cpu=False, **kw_args):
172
  with torch.no_grad():
 
124
  # Q = self.query_project(Q) #(64, 128,8)
125
  # Q = Q.permute(0, 2, 1) #(64,8,128)
126
  Q = self.linear(Q)
127
+ mask = self.mask(input_ids, skiplist=[]).unsqueeze(2)
 
 
128
  Q = Q * mask
129
 
130
  return torch.nn.functional.normalize(Q, p=2, dim=2)
 
163
  return D
164
 
165
  def mask(self, input_ids, skiplist):
166
+ # For ONNX export and inference, skiplist should be empty
167
+ # Create mask: 1 where input_ids != pad_token, else 0
168
+ return (input_ids != self.pad_token).float()
169
 
170
  def query(self, *args, to_cpu=False, **kw_args):
171
  with torch.no_grad():