nadiinchi commited on
Commit
421f913
·
verified ·
1 Parent(s): 7e3a445

add compression rates

Browse files
Files changed (1) hide show
  1. modeling_provence.py +10 -1
modeling_provence.py CHANGED
@@ -158,6 +158,9 @@ class Provence(DebertaV2PreTrainedModel):
158
  reranking_scores = [
159
  [None for j in range(len(contexts[i]))] for i in range(len(queries))
160
  ]
 
 
 
161
  with torch.no_grad():
162
  for batch_start in tqdm(
163
  range(0, len(dataset), batch_size), desc="Pruning contexts..."
@@ -225,18 +228,24 @@ class Provence(DebertaV2PreTrainedModel):
225
  )
226
  else:
227
  selected_contexts[i][j] = selected_contexts[i][j][0]
 
 
 
228
  if reorder:
229
  idxs = np.argsort(reranking_scores[i])[::-1][:top_k]
230
  selected_contexts[i] = [selected_contexts[i][j] for j in idxs]
231
  reranking_scores[i] = [reranking_scores[i][j] for j in idxs]
 
232
 
233
  if type(context) == str:
234
  selected_contexts = selected_contexts[0][0]
235
  reranking_scores = reranking_scores[0][0]
 
236
 
237
  return {
238
  "pruned_context": selected_contexts,
239
- "reranking_score": reranking_scores
 
240
  }
241
 
242
 
 
158
  reranking_scores = [
159
  [None for j in range(len(contexts[i]))] for i in range(len(queries))
160
  ]
161
+ compressions = [
162
+ [0 for j in range(len(contexts[i]))] for i in range(len(queries))
163
+ ]
164
  with torch.no_grad():
165
  for batch_start in tqdm(
166
  range(0, len(dataset), batch_size), desc="Pruning contexts..."
 
228
  )
229
  else:
230
  selected_contexts[i][j] = selected_contexts[i][j][0]
231
+ len_original = len(contexts[i][j])
232
+ len_compressed = len(selected_contexts[i][j])
233
+ compressions[i][j] = (len_original-len_compressed)/len_original * 100
234
  if reorder:
235
  idxs = np.argsort(reranking_scores[i])[::-1][:top_k]
236
  selected_contexts[i] = [selected_contexts[i][j] for j in idxs]
237
  reranking_scores[i] = [reranking_scores[i][j] for j in idxs]
238
+ compressions[i] = [compressions[i][j] for j in idxs]
239
 
240
  if type(context) == str:
241
  selected_contexts = selected_contexts[0][0]
242
  reranking_scores = reranking_scores[0][0]
243
+ compressions = compressions[0][0]
244
 
245
  return {
246
  "pruned_context": selected_contexts,
247
+ "reranking_score": reranking_scores,
248
+ "compression_rate": compressions,
249
  }
250
 
251