ernie_demo_toy / ernie /aggregation_strategies.py
Jean Garcia-Gathright
added ernie files
a02c788
raw
history blame contribute delete
No virus
2.01 kB
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from statistics import mean
class AggregationStrategy:
def __init__(
self,
method,
max_items=None,
top_items=True,
sorting_class_index=1
):
self.method = method
self.max_items = max_items
self.top_items = top_items
self.sorting_class_index = sorting_class_index
def aggregate(self, softmax_tuples):
softmax_dicts = []
for softmax_tuple in softmax_tuples:
softmax_dict = {}
for i, probability in enumerate(softmax_tuple):
softmax_dict[i] = probability
softmax_dicts.append(softmax_dict)
if self.max_items is not None:
softmax_dicts = sorted(
softmax_dicts,
key=lambda x: x[self.sorting_class_index],
reverse=self.top_items
)
if self.max_items < len(softmax_dicts):
softmax_dicts = softmax_dicts[:self.max_items]
softmax_list = []
for key in softmax_dicts[0].keys():
softmax_list.append(self.method(
[probabilities[key] for probabilities in softmax_dicts]))
softmax_tuple = tuple(softmax_list)
return softmax_tuple
class AggregationStrategies:
Mean = AggregationStrategy(method=mean)
MeanTopFiveBinaryClassification = AggregationStrategy(
method=mean,
max_items=5,
top_items=True,
sorting_class_index=1
)
MeanTopTenBinaryClassification = AggregationStrategy(
method=mean,
max_items=10,
top_items=True,
sorting_class_index=1
)
MeanTopFifteenBinaryClassification = AggregationStrategy(
method=mean,
max_items=15,
top_items=True,
sorting_class_index=1
)
MeanTopTwentyBinaryClassification = AggregationStrategy(
method=mean,
max_items=20,
top_items=True,
sorting_class_index=1
)