from collections import defaultdict from scipy.stats import spearmanr import pandas as pd import numpy as np from constants import ASSAY_LIST, ASSAY_HIGHER_IS_BETTER FOLD_COL = "hierarchical_cluster_IgG_isotype_stratified_fold" def recall_at_k(y_true: np.ndarray, y_pred: np.ndarray, frac: float = 0.1) -> float: """Calculate recall (TP)/(TP+FN) for top fraction of true values. A recall of 1 would mean that the top fraction of true values are also the top fraction of predicted values. There is no penalty for ranking the top k differently. Args: y_true (np.ndarray): true values with shape (num_data,) y_pred (np.ndarray): predicted values with shape (num_data,) frac (float, optional): fraction of data points to consider as the top. Defaults to 0.1. Returns: float: recall at top k of data """ top_k = int(len(y_true) * frac) y_true, y_pred = np.array(y_true).flatten(), np.array(y_pred).flatten() true_top_k = np.argsort(y_true)[-1 * top_k :] predicted_top_k = np.argsort(y_pred)[-1 * top_k :] return ( len( set(list(true_top_k.flatten())).intersection( set(list(predicted_top_k.flatten())) ) ) / top_k ) def get_metrics( predictions_series: pd.Series, target_series: pd.Series, assay_col: str ) -> dict[str, float]: results_dict = { "spearman": spearmanr( predictions_series, target_series, nan_policy="omit" ).correlation } # Top 10% recall y_true = target_series.values y_pred = predictions_series.values if not ASSAY_HIGHER_IS_BETTER[assay_col]: y_true = -1 * y_true y_pred = -1 * y_pred results_dict["top_10_recall"] = recall_at_k(y_true=y_true, y_pred=y_pred, frac=0.1) return results_dict def get_metrics_cross_validation( predictions_series: pd.Series, target_series: pd.Series, folds_series: pd.Series, assay_col: str, ) -> dict[str, float]: # Run evaluate in a cross-validation loop results_dict = defaultdict(list) if folds_series.nunique() != 5: raise ValueError(f"Expected 5 folds, got {folds_series.nunique()}") for fold in folds_series.unique(): predictions_series_fold = predictions_series[folds_series == fold] target_series_fold = target_series[folds_series == fold] results = get_metrics(predictions_series_fold, target_series_fold, assay_col) # Update the results_dict with the results for this fold for key, value in results.items(): results_dict[key].append(value) # Calculate the mean of the results for each key (could also add std dev later) for key, values in results_dict.items(): results_dict[key] = np.mean(values) return results_dict def _get_result_for_assay(df_merged, assay_col, dataset_name): """ Return a dictionary with the results for a single assay. """ if dataset_name == "GDPa1_cross_validation": results = get_metrics_cross_validation( df_merged[assay_col + "_pred"], df_merged[assay_col + "_true"], df_merged[FOLD_COL], assay_col, ) elif dataset_name == "GDPa1": results = get_metrics( df_merged[assay_col + "_pred"], df_merged[assay_col + "_true"], assay_col ) elif dataset_name == "Heldout Test Set": # Just record these as NaNs for now - they'll appear on the leaderboard and we can handle them on their own results = {"spearman": np.nan, "top_10_recall": np.nan} results["assay"] = assay_col return results def _get_error_result(assay_col, dataset_name, error): """ Return a dictionary with the error message instead of metrics. Used when _get_result_for_assay fails. """ print(f"Error evaluating {assay_col}: {error}") # Add a failed result record with error information error_result = { "dataset": dataset_name, "assay": assay_col, } error_result.update({"spearman": error, "top_10_recall": error}) return error_result def evaluate(predictions_df, target_df, dataset_name="GDPa1"): """ Evaluates a single model, where the predictions dataframe has columns named by property. eg. my_model.csv has columns antibody_name, HIC, Tm2 Lood: Copied from Github repo, which I should move over here """ properties_in_preds = [col for col in predictions_df.columns if col in ASSAY_LIST] df_merged = pd.merge( target_df[["antibody_name", FOLD_COL] + ASSAY_LIST], predictions_df[["antibody_name"] + properties_in_preds], on="antibody_name", how="left", suffixes=("_true", "_pred"), ) results_list = [] # Process each property one by one for better error handling for assay_col in properties_in_preds: try: results = _get_result_for_assay(df_merged, assay_col, dataset_name) results_list.append(results) except Exception as e: error_result = _get_error_result(assay_col, dataset_name, e) results_list.append(error_result) results_df = pd.DataFrame(results_list) return results_df