alidenewade commited on
Commit
c3673c6
·
verified ·
1 Parent(s): b102a7f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +114 -163
app.py CHANGED
@@ -7,13 +7,17 @@ import os
7
  import glob
8
  import time
9
  import warnings
10
- import base64
11
 
12
  # Chemistry and Cheminformatics
13
  from rdkit import Chem
14
- from rdkit.Chem import Descriptors, Lipinski, Draw, rdDepictor
15
  from chembl_webresource_client.new_client import new_client
16
  from padelpy import padeldescriptor
 
 
 
 
 
17
 
18
  # Plotting and Visualization
19
  import matplotlib.pyplot as plt
@@ -63,25 +67,20 @@ warnings.filterwarnings("ignore")
63
  sns.set_theme(style='whitegrid')
64
 
65
  # --- FINGERPRINT CONFIGURATION ---
66
- DESCRIPTOR_DIR = "padel_descriptors"
67
-
68
- # Check if the descriptor directory exists and contains files
69
- if not os.path.isdir(DESCRIPTOR_DIR):
70
- warnings.warn(
71
- f"The descriptor directory '{DESCRIPTOR_DIR}' was not found. "
72
- "Fingerprint calculation will be disabled. Please create this directory and upload your .xml files."
73
- )
74
- xml_files = []
75
- else:
76
- xml_files = sorted(glob.glob(os.path.join(DESCRIPTOR_DIR, '*.xml')))
77
 
 
78
  if not xml_files:
79
  warnings.warn(
80
- f"No descriptor .xml files found in the '{DESCRIPTOR_DIR}' directory. "
81
- "Fingerprint calculation will not be possible."
82
  )
83
-
84
- # The key is the filename without extension; the value is the full path to the file
85
  fp_config = {os.path.splitext(os.path.basename(file))[0]: file for file in xml_files}
86
  FP_list = sorted(list(fp_config.keys()))
87
 
@@ -155,6 +154,7 @@ def clean_and_process_data(df):
155
  raise gr.Error(f"Could not fetch SMILES from ChEMBL: {e}")
156
  df = df[df.standard_value.notna()]
157
  df = df[df.canonical_smiles.notna()]
 
158
  df.drop_duplicates(['canonical_smiles'], inplace=True)
159
  df["standard_value"] = pd.to_numeric(df["standard_value"], errors='coerce')
160
  df.dropna(subset=['standard_value'], inplace=True)
@@ -200,58 +200,68 @@ def mannwhitney_test(df, descriptor):
200
  # === STEP 2: FEATURE ENGINEERING FUNCTIONS ===
201
  # ==============================================================================
202
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  def calculate_fingerprints(current_state, fingerprint_type, progress=gr.Progress()):
204
  input_df = current_state.get('cleaned_data')
205
- if input_df is None or input_df.empty:
206
- raise gr.Error("No cleaned data found. Please complete Step 1.")
207
- if not fingerprint_type:
208
- raise gr.Error("Please select a fingerprint type.")
209
-
210
- progress(0, desc="Starting...")
211
- yield f"🧪 Starting fingerprint calculation...", None, gr.update(visible=False), None, current_state
212
-
213
  try:
214
  smi_file, output_csv = 'molecule.smi', 'fingerprints.csv'
 
 
 
215
  input_df[['canonical_smiles', 'canonical_smiles']].to_csv(smi_file, sep='\t', index=False, header=False)
216
-
217
- if os.path.exists(output_csv):
218
- os.remove(output_csv)
219
  descriptortypes = fp_config.get(fingerprint_type)
220
- if not descriptortypes:
221
- raise gr.Error(f"Descriptor XML for '{fingerprint_type}' not found.")
222
-
223
- progress(0.3, desc="⚗️ Running PaDEL...")
224
- yield f"⚗️ Running PaDEL...", None, gr.update(visible=False), None, current_state
225
  padeldescriptor(mol_dir=smi_file, d_file=output_csv, descriptortypes=descriptortypes, detectaromaticity=True, standardizenitro=True, standardizetautomers=True, threads=-1, removesalt=True, log=False, fingerprints=True)
226
-
227
  if not os.path.exists(output_csv) or os.path.getsize(output_csv) == 0:
228
  raise gr.Error("PaDEL failed to produce an output file. Check molecule validity.")
229
 
230
- progress(0.7, desc="📊 Processing results...")
231
- yield "📊 Processing results...", None, gr.update(visible=False), None, current_state
232
  df_X = pd.read_csv(output_csv).rename(columns={'Name': 'canonical_smiles'})
