alidenewade commited on
Commit
28d310c
·
verified ·
1 Parent(s): 46db056

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +893 -412
app.py CHANGED
@@ -1,70 +1,69 @@
1
  # --- IMPORTS ---
2
  # Core and Data Handling
3
- import gradio as gr #
4
- import pandas as pd #
5
- import numpy as np #
6
- import os #
7
- import glob #
8
- import time #
9
- import warnings #
10
 
11
  # Chemistry and Cheminformatics
12
- from rdkit import Chem #
13
- from rdkit.Chem import Descriptors, Lipinski #
14
- from chembl_webresource_client.new_client import new_client #
15
- from padelpy import padeldescriptor #
16
- import mols2grid #
17
 
18
  # Plotting and Visualization
19
- import matplotlib.pyplot as plt #
20
- import seaborn as sns #
21
- from scipy import stats #
22
- from scipy.stats import mannwhitneyu #
23
 
24
  # Machine Learning Models and Metrics
25
- from sklearn.model_selection import train_test_split #
26
- from sklearn.feature_selection import VarianceThreshold #
27
- from sklearn.linear_model import ( #
28
- LinearRegression, Ridge, Lasso, ElasticNet, BayesianRidge, #
29
- HuberRegressor, PassiveAggressiveRegressor, OrthogonalMatchingPursuit, #
30
- LassoLars #
31
  )
32
- from sklearn.tree import DecisionTreeRegressor #
33
- from sklearn.ensemble import ( #
34
- RandomForestRegressor, GradientBoostingRegressor, ExtraTreesRegressor, #
35
- AdaBoostRegressor #
36
  )
37
- from sklearn.neighbors import KNeighborsRegressor #
38
- from sklearn.dummy import DummyRegressor #
39
- from sklearn.metrics import ( #
40
- mean_absolute_error, mean_squared_error, r2_score #
41
  )
42
 
43
  # A placeholder class to store all results from a modeling run
44
- class ModelRunResult: #
45
- def __init__(self, dataframe, plotter, models, selected_features): #
46
- self.dataframe = dataframe #
47
- self.plotter = plotter #
48
- self.models = models #
49
- self.selected_features = selected_features #
50
 
51
  # Optional Advanced Models
52
- try: #
53
- import xgboost as xgb #
54
- import lightgbm as lgb #
55
- import catboost as cb #
56
- _has_extra_libs = True #
57
- except ImportError: #
58
- _has_extra_libs = False #
59
- warnings.warn("Optional libraries (xgboost, lightgbm, catboost) not found. Some models will be unavailable.") #
60
 
61
  # --- GLOBAL CONFIGURATION & SETUP ---
62
- warnings.filterwarnings("ignore") #
63
- sns.set_theme(style='whitegrid') #
64
 
65
  # --- FINGERPRINT CONFIGURATION ---
66
  DESCRIPTOR_DIR = "padel_descriptors"
67
-
68
  # Check if the descriptor directory exists and contains files
69
  if not os.path.isdir(DESCRIPTOR_DIR):
