Spaces:
Runtime error
Runtime error
| import numpy as np | |
| class BaseProbInference: | |
| def __init__(self, prompt_version): | |
| if prompt_version == "default": | |
| self.prompt_version = self.default_prompt_version() | |
| else: | |
| self.prompt_version = prompt_version | |
| self.raw_data_result = None | |
| self.raw_data_sample = None | |
| self.raw_data_dev = None | |
| self.can_be_stratified = False | |
| self.CHOICES = None | |
| self.num_base_shot = 1 | |
| def default_prompt_version(self): | |
| raise NotImplementedError | |
| def dataset_signature(self): | |
| # { | |
| # "result": (dataset_name, subset, split), # which produce the final result | |
| # "sample": (dataset_name, subset, split), # which we sample ICL few-shot examples | |
| # } | |
| raise NotImplementedError | |
| def dataset_part(self, part): | |
| return self.dataset_signature()[part] | |
| def dataset_preprocess(self, raw_data): | |
| raise NotImplementedError | |
| def handcrafted_exemplars(self): | |
| raise NotImplementedError | |
| def exemplar_seperator(self): | |
| raise NotImplementedError | |
| def multiple_choice_promptify(self, query, choice): | |
| raise NotImplementedError | |
| def merge_choice_info(choice_info): | |
| merged = {} | |
| for k in ["lm_log_p", "norm_lm_log_p"]: | |
| one_metric_merged = [] | |
| for info in choice_info: | |
| one_metric_merged.append(info[k]) | |
| merged[k] = one_metric_merged | |
| return merged | |
| def choice_info_to_predictions(info): | |
| lm_log_p_idx = int(np.argmax(info["lm_log_p"])) | |
| norm_lm_log_p_idx = int(np.argmax(info["norm_lm_log_p"])) | |
| return {"lm_log_p": lm_log_p_idx, "norm_lm_log_p": norm_lm_log_p_idx} | |