233
-
 
 
234
  final_df = pd.merge(input_df[['canonical_smiles', 'pIC50']], df_X, on='canonical_smiles', how='inner')
235
-
236
- current_state['fingerprint_data'] = final_df
237
- current_state['fingerprint_type'] = fingerprint_type
238
-
239
- progress(0.9, desc="🖼️ Generating molecule grid...")
240
- mols_html = create_molecule_html_grid(final_df, 'canonical_smiles', ['pIC50'])
241
 
 
 
 
242
  success_msg = f"✅ Success! Generated {len(df_X.columns) -1} descriptors for {len(final_df)} molecules."
243
- progress(1, desc="Completed!")
244
- yield success_msg, final_df, gr.update(visible=True), gr.update(value=mols_html, visible=True), current_state
245
-
246
- except Exception as e:
247
- raise gr.Error(f"Calculation failed: {e}")
248
-
249
  finally:
250
- if os.path.exists('molecule.smi'):
251
- os.remove('molecule.smi')
252
- if os.path.exists('fingerprints.csv'):
253
- os.remove('fingerprints.csv')
254
-
255
 
256
  # ==============================================================================
257
  # === STEP 3: MODEL TRAINING & PREDICTION FUNCTIONS ===
@@ -290,20 +300,9 @@ def run_regression_suite(df: pd.DataFrame, progress=gr.Progress()):
290
  X_test = pd.DataFrame(selector.transform(X_test), columns=X_test.columns[selector.get_support()], index=X_test.index)
291
  selected_features = X_train.columns.tolist()
292
 
293
- model_defs = [
294
- ('Linear Regression', LinearRegression()),
295
- ('Ridge', Ridge(random_state=42)),
296
- ('Lasso', Lasso(random_state=42)),
297
- ('Random Forest', RandomForestRegressor(random_state=42, n_jobs=-1)),
298
- # ('Gradient Boosting', GradientBoostingRegressor(random_state=42)) # <-- Commented out
299
- ]
300
- if _has_extra_libs:
301
- model_defs.extend([
302
- # ('XGBoost', xgb.XGBRegressor(random_state=42, n_jobs=-1, verbosity=0)), # <-- Commented out
303
- ('LightGBM', lgb.LGBMRegressor(random_state=42, n_jobs=-1, verbosity=-1)),
304
- # ('CatBoost', cb.CatBoostRegressor(random_state=42, verbose=0)) # <-- Commented out
305
- ])
306
-
307
  results_list, trained_models = [], {}
308
  for i, (name, model) in enumerate(model_defs):
309
  progress(0.2 + (i / len(model_defs)) * 0.8, desc=f"Training {name}...")
@@ -315,139 +314,89 @@ def run_regression_suite(df: pd.DataFrame, progress=gr.Progress()):
315
  results_df = pd.DataFrame(results_list).sort_values(by='R²', ascending=False).reset_index(drop=True)
316
  plotter = ModelPlotter(trained_models, X_test, y_test)
317
  model_run_results = ModelRunResult(results_df, plotter, trained_models, selected_features)
318
-
319
  model_choices = results_df['Model'].tolist()
320
  yield "✅ Model training & evaluation complete.", model_run_results, gr.Dropdown(choices=model_choices, interactive=True)
321
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
322
  def predict_on_upload(uploaded_file, model_name, current_state, progress=gr.Progress()):
323
  if not uploaded_file: raise gr.Error("Please upload a file.")
324
  if not model_name: raise gr.Error("Please select a trained model.")
325
  model_run_results = current_state.get('model_results')
326
  fingerprint_type = current_state.get('fingerprint_type')
327
  if not model_run_results or not fingerprint_type: raise gr.Error("Please run Steps 2 and 3 first.")
328
-
329
  model = model_run_results.models.get(model_name)
330
  selected_features = model_run_results.selected_features
331
  if model is None: raise gr.Error(f"Model '{model_name}' not found.")
332
-
333
  smi_file, output_csv = 'predict.smi', 'predict_fp.csv'
334
  try:
335
  progress(0, desc="Reading & processing new molecules..."); yield "Reading uploaded file...", None, None
336
  df_new = pd.read_csv(uploaded_file.name)
337
  if 'canonical_smiles' not in df_new.columns: raise gr.Error("CSV must contain a 'canonical_smiles' column.")
338
  df_new = df_new.reset_index().rename(columns={'index': 'mol_id'})
