alidenewade commited on
Commit
df9da26
·
verified ·
1 Parent(s): ff9435d

Upload model_comparator.py

Browse files
Files changed (1) hide show
  1. model_comparator.py +333 -0
model_comparator.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Import necessary libraries
2
+ import pandas as pd
3
+ import numpy as np
4
+ import time
5
+ from sklearn.model_selection import train_test_split
6
+
7
+ # For plotting
8
+ import matplotlib.pyplot as plt
9
+ import seaborn as sns
10
+ from scipy import stats
11
+
12
+ # Import models and metrics
13
+ from sklearn.linear_model import (
14
+ LinearRegression, Ridge, Lasso, ElasticNet, BayesianRidge,
15
+ HuberRegressor, PassiveAggressiveRegressor, OrthogonalMatchingPursuit,
16
+ LassoLars
17
+ )
18
+ from sklearn.tree import DecisionTreeRegressor
19
+ from sklearn.ensemble import (
20
+ RandomForestRegressor, GradientBoostingRegressor, ExtraTreesRegressor,
21
+ AdaBoostRegressor
22
+ )
23
+ from sklearn.neighbors import KNeighborsRegressor
24
+ from sklearn.dummy import DummyRegressor
25
+ from sklearn.metrics import (
26
+ mean_absolute_error, mean_squared_error, r2_score,
27
+ mean_absolute_percentage_error, mean_squared_log_error
28
+ )
29
+
30
+ # Import optional libraries
31
+ try:
32
+ import xgboost as xgb
33
+ import lightgbm as lgb
34
+ import catboost as cb
35
+ import mols2grid
36
+ _has_extra_libs = True
37
+ except ImportError:
38
+ _has_extra_libs = False
39
+
40
+ # --- Helper Functions ---
41
+
42
+ def _create_abbreviation(name: str) -> str:
43
+ """Creates a capitalized abbreviation from a model name."""
44
+ if name == 'Lasso Regression':
45
+ return 'LaR'
46
+ if name == 'Linear Regression':
47
+ return 'LR'
48
+ return "".join([word[0] for word in name.split()]).upper()
49
+
50
+ def _rmsle(y_true, y_pred):
51
+ """Calculates the Root Mean Squared Log Error."""
52
+ y_pred_clipped = np.maximum(y_pred, 0)
53
+ y_true_clipped = np.maximum(y_true, 0)
54
+ return np.sqrt(mean_squared_log_error(y_true_clipped, y_pred_clipped))
55
+
56
+ # --- Plotting Class ---
57
+
58
+ class ModelPlotter:
59
+ """A class to handle plotting for trained regression models."""
60
+ def __init__(self, models: dict, X_test: pd.DataFrame, y_test: pd.Series, df_test: pd.DataFrame):
61
+ self._models = models
62
+ self._X_test = X_test # Numeric features for standard plots
63
+ self._y_test = y_test
64
+ self._df_test = df_test # Original test dataframe with all columns for molecule plotting
65
+ self.full_names = {abbr: model.__class__.__name__ for abbr, model in models.items()}
66
+
67
+ def plot(self, model_abbr: str):
68
+ """
69
+ Generates a 2x2 grid of validation plots for a specified model.
70
+
71
+ Args:
72
+ model_abbr (str): The abbreviation of the model to plot (e.g., 'RFR').
73
+ """
74
+ if model_abbr not in self._models:
75
+ raise ValueError(f"Model '{model_abbr}' not found. Available models: {list(self._models.keys())}")
76
+
77
+ model = self._models[model_abbr]
78
+ model_full_name = self.full_names.get(model_abbr, model_abbr)
79
+
80
+ y_pred = model.predict(self._X_test)
81
+ residuals = self._y_test - y_pred
82
+
83
+ sns.set_theme(style='whitegrid')
84
+ fig, axes = plt.subplots(2, 2, figsize=(15, 12))
85
+ fig.suptitle(f'Model Validation Plots for {model_full_name} ({model_abbr})', fontsize=20, y=1.03)
86
+
87
+ # 1. Actual vs. Predicted
88
+ sns.scatterplot(x=self._y_test, y=y_pred, ax=axes[0, 0], alpha=0.6)
89
+ axes[0, 0].set_title('Actual vs. Predicted Values', fontsize=14)
90
+ axes[0, 0].set_xlabel('Actual Values', fontsize=12)
91
+ axes[0, 0].set_ylabel('Predicted Values', fontsize=12)
92
+ lims = [min(self._y_test.min(), y_pred.min()), max(self._y_test.max(), y_pred.max())]
93
+ axes[0, 0].plot(lims, lims, 'r--', alpha=0.75, zorder=0)
94
+
95
+ # 2. Residuals vs. Predicted
96
+ sns.scatterplot(x=y_pred, y=residuals, ax=axes[0, 1], alpha=0.6)
97
+ axes[0, 1].axhline(y=0, color='r', linestyle='--')
98
+ axes[0, 1].set_title('Residuals vs. Predicted Values', fontsize=14)
99
+ axes[0, 1].set_xlabel('Predicted Values', fontsize=12)
100
+ axes[0, 1].set_ylabel('Residuals', fontsize=12)
101
+
102
+ # 3. Histogram of Residuals
103
+ sns.histplot(residuals, kde=True, ax=axes[1, 0])
104
+ axes[1, 0].set_title('Distribution of Residuals', fontsize=14)
105
+ axes[1, 0].set_xlabel('Residuals', fontsize=12)
106
+ axes[1, 0].set_ylabel('Frequency', fontsize=12)
107
+
108
+ # 4. Q-Q Plot
109
+ stats.probplot(residuals, dist="norm", plot=axes[1, 1])
110
+ axes[1, 1].get_lines()[0].set_markerfacecolor('#1f77b4')
111
+ axes[1, 1].get_lines()[0].set_markeredgecolor('#1f77b4')
112
+ axes[1, 1].get_lines()[1].set_color('r')
113
+ axes[1, 1].set_title('Normal Q-Q Plot of Residuals', fontsize=14)
114
+ axes[1, 1].set_xlabel('Theoretical Quantiles', fontsize=12)
115
+ axes[1, 1].set_ylabel('Sample Quantiles', fontsize=12)
116
+
117
+ plt.tight_layout()
118
+ plt.show()
119
+
120
+ def plot_feature_importance(self, model_abbr: str, top_n: int = 7):
121
+ """
122
+ Plots the top N most important features for a specified model.
123
+ This function works for models with `feature_importances_` (e.g., RandomForest)
124
+ or `coef_` (e.g., LinearRegression) attributes.
125
+
126
+ Args:
127
+ model_abbr (str): The abbreviation of the model to plot (e.g., 'RFR').
128
+ top_n (int): The number of top features to display. Defaults to 7.
129
+ """
130
+ if model_abbr not in self._models:
131
+ raise ValueError(f"Model '{model_abbr}' not found. Available models: {list(self._models.keys())}")
132
+
133
+ model = self._models[model_abbr]
134
+ model_full_name = self.full_names.get(model_abbr, model_abbr)
135
+ feature_names = self._X_test.columns
136
+ importance_type = ''
137
+
138
+ if hasattr(model, 'feature_importances_'):
139
+ importances = model.feature_importances_
140
+ importance_type = 'Importance'
141
+ elif hasattr(model, 'coef_'):
142
+ if model.coef_.ndim > 1:
143
+ importances = np.mean(np.abs(model.coef_), axis=0)
144
+ else:
145
+ importances = np.abs(model.coef_)
146
+ importance_type = 'Importance (Absolute Coefficient Value)'
147
+ else:
148
+ print(f"'{model_full_name}' does not support feature importance plotting (no 'feature_importances_' or 'coef_' attribute).")
149
+ return
150
+
151
+ feature_importance_df = pd.DataFrame({'Feature': feature_names, 'Importance': importances})
152
+ top_features = feature_importance_df.sort_values(by='Importance', ascending=False).head(top_n)
153
+
154
+ plt.figure(figsize=(12, top_n * 0.6))
155
+ sns.barplot(x='Importance', y='Feature', data=top_features, palette='viridis', orient='h')
156
+
157
+ plt.title(f'Top {top_n} Feature Importances for {model_full_name} ({model_abbr})', fontsize=16, pad=20)
158
+ plt.xlabel(importance_type, fontsize=12)
159
+ plt.ylabel('Feature', fontsize=12)
160
+ plt.tight_layout()
161
+ plt.show()
162
+
163
+ def plot_mols_for_top_features(self, model_abbr: str, smiles_col: str, top_n: int = 5, **kwargs):
164
+ """
165
+ Displays an interactive grid of test set molecules, highlighting the top model features.
166
+ Requires the 'mols2grid' library.
167
+
168
+ Args:
169
+ model_abbr (str): The abbreviation of the model to use for feature importances.
170
+ smiles_col (str): The name of the column in the original DataFrame containing SMILES strings.
171
+ top_n (int): The number of top features to display in the grid's subset and tooltip.
172
+ **kwargs: Additional keyword arguments passed to mols2grid.display().
173
+ This can be used to customize 'subset', 'tooltip', 'rename', etc.
174
+ """
175
+ if not _has_extra_libs or 'mols2grid' not in globals():
176
+ print("mols2grid library is not installed. Please install it using 'pip install mols2grid'.")
177
+ return
178
+
179
+ if model_abbr not in self._models:
180
+ raise ValueError(f"Model '{model_abbr}' not found. Available models: {list(self._models.keys())}")
181
+
182
+ if smiles_col not in self._df_test.columns:
183
+ raise ValueError(f"SMILES column '{smiles_col}' not found in the DataFrame. Please ensure it was present in the initial DataFrame.")
184
+
185
+ model = self._models[model_abbr]
186
+
187
+ # Get feature importances
188
+ if hasattr(model, 'feature_importances_'):
189
+ importances = model.feature_importances_
190
+ elif hasattr(model, 'coef_') and model.coef_.ndim == 1:
191
+ importances = np.abs(model.coef_)
192
+ elif hasattr(model, 'coef_'):
193
+ importances = np.mean(np.abs(model.coef_), axis=0)
194
+ else:
195
+ print(f"Cannot get feature importances for model '{model_abbr}'.")
196
+ return
197
+
198
+ feature_names = self._X_test.columns
199
+ feature_importance_df = pd.DataFrame({'Feature': feature_names, 'Importance': importances})
200
+ top_features = feature_importance_df.sort_values(by='Importance', ascending=False).head(top_n)
201
+ top_feature_names = top_features['Feature'].tolist()
202
+
203
+ # Prepare DataFrame for mols2grid
204
+ df_for_grid = self._df_test.copy()
205
+
206
+ # Set up default mols2grid display options, which user can override with kwargs
207
+ display_kwargs = {
208
+ "smiles_col": smiles_col,
209
+ "subset": ["img", self._y_test.name] + top_feature_names,
210
+ "tooltip": [self._y_test.name] + top_feature_names
211
+ }
212
+ display_kwargs.update(kwargs)
213
+
214
+ print(f"Generating molecular grid for top {top_n} features of model {model_abbr}...")
215
+ return mols2grid.display(df_for_grid, **display_kwargs)
216
+
217
+ # --- Result Container Class ---
218
+ class RegressionResult:
219
+ """
220
+ A container for regression results designed for rich display in notebooks.
221
+ Access the results DataFrame via the `.dataframe` attribute, the
222
+ ModelPlotter object via the `.plotter` attribute, and the dictionary of
223
+ trained models via the `.models` attribute.
224
+
225
+ Example:
226
+ >>> result = regression(df, 'target')
227
+ >>> best_model = result.models['RFR']
228
+ >>> result.plotter.plot_feature_importance('RFR')
229
+ >>> result.plotter.plot_mols_for_top_features('RFR', smiles_col='SMILES')
230
+ """
231
+ def __init__(self, results_df: pd.DataFrame, plotter: ModelPlotter, trained_models: dict):
232
+ self.dataframe = results_df
233
+ self.plotter = plotter
234
+ self.models = trained_models
235
+
236
+ def _repr_html_(self):
237
+ """Returns the HTML representation of the results DataFrame."""
238
+ return self.dataframe.to_html(index=False, float_format='{:.4f}'.format)
239
+
240
+ def __repr__(self):
241
+ """Returns the string representation for display in non-HTML environments."""
242
+ return self.dataframe.to_string(index=False)
243
+
244
+ # --- Main Function ---
245
+ def regression(df: pd.DataFrame, target_variable: str) -> RegressionResult:
246
+ """
247
+ Trains, evaluates, and provides plotting for multiple regression models.
248
+
249
+ Args:
250
+ df (pd.DataFrame): The input dataframe. Must contain the target variable
251
+ and all features. For molecule plotting, it should also
252
+ contain a SMILES column.
253
+ target_variable (str): The name of the target column.
254
+
255
+ Returns:
256
+ RegressionResult: An object containing the performance metrics DataFrame,
257
+ a ModelPlotter, and a dictionary of trained models.
258
+ """
259
+ # 1. Split data while keeping original structure for molecule plotting
260
+ indices = df.index
261
+ train_indices, test_indices = train_test_split(indices, test_size=0.2, random_state=42)
262
+ df_train = df.loc[train_indices]
263
+ df_test = df.loc[test_indices]
264
+
265
+ # 2. Prepare numeric data for training and evaluation
266
+ X_train = df_train.drop(columns=[target_variable]).apply(pd.to_numeric, errors='coerce').fillna(0)
267
+ y_train = df_train[target_variable]
268
+ X_test = df_test.drop(columns=[target_variable]).apply(pd.to_numeric, errors='coerce').fillna(0)
269
+ y_test = df_test[target_variable]
270
+
271
+ # ... (rest of the models are the same)
272
+ model_definitions = [
273
+ ('Linear Regression', LinearRegression()),
274
+ ('Ridge Regression', Ridge(random_state=42)),
275
+ ('Lasso Regression', Lasso(random_state=42)),
276
+ ('Elastic Net', ElasticNet(random_state=42)),
277
+ ('Lasso Least Angle Regression', LassoLars(random_state=42)),
278
+ ('Orthogonal Matching Pursuit', OrthogonalMatchingPursuit()),
279
+ ('Bayesian Ridge', BayesianRidge()),
280
+ ('Huber Regressor', HuberRegressor()),
281
+ ('Passive Aggressive Regressor', PassiveAggressiveRegressor(random_state=42)),
282
+ ('K Neighbors Regressor', KNeighborsRegressor()),
283
+ ('Decision Tree Regressor', DecisionTreeRegressor(random_state=42)),
284
+ ('Random Forest Regressor', RandomForestRegressor(random_state=42, n_jobs=-1)),
285
+ ('Extra Trees Regressor', ExtraTreesRegressor(random_state=42, n_jobs=-1)),
286
+ ('AdaBoost Regressor', AdaBoostRegressor(random_state=42)),
287
+ ('Gradient Boosting Regressor', GradientBoostingRegressor(random_state=42)),
288
+ ('Dummy Regressor', DummyRegressor(strategy='mean'))
289
+ ]
290
+ if _has_extra_libs:
291
+ model_definitions.extend([
292
+ ('Extreme Gradient Boosting', xgb.XGBRegressor(random_state=42, n_jobs=-1, verbosity=0)),
293
+ ('Light Gradient Boosting Machine', lgb.LGBMRegressor(random_state=42, n_jobs=-1, verbosity=-1)),
294
+ ('CatBoost Regressor', cb.CatBoostRegressor(random_state=42, verbose=0))
295
+ ])
296
+
297
+ results_list = []
298
+ trained_models = {}
299
+ print("Starting model training and evaluation...")
300
+
301
+ for name, model in model_definitions:
302
+ abbr = _create_abbreviation(name)
303
+ start_time = time.time()
304
+ try:
305
+ model.fit(X_train, y_train)
306
+ training_time = time.time() - start_time
307
+ y_pred = model.predict(X_test)
308
+
309
+ results_list.append({
310
+ 'Model Abbreviation': abbr, 'Model': name,
311
+ 'MAE': mean_absolute_error(y_test, y_pred),
312
+ 'MSE': mean_squared_error(y_test, y_pred),
313
+ 'RMSE': np.sqrt(mean_squared_error(y_test, y_pred)),
314
+ 'R2': r2_score(y_test, y_pred),
315
+ 'RMSLE': _rmsle(y_test, y_pred),
316
+ 'MAPE': mean_absolute_percentage_error(y_test, y_pred),
317
+ 'TT (Sec)': training_time
318
+ })
319
+ trained_models[abbr] = model
320
+ except Exception as e:
321
+ print(f"Could not train {name}. Error: {e}")
322
+
323
+
324
+ print("Evaluation complete.")
325
+ print("=" * 50)
326
+
327
+ results_df = pd.DataFrame(results_list)
328
+ results_df = results_df.sort_values(by='R2', ascending=False).reset_index(drop=True)
329
+
330
+ # Pass the original df_test to the plotter for molecule visualization
331
+ plotter = ModelPlotter(trained_models, X_test, y_test, df_test)
332
+
333
+ return RegressionResult(results_df, plotter, trained_models)