File size: 2,011 Bytes
a02c788
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
#!/usr/bin/env python
# -*- coding: utf-8 -*-

from statistics import mean


class AggregationStrategy:
    def __init__(
        self,
        method,
        max_items=None,
        top_items=True,
        sorting_class_index=1
    ):
        self.method = method
        self.max_items = max_items
        self.top_items = top_items
        self.sorting_class_index = sorting_class_index

    def aggregate(self, softmax_tuples):
        softmax_dicts = []
        for softmax_tuple in softmax_tuples:
            softmax_dict = {}
            for i, probability in enumerate(softmax_tuple):
                softmax_dict[i] = probability
            softmax_dicts.append(softmax_dict)

        if self.max_items is not None:
            softmax_dicts = sorted(
                softmax_dicts,
                key=lambda x: x[self.sorting_class_index],
                reverse=self.top_items
            )
            if self.max_items < len(softmax_dicts):
                softmax_dicts = softmax_dicts[:self.max_items]

        softmax_list = []
        for key in softmax_dicts[0].keys():
            softmax_list.append(self.method(
                [probabilities[key] for probabilities in softmax_dicts]))
        softmax_tuple = tuple(softmax_list)
        return softmax_tuple


class AggregationStrategies:
    Mean = AggregationStrategy(method=mean)
    MeanTopFiveBinaryClassification = AggregationStrategy(
        method=mean,
        max_items=5,
        top_items=True,
        sorting_class_index=1
    )
    MeanTopTenBinaryClassification = AggregationStrategy(
        method=mean,
        max_items=10,
        top_items=True,
        sorting_class_index=1
    )
    MeanTopFifteenBinaryClassification = AggregationStrategy(
        method=mean,
        max_items=15,
        top_items=True,
        sorting_class_index=1
    )
    MeanTopTwentyBinaryClassification = AggregationStrategy(
        method=mean,
        max_items=20,
        top_items=True,
        sorting_class_index=1
    )