Spaces:
Runtime error
Runtime error
File size: 2,011 Bytes
a02c788 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from statistics import mean
class AggregationStrategy:
def __init__(
self,
method,
max_items=None,
top_items=True,
sorting_class_index=1
):
self.method = method
self.max_items = max_items
self.top_items = top_items
self.sorting_class_index = sorting_class_index
def aggregate(self, softmax_tuples):
softmax_dicts = []
for softmax_tuple in softmax_tuples:
softmax_dict = {}
for i, probability in enumerate(softmax_tuple):
softmax_dict[i] = probability
softmax_dicts.append(softmax_dict)
if self.max_items is not None:
softmax_dicts = sorted(
softmax_dicts,
key=lambda x: x[self.sorting_class_index],
reverse=self.top_items
)
if self.max_items < len(softmax_dicts):
softmax_dicts = softmax_dicts[:self.max_items]
softmax_list = []
for key in softmax_dicts[0].keys():
softmax_list.append(self.method(
[probabilities[key] for probabilities in softmax_dicts]))
softmax_tuple = tuple(softmax_list)
return softmax_tuple
class AggregationStrategies:
Mean = AggregationStrategy(method=mean)
MeanTopFiveBinaryClassification = AggregationStrategy(
method=mean,
max_items=5,
top_items=True,
sorting_class_index=1
)
MeanTopTenBinaryClassification = AggregationStrategy(
method=mean,
max_items=10,
top_items=True,
sorting_class_index=1
)
MeanTopFifteenBinaryClassification = AggregationStrategy(
method=mean,
max_items=15,
top_items=True,
sorting_class_index=1
)
MeanTopTwentyBinaryClassification = AggregationStrategy(
method=mean,
max_items=20,
top_items=True,
sorting_class_index=1
)
|