from typing import List def preprocess_logits_for_metrics(logits, labels): if isinstance(logits, tuple): logits = logits[0] return logits.argmax(dim=-1) def dataset_split_selector(data) -> List: """ This is a function for automating the process of selecting data split. Will be further updated. """ if len(data.keys()) == 1: return ['train'] else: if 'train_prefs' in data.keys(): return ['train_prefs', 'test_prefs'] else: return ['train', 'test']