339
-
340
  padel_input = pd.DataFrame({'smiles': df_new['canonical_smiles'], 'name': df_new['mol_id']})
341
  padel_input.to_csv(smi_file, sep='\t', index=False, header=False)
342
  if os.path.exists(output_csv): os.remove(output_csv)
343
-
344
  progress(0.3, desc="Calculating fingerprints..."); yield "Calculating fingerprints for new molecules...", None, None
345
  padeldescriptor(mol_dir=smi_file, d_file=output_csv, descriptortypes=fp_config.get(fingerprint_type), detectaromaticity=True, standardizenitro=True, threads=-1, removesalt=True, log=False, fingerprints=True)
346
  if not os.path.exists(output_csv) or os.path.getsize(output_csv) == 0: raise gr.Error("PaDEL calculation failed for the uploaded molecules.")
347
-
348
  progress(0.7, desc="Aligning features and predicting..."); yield "Aligning features and predicting...", None, None
349
  df_fp = pd.read_csv(output_csv).rename(columns={'Name': 'mol_id'})
350
-
351
  X_new = df_fp.set_index('mol_id')
352
  X_new_aligned = X_new.reindex(columns=selected_features, fill_value=0)[selected_features]
353
-
354
  predictions = model.predict(X_new_aligned)
355
-
356
  results_subset = pd.DataFrame({'mol_id': X_new_aligned.index, 'predicted_pIC50': predictions})
357
  df_results = pd.merge(df_new, results_subset, on='mol_id', how='left')
358
 
359
  progress(0.9, desc="Generating visualization..."); yield "Generating visualization...", None, None
360
 
 
 
 
361
  df_grid_view = df_results.dropna(subset=['predicted_pIC50']).copy()
362
- mols_html = create_molecule_html_grid(
363
- df_grid_view,
364
- smiles_col='canonical_smiles',
365
- data_cols=['predicted_pIC50'],
366
- mol_id_col='mol_id'
367
- )
368
-
369
  progress(1, desc="Complete!"); yield "✅ Prediction complete.", df_results[['canonical_smiles', 'predicted_pIC50']], mols_html
370
  finally:
371
  if os.path.exists(smi_file): os.remove(smi_file)
372
  if os.path.exists(output_csv): os.remove(output_csv)
373
 
374
- # ==============================================================================
375
- # === HELPER FUNCTIONS ===
376
- # ==============================================================================
377
- def create_molecule_html_grid(df: pd.DataFrame, smiles_col: str, data_cols: list, mol_id_col: str = None):
378
- """
379
- Generates a self-contained HTML grid for a DataFrame of molecules.
380
-
381
- Args:
382
- df: DataFrame containing molecule data.
383
- smiles_col: The name of the column with the SMILES strings.
384
- data_cols: A list of column names to display alongside the molecule.
385
- mol_id_col: Optional column to use as a title for each molecule entry.
386
-
387
- Returns:
388
- An HTML string for display in Gradio's gr.HTML component.
389
- """
390
- if df.empty:
391
- return "<h3>No molecules to display.</h3>"
392
-
393
- # Step 1: Filter valid molecules with 2D conformers
394
- valid_mols = []
395
- for smiles in df[smiles_col]:
396
- mol = Chem.MolFromSmiles(smiles)
397
- if mol:
398
- try:
399
- rdDepictor.Compute2DCoords(mol)
400
- _ = mol.GetConformer() # Ensure conformer exists
401
- valid_mols.append((smiles, mol))
402
- except Exception as e:
403
- print(f"[Warning] Skipping molecule due to depiction error: {smiles} – {e}")
404
- continue
405
-
406
- if not valid_mols:
407
- return "<h3>No valid molecules could be rendered.</h3>"
408
-
409
- # Step 2: Generate SVGs
410
- images = []
411
- smiles_list = []
412
- for smiles, mol in valid_mols:
413
- try:
414
- svg = Draw.MolToSVG(mol, width=200, height=200)
415
- images.append(svg)
416
- smiles_list.append(smiles)
417
- except Exception as e:
418
- print(f"[Warning] Failed to draw molecule: {smiles} – {e}")
419
- continue
420
-
421
- # Filter the DataFrame to include only valid molecules
422
- df = df[df[smiles_col].isin(smiles_list)].copy()
423
- df['image'] = images
424
-
425
- # Step 3: Build HTML
426
- html = '<div style="display: flex; flex-wrap: wrap; gap: 20px;">'
427
- for _, row in df.iterrows():
428
- if not row['image']:
429
- continue
430
-
431
- html += '<div style="border: 1px solid #ddd; border-radius: 5px; padding: 10px; text-align: center; width: 220px;">'
432
- html += row['image'] # SVG
433
-
434
- # Optional: molecule ID
435
- if mol_id_col and mol_id_col in row:
436
- html += f'<strong>{row[mol_id_col]}</strong><br>'
437
-
438
- # Show other data values
439
- for col in data_cols:
440
- if col in row:
441
- value = row[col]
442
- if isinstance(value, float):
443
- value = f"{value:.2f}"
444
- html += f'<span><strong>{col}:</strong> {value}</span><br>'
445
- html += '</div>'
446
-
447
- html += '</div>'
448
- return html
449
-
450
-
451
  # ==============================================================================