70
  warnings.warn(
@@ -74,411 +73,893 @@ if not os.path.isdir(DESCRIPTOR_DIR):
74
  xml_files = []
75
  else:
76
  xml_files = sorted(glob.glob(os.path.join(DESCRIPTOR_DIR, '*.xml')))
77
-
78
- if not xml_files:
79
- warnings.warn(
80
- f"No descriptor .xml files found in the '{DESCRIPTOR_DIR}' directory. "
81
- "Fingerprint calculation will not be possible."
82
- )
83
 
84
  # The key is the filename without extension; the value is the full path to the file
85
  fp_config = {os.path.splitext(os.path.basename(file))[0]: file for file in xml_files}
86
  FP_list = sorted(list(fp_config.keys()))
87
 
88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  # ==============================================================================
90
  # === STEP 1: CORE DATA COLLECTION & EDA FUNCTIONS ===
91
  # ==============================================================================
92
-
93
- def get_target_chembl_id(query): #
94
- try: #
95
- target = new_client.target #
96
- res = target.search(query) #
97
- if not res: #
98
- return pd.DataFrame(), gr.Dropdown(choices=[], value=None), "No targets found for your query." #
99
- df = pd.DataFrame(res) #
100
- return df[["target_chembl_id", "pref_name", "organism"]], gr.Dropdown(choices=df["target_chembl_id"].tolist()), f"Found {len(df)} targets." #
101
- except Exception as e: #
102
- raise gr.Error(f"ChEMBL search failed: {e}") #
103
-
104
- def get_bioactivity_data(target_id): #
105
- try: #
106
- activity = new_client.activity #
107
- res = activity.filter(target_chembl_id=target_id).filter(standard_type="IC50") #
108
- if not res: #
109
- return pd.DataFrame(), "No IC50 bioactivity data found for this target." #
110
- df = pd.DataFrame(res) #
111
- return df, f"Fetched {len(df)} data points." #
112
- except Exception as e: #
113
- raise gr.Error(f"Failed to fetch bioactivity data: {e}") #
114
-
115
- def pIC50_calc(input_df): #
116
- df_copy = input_df.copy() #
117
- df_copy['standard_value'] = pd.to_numeric(df_copy['standard_value'], errors='coerce') #
118
- df_copy.dropna(subset=['standard_value'], inplace=True) #
119
- df_copy['standard_value_norm'] = df_copy['standard_value'].apply(lambda x: min(x, 100000000)) #
120
- pIC50_values = [] #
121
- for i in df_copy['standard_value_norm']: #
122
- if pd.notna(i) and i > 0: #
123
- molar = i * (10**-9) #
124
- pIC50_values.append(-np.log10(molar)) #
125
- else: #
126
- pIC50_values.append(np.nan) #
127
- df_copy['pIC50'] = pIC50_values #
128
- df_copy['bioactivity_class'] = df_copy['standard_value_norm'].apply( #
129
- lambda x: "inactive" if pd.notna(x) and x >= 10000 else ("active" if pd.notna(x) and x <= 1000 else "intermediate") #
130
  )
131
- return df_copy.drop(columns=['standard_value', 'standard_value_norm']) #
132
-
133
- def lipinski_descriptors(smiles_series): #
134
- moldata, valid_smiles = [], [] #
135
- for elem in smiles_series: #
136
- if elem and isinstance(elem, str): #
137
- mol = Chem.MolFromSmiles(elem) #
138
- if mol: #
139
- moldata.append(mol) #
140
- valid_smiles.append(elem) #
141
- descriptor_rows = [] #
142
- for mol in moldata: #
143
- row = [Descriptors.MolWt(mol), Descriptors.MolLogP(mol), Lipinski.NumHDonors(mol), Lipinski.NumHAcceptors(mol)] #
144
- descriptor_rows.append(row) #
145
- columnNames = ["MW", "LogP", "NumHDonors", "NumHAcceptors"] #
146
- if not descriptor_rows: return pd.DataFrame(columns=columnNames), [] #
147
- return pd.DataFrame(data=np.array(descriptor_rows), columns=columnNames), valid_smiles #
148
-
149
- def clean_and_process_data(df): #
150
- if df is None or df.empty: raise gr.Error("No data to process. Please fetch data first.") #
151
- if "canonical_smiles" not in df.columns or df["canonical_smiles"].isnull().all(): #
152
- try: #
153
- df["canonical_smiles"] = [c.get("molecule_structures", {}).get("canonical_smiles") for c in new_client.molecule.get(list(df["molecule_chembl_id"]))] #
154
- except Exception as e: #
155
- raise gr.Error(f"Could not fetch SMILES from ChEMBL: {e}") #
156
- df = df[df.standard_value.notna()] #
157
- df = df[df.canonical_smiles.notna()] #
158
- df.drop_duplicates(['canonical_smiles'], inplace=True) #
159
- df["standard_value"] = pd.to_numeric(df["standard_value"], errors='coerce') #
160
- df.dropna(subset=['standard_value'], inplace=True) #
161
- df_processed = pIC50_calc(df) #
162
- df_processed = df_processed[df_processed.pIC50.notna()] #
163
- if df_processed.empty: return pd.DataFrame(), "No compounds remaining after pIC50 calculation." #
164
- df_lipinski, valid_smiles = lipinski_descriptors(df_processed['canonical_smiles']) #
165
- if not valid_smiles: return pd.DataFrame(), "No valid SMILES could be processed for Lipinski descriptors." #
166
- df_processed = df_processed[df_processed['canonical_smiles'].isin(valid_smiles)].reset_index(drop=True) #
167
- df_lipinski = df_lipinski.reset_index(drop=True) #
168
- df_final = pd.concat([df_processed, df_lipinski], axis=1) #
169
- return df_final, f"Processing complete. {len(df_final)} compounds remain after cleaning." #
170
-
171
- def run_eda_analysis(df, selected_classes): #
172
- if df is None or df.empty: raise gr.Error("No data available for analysis.") #
173
- df_filtered = df[df.bioactivity_class.isin(selected_classes)].copy() #
174
- if df_filtered.empty: return (None, None, None, pd.DataFrame(), None, pd.DataFrame(), None, pd.DataFrame(), None, pd.DataFrame(), None, pd.DataFrame(), "No data for selected classes.") #
175
- plots = [create_frequency_plot(df_filtered), create_scatter_plot(df_filtered)] #
176
- stats_dfs = [] #
177
- for desc in ['pIC50', 'MW', 'LogP', 'NumHDonors', 'NumHAcceptors']: #
178
- plots.append(create_boxplot(df_filtered, desc)) #
179
- stats_dfs.append(mannwhitney_test(df_filtered, desc)) #
180
- plt.close('all') #
181
- return (plots[0], plots[1], plots[2], stats_dfs[0], plots[3], stats_dfs[1], plots[4], stats_dfs[2], plots[5], stats_dfs[3], plots[6], stats_dfs[4], f"EDA complete for {len(df_filtered)} compounds.") #
182
-
183
- def create_frequency_plot(df): #
184
- plt.figure(figsize=(5.5, 5.5)); sns.barplot(x=df['bioactivity_class'].value_counts().index, y=df['bioactivity_class'].value_counts().values, palette={'active': '#1f77b4', 'inactive': '#ff7f0e', 'intermediate': '#2ca02c'}); plt.xlabel('Bioactivity Class', fontsize=12); plt.ylabel('Frequency', fontsize=12); plt.title('Frequency of Bioactivity Classes', fontsize=14); return plt.gcf() #
185
- def create_scatter_plot(df): #
186
- plt.figure(figsize=(5.5, 5.5)); sns.scatterplot(data=df, x='MW', y='LogP', hue='bioactivity_class', size='pIC50', palette={'active': '#1f77b4', 'inactive': '#ff7f0e', 'intermediate': '#2ca02c'}, sizes=(20, 200), alpha=0.7); plt.xlabel('Molecular Weight (MW)', fontsize=12); plt.ylabel('LogP', fontsize=12); plt.title('Chemical Space: MW vs. LogP', fontsize=14); plt.legend(title='Bioactivity Class'); return plt.gcf() #
187
- def create_boxplot(df, descriptor): #
188
- plt.figure(figsize=(5.5, 5.5)); sns.boxplot(x='bioactivity_class', y=descriptor, data=df, palette={'active': '#1f77b4', 'inactive': '#ff7f0e', 'intermediate': '#2ca02c'}); plt.xlabel('Bioactivity Class', fontsize=12); plt.ylabel(descriptor, fontsize=12); plt.title(f'{descriptor} by Bioactivity Class', fontsize=14); return plt.gcf() #
189
- def mannwhitney_test(df, descriptor): #
190
- results = [] #
191
- for c1, c2 in [('active', 'inactive'), ('active', 'intermediate'), ('inactive', 'intermediate')]: #
192
- if c1 in df['bioactivity_class'].unique() and c2 in df['bioactivity_class'].unique(): #
193
- d1, d2 = df[df.bioactivity_class == c1][descriptor].dropna(), df[df.bioactivity_class == c2][descriptor].dropna() #
194
- if not d1.empty and not d2.empty: #
195
- stat, p = mannwhitneyu(d1, d2) #
196
- results.append({'Comparison': f'{c1.title()} vs {c2.title()}', 'Statistics': stat, 'p-value': p, 'Interpretation': 'Different distribution (p < 0.05)' if p <= 0.05 else 'Same distribution (p > 0.05)'}) #
197
- return pd.DataFrame(results) #
 
 
 
198
 
199
  # ==============================================================================
200
  # === STEP 2: FEATURE ENGINEERING FUNCTIONS ===
201
  # ==============================================================================
202
-
203
- def calculate_fingerprints(current_state, fingerprint_type, progress=gr.Progress()): #
204
- input_df = current_state.get('cleaned_data') #
205
- if input_df is None or input_df.empty: raise gr.Error("No cleaned data found. Please complete Step 1.") #
206
- if not fingerprint_type: raise gr.Error("Please select a fingerprint type.") #
207
- progress(0, desc="Starting..."); yield f"🧪 Starting fingerprint calculation...", None, gr.update(visible=False), None, current_state #
208
- try: #
209
- smi_file, output_csv = 'molecule.smi', 'fingerprints.csv' #
 
 
 
 
 
 
210
 
211
- input_df[['canonical_smiles', 'canonical_smiles']].to_csv(smi_file, sep='\t', index=False, header=False) #
 
 
 
 
212
 
213
- if os.path.exists(output_csv): os.remove(output_csv) #
214
- descriptortypes = fp_config.get(fingerprint_type) #
215
- if not descriptortypes: raise gr.Error(f"Descriptor XML for '{fingerprint_type}' not found.") #
216
 
217
- progress(0.3, desc="⚗️ Running PaDEL..."); yield f"⚗️ Running PaDEL...", None, gr.update(visible=False), None, current_state #
218
- padeldescriptor(mol_dir=smi_file, d_file=output_csv, descriptortypes=descriptortypes, detectaromaticity=True, standardizenitro=True, standardizetautomers=True, threads=-1, removesalt=True, log=False, fingerprints=True) #
219
- if not os.path.exists(output_csv) or os.path.getsize(output_csv) == 0: #
220
- raise gr.Error("PaDEL failed to produce an output file. Check molecule validity.") #
 
 
 
 
 
 
 
 
 
 
 
221
 
222
- progress(0.7, desc="📊 Processing results..."); yield "📊 Processing results...", None, gr.update(visible=False), None, current_state #
223
- df_X = pd.read_csv(output_csv).rename(columns={'Name': 'canonical_smiles'}) #
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
 
225
- final_df = pd.merge(input_df[['canonical_smiles', 'pIC50']], df_X, on='canonical_smiles', how='inner') #
 
 
226
 
227
- current_state['fingerprint_data'] = final_df; current_state['fingerprint_type'] = fingerprint_type #
228
- progress(0.9, desc="🖼️ Generating molecule grid...") #
229
- mols_html = mols2grid.display(final_df, smiles_col='canonical_smiles', subset=['img', 'pIC50'], rename={"pIC50": "pIC50"}, transform={"pIC50": lambda x: f"{x:.2f}"})._repr_html_() #
230
- success_msg = f"✅ Success! Generated {len(df_X.columns) -1} descriptors for {len(final_df)} molecules." #
231
- progress(1, desc="Completed!"); yield success_msg, final_df, gr.update(visible=True), gr.update(value=mols_html, visible=True), current_state #
232
- except Exception as e: raise gr.Error(f"Calculation failed: {e}") #
233
- finally: #
234
- if os.path.exists('molecule.smi'): os.remove('molecule.smi') #
235
- if os.path.exists('fingerprints.csv'): os.remove('fingerprints.csv') #
236
 
237
  # ==============================================================================
238
- # === STEP 3: MODEL TRAINING & PREDICTION FUNCTIONS ===
239
  # ==============================================================================
240
- class ModelPlotter: #
241
- def __init__(self, models: dict, X_test: pd.DataFrame, y_test: pd.Series): #
242
- self._models, self._X_test, self._y_test = models, X_test, y_test #
243
- def plot_validation(self, model_name: str): #
244
- if model_name not in self._models: raise ValueError(f"Model '{model_name}' not found.") #
245
- model, y_pred = self._models[model_name], self._models[model_name].predict(self._X_test) #
246
- residuals = self._y_test - y_pred #
247
- fig, axes = plt.subplots(2, 2, figsize=(12, 10)); fig.suptitle(f'Model Validation Plots for {model_name}', fontsize=16, y=1.02) #
248
- sns.scatterplot(x=self._y_test, y=y_pred, ax=axes[0, 0], alpha=0.6); axes[0, 0].set_title('Actual vs. Predicted'); axes[0, 0].set_xlabel('Actual pIC50'); axes[0, 0].set_ylabel('Predicted pIC50'); lims = [min(self._y_test.min(), y_pred.min()), max(self._y_test.max(), y_pred.max())]; axes[0, 0].plot(lims, lims, 'r--', alpha=0.75, zorder=0) #
249
- sns.scatterplot(x=y_pred, y=residuals, ax=axes[0, 1], alpha=0.6); axes[0, 1].axhline(y=0, color='r', linestyle='--'); axes[0, 1].set_title('Residuals vs. Predicted'); axes[0, 1].set_xlabel('Predicted pIC50'); axes[0, 1].set_ylabel('Residuals') #
250
- sns.histplot(residuals, kde=True, ax=axes[1, 0]); axes[1, 0].set_title('Distribution of Residuals') #
251
- stats.probplot(residuals, dist="norm", plot=axes[1, 1]); axes[1, 1].set_title('Normal Q-Q Plot') #
252
- plt.tight_layout(); return fig #
253
- def plot_feature_importance(self, model_name: str, top_n: int = 7): #
254
- if model_name not in self._models: raise ValueError(f"Model '{model_name}' not found.") #
255
- model = self._models[model_name] #
256
- if hasattr(model, 'feature_importances_'): importances = model.feature_importances_ #
257
- elif hasattr(model, 'coef_'): importances = np.abs(model.coef_) #
258
- else: return None #
259
- top_features = pd.DataFrame({'Feature': self._X_test.columns, 'Importance': importances}).sort_values(by='Importance', ascending=False).head(top_n) #
260
- plt.figure(figsize=(10, top_n * 0.5)); sns.barplot(x='Importance', y='Feature', data=top_features, palette='viridis', orient='h'); plt.title(f'Top {top_n} Features for {model_name}'); plt.tight_layout(); return plt.gcf() #
261
-
262
- def run_regression_suite(df: pd.DataFrame, progress=gr.Progress()): #
263
- progress(0, desc="Splitting data..."); yield "Splitting data (80/20 train/test split)...", None, None #
264
- X = df.drop(columns=['pIC50', 'canonical_smiles'], errors='ignore') #
265
- y = df['pIC50'] #
266
- X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) #
267
-
268
- progress(0.1, desc="Selecting features..."); yield "Performing feature selection (removing low variance)...", None, None #
269
- selector = VarianceThreshold(threshold=0.1) #
270
- X_train = pd.DataFrame(selector.fit_transform(X_train), columns=X_train.columns[selector.get_support()], index=X_train.index) #
271
- X_test = pd.DataFrame(selector.transform(X_test), columns=X_test.columns[selector.get_support()], index=X_test.index) #
272
- selected_features = X_train.columns.tolist() #
273
-
274
- model_defs = [
275
- ('Linear Regression', LinearRegression()),
276
- ('Ridge', Ridge(random_state=42)),
277
- ('Lasso', Lasso(random_state=42)),
278
- ('Random Forest', RandomForestRegressor(random_state=42, n_jobs=-1)),
279
- # ('Gradient Boosting', GradientBoostingRegressor(random_state=42)) # <-- Commented out
280
- ]
281
- if _has_extra_libs:
282
- model_defs.extend([
283
- # ('XGBoost', xgb.XGBRegressor(random_state=42, n_jobs=-1, verbosity=0)), # <-- Commented out
284
- ('LightGBM', lgb.LGBMRegressor(random_state=42, n_jobs=-1, verbosity=-1)),
285
- # ('CatBoost', cb.CatBoostRegressor(random_state=42, verbose=0)) # <-- Commented out
286
- ])
287
 
288
- results_list, trained_models = [], {} #
289
- for i, (name, model) in enumerate(model_defs): #
290
- progress(0.2 + (i / len(model_defs)) * 0.8, desc=f"Training {name}...") #
291
- yield f"Training {i+1}/{len(model_defs)}: {name}...", None, None #
292
- start_time = time.time(); model.fit(X_train, y_train); y_pred = model.predict(X_test) #
293
- results_list.append({'Model': name, '': r2_score(y_test, y_pred), 'MAE': mean_absolute_error(y_test, y_pred), 'RMSE': np.sqrt(mean_squared_error(y_test, y_pred)), 'Time (s)': f"{time.time() - start_time:.2f}"}) #
294
- trained_models[name] = model #
295
-
296
- results_df = pd.DataFrame(results_list).sort_values(by='R²', ascending=False).reset_index(drop=True) #
297
- plotter = ModelPlotter(trained_models, X_test, y_test) #
298
- model_run_results = ModelRunResult(results_df, plotter, trained_models, selected_features) #
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
 
300
- model_choices = results_df['Model'].tolist() #
301
- yield "✅ Model training & evaluation complete.", model_run_results, gr.Dropdown(choices=model_choices, interactive=True) #
302
-
303
- def predict_on_upload(uploaded_file, model_name, current_state, progress=gr.Progress()): #
304
- if not uploaded_file: raise gr.Error("Please upload a file.") #
305
- if not model_name: raise gr.Error("Please select a trained model.") #
306
- model_run_results = current_state.get('model_results') #
307
- fingerprint_type = current_state.get('fingerprint_type') #
308
- if not model_run_results or not fingerprint_type: raise gr.Error("Please run Steps 2 and 3 first.") #
309
 
