alidenewade commited on
Commit
0732634
·
verified ·
1 Parent(s): 28d310c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +412 -893
app.py CHANGED
@@ -1,69 +1,70 @@
1
  # --- IMPORTS ---
2
  # Core and Data Handling
3
- import gradio as gr
4
- import pandas as pd
5
- import numpy as np
6
- import os
7
- import glob
8
- import time
9
- import warnings
10
 
11
  # Chemistry and Cheminformatics
12
- from rdkit import Chem
13
- from rdkit.Chem import Descriptors, Lipinski
14
- from chembl_webresource_client.new_client import new_client
15
- from padelpy import padeldescriptor
16
- # import mols2grid # This line will be removed
17
 
18
  # Plotting and Visualization
19
- import matplotlib.pyplot as plt
20
- import seaborn as sns
21
- from scipy import stats
22
- from scipy.stats import mannwhitneyu
23
 
24
  # Machine Learning Models and Metrics
25
- from sklearn.model_selection import train_test_split
26
- from sklearn.feature_selection import VarianceThreshold
27
- from sklearn.linear_model import (
28
- LinearRegression, Ridge, Lasso, ElasticNet, BayesianRidge,
29
- HuberRegressor, PassiveAggressiveRegressor, OrthogonalMatchingPursuit,
30
- LassoLars
31
  )
32
- from sklearn.tree import DecisionTreeRegressor
33
- from sklearn.ensemble import (
34
- RandomForestRegressor, GradientBoostingRegressor, ExtraTreesRegressor,
35
- AdaBoostRegressor
36
  )
37
- from sklearn.neighbors import KNeighborsRegressor
38
- from sklearn.dummy import DummyRegressor
39
- from sklearn.metrics import (
40
- mean_absolute_error, mean_squared_error, r2_score
41
  )
42
 
43
  # A placeholder class to store all results from a modeling run
44
- class ModelRunResult:
45
- def __init__(self, dataframe, plotter, models, selected_features):
46
- self.dataframe = dataframe
47
- self.plotter = plotter
48
- self.models = models
49
- self.selected_features = selected_features
50
 
51
  # Optional Advanced Models
52
- try:
53
- import xgboost as xgb
54
- import lightgbm as lgb
55
- import catboost as cb
56
- _has_extra_libs = True
57
- except ImportError:
58
- _has_extra_libs = False
59
- warnings.warn("Optional libraries (xgboost, lightgbm, catboost) not found. Some models will be unavailable.")
60
 
61
  # --- GLOBAL CONFIGURATION & SETUP ---
62
- warnings.filterwarnings("ignore")
63
- sns.set_theme(style='whitegrid')
64
 
65
  # --- FINGERPRINT CONFIGURATION ---
66
  DESCRIPTOR_DIR = "padel_descriptors"
 
67
  # Check if the descriptor directory exists and contains files
68
  if not os.path.isdir(DESCRIPTOR_DIR):