452
  # === GRADIO INTERFACE ===
453
  # ==============================================================================
@@ -457,6 +406,7 @@ with gr.Blocks(theme=gr.themes.Default(primary_hue="blue", secondary_hue="sky"),
457
  app_state = gr.State({})
458
  with gr.Tabs():
459
  with gr.Tab("Step 1: Data Collection & EDA"):
 
460
  gr.Markdown("## Fetch Bioactivity Data from ChEMBL and Perform Exploratory Analysis")
461
  with gr.Row():
462
  query_input = gr.Textbox(label="Target Query", placeholder="e.g., acetylcholinesterase, BRAF kinase", scale=3)
@@ -495,6 +445,7 @@ with gr.Blocks(theme=gr.themes.Default(primary_hue="blue", secondary_hue="sky"),
495
  hdonors_stats_output = gr.Dataframe(label="Stats for H-Donors")
496
  hacceptors_stats_output = gr.Dataframe(label="Stats for H-Acceptors")
497
  with gr.Tab("Step 2: Feature Engineering"):
 
498
  gr.Markdown("## Calculate Molecular Fingerprints using PaDEL")
499
  with gr.Row():
500
  fingerprint_dropdown = gr.Dropdown(choices=FP_list, value='PubChem' if 'PubChem' in FP_list else None, label="Select Fingerprint Method", scale=3)
@@ -504,6 +455,7 @@ with gr.Blocks(theme=gr.themes.Default(primary_hue="blue", secondary_hue="sky"),
504
  download_s2 = gr.DownloadButton("Download Feature Data (CSV)", variant="secondary", visible=False)
505
  mols_grid_s2 = gr.HTML(label="Interactive Molecule Viewer")
506
  with gr.Tab("Step 3: Model Training & Prediction"):
 
507
  gr.Markdown("## Train Regression Models and Predict pIC50")
508
  with gr.Tabs():
509
  with gr.Tab("Model Training & Evaluation"):
@@ -563,12 +515,11 @@ with gr.Blocks(theme=gr.themes.Default(primary_hue="blue", secondary_hue="sky"),
563
  process_btn.click(fn=process_and_analyze_wrapper, inputs=[selected_target_dropdown, bioactivity_class_selector, app_state], outputs=[df_output_s1, freq_plot_output, scatter_plot_output, pic50_plot_output, pic50_stats_output, mw_plot_output, mw_stats_output, logp_plot_output, logp_stats_output, hdonors_plot_output, hdonors_stats_output, hacceptors_plot_output, hacceptors_stats_output, status_step1_process, app_state])
564
  bioactivity_class_selector.change(fn=update_analysis_on_filter_change, inputs=[bioactivity_class_selector, app_state], outputs=[df_output_s1, freq_plot_output, scatter_plot_output, pic50_plot_output, pic50_stats_output, mw_plot_output, mw_stats_output, logp_plot_output, logp_stats_output, hdonors_plot_output, hdonors_stats_output, hacceptors_plot_output, hacceptors_stats_output, status_step1_process], show_progress="minimal")
565
  calculate_fp_btn.click(fn=calculate_fingerprints, inputs=[app_state, fingerprint_dropdown], outputs=[status_step2, output_df_s2, download_s2, mols_grid_s2, app_state])
566
-
567
  @download_s2.click(inputs=app_state, outputs=download_s2, show_progress="hidden")
568
  def download_handler(current_state):
569
  df_to_download = current_state.get('fingerprint_data')
570
  return save_dataframe_as_csv(df_to_download)
571
-
572
  train_models_btn.click(fn=handle_model_training, inputs=[app_state], outputs=[status_step3_train, model_results_df, model_selector_s3, app_state])
573
  for listener in [model_selector_s3.change, feature_count_s3.change]: listener(fn=update_analysis_plots, inputs=[model_selector_s3, feature_count_s3, app_state], outputs=[validation_plot_s3, feature_plot_s3], show_progress="minimal")
574
  predict_btn_s3.click(fn=predict_on_upload, inputs=[upload_predict_file, model_selector_s3, app_state], outputs=[status_step3_predict, prediction_results_df, prediction_mols_grid])
 
7
  import glob
8
  import time
9
  import warnings
 
10
 
11
  # Chemistry and Cheminformatics
12
  from rdkit import Chem
13
+ from rdkit.Chem import Descriptors, Lipinski
14
  from chembl_webresource_client.new_client import new_client
15
  from padelpy import padeldescriptor
16
+ from rdkit.Chem.Draw import rdMolDraw2D
17
+ from rdkit.Chem import Draw
18
+ import base64
19
+ from io import BytesIO
20
+
21
 
22
  # Plotting and Visualization
23
  import matplotlib.pyplot as plt
 
67
  sns.set_theme(style='whitegrid')
68
 
69
  # --- FINGERPRINT CONFIGURATION ---
70
+ # Create a dummy PubChem.xml if no XML files are found, to ensure fp_config is populated
71
+ if not glob.glob('*.xml'):
72
+ try:
73
+ with open('PubChem.xml', 'w') as f:
74
+ f.write('')
75
+ except IOError:
76
+ warnings.warn("Could not create a dummy 'PubChem.xml' file. Fingerprint calculation might fail if no .xml files are present.")
 
 
 
 
77
 
78
+ xml_files = sorted(glob.glob('*.xml'))
79
  if not xml_files:
80
  warnings.warn(
81
+ "No descriptor .xml files found. Fingerprint calculation will not be possible. "
82
+ "Please place descriptor XML files in the same directory as the script."
83
  )
 
 
84
  fp_config = {os.path.splitext(os.path.basename(file))[0]: file for file in xml_files}
85
  FP_list = sorted(list(fp_config.keys()))
86
 
 
154
  raise gr.Error(f"Could not fetch SMILES from ChEMBL: {e}")
155
  df = df[df.standard_value.notna()]
156
  df = df[df.canonical_smiles.notna()]
157
+ # DEBUG FIX: Added drop_duplicates to align with notebook logic and ensure unique SMILES for merging.
158
  df.drop_duplicates(['canonical_smiles'], inplace=True)
159
  df["standard_value"] = pd.to_numeric(df["standard_value"], errors='coerce')
160
  df.dropna(subset=['standard_value'], inplace=True)
 
200
  # === STEP 2: FEATURE ENGINEERING FUNCTIONS ===
201
  # ==============================================================================
202
 
203
+ def create_molecule_grid_html(df, smiles_col='canonical_smiles', max_mols=20):
204
+ html_parts = ['<div style="display: flex; flex-wrap: wrap; gap: 10px;">']
205
+ for idx, row in df.head(max_mols).iterrows():
206
+ smiles = row[smiles_col]
207
+ pic50 = row['pIC50']
208
+ mol = Chem.MolFromSmiles(smiles)
209
+ if mol:
210
+ # Generate molecule image
211
+ img = Draw.MolToImage(mol, size=(200, 200))
212
+ # Convert to base64
213
+ buffered = BytesIO()
214
+ img.save(buffered, format="PNG")
215
+ img_str = base64.b64encode(buffered.getvalue()).decode()
216
+ # Create HTML for this molecule
217
+ mol_html = f'''
218
+ <div style="border: 1px solid #ccc; padding: 10px; border-radius: 5px; text-align: center;">
219
+ <img src="data:image/png;base64,{img_str}" alt="Molecule" style="max-width: 200px;">
220
+ <p><strong>pIC50:</strong> {pic50:.2f}</p>
221
+ <p style="font-size: 10px; word-break: break-all;">{smiles}</p>
222
+ </div>
223
+ '''
224
+ html_parts.append(mol_html)
225
+ html_parts.append('</div>')
226
+ return ''.join(html_parts)
227
+
228
  def calculate_fingerprints(current_state, fingerprint_type, progress=gr.Progress()):
229
  input_df = current_state.get('cleaned_data')
230
+ if input_df is None or input_df.empty: raise gr.Error("No cleaned data found. Please complete Step 1.")
231
+ if not fingerprint_type: raise gr.Error("Please select a fingerprint type.")
232
+ progress(0, desc="Starting..."); yield f"🧪 Starting fingerprint calculation...", None, gr.update(visible=False), None, current_state
 
 
 
 
 
233
  try:
234
  smi_file, output_csv = 'molecule.smi', 'fingerprints.csv'
235
+
236
+ # DEBUG FIX: Switched to a safe merge instead of risky concat.
237
+ # Use canonical_smiles as the unique ID for PaDEL, since it was deduplicated in Step 1.
238
  input_df[['canonical_smiles', 'canonical_smiles']].to_csv(smi_file, sep='\t', index=False, header=False)
239
+
240
+ if os.path.exists(output_csv): os.remove(output_csv)
 
241
  descriptortypes = fp_config.get(fingerprint_type)
242
+ if not descriptortypes: raise gr.Error(f"Descriptor XML for '{fingerprint_type}' not found.")
243
+
244
+ progress(0.3, desc="⚗️ Running PaDEL..."); yield f"⚗️ Running PaDEL...", None, gr.update(visible=False), None, current_state
 
 
245
  padeldescriptor(mol_dir=smi_file, d_file=output_csv, descriptortypes=descriptortypes, detectaromaticity=True, standardizenitro=True, standardizetautomers=True, threads=-1, removesalt=True, log=False, fingerprints=True)
 
246
  if not os.path.exists(output_csv) or os.path.getsize(output_csv) == 0:
247
  raise gr.Error("PaDEL failed to produce an output file. Check molecule validity.")
248
 
249
+ progress(0.7, desc="📊 Processing results..."); yield "📊 Processing results...", None, gr.update(visible=False), None, current_state
 
250
  df_X = pd.read_csv(output_csv).rename(columns={'Name': 'canonical_smiles'})
251
+
252
+ # Safely merge fingerprints with original data. 'inner' ensures that only molecules
253
+ # for which fingerprints were successfully calculated are included.
254
  final_df = pd.merge(input_df[['canonical_smiles', 'pIC50']], df_X, on='canonical_smiles', how='inner')
 
 
 
 
 
 
255
 
256
+ current_state['fingerprint_data'] = final_df; current_state['fingerprint_type'] = fingerprint_type
257
+ progress(0.9, desc="🖼️ Generating molecule grid...")
258
+ mols_html = create_molecule_grid_html(final_df)
259
  success_msg = f"✅ Success! Generated {len(df_X.columns) -1} descriptors for {len(final_df)} molecules."
260
+ progress(1, desc="Completed!"); yield success_msg, final_df, gr.update(visible=True), gr.update(value=mols_html, visible=True), current_state
261
+ except Exception as e: raise gr.Error(f"Calculation failed: {e}")
 
 
 
 
262
  finally:
263
+ if os.path.exists('molecule.smi'): os.remove('molecule.smi')
264
+ if os.path.exists('fingerprints.csv'): os.remove('fingerprints.csv')
 
 
 
265
 
266
  # ==============================================================================
267
  # === STEP 3: MODEL TRAINING & PREDICTION FUNCTIONS ===
 
300
  X_test = pd.DataFrame(selector.transform(X_test), columns=X_test.columns[selector.get_support()], index=X_test.index)
301
  selected_features = X_train.columns.tolist()
302
 
303
+ model_defs = [('Linear Regression', LinearRegression()), ('Ridge', Ridge(random_state=42)), ('Lasso', Lasso(random_state=42)), ('Random Forest', RandomForestRegressor(random_state=42, n_jobs=-1)), ('Gradient Boosting', GradientBoostingRegressor(random_state=42))]
304
+ if _has_extra_libs: model_defs.extend([('XGBoost', xgb.XGBRegressor(random_state=42, n_jobs=-1, verbosity=0)), ('LightGBM', lgb.LGBMRegressor(random_state=42, n_jobs=-1, verbosity=-1)), ('CatBoost', cb.CatBoostRegressor(random_state=42, verbose=0))])
305
+
 
 
 
 
 
 
 
 
 
 
 
306
  results_list, trained_models = [], {}
307
  for i, (name, model) in enumerate(model_defs):
308
  progress(0.2 + (i / len(model_defs)) * 0.8, desc=f"Training {name}...")
 
314
  results_df = pd.DataFrame(results_list).sort_values(by='R²', ascending=False).reset_index(drop=True)
315
  plotter = ModelPlotter(trained_models, X_test, y_test)
316
  model_run_results = ModelRunResult(results_df, plotter, trained_models, selected_features)
317
+
318
  model_choices = results_df['Model'].tolist()
319
  yield "✅ Model training & evaluation complete.", model_run_results, gr.Dropdown(choices=model_choices, interactive=True)
320
 
321
+ def create_prediction_grid_html(df, smiles_col='canonical_smiles', pred_col='predicted_pIC50', max_mols=20):
322
+ html_parts = ['<div style="display: flex; flex-wrap: wrap; gap: 10px;">']
323
+ for idx, row in df.head(max_mols).iterrows():
324
+ smiles = row[smiles_col]
325
+ pred_pic50 = row[pred_col]
326
+ if pd.isna(pred_pic50):
327
+ continue
328
+ mol = Chem.MolFromSmiles(smiles)
329
+ if mol:
330
+ # Generate molecule image
331
+ img = Draw.MolToImage(mol, size=(200, 200))
332
+ # Convert to base64
333
+ buffered = BytesIO()
334
+ img.save(buffered, format="PNG")
335
+ img_str = base64.b64encode(buffered.getvalue()).decode()
336
+ # Create HTML for this molecule
337
+ mol_html = f'''
338
+ <div style="border: 1px solid #ccc; padding: 10px; border-radius: 5px; text-align: center;">
339
+ <img src="data:image/png;base64,{img_str}" alt="Molecule" style="max-width: 200px;">
340
+ <p><strong>Predicted pIC50:</strong> {pred_pic50:.2f}</p>
341
+ <p style="font-size: 10px; word-break: break-all;">{smiles}</p>
342
+ </div>
343
+ '''
344
+ html_parts.append(mol_html)
345
+ html_parts.append('</div>')
346
+ return ''.join(html_parts)
347
+
348
  def predict_on_upload(uploaded_file, model_name, current_state, progress=gr.Progress()):
349
  if not uploaded_file: raise gr.Error("Please upload a file.")
350
  if not model_name: raise gr.Error("Please select a trained model.")
351
  model_run_results = current_state.get('model_results')
352
  fingerprint_type = current_state.get('fingerprint_type')
353
  if not model_run_results or not fingerprint_type: raise gr.Error("Please run Steps 2 and 3 first.")
354
+
355
  model = model_run_results.models.get(model_name)
356
  selected_features = model_run_results.selected_features
357
  if model is None: raise gr.Error(f"Model '{model_name}' not found.")
358
+
359
  smi_file, output_csv = 'predict.smi', 'predict_fp.csv'
360
  try:
361
  progress(0, desc="Reading & processing new molecules..."); yield "Reading uploaded file...", None, None
362
  df_new = pd.read_csv(uploaded_file.name)
363
  if 'canonical_smiles' not in df_new.columns: raise gr.Error("CSV must contain a 'canonical_smiles' column.")
364
  df_new = df_new.reset_index().rename(columns={'index': 'mol_id'})
365
+
366
  padel_input = pd.DataFrame({'smiles': df_new['canonical_smiles'], 'name': df_new['mol_id']})
367
  padel_input.to_csv(smi_file, sep='\t', index=False, header=False)
368
  if os.path.exists(output_csv): os.remove(output_csv)
369
+
370
  progress(0.3, desc="Calculating fingerprints..."); yield "Calculating fingerprints for new molecules...", None, None
371
  padeldescriptor(mol_dir=smi_file, d_file=output_csv, descriptortypes=fp_config.get(fingerprint_type), detectaromaticity=True, standardizenitro=True, threads=-1, removesalt=True, log=False, fingerprints=True)
372
  if not os.path.exists(output_csv) or os.path.getsize(output_csv) == 0: raise gr.Error("PaDEL calculation failed for the uploaded molecules.")
373
+
374
  progress(0.7, desc="Aligning features and predicting..."); yield "Aligning features and predicting...", None, None
375
  df_fp = pd.read_csv(output_csv).rename(columns={'Name': 'mol_id'})
376
+
377
  X_new = df_fp.set_index('mol_id')
378
  X_new_aligned = X_new.reindex(columns=selected_features, fill_value=0)[selected_features]
379
+
380
  predictions = model.predict(X_new_aligned)
381
+
382
  results_subset = pd.DataFrame({'mol_id': X_new_aligned.index, 'predicted_pIC50': predictions})
383
  df_results = pd.merge(df_new, results_subset, on='mol_id', how='left')
384
 
385
  progress(0.9, desc="Generating visualization..."); yield "Generating visualization...", None, None
386
 
387
+ # DEBUG FIX: The main fix for the KeyError.
388
+ # Create a copy, rename the column *before* calling mols2grid.
389
+ # This is more robust than relying on the library's 'rename' parameter.
390
  df_grid_view = df_results.dropna(subset=['predicted_pIC50']).copy()
391
+ mols_html = "<h3>No molecules with successful predictions to display.</h3>"
392
+ if not df_grid_view.empty:
393
+ mols_html = create_prediction_grid_html(df_grid_view)
394
+
 
 
 
395
  progress(1, desc="Complete!"); yield "✅ Prediction complete.", df_results[['canonical_smiles', 'predicted_pIC50']], mols_html
396
  finally:
397
  if os.path.exists(smi_file): os.remove(smi_file)
398
  if os.path.exists(output_csv): os.remove(output_csv)
399
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
400
  # ==============================================================================
401
  # === GRADIO INTERFACE ===
402
  # ==============================================================================
 
406
  app_state = gr.State({})
407
  with gr.Tabs():
408
  with gr.Tab("Step 1: Data Collection & EDA"):
409
+ # UI Definition for Step 1...
410
  gr.Markdown("## Fetch Bioactivity Data from ChEMBL and Perform Exploratory Analysis")
411
  with gr.Row():
412
  query_input = gr.Textbox(label="Target Query", placeholder="e.g., acetylcholinesterase, BRAF kinase", scale=3)
 
445
  hdonors_stats_output = gr.Dataframe(label="Stats for H-Donors")
446
  hacceptors_stats_output = gr.Dataframe(label="Stats for H-Acceptors")
447
  with gr.Tab("Step 2: Feature Engineering"):
448
+ # UI Definition for Step 2...
449
  gr.Markdown("## Calculate Molecular Fingerprints using PaDEL")
450
  with gr.Row():
451
  fingerprint_dropdown = gr.Dropdown(choices=FP_list, value='PubChem' if 'PubChem' in FP_list else None, label="Select Fingerprint Method", scale=3)
 
455
  download_s2 = gr.DownloadButton("Download Feature Data (CSV)", variant="secondary", visible=False)
456
  mols_grid_s2 = gr.HTML(label="Interactive Molecule Viewer")
457
  with gr.Tab("Step 3: Model Training & Prediction"):
458
+ # UI Definition for Step 3...
459
  gr.Markdown("## Train Regression Models and Predict pIC50")
460
  with gr.Tabs():
461
  with gr.Tab("Model Training & Evaluation"):
 
515
  process_btn.click(fn=process_and_analyze_wrapper, inputs=[selected_target_dropdown, bioactivity_class_selector, app_state], outputs=[df_output_s1, freq_plot_output, scatter_plot_output, pic50_plot_output, pic50_stats_output, mw_plot_output, mw_stats_output, logp_plot_output, logp_stats_output, hdonors_plot_output, hdonors_stats_output, hacceptors_plot_output, hacceptors_stats_output, status_step1_process, app_state])
516
  bioactivity_class_selector.change(fn=update_analysis_on_filter_change, inputs=[bioactivity_class_selector, app_state], outputs=[df_output_s1, freq_plot_output, scatter_plot_output, pic50_plot_output, pic50_stats_output, mw_plot_output, mw_stats_output, logp_plot_output, logp_stats_output, hdonors_plot_output, hdonors_stats_output, hacceptors_plot_output, hacceptors_stats_output, status_step1_process], show_progress="minimal")
517
  calculate_fp_btn.click(fn=calculate_fingerprints, inputs=[app_state, fingerprint_dropdown], outputs=[status_step2, output_df_s2, download_s2, mols_grid_s2, app_state])
518
+ # The download button click handler was incorrect, it should take the dataframe from the state
519
  @download_s2.click(inputs=app_state, outputs=download_s2, show_progress="hidden")
520
  def download_handler(current_state):
521
  df_to_download = current_state.get('fingerprint_data')
522
  return save_dataframe_as_csv(df_to_download)
 
523
  train_models_btn.click(fn=handle_model_training, inputs=[app_state], outputs=[status_step3_train, model_results_df, model_selector_s3, app_state])
524
  for listener in [model_selector_s3.change, feature_count_s3.change]: listener(fn=update_analysis_plots, inputs=[model_selector_s3, feature_count_s3, app_state], outputs=[validation_plot_s3, feature_plot_s3], show_progress="minimal")
525
  predict_btn_s3.click(fn=predict_on_upload, inputs=[upload_predict_file, model_selector_s3, app_state], outputs=[status_step3_predict, prediction_results_df, prediction_mols_grid])