ag-nexla commited on
Commit
2336bf5
·
1 Parent(s): 09fb7a6

added custom hadnler

Browse files
Files changed (5) hide show
  1. export_to_onnx.py +6 -3
  2. handler.py +94 -0
  3. model.onnx +2 -2
  4. modeling.py +31 -17
  5. requirements.txt +4 -0
export_to_onnx.py CHANGED
@@ -16,9 +16,12 @@ try:
16
  model.eval()
17
  print("✓ Model set to evaluation mode")
18
 
19
- print("Preparing dummy input for ONNX export...")
20
- dummy_input_ids = torch.ones((1, 512), dtype=torch.long)
21
- dummy_attention_mask = torch.ones((1, 512), dtype=torch.long)
 
 
 
22
  print("✓ Dummy input prepared")
23
 
24
  print("Exporting model to ONNX format...")
 
16
  model.eval()
17
  print("✓ Model set to evaluation mode")
18
 
19
+ # Use the doc_maxlen from the *loaded model's* colbert_config
20
+ actual_doc_maxlen = model.colbert_config.doc_maxlen
21
+ print(f"DEBUG: model.colbert_config.doc_maxlen = {actual_doc_maxlen}")
22
+ print(f"Preparing dummy input for ONNX export with doc_maxlen={actual_doc_maxlen}...")
23
+ dummy_input_ids = torch.ones((1, actual_doc_maxlen), dtype=torch.long)
24
+ dummy_attention_mask = torch.ones((1, actual_doc_maxlen), dtype=torch.long)
25
  print("✓ Dummy input prepared")
26
 
27
  print("Exporting model to ONNX format...")
handler.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # handler.py
2
+ import os
3
+ import onnxruntime as ort
4
+ import numpy as np
5
+ from transformers import AutoTokenizer
6
+ from typing import Dict, List, Any
7
+ from colbert_configuration import ColBERTConfig # Import ColBERTConfig
8
+
9
+ # Assuming modeling.py and colbert_configuration.py are in the same directory
10
+ # We'll use local imports since this handler will run within the model's directory context
11
+ # For ConstBERT to be recognized, you need to ensure these are importable.
12
+ # If you run into issues, consider a custom Docker image or ensuring the model
13
+ # is loadable via AutoModel.from_pretrained if it has auto_map in config.json
14
+ # For simplicity, we're relying on ConstBERT.from_pretrained working with ONNXRuntime path.
15
+
16
+ # Note: The EndpointHandler class must be named exactly this.
17
+ class EndpointHandler:
18
+ def __init__(self, path=""): # path will be '/repository' on HF Endpoints
19
+ # `path` is the directory where your model files (model.onnx, tokenizer files) are located.
20
+
21
+ # Load the tokenizer
22
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
23
+ print(f"Tokenizer loaded from: {path}")
24
+
25
+ # Load ColBERTConfig to get doc_maxlen for consistent padding
26
+ # IMPORTANT: Use load_from_checkpoint to get the *exact* config used for model export.
27
+ self.colbert_config = ColBERTConfig.load_from_checkpoint(path)
28
+ self.doc_max_length = self.colbert_config.doc_maxlen
29
+ print(f"ColBERTConfig doc_maxlen loaded as: {self.doc_max_length}")
30
+
31
+ # Load the ONNX model
32
+ onnx_model_path = os.path.join(path, "model.onnx")
33
+ self.session = ort.InferenceSession(onnx_model_path)
34
+ print(f"ONNX model loaded from: {onnx_model_path}")
35
+
36
+ # Get input names from the ONNX model
37
+ self.input_names = [input.name for input in self.session.get_inputs()]
38
+ print(f"ONNX input names: {self.input_names}")
39
+
40
+
41
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
42
+ """
43
+ Inference call for the endpoint.
44
+
45
+ Args:
46
+ data (Dict[str, Any]): The request payload.
47
+ Expected to contain "inputs" (str or list of str).
48
+
49
+ Returns:
50
+ List[Dict[str, Any]]: A list of dictionaries, where each dict
51
+ contains the raw multi-vector output for an input.
52
+ Example: [{"embedding": [[...], [...], ...]}, ...]
53
+ """
54
+ inputs = data.pop("inputs", None)
55
+ if inputs is None:
56
+ raise ValueError("No 'inputs' found in the request payload.")
57
+
58
+ # Ensure inputs is a list
59
+ if isinstance(inputs, str):
60
+ inputs = [inputs]
61
+
62
+ # Tokenize the inputs, ensuring consistent padding/truncation to doc_max_length
63
+ tokenized_inputs = self.tokenizer(
64
+ inputs,
65
+ padding="max_length", # Use max_length padding
66
+ truncation=True,
67
+ max_length=self.doc_max_length, # Use the loaded doc_max_length
68
+ return_tensors="np"
69
+ )
70
+
71
+ input_ids = tokenized_inputs["input_ids"]
72
+ attention_mask = tokenized_inputs["attention_mask"]
73
+
74
+ # Prepare ONNX input dictionary
75
+ onnx_inputs = {
76
+ "input_ids": input_ids,
77
+ "attention_mask": attention_mask
78
+ }
79
+
80
+ # Run ONNX inference
81
+ outputs = self.session.run(None, onnx_inputs)
82
+
83
+ # The first output is your multi-vector embedding
84
+ multi_vector_embeddings = outputs[0]
85
+
86
+ # Convert to list of lists (JSON serializable)
87
+ # Assuming batch_size will be 1 for typical endpoint requests, but handling potential batching from client for robustness.
88
+ result_list = []
89
+ for i in range(multi_vector_embeddings.shape[0]):
90
+ # Each element in the result_list will be a dictionary for one input,
91
+ # containing its multi-vector embedding (fixed 32 x 128)
92
+ result_list.append({"embedding": multi_vector_embeddings[i].tolist()})
93
+
94
+ return result_list
model.onnx CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:30dae9a99d07f56c103a09173deaa9f76f141976ca20dd8f7e5a5cce8152dee8
3
- size 436269030
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d515b85a59a302d13d04b3a45c6211b3e1893a2718c13598231acc18825f0f02
3
+ size 436300888
modeling.py CHANGED
@@ -60,6 +60,7 @@ class ConstBERT(BertPreTrainedModel):
60
  super().__init__(config)