69
  warnings.warn(
@@ -73,893 +74,411 @@ if not os.path.isdir(DESCRIPTOR_DIR):
73
  xml_files = []
74
  else:
75
  xml_files = sorted(glob.glob(os.path.join(DESCRIPTOR_DIR, '*.xml')))
76
- if not xml_files:
77
- warnings.warn(
78
- f"No descriptor .xml files found in the '{DESCRIPTOR_DIR}' directory. "
79
- "Fingerprint calculation will not be possible."
80
- )
 
81
 
82
  # The key is the filename without extension; the value is the full path to the file
83
  fp_config = {os.path.splitext(os.path.basename(file))[0]: file for file in xml_files}
84
  FP_list = sorted(list(fp_config.keys()))
85
 
86
 
87
- # --- NEW MOLECULE DISPLAY FUNCTIONS ---
88
- def create_molecule_grid_html(df, smiles_col='canonical_smiles', additional_cols=None, max_mols=50):
89
- """
90
- Create a custom HTML grid for displaying molecules using RDKit and base64 encoding
91
- This works better in Hugging Face Spaces than mols2grid
92
- """
93
- from rdkit import Chem
94
- from rdkit.Chem import Draw
95
- import base64
96
- from io import BytesIO
97
-
98
- if df.empty:
99
- return "<h3>No molecules to display.</h3>"
100
-
101
- # Limit number of molecules for performance
102
- df_display = df.head(max_mols).copy()
103
-
104
- # Additional columns to display
105
- if additional_cols is None:
106
- additional_cols = []
107
-
108
- html_parts = ["""
109
- <div style="display: grid; grid-template-columns: repeat(auto-fill, minmax(300px, 1fr)); gap: 15px; padding: 10px;">
110
- """]
111
-
112
- for idx, row in df_display.iterrows():
113
- try:
114
- # Generate molecule image
115
- mol = Chem.MolFromSmiles(row[smiles_col])
116
- if mol is None:
117
- continue
118
-
119
- # Draw molecule
120
- img = Draw.MolToImage(mol, size=(250, 200))
121
-
122
- # Convert to base64
123
- buffer = BytesIO()
124
- img.save(buffer, format='PNG')
125
- img_str = base64.b64encode(buffer.getvalue()).decode()
126
-
127
- # Create card HTML
128
- card_html = f"""
129
- <div style="border: 1px solid #ddd; border-radius: 8px; padding: 10px; background: white;">
130
- <img src="data:image/png;base64,{img_str}" style="width: 100%; height: auto;" alt="Molecule"/>
131
- <div style="margin-top: 8px; font-size: 12px;">
132
- <strong>SMILES:</strong> {row[smiles_col][:50]}{'...' if len(str(row[smiles_col])) > 50 else ''}<br/>
133
- """
134
-
135
- # Add additional columns
136
- for col in additional_cols:
137
- if col in row and pd.notna(row[col]):
138
- value = row[col]
139
- if isinstance(value, float):
140
- value = f"{value:.2f}"
141
- card_html += f"<strong>{col}:</strong> {value}<br/>"
142
-
143
- card_html += """
144
- </div>
145
- </div>
146
- """
147
- html_parts.append(card_html)
148
-
149
- except Exception as e:
150
- print(f"Error processing molecule {idx}: {e}")
151
- continue
152
-
153
- html_parts.append("</div>")
154
-
155
- if len(html_parts) == 2: # Only header and footer, no molecules processed
156
- return "<h3>No valid molecules to display.</h3>"
157
-
158
- return "".join(html_parts)
159
-
160
- def create_simple_molecule_table(df, smiles_col='canonical_smiles', additional_cols=None, max_mols=20):
161
- """
162
- Create a simple HTML table with molecule images - fallback option
163
- """
164
- from rdkit import Chem
165
- from rdkit.Chem import Draw
166
- import base64
167
- from io import BytesIO
168
-
169
- if df.empty:
170
- return "<h3>No molecules to display.</h3>"
171
-
172
- df_display = df.head(max_mols).copy()
173
-
174
- if additional_cols is None:
175
- additional_cols = []
176
-
177
- # Start HTML table
178
- html = """
179
- <table style="border-collapse: collapse; width: 100%;">
180
- <thead>
181
- <tr style="background-color: #f2f2f2;">
182
- <th style="border: 1px solid #ddd; padding: 8px;">Structure</th>
183
- <th style="border: 1px solid #ddd; padding: 8px;">SMILES</th>
184
- """
185
-
186
- for col in additional_cols:
187
- html += f'<th style="border: 1px solid #ddd; padding: 8px;">{col}</th>'
188
-
189
- html += """
190
- </tr>
191
- </thead>
192
- <tbody>
193
- """
194
-
195
- for idx, row in df_display.iterrows():
196
- try:
197
- mol = Chem.MolFromSmiles(row[smiles_col])
198
- if mol is None:
199
- continue
200
-
201
- # Generate image
202
- img = Draw.MolToImage(mol, size=(200, 150))
203
- buffer = BytesIO()
204
- img.save(buffer, format='PNG')
205
- img_str = base64.b64encode(buffer.getvalue()).decode()
206
-
207
- html += f"""
208
- <tr>
209
- <td style="border: 1px solid #ddd; padding: 8px; text-align: center;">
210
- <img src="data:image/png;base64,{img_str}" style="max-width: 200px; height: auto;" alt="Molecule"/>
211
- </td>
212
- <td style="border: 1px solid #ddd; padding: 8px; font-family: monospace; font-size: 11px;">
213
- {row[smiles_col][:100]}{'...' if len(str(row[smiles_col])) > 100 else ''}
214
- </td>
215
- """
216
-
217
- for col in additional_cols:
218
- value = row[col] if col in row and pd.notna(row[col]) else "N/A"
219
- if isinstance(value, float):
220
- value = f"{value:.2f}"
221
- html += f'<td style="border: 1px solid #ddd; padding: 8px;">{value}</td>'
222
-
223
- html += "</tr>"
224
-
225
- except Exception as e:
226
- print(f"Error processing molecule {idx}: {e}")
227
- continue
228
-
229
- html += "</tbody></table>"
230
- return html
231
-
232
  # ==============================================================================
233
  # === STEP 1: CORE DATA COLLECTION & EDA FUNCTIONS ===
234
  # ==============================================================================
235
- def get_target_chembl_id(query):
236
- try:
237
- target = new_client.target
238
- res = target.search(query)
239
- if not res:
240
- return pd.DataFrame(), gr.Dropdown(choices=[], value=None), "No targets found for your query."
241
- df = pd.DataFrame(res)
242
- return df[["target_chembl_id", "pref_name", "organism"]], gr.Dropdown(choices=df["target_chembl_id"].tolist()), f"Found {len(df)} targets."
243
- except Exception as e:
244
- raise gr.Error(f"ChEMBL search failed: {e}")
245
-
246
- def get_bioactivity_data(target_id):
247
- try:
248
- activity = new_client.activity
249
- res = activity.filter(target_chembl_id=target_id).filter(standard_type="IC50")
250
- if not res:
251
- return pd.DataFrame(), "No IC50 bioactivity data found for this target."
252
- df = pd.DataFrame(res)
253
- return df, f"Fetched {len(df)} data points."
254
- except Exception as e:
255
- raise gr.Error(f"Failed to fetch bioactivity data: {e}")
256
-
257
- def pIC50_calc(input_df):
258
- df_copy = input_df.copy()
259
- df_copy['standard_value'] = pd.to_numeric(df_copy['standard_value'], errors='coerce')
260
- df_copy.dropna(subset=['standard_value'], inplace=True)
261
- df_copy['standard_value_norm'] = df_copy['standard_value'].apply(lambda x: min(x, 100000000))
262
- pIC50_values = []
263
- for i in df_copy['standard_value_norm']:
264
- if pd.notna(i) and i > 0:
265
- molar = i * (10**-9)
266
- pIC50_values.append(-np.log10(molar))
267
- else:
268
- pIC50_values.append(np.nan)
269
- df_copy['pIC50'] = pIC50_values
270
- df_copy['bioactivity_class'] = df_copy['standard_value_norm'].apply(
271
- lambda x: "inactive" if pd.notna(x) and x >= 10000 else ("active" if pd.notna(x) and x <= 1000 else "intermediate")
 
272
  )
273
- return df_copy.drop(columns=['standard_value', 'standard_value_norm'])
274
-
275
- def lipinski_descriptors(smiles_series):
276
- moldata, valid_smiles = [], []
277
- for elem in smiles_series:
278
- if elem and isinstance(elem, str):
279
- mol = Chem.MolFromSmiles(elem)
280
- if mol:
281
- moldata.append(mol)
282
- valid_smiles.append(elem)
283
- descriptor_rows = []
284
- for mol in moldata:
285
- row = [Descriptors.MolWt(mol), Descriptors.MolLogP(mol), Lipinski.NumHDonors(mol), Lipinski.NumHAcceptors(mol)]
286
- descriptor_rows.append(row)
287
- columnNames = ["MW", "LogP", "NumHDonors", "NumHAcceptors"]
288
- if not descriptor_rows: return pd.DataFrame(columns=columnNames), []
289
- return pd.DataFrame(data=np.array(descriptor_rows), columns=columnNames), valid_smiles
290
-
291
- def clean_and_process_data(df):
292
- if df is None or df.empty: raise gr.Error("No data to process. Please fetch data first.")
293
- if "canonical_smiles" not in df.columns or df["canonical_smiles"].isnull().all():
294
- try:
295
- df["canonical_smiles"] = [c.get("molecule_structures", {}).get("canonical_smiles") for c in new_client.molecule.get(list(df["molecule_chembl_id"]))]
296
- except Exception as e:
297
- raise gr.Error(f"Could not fetch SMILES from ChEMBL: {e}")
298
- df = df[df.standard_value.notna()]
299
- df = df[df.canonical_smiles.notna()]
300
- df.drop_duplicates(['canonical_smiles'], inplace=True)
301
- df["standard_value"] = pd.to_numeric(df["standard_value"], errors='coerce')
302
- df.dropna(subset=['standard_value'], inplace=True)
303
- df_processed = pIC50_calc(df)
304
- df_processed = df_processed[df_processed.pIC50.notna()]
305
- if df_processed.empty: return pd.DataFrame(), "No compounds remaining after pIC50 calculation."
306
- df_lipinski, valid_smiles = lipinski_descriptors(df_processed['canonical_smiles'])
307
- if not valid_smiles: return pd.DataFrame(), "No valid SMILES could be processed for Lipinski descriptors."
308
- df_processed = df_processed[df_processed['canonical_smiles'].isin(valid_smiles)].reset_index(drop=True)
309
- df_lipinski = df_lipinski.reset_index(drop=True)
310
- df_final = pd.concat([df_processed, df_lipinski], axis=1)
311
- return df_final, f"Processing complete. {len(df_final)} compounds remain after cleaning."
312
-
313
- def run_eda_analysis(df, selected_classes):
314
- if df is None or df.empty: raise gr.Error("No data available for analysis.")
315
- df_filtered = df[df.bioactivity_class.isin(selected_classes)].copy()
316
- if df_filtered.empty: return (None, None, None, pd.DataFrame(), None, pd.DataFrame(), None, pd.DataFrame(), None, pd.DataFrame(), None, pd.DataFrame(), "No data for selected classes.")
317
- plots = [create_frequency_plot(df_filtered), create_scatter_plot(df_filtered)]
318
- stats_dfs = []
319
- for desc in ['pIC50', 'MW', 'LogP', 'NumHDonors', 'NumHAcceptors']:
320
- plots.append(create_boxplot(df_filtered, desc))
321
- stats_dfs.append(mannwhitney_test(df_filtered, desc))
322
- plt.close('all')
323
- return (plots[0], plots[1], plots[2], stats_dfs[0], plots[3], stats_dfs[1], plots[4], stats_dfs[2], plots[5], stats_dfs[3], plots[6], stats_dfs[4], f"EDA complete for {len(df_filtered)} compounds.")
324
-
325
- def create_frequency_plot(df):
326
- plt.figure(figsize=(5.5, 5.5)); sns.barplot(x=df['bioactivity_class'].value_counts().index, y=df['bioactivity_class'].value_counts().values, palette={'active': '#1f77b4', 'inactive': '#ff7f0e', 'intermediate': '#2ca02c'}); plt.xlabel('Bioactivity Class', fontsize=12); plt.ylabel('Frequency', fontsize=12); plt.title('Frequency of Bioactivity Classes', fontsize=14); return plt.gcf()
327
-
328
- def create_scatter_plot(df):
329
- plt.figure(figsize=(5.5, 5.5)); sns.scatterplot(data=df, x='MW', y='LogP', hue='bioactivity_class', size='pIC50', palette={'active': '#1f77b4', 'inactive': '#ff7f0e', 'intermediate': '#2ca02c'}, sizes=(20, 200), alpha=0.7); plt.xlabel('Molecular Weight (MW)', fontsize=12); plt.ylabel('LogP', fontsize=12); plt.title('Chemical Space: MW vs. LogP', fontsize=14); plt.legend(title='Bioactivity Class'); return plt.gcf()
330
-
331
- def create_boxplot(df, descriptor):
332
- plt.figure(figsize=(5.5, 5.5)); sns.boxplot(x='bioactivity_class', y=descriptor, data=df, palette={'active': '#1f77b4', 'inactive': '#ff7f0e', 'intermediate': '#2ca02c'}); plt.xlabel('Bioactivity Class', fontsize=12); plt.ylabel(descriptor, fontsize=12); plt.title(f'{descriptor} by Bioactivity Class', fontsize=14); return plt.gcf()
333
-
334
- def mannwhitney_test(df, descriptor):
335
- results = []
336
- for c1, c2 in [('active', 'inactive'), ('active', 'intermediate'), ('inactive', 'intermediate')]:
337
- if c1 in df['bioactivity_class'].unique() and c2 in df['bioactivity_class'].unique():
338
- d1, d2 = df[df.bioactivity_class == c1][descriptor].dropna(), df[df.bioactivity_class == c2][descriptor].dropna()
339
- if not d1.empty and not d2.empty:
340
- stat, p = mannwhitneyu(d1, d2)
341
- results.append({'Comparison': f'{c1.title()} vs {c2.title()}', 'Statistics': stat, 'p-value': p, 'Interpretation': 'Different distribution (p < 0.05)' if p <= 0.05 else 'Same distribution (p > 0.05)'})
342
- return pd.DataFrame(results)
343
 
344
  # ==============================================================================
345
  # === STEP 2: FEATURE ENGINEERING FUNCTIONS ===
346
  # ==============================================================================
347
- def calculate_fingerprints(current_state, fingerprint_type, progress=gr.Progress()):
348
- input_df = current_state.get('cleaned_data')
349
- if input_df is None or input_df.empty:
350
- raise gr.Error("No cleaned data found. Please complete Step 1.")
351
- if not fingerprint_type:
352
- raise gr.Error("Please select a fingerprint type.")
353
-
354
- progress(0, desc="Starting...")
355
- yield f"🧪 Starting fingerprint calculation...", None, gr.update(visible=False), None, current_state
356
-
357
- try:
358
- smi_file, output_csv = 'molecule.smi', 'fingerprints.csv'
359
-
360
- input_df[['canonical_smiles', 'canonical_smiles']].to_csv(smi_file, sep='\t', index=False, header=False)
361
-
362
- if os.path.exists(output_csv):
363
- os.remove(output_csv)
364
- descriptortypes = fp_config.get(fingerprint_type)
365
- if not descriptortypes:
366
- raise gr.Error(f"Descriptor XML for '{fingerprint_type}' not found.")
367
-
368
- progress(0.3, desc="⚗️ Running PaDEL...")
369
- yield f"⚗️ Running PaDEL...", None, gr.update(visible=False), None, current_state
370
-
371
- padeldescriptor(
372
- mol_dir=smi_file,
373
- d_file=output_csv,
374
- descriptortypes=descriptortypes,
375
- detectaromaticity=True,
376
- standardizenitro=True,
377
- standardizetautomers=True,
378
- threads=-1,
379
- removesalt=True,
380
- log=False,
381
- fingerprints=True
382
- )
383
-
384
- if not os.path.exists(output_csv) or os.path.getsize(output_csv) == 0:
385
- raise gr.Error("PaDEL failed to produce an output file. Check molecule validity.")
386
 
387
- progress(0.7, desc="📊 Processing results...")
388
- yield "📊 Processing results...", None, gr.update(visible=False), None, current_state
389
-
390
- df_X = pd.read_csv(output_csv).rename(columns={'Name': 'canonical_smiles'})
391
- final_df = pd.merge(input_df[['canonical_smiles', 'pIC50']], df_X, on='canonical_smiles', how='inner')
 
 
392
 
393
- current_state['fingerprint_data'] = final_df
394
- current_state['fingerprint_type'] = fingerprint_type
395
 
396
- progress(0.9, desc="🖼️ Generating molecule grid...")
 
 
397
 
398
- # Use custom molecule display instead of mols2grid
399
- try:
400
- # Try the grid layout first
401
- mols_html = create_molecule_grid_html(
402
- final_df,
403
- smiles_col='canonical_smiles',
404
- additional_cols=['pIC50'],
405
- max_mols=50
406
- )
407
- except Exception as e:
408
- print(f"Grid layout failed: {e}, trying table layout...")
409
- # Fallback to table layout
410
- mols_html = create_simple_molecule_table(
411
- final_df,
412
- smiles_col='canonical_smiles',
413
- additional_cols=['pIC50'],
414
- max_mols=20
415
- )
416
 
417
- success_msg = f"✅ Success! Generated {len(df_X.columns) -1} descriptors for {len(final_df)} molecules."
418
- progress(1, desc="Completed!")
419
- yield success_msg, final_df, gr.update(visible=True), gr.update(value=mols_html, visible=True), current_state
420
 
421
- except Exception as e:
422
- raise gr.Error(f"Calculation failed: {e}")
423
- finally:
424
- if os.path.exists('molecule.smi'):
425
- os.remove('molecule.smi')
426
- if os.path.exists('fingerprints.csv'):
427
- os.remove('fingerprints.csv')
 
 
428
 
429
  # ==============================================================================
430
- # === STEP 3: MODELING FUNCTIONS ===
431
  # ==============================================================================
432
-
433
- # Model definitions
434
- regression_models = {
435
- "Linear Regression": LinearRegression,
436
- "Ridge Regression": Ridge,
437
- "Lasso Regression": Lasso,
438
- "Elastic Net": ElasticNet,
439
- "Bayesian Ridge": BayesianRidge,
440
- "Huber Regressor": HuberRegressor,
441
- "Passive Aggressive Regressor": PassiveAggressiveRegressor,
442
- "Orthogonal Matching Pursuit": OrthogonalMatchingPursuit,
443
- "Lasso Lars": LassoLars,
444
- "Decision Tree Regressor": DecisionTreeRegressor,
445
- "Random Forest Regressor": RandomForestRegressor,
446
- "Gradient Boosting Regressor": GradientBoostingRegressor,
447
- "Extra Trees Regressor": ExtraTreesRegressor,
448
- "AdaBoost Regressor": AdaBoostRegressor,
449
- "K-Neighbors Regressor": KNeighborsRegressor,
450
- "Dummy Regressor (Mean)": DummyRegressor,
451
- }
452
-
453
- if _has_extra_libs:
454
- regression_models.update({
455
- "XGBoost Regressor": xgb.XGBRegressor,
456
- "LightGBM Regressor": lgb.LGBMRegressor,
457
- "CatBoost Regressor": cb.CatBoostRegressor,
458
- })
459
-
460
- def handle_model_training(current_state, progress=gr.Progress()):
461
- df_fp = current_state.get('fingerprint_data')
462
- if df_fp is None or df_fp.empty:
463
- raise gr.Error("No fingerprint data found. Please complete Step 2.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
464
 
465
- yield "⚙️ Starting model training...", None, None, current_state
466
- progress(0, desc="Starting model training...")
467
-
468
- try:
469
- X = df_fp.drop(columns=['canonical_smiles', 'pIC50'])
470
- y = df_fp['pIC50']
471
-
472
- if X.empty:
473
- raise gr.Error("No features found for training. Check fingerprint calculation.")
474
-
475
- # Remove features with zero variance
476
- sel = VarianceThreshold(threshold=0.0)
477
- X_filtered = sel.fit_transform(X)
478
- selected_features = X.columns[sel.get_support()]
479
-
480
- if len(selected_features) == 0:
481
- raise gr.Error("No features with non-zero variance found. Cannot train models.")
482
-
483
- X_filtered = pd.DataFrame(X_filtered, columns=selected_features)
484
-
485
- X_train, X_test, y_train, y_test = train_test_split(X_filtered, y, test_size=0.2, random_state=42)
486
-
487
- results = []
488
- trained_models = {}
489
-
490
- for i, (name, model_class) in enumerate(regression_models.items()):
491
- current_progress = (i + 1) / len(regression_models)
492
- progress(current_progress, desc=f"Training {name}...")
493
- yield f"Training {name}...", None, None, current_state
494
-
495
- model = model_class()
496
- model.fit(X_train, y_train)
497
- y_pred = model.predict(X_test)
498
-
499
- mae = mean_absolute_error(y_test, y_pred)
500
- mse = mean_squared_error(y_test, y_pred)
501
- r2 = r2_score(y_test, y_pred)
502
-
503
- results.append({
504
- "Model": name,
505
- "MAE": mae,
506
- "MSE": mse,
507
- "R2": r2
508
- })
509
- trained_models[name] = model
510
-
511
- results_df = pd.DataFrame(results)
512
-
513
- # Store results in the state
514
- current_state['model_results'] = ModelRunResult(
515
- dataframe=results_df,
516
- plotter=None, # No plot generated here directly, but could be added
517
- models=trained_models,
518
- selected_features=selected_features
519
- )
520
- current_state['X_train'] = X_train
521
- current_state['y_train'] = y_train
522
- current_state['X_test'] = X_test
523
- current_state['y_test'] = y_test
524
-
525
- progress(1, desc="Model training complete!")
526
- yield "✅ Model training complete!", results_df, gr.update(choices=list(trained_models.keys()), value=list(trained_models.keys())[0]), current_state
527
-
528
- except Exception as e:
529
- raise gr.Error(f"Model training failed: {e}")
530
-
531
- def update_analysis_plots(model_name, feature_count, current_state):
532
- model_run_results = current_state.get('model_results')
533
- X_train = current_state.get('X_train')
534
- y_train = current_state.get('y_train')
535
- X_test = current_state.get('X_test')
536
- y_test = current_state.get('y_test')
537
-
538
- if not model_run_results or not model_name or X_test is None or y_test is None:
539
- return None, None, None, None, "Please train models first."
540
-
541
- model = model_run_results.models.get(model_name)
542
- if model is None:
543
- return None, None, None, None, f"Model '{model_name}' not found."
544
-
545
- y_pred_test = model.predict(X_test)
546
- y_pred_train = model.predict(X_train)
547
-
548
- # Calculate R2, MAE, MSE for test set
549
- r2_test = r2_score(y_test, y_pred_test)
550
- mae_test = mean_absolute_error(y_test, y_pred_test)
551
- mse_test = mean_squared_error(y_test, y_pred_test)
552
-
553
- # Plot Y-obs vs Y-pred
554
- fig_obs_pred, ax_obs_pred = plt.subplots(figsize=(6, 6))
555
- sns.scatterplot(x=y_test, y=y_pred_test, ax=ax_obs_pred, alpha=0.7, color='blue', label='Test Data')
556
- sns.scatterplot(x=y_train, y=y_pred_train, ax=ax_obs_pred, alpha=0.7, color='green', label='Train Data')
557
- ax_obs_pred.plot([min(y_test.min(), y_train.min()), max(y_test.max(), y_train.max())],
558
- [min(y_test.min(), y_train.min()), max(y_test.max(), y_train.max())],
559
- color='red', linestyle='--', label='Ideal Prediction')
560
- ax_obs_pred.set_xlabel("Observed pIC50", fontsize=12)
561
- ax_obs_pred.set_ylabel("Predicted pIC50", fontsize=12)
562
- ax_obs_pred.set_title(f"{model_name}: Observed vs. Predicted pIC50", fontsize=14)
563
- ax_obs_pred.legend()
564
- plt.close(fig_obs_pred)
565
-
566
- # Plot Residuals
567
- residuals = y_test - y_pred_test
568
- fig_residuals, ax_residuals = plt.subplots(figsize=(6, 6))
569
- sns.scatterplot(x=y_pred_test, y=residuals, ax=ax_residuals, alpha=0.7, color='purple')
570
- ax_residuals.axhline(y=0, color='red', linestyle='--')
571
- ax_residuals.set_xlabel("Predicted pIC50", fontsize=12)
572
- ax_residuals.set_ylabel("Residuals (Observed - Predicted)", fontsize=12)
573
- ax_residuals.set_title(f"{model_name}: Residuals Plot", fontsize=14)
574
- plt.close(fig_residuals)
575
-
576
- # Feature Importance Plot (if applicable)
577
- fig_feature_importance = None
578
- if hasattr(model, 'feature_importances_') and feature_count > 0:
579
- feature_importances = pd.Series(model.feature_importances_, index=model_run_results.selected_features)
580
- top_features = feature_importances.nlargest(feature_count)
581
-
582
- fig_feature_importance, ax_fi = plt.subplots(figsize=(8, 6))
583
- sns.barplot(x=top_features.values, y=top_features.index, ax=ax_fi, palette='viridis')
584
- ax_fi.set_xlabel("Importance", fontsize=12)
585
- ax_fi.set_ylabel("Feature", fontsize=12)
586
- ax_fi.set_title(f"{model_name}: Top {feature_count} Feature Importances", fontsize=14)
587
- plt.tight_layout()
588
- plt.close(fig_feature_importance)
589
- elif hasattr(model, 'coef_') and feature_count > 0 and model_name not in ["Dummy Regressor (Mean)", "K-Neighbors Regressor"]:
590
- # For linear models, coefficients can be used as importance
591
- feature_importances = pd.Series(model.coef_, index=model_run_results.selected_features)
592
- top_features = feature_importances.abs().nlargest(feature_count) # Use absolute value for ranking
593
- top_features_values = feature_importances[top_features.index] # Get actual signed values
594
-
595
- fig_feature_importance, ax_fi = plt.subplots(figsize=(8, 6))
596
- sns.barplot(x=top_features_values.values, y=top_features_values.index, ax=ax_fi, palette='coolwarm')
597
- ax_fi.set_xlabel("Coefficient Value", fontsize=12)
598
- ax_fi.set_ylabel("Feature", fontsize=12)
599
- ax_fi.set_title(f"{model_name}: Top {feature_count} Feature Coefficients", fontsize=14)
600
- plt.tight_layout()
601
- plt.close(fig_feature_importance)
602
-
603
-
604
- return fig_obs_pred, fig_residuals, fig_feature_importance, \
605
- f"R2 (Test): {r2_test:.4f}, MAE (Test): {mae_test:.4f}, MSE (Test): {mse_test:.4f}", \
606
- "Plots updated."
607
-
608
- # Updated prediction function
609
- def predict_on_upload(uploaded_file, model_name, current_state, progress=gr.Progress()):
610
- if not uploaded_file:
611
- raise gr.Error("Please upload a file.")
612
- if not model_name:
613
- raise gr.Error("Please select a trained model.")
614
 
615
- model_run_results = current_state.get('model_results')
616
- fingerprint_type = current_state.get('fingerprint_type')
617
- if not model_run_results or not fingerprint_type:
618
- raise gr.Error("Please run Steps 2 and 3 first.")
 
 
 
 
 
619
 
620
- model = model_run_results.models.get(model_name)
621
- selected_features = model_run_results.selected_features
622
- if model is None:
623
- raise gr.Error(f"Model '{model_name}' not found.")
624
 
625
- smi_file, output_csv = 'predict.smi', 'predict_fp.csv'
626
- try:
627
- progress(0, desc="Reading & processing new molecules...")
628
- yield "Reading uploaded file...", None, None
629
-
630
- df_new = pd.read_csv(uploaded_file.name)
631
- if 'canonical_smiles' not in df_new.columns:
632
- raise gr.Error("CSV must contain a 'canonical_smiles' column.")
633
- df_new = df_new.reset_index().rename(columns={'index': 'mol_id'})
634
-
635
- padel_input = pd.DataFrame({
636
- 'smiles': df_new['canonical_smiles'],
637
- 'name': df_new['mol_id']
638
- })
639
- padel_input.to_csv(smi_file, sep='\t', index=False, header=False)
640
- if os.path.exists(output_csv):
641
- os.remove(output_csv)
642
 
643
- progress(0.3, desc="Calculating fingerprints...")
644
- yield "Calculating fingerprints for new molecules...", None, None
 
645
 
646
- padeldescriptor(
647
- mol_dir=smi_file,
648
- d_file=output_csv,
649
- descriptortypes=fp_config.get(fingerprint_type),
650
- detectaromaticity=True,
651
- standardizenitro=True,
652
- threads=-1,
653
- removesalt=True,
654
- log=False,
655
- fingerprints=True
656
- )
657
 
658
- if not os.path.exists(output_csv) or os.path.getsize(output_csv) == 0:
659
- raise gr.Error("PaDEL calculation failed for the uploaded molecules.")
660
 
661
- progress(0.7, desc="Aligning features and predicting...")
662
- yield "Aligning features and predicting...", None, None
663
 
664
- df_fp = pd.read_csv(output_csv).rename(columns={'Name': 'mol_id'})
665
- X_new = df_fp.set_index('mol_id')
666
- X_new_aligned = X_new.reindex(columns=selected_features, fill_value=0)[selected_features]
667
- predictions = model.predict(X_new_aligned)
668
- results_subset = pd.DataFrame({
669
- 'mol_id': X_new_aligned.index,
670
- 'predicted_pIC50': predictions
671
- })
672
- df_results = pd.merge(df_new, results_subset, on='mol_id', how='left')
673
-
674
- progress(0.9, desc="Generating visualization...")
675
- yield "Generating visualization...", None, None
676
 
677
- df_grid_view = df_results.dropna(subset=['predicted_pIC50']).copy()
678
- mols_html = "<h3>No molecules with successful predictions to display.</h3>"
679
-
680
- if not df_grid_view.empty:
681
- try:
682
- # Use custom molecule display
683
- mols_html = create_molecule_grid_html(
684
- df_grid_view,
685
- smiles_col='canonical_smiles',
686
- additional_cols=['predicted_pIC50'],
687
- max_mols=50
688
- )
689
- except Exception as e:
690
- print(f"Grid layout failed: {e}, trying table layout...")
691
- mols_html = create_simple_molecule_table(
692
- df_grid_view,
693
- smiles_col='canonical_smiles',
694
- additional_cols=['predicted_pIC50'],
695
- max_mols=20
696
- )
697
 
698
- progress(1, desc="Complete!")
699
- yield " Prediction complete.", df_results[['canonical_smiles', 'predicted_pIC50']], mols_html
 
 
 
 
 
 
 
 
700
 
701
- finally:
702
- if os.path.exists(smi_file):
703
- os.remove(smi_file)
704
- if os.path.exists(output_csv):
705
- os.remove(output_csv)
706
 
707
  # ==============================================================================
708
- # === GRADIO INTERFACE LAYOUT ===
709
  # ==============================================================================
710
-
711
- with gr.Blocks(css=".container { max-width: 1200px; margin: auto; }") as demo:
712
- app_state = gr.State({}) # Store application state (e.g., fetched data, trained models)
713
-
714
- gr.Markdown("# 💊 Bioactivity Prediction App")
715
- gr.Markdown("---")
716
-
717
- with gr.Tabs():
718
- with gr.TabItem("Step 1: Data Collection & EDA"):
719
- with gr.Row():
720
- with gr.Column():
721
- gr.Markdown("## 1.1 ChEMBL Target Search")
722
- target_query = gr.Textbox(
723
- label="Enter target protein name (e.g., 'EGFR', 'acetylcholinesterase')",
724
- placeholder="EGFR"
725
- )
726
- search_target_btn = gr.Button("Search ChEMBL")
727
- target_output_df = gr.DataFrame(
728
- label="Search Results",
729
- headers=["target_chembl_id", "pref_name", "organism"],
730
- max_rows=5
731
- )
732
- status_step1_search = gr.Textbox(label="Status", interactive=False)
733
-
734
- gr.Markdown("## 1.2 Fetch Bioactivity Data")
735
- chembl_id_selector = gr.Dropdown(
736
- label="Select Target ChEMBL ID",
737
- choices=[], interactive=True
738
- )
739
- fetch_data_btn = gr.Button("Fetch Bioactivity Data (IC50)")
740
- bioactivity_output_df = gr.DataFrame(label="Raw Bioactivity Data", max_rows=5)
741
- status_step1_fetch = gr.Textbox(label="Status", interactive=False)
742
-
743
- gr.Markdown("## 1.3 Clean & Process Data")
744
- process_data_btn = gr.Button("Process & Calculate pIC50/Lipinski Descriptors")
745
- cleaned_data_output_df = gr.DataFrame(label="Cleaned & Processed Data", max_rows=5)
746
- status_step1_process_clean = gr.Textbox(label="Status", interactive=False)
747
- download_s1 = gr.DownloadButton("Download Cleaned Data (CSV)", visible=False, interactive=False)
748
-
749
- with gr.Column():
750
- gr.Markdown("## 1.4 Exploratory Data Analysis (EDA)")
751
- bioactivity_class_selector = gr.CheckboxGroup(
752
- label="Select Bioactivity Classes for EDA",
753
- choices=["active", "inactive", "intermediate"],
754
- value=["active", "inactive", "intermediate"],
755
- interactive=True
756
- )
757
- run_eda_btn = gr.Button("Run EDA")
758
- status_step1_process = gr.Textbox(label="Status", interactive=False)
759
-
760
- gr.Markdown("### Bioactivity Class Frequency")
761
- freq_plot_output = gr.Plot(label="Bioactivity Class Frequency")
762
-
763
- gr.Markdown("### Chemical Space (MW vs LogP)")
764
- scatter_plot_output = gr.Plot(label="MW vs LogP Scatter Plot")
765
-
766
- gr.Markdown("### pIC50 Distribution")
767
- pic50_plot_output = gr.Plot(label="pIC50 Box Plot")
768
- pic50_stats_output = gr.DataFrame(label="pIC50 Mann-Whitney U Test", max_rows=5)
769
-
770
- gr.Markdown("### Molecular Weight (MW) Distribution")
771
- mw_plot_output = gr.Plot(label="MW Box Plot")
772
- mw_stats_output = gr.DataFrame(label="MW Mann-Whitney U Test", max_rows=5)
773
-
774
- gr.Markdown("### LogP Distribution")
775
- logp_plot_output = gr.Plot(label="LogP Box Plot")
776
- logp_stats_output = gr.DataFrame(label="LogP Mann-Whitney U Test", max_rows=5)
777
-
778
- gr.Markdown("### Hydrogen Donors Distribution")
779
- hdonors_plot_output = gr.Plot(label="Hydrogen Donors Box Plot")
780
- hdonors_stats_output = gr.DataFrame(label="Hydrogen Donors Mann-Whitney U Test", max_rows=5)
781
-
782
- gr.Markdown("### Hydrogen Acceptors Distribution")
783
- hacceptors_plot_output = gr.Plot(label="Hydrogen Acceptors Box Plot")
784
- hacceptors_stats_output = gr.DataFrame(label="Hydrogen Acceptors Mann-Whitney U Test", max_rows=5)
785
-
786
- with gr.TabItem("Step 2: Feature Engineering (Fingerprints)"):
787
- gr.Markdown("## 2.1 Calculate Molecular Fingerprints")
788
- gr.Markdown(f"Available Fingerprint Types: {', '.join(FP_list)}")
789
- fingerprint_dropdown = gr.Dropdown(
790
- label="Select Fingerprint Type",
791
- choices=FP_list,
792
- interactive=True,
793
- value=FP_list[0] if FP_list else None # Set default if available
794
- )
795
- calculate_fp_btn = gr.Button("Calculate Fingerprints")
796
- status_step2 = gr.Textbox(label="Status", interactive=False)
797
- output_df_s2 = gr.DataFrame(label="Fingerprint Data (First 5 rows, with pIC50)", max_rows=5)
798
- download_s2 = gr.DownloadButton("Download Fingerprint Data (CSV)", visible=False, interactive=False)
799
-
800
- # Using HTML component for custom molecule display
801
- mols_grid_s2 = gr.HTML(label="Molecules with pIC50", visible=True)
802
-
803
-
804
- with gr.TabItem("Step 3: Model Training & Evaluation"):
805
- gr.Markdown("## 3.1 Train Regression Models")
806
- train_models_btn = gr.Button("Train Models")
807
- status_step3_train = gr.Textbox(label="Status", interactive=False)
808
- model_results_df = gr.DataFrame(label="Model Performance Metrics", max_rows=10)
809
-
810
- gr.Markdown("## 3.2 Model Analysis")
811
- model_selector_s3 = gr.Dropdown(
812
- label="Select Model for Detailed Analysis",
813
- choices=[],
814
- interactive=True
815
- )
816
- feature_count_s3 = gr.Slider(
817
- minimum=0, maximum=50, step=1, value=10,
818
- label="Number of Top Features to Display", interactive=True
819
- )
820
- model_metrics_summary = gr.Textbox(label="Selected Model Metrics (Test Set)", interactive=False)
821
- obs_pred_plot = gr.Plot(label="Observed vs. Predicted pIC50")
822
- residuals_plot = gr.Plot(label="Residuals Plot")
823
- feature_importance_plot = gr.Plot(label="Feature Importance/Coefficients")
824
-
825
- with gr.TabItem("Step 4: Prediction on New Data"):
826
- gr.Markdown("## 4.1 Upload New Molecules & Predict")
827
- upload_new_mols = gr.File(label="Upload CSV with 'canonical_smiles' column")
828
- model_selector_s4 = gr.Dropdown(
829
- label="Select Trained Model for Prediction",
830
- choices=[],
831
- interactive=True
832
- )
833
- predict_btn = gr.Button("Predict pIC50 for New Molecules")
834
- status_step4 = gr.Textbox(label="Status", interactive=False)
835
- predictions_output_df = gr.DataFrame(label="Predictions for New Molecules", max_rows=10)
836
- download_s4 = gr.DownloadButton("Download Predictions (CSV)", visible=False, interactive=False)
837
- # Using HTML component for custom molecule display
838
- mols_grid_s4 = gr.HTML(label="New Molecules with Predicted pIC50", visible=True)
839
 
840
  # --- EVENT HANDLERS ---
841
-
842
- # Step 1 Callbacks
843
- search_target_btn.click(
844
- fn=lambda query: (get_target_chembl_id(query), gr.update(visible=True)), # Return two values
845
- inputs=target_query,
846
- outputs=[target_output_df, chembl_id_selector, status_step1_search]
847
- )
848
-
849
- chembl_id_selector.change(
850
- fn=lambda x: gr.update(value=x),
851
- inputs=chembl_id_selector,
852
- outputs=chembl_id_selector # This is just to ensure the value is updated
853
- )
854
-
855
- fetch_data_btn.click(
856
- fn=lambda target_id: (get_bioactivity_data(target_id), gr.update(value=target_id)),
857
- inputs=chembl_id_selector,
858
- outputs=[bioactivity_output_df, status_step1_fetch]
859
- ).then(
860
- fn=lambda df: (df, df.copy()), # Pass the fetched df to the state
861
- inputs=bioactivity_output_df,
862
- outputs=[gr.State(value={}, key='raw_data'), app_state]
863
- )
864
-
865
- process_data_btn.click(
866
- fn=lambda current_state: clean_and_process_data(current_state.get('raw_data')),
867
- inputs=app_state,
868
- outputs=[cleaned_data_output_df, status_step1_process_clean]
869
- ).then(
870
- fn=lambda df: (gr.update(visible=True), df), # Show download button and update state
871
- inputs=cleaned_data_output_df,
872
- outputs=[download_s1, gr.State(value={}, key='cleaned_data'), app_state]
873
- )
874
-
875
- @download_s1.click(inputs=app_state, outputs=download_s1, show_progress="hidden")
876
- def download_handler_s1(current_state):
877
- df_to_download = current_state.get('cleaned_data')
878
- if df_to_download is None:
879
- raise gr.Error("No data to download. Please process data first.")
880
- # Create a dummy file path for Gradio to handle the download
881
- file_path = "cleaned_data.csv"
882
- df_to_download.to_csv(file_path, index=False)
883
- return gr.File(file_path, visible=True)
884
-
885
- run_eda_btn.click(
886
- fn=lambda df, classes: run_eda_analysis(df, classes),
887
- inputs=[cleaned_data_output_df, bioactivity_class_selector],
888
- outputs=[freq_plot_output, scatter_plot_output, pic50_plot_output, pic50_stats_output,
889
- mw_plot_output, mw_stats_output, logp_plot_output, logp_stats_output,
890
- hdonors_plot_output, hdonors_stats_output, hacceptors_plot_output, hacceptors_stats_output,
891
- status_step1_process]
892
- )
893
 
894
- # Update EDA plots on filter change (if data is already processed)
895
- bioactivity_class_selector.change(
896
- fn=lambda current_state, selected_classes: run_eda_analysis(current_state.get('cleaned_data'), selected_classes),
897
- inputs=[app_state, bioactivity_class_selector],
898
- outputs=[freq_plot_output, scatter_plot_output, pic50_plot_output, pic50_stats_output,
899
- mw_plot_output, mw_stats_output, logp_plot_output, logp_stats_output,
900
- hdonors_plot_output, hdonors_stats_output, hacceptors_plot_output, hacceptors_stats_output,
901
- status_step1_process],
902
- show_progress="minimal"
903
- )
904
-
905
-
906
- # Step 2 Callbacks
907
- calculate_fp_btn.click(
908
- fn=calculate_fingerprints,
909
- inputs=[app_state, fingerprint_dropdown],
910
- outputs=[status_step2, output_df_s2, download_s2, mols_grid_s2, app_state]
911
- )
912
 
913
- @download_s2.click(inputs=app_state, outputs=download_s2, show_progress="hidden")
914
- def download_handler(current_state):
915
- df_to_download = current_state.get('fingerprint_data')
916
- if df_to_download is None:
917
- raise gr.Error("No data to download. Please calculate fingerprints first.")
918
- file_path = "fingerprint_data.csv"
919
- df_to_download.to_csv(file_path, index=False)
920
- return gr.File(file_path, visible=True)
921
-
922
-
923
- # Step 3 Callbacks
924
- train_models_btn.click(
925
- fn=handle_model_training,
926
- inputs=[app_state],
927
- outputs=[status_step3_train, model_results_df, model_selector_s3, app_state]
928
- )
929
-
930
- # Update plots when model or feature count changes
931
- for listener in [model_selector_s3.change, feature_count_s3.change]:
932
- listener(
933
- fn=update_analysis_plots,
934
- inputs=[model_selector_s3, feature_count_s3, app_state],
935
- outputs=[obs_pred_plot, residuals_plot, feature_importance_plot, model_metrics_summary, status_step3_train]
936
- )
937
-
938
- # Update model selector in Step 4 when models are trained
939
- model_selector_s3.change(
940
- fn=lambda choice: gr.update(choices=model_selector_s3.choices, value=choice),
941
- inputs=model_selector_s3,
942
- outputs=model_selector_s4
943
- )
944
-
945
-
946
- # Step 4 Callbacks
947
- predict_btn.click(
948
- fn=predict_on_upload,
949
- inputs=[upload_new_mols, model_selector_s4, app_state],
950
- outputs=[status_step4, predictions_output_df, mols_grid_s4]
951
- ).then(
952
- fn=lambda: gr.update(visible=True),
953
- outputs=download_s4
954
- )
955
-
956
- @download_s4.click(inputs=predictions_output_df, outputs=download_s4, show_progress="hidden")
957
- def download_predictions(df_predictions):
958
- if df_predictions is None or df_predictions.empty:
959
- raise gr.Error("No predictions to download.")
960
- file_path = "predictions.csv"
961
- df_predictions.to_csv(file_path, index=False)
962
- return gr.File(file_path, visible=True)
963
-
964
 
965
- demo.launch()
 
 
1
  # --- IMPORTS ---
2
  # Core and Data Handling
3
+ import gradio as gr #
4
+ import pandas as pd #
5
+ import numpy as np #
6
+ import os #
7
+ import glob #
8
+ import time #
9
+ import warnings #
10
 
11
  # Chemistry and Cheminformatics
12
+ from rdkit import Chem #
13
+ from rdkit.Chem import Descriptors, Lipinski #
14
+ from chembl_webresource_client.new_client import new_client #
15
+ from padelpy import padeldescriptor #
16
+ import mols2grid #
17
 
18
  # Plotting and Visualization
19
+ import matplotlib.pyplot as plt #
20
+ import seaborn as sns #
21
+ from scipy import stats #
22
+ from scipy.stats import mannwhitneyu #
23
 
24
  # Machine Learning Models and Metrics
25
+ from sklearn.model_selection import train_test_split #
26
+ from sklearn.feature_selection import VarianceThreshold #
27
+ from sklearn.linear_model import ( #
28
+ LinearRegression, Ridge, Lasso, ElasticNet, BayesianRidge, #
29
+ HuberRegressor, PassiveAggressiveRegressor, OrthogonalMatchingPursuit, #
30
+ LassoLars #
31
  )
32
+ from sklearn.tree import DecisionTreeRegressor #
33
+ from sklearn.ensemble import ( #
34
+ RandomForestRegressor, GradientBoostingRegressor, ExtraTreesRegressor, #
35
+ AdaBoostRegressor #
36
  )
37
+ from sklearn.neighbors import KNeighborsRegressor #
38
+ from sklearn.dummy import DummyRegressor #
39
+ from sklearn.metrics import ( #
40
+ mean_absolute_error, mean_squared_error, r2_score #
41
  )
42
 
43
  # A placeholder class to store all results from a modeling run
44
+ class ModelRunResult: #
45
+ def __init__(self, dataframe, plotter, models, selected_features): #
46
+ self.dataframe = dataframe #
47
+ self.plotter = plotter #
48
+ self.models = models #
49
+ self.selected_features = selected_features #
50
 
51
  # Optional Advanced Models
52
+ try: #
53
+ import xgboost as xgb #
54
+ import lightgbm as lgb #
55
+ import catboost as cb #
56
+ _has_extra_libs = True #
57
+ except ImportError: #
58
+ _has_extra_libs = False #
59
+ warnings.warn("Optional libraries (xgboost, lightgbm, catboost) not found. Some models will be unavailable.") #
60
 
61
  # --- GLOBAL CONFIGURATION & SETUP ---
62
+ warnings.filterwarnings("ignore") #
63
+ sns.set_theme(style='whitegrid') #
64
 
65
  # --- FINGERPRINT CONFIGURATION ---
66
  DESCRIPTOR_DIR = "padel_descriptors"
67
+
68
  # Check if the descriptor directory exists and contains files
69
  if not os.path.isdir(DESCRIPTOR_DIR):
70
  warnings.warn(
 
74
  xml_files = []
75
  else:
76
  xml_files = sorted(glob.glob(os.path.join(DESCRIPTOR_DIR, '*.xml')))
77
+
78
+ if not xml_files:
79
+ warnings.warn(
80
+ f"No descriptor .xml files found in the '{DESCRIPTOR_DIR}' directory. "
81
+ "Fingerprint calculation will not be possible."
82
+ )
83
 
84
  # The key is the filename without extension; the value is the full path to the file
85
  fp_config = {os.path.splitext(os.path.basename(file))[0]: file for file in xml_files}
86
  FP_list = sorted(list(fp_config.keys()))
87
 
88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  # ==============================================================================
90
  # === STEP 1: CORE DATA COLLECTION & EDA FUNCTIONS ===
91
  # ==============================================================================
92
+
93
+ def get_target_chembl_id(query): #
94
+ try: #
95
+ target = new_client.target #
96
+ res = target.search(query) #
97
+ if not res: #
98
+ return pd.DataFrame(), gr.Dropdown(choices=[], value=None), "No targets found for your query." #
99
+ df = pd.DataFrame(res) #
100
+ return df[["target_chembl_id", "pref_name", "organism"]], gr.Dropdown(choices=df["target_chembl_id"].tolist()), f"Found {len(df)} targets." #
101
+ except Exception as e: #
102
+ raise gr.Error(f"ChEMBL search failed: {e}") #
103
+
104
+ def get_bioactivity_data(target_id): #
105
+ try: #
106
+ activity = new_client.activity #
107
+ res = activity.filter(target_chembl_id=target_id).filter(standard_type="IC50") #
108
+ if not res: #
109
+ return pd.DataFrame(), "No IC50 bioactivity data found for this target." #
110
+ df = pd.DataFrame(res) #
111
+ return df, f"Fetched {len(df)} data points." #
112
+ except Exception as e: #
113
+ raise gr.Error(f"Failed to fetch bioactivity data: {e}") #
114
+
115
+ def pIC50_calc(input_df): #
116
+ df_copy = input_df.copy() #
117
+ df_copy['standard_value'] = pd.to_numeric(df_copy['standard_value'], errors='coerce') #
118
+ df_copy.dropna(subset=['standard_value'], inplace=True) #
119
+ df_copy['standard_value_norm'] = df_copy['standard_value'].apply(lambda x: min(x, 100000000)) #
120
+ pIC50_values = [] #
121
+ for i in df_copy['standard_value_norm']: #
122
+ if pd.notna(i) and i > 0: #
123
+ molar = i * (10**-9) #
124
+ pIC50_values.append(-np.log10(molar)) #
125
+ else: #
126
+ pIC50_values.append(np.nan) #
127
+ df_copy['pIC50'] = pIC50_values #
128
+ df_copy['bioactivity_class'] = df_copy['standard_value_norm'].apply( #
129
+ lambda x: "inactive" if pd.notna(x) and x >= 10000 else ("active" if pd.notna(x) and x <= 1000 else "intermediate") #
130
  )
131
+ return df_copy.drop(columns=['standard_value', 'standard_value_norm']) #
132
+
133
+ def lipinski_descriptors(smiles_series): #
134
+ moldata, valid_smiles = [], [] #
135
+ for elem in smiles_series: #
136
+ if elem and isinstance(elem, str): #
137
+ mol = Chem.MolFromSmiles(elem) #
138
+ if mol: #
139
+ moldata.append(mol) #
140
+ valid_smiles.append(elem) #
141
+ descriptor_rows = [] #
142
+ for mol in moldata: #
143
+ row = [Descriptors.MolWt(mol), Descriptors.MolLogP(mol), Lipinski.NumHDonors(mol), Lipinski.NumHAcceptors(mol)] #
144
+ descriptor_rows.append(row) #
145
+ columnNames = ["MW", "LogP", "NumHDonors", "NumHAcceptors"] #
146
+ if not descriptor_rows: return pd.DataFrame(columns=columnNames), [] #
147
+ return pd.DataFrame(data=np.array(descriptor_rows), columns=columnNames), valid_smiles #
148
+
149
+ def clean_and_process_data(df): #
150
+ if df is None or df.empty: raise gr.Error("No data to process. Please fetch data first.") #
151
+ if "canonical_smiles" not in df.columns or df["canonical_smiles"].isnull().all(): #
152
+ try: #
153
+ df["canonical_smiles"] = [c.get("molecule_structures", {}).get("canonical_smiles") for c in new_client.molecule.get(list(df["molecule_chembl_id"]))] #
154
+ except Exception as e: #
155
+ raise gr.Error(f"Could not fetch SMILES from ChEMBL: {e}") #
156
+ df = df[df.standard_value.notna()] #
157
+ df = df[df.canonical_smiles.notna()] #
158
+ df.drop_duplicates(['canonical_smiles'], inplace=True) #
159
+ df["standard_value"] = pd.to_numeric(df["standard_value"], errors='coerce') #
160
+ df.dropna(subset=['standard_value'], inplace=True) #
161
+ df_processed = pIC50_calc(df) #
162
+ df_processed = df_processed[df_processed.pIC50.notna()] #
163
+ if df_processed.empty: return pd.DataFrame(), "No compounds remaining after pIC50 calculation." #
164
+ df_lipinski, valid_smiles = lipinski_descriptors(df_processed['canonical_smiles']) #
165
+ if not valid_smiles: return pd.DataFrame(), "No valid SMILES could be processed for Lipinski descriptors." #
166
+ df_processed = df_processed[df_processed['canonical_smiles'].isin(valid_smiles)].reset_index(drop=True) #
167
+ df_lipinski = df_lipinski.reset_index(drop=True) #
168
+ df_final = pd.concat([df_processed, df_lipinski], axis=1) #
169
+ return df_final, f"Processing complete. {len(df_final)} compounds remain after cleaning." #
170
+
171
+ def run_eda_analysis(df, selected_classes): #
172
+ if df is None or df.empty: raise gr.Error("No data available for analysis.") #
173
+ df_filtered = df[df.bioactivity_class.isin(selected_classes)].copy() #
174
+ if df_filtered.empty: return (None, None, None, pd.DataFrame(), None, pd.DataFrame(), None, pd.DataFrame(), None, pd.DataFrame(), None, pd.DataFrame(), "No data for selected classes.") #
175
+ plots = [create_frequency_plot(df_filtered), create_scatter_plot(df_filtered)] #
176
+ stats_dfs = [] #
177
+ for desc in ['pIC50', 'MW', 'LogP', 'NumHDonors', 'NumHAcceptors']: #
178
+ plots.append(create_boxplot(df_filtered, desc)) #
179
+ stats_dfs.append(mannwhitney_test(df_filtered, desc)) #
180
+ plt.close('all') #
181
+ return (plots[0], plots[1], plots[2], stats_dfs[0], plots[3], stats_dfs[1], plots[4], stats_dfs[2], plots[5], stats_dfs[3], plots[6], stats_dfs[4], f"EDA complete for {len(df_filtered)} compounds.") #
182
+
183
+ def create_frequency_plot(df): #
184
+ plt.figure(figsize=(5.5, 5.5)); sns.barplot(x=df['bioactivity_class'].value_counts().index, y=df['bioactivity_class'].value_counts().values, palette={'active': '#1f77b4', 'inactive': '#ff7f0e', 'intermediate': '#2ca02c'}); plt.xlabel('Bioactivity Class', fontsize=12); plt.ylabel('Frequency', fontsize=12); plt.title('Frequency of Bioactivity Classes', fontsize=14); return plt.gcf() #
185
+ def create_scatter_plot(df): #
186
+ plt.figure(figsize=(5.5, 5.5)); sns.scatterplot(data=df, x='MW', y='LogP', hue='bioactivity_class', size='pIC50', palette={'active': '#1f77b4', 'inactive': '#ff7f0e', 'intermediate': '#2ca02c'}, sizes=(20, 200), alpha=0.7); plt.xlabel('Molecular Weight (MW)', fontsize=12); plt.ylabel('LogP', fontsize=12); plt.title('Chemical Space: MW vs. LogP', fontsize=14); plt.legend(title='Bioactivity Class'); return plt.gcf() #
187
+ def create_boxplot(df, descriptor): #
188
+ plt.figure(figsize=(5.5, 5.5)); sns.boxplot(x='bioactivity_class', y=descriptor, data=df, palette={'active': '#1f77b4', 'inactive': '#ff7f0e', 'intermediate': '#2ca02c'}); plt.xlabel('Bioactivity Class', fontsize=12); plt.ylabel(descriptor, fontsize=12); plt.title(f'{descriptor} by Bioactivity Class', fontsize=14); return plt.gcf() #
189
+ def mannwhitney_test(df, descriptor): #
190
+ results = [] #
191
+ for c1, c2 in [('active', 'inactive'), ('active', 'intermediate'), ('inactive', 'intermediate')]: #
192
+ if c1 in df['bioactivity_class'].unique() and c2 in df['bioactivity_class'].unique(): #
193
+ d1, d2 = df[df.bioactivity_class == c1][descriptor].dropna(), df[df.bioactivity_class == c2][descriptor].dropna() #
194
+ if not d1.empty and not d2.empty: #
195
+ stat, p = mannwhitneyu(d1, d2) #
196
+ results.append({'Comparison': f'{c1.title()} vs {c2.title()}', 'Statistics': stat, 'p-value': p, 'Interpretation': 'Different distribution (p < 0.05)' if p <= 0.05 else 'Same distribution (p > 0.05)'}) #
197
+ return pd.DataFrame(results) #
 
 
 
198
 
199
  # ==============================================================================
200
  # === STEP 2: FEATURE ENGINEERING FUNCTIONS ===
201
  # ==============================================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
 
203
+ def calculate_fingerprints(current_state, fingerprint_type, progress=gr.Progress()): #
204
+ input_df = current_state.get('cleaned_data') #
205
+ if input_df is None or input_df.empty: raise gr.Error("No cleaned data found. Please complete Step 1.") #
206
+ if not fingerprint_type: raise gr.Error("Please select a fingerprint type.") #
207
+ progress(0, desc="Starting..."); yield f"🧪 Starting fingerprint calculation...", None, gr.update(visible=False), None, current_state #
208
+ try: #
209
+ smi_file, output_csv = 'molecule.smi', 'fingerprints.csv' #
210
 
211
+ input_df[['canonical_smiles', 'canonical_smiles']].to_csv(smi_file, sep='\t', index=False, header=False) #
 
212
 
213
+ if os.path.exists(output_csv): os.remove(output_csv) #
214
+ descriptortypes = fp_config.get(fingerprint_type) #
215
+ if not descriptortypes: raise gr.Error(f"Descriptor XML for '{fingerprint_type}' not found.") #
216
 
217
+ progress(0.3, desc="⚗️ Running PaDEL..."); yield f"⚗️ Running PaDEL...", None, gr.update(visible=False), None, current_state #
218
+ padeldescriptor(mol_dir=smi_file, d_file=output_csv, descriptortypes=descriptortypes, detectaromaticity=True, standardizenitro=True, standardizetautomers=True, threads=-1, removesalt=True, log=False, fingerprints=True) #
219
+ if not os.path.exists(output_csv) or os.path.getsize(output_csv) == 0: #
220
+ raise gr.Error("PaDEL failed to produce an output file. Check molecule validity.") #
221
+
222
+ progress(0.7, desc="📊 Processing results..."); yield "📊 Processing results...", None, gr.update(visible=False), None, current_state #
223
+ df_X = pd.read_csv(output_csv).rename(columns={'Name': 'canonical_smiles'}) #
 
 
 
 
 
 
 
 
 
 
 
224
 
225
+ final_df = pd.merge(input_df[['canonical_smiles', 'pIC50']], df_X, on='canonical_smiles', how='inner') #
 
 
226
 
227
+ current_state['fingerprint_data'] = final_df; current_state['fingerprint_type'] = fingerprint_type #
228
+ progress(0.9, desc="🖼️ Generating molecule grid...") #
229
+ mols_html = mols2grid.display(final_df, smiles_col='canonical_smiles', subset=['img', 'pIC50'], rename={"pIC50": "pIC50"}, transform={"pIC50": lambda x: f"{x:.2f}"})._repr_html_() #
230
+ success_msg = f"✅ Success! Generated {len(df_X.columns) -1} descriptors for {len(final_df)} molecules." #
231
+ progress(1, desc="Completed!"); yield success_msg, final_df, gr.update(visible=True), gr.update(value=mols_html, visible=True), current_state #
232
+ except Exception as e: raise gr.Error(f"Calculation failed: {e}") #
233
+ finally: #
234
+ if os.path.exists('molecule.smi'): os.remove('molecule.smi') #
235
+ if os.path.exists('fingerprints.csv'): os.remove('fingerprints.csv') #
236
 
237
  # ==============================================================================
238
+ # === STEP 3: MODEL TRAINING & PREDICTION FUNCTIONS ===
239
  # ==============================================================================
240
+ class ModelPlotter: #
241
+ def __init__(self, models: dict, X_test: pd.DataFrame, y_test: pd.Series): #
242
+ self._models, self._X_test, self._y_test = models, X_test, y_test #
243
+ def plot_validation(self, model_name: str): #
244
+ if model_name not in self._models: raise ValueError(f"Model '{model_name}' not found.") #
245
+ model, y_pred = self._models[model_name], self._models[model_name].predict(self._X_test) #
246
+ residuals = self._y_test - y_pred #
247
+ fig, axes = plt.subplots(2, 2, figsize=(12, 10)); fig.suptitle(f'Model Validation Plots for {model_name}', fontsize=16, y=1.02) #
248
+ sns.scatterplot(x=self._y_test, y=y_pred, ax=axes[0, 0], alpha=0.6); axes[0, 0].set_title('Actual vs. Predicted'); axes[0, 0].set_xlabel('Actual pIC50'); axes[0, 0].set_ylabel('Predicted pIC50'); lims = [min(self._y_test.min(), y_pred.min()), max(self._y_test.max(), y_pred.max())]; axes[0, 0].plot(lims, lims, 'r--', alpha=0.75, zorder=0) #
249
+ sns.scatterplot(x=y_pred, y=residuals, ax=axes[0, 1], alpha=0.6); axes[0, 1].axhline(y=0, color='r', linestyle='--'); axes[0, 1].set_title('Residuals vs. Predicted'); axes[0, 1].set_xlabel('Predicted pIC50'); axes[0, 1].set_ylabel('Residuals') #
250
+ sns.histplot(residuals, kde=True, ax=axes[1, 0]); axes[1, 0].set_title('Distribution of Residuals') #
251
+ stats.probplot(residuals, dist="norm", plot=axes[1, 1]); axes[1, 1].set_title('Normal Q-Q Plot') #
252
+ plt.tight_layout(); return fig #
253
+ def plot_feature_importance(self, model_name: str, top_n: int = 7): #
254
+ if model_name not in self._models: raise ValueError(f"Model '{model_name}' not found.") #
255
+ model = self._models[model_name] #
256
+ if hasattr(model, 'feature_importances_'): importances = model.feature_importances_ #
257
+ elif hasattr(model, 'coef_'): importances = np.abs(model.coef_) #
258
+ else: return None #
259
+ top_features = pd.DataFrame({'Feature': self._X_test.columns, 'Importance': importances}).sort_values(by='Importance', ascending=False).head(top_n) #
260
+ plt.figure(figsize=(10, top_n * 0.5)); sns.barplot(x='Importance', y='Feature', data=top_features, palette='viridis', orient='h'); plt.title(f'Top {top_n} Features for {model_name}'); plt.tight_layout(); return plt.gcf() #
261
+
262
+ def run_regression_suite(df: pd.DataFrame, progress=gr.Progress()): #
263
+ progress(0, desc="Splitting data..."); yield "Splitting data (80/20 train/test split)...", None, None #
264
+ X = df.drop(columns=['pIC50', 'canonical_smiles'], errors='ignore') #
265
+ y = df['pIC50'] #
266
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) #
267
+
268
+ progress(0.1, desc="Selecting features..."); yield "Performing feature selection (removing low variance)...", None, None #
269
+ selector = VarianceThreshold(threshold=0.1) #
270
+ X_train = pd.DataFrame(selector.fit_transform(X_train), columns=X_train.columns[selector.get_support()], index=X_train.index) #
271
+ X_test = pd.DataFrame(selector.transform(X_test), columns=X_test.columns[selector.get_support()], index=X_test.index) #
272
+ selected_features = X_train.columns.tolist() #
273
+
274
+ model_defs = [
275
+ ('Linear Regression', LinearRegression()),
276
+ ('Ridge', Ridge(random_state=42)),
277
+ ('Lasso', Lasso(random_state=42)),
278
+ ('Random Forest', RandomForestRegressor(random_state=42, n_jobs=-1)),
279
+ # ('Gradient Boosting', GradientBoostingRegressor(random_state=42)) # <-- Commented out
280
+ ]
281
+ if _has_extra_libs:
282
+ model_defs.extend([
283
+ # ('XGBoost', xgb.XGBRegressor(random_state=42, n_jobs=-1, verbosity=0)), # <-- Commented out
284
+ ('LightGBM', lgb.LGBMRegressor(random_state=42, n_jobs=-1, verbosity=-1)),
285
+ # ('CatBoost', cb.CatBoostRegressor(random_state=42, verbose=0)) # <-- Commented out
286
+ ])
287
 
288
+ results_list, trained_models = [], {} #
289
+ for i, (name, model) in enumerate(model_defs): #
290
+ progress(0.2 + (i / len(model_defs)) * 0.8, desc=f"Training {name}...") #
291
+ yield f"Training {i+1}/{len(model_defs)}: {name}...", None, None #
292
+ start_time = time.time(); model.fit(X_train, y_train); y_pred = model.predict(X_test) #
293
+ results_list.append({'Model': name, '': r2_score(y_test, y_pred), 'MAE': mean_absolute_error(y_test, y_pred), 'RMSE': np.sqrt(mean_squared_error(y_test, y_pred)), 'Time (s)': f"{time.time() - start_time:.2f}"}) #
294
+ trained_models[name] = model #
295
+
296
+ results_df = pd.DataFrame(results_list).sort_values(by='R²', ascending=False).reset_index(drop=True) #
297
+ plotter = ModelPlotter(trained_models, X_test, y_test) #
298
+ model_run_results = ModelRunResult(results_df, plotter, trained_models, selected_features) #
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
 
300
+ model_choices = results_df['Model'].tolist() #
301
+ yield "✅ Model training & evaluation complete.", model_run_results, gr.Dropdown(choices=model_choices, interactive=True) #
302
+
303
+ def predict_on_upload(uploaded_file, model_name, current_state, progress=gr.Progress()): #
304
+ if not uploaded_file: raise gr.Error("Please upload a file.") #
305
+ if not model_name: raise gr.Error("Please select a trained model.") #
306
+ model_run_results = current_state.get('model_results') #
307
+ fingerprint_type = current_state.get('fingerprint_type') #
308
+ if not model_run_results or not fingerprint_type: raise gr.Error("Please run Steps 2 and 3 first.") #
309
 
310
+ model = model_run_results.models.get(model_name) #
311
+ selected_features = model_run_results.selected_features #
312
+ if model is None: raise gr.Error(f"Model '{model_name}' not found.") #
 
313
 
314
+ smi_file, output_csv = 'predict.smi', 'predict_fp.csv' #
315
+ try: #
316
+ progress(0, desc="Reading & processing new molecules..."); yield "Reading uploaded file...", None, None #
317
+ df_new = pd.read_csv(uploaded_file.name) #
318
+ if 'canonical_smiles' not in df_new.columns: raise gr.Error("CSV must contain a 'canonical_smiles' column.") #
319
+ df_new = df_new.reset_index().rename(columns={'index': 'mol_id'}) #
 
 
 
 
 
 
 
 
 
 
 
320
 
321
+ padel_input = pd.DataFrame({'smiles': df_new['canonical_smiles'], 'name': df_new['mol_id']}) #
322
+ padel_input.to_csv(smi_file, sep='\t', index=False, header=False) #
323
+ if os.path.exists(output_csv): os.remove(output_csv) #
324
 
325
+ progress(0.3, desc="Calculating fingerprints..."); yield "Calculating fingerprints for new molecules...", None, None #
326
+ padeldescriptor(mol_dir=smi_file, d_file=output_csv, descriptortypes=fp_config.get(fingerprint_type), detectaromaticity=True, standardizenitro=True, threads=-1, removesalt=True, log=False, fingerprints=True) #
327
+ if not os.path.exists(output_csv) or os.path.getsize(output_csv) == 0: raise gr.Error("PaDEL calculation failed for the uploaded molecules.") #
 
 
 
 
 
 
 
 
328
 
329
+ progress(0.7, desc="Aligning features and predicting..."); yield "Aligning features and predicting...", None, None #
330
+ df_fp = pd.read_csv(output_csv).rename(columns={'Name': 'mol_id'}) #
331
 
332
+ X_new = df_fp.set_index('mol_id') #
333
+ X_new_aligned = X_new.reindex(columns=selected_features, fill_value=0)[selected_features] #
334
 
335
+ predictions = model.predict(X_new_aligned) #
 
 
 
 
 
 
 
 
 
 
 
336
 
337
+ results_subset = pd.DataFrame({'mol_id': X_new_aligned.index, 'predicted_pIC50': predictions}) #
338
+ df_results = pd.merge(df_new, results_subset, on='mol_id', how='left') #
339
+
340
+ progress(0.9, desc="Generating visualization..."); yield "Generating visualization...", None, None #
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
341
 
342
+ df_grid_view = df_results.dropna(subset=['predicted_pIC50']).copy() #
343
+ mols_html = "<h3>No molecules with successful predictions to display.</h3>" #
344
+ if not df_grid_view.empty: #
345
+ df_grid_view.rename(columns={"predicted_pIC50": "Predicted pIC50"}, inplace=True) #
346
+ mols_html = mols2grid.display( #
347
+ df_grid_view, #
348
+ smiles_col='canonical_smiles', #
349
+ subset=['img', 'Predicted pIC50'], #
350
+ transform={"Predicted pIC50": lambda x: f"{x:.2f}"} #
351
+ )._repr_html_() #
352
 
353
+ progress(1, desc="Complete!"); yield "✅ Prediction complete.", df_results[['canonical_smiles', 'predicted_pIC50']], mols_html #
354
+ finally: #
355
+ if os.path.exists(smi_file): os.remove(smi_file) #
356
+ if os.path.exists(output_csv): os.remove(output_csv) #
 
357
 
358
  # ==============================================================================
359
+ # === GRADIO INTERFACE ===
360
  # ==============================================================================
361
+ with gr.Blocks(theme=gr.themes.Default(primary_hue="blue", secondary_hue="sky"), title="Comprehensive Drug Discovery Workflow") as demo: #
362
+ gr.Markdown("# 🧪 Comprehensive Drug Discovery Workflow") #
363
+ gr.Markdown("A 3-step application to fetch, analyze, and model chemical bioactivity data.") #
364
+ app_state = gr.State({}) #
365
+ with gr.Tabs(): #
366
+ with gr.Tab("Step 1: Data Collection & EDA"): #
367
+ gr.Markdown("## Fetch Bioactivity Data from ChEMBL and Perform Exploratory Analysis") #
368
+ with gr.Row(): #
369
+ query_input = gr.Textbox(label="Target Query", placeholder="e.g., acetylcholinesterase, BRAF kinase", scale=3) #
370
+ fetch_btn = gr.Button("Fetch Targets", variant="primary", scale=1) #
371
+ status_step1_fetch = gr.Textbox(label="Status", interactive=False) #
372
+ target_id_table = gr.Dataframe(label="Available Targets", interactive=False, headers=["target_chembl_id", "pref_name", "organism"]) #
373
+ with gr.Row(): #
374
+ selected_target_dropdown = gr.Dropdown(label="Select Target ChEMBL ID", interactive=True, scale=3) #
375
+ process_btn = gr.Button("Process Data & Run EDA", variant="primary", scale=1, interactive=False) #
376
+ status_step1_process = gr.Textbox(label="Status", interactive=False) #
377
+ gr.Markdown("### Filtered Data & Analysis") #
378
+ bioactivity_class_selector = gr.CheckboxGroup(["active", "inactive", "intermediate"], label="Filter by Bioactivity Class", value=["active", "inactive", "intermediate"]) #
379
+ df_output_s1 = gr.Dataframe(label="Cleaned Bioactivity Data") #
380
+ with gr.Tabs(): #
381
+ with gr.Tab("Chemical Space Overview"): #
382
+ with gr.Row(): #
383
+ freq_plot_output = gr.Plot(label="Frequency of Bioactivity Classes") #
384
+ scatter_plot_output = gr.Plot(label="Scatter Plot: MW vs LogP") #
385
+ with gr.Tab("pIC50 Analysis"): #
386
+ with gr.Row(): #
387
+ pic50_plot_output = gr.Plot(label="pIC50 Box Plot") #
388
+ pic50_stats_output = gr.Dataframe(label="Mann-Whitney U Test Results for pIC50") #
389
+ with gr.Tab("Molecular Weight Analysis"): #
390
+ with gr.Row(): #
391
+ mw_plot_output = gr.Plot(label="MW Box Plot") #
392
+ mw_stats_output = gr.Dataframe(label="Mann-Whitney U Test Results for MW") #
393
+ with gr.Tab("LogP Analysis"): #
394
+ with gr.Row(): #
395
+ logp_plot_output = gr.Plot(label="LogP Box Plot") #
396
+ logp_stats_output = gr.Dataframe(label="Mann-Whitney U Test Results for LogP") #
397
+ with gr.Tab("H-Bond Donor/Acceptor Analysis"): #
398
+ with gr.Row(): #
399
+ hdonors_plot_output = gr.Plot(label="H-Donors Box Plot") #
400
+ hacceptors_plot_output = gr.Plot(label="H-Acceptors Box Plot") #
401
+ with gr.Row(): #
402
+ hdonors_stats_output = gr.Dataframe(label="Stats for H-Donors") #
403
+ hacceptors_stats_output = gr.Dataframe(label="Stats for H-Acceptors") #
404
+ with gr.Tab("Step 2: Feature Engineering"): #
405
+ gr.Markdown("## Calculate Molecular Fingerprints using PaDEL") #
406
+ with gr.Row(): #
407
+ fingerprint_dropdown = gr.Dropdown(choices=FP_list, value='PubChem' if 'PubChem' in FP_list else None, label="Select Fingerprint Method", scale=3) #
408
+ calculate_fp_btn = gr.Button("Calculate Fingerprints", variant="primary", scale=1) #
409
+ status_step2 = gr.Textbox(label="Status", interactive=False) #
410
+ output_df_s2 = gr.Dataframe(label="Final Processed Data", wrap=True) #
411
+ download_s2 = gr.DownloadButton("Download Feature Data (CSV)", variant="secondary", visible=False) #
412
+ mols_grid_s2 = gr.HTML(label="Interactive Molecule Viewer") #
413
+ with gr.Tab("Step 3: Model Training & Prediction"): #
414
+ gr.Markdown("## Train Regression Models and Predict pIC50") #
415
+ with gr.Tabs(): #
416
+ with gr.Tab("Model Training & Evaluation"): #
417
+ train_models_btn = gr.Button("Train All Models", variant="primary") #
418
+ status_step3_train = gr.Textbox(label="Status", interactive=False) #
419
+ model_results_df = gr.DataFrame(label="Ranked Model Results", interactive=False) #
420
+ with gr.Row(): #
421
+ model_selector_s3 = gr.Dropdown(label="Select Model to Analyze", interactive=False) #
422
+ feature_count_s3 = gr.Number(label="Top Features to Show", value=7, minimum=3, maximum=20, step=1) #
423
+ with gr.Tabs(): #
424
+ with gr.Tab("Validation Plots"): validation_plot_s3 = gr.Plot(label="Model Validation Plots") #
425
+ with gr.Tab("Feature Importance"): feature_plot_s3 = gr.Plot(label="Top Feature Importances") #
426
+ with gr.Tab("Predict on New Data"): #
427
+ gr.Markdown("Upload a CSV with a `canonical_smiles` column to predict pIC50.") #
428
+ with gr.Row(): #
429
+ upload_predict_file = gr.File(label="Upload CSV for Prediction", file_types=[".csv"]) #
430
+ predict_btn_s3 = gr.Button("Run Prediction", variant="primary") #
431
+ status_step3_predict = gr.Textbox(label="Status", interactive=False) #
432
+ prediction_results_df = gr.DataFrame(label="Prediction Results") #
433
+ prediction_mols_grid = gr.HTML(label="Interactive Molecular Grid of Predictions") #
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
434
 
435
  # --- EVENT HANDLERS ---
436
+ def enable_process_button(target_id): return gr.update(interactive=bool(target_id)) #
437
+ def process_and_analyze_wrapper(target_id, selected_classes, current_state, progress=gr.Progress()): #
438
+ if not target_id: raise gr.Error("Please select a target ChEMBL ID first.") #
439
+ progress(0, desc="Fetching data..."); raw_data, msg1 = get_bioactivity_data(target_id); yield {status_step1_process: gr.update(value=msg1)} #
440
+ progress(0.3, desc="Cleaning data..."); processed_data, msg2 = clean_and_process_data(raw_data); yield {df_output_s1: processed_data, status_step1_process: gr.update(value=msg2)} #
441
+ current_state['cleaned_data'] = processed_data #
442
+ progress(0.6, desc="Running EDA..."); plots_and_stats = run_eda_analysis(processed_data, selected_classes); msg3 = plots_and_stats[-1] #
443
+ progress(1, desc="Done!") #
444
+ filtered_data = processed_data[processed_data.bioactivity_class.isin(selected_classes)] if not processed_data.empty else pd.DataFrame() #
445
+ outputs = [filtered_data] + list(plots_and_stats[:-1]) + [msg3, current_state] #
446
+ output_components = [df_output_s1, freq_plot_output, scatter_plot_output, pic50_plot_output, pic50_stats_output, mw_plot_output, mw_stats_output, logp_plot_output, logp_stats_output, hdonors_plot_output, hdonors_stats_output, hacceptors_plot_output, hacceptors_stats_output, status_step1_process, app_state] #
447
+ yield dict(zip(output_components, outputs)) #
448
+ def update_analysis_on_filter_change(selected_classes, current_state): #
449
+ cleaned_data = current_state.get('cleaned_data') #
450
+ if cleaned_data is None or cleaned_data.empty: return (pd.DataFrame(),) + (None,) * 11 + ("No data available.",) #
451
+ plots_and_stats = run_eda_analysis(cleaned_data, selected_classes); msg = plots_and_stats[-1] #
452
+ filtered_data = cleaned_data[cleaned_data.bioactivity_class.isin(selected_classes)] #
453
+ return (filtered_data,) + plots_and_stats[:-1] + (msg,) #
454
+ def handle_model_training(current_state, progress=gr.Progress(track_tqdm=True)): #
455
+ fingerprint_data = current_state.get('fingerprint_data') #
456
+ if fingerprint_data is None or fingerprint_data.empty: raise gr.Error("No feature data. Please complete Step 2.") #
457
+ for status_msg, model_results, model_choices_update in run_regression_suite(fingerprint_data, progress=progress): #
458
+ if model_results: current_state['model_results'] = model_results #
459
+ yield status_msg, model_results.dataframe if model_results else None, model_choices_update, current_state #
460
+ def save_dataframe_as_csv(df): #
461
+ if df is None or df.empty: return None #
462
+ filename = "feature_engineered_data.csv"; df.to_csv(filename, index=False); return gr.File(value=filename, visible=True) #
463
+ def update_analysis_plots(model_name, feature_count, current_state): #
464
+ model_results = current_state.get('model_results') #
465
+ if not model_results or not model_name: return None, None #
466
+ plotter = model_results.plotter; validation_fig = plotter.plot_validation(model_name); feature_fig = plotter.plot_feature_importance(model_name, int(feature_count)); plt.close('all'); return validation_fig, feature_fig #
467
+
468
+ fetch_btn.click(fn=get_target_chembl_id, inputs=query_input, outputs=[target_id_table, selected_target_dropdown, status_step1_fetch], show_progress="minimal") #
469
+ selected_target_dropdown.change(fn=enable_process_button, inputs=selected_target_dropdown, outputs=process_btn, show_progress="hidden") #
470
+ process_btn.click(fn=process_and_analyze_wrapper, inputs=[selected_target_dropdown, bioactivity_class_selector, app_state], outputs=[df_output_s1, freq_plot_output, scatter_plot_output, pic50_plot_output, pic50_stats_output, mw_plot_output, mw_stats_output, logp_plot_output, logp_stats_output, hdonors_plot_output, hdonors_stats_output, hacceptors_plot_output, hacceptors_stats_output, status_step1_process, app_state]) #
471
+ bioactivity_class_selector.change(fn=update_analysis_on_filter_change, inputs=[bioactivity_class_selector, app_state], outputs=[df_output_s1, freq_plot_output, scatter_plot_output, pic50_plot_output, pic50_stats_output, mw_plot_output, mw_stats_output, logp_plot_output, logp_stats_output, hdonors_plot_output, hdonors_stats_output, hacceptors_plot_output, hacceptors_stats_output, status_step1_process], show_progress="minimal") #
472
+ calculate_fp_btn.click(fn=calculate_fingerprints, inputs=[app_state, fingerprint_dropdown], outputs=[status_step2, output_df_s2, download_s2, mols_grid_s2, app_state]) #
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
473
 
474
+ @download_s2.click(inputs=app_state, outputs=download_s2, show_progress="hidden") #
475
+ def download_handler(current_state): #
476
+ df_to_download = current_state.get('fingerprint_data') #
477
+ return save_dataframe_as_csv(df_to_download) #
 
 
 
 
 
 
 
 
 
 
 
 
 
 
478
 
479
+ train_models_btn.click(fn=handle_model_training, inputs=[app_state], outputs=[status_step3_train, model_results_df, model_selector_s3, app_state]) #
480
+ for listener in [model_selector_s3.change, feature_count_s3.change]: listener(fn=update_analysis_plots, inputs=[model_selector_s3, feature_count_s3, app_state], outputs=[validation_plot_s3, feature_plot_s3], show_progress="minimal") #
481
+ predict_btn_s3.click(fn=predict_on_upload, inputs=[upload_predict_file, model_selector_s3, app_state], outputs=[status_step3_predict, prediction_results_df, prediction_mols_grid]) #
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
482
 
483
+ if __name__ == "__main__": #
484
+ demo.launch(debug=True) #