add compression rates
Browse files- modeling_provence.py +10 -1
modeling_provence.py
CHANGED
@@ -158,6 +158,9 @@ class Provence(DebertaV2PreTrainedModel):
|
|
158 |
reranking_scores = [
|
159 |
[None for j in range(len(contexts[i]))] for i in range(len(queries))
|
160 |
]
|
|
|
|
|
|
|
161 |
with torch.no_grad():
|
162 |
for batch_start in tqdm(
|
163 |
range(0, len(dataset), batch_size), desc="Pruning contexts..."
|
@@ -225,18 +228,24 @@ class Provence(DebertaV2PreTrainedModel):
|
|
225 |
)
|
226 |
else:
|
227 |
selected_contexts[i][j] = selected_contexts[i][j][0]
|
|
|
|
|
|
|
228 |
if reorder:
|
229 |
idxs = np.argsort(reranking_scores[i])[::-1][:top_k]
|
230 |
selected_contexts[i] = [selected_contexts[i][j] for j in idxs]
|
231 |
reranking_scores[i] = [reranking_scores[i][j] for j in idxs]
|
|
|
232 |
|
233 |
if type(context) == str:
|
234 |
selected_contexts = selected_contexts[0][0]
|
235 |
reranking_scores = reranking_scores[0][0]
|
|
|
236 |
|
237 |
return {
|
238 |
"pruned_context": selected_contexts,
|
239 |
-
"reranking_score": reranking_scores
|
|
|
240 |
}
|
241 |
|
242 |
|
|
|
158 |
reranking_scores = [
|
159 |
[None for j in range(len(contexts[i]))] for i in range(len(queries))
|
160 |
]
|
161 |
+
compressions = [
|
162 |
+
[0 for j in range(len(contexts[i]))] for i in range(len(queries))
|
163 |
+
]
|
164 |
with torch.no_grad():
|
165 |
for batch_start in tqdm(
|
166 |
range(0, len(dataset), batch_size), desc="Pruning contexts..."
|
|
|
228 |
)
|
229 |
else:
|
230 |
selected_contexts[i][j] = selected_contexts[i][j][0]
|
231 |
+
len_original = len(contexts[i][j])
|
232 |
+
len_compressed = len(selected_contexts[i][j])
|
233 |
+
compressions[i][j] = (len_original-len_compressed)/len_original * 100
|
234 |
if reorder:
|
235 |
idxs = np.argsort(reranking_scores[i])[::-1][:top_k]
|
236 |
selected_contexts[i] = [selected_contexts[i][j] for j in idxs]
|
237 |
reranking_scores[i] = [reranking_scores[i][j] for j in idxs]
|
238 |
+
compressions[i] = [compressions[i][j] for j in idxs]
|
239 |
|
240 |
if type(context) == str:
|
241 |
selected_contexts = selected_contexts[0][0]
|
242 |
reranking_scores = reranking_scores[0][0]
|
243 |
+
compressions = compressions[0][0]
|
244 |
|
245 |
return {
|
246 |
"pruned_context": selected_contexts,
|
247 |
+
"reranking_score": reranking_scores,
|
248 |
+
"compression_rate": compressions,
|
249 |
}
|
250 |
|
251 |
|