Spaces:
Sleeping
Sleeping
Upload model_comparator.py
Browse files- model_comparator.py +333 -0
model_comparator.py
ADDED
@@ -0,0 +1,333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Import necessary libraries
|
2 |
+
import pandas as pd
|
3 |
+
import numpy as np
|
4 |
+
import time
|
5 |
+
from sklearn.model_selection import train_test_split
|
6 |
+
|
7 |
+
# For plotting
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
import seaborn as sns
|
10 |
+
from scipy import stats
|
11 |
+
|
12 |
+
# Import models and metrics
|
13 |
+
from sklearn.linear_model import (
|
14 |
+
LinearRegression, Ridge, Lasso, ElasticNet, BayesianRidge,
|
15 |
+
HuberRegressor, PassiveAggressiveRegressor, OrthogonalMatchingPursuit,
|
16 |
+
LassoLars
|
17 |
+
)
|
18 |
+
from sklearn.tree import DecisionTreeRegressor
|
19 |
+
from sklearn.ensemble import (
|
20 |
+
RandomForestRegressor, GradientBoostingRegressor, ExtraTreesRegressor,
|
21 |
+
AdaBoostRegressor
|
22 |
+
)
|
23 |
+
from sklearn.neighbors import KNeighborsRegressor
|
24 |
+
from sklearn.dummy import DummyRegressor
|
25 |
+
from sklearn.metrics import (
|
26 |
+
mean_absolute_error, mean_squared_error, r2_score,
|
27 |
+
mean_absolute_percentage_error, mean_squared_log_error
|
28 |
+
)
|
29 |
+
|
30 |
+
# Import optional libraries
|
31 |
+
try:
|
32 |
+
import xgboost as xgb
|
33 |
+
import lightgbm as lgb
|
34 |
+
import catboost as cb
|
35 |
+
import mols2grid
|
36 |
+
_has_extra_libs = True
|
37 |
+
except ImportError:
|
38 |
+
_has_extra_libs = False
|
39 |
+
|
40 |
+
# --- Helper Functions ---
|
41 |
+
|
42 |
+
def _create_abbreviation(name: str) -> str:
|
43 |
+
"""Creates a capitalized abbreviation from a model name."""
|
44 |
+
if name == 'Lasso Regression':
|
45 |
+
return 'LaR'
|
46 |
+
if name == 'Linear Regression':
|
47 |
+
return 'LR'
|
48 |
+
return "".join([word[0] for word in name.split()]).upper()
|
49 |
+
|
50 |
+
def _rmsle(y_true, y_pred):
|
51 |
+
"""Calculates the Root Mean Squared Log Error."""
|
52 |
+
y_pred_clipped = np.maximum(y_pred, 0)
|
53 |
+
y_true_clipped = np.maximum(y_true, 0)
|
54 |
+
return np.sqrt(mean_squared_log_error(y_true_clipped, y_pred_clipped))
|
55 |
+
|
56 |
+
# --- Plotting Class ---
|
57 |
+
|
58 |
+
class ModelPlotter:
|
59 |
+
"""A class to handle plotting for trained regression models."""
|
60 |
+
def __init__(self, models: dict, X_test: pd.DataFrame, y_test: pd.Series, df_test: pd.DataFrame):
|
61 |
+
self._models = models
|
62 |
+
self._X_test = X_test # Numeric features for standard plots
|
63 |
+
self._y_test = y_test
|
64 |
+
self._df_test = df_test # Original test dataframe with all columns for molecule plotting
|
65 |
+
self.full_names = {abbr: model.__class__.__name__ for abbr, model in models.items()}
|
66 |
+
|
67 |
+
def plot(self, model_abbr: str):
|
68 |
+
"""
|
69 |
+
Generates a 2x2 grid of validation plots for a specified model.
|
70 |
+
|
71 |
+
Args:
|
72 |
+
model_abbr (str): The abbreviation of the model to plot (e.g., 'RFR').
|
73 |
+
"""
|
74 |
+
if model_abbr not in self._models:
|
75 |
+
raise ValueError(f"Model '{model_abbr}' not found. Available models: {list(self._models.keys())}")
|
76 |
+
|
77 |
+
model = self._models[model_abbr]
|
78 |
+
model_full_name = self.full_names.get(model_abbr, model_abbr)
|
79 |
+
|
80 |
+
y_pred = model.predict(self._X_test)
|
81 |
+
residuals = self._y_test - y_pred
|
82 |
+
|
83 |
+
sns.set_theme(style='whitegrid')
|
84 |
+
fig, axes = plt.subplots(2, 2, figsize=(15, 12))
|
85 |
+
fig.suptitle(f'Model Validation Plots for {model_full_name} ({model_abbr})', fontsize=20, y=1.03)
|
86 |
+
|
87 |
+
# 1. Actual vs. Predicted
|
88 |
+
sns.scatterplot(x=self._y_test, y=y_pred, ax=axes[0, 0], alpha=0.6)
|
89 |
+
axes[0, 0].set_title('Actual vs. Predicted Values', fontsize=14)
|
90 |
+
axes[0, 0].set_xlabel('Actual Values', fontsize=12)
|
91 |
+
axes[0, 0].set_ylabel('Predicted Values', fontsize=12)
|
92 |
+
lims = [min(self._y_test.min(), y_pred.min()), max(self._y_test.max(), y_pred.max())]
|
93 |
+
axes[0, 0].plot(lims, lims, 'r--', alpha=0.75, zorder=0)
|
94 |
+
|
95 |
+
# 2. Residuals vs. Predicted
|
96 |
+
sns.scatterplot(x=y_pred, y=residuals, ax=axes[0, 1], alpha=0.6)
|
97 |
+
axes[0, 1].axhline(y=0, color='r', linestyle='--')
|
98 |
+
axes[0, 1].set_title('Residuals vs. Predicted Values', fontsize=14)
|
99 |
+
axes[0, 1].set_xlabel('Predicted Values', fontsize=12)
|
100 |
+
axes[0, 1].set_ylabel('Residuals', fontsize=12)
|
101 |
+
|
102 |
+
# 3. Histogram of Residuals
|
103 |
+
sns.histplot(residuals, kde=True, ax=axes[1, 0])
|
104 |
+
axes[1, 0].set_title('Distribution of Residuals', fontsize=14)
|
105 |
+
axes[1, 0].set_xlabel('Residuals', fontsize=12)
|
106 |
+
axes[1, 0].set_ylabel('Frequency', fontsize=12)
|
107 |
+
|
108 |
+
# 4. Q-Q Plot
|
109 |
+
stats.probplot(residuals, dist="norm", plot=axes[1, 1])
|
110 |
+
axes[1, 1].get_lines()[0].set_markerfacecolor('#1f77b4')
|
111 |
+
axes[1, 1].get_lines()[0].set_markeredgecolor('#1f77b4')
|
112 |
+
axes[1, 1].get_lines()[1].set_color('r')
|
113 |
+
axes[1, 1].set_title('Normal Q-Q Plot of Residuals', fontsize=14)
|
114 |
+
axes[1, 1].set_xlabel('Theoretical Quantiles', fontsize=12)
|
115 |
+
axes[1, 1].set_ylabel('Sample Quantiles', fontsize=12)
|
116 |
+
|
117 |
+
plt.tight_layout()
|
118 |
+
plt.show()
|
119 |
+
|
120 |
+
def plot_feature_importance(self, model_abbr: str, top_n: int = 7):
|
121 |
+
"""
|
122 |
+
Plots the top N most important features for a specified model.
|
123 |
+
This function works for models with `feature_importances_` (e.g., RandomForest)
|
124 |
+
or `coef_` (e.g., LinearRegression) attributes.
|
125 |
+
|
126 |
+
Args:
|
127 |
+
model_abbr (str): The abbreviation of the model to plot (e.g., 'RFR').
|
128 |
+
top_n (int): The number of top features to display. Defaults to 7.
|
129 |
+
"""
|
130 |
+
if model_abbr not in self._models:
|
131 |
+
raise ValueError(f"Model '{model_abbr}' not found. Available models: {list(self._models.keys())}")
|
132 |
+
|
133 |
+
model = self._models[model_abbr]
|
134 |
+
model_full_name = self.full_names.get(model_abbr, model_abbr)
|
135 |
+
feature_names = self._X_test.columns
|
136 |
+
importance_type = ''
|
137 |
+
|
138 |
+
if hasattr(model, 'feature_importances_'):
|
139 |
+
importances = model.feature_importances_
|
140 |
+
importance_type = 'Importance'
|
141 |
+
elif hasattr(model, 'coef_'):
|
142 |
+
if model.coef_.ndim > 1:
|
143 |
+
importances = np.mean(np.abs(model.coef_), axis=0)
|
144 |
+
else:
|
145 |
+
importances = np.abs(model.coef_)
|
146 |
+
importance_type = 'Importance (Absolute Coefficient Value)'
|
147 |
+
else:
|
148 |
+
print(f"'{model_full_name}' does not support feature importance plotting (no 'feature_importances_' or 'coef_' attribute).")
|
149 |
+
return
|
150 |
+
|
151 |
+
feature_importance_df = pd.DataFrame({'Feature': feature_names, 'Importance': importances})
|
152 |
+
top_features = feature_importance_df.sort_values(by='Importance', ascending=False).head(top_n)
|
153 |
+
|
154 |
+
plt.figure(figsize=(12, top_n * 0.6))
|
155 |
+
sns.barplot(x='Importance', y='Feature', data=top_features, palette='viridis', orient='h')
|
156 |
+
|
157 |
+
plt.title(f'Top {top_n} Feature Importances for {model_full_name} ({model_abbr})', fontsize=16, pad=20)
|
158 |
+
plt.xlabel(importance_type, fontsize=12)
|
159 |
+
plt.ylabel('Feature', fontsize=12)
|
160 |
+
plt.tight_layout()
|
161 |
+
plt.show()
|
162 |
+
|
163 |
+
def plot_mols_for_top_features(self, model_abbr: str, smiles_col: str, top_n: int = 5, **kwargs):
|
164 |
+
"""
|
165 |
+
Displays an interactive grid of test set molecules, highlighting the top model features.
|
166 |
+
Requires the 'mols2grid' library.
|
167 |
+
|
168 |
+
Args:
|
169 |
+
model_abbr (str): The abbreviation of the model to use for feature importances.
|
170 |
+
smiles_col (str): The name of the column in the original DataFrame containing SMILES strings.
|
171 |
+
top_n (int): The number of top features to display in the grid's subset and tooltip.
|
172 |
+
**kwargs: Additional keyword arguments passed to mols2grid.display().
|
173 |
+
This can be used to customize 'subset', 'tooltip', 'rename', etc.
|
174 |
+
"""
|
175 |
+
if not _has_extra_libs or 'mols2grid' not in globals():
|
176 |
+
print("mols2grid library is not installed. Please install it using 'pip install mols2grid'.")
|
177 |
+
return
|
178 |
+
|
179 |
+
if model_abbr not in self._models:
|
180 |
+
raise ValueError(f"Model '{model_abbr}' not found. Available models: {list(self._models.keys())}")
|
181 |
+
|
182 |
+
if smiles_col not in self._df_test.columns:
|
183 |
+
raise ValueError(f"SMILES column '{smiles_col}' not found in the DataFrame. Please ensure it was present in the initial DataFrame.")
|
184 |
+
|
185 |
+
model = self._models[model_abbr]
|
186 |
+
|
187 |
+
# Get feature importances
|
188 |
+
if hasattr(model, 'feature_importances_'):
|
189 |
+
importances = model.feature_importances_
|
190 |
+
elif hasattr(model, 'coef_') and model.coef_.ndim == 1:
|
191 |
+
importances = np.abs(model.coef_)
|
192 |
+
elif hasattr(model, 'coef_'):
|
193 |
+
importances = np.mean(np.abs(model.coef_), axis=0)
|
194 |
+
else:
|
195 |
+
print(f"Cannot get feature importances for model '{model_abbr}'.")
|
196 |
+
return
|
197 |
+
|
198 |
+
feature_names = self._X_test.columns
|
199 |
+
feature_importance_df = pd.DataFrame({'Feature': feature_names, 'Importance': importances})
|
200 |
+
top_features = feature_importance_df.sort_values(by='Importance', ascending=False).head(top_n)
|
201 |
+
top_feature_names = top_features['Feature'].tolist()
|
202 |
+
|
203 |
+
# Prepare DataFrame for mols2grid
|
204 |
+
df_for_grid = self._df_test.copy()
|
205 |
+
|
206 |
+
# Set up default mols2grid display options, which user can override with kwargs
|
207 |
+
display_kwargs = {
|
208 |
+
"smiles_col": smiles_col,
|
209 |
+
"subset": ["img", self._y_test.name] + top_feature_names,
|
210 |
+
"tooltip": [self._y_test.name] + top_feature_names
|
211 |
+
}
|
212 |
+
display_kwargs.update(kwargs)
|
213 |
+
|
214 |
+
print(f"Generating molecular grid for top {top_n} features of model {model_abbr}...")
|
215 |
+
return mols2grid.display(df_for_grid, **display_kwargs)
|
216 |
+
|
217 |
+
# --- Result Container Class ---
|
218 |
+
class RegressionResult:
|
219 |
+
"""
|
220 |
+
A container for regression results designed for rich display in notebooks.
|
221 |
+
Access the results DataFrame via the `.dataframe` attribute, the
|
222 |
+
ModelPlotter object via the `.plotter` attribute, and the dictionary of
|
223 |
+
trained models via the `.models` attribute.
|
224 |
+
|
225 |
+
Example:
|
226 |
+
>>> result = regression(df, 'target')
|
227 |
+
>>> best_model = result.models['RFR']
|
228 |
+
>>> result.plotter.plot_feature_importance('RFR')
|
229 |
+
>>> result.plotter.plot_mols_for_top_features('RFR', smiles_col='SMILES')
|
230 |
+
"""
|
231 |
+
def __init__(self, results_df: pd.DataFrame, plotter: ModelPlotter, trained_models: dict):
|
232 |
+
self.dataframe = results_df
|
233 |
+
self.plotter = plotter
|
234 |
+
self.models = trained_models
|
235 |
+
|
236 |
+
def _repr_html_(self):
|
237 |
+
"""Returns the HTML representation of the results DataFrame."""
|
238 |
+
return self.dataframe.to_html(index=False, float_format='{:.4f}'.format)
|
239 |
+
|
240 |
+
def __repr__(self):
|
241 |
+
"""Returns the string representation for display in non-HTML environments."""
|
242 |
+
return self.dataframe.to_string(index=False)
|
243 |
+
|
244 |
+
# --- Main Function ---
|
245 |
+
def regression(df: pd.DataFrame, target_variable: str) -> RegressionResult:
|
246 |
+
"""
|
247 |
+
Trains, evaluates, and provides plotting for multiple regression models.
|
248 |
+
|
249 |
+
Args:
|
250 |
+
df (pd.DataFrame): The input dataframe. Must contain the target variable
|
251 |
+
and all features. For molecule plotting, it should also
|
252 |
+
contain a SMILES column.
|
253 |
+
target_variable (str): The name of the target column.
|
254 |
+
|
255 |
+
Returns:
|
256 |
+
RegressionResult: An object containing the performance metrics DataFrame,
|
257 |
+
a ModelPlotter, and a dictionary of trained models.
|
258 |
+
"""
|
259 |
+
# 1. Split data while keeping original structure for molecule plotting
|
260 |
+
indices = df.index
|
261 |
+
train_indices, test_indices = train_test_split(indices, test_size=0.2, random_state=42)
|
262 |
+
df_train = df.loc[train_indices]
|
263 |
+
df_test = df.loc[test_indices]
|
264 |
+
|
265 |
+
# 2. Prepare numeric data for training and evaluation
|
266 |
+
X_train = df_train.drop(columns=[target_variable]).apply(pd.to_numeric, errors='coerce').fillna(0)
|
267 |
+
y_train = df_train[target_variable]
|
268 |
+
X_test = df_test.drop(columns=[target_variable]).apply(pd.to_numeric, errors='coerce').fillna(0)
|
269 |
+
y_test = df_test[target_variable]
|
270 |
+
|
271 |
+
# ... (rest of the models are the same)
|
272 |
+
model_definitions = [
|
273 |
+
('Linear Regression', LinearRegression()),
|
274 |
+
('Ridge Regression', Ridge(random_state=42)),
|
275 |
+
('Lasso Regression', Lasso(random_state=42)),
|
276 |
+
('Elastic Net', ElasticNet(random_state=42)),
|
277 |
+
('Lasso Least Angle Regression', LassoLars(random_state=42)),
|
278 |
+
('Orthogonal Matching Pursuit', OrthogonalMatchingPursuit()),
|
279 |
+
('Bayesian Ridge', BayesianRidge()),
|
280 |
+
('Huber Regressor', HuberRegressor()),
|
281 |
+
('Passive Aggressive Regressor', PassiveAggressiveRegressor(random_state=42)),
|
282 |
+
('K Neighbors Regressor', KNeighborsRegressor()),
|
283 |
+
('Decision Tree Regressor', DecisionTreeRegressor(random_state=42)),
|
284 |
+
('Random Forest Regressor', RandomForestRegressor(random_state=42, n_jobs=-1)),
|
285 |
+
('Extra Trees Regressor', ExtraTreesRegressor(random_state=42, n_jobs=-1)),
|
286 |
+
('AdaBoost Regressor', AdaBoostRegressor(random_state=42)),
|
287 |
+
('Gradient Boosting Regressor', GradientBoostingRegressor(random_state=42)),
|
288 |
+
('Dummy Regressor', DummyRegressor(strategy='mean'))
|
289 |
+
]
|
290 |
+
if _has_extra_libs:
|
291 |
+
model_definitions.extend([
|
292 |
+
('Extreme Gradient Boosting', xgb.XGBRegressor(random_state=42, n_jobs=-1, verbosity=0)),
|
293 |
+
('Light Gradient Boosting Machine', lgb.LGBMRegressor(random_state=42, n_jobs=-1, verbosity=-1)),
|
294 |
+
('CatBoost Regressor', cb.CatBoostRegressor(random_state=42, verbose=0))
|
295 |
+
])
|
296 |
+
|
297 |
+
results_list = []
|
298 |
+
trained_models = {}
|
299 |
+
print("Starting model training and evaluation...")
|
300 |
+
|
301 |
+
for name, model in model_definitions:
|
302 |
+
abbr = _create_abbreviation(name)
|
303 |
+
start_time = time.time()
|
304 |
+
try:
|
305 |
+
model.fit(X_train, y_train)
|
306 |
+
training_time = time.time() - start_time
|
307 |
+
y_pred = model.predict(X_test)
|
308 |
+
|
309 |
+
results_list.append({
|
310 |
+
'Model Abbreviation': abbr, 'Model': name,
|
311 |
+
'MAE': mean_absolute_error(y_test, y_pred),
|
312 |
+
'MSE': mean_squared_error(y_test, y_pred),
|
313 |
+
'RMSE': np.sqrt(mean_squared_error(y_test, y_pred)),
|
314 |
+
'R2': r2_score(y_test, y_pred),
|
315 |
+
'RMSLE': _rmsle(y_test, y_pred),
|
316 |
+
'MAPE': mean_absolute_percentage_error(y_test, y_pred),
|
317 |
+
'TT (Sec)': training_time
|
318 |
+
})
|
319 |
+
trained_models[abbr] = model
|
320 |
+
except Exception as e:
|
321 |
+
print(f"Could not train {name}. Error: {e}")
|
322 |
+
|
323 |
+
|
324 |
+
print("Evaluation complete.")
|
325 |
+
print("=" * 50)
|
326 |
+
|
327 |
+
results_df = pd.DataFrame(results_list)
|
328 |
+
results_df = results_df.sort_values(by='R2', ascending=False).reset_index(drop=True)
|
329 |
+
|
330 |
+
# Pass the original df_test to the plotter for molecule visualization
|
331 |
+
plotter = ModelPlotter(trained_models, X_test, y_test, df_test)
|
332 |
+
|
333 |
+
return RegressionResult(results_df, plotter, trained_models)
|