310
- model = model_run_results.models.get(model_name) #
311
- selected_features = model_run_results.selected_features #
312
- if model is None: raise gr.Error(f"Model '{model_name}' not found.") #
 
313
 
314
- smi_file, output_csv = 'predict.smi', 'predict_fp.csv' #
315
- try: #
316
- progress(0, desc="Reading & processing new molecules..."); yield "Reading uploaded file...", None, None #
317
- df_new = pd.read_csv(uploaded_file.name) #
318
- if 'canonical_smiles' not in df_new.columns: raise gr.Error("CSV must contain a 'canonical_smiles' column.") #
319
- df_new = df_new.reset_index().rename(columns={'index': 'mol_id'}) #
320
 
321
- padel_input = pd.DataFrame({'smiles': df_new['canonical_smiles'], 'name': df_new['mol_id']}) #
322
- padel_input.to_csv(smi_file, sep='\t', index=False, header=False) #
323
- if os.path.exists(output_csv): os.remove(output_csv) #
 
324
 
325
- progress(0.3, desc="Calculating fingerprints..."); yield "Calculating fingerprints for new molecules...", None, None #
326
- padeldescriptor(mol_dir=smi_file, d_file=output_csv, descriptortypes=fp_config.get(fingerprint_type), detectaromaticity=True, standardizenitro=True, threads=-1, removesalt=True, log=False, fingerprints=True) #
327
- if not os.path.exists(output_csv) or os.path.getsize(output_csv) == 0: raise gr.Error("PaDEL calculation failed for the uploaded molecules.") #
 
 
 
 
328
 
329
- progress(0.7, desc="Aligning features and predicting..."); yield "Aligning features and predicting...", None, None #
330
- df_fp = pd.read_csv(output_csv).rename(columns={'Name': 'mol_id'}) #
331
 
332
- X_new = df_fp.set_index('mol_id') #
333
- X_new_aligned = X_new.reindex(columns=selected_features, fill_value=0)[selected_features] #
 
 
 
 
 
 
 
 
 
334
 
335
- predictions = model.predict(X_new_aligned) #
 
336
 
337
- results_subset = pd.DataFrame({'mol_id': X_new_aligned.index, 'predicted_pIC50': predictions}) #
338
- df_results = pd.merge(df_new, results_subset, on='mol_id', how='left') #
339
-
340
- progress(0.9, desc="Generating visualization..."); yield "Generating visualization...", None, None #
 
 
 
 
 
 
 
 
 
 
 
341
 
342
- df_grid_view = df_results.dropna(subset=['predicted_pIC50']).copy() #
343
- mols_html = "<h3>No molecules with successful predictions to display.</h3>" #
344
- if not df_grid_view.empty: #
345
- df_grid_view.rename(columns={"predicted_pIC50": "Predicted pIC50"}, inplace=True) #
346
- mols_html = mols2grid.display( #
347
- df_grid_view, #
348
- smiles_col='canonical_smiles', #
349
- subset=['img', 'Predicted pIC50'], #
350
- transform={"Predicted pIC50": lambda x: f"{x:.2f}"} #
351
- )._repr_html_() #
352
 
353
- progress(1, desc="Complete!"); yield "✅ Prediction complete.", df_results[['canonical_smiles', 'predicted_pIC50']], mols_html #
354
- finally: #
355
- if os.path.exists(smi_file): os.remove(smi_file) #
356
- if os.path.exists(output_csv): os.remove(output_csv) #
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
357
 
358
  # ==============================================================================
359
- # === GRADIO INTERFACE ===
360
  # ==============================================================================
361
- with gr.Blocks(theme=gr.themes.Default(primary_hue="blue", secondary_hue="sky"), title="Comprehensive Drug Discovery Workflow") as demo: #
362
- gr.Markdown("# 🧪 Comprehensive Drug Discovery Workflow") #
363
- gr.Markdown("A 3-step application to fetch, analyze, and model chemical bioactivity data.") #
364
- app_state = gr.State({}) #
365
- with gr.Tabs(): #
366
- with gr.Tab("Step 1: Data Collection & EDA"): #
367
- gr.Markdown("## Fetch Bioactivity Data from ChEMBL and Perform Exploratory Analysis") #
368
- with gr.Row(): #
369
- query_input = gr.Textbox(label="Target Query", placeholder="e.g., acetylcholinesterase, BRAF kinase", scale=3) #
370
- fetch_btn = gr.Button("Fetch Targets", variant="primary", scale=1) #
371
- status_step1_fetch = gr.Textbox(label="Status", interactive=False) #
372
- target_id_table = gr.Dataframe(label="Available Targets", interactive=False, headers=["target_chembl_id", "pref_name", "organism"]) #
373
- with gr.Row(): #
374
- selected_target_dropdown = gr.Dropdown(label="Select Target ChEMBL ID", interactive=True, scale=3) #
375
- process_btn = gr.Button("Process Data & Run EDA", variant="primary", scale=1, interactive=False) #
376
- status_step1_process = gr.Textbox(label="Status", interactive=False) #
377
- gr.Markdown("### Filtered Data & Analysis") #
378
- bioactivity_class_selector = gr.CheckboxGroup(["active", "inactive", "intermediate"], label="Filter by Bioactivity Class", value=["active", "inactive", "intermediate"]) #
379
- df_output_s1 = gr.Dataframe(label="Cleaned Bioactivity Data") #
380
- with gr.Tabs(): #
381
- with gr.Tab("Chemical Space Overview"): #
382
- with gr.Row(): #
383
- freq_plot_output = gr.Plot(label="Frequency of Bioactivity Classes") #
384
- scatter_plot_output = gr.Plot(label="Scatter Plot: MW vs LogP") #
385
- with gr.Tab("pIC50 Analysis"): #
386
- with gr.Row(): #
387
- pic50_plot_output = gr.Plot(label="pIC50 Box Plot") #
388
- pic50_stats_output = gr.Dataframe(label="Mann-Whitney U Test Results for pIC50") #
389
- with gr.Tab("Molecular Weight Analysis"): #
390
- with gr.Row(): #
391
- mw_plot_output = gr.Plot(label="MW Box Plot") #
392
- mw_stats_output = gr.Dataframe(label="Mann-Whitney U Test Results for MW") #
393
- with gr.Tab("LogP Analysis"): #
394
- with gr.Row(): #
395
- logp_plot_output = gr.Plot(label="LogP Box Plot") #
396
- logp_stats_output = gr.Dataframe(label="Mann-Whitney U Test Results for LogP") #
397
- with gr.Tab("H-Bond Donor/Acceptor Analysis"): #
398
- with gr.Row(): #
399
- hdonors_plot_output = gr.Plot(label="H-Donors Box Plot") #
400
- hacceptors_plot_output = gr.Plot(label="H-Acceptors Box Plot") #
401
- with gr.Row(): #
402
- hdonors_stats_output = gr.Dataframe(label="Stats for H-Donors") #
403
- hacceptors_stats_output = gr.Dataframe(label="Stats for H-Acceptors") #
404
- with gr.Tab("Step 2: Feature Engineering"): #
405
- gr.Markdown("## Calculate Molecular Fingerprints using PaDEL") #
406
- with gr.Row(): #
407
- fingerprint_dropdown = gr.Dropdown(choices=FP_list, value='PubChem' if 'PubChem' in FP_list else None, label="Select Fingerprint Method", scale=3) #
408
- calculate_fp_btn = gr.Button("Calculate Fingerprints", variant="primary", scale=1) #
409
- status_step2 = gr.Textbox(label="Status", interactive=False) #
410
- output_df_s2 = gr.Dataframe(label="Final Processed Data", wrap=True) #
411
- download_s2 = gr.DownloadButton("Download Feature Data (CSV)", variant="secondary", visible=False) #
412
- mols_grid_s2 = gr.HTML(label="Interactive Molecule Viewer") #
413
- with gr.Tab("Step 3: Model Training & Prediction"): #
414
- gr.Markdown("## Train Regression Models and Predict pIC50") #
415
- with gr.Tabs(): #
416
- with gr.Tab("Model Training & Evaluation"): #
417
- train_models_btn = gr.Button("Train All Models", variant="primary") #
418
- status_step3_train = gr.Textbox(label="Status", interactive=False) #
419
- model_results_df = gr.DataFrame(label="Ranked Model Results", interactive=False) #
420
- with gr.Row(): #
421
- model_selector_s3 = gr.Dropdown(label="Select Model to Analyze", interactive=False) #
422
- feature_count_s3 = gr.Number(label="Top Features to Show", value=7, minimum=3, maximum=20, step=1) #
423
- with gr.Tabs(): #
424
- with gr.Tab("Validation Plots"): validation_plot_s3 = gr.Plot(label="Model Validation Plots") #
425
- with gr.Tab("Feature Importance"): feature_plot_s3 = gr.Plot(label="Top Feature Importances") #
426
- with gr.Tab("Predict on New Data"): #
427
- gr.Markdown("Upload a CSV with a `canonical_smiles` column to predict pIC50.") #
428
- with gr.Row(): #
429
- upload_predict_file = gr.File(label="Upload CSV for Prediction", file_types=[".csv"]) #
430
- predict_btn_s3 = gr.Button("Run Prediction", variant="primary") #
431
- status_step3_predict = gr.Textbox(label="Status", interactive=False) #
432
- prediction_results_df = gr.DataFrame(label="Prediction Results") #
433
- prediction_mols_grid = gr.HTML(label="Interactive Molecular Grid of Predictions") #
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
434
 
435
  # --- EVENT HANDLERS ---
