drug-discovery-app / model_comparator.py
alidenewade's picture
Upload model_comparator.py
df9da26 verified
raw
history blame
15 kB
# Import necessary libraries
import pandas as pd
import numpy as np
import time
from sklearn.model_selection import train_test_split
# For plotting
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
# Import models and metrics
from sklearn.linear_model import (
LinearRegression, Ridge, Lasso, ElasticNet, BayesianRidge,
HuberRegressor, PassiveAggressiveRegressor, OrthogonalMatchingPursuit,
LassoLars
)
from sklearn.tree import DecisionTreeRegressor
from sklearn.ensemble import (
RandomForestRegressor, GradientBoostingRegressor, ExtraTreesRegressor,
AdaBoostRegressor
)
from sklearn.neighbors import KNeighborsRegressor
from sklearn.dummy import DummyRegressor
from sklearn.metrics import (
mean_absolute_error, mean_squared_error, r2_score,
mean_absolute_percentage_error, mean_squared_log_error
)
# Import optional libraries
try:
import xgboost as xgb
import lightgbm as lgb
import catboost as cb
import mols2grid
_has_extra_libs = True
except ImportError:
_has_extra_libs = False
# --- Helper Functions ---
def _create_abbreviation(name: str) -> str:
"""Creates a capitalized abbreviation from a model name."""
if name == 'Lasso Regression':
return 'LaR'
if name == 'Linear Regression':
return 'LR'
return "".join([word[0] for word in name.split()]).upper()
def _rmsle(y_true, y_pred):
"""Calculates the Root Mean Squared Log Error."""
y_pred_clipped = np.maximum(y_pred, 0)
y_true_clipped = np.maximum(y_true, 0)
return np.sqrt(mean_squared_log_error(y_true_clipped, y_pred_clipped))
# --- Plotting Class ---
class ModelPlotter:
"""A class to handle plotting for trained regression models."""
def __init__(self, models: dict, X_test: pd.DataFrame, y_test: pd.Series, df_test: pd.DataFrame):
self._models = models
self._X_test = X_test # Numeric features for standard plots
self._y_test = y_test
self._df_test = df_test # Original test dataframe with all columns for molecule plotting
self.full_names = {abbr: model.__class__.__name__ for abbr, model in models.items()}
def plot(self, model_abbr: str):
"""
Generates a 2x2 grid of validation plots for a specified model.
Args:
model_abbr (str): The abbreviation of the model to plot (e.g., 'RFR').
"""
if model_abbr not in self._models:
raise ValueError(f"Model '{model_abbr}' not found. Available models: {list(self._models.keys())}")
model = self._models[model_abbr]
model_full_name = self.full_names.get(model_abbr, model_abbr)
y_pred = model.predict(self._X_test)
residuals = self._y_test - y_pred
sns.set_theme(style='whitegrid')
fig, axes = plt.subplots(2, 2, figsize=(15, 12))
fig.suptitle(f'Model Validation Plots for {model_full_name} ({model_abbr})', fontsize=20, y=1.03)
# 1. Actual vs. Predicted
sns.scatterplot(x=self._y_test, y=y_pred, ax=axes[0, 0], alpha=0.6)
axes[0, 0].set_title('Actual vs. Predicted Values', fontsize=14)
axes[0, 0].set_xlabel('Actual Values', fontsize=12)
axes[0, 0].set_ylabel('Predicted Values', fontsize=12)
lims = [min(self._y_test.min(), y_pred.min()), max(self._y_test.max(), y_pred.max())]
axes[0, 0].plot(lims, lims, 'r--', alpha=0.75, zorder=0)
# 2. Residuals vs. Predicted
sns.scatterplot(x=y_pred, y=residuals, ax=axes[0, 1], alpha=0.6)
axes[0, 1].axhline(y=0, color='r', linestyle='--')
axes[0, 1].set_title('Residuals vs. Predicted Values', fontsize=14)
axes[0, 1].set_xlabel('Predicted Values', fontsize=12)
axes[0, 1].set_ylabel('Residuals', fontsize=12)
# 3. Histogram of Residuals
sns.histplot(residuals, kde=True, ax=axes[1, 0])
axes[1, 0].set_title('Distribution of Residuals', fontsize=14)
axes[1, 0].set_xlabel('Residuals', fontsize=12)
axes[1, 0].set_ylabel('Frequency', fontsize=12)
# 4. Q-Q Plot
stats.probplot(residuals, dist="norm", plot=axes[1, 1])
axes[1, 1].get_lines()[0].set_markerfacecolor('#1f77b4')
axes[1, 1].get_lines()[0].set_markeredgecolor('#1f77b4')
axes[1, 1].get_lines()[1].set_color('r')
axes[1, 1].set_title('Normal Q-Q Plot of Residuals', fontsize=14)
axes[1, 1].set_xlabel('Theoretical Quantiles', fontsize=12)
axes[1, 1].set_ylabel('Sample Quantiles', fontsize=12)
plt.tight_layout()
plt.show()
def plot_feature_importance(self, model_abbr: str, top_n: int = 7):
"""
Plots the top N most important features for a specified model.
This function works for models with `feature_importances_` (e.g., RandomForest)
or `coef_` (e.g., LinearRegression) attributes.
Args:
model_abbr (str): The abbreviation of the model to plot (e.g., 'RFR').
top_n (int): The number of top features to display. Defaults to 7.
"""
if model_abbr not in self._models:
raise ValueError(f"Model '{model_abbr}' not found. Available models: {list(self._models.keys())}")
model = self._models[model_abbr]
model_full_name = self.full_names.get(model_abbr, model_abbr)
feature_names = self._X_test.columns
importance_type = ''
if hasattr(model, 'feature_importances_'):
importances = model.feature_importances_
importance_type = 'Importance'
elif hasattr(model, 'coef_'):
if model.coef_.ndim > 1:
importances = np.mean(np.abs(model.coef_), axis=0)
else:
importances = np.abs(model.coef_)
importance_type = 'Importance (Absolute Coefficient Value)'
else:
print(f"'{model_full_name}' does not support feature importance plotting (no 'feature_importances_' or 'coef_' attribute).")
return
feature_importance_df = pd.DataFrame({'Feature': feature_names, 'Importance': importances})
top_features = feature_importance_df.sort_values(by='Importance', ascending=False).head(top_n)
plt.figure(figsize=(12, top_n * 0.6))
sns.barplot(x='Importance', y='Feature', data=top_features, palette='viridis', orient='h')
plt.title(f'Top {top_n} Feature Importances for {model_full_name} ({model_abbr})', fontsize=16, pad=20)
plt.xlabel(importance_type, fontsize=12)
plt.ylabel('Feature', fontsize=12)
plt.tight_layout()
plt.show()
def plot_mols_for_top_features(self, model_abbr: str, smiles_col: str, top_n: int = 5, **kwargs):
"""
Displays an interactive grid of test set molecules, highlighting the top model features.
Requires the 'mols2grid' library.
Args:
model_abbr (str): The abbreviation of the model to use for feature importances.
smiles_col (str): The name of the column in the original DataFrame containing SMILES strings.
top_n (int): The number of top features to display in the grid's subset and tooltip.
**kwargs: Additional keyword arguments passed to mols2grid.display().
This can be used to customize 'subset', 'tooltip', 'rename', etc.
"""
if not _has_extra_libs or 'mols2grid' not in globals():
print("mols2grid library is not installed. Please install it using 'pip install mols2grid'.")
return
if model_abbr not in self._models:
raise ValueError(f"Model '{model_abbr}' not found. Available models: {list(self._models.keys())}")
if smiles_col not in self._df_test.columns:
raise ValueError(f"SMILES column '{smiles_col}' not found in the DataFrame. Please ensure it was present in the initial DataFrame.")
model = self._models[model_abbr]
# Get feature importances
if hasattr(model, 'feature_importances_'):
importances = model.feature_importances_
elif hasattr(model, 'coef_') and model.coef_.ndim == 1:
importances = np.abs(model.coef_)
elif hasattr(model, 'coef_'):
importances = np.mean(np.abs(model.coef_), axis=0)
else:
print(f"Cannot get feature importances for model '{model_abbr}'.")
return
feature_names = self._X_test.columns
feature_importance_df = pd.DataFrame({'Feature': feature_names, 'Importance': importances})
top_features = feature_importance_df.sort_values(by='Importance', ascending=False).head(top_n)
top_feature_names = top_features['Feature'].tolist()
# Prepare DataFrame for mols2grid
df_for_grid = self._df_test.copy()
# Set up default mols2grid display options, which user can override with kwargs
display_kwargs = {
"smiles_col": smiles_col,
"subset": ["img", self._y_test.name] + top_feature_names,
"tooltip": [self._y_test.name] + top_feature_names
}
display_kwargs.update(kwargs)
print(f"Generating molecular grid for top {top_n} features of model {model_abbr}...")
return mols2grid.display(df_for_grid, **display_kwargs)
# --- Result Container Class ---
class RegressionResult:
"""
A container for regression results designed for rich display in notebooks.
Access the results DataFrame via the `.dataframe` attribute, the
ModelPlotter object via the `.plotter` attribute, and the dictionary of
trained models via the `.models` attribute.
Example:
>>> result = regression(df, 'target')
>>> best_model = result.models['RFR']
>>> result.plotter.plot_feature_importance('RFR')
>>> result.plotter.plot_mols_for_top_features('RFR', smiles_col='SMILES')
"""
def __init__(self, results_df: pd.DataFrame, plotter: ModelPlotter, trained_models: dict):
self.dataframe = results_df
self.plotter = plotter
self.models = trained_models
def _repr_html_(self):
"""Returns the HTML representation of the results DataFrame."""
return self.dataframe.to_html(index=False, float_format='{:.4f}'.format)
def __repr__(self):
"""Returns the string representation for display in non-HTML environments."""
return self.dataframe.to_string(index=False)
# --- Main Function ---
def regression(df: pd.DataFrame, target_variable: str) -> RegressionResult:
"""
Trains, evaluates, and provides plotting for multiple regression models.
Args:
df (pd.DataFrame): The input dataframe. Must contain the target variable
and all features. For molecule plotting, it should also
contain a SMILES column.
target_variable (str): The name of the target column.
Returns:
RegressionResult: An object containing the performance metrics DataFrame,
a ModelPlotter, and a dictionary of trained models.
"""
# 1. Split data while keeping original structure for molecule plotting
indices = df.index
train_indices, test_indices = train_test_split(indices, test_size=0.2, random_state=42)
df_train = df.loc[train_indices]
df_test = df.loc[test_indices]
# 2. Prepare numeric data for training and evaluation
X_train = df_train.drop(columns=[target_variable]).apply(pd.to_numeric, errors='coerce').fillna(0)
y_train = df_train[target_variable]
X_test = df_test.drop(columns=[target_variable]).apply(pd.to_numeric, errors='coerce').fillna(0)
y_test = df_test[target_variable]
# ... (rest of the models are the same)
model_definitions = [
('Linear Regression', LinearRegression()),
('Ridge Regression', Ridge(random_state=42)),
('Lasso Regression', Lasso(random_state=42)),
('Elastic Net', ElasticNet(random_state=42)),
('Lasso Least Angle Regression', LassoLars(random_state=42)),
('Orthogonal Matching Pursuit', OrthogonalMatchingPursuit()),
('Bayesian Ridge', BayesianRidge()),
('Huber Regressor', HuberRegressor()),
('Passive Aggressive Regressor', PassiveAggressiveRegressor(random_state=42)),
('K Neighbors Regressor', KNeighborsRegressor()),
('Decision Tree Regressor', DecisionTreeRegressor(random_state=42)),
('Random Forest Regressor', RandomForestRegressor(random_state=42, n_jobs=-1)),
('Extra Trees Regressor', ExtraTreesRegressor(random_state=42, n_jobs=-1)),
('AdaBoost Regressor', AdaBoostRegressor(random_state=42)),
('Gradient Boosting Regressor', GradientBoostingRegressor(random_state=42)),
('Dummy Regressor', DummyRegressor(strategy='mean'))
]
if _has_extra_libs:
model_definitions.extend([
('Extreme Gradient Boosting', xgb.XGBRegressor(random_state=42, n_jobs=-1, verbosity=0)),
('Light Gradient Boosting Machine', lgb.LGBMRegressor(random_state=42, n_jobs=-1, verbosity=-1)),
('CatBoost Regressor', cb.CatBoostRegressor(random_state=42, verbose=0))
])
results_list = []
trained_models = {}
print("Starting model training and evaluation...")
for name, model in model_definitions:
abbr = _create_abbreviation(name)
start_time = time.time()
try:
model.fit(X_train, y_train)
training_time = time.time() - start_time
y_pred = model.predict(X_test)
results_list.append({
'Model Abbreviation': abbr, 'Model': name,
'MAE': mean_absolute_error(y_test, y_pred),
'MSE': mean_squared_error(y_test, y_pred),
'RMSE': np.sqrt(mean_squared_error(y_test, y_pred)),
'R2': r2_score(y_test, y_pred),
'RMSLE': _rmsle(y_test, y_pred),
'MAPE': mean_absolute_percentage_error(y_test, y_pred),
'TT (Sec)': training_time
})
trained_models[abbr] = model
except Exception as e:
print(f"Could not train {name}. Error: {e}")
print("Evaluation complete.")
print("=" * 50)
results_df = pd.DataFrame(results_list)
results_df = results_df.sort_values(by='R2', ascending=False).reset_index(drop=True)
# Pass the original df_test to the plotter for molecule visualization
plotter = ModelPlotter(trained_models, X_test, y_test, df_test)
return RegressionResult(results_df, plotter, trained_models)