61
 
62
  self.config = config
 
63
  self.dim = colbert_config.dim
64
  self.linear = nn.Linear(config.hidden_size, colbert_config.dim, bias=False)
65
  self.doc_project = nn.Linear(colbert_config.doc_maxlen, 32, bias=False)
@@ -132,33 +133,46 @@ class ConstBERT(BertPreTrainedModel):
132
  def forward(self, input_ids, attention_mask):
133
  """
134
  Forward method for ONNX export and PyTorch compatibility.
135
- This simply calls the existing _query method, preserving all current model behavior.
136
  """
137
- return self._query(input_ids, attention_mask)
138
 
139
  def _doc(self, input_ids, attention_mask, keep_dims=True):
140
  assert keep_dims in [True, False, 'return_mask']
141
 
142
  input_ids, attention_mask = input_ids.to(self.device), attention_mask.to(self.device)
143
- D = self.bert(input_ids, attention_mask=attention_mask)[0]
144
- D = D.permute(0, 2, 1) #(64, 128,180)
145
- D = self.doc_project(D) #(64, 128,16)
146
- D = D.permute(0, 2, 1) #(64,16,128)
147
- D = self.linear(D)
148
- mask = torch.ones(D.shape[0], D.shape[1], device=self.device).unsqueeze(2).float()
149
-
150
- # mask = torch.tensor(self.mask(input_ids, skiplist=self.skiplist), device=self.device).unsqueeze(2).float()
151
- D = D * mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  D = torch.nn.functional.normalize(D, p=2, dim=2)
153
  if self.use_gpu:
154
  D = D.half()
155
 
156
- if keep_dims is False:
157
- D, mask = D.cpu(), mask.bool().cpu().squeeze(-1)
158
- D = [d[mask[idx]] for idx, d in enumerate(D)]
159
-
160
- elif keep_dims == 'return_mask':
161
- return D, mask.bool()
162
 
163
  return D
164
 
 
60
  super().__init__(config)
61
 
62
  self.config = config
63
+ self.colbert_config = colbert_config
64
  self.dim = colbert_config.dim
65
  self.linear = nn.Linear(config.hidden_size, colbert_config.dim, bias=False)
66
  self.doc_project = nn.Linear(colbert_config.doc_maxlen, 32, bias=False)
 
133
  def forward(self, input_ids, attention_mask):
134
  """
135
  Forward method for ONNX export and PyTorch compatibility.
136
+ This will now call _doc to produce a fixed number of vectors.
137
  """
138
+ return self._doc(input_ids, attention_mask)
139
 
140
  def _doc(self, input_ids, attention_mask, keep_dims=True):
141
  assert keep_dims in [True, False, 'return_mask']
142
 
143
  input_ids, attention_mask = input_ids.to(self.device), attention_mask.to(self.device)
144
+ D = self.bert(input_ids, attention_mask=attention_mask)[0] # Shape: (batch_size, seq_len, hidden_size)
145
+
146
+ # First, apply linear layer to project hidden_size to colbert_config.dim (128)
147
+ D = self.linear(D) # Shape: (batch_size, seq_len, dim)
148
+
149
+ # Now, permute to put seq_len in the feature dimension for doc_project
150
+ D = D.permute(0, 2, 1) # Shape: (batch_size, dim, seq_len)
151
+
152
+ # Apply doc_project to reduce seq_len (e.g., 250) to fixed length (32)
153
+ # The nn.Linear(in_features, out_features) operates on the last dimension.
154
+ # So it expects the last dimension to be seq_len (doc_maxlen).
155
+ # It will transform it to (batch_size, dim, 32)
156
+ D = self.doc_project(D) # Shape: (batch_size, dim, 32)
157
+
158
+ # Permute back to (batch_size, 32, dim)
159
+ D = D.permute(0, 2, 1) # Shape: (batch_size, 32, dim)
160
+
161
+ # Apply mask (assuming it's still needed in this part of the flow)
162
+ # The mask now needs to be applied correctly to the (batch_size, 32, dim) shape
163
+ # For now, let's simplify mask application or ensure it's handled correctly if it remains a static shape.
164
+ # Given the fixed output, the original masking might be less critical here, or needs to be re-evaluated.
165
+
166
+ # Temporarily removing original mask logic in _doc to avoid immediate conflict.
167
+ # If a learned mask is needed on the 32 vectors, it needs separate logic.
168
+ # mask = torch.ones(D.shape[0], D.shape[1], device=self.device).unsqueeze(2).float()
169
+ # D = D * mask
170
+
171
  D = torch.nn.functional.normalize(D, p=2, dim=2)
172
  if self.use_gpu:
173
  D = D.half()
174
 
175
+ # Removed keep_dims conditional branches as _doc now consistently returns fixed 32 vectors.
 
 
 
 
 
176
 
177
  return D
178
 
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ onnxruntime
2
+ transformers
3
+ numpy
4
+ torch # Required by your modeling.py for ConstBERT logic