436
- def enable_process_button(target_id): return gr.update(interactive=bool(target_id)) #
437
- def process_and_analyze_wrapper(target_id, selected_classes, current_state, progress=gr.Progress()): #
438
- if not target_id: raise gr.Error("Please select a target ChEMBL ID first.") #
439
- progress(0, desc="Fetching data..."); raw_data, msg1 = get_bioactivity_data(target_id); yield {status_step1_process: gr.update(value=msg1)} #
440
- progress(0.3, desc="Cleaning data..."); processed_data, msg2 = clean_and_process_data(raw_data); yield {df_output_s1: processed_data, status_step1_process: gr.update(value=msg2)} #
441
- current_state['cleaned_data'] = processed_data #
442
- progress(0.6, desc="Running EDA..."); plots_and_stats = run_eda_analysis(processed_data, selected_classes); msg3 = plots_and_stats[-1] #
443
- progress(1, desc="Done!") #
444
- filtered_data = processed_data[processed_data.bioactivity_class.isin(selected_classes)] if not processed_data.empty else pd.DataFrame() #
445
- outputs = [filtered_data] + list(plots_and_stats[:-1]) + [msg3, current_state] #
446
- output_components = [df_output_s1, freq_plot_output, scatter_plot_output, pic50_plot_output, pic50_stats_output, mw_plot_output, mw_stats_output, logp_plot_output, logp_stats_output, hdonors_plot_output, hdonors_stats_output, hacceptors_plot_output, hacceptors_stats_output, status_step1_process, app_state] #
447
- yield dict(zip(output_components, outputs)) #
448
- def update_analysis_on_filter_change(selected_classes, current_state): #
449
- cleaned_data = current_state.get('cleaned_data') #
450
- if cleaned_data is None or cleaned_data.empty: return (pd.DataFrame(),) + (None,) * 11 + ("No data available.",) #
451
- plots_and_stats = run_eda_analysis(cleaned_data, selected_classes); msg = plots_and_stats[-1] #
452
- filtered_data = cleaned_data[cleaned_data.bioactivity_class.isin(selected_classes)] #
453
- return (filtered_data,) + plots_and_stats[:-1] + (msg,) #
454
- def handle_model_training(current_state, progress=gr.Progress(track_tqdm=True)): #
455
- fingerprint_data = current_state.get('fingerprint_data') #
456
- if fingerprint_data is None or fingerprint_data.empty: raise gr.Error("No feature data. Please complete Step 2.") #
457
- for status_msg, model_results, model_choices_update in run_regression_suite(fingerprint_data, progress=progress): #
458
- if model_results: current_state['model_results'] = model_results #
459
- yield status_msg, model_results.dataframe if model_results else None, model_choices_update, current_state #
460
- def save_dataframe_as_csv(df): #
461
- if df is None or df.empty: return None #
462
- filename = "feature_engineered_data.csv"; df.to_csv(filename, index=False); return gr.File(value=filename, visible=True) #
463
- def update_analysis_plots(model_name, feature_count, current_state): #
464
- model_results = current_state.get('model_results') #
465
- if not model_results or not model_name: return None, None #
466
- plotter = model_results.plotter; validation_fig = plotter.plot_validation(model_name); feature_fig = plotter.plot_feature_importance(model_name, int(feature_count)); plt.close('all'); return validation_fig, feature_fig #
467
-
468
- fetch_btn.click(fn=get_target_chembl_id, inputs=query_input, outputs=[target_id_table, selected_target_dropdown, status_step1_fetch], show_progress="minimal") #
469
- selected_target_dropdown.change(fn=enable_process_button, inputs=selected_target_dropdown, outputs=process_btn, show_progress="hidden") #
470
- process_btn.click(fn=process_and_analyze_wrapper, inputs=[selected_target_dropdown, bioactivity_class_selector, app_state], outputs=[df_output_s1, freq_plot_output, scatter_plot_output, pic50_plot_output, pic50_stats_output, mw_plot_output, mw_stats_output, logp_plot_output, logp_stats_output, hdonors_plot_output, hdonors_stats_output, hacceptors_plot_output, hacceptors_stats_output, status_step1_process, app_state]) #
471
- bioactivity_class_selector.change(fn=update_analysis_on_filter_change, inputs=[bioactivity_class_selector, app_state], outputs=[df_output_s1, freq_plot_output, scatter_plot_output, pic50_plot_output, pic50_stats_output, mw_plot_output, mw_stats_output, logp_plot_output, logp_stats_output, hdonors_plot_output, hdonors_stats_output, hacceptors_plot_output, hacceptors_stats_output, status_step1_process], show_progress="minimal") #
472
- calculate_fp_btn.click(fn=calculate_fingerprints, inputs=[app_state, fingerprint_dropdown], outputs=[status_step2, output_df_s2, download_s2, mols_grid_s2, app_state]) #
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
473
 
474
- @download_s2.click(inputs=app_state, outputs=download_s2, show_progress="hidden") #
475
- def download_handler(current_state): #
476
- df_to_download = current_state.get('fingerprint_data') #
477
- return save_dataframe_as_csv(df_to_download) #
 
 
 
 
 
 
 
 
 
 
 
 
 
 
478
 
479
- train_models_btn.click(fn=handle_model_training, inputs=[app_state], outputs=[status_step3_train, model_results_df, model_selector_s3, app_state]) #
480
- for listener in [model_selector_s3.change, feature_count_s3.change]: listener(fn=update_analysis_plots, inputs=[model_selector_s3, feature_count_s3, app_state], outputs=[validation_plot_s3, feature_plot_s3], show_progress="minimal") #
481
- predict_btn_s3.click(fn=predict_on_upload, inputs=[upload_predict_file, model_selector_s3, app_state], outputs=[status_step3_predict, prediction_results_df, prediction_mols_grid]) #
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
482
 
483
- if __name__ == "__main__": #
484
- demo.launch(debug=True) #
 
1
  # --- IMPORTS ---
2
  # Core and Data Handling
3
+ import gradio as gr
4
+ import pandas as pd
5
+ import numpy as np
6
+ import os
7
+ import glob
8
+ import time
9
+ import warnings
10
 
11
  # Chemistry and Cheminformatics
12
+ from rdkit import Chem
13
+ from rdkit.Chem import Descriptors, Lipinski
14
+ from chembl_webresource_client.new_client import new_client
15
+ from padelpy import padeldescriptor
16
+ # import mols2grid # This line will be removed
17
 
18
  # Plotting and Visualization
19
+ import matplotlib.pyplot as plt
20
+ import seaborn as sns
21
+ from scipy import stats
22
+ from scipy.stats import mannwhitneyu
23
 
24
  # Machine Learning Models and Metrics
25
+ from sklearn.model_selection import train_test_split
26
+ from sklearn.feature_selection import VarianceThreshold
27
+ from sklearn.linear_model import (
28
+ LinearRegression, Ridge, Lasso, ElasticNet, BayesianRidge,
29
+ HuberRegressor, PassiveAggressiveRegressor, OrthogonalMatchingPursuit,
30
+ LassoLars
31
  )
32
+ from sklearn.tree import DecisionTreeRegressor
33
+ from sklearn.ensemble import (
34
+ RandomForestRegressor, GradientBoostingRegressor, ExtraTreesRegressor,
35
+ AdaBoostRegressor
36
  )
37
+ from sklearn.neighbors import KNeighborsRegressor
38
+ from sklearn.dummy import DummyRegressor
39
+ from sklearn.metrics import (
40
+ mean_absolute_error, mean_squared_error, r2_score
41
  )
42
 
43
  # A placeholder class to store all results from a modeling run
44
+ class ModelRunResult:
45
+ def __init__(self, dataframe, plotter, models, selected_features):
46
+ self.dataframe = dataframe
47
+ self.plotter = plotter
48
+ self.models = models
49
+ self.selected_features = selected_features
50
 
51
  # Optional Advanced Models
52
+ try:
53
+ import xgboost as xgb
54
+ import lightgbm as lgb
55
+ import catboost as cb
56
+ _has_extra_libs = True
57
+ except ImportError:
58
+ _has_extra_libs = False
59
+ warnings.warn("Optional libraries (xgboost, lightgbm, catboost) not found. Some models will be unavailable.")
60
 
61
  # --- GLOBAL CONFIGURATION & SETUP ---
62
+ warnings.filterwarnings("ignore")
63
+ sns.set_theme(style='whitegrid')
64
 
65
  # --- FINGERPRINT CONFIGURATION ---
66
  DESCRIPTOR_DIR = "padel_descriptors"
 
67
  # Check if the descriptor directory exists and contains files
68
  if not os.path.isdir(DESCRIPTOR_DIR):
69
  warnings.warn(
 
73
  xml_files = []
74
  else:
75
  xml_files = sorted(glob.glob(os.path.join(DESCRIPTOR_DIR, '*.xml')))
76
+ if not xml_files:
77
+ warnings.warn(
78
+ f"No descriptor .xml files found in the '{DESCRIPTOR_DIR}' directory. "
79
+ "Fingerprint calculation will not be possible."
80
+ )
 
81
 
82
  # The key is the filename without extension; the value is the full path to the file
83
  fp_config = {os.path.splitext(os.path.basename(file))[0]: file for file in xml_files}
84
  FP_list = sorted(list(fp_config.keys()))
85
 
86
 
87
+ # --- NEW MOLECULE DISPLAY FUNCTIONS ---
88
+ def create_molecule_grid_html(df, smiles_col='canonical_smiles', additional_cols=None, max_mols=50):
89
+ """
90
+ Create a custom HTML grid for displaying molecules using RDKit and base64 encoding
91
+ This works better in Hugging Face Spaces than mols2grid
92
+ """
93
+ from rdkit import Chem
94
+ from rdkit.Chem import Draw
95
+ import base64
96
+ from io import BytesIO
97
+
98
+ if df.empty:
99
+ return "<h3>No molecules to display.</h3>"
100
+
101
+ # Limit number of molecules for performance
102
+ df_display = df.head(max_mols).copy()
103
+
104
+ # Additional columns to display
105
+ if additional_cols is None:
106
+ additional_cols = []
107
+
108
+ html_parts = ["""
109
+ <div style="display: grid; grid-template-columns: repeat(auto-fill, minmax(300px, 1fr)); gap: 15px; padding: 10px;">
110
+ """]
111
+
112
+ for idx, row in df_display.iterrows():
113
+ try:
114
+ # Generate molecule image
115
+ mol = Chem.MolFromSmiles(row[smiles_col])
116
+ if mol is None:
117
+ continue
118
+
119
+ # Draw molecule
120
+ img = Draw.MolToImage(mol, size=(250, 200))
121
+
122
+ # Convert to base64
123
+ buffer = BytesIO()
124
+ img.save(buffer, format='PNG')
125
+ img_str = base64.b64encode(buffer.getvalue()).decode()
126
+
127
+ # Create card HTML
128
+ card_html = f"""
129
+ <div style="border: 1px solid #ddd; border-radius: 8px; padding: 10px; background: white;">
130
+ <img src="data:image/png;base64,{img_str}" style="width: 100%; height: auto;" alt="Molecule"/>
131
+ <div style="margin-top: 8px; font-size: 12px;">
132
+ <strong>SMILES:</strong> {row[smiles_col][:50]}{'...' if len(str(row[smiles_col])) > 50 else ''}<br/>
133
+ """
134
+
135
+ # Add additional columns
136
+ for col in additional_cols:
137
+ if col in row and pd.notna(row[col]):
138
+ value = row[col]
139
+ if isinstance(value, float):
140
+ value = f"{value:.2f}"
141
+ card_html += f"<strong>{col}:</strong> {value}<br/>"
142
+
143
+ card_html += """
144
+ </div>
145
+ </div>
146
+ """
147
+ html_parts.append(card_html)
148
+
149
+ except Exception as e:
150
+ print(f"Error processing molecule {idx}: {e}")
151
+ continue
152
+
153
+ html_parts.append("</div>")
154
+
155
+ if len(html_parts) == 2: # Only header and footer, no molecules processed
156
+ return "<h3>No valid molecules to display.</h3>"
157
+
158
+ return "".join(html_parts)
159
+
160
+ def create_simple_molecule_table(df, smiles_col='canonical_smiles', additional_cols=None, max_mols=20):
161
+ """
162
+ Create a simple HTML table with molecule images - fallback option
163
+ """
164
+ from rdkit import Chem
165
+ from rdkit.Chem import Draw
166
+ import base64
167
+ from io import BytesIO
168
+
169
+ if df.empty:
170
+ return "<h3>No molecules to display.</h3>"
171
+
172
+ df_display = df.head(max_mols).copy()
173
+
174
+ if additional_cols is None:
175
+ additional_cols = []
176
+
177
+ # Start HTML table
178
+ html = """
179
+ <table style="border-collapse: collapse; width: 100%;">
180
+ <thead>
181
+ <tr style="background-color: #f2f2f2;">
182
+ <th style="border: 1px solid #ddd; padding: 8px;">Structure</th>
183
+ <th style="border: 1px solid #ddd; padding: 8px;">SMILES</th>
184
+ """
185
+
186
+ for col in additional_cols:
187
+ html += f'<th style="border: 1px solid #ddd; padding: 8px;">{col}</th>'
188
+
189
+ html += """
190
+ </tr>
191
+ </thead>
192
+ <tbody>
193
+ """
194
+
195
+ for idx, row in df_display.iterrows():
196
+ try:
197
+ mol = Chem.MolFromSmiles(row[smiles_col])
198
+ if mol is None:
199
+ continue
200
+
201
+ # Generate image
202
+ img = Draw.MolToImage(mol, size=(200, 150))
203
+ buffer = BytesIO()
204
+ img.save(buffer, format='PNG')
205
+ img_str = base64.b64encode(buffer.getvalue()).decode()
206
+
207
+ html += f"""
208
+ <tr>
209
+ <td style="border: 1px solid #ddd; padding: 8px; text-align: center;">
210
+ <img src="data:image/png;base64,{img_str}" style="max-width: 200px; height: auto;" alt="Molecule"/>
211
+ </td>
212
+ <td style="border: 1px solid #ddd; padding: 8px; font-family: monospace; font-size: 11px;">
213
+ {row[smiles_col][:100]}{'...' if len(str(row[smiles_col])) > 100 else ''}
214
+ </td>
215
+ """
216
+
217
+ for col in additional_cols:
218
+ value = row[col] if col in row and pd.notna(row[col]) else "N/A"
219
+ if isinstance(value, float):
220
+ value = f"{value:.2f}"
221
+ html += f'<td style="border: 1px solid #ddd; padding: 8px;">{value}</td>'
222
+
223
+ html += "</tr>"
224
+
225
+ except Exception as e:
226
+ print(f"Error processing molecule {idx}: {e}")
227
+ continue
228
+
229
+ html += "</tbody></table>"
230
+ return html
231
+
232
  # ==============================================================================
233
  # === STEP 1: CORE DATA COLLECTION & EDA FUNCTIONS ===
234
  # ==============================================================================
235
+ def get_target_chembl_id(query):
236
+ try:
237
+ target = new_client.target
238
+ res = target.search(query)
239
+ if not res:
240
+ return pd.DataFrame(), gr.Dropdown(choices=[], value=None), "No targets found for your query."
241
+ df = pd.DataFrame(res)
242
+ return df[["target_chembl_id", "pref_name", "organism"]], gr.Dropdown(choices=df["target_chembl_id"].tolist()), f"Found {len(df)} targets."
243
+ except Exception as e:
244
+ raise gr.Error(f"ChEMBL search failed: {e}")
245
+
246
+ def get_bioactivity_data(target_id):
247
+ try:
248
+ activity = new_client.activity
249
+ res = activity.filter(target_chembl_id=target_id).filter(standard_type="IC50")
250
+ if not res:
251
+ return pd.DataFrame(), "No IC50 bioactivity data found for this target."
252
+ df = pd.DataFrame(res)
253
+ return df, f"Fetched {len(df)} data points."
254
+ except Exception as e:
255
+ raise gr.Error(f"Failed to fetch bioactivity data: {e}")
256
+
257
+ def pIC50_calc(input_df):
258
+ df_copy = input_df.copy()
259
+ df_copy['standard_value'] = pd.to_numeric(df_copy['standard_value'], errors='coerce')
260
+ df_copy.dropna(subset=['standard_value'], inplace=True)
261
+ df_copy['standard_value_norm'] = df_copy['standard_value'].apply(lambda x: min(x, 100000000))
262
+ pIC50_values = []
263
+ for i in df_copy['standard_value_norm']:
264
+ if pd.notna(i) and i > 0:
265
+ molar = i * (10**-9)
266
+ pIC50_values.append(-np.log10(molar))
267
+ else:
268
+ pIC50_values.append(np.nan)
269
+ df_copy['pIC50'] = pIC50_values
270
+ df_copy['bioactivity_class'] = df_copy['standard_value_norm'].apply(
271
+ lambda x: "inactive" if pd.notna(x) and x >= 10000 else ("active" if pd.notna(x) and x <= 1000 else "intermediate")
 
272
  )
273
+ return df_copy.drop(columns=['standard_value', 'standard_value_norm'])
274
+
275
+ def lipinski_descriptors(smiles_series):
276
+ moldata, valid_smiles = [], []
277
+ for elem in smiles_series:
278
+ if elem and isinstance(elem, str):
279
+ mol = Chem.MolFromSmiles(elem)
280
+ if mol:
281
+ moldata.append(mol)
282
+ valid_smiles.append(elem)
283
+ descriptor_rows = []
284
+ for mol in moldata:
285
+ row = [Descriptors.MolWt(mol), Descriptors.MolLogP(mol), Lipinski.NumHDonors(mol), Lipinski.NumHAcceptors(mol)]
286
+ descriptor_rows.append(row)
287
+ columnNames = ["MW", "LogP", "NumHDonors", "NumHAcceptors"]
288
+ if not descriptor_rows: return pd.DataFrame(columns=columnNames), []
289
+ return pd.DataFrame(data=np.array(descriptor_rows), columns=columnNames), valid_smiles
290
+
291
+ def clean_and_process_data(df):
292
+ if df is None or df.empty: raise gr.Error("No data to process. Please fetch data first.")
293
+ if "canonical_smiles" not in df.columns or df["canonical_smiles"].isnull().all():
294
+ try:
295
+ df["canonical_smiles"] = [c.get("molecule_structures", {}).get("canonical_smiles") for c in new_client.molecule.get(list(df["molecule_chembl_id"]))]
296
+ except Exception as e:
297
+ raise gr.Error(f"Could not fetch SMILES from ChEMBL: {e}")
298
+ df = df[df.standard_value.notna()]
299
+ df = df[df.canonical_smiles.notna()]
300
+ df.drop_duplicates(['canonical_smiles'], inplace=True)
301
+ df["standard_value"] = pd.to_numeric(df["standard_value"], errors='coerce')
302
+ df.dropna(subset=['standard_value'], inplace=True)
303
+ df_processed = pIC50_calc(df)
304
+ df_processed = df_processed[df_processed.pIC50.notna()]
305
+ if df_processed.empty: return pd.DataFrame(), "No compounds remaining after pIC50 calculation."
306
+ df_lipinski, valid_smiles = lipinski_descriptors(df_processed['canonical_smiles'])
307
+ if not valid_smiles: return pd.DataFrame(), "No valid SMILES could be processed for Lipinski descriptors."
308
+ df_processed = df_processed[df_processed['canonical_smiles'].isin(valid_smiles)].reset_index(drop=True)
309
+ df_lipinski = df_lipinski.reset_index(drop=True)
310
+ df_final = pd.concat([df_processed, df_lipinski], axis=1)
311
+ return df_final, f"Processing complete. {len(df_final)} compounds remain after cleaning."
312
+
313
+ def run_eda_analysis(df, selected_classes):
314
+ if df is None or df.empty: raise gr.Error("No data available for analysis.")
315
+ df_filtered = df[df.bioactivity_class.isin(selected_classes)].copy()
316
+ if df_filtered.empty: return (None, None, None, pd.DataFrame(), None, pd.DataFrame(), None, pd.DataFrame(), None, pd.DataFrame(), None, pd.DataFrame(), "No data for selected classes.")
317
+ plots = [create_frequency_plot(df_filtered), create_scatter_plot(df_filtered)]
318
+ stats_dfs = []
319
+ for desc in ['pIC50', 'MW', 'LogP', 'NumHDonors', 'NumHAcceptors']:
320
+ plots.append(create_boxplot(df_filtered, desc))
321
+ stats_dfs.append(mannwhitney_test(df_filtered, desc))
322
+ plt.close('all')
323
+ return (plots[0], plots[1], plots[2], stats_dfs[0], plots[3], stats_dfs[1], plots[4], stats_dfs[2], plots[5], stats_dfs[3], plots[6], stats_dfs[4], f"EDA complete for {len(df_filtered)} compounds.")
324
+
325
+ def create_frequency_plot(df):
326
+ plt.figure(figsize=(5.5, 5.5)); sns.barplot(x=df['bioactivity_class'].value_counts().index, y=df['bioactivity_class'].value_counts().values, palette={'active': '#1f77b4', 'inactive': '#ff7f0e', 'intermediate': '#2ca02c'}); plt.xlabel('Bioactivity Class', fontsize=12); plt.ylabel('Frequency', fontsize=12); plt.title('Frequency of Bioactivity Classes', fontsize=14); return plt.gcf()
327
+
328
+ def create_scatter_plot(df):
329
+ plt.figure(figsize=(5.5, 5.5)); sns.scatterplot(data=df, x='MW', y='LogP', hue='bioactivity_class', size='pIC50', palette={'active': '#1f77b4', 'inactive': '#ff7f0e', 'intermediate': '#2ca02c'}, sizes=(20, 200), alpha=0.7); plt.xlabel('Molecular Weight (MW)', fontsize=12); plt.ylabel('LogP', fontsize=12); plt.title('Chemical Space: MW vs. LogP', fontsize=14); plt.legend(title='Bioactivity Class'); return plt.gcf()
330
+
331
+ def create_boxplot(df, descriptor):
332
+ plt.figure(figsize=(5.5, 5.5)); sns.boxplot(x='bioactivity_class', y=descriptor, data=df, palette={'active': '#1f77b4', 'inactive': '#ff7f0e', 'intermediate': '#2ca02c'}); plt.xlabel('Bioactivity Class', fontsize=12); plt.ylabel(descriptor, fontsize=12); plt.title(f'{descriptor} by Bioactivity Class', fontsize=14); return plt.gcf()
333
+
334
+ def mannwhitney_test(df, descriptor):
335
+ results = []
336
+ for c1, c2 in [('active', 'inactive'), ('active', 'intermediate'), ('inactive', 'intermediate')]:
337
+ if c1 in df['bioactivity_class'].unique() and c2 in df['bioactivity_class'].unique():
338
+ d1, d2 = df[df.bioactivity_class == c1][descriptor].dropna(), df[df.bioactivity_class == c2][descriptor].dropna()
339
+ if not d1.empty and not d2.empty:
340
+ stat, p = mannwhitneyu(d1, d2)
341
+ results.append({'Comparison': f'{c1.title()} vs {c2.title()}', 'Statistics': stat, 'p-value': p, 'Interpretation': 'Different distribution (p < 0.05)' if p <= 0.05 else 'Same distribution (p > 0.05)'})
342
+ return pd.DataFrame(results)
343
 
344
  # ==============================================================================
345
  # === STEP 2: FEATURE ENGINEERING FUNCTIONS ===
346
  # ==============================================================================
347
+ def calculate_fingerprints(current_state, fingerprint_type, progress=gr.Progress()):
348
+ input_df = current_state.get('cleaned_data')
349
+ if input_df is None or input_df.empty:
350
+ raise gr.Error("No cleaned data found. Please complete Step 1.")
351
+ if not fingerprint_type:
352
+ raise gr.Error("Please select a fingerprint type.")
353
+
354
+ progress(0, desc="Starting...")
355
+ yield f"��� Starting fingerprint calculation...", None, gr.update(visible=False), None, current_state
356
+
357
+ try:
358
+ smi_file, output_csv = 'molecule.smi', 'fingerprints.csv'
359
+
360
+ input_df[['canonical_smiles', 'canonical_smiles']].to_csv(smi_file, sep='\t', index=False, header=False)
361
 
362
+ if os.path.exists(output_csv):
363
+ os.remove(output_csv)
364
+ descriptortypes = fp_config.get(fingerprint_type)
365
+ if not descriptortypes:
366
+ raise gr.Error(f"Descriptor XML for '{fingerprint_type}' not found.")
367
 
368
+ progress(0.3, desc="⚗️ Running PaDEL...")
369
+ yield f"⚗️ Running PaDEL...", None, gr.update(visible=False), None, current_state
 
370
 
371
+ padeldescriptor(
372
+ mol_dir=smi_file,
373
+ d_file=output_csv,
374
+ descriptortypes=descriptortypes,
375
+ detectaromaticity=True,
376
+ standardizenitro=True,
377
+ standardizetautomers=True,
378
+ threads=-1,
379
+ removesalt=True,
380
+ log=False,
381
+ fingerprints=True
382
+ )
383
+
384
+ if not os.path.exists(output_csv) or os.path.getsize(output_csv) == 0:
385
+ raise gr.Error("PaDEL failed to produce an output file. Check molecule validity.")
386
 
387
+ progress(0.7, desc="📊 Processing results...")
388
+ yield "📊 Processing results...", None, gr.update(visible=False), None, current_state
389
+
390
+ df_X = pd.read_csv(output_csv).rename(columns={'Name': 'canonical_smiles'})
391
+ final_df = pd.merge(input_df[['canonical_smiles', 'pIC50']], df_X, on='canonical_smiles', how='inner')
392
+
393
+ current_state['fingerprint_data'] = final_df
394
+ current_state['fingerprint_type'] = fingerprint_type
395
+
396
+ progress(0.9, desc="🖼️ Generating molecule grid...")
397
+
398
+ # Use custom molecule display instead of mols2grid
399
+ try:
400
+ # Try the grid layout first
401
+ mols_html = create_molecule_grid_html(
402
+ final_df,
403
+ smiles_col='canonical_smiles',
404
+ additional_cols=['pIC50'],
405
+ max_mols=50
406
+ )
407
+ except Exception as e:
408
+ print(f"Grid layout failed: {e}, trying table layout...")
409
+ # Fallback to table layout
410
+ mols_html = create_simple_molecule_table(
411
+ final_df,
412
+ smiles_col='canonical_smiles',
413
+ additional_cols=['pIC50'],
414
+ max_mols=20
415
+ )
416
 
417
+ success_msg = f"✅ Success! Generated {len(df_X.columns) -1} descriptors for {len(final_df)} molecules."
418
+ progress(1, desc="Completed!")
419
+ yield success_msg, final_df, gr.update(visible=True), gr.update(value=mols_html, visible=True), current_state
420
 
421
+ except Exception as e:
422
+ raise gr.Error(f"Calculation failed: {e}")
423
+ finally:
424
+ if os.path.exists('molecule.smi'):
425
+ os.remove('molecule.smi')
426
+ if os.path.exists('fingerprints.csv'):
427
+ os.remove('fingerprints.csv')
 
 
428
 
429
  # ==============================================================================
430
+ # === STEP 3: MODELING FUNCTIONS ===
431
  # ==============================================================================
432
+
433
+ # Model definitions
434
+ regression_models = {
435
+ "Linear Regression": LinearRegression,
436
+ "Ridge Regression": Ridge,
437
+ "Lasso Regression": Lasso,
438
+ "Elastic Net": ElasticNet,
439
+ "Bayesian Ridge": BayesianRidge,
440
+ "Huber Regressor": HuberRegressor,
441
+ "Passive Aggressive Regressor": PassiveAggressiveRegressor,
442
+ "Orthogonal Matching Pursuit": OrthogonalMatchingPursuit,
443
+ "Lasso Lars": LassoLars,
444
+ "Decision Tree Regressor": DecisionTreeRegressor,
445
+ "Random Forest Regressor": RandomForestRegressor,
446
+ "Gradient Boosting Regressor": GradientBoostingRegressor,
447
+ "Extra Trees Regressor": ExtraTreesRegressor,
448
+ "AdaBoost Regressor": AdaBoostRegressor,
449
+ "K-Neighbors Regressor": KNeighborsRegressor,
450
+ "Dummy Regressor (Mean)": DummyRegressor,
451
+ }
452
+
453
+ if _has_extra_libs:
454
+ regression_models.update({
455
+ "XGBoost Regressor": xgb.XGBRegressor,
456
+ "LightGBM Regressor": lgb.LGBMRegressor,
457
+ "CatBoost Regressor": cb.CatBoostRegressor,
458
+ })
459
+
460
+ def handle_model_training(current_state, progress=gr.Progress()):
461
+ df_fp = current_state.get('fingerprint_data')
462
+ if df_fp is None or df_fp.empty:
463
+ raise gr.Error("No fingerprint data found. Please complete Step 2.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
464
 
465
+ yield "⚙️ Starting model training...", None, None, current_state
466
+ progress(0, desc="Starting model training...")
467
+
468
+ try:
469
+ X = df_fp.drop(columns=['canonical_smiles', 'pIC50'])
470
+ y = df_fp['pIC50']
471
+
472
+ if X.empty:
473
+ raise gr.Error("No features found for training. Check fingerprint calculation.")
474
+
475
+ # Remove features with zero variance
476
+ sel = VarianceThreshold(threshold=0.0)
477
+ X_filtered = sel.fit_transform(X)
478
+ selected_features = X.columns[sel.get_support()]
479
+
480
+ if len(selected_features) == 0:
481
+ raise gr.Error("No features with non-zero variance found. Cannot train models.")
482
+
483
+ X_filtered = pd.DataFrame(X_filtered, columns=selected_features)
484
+
485
+ X_train, X_test, y_train, y_test = train_test_split(X_filtered, y, test_size=0.2, random_state=42)
486
+
487
+ results = []
488
+ trained_models = {}
489
+
490
+ for i, (name, model_class) in enumerate(regression_models.items()):
491
+ current_progress = (i + 1) / len(regression_models)
492
+ progress(current_progress, desc=f"Training {name}...")
493
+ yield f"Training {name}...", None, None, current_state
494
+
495
+ model = model_class()
496
+ model.fit(X_train, y_train)
497
+ y_pred = model.predict(X_test)
498
+
499
+ mae = mean_absolute_error(y_test, y_pred)
500
+ mse = mean_squared_error(y_test, y_pred)
501
+ r2 = r2_score(y_test, y_pred)
502
+
503
+ results.append({
504
+ "Model": name,
505
+ "MAE": mae,
506
+ "MSE": mse,
507
+ "R2": r2
508
+ })
509
+ trained_models[name] = model
510
+
511
+ results_df = pd.DataFrame(results)
512
+
513
+ # Store results in the state
514
+ current_state['model_results'] = ModelRunResult(
515
+ dataframe=results_df,
516
+ plotter=None, # No plot generated here directly, but could be added
517
+ models=trained_models,
518
+ selected_features=selected_features
519
+ )
520
+ current_state['X_train'] = X_train
521
+ current_state['y_train'] = y_train
522
+ current_state['X_test'] = X_test
523
+ current_state['y_test'] = y_test
524
+
525
+ progress(1, desc="Model training complete!")
526
+ yield "✅ Model training complete!", results_df, gr.update(choices=list(trained_models.keys()), value=list(trained_models.keys())[0]), current_state
527
+
528
+ except Exception as e:
529
+ raise gr.Error(f"Model training failed: {e}")
530
+
531
+ def update_analysis_plots(model_name, feature_count, current_state):
532
+ model_run_results = current_state.get('model_results')
533
+ X_train = current_state.get('X_train')
534
+ y_train = current_state.get('y_train')
535
+ X_test = current_state.get('X_test')
536
+ y_test = current_state.get('y_test')
537
+
538
+ if not model_run_results or not model_name or X_test is None or y_test is None:
539
+ return None, None, None, None, "Please train models first."
540
+
541
+ model = model_run_results.models.get(model_name)
542
+ if model is None:
543
+ return None, None, None, None, f"Model '{model_name}' not found."
544
+
545
+ y_pred_test = model.predict(X_test)
546
+ y_pred_train = model.predict(X_train)
547
+
548
+ # Calculate R2, MAE, MSE for test set
549
+ r2_test = r2_score(y_test, y_pred_test)
550
+ mae_test = mean_absolute_error(y_test, y_pred_test)
551
+ mse_test = mean_squared_error(y_test, y_pred_test)
552
+
553
+ # Plot Y-obs vs Y-pred
554
+ fig_obs_pred, ax_obs_pred = plt.subplots(figsize=(6, 6))
555
+ sns.scatterplot(x=y_test, y=y_pred_test, ax=ax_obs_pred, alpha=0.7, color='blue', label='Test Data')
556
+ sns.scatterplot(x=y_train, y=y_pred_train, ax=ax_obs_pred, alpha=0.7, color='green', label='Train Data')
557
+ ax_obs_pred.plot([min(y_test.min(), y_train.min()), max(y_test.max(), y_train.max())],
558
+ [min(y_test.min(), y_train.min()), max(y_test.max(), y_train.max())],
559
+ color='red', linestyle='--', label='Ideal Prediction')
560
+ ax_obs_pred.set_xlabel("Observed pIC50", fontsize=12)
561
+ ax_obs_pred.set_ylabel("Predicted pIC50", fontsize=12)
562
+ ax_obs_pred.set_title(f"{model_name}: Observed vs. Predicted pIC50", fontsize=14)
563
+ ax_obs_pred.legend()
564
+ plt.close(fig_obs_pred)
565
+
566
+ # Plot Residuals
567
+ residuals = y_test - y_pred_test
568
+ fig_residuals, ax_residuals = plt.subplots(figsize=(6, 6))
569
+ sns.scatterplot(x=y_pred_test, y=residuals, ax=ax_residuals, alpha=0.7, color='purple')
570
+ ax_residuals.axhline(y=0, color='red', linestyle='--')
571
+ ax_residuals.set_xlabel("Predicted pIC50", fontsize=12)
572
+ ax_residuals.set_ylabel("Residuals (Observed - Predicted)", fontsize=12)
573
+ ax_residuals.set_title(f"{model_name}: Residuals Plot", fontsize=14)
574
+ plt.close(fig_residuals)
575
+
576
+ # Feature Importance Plot (if applicable)
577
+ fig_feature_importance = None
578
+ if hasattr(model, 'feature_importances_') and feature_count > 0:
579
+ feature_importances = pd.Series(model.feature_importances_, index=model_run_results.selected_features)
580
+ top_features = feature_importances.nlargest(feature_count)
581
+
582
+ fig_feature_importance, ax_fi = plt.subplots(figsize=(8, 6))
583
+ sns.barplot(x=top_features.values, y=top_features.index, ax=ax_fi, palette='viridis')
584
+ ax_fi.set_xlabel("Importance", fontsize=12)
585
+ ax_fi.set_ylabel("Feature", fontsize=12)
586
+ ax_fi.set_title(f"{model_name}: Top {feature_count} Feature Importances", fontsize=14)
587
+ plt.tight_layout()
588
+ plt.close(fig_feature_importance)
589
+ elif hasattr(model, 'coef_') and feature_count > 0 and model_name not in ["Dummy Regressor (Mean)", "K-Neighbors Regressor"]:
590
+ # For linear models, coefficients can be used as importance
591
+ feature_importances = pd.Series(model.coef_, index=model_run_results.selected_features)
592
+ top_features = feature_importances.abs().nlargest(feature_count) # Use absolute value for ranking
593
+ top_features_values = feature_importances[top_features.index] # Get actual signed values
594
+
595
+ fig_feature_importance, ax_fi = plt.subplots(figsize=(8, 6))
596
+ sns.barplot(x=top_features_values.values, y=top_features_values.index, ax=ax_fi, palette='coolwarm')
597
+ ax_fi.set_xlabel("Coefficient Value", fontsize=12)
598
+ ax_fi.set_ylabel("Feature", fontsize=12)
599
+ ax_fi.set_title(f"{model_name}: Top {feature_count} Feature Coefficients", fontsize=14)
600
+ plt.tight_layout()
601
+ plt.close(fig_feature_importance)
602
+
603
+
604
+ return fig_obs_pred, fig_residuals, fig_feature_importance, \
605
+ f"R2 (Test): {r2_test:.4f}, MAE (Test): {mae_test:.4f}, MSE (Test): {mse_test:.4f}", \
606
+ "Plots updated."
607
+
608
+ # Updated prediction function
609
+ def predict_on_upload(uploaded_file, model_name, current_state, progress=gr.Progress()):
610
+ if not uploaded_file:
611
+ raise gr.Error("Please upload a file.")
612
+ if not model_name:
613
+ raise gr.Error("Please select a trained model.")
614
 
615
+ model_run_results = current_state.get('model_results')
616
+ fingerprint_type = current_state.get('fingerprint_type')
617
+ if not model_run_results or not fingerprint_type:
618
+ raise gr.Error("Please run Steps 2 and 3 first.")
 
 
 
 
 
619
 
620
+ model = model_run_results.models.get(model_name)
621
+ selected_features = model_run_results.selected_features
622
+ if model is None:
623
+ raise gr.Error(f"Model '{model_name}' not found.")
624
 
625
+ smi_file, output_csv = 'predict.smi', 'predict_fp.csv'
626
+ try:
627
+ progress(0, desc="Reading & processing new molecules...")
628
+ yield "Reading uploaded file...", None, None
 
 
629
 
630
+ df_new = pd.read_csv(uploaded_file.name)
631
+ if 'canonical_smiles' not in df_new.columns:
632
+ raise gr.Error("CSV must contain a 'canonical_smiles' column.")
633
+ df_new = df_new.reset_index().rename(columns={'index': 'mol_id'})
634
 
635
+ padel_input = pd.DataFrame({
636
+ 'smiles': df_new['canonical_smiles'],
637
+ 'name': df_new['mol_id']
638
+ })
639
+ padel_input.to_csv(smi_file, sep='\t', index=False, header=False)
640
+ if os.path.exists(output_csv):
641
+ os.remove(output_csv)
642
 
643
+ progress(0.3, desc="Calculating fingerprints...")
644
+ yield "Calculating fingerprints for new molecules...", None, None
645
 
646
+ padeldescriptor(
647
+ mol_dir=smi_file,
648
+ d_file=output_csv,
649
+ descriptortypes=fp_config.get(fingerprint_type),
650
+ detectaromaticity=True,
651
+ standardizenitro=True,
652
+ threads=-1,
653
+ removesalt=True,
654
+ log=False,
655
+ fingerprints=True
656
+ )
657
 
658
+ if not os.path.exists(output_csv) or os.path.getsize(output_csv) == 0:
659
+ raise gr.Error("PaDEL calculation failed for the uploaded molecules.")
660
 
661
+ progress(0.7, desc="Aligning features and predicting...")
662
+ yield "Aligning features and predicting...", None, None
663
+
664
+ df_fp = pd.read_csv(output_csv).rename(columns={'Name': 'mol_id'})
665
+ X_new = df_fp.set_index('mol_id')
666
+ X_new_aligned = X_new.reindex(columns=selected_features, fill_value=0)[selected_features]
667
+ predictions = model.predict(X_new_aligned)
668
+ results_subset = pd.DataFrame({
669
+ 'mol_id': X_new_aligned.index,
670
+ 'predicted_pIC50': predictions
671
+ })
672
+ df_results = pd.merge(df_new, results_subset, on='mol_id', how='left')
673
+
674
+ progress(0.9, desc="Generating visualization...")
675
+ yield "Generating visualization...", None, None
676
 
677
+ df_grid_view = df_results.dropna(subset=['predicted_pIC50']).copy()
678
+ mols_html = "<h3>No molecules with successful predictions to display.</h3>"
 
 
 
 
 
 
 
 
679
 
680
+ if not df_grid_view.empty:
681
+ try:
682
+ # Use custom molecule display
683
+ mols_html = create_molecule_grid_html(
684
+ df_grid_view,
685
+ smiles_col='canonical_smiles',
686
+ additional_cols=['predicted_pIC50'],
687
+ max_mols=50
688
+ )
689
+ except Exception as e:
690
+ print(f"Grid layout failed: {e}, trying table layout...")
691
+ mols_html = create_simple_molecule_table(
692
+ df_grid_view,
693
+ smiles_col='canonical_smiles',
694
+ additional_cols=['predicted_pIC50'],
695
+ max_mols=20
696
+ )
697
+
698
+ progress(1, desc="Complete!")
699
+ yield "✅ Prediction complete.", df_results[['canonical_smiles', 'predicted_pIC50']], mols_html
700
+
701
+ finally:
702
+ if os.path.exists(smi_file):
703
+ os.remove(smi_file)
704
+ if os.path.exists(output_csv):
705
+ os.remove(output_csv)
706
 
707
  # ==============================================================================
708
+ # === GRADIO INTERFACE LAYOUT ===
709
  # ==============================================================================
710
+
711
+ with gr.Blocks(css=".container { max-width: 1200px; margin: auto; }") as demo:
712
+ app_state = gr.State({}) # Store application state (e.g., fetched data, trained models)
713
+
714
+ gr.Markdown("# 💊 Bioactivity Prediction App")
715
+ gr.Markdown("---")
716
+
717
+ with gr.Tabs():
718
+ with gr.TabItem("Step 1: Data Collection & EDA"):
719
+ with gr.Row():
720
+ with gr.Column():
721
+ gr.Markdown("## 1.1 ChEMBL Target Search")
722
+ target_query = gr.Textbox(
723
+ label="Enter target protein name (e.g., 'EGFR', 'acetylcholinesterase')",
724
+ placeholder="EGFR"
725
+ )
726
+ search_target_btn = gr.Button("Search ChEMBL")
727
+ target_output_df = gr.DataFrame(
728
+ label="Search Results",
729
+ headers=["target_chembl_id", "pref_name", "organism"],
730
+ max_rows=5
731
+ )
732
+ status_step1_search = gr.Textbox(label="Status", interactive=False)
733
+
734
+ gr.Markdown("## 1.2 Fetch Bioactivity Data")
735
+ chembl_id_selector = gr.Dropdown(
736
+ label="Select Target ChEMBL ID",
737
+ choices=[], interactive=True
738
+ )
739
+ fetch_data_btn = gr.Button("Fetch Bioactivity Data (IC50)")
740
+ bioactivity_output_df = gr.DataFrame(label="Raw Bioactivity Data", max_rows=5)
741
+ status_step1_fetch = gr.Textbox(label="Status", interactive=False)
742
+
743
+ gr.Markdown("## 1.3 Clean & Process Data")
744
+ process_data_btn = gr.Button("Process & Calculate pIC50/Lipinski Descriptors")
745
+ cleaned_data_output_df = gr.DataFrame(label="Cleaned & Processed Data", max_rows=5)
746
+ status_step1_process_clean = gr.Textbox(label="Status", interactive=False)
747
+ download_s1 = gr.DownloadButton("Download Cleaned Data (CSV)", visible=False, interactive=False)
748
+
749
+ with gr.Column():
750
+ gr.Markdown("## 1.4 Exploratory Data Analysis (EDA)")
751
+ bioactivity_class_selector = gr.CheckboxGroup(
752
+ label="Select Bioactivity Classes for EDA",
753
+ choices=["active", "inactive", "intermediate"],
754
+ value=["active", "inactive", "intermediate"],
755
+ interactive=True
756
+ )
757
+ run_eda_btn = gr.Button("Run EDA")
758
+ status_step1_process = gr.Textbox(label="Status", interactive=False)
759
+
760
+ gr.Markdown("### Bioactivity Class Frequency")
761
+ freq_plot_output = gr.Plot(label="Bioactivity Class Frequency")
762
+
763
+ gr.Markdown("### Chemical Space (MW vs LogP)")
764
+ scatter_plot_output = gr.Plot(label="MW vs LogP Scatter Plot")
765
+
766
+ gr.Markdown("### pIC50 Distribution")
767
+ pic50_plot_output = gr.Plot(label="pIC50 Box Plot")
768
+ pic50_stats_output = gr.DataFrame(label="pIC50 Mann-Whitney U Test", max_rows=5)
769
+
770
+ gr.Markdown("### Molecular Weight (MW) Distribution")
771
+ mw_plot_output = gr.Plot(label="MW Box Plot")
772
+ mw_stats_output = gr.DataFrame(label="MW Mann-Whitney U Test", max_rows=5)
773
+
774
+ gr.Markdown("### LogP Distribution")
775
+ logp_plot_output = gr.Plot(label="LogP Box Plot")
776
+ logp_stats_output = gr.DataFrame(label="LogP Mann-Whitney U Test", max_rows=5)
777
+
778
+ gr.Markdown("### Hydrogen Donors Distribution")
779
+ hdonors_plot_output = gr.Plot(label="Hydrogen Donors Box Plot")
780
+ hdonors_stats_output = gr.DataFrame(label="Hydrogen Donors Mann-Whitney U Test", max_rows=5)
781
+
782
+ gr.Markdown("### Hydrogen Acceptors Distribution")
783
+ hacceptors_plot_output = gr.Plot(label="Hydrogen Acceptors Box Plot")
784
+ hacceptors_stats_output = gr.DataFrame(label="Hydrogen Acceptors Mann-Whitney U Test", max_rows=5)
785
+
786
+ with gr.TabItem("Step 2: Feature Engineering (Fingerprints)"):
787
+ gr.Markdown("## 2.1 Calculate Molecular Fingerprints")
788
+ gr.Markdown(f"Available Fingerprint Types: {', '.join(FP_list)}")
789
+ fingerprint_dropdown = gr.Dropdown(
790
+ label="Select Fingerprint Type",
791
+ choices=FP_list,
792
+ interactive=True,
793
+ value=FP_list[0] if FP_list else None # Set default if available
794
+ )
795
+ calculate_fp_btn = gr.Button("Calculate Fingerprints")
796
+ status_step2 = gr.Textbox(label="Status", interactive=False)
797
+ output_df_s2 = gr.DataFrame(label="Fingerprint Data (First 5 rows, with pIC50)", max_rows=5)
798
+ download_s2 = gr.DownloadButton("Download Fingerprint Data (CSV)", visible=False, interactive=False)
799
+
800
+ # Using HTML component for custom molecule display
801
+ mols_grid_s2 = gr.HTML(label="Molecules with pIC50", visible=True)
802
+
803
+
804
+ with gr.TabItem("Step 3: Model Training & Evaluation"):
805
+ gr.Markdown("## 3.1 Train Regression Models")
806
+ train_models_btn = gr.Button("Train Models")
807
+ status_step3_train = gr.Textbox(label="Status", interactive=False)
808
+ model_results_df = gr.DataFrame(label="Model Performance Metrics", max_rows=10)
809
+
810
+ gr.Markdown("## 3.2 Model Analysis")
811
+ model_selector_s3 = gr.Dropdown(
812
+ label="Select Model for Detailed Analysis",
813
+ choices=[],
814
+ interactive=True
815
+ )
816
+ feature_count_s3 = gr.Slider(
817
+ minimum=0, maximum=50, step=1, value=10,
818
+ label="Number of Top Features to Display", interactive=True
819
+ )
820
+ model_metrics_summary = gr.Textbox(label="Selected Model Metrics (Test Set)", interactive=False)
821
+ obs_pred_plot = gr.Plot(label="Observed vs. Predicted pIC50")
822
+ residuals_plot = gr.Plot(label="Residuals Plot")
823
+ feature_importance_plot = gr.Plot(label="Feature Importance/Coefficients")
824
+
825
+ with gr.TabItem("Step 4: Prediction on New Data"):
826
+ gr.Markdown("## 4.1 Upload New Molecules & Predict")
827
+ upload_new_mols = gr.File(label="Upload CSV with 'canonical_smiles' column")
828
+ model_selector_s4 = gr.Dropdown(
829
+ label="Select Trained Model for Prediction",
830
+ choices=[],
831
+ interactive=True
832
+ )
833
+ predict_btn = gr.Button("Predict pIC50 for New Molecules")
834
+ status_step4 = gr.Textbox(label="Status", interactive=False)
835
+ predictions_output_df = gr.DataFrame(label="Predictions for New Molecules", max_rows=10)
836
+ download_s4 = gr.DownloadButton("Download Predictions (CSV)", visible=False, interactive=False)
837
+ # Using HTML component for custom molecule display
838
+ mols_grid_s4 = gr.HTML(label="New Molecules with Predicted pIC50", visible=True)
839
 
840
  # --- EVENT HANDLERS ---
841
+
842
+ # Step 1 Callbacks
843
+ search_target_btn.click(
844
+ fn=lambda query: (get_target_chembl_id(query), gr.update(visible=True)), # Return two values
845
+ inputs=target_query,
846
+ outputs=[target_output_df, chembl_id_selector, status_step1_search]
847
+ )
848
+
849
+ chembl_id_selector.change(
850
+ fn=lambda x: gr.update(value=x),
851
+ inputs=chembl_id_selector,
852
+ outputs=chembl_id_selector # This is just to ensure the value is updated
853
+ )
854
+
855
+ fetch_data_btn.click(
856
+ fn=lambda target_id: (get_bioactivity_data(target_id), gr.update(value=target_id)),
857
+ inputs=chembl_id_selector,
858
+ outputs=[bioactivity_output_df, status_step1_fetch]
859
+ ).then(
860
+ fn=lambda df: (df, df.copy()), # Pass the fetched df to the state
861
+ inputs=bioactivity_output_df,
862
+ outputs=[gr.State(value={}, key='raw_data'), app_state]
863
+ )
864
+
865
+ process_data_btn.click(
866
+ fn=lambda current_state: clean_and_process_data(current_state.get('raw_data')),
867
+ inputs=app_state,
868
+ outputs=[cleaned_data_output_df, status_step1_process_clean]
869
+ ).then(
870
+ fn=lambda df: (gr.update(visible=True), df), # Show download button and update state
871
+ inputs=cleaned_data_output_df,
872
+ outputs=[download_s1, gr.State(value={}, key='cleaned_data'), app_state]
873
+ )
874
+
875
+ @download_s1.click(inputs=app_state, outputs=download_s1, show_progress="hidden")
876
+ def download_handler_s1(current_state):
877
+ df_to_download = current_state.get('cleaned_data')
878
+ if df_to_download is None:
879
+ raise gr.Error("No data to download. Please process data first.")
880
+ # Create a dummy file path for Gradio to handle the download
881
+ file_path = "cleaned_data.csv"
882
+ df_to_download.to_csv(file_path, index=False)
883
+ return gr.File(file_path, visible=True)
884
+
885
+ run_eda_btn.click(
886
+ fn=lambda df, classes: run_eda_analysis(df, classes),
887
+ inputs=[cleaned_data_output_df, bioactivity_class_selector],
888
+ outputs=[freq_plot_output, scatter_plot_output, pic50_plot_output, pic50_stats_output,
889
+ mw_plot_output, mw_stats_output, logp_plot_output, logp_stats_output,
890
+ hdonors_plot_output, hdonors_stats_output, hacceptors_plot_output, hacceptors_stats_output,
891
+ status_step1_process]
892
+ )
893
 
894
+ # Update EDA plots on filter change (if data is already processed)
895
+ bioactivity_class_selector.change(
896
+ fn=lambda current_state, selected_classes: run_eda_analysis(current_state.get('cleaned_data'), selected_classes),
897
+ inputs=[app_state, bioactivity_class_selector],
898
+ outputs=[freq_plot_output, scatter_plot_output, pic50_plot_output, pic50_stats_output,
899
+ mw_plot_output, mw_stats_output, logp_plot_output, logp_stats_output,
900
+ hdonors_plot_output, hdonors_stats_output, hacceptors_plot_output, hacceptors_stats_output,
901
+ status_step1_process],
902
+ show_progress="minimal"
903
+ )
904
+
905
+
906
+ # Step 2 Callbacks
907
+ calculate_fp_btn.click(
908
+ fn=calculate_fingerprints,
909
+ inputs=[app_state, fingerprint_dropdown],
910
+ outputs=[status_step2, output_df_s2, download_s2, mols_grid_s2, app_state]
911
+ )
912
 
913
+ @download_s2.click(inputs=app_state, outputs=download_s2, show_progress="hidden")
914
+ def download_handler(current_state):
915
+ df_to_download = current_state.get('fingerprint_data')
916
+ if df_to_download is None:
917
+ raise gr.Error("No data to download. Please calculate fingerprints first.")
918
+ file_path = "fingerprint_data.csv"
919
+ df_to_download.to_csv(file_path, index=False)
920
+ return gr.File(file_path, visible=True)
921
+
922
+
923
+ # Step 3 Callbacks
924
+ train_models_btn.click(
925
+ fn=handle_model_training,
926
+ inputs=[app_state],
927
+ outputs=[status_step3_train, model_results_df, model_selector_s3, app_state]
928
+ )
929
+
930
+ # Update plots when model or feature count changes
931
+ for listener in [model_selector_s3.change, feature_count_s3.change]:
932
+ listener(
933
+ fn=update_analysis_plots,
934
+ inputs=[model_selector_s3, feature_count_s3, app_state],
935
+ outputs=[obs_pred_plot, residuals_plot, feature_importance_plot, model_metrics_summary, status_step3_train]
936
+ )
937
+
938
+ # Update model selector in Step 4 when models are trained
939
+ model_selector_s3.change(
940
+ fn=lambda choice: gr.update(choices=model_selector_s3.choices, value=choice),
941
+ inputs=model_selector_s3,
942
+ outputs=model_selector_s4
943
+ )
944
+
945
+
946
+ # Step 4 Callbacks
947
+ predict_btn.click(
948
+ fn=predict_on_upload,
949
+ inputs=[upload_new_mols, model_selector_s4, app_state],
950
+ outputs=[status_step4, predictions_output_df, mols_grid_s4]
951
+ ).then(
952
+ fn=lambda: gr.update(visible=True),
953
+ outputs=download_s4
954
+ )
955
+
956
+ @download_s4.click(inputs=predictions_output_df, outputs=download_s4, show_progress="hidden")
957
+ def download_predictions(df_predictions):
958
+ if df_predictions is None or df_predictions.empty:
959
+ raise gr.Error("No predictions to download.")
960
+ file_path = "predictions.csv"
961
+ df_predictions.to_csv(file_path, index=False)
962
+ return gr.File(file_path, visible=True)
963
+
964
 
965
+ demo.launch()