Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -7,13 +7,17 @@ import os
|
|
7 |
import glob
|
8 |
import time
|
9 |
import warnings
|
10 |
-
import base64
|
11 |
|
12 |
# Chemistry and Cheminformatics
|
13 |
from rdkit import Chem
|
14 |
-
from rdkit.Chem import Descriptors, Lipinski
|
15 |
from chembl_webresource_client.new_client import new_client
|
16 |
from padelpy import padeldescriptor
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
# Plotting and Visualization
|
19 |
import matplotlib.pyplot as plt
|
@@ -63,25 +67,20 @@ warnings.filterwarnings("ignore")
|
|
63 |
sns.set_theme(style='whitegrid')
|
64 |
|
65 |
# --- FINGERPRINT CONFIGURATION ---
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
"
|
73 |
-
)
|
74 |
-
xml_files = []
|
75 |
-
else:
|
76 |
-
xml_files = sorted(glob.glob(os.path.join(DESCRIPTOR_DIR, '*.xml')))
|
77 |
|
|
|
78 |
if not xml_files:
|
79 |
warnings.warn(
|
80 |
-
|
81 |
-
"
|
82 |
)
|
83 |
-
|
84 |
-
# The key is the filename without extension; the value is the full path to the file
|
85 |
fp_config = {os.path.splitext(os.path.basename(file))[0]: file for file in xml_files}
|
86 |
FP_list = sorted(list(fp_config.keys()))
|
87 |
|
@@ -155,6 +154,7 @@ def clean_and_process_data(df):
|
|
155 |
raise gr.Error(f"Could not fetch SMILES from ChEMBL: {e}")
|
156 |
df = df[df.standard_value.notna()]
|
157 |
df = df[df.canonical_smiles.notna()]
|
|
|
158 |
df.drop_duplicates(['canonical_smiles'], inplace=True)
|
159 |
df["standard_value"] = pd.to_numeric(df["standard_value"], errors='coerce')
|
160 |
df.dropna(subset=['standard_value'], inplace=True)
|
@@ -200,58 +200,68 @@ def mannwhitney_test(df, descriptor):
|
|
200 |
# === STEP 2: FEATURE ENGINEERING FUNCTIONS ===
|
201 |
# ==============================================================================
|
202 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
203 |
def calculate_fingerprints(current_state, fingerprint_type, progress=gr.Progress()):
|
204 |
input_df = current_state.get('cleaned_data')
|
205 |
-
if input_df is None or input_df.empty:
|
206 |
-
|
207 |
-
|
208 |
-
raise gr.Error("Please select a fingerprint type.")
|
209 |
-
|
210 |
-
progress(0, desc="Starting...")
|
211 |
-
yield f"🧪 Starting fingerprint calculation...", None, gr.update(visible=False), None, current_state
|
212 |
-
|
213 |
try:
|
214 |
smi_file, output_csv = 'molecule.smi', 'fingerprints.csv'
|
|
|
|
|
|
|
215 |
input_df[['canonical_smiles', 'canonical_smiles']].to_csv(smi_file, sep='\t', index=False, header=False)
|
216 |
-
|
217 |
-
if os.path.exists(output_csv):
|
218 |
-
os.remove(output_csv)
|
219 |
descriptortypes = fp_config.get(fingerprint_type)
|
220 |
-
if not descriptortypes:
|
221 |
-
|
222 |
-
|
223 |
-
progress(0.3, desc="⚗️ Running PaDEL...")
|
224 |
-
yield f"⚗️ Running PaDEL...", None, gr.update(visible=False), None, current_state
|
225 |
padeldescriptor(mol_dir=smi_file, d_file=output_csv, descriptortypes=descriptortypes, detectaromaticity=True, standardizenitro=True, standardizetautomers=True, threads=-1, removesalt=True, log=False, fingerprints=True)
|
226 |
-
|
227 |
if not os.path.exists(output_csv) or os.path.getsize(output_csv) == 0:
|
228 |
raise gr.Error("PaDEL failed to produce an output file. Check molecule validity.")
|
229 |
|
230 |
-
progress(0.7, desc="📊 Processing results...")
|
231 |
-
yield "📊 Processing results...", None, gr.update(visible=False), None, current_state
|
232 |
df_X = pd.read_csv(output_csv).rename(columns={'Name': 'canonical_smiles'})
|
233 |
-
|
|
|
|
|
234 |
final_df = pd.merge(input_df[['canonical_smiles', 'pIC50']], df_X, on='canonical_smiles', how='inner')
|
235 |
-
|
236 |
-
current_state['fingerprint_data'] = final_df
|
237 |
-
current_state['fingerprint_type'] = fingerprint_type
|
238 |
-
|
239 |
-
progress(0.9, desc="🖼️ Generating molecule grid...")
|
240 |
-
mols_html = create_molecule_html_grid(final_df, 'canonical_smiles', ['pIC50'])
|
241 |
|
|
|
|
|
|
|
242 |
success_msg = f"✅ Success! Generated {len(df_X.columns) -1} descriptors for {len(final_df)} molecules."
|
243 |
-
progress(1, desc="Completed!")
|
244 |
-
|
245 |
-
|
246 |
-
except Exception as e:
|
247 |
-
raise gr.Error(f"Calculation failed: {e}")
|
248 |
-
|
249 |
finally:
|
250 |
-
if os.path.exists('molecule.smi'):
|
251 |
-
|
252 |
-
if os.path.exists('fingerprints.csv'):
|
253 |
-
os.remove('fingerprints.csv')
|
254 |
-
|
255 |
|
256 |
# ==============================================================================
|
257 |
# === STEP 3: MODEL TRAINING & PREDICTION FUNCTIONS ===
|
@@ -290,20 +300,9 @@ def run_regression_suite(df: pd.DataFrame, progress=gr.Progress()):
|
|
290 |
X_test = pd.DataFrame(selector.transform(X_test), columns=X_test.columns[selector.get_support()], index=X_test.index)
|
291 |
selected_features = X_train.columns.tolist()
|
292 |
|
293 |
-
model_defs = [
|
294 |
-
|
295 |
-
|
296 |
-
('Lasso', Lasso(random_state=42)),
|
297 |
-
('Random Forest', RandomForestRegressor(random_state=42, n_jobs=-1)),
|
298 |
-
# ('Gradient Boosting', GradientBoostingRegressor(random_state=42)) # <-- Commented out
|
299 |
-
]
|
300 |
-
if _has_extra_libs:
|
301 |
-
model_defs.extend([
|
302 |
-
# ('XGBoost', xgb.XGBRegressor(random_state=42, n_jobs=-1, verbosity=0)), # <-- Commented out
|
303 |
-
('LightGBM', lgb.LGBMRegressor(random_state=42, n_jobs=-1, verbosity=-1)),
|
304 |
-
# ('CatBoost', cb.CatBoostRegressor(random_state=42, verbose=0)) # <-- Commented out
|
305 |
-
])
|
306 |
-
|
307 |
results_list, trained_models = [], {}
|
308 |
for i, (name, model) in enumerate(model_defs):
|
309 |
progress(0.2 + (i / len(model_defs)) * 0.8, desc=f"Training {name}...")
|
@@ -315,139 +314,89 @@ def run_regression_suite(df: pd.DataFrame, progress=gr.Progress()):
|
|
315 |
results_df = pd.DataFrame(results_list).sort_values(by='R²', ascending=False).reset_index(drop=True)
|
316 |
plotter = ModelPlotter(trained_models, X_test, y_test)
|
317 |
model_run_results = ModelRunResult(results_df, plotter, trained_models, selected_features)
|
318 |
-
|
319 |
model_choices = results_df['Model'].tolist()
|
320 |
yield "✅ Model training & evaluation complete.", model_run_results, gr.Dropdown(choices=model_choices, interactive=True)
|
321 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
322 |
def predict_on_upload(uploaded_file, model_name, current_state, progress=gr.Progress()):
|
323 |
if not uploaded_file: raise gr.Error("Please upload a file.")
|
324 |
if not model_name: raise gr.Error("Please select a trained model.")
|
325 |
model_run_results = current_state.get('model_results')
|
326 |
fingerprint_type = current_state.get('fingerprint_type')
|
327 |
if not model_run_results or not fingerprint_type: raise gr.Error("Please run Steps 2 and 3 first.")
|
328 |
-
|
329 |
model = model_run_results.models.get(model_name)
|
330 |
selected_features = model_run_results.selected_features
|
331 |
if model is None: raise gr.Error(f"Model '{model_name}' not found.")
|
332 |
-
|
333 |
smi_file, output_csv = 'predict.smi', 'predict_fp.csv'
|
334 |
try:
|
335 |
progress(0, desc="Reading & processing new molecules..."); yield "Reading uploaded file...", None, None
|
336 |
df_new = pd.read_csv(uploaded_file.name)
|
337 |
if 'canonical_smiles' not in df_new.columns: raise gr.Error("CSV must contain a 'canonical_smiles' column.")
|
338 |
df_new = df_new.reset_index().rename(columns={'index': 'mol_id'})
|
339 |
-
|
340 |
padel_input = pd.DataFrame({'smiles': df_new['canonical_smiles'], 'name': df_new['mol_id']})
|
341 |
padel_input.to_csv(smi_file, sep='\t', index=False, header=False)
|
342 |
if os.path.exists(output_csv): os.remove(output_csv)
|
343 |
-
|
344 |
progress(0.3, desc="Calculating fingerprints..."); yield "Calculating fingerprints for new molecules...", None, None
|
345 |
padeldescriptor(mol_dir=smi_file, d_file=output_csv, descriptortypes=fp_config.get(fingerprint_type), detectaromaticity=True, standardizenitro=True, threads=-1, removesalt=True, log=False, fingerprints=True)
|
346 |
if not os.path.exists(output_csv) or os.path.getsize(output_csv) == 0: raise gr.Error("PaDEL calculation failed for the uploaded molecules.")
|
347 |
-
|
348 |
progress(0.7, desc="Aligning features and predicting..."); yield "Aligning features and predicting...", None, None
|
349 |
df_fp = pd.read_csv(output_csv).rename(columns={'Name': 'mol_id'})
|
350 |
-
|
351 |
X_new = df_fp.set_index('mol_id')
|
352 |
X_new_aligned = X_new.reindex(columns=selected_features, fill_value=0)[selected_features]
|
353 |
-
|
354 |
predictions = model.predict(X_new_aligned)
|
355 |
-
|
356 |
results_subset = pd.DataFrame({'mol_id': X_new_aligned.index, 'predicted_pIC50': predictions})
|
357 |
df_results = pd.merge(df_new, results_subset, on='mol_id', how='left')
|
358 |
|
359 |
progress(0.9, desc="Generating visualization..."); yield "Generating visualization...", None, None
|
360 |
|
|
|
|
|
|
|
361 |
df_grid_view = df_results.dropna(subset=['predicted_pIC50']).copy()
|
362 |
-
mols_html =
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
mol_id_col='mol_id'
|
367 |
-
)
|
368 |
-
|
369 |
progress(1, desc="Complete!"); yield "✅ Prediction complete.", df_results[['canonical_smiles', 'predicted_pIC50']], mols_html
|
370 |
finally:
|
371 |
if os.path.exists(smi_file): os.remove(smi_file)
|
372 |
if os.path.exists(output_csv): os.remove(output_csv)
|
373 |
|
374 |
-
# ==============================================================================
|
375 |
-
# === HELPER FUNCTIONS ===
|
376 |
-
# ==============================================================================
|
377 |
-
def create_molecule_html_grid(df: pd.DataFrame, smiles_col: str, data_cols: list, mol_id_col: str = None):
|
378 |
-
"""
|
379 |
-
Generates a self-contained HTML grid for a DataFrame of molecules.
|
380 |
-
|
381 |
-
Args:
|
382 |
-
df: DataFrame containing molecule data.
|
383 |
-
smiles_col: The name of the column with the SMILES strings.
|
384 |
-
data_cols: A list of column names to display alongside the molecule.
|
385 |
-
mol_id_col: Optional column to use as a title for each molecule entry.
|
386 |
-
|
387 |
-
Returns:
|
388 |
-
An HTML string for display in Gradio's gr.HTML component.
|
389 |
-
"""
|
390 |
-
if df.empty:
|
391 |
-
return "<h3>No molecules to display.</h3>"
|
392 |
-
|
393 |
-
# Step 1: Filter valid molecules with 2D conformers
|
394 |
-
valid_mols = []
|
395 |
-
for smiles in df[smiles_col]:
|
396 |
-
mol = Chem.MolFromSmiles(smiles)
|
397 |
-
if mol:
|
398 |
-
try:
|
399 |
-
rdDepictor.Compute2DCoords(mol)
|
400 |
-
_ = mol.GetConformer() # Ensure conformer exists
|
401 |
-
valid_mols.append((smiles, mol))
|
402 |
-
except Exception as e:
|
403 |
-
print(f"[Warning] Skipping molecule due to depiction error: {smiles} – {e}")
|
404 |
-
continue
|
405 |
-
|
406 |
-
if not valid_mols:
|
407 |
-
return "<h3>No valid molecules could be rendered.</h3>"
|
408 |
-
|
409 |
-
# Step 2: Generate SVGs
|
410 |
-
images = []
|
411 |
-
smiles_list = []
|
412 |
-
for smiles, mol in valid_mols:
|
413 |
-
try:
|
414 |
-
svg = Draw.MolToSVG(mol, width=200, height=200)
|
415 |
-
images.append(svg)
|
416 |
-
smiles_list.append(smiles)
|
417 |
-
except Exception as e:
|
418 |
-
print(f"[Warning] Failed to draw molecule: {smiles} – {e}")
|
419 |
-
continue
|
420 |
-
|
421 |
-
# Filter the DataFrame to include only valid molecules
|
422 |
-
df = df[df[smiles_col].isin(smiles_list)].copy()
|
423 |
-
df['image'] = images
|
424 |
-
|
425 |
-
# Step 3: Build HTML
|
426 |
-
html = '<div style="display: flex; flex-wrap: wrap; gap: 20px;">'
|
427 |
-
for _, row in df.iterrows():
|
428 |
-
if not row['image']:
|
429 |
-
continue
|
430 |
-
|
431 |
-
html += '<div style="border: 1px solid #ddd; border-radius: 5px; padding: 10px; text-align: center; width: 220px;">'
|
432 |
-
html += row['image'] # SVG
|
433 |
-
|
434 |
-
# Optional: molecule ID
|
435 |
-
if mol_id_col and mol_id_col in row:
|
436 |
-
html += f'<strong>{row[mol_id_col]}</strong><br>'
|
437 |
-
|
438 |
-
# Show other data values
|
439 |
-
for col in data_cols:
|
440 |
-
if col in row:
|
441 |
-
value = row[col]
|
442 |
-
if isinstance(value, float):
|
443 |
-
value = f"{value:.2f}"
|
444 |
-
html += f'<span><strong>{col}:</strong> {value}</span><br>'
|
445 |
-
html += '</div>'
|
446 |
-
|
447 |
-
html += '</div>'
|
448 |
-
return html
|
449 |
-
|
450 |
-
|
451 |
# ==============================================================================
|
452 |
# === GRADIO INTERFACE ===
|
453 |
# ==============================================================================
|
@@ -457,6 +406,7 @@ with gr.Blocks(theme=gr.themes.Default(primary_hue="blue", secondary_hue="sky"),
|
|
457 |
app_state = gr.State({})
|
458 |
with gr.Tabs():
|
459 |
with gr.Tab("Step 1: Data Collection & EDA"):
|
|
|
460 |
gr.Markdown("## Fetch Bioactivity Data from ChEMBL and Perform Exploratory Analysis")
|
461 |
with gr.Row():
|
462 |
query_input = gr.Textbox(label="Target Query", placeholder="e.g., acetylcholinesterase, BRAF kinase", scale=3)
|
@@ -495,6 +445,7 @@ with gr.Blocks(theme=gr.themes.Default(primary_hue="blue", secondary_hue="sky"),
|
|
495 |
hdonors_stats_output = gr.Dataframe(label="Stats for H-Donors")
|
496 |
hacceptors_stats_output = gr.Dataframe(label="Stats for H-Acceptors")
|
497 |
with gr.Tab("Step 2: Feature Engineering"):
|
|
|
498 |
gr.Markdown("## Calculate Molecular Fingerprints using PaDEL")
|
499 |
with gr.Row():
|
500 |
fingerprint_dropdown = gr.Dropdown(choices=FP_list, value='PubChem' if 'PubChem' in FP_list else None, label="Select Fingerprint Method", scale=3)
|
@@ -504,6 +455,7 @@ with gr.Blocks(theme=gr.themes.Default(primary_hue="blue", secondary_hue="sky"),
|
|
504 |
download_s2 = gr.DownloadButton("Download Feature Data (CSV)", variant="secondary", visible=False)
|
505 |
mols_grid_s2 = gr.HTML(label="Interactive Molecule Viewer")
|
506 |
with gr.Tab("Step 3: Model Training & Prediction"):
|
|
|
507 |
gr.Markdown("## Train Regression Models and Predict pIC50")
|
508 |
with gr.Tabs():
|
509 |
with gr.Tab("Model Training & Evaluation"):
|
@@ -563,12 +515,11 @@ with gr.Blocks(theme=gr.themes.Default(primary_hue="blue", secondary_hue="sky"),
|
|
563 |
process_btn.click(fn=process_and_analyze_wrapper, inputs=[selected_target_dropdown, bioactivity_class_selector, app_state], outputs=[df_output_s1, freq_plot_output, scatter_plot_output, pic50_plot_output, pic50_stats_output, mw_plot_output, mw_stats_output, logp_plot_output, logp_stats_output, hdonors_plot_output, hdonors_stats_output, hacceptors_plot_output, hacceptors_stats_output, status_step1_process, app_state])
|
564 |
bioactivity_class_selector.change(fn=update_analysis_on_filter_change, inputs=[bioactivity_class_selector, app_state], outputs=[df_output_s1, freq_plot_output, scatter_plot_output, pic50_plot_output, pic50_stats_output, mw_plot_output, mw_stats_output, logp_plot_output, logp_stats_output, hdonors_plot_output, hdonors_stats_output, hacceptors_plot_output, hacceptors_stats_output, status_step1_process], show_progress="minimal")
|
565 |
calculate_fp_btn.click(fn=calculate_fingerprints, inputs=[app_state, fingerprint_dropdown], outputs=[status_step2, output_df_s2, download_s2, mols_grid_s2, app_state])
|
566 |
-
|
567 |
@download_s2.click(inputs=app_state, outputs=download_s2, show_progress="hidden")
|
568 |
def download_handler(current_state):
|
569 |
df_to_download = current_state.get('fingerprint_data')
|
570 |
return save_dataframe_as_csv(df_to_download)
|
571 |
-
|
572 |
train_models_btn.click(fn=handle_model_training, inputs=[app_state], outputs=[status_step3_train, model_results_df, model_selector_s3, app_state])
|
573 |
for listener in [model_selector_s3.change, feature_count_s3.change]: listener(fn=update_analysis_plots, inputs=[model_selector_s3, feature_count_s3, app_state], outputs=[validation_plot_s3, feature_plot_s3], show_progress="minimal")
|
574 |
predict_btn_s3.click(fn=predict_on_upload, inputs=[upload_predict_file, model_selector_s3, app_state], outputs=[status_step3_predict, prediction_results_df, prediction_mols_grid])
|
|
|
7 |
import glob
|
8 |
import time
|
9 |
import warnings
|
|
|
10 |
|
11 |
# Chemistry and Cheminformatics
|
12 |
from rdkit import Chem
|
13 |
+
from rdkit.Chem import Descriptors, Lipinski
|
14 |
from chembl_webresource_client.new_client import new_client
|
15 |
from padelpy import padeldescriptor
|
16 |
+
from rdkit.Chem.Draw import rdMolDraw2D
|
17 |
+
from rdkit.Chem import Draw
|
18 |
+
import base64
|
19 |
+
from io import BytesIO
|
20 |
+
|
21 |
|
22 |
# Plotting and Visualization
|
23 |
import matplotlib.pyplot as plt
|
|
|
67 |
sns.set_theme(style='whitegrid')
|
68 |
|
69 |
# --- FINGERPRINT CONFIGURATION ---
|
70 |
+
# Create a dummy PubChem.xml if no XML files are found, to ensure fp_config is populated
|
71 |
+
if not glob.glob('*.xml'):
|
72 |
+
try:
|
73 |
+
with open('PubChem.xml', 'w') as f:
|
74 |
+
f.write('')
|
75 |
+
except IOError:
|
76 |
+
warnings.warn("Could not create a dummy 'PubChem.xml' file. Fingerprint calculation might fail if no .xml files are present.")
|
|
|
|
|
|
|
|
|
77 |
|
78 |
+
xml_files = sorted(glob.glob('*.xml'))
|
79 |
if not xml_files:
|
80 |
warnings.warn(
|
81 |
+
"No descriptor .xml files found. Fingerprint calculation will not be possible. "
|
82 |
+
"Please place descriptor XML files in the same directory as the script."
|
83 |
)
|
|
|
|
|
84 |
fp_config = {os.path.splitext(os.path.basename(file))[0]: file for file in xml_files}
|
85 |
FP_list = sorted(list(fp_config.keys()))
|
86 |
|
|
|
154 |
raise gr.Error(f"Could not fetch SMILES from ChEMBL: {e}")
|
155 |
df = df[df.standard_value.notna()]
|
156 |
df = df[df.canonical_smiles.notna()]
|
157 |
+
# DEBUG FIX: Added drop_duplicates to align with notebook logic and ensure unique SMILES for merging.
|
158 |
df.drop_duplicates(['canonical_smiles'], inplace=True)
|
159 |
df["standard_value"] = pd.to_numeric(df["standard_value"], errors='coerce')
|
160 |
df.dropna(subset=['standard_value'], inplace=True)
|
|
|
200 |
# === STEP 2: FEATURE ENGINEERING FUNCTIONS ===
|
201 |
# ==============================================================================
|
202 |
|
203 |
+
def create_molecule_grid_html(df, smiles_col='canonical_smiles', max_mols=20):
|
204 |
+
html_parts = ['<div style="display: flex; flex-wrap: wrap; gap: 10px;">']
|
205 |
+
for idx, row in df.head(max_mols).iterrows():
|
206 |
+
smiles = row[smiles_col]
|
207 |
+
pic50 = row['pIC50']
|
208 |
+
mol = Chem.MolFromSmiles(smiles)
|
209 |
+
if mol:
|
210 |
+
# Generate molecule image
|
211 |
+
img = Draw.MolToImage(mol, size=(200, 200))
|
212 |
+
# Convert to base64
|
213 |
+
buffered = BytesIO()
|
214 |
+
img.save(buffered, format="PNG")
|
215 |
+
img_str = base64.b64encode(buffered.getvalue()).decode()
|
216 |
+
# Create HTML for this molecule
|
217 |
+
mol_html = f'''
|
218 |
+
<div style="border: 1px solid #ccc; padding: 10px; border-radius: 5px; text-align: center;">
|
219 |
+
<img src="data:image/png;base64,{img_str}" alt="Molecule" style="max-width: 200px;">
|
220 |
+
<p><strong>pIC50:</strong> {pic50:.2f}</p>
|
221 |
+
<p style="font-size: 10px; word-break: break-all;">{smiles}</p>
|
222 |
+
</div>
|
223 |
+
'''
|
224 |
+
html_parts.append(mol_html)
|
225 |
+
html_parts.append('</div>')
|
226 |
+
return ''.join(html_parts)
|
227 |
+
|
228 |
def calculate_fingerprints(current_state, fingerprint_type, progress=gr.Progress()):
|
229 |
input_df = current_state.get('cleaned_data')
|
230 |
+
if input_df is None or input_df.empty: raise gr.Error("No cleaned data found. Please complete Step 1.")
|
231 |
+
if not fingerprint_type: raise gr.Error("Please select a fingerprint type.")
|
232 |
+
progress(0, desc="Starting..."); yield f"🧪 Starting fingerprint calculation...", None, gr.update(visible=False), None, current_state
|
|
|
|
|
|
|
|
|
|
|
233 |
try:
|
234 |
smi_file, output_csv = 'molecule.smi', 'fingerprints.csv'
|
235 |
+
|
236 |
+
# DEBUG FIX: Switched to a safe merge instead of risky concat.
|
237 |
+
# Use canonical_smiles as the unique ID for PaDEL, since it was deduplicated in Step 1.
|
238 |
input_df[['canonical_smiles', 'canonical_smiles']].to_csv(smi_file, sep='\t', index=False, header=False)
|
239 |
+
|
240 |
+
if os.path.exists(output_csv): os.remove(output_csv)
|
|
|
241 |
descriptortypes = fp_config.get(fingerprint_type)
|
242 |
+
if not descriptortypes: raise gr.Error(f"Descriptor XML for '{fingerprint_type}' not found.")
|
243 |
+
|
244 |
+
progress(0.3, desc="⚗️ Running PaDEL..."); yield f"⚗️ Running PaDEL...", None, gr.update(visible=False), None, current_state
|
|
|
|
|
245 |
padeldescriptor(mol_dir=smi_file, d_file=output_csv, descriptortypes=descriptortypes, detectaromaticity=True, standardizenitro=True, standardizetautomers=True, threads=-1, removesalt=True, log=False, fingerprints=True)
|
|
|
246 |
if not os.path.exists(output_csv) or os.path.getsize(output_csv) == 0:
|
247 |
raise gr.Error("PaDEL failed to produce an output file. Check molecule validity.")
|
248 |
|
249 |
+
progress(0.7, desc="📊 Processing results..."); yield "📊 Processing results...", None, gr.update(visible=False), None, current_state
|
|
|
250 |
df_X = pd.read_csv(output_csv).rename(columns={'Name': 'canonical_smiles'})
|
251 |
+
|
252 |
+
# Safely merge fingerprints with original data. 'inner' ensures that only molecules
|
253 |
+
# for which fingerprints were successfully calculated are included.
|
254 |
final_df = pd.merge(input_df[['canonical_smiles', 'pIC50']], df_X, on='canonical_smiles', how='inner')
|
|
|
|
|
|
|
|
|
|
|
|
|
255 |
|
256 |
+
current_state['fingerprint_data'] = final_df; current_state['fingerprint_type'] = fingerprint_type
|
257 |
+
progress(0.9, desc="🖼️ Generating molecule grid...")
|
258 |
+
mols_html = create_molecule_grid_html(final_df)
|
259 |
success_msg = f"✅ Success! Generated {len(df_X.columns) -1} descriptors for {len(final_df)} molecules."
|
260 |
+
progress(1, desc="Completed!"); yield success_msg, final_df, gr.update(visible=True), gr.update(value=mols_html, visible=True), current_state
|
261 |
+
except Exception as e: raise gr.Error(f"Calculation failed: {e}")
|
|
|
|
|
|
|
|
|
262 |
finally:
|
263 |
+
if os.path.exists('molecule.smi'): os.remove('molecule.smi')
|
264 |
+
if os.path.exists('fingerprints.csv'): os.remove('fingerprints.csv')
|
|
|
|
|
|
|
265 |
|
266 |
# ==============================================================================
|
267 |
# === STEP 3: MODEL TRAINING & PREDICTION FUNCTIONS ===
|
|
|
300 |
X_test = pd.DataFrame(selector.transform(X_test), columns=X_test.columns[selector.get_support()], index=X_test.index)
|
301 |
selected_features = X_train.columns.tolist()
|
302 |
|
303 |
+
model_defs = [('Linear Regression', LinearRegression()), ('Ridge', Ridge(random_state=42)), ('Lasso', Lasso(random_state=42)), ('Random Forest', RandomForestRegressor(random_state=42, n_jobs=-1)), ('Gradient Boosting', GradientBoostingRegressor(random_state=42))]
|
304 |
+
if _has_extra_libs: model_defs.extend([('XGBoost', xgb.XGBRegressor(random_state=42, n_jobs=-1, verbosity=0)), ('LightGBM', lgb.LGBMRegressor(random_state=42, n_jobs=-1, verbosity=-1)), ('CatBoost', cb.CatBoostRegressor(random_state=42, verbose=0))])
|
305 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
306 |
results_list, trained_models = [], {}
|
307 |
for i, (name, model) in enumerate(model_defs):
|
308 |
progress(0.2 + (i / len(model_defs)) * 0.8, desc=f"Training {name}...")
|
|
|
314 |
results_df = pd.DataFrame(results_list).sort_values(by='R²', ascending=False).reset_index(drop=True)
|
315 |
plotter = ModelPlotter(trained_models, X_test, y_test)
|
316 |
model_run_results = ModelRunResult(results_df, plotter, trained_models, selected_features)
|
317 |
+
|
318 |
model_choices = results_df['Model'].tolist()
|
319 |
yield "✅ Model training & evaluation complete.", model_run_results, gr.Dropdown(choices=model_choices, interactive=True)
|
320 |
|
321 |
+
def create_prediction_grid_html(df, smiles_col='canonical_smiles', pred_col='predicted_pIC50', max_mols=20):
|
322 |
+
html_parts = ['<div style="display: flex; flex-wrap: wrap; gap: 10px;">']
|
323 |
+
for idx, row in df.head(max_mols).iterrows():
|
324 |
+
smiles = row[smiles_col]
|
325 |
+
pred_pic50 = row[pred_col]
|
326 |
+
if pd.isna(pred_pic50):
|
327 |
+
continue
|
328 |
+
mol = Chem.MolFromSmiles(smiles)
|
329 |
+
if mol:
|
330 |
+
# Generate molecule image
|
331 |
+
img = Draw.MolToImage(mol, size=(200, 200))
|
332 |
+
# Convert to base64
|
333 |
+
buffered = BytesIO()
|
334 |
+
img.save(buffered, format="PNG")
|
335 |
+
img_str = base64.b64encode(buffered.getvalue()).decode()
|
336 |
+
# Create HTML for this molecule
|
337 |
+
mol_html = f'''
|
338 |
+
<div style="border: 1px solid #ccc; padding: 10px; border-radius: 5px; text-align: center;">
|
339 |
+
<img src="data:image/png;base64,{img_str}" alt="Molecule" style="max-width: 200px;">
|
340 |
+
<p><strong>Predicted pIC50:</strong> {pred_pic50:.2f}</p>
|
341 |
+
<p style="font-size: 10px; word-break: break-all;">{smiles}</p>
|
342 |
+
</div>
|
343 |
+
'''
|
344 |
+
html_parts.append(mol_html)
|
345 |
+
html_parts.append('</div>')
|
346 |
+
return ''.join(html_parts)
|
347 |
+
|
348 |
def predict_on_upload(uploaded_file, model_name, current_state, progress=gr.Progress()):
|
349 |
if not uploaded_file: raise gr.Error("Please upload a file.")
|
350 |
if not model_name: raise gr.Error("Please select a trained model.")
|
351 |
model_run_results = current_state.get('model_results')
|
352 |
fingerprint_type = current_state.get('fingerprint_type')
|
353 |
if not model_run_results or not fingerprint_type: raise gr.Error("Please run Steps 2 and 3 first.")
|
354 |
+
|
355 |
model = model_run_results.models.get(model_name)
|
356 |
selected_features = model_run_results.selected_features
|
357 |
if model is None: raise gr.Error(f"Model '{model_name}' not found.")
|
358 |
+
|
359 |
smi_file, output_csv = 'predict.smi', 'predict_fp.csv'
|
360 |
try:
|
361 |
progress(0, desc="Reading & processing new molecules..."); yield "Reading uploaded file...", None, None
|
362 |
df_new = pd.read_csv(uploaded_file.name)
|
363 |
if 'canonical_smiles' not in df_new.columns: raise gr.Error("CSV must contain a 'canonical_smiles' column.")
|
364 |
df_new = df_new.reset_index().rename(columns={'index': 'mol_id'})
|
365 |
+
|
366 |
padel_input = pd.DataFrame({'smiles': df_new['canonical_smiles'], 'name': df_new['mol_id']})
|
367 |
padel_input.to_csv(smi_file, sep='\t', index=False, header=False)
|
368 |
if os.path.exists(output_csv): os.remove(output_csv)
|
369 |
+
|
370 |
progress(0.3, desc="Calculating fingerprints..."); yield "Calculating fingerprints for new molecules...", None, None
|
371 |
padeldescriptor(mol_dir=smi_file, d_file=output_csv, descriptortypes=fp_config.get(fingerprint_type), detectaromaticity=True, standardizenitro=True, threads=-1, removesalt=True, log=False, fingerprints=True)
|
372 |
if not os.path.exists(output_csv) or os.path.getsize(output_csv) == 0: raise gr.Error("PaDEL calculation failed for the uploaded molecules.")
|
373 |
+
|
374 |
progress(0.7, desc="Aligning features and predicting..."); yield "Aligning features and predicting...", None, None
|
375 |
df_fp = pd.read_csv(output_csv).rename(columns={'Name': 'mol_id'})
|
376 |
+
|
377 |
X_new = df_fp.set_index('mol_id')
|
378 |
X_new_aligned = X_new.reindex(columns=selected_features, fill_value=0)[selected_features]
|
379 |
+
|
380 |
predictions = model.predict(X_new_aligned)
|
381 |
+
|
382 |
results_subset = pd.DataFrame({'mol_id': X_new_aligned.index, 'predicted_pIC50': predictions})
|
383 |
df_results = pd.merge(df_new, results_subset, on='mol_id', how='left')
|
384 |
|
385 |
progress(0.9, desc="Generating visualization..."); yield "Generating visualization...", None, None
|
386 |
|
387 |
+
# DEBUG FIX: The main fix for the KeyError.
|
388 |
+
# Create a copy, rename the column *before* calling mols2grid.
|
389 |
+
# This is more robust than relying on the library's 'rename' parameter.
|
390 |
df_grid_view = df_results.dropna(subset=['predicted_pIC50']).copy()
|
391 |
+
mols_html = "<h3>No molecules with successful predictions to display.</h3>"
|
392 |
+
if not df_grid_view.empty:
|
393 |
+
mols_html = create_prediction_grid_html(df_grid_view)
|
394 |
+
|
|
|
|
|
|
|
395 |
progress(1, desc="Complete!"); yield "✅ Prediction complete.", df_results[['canonical_smiles', 'predicted_pIC50']], mols_html
|
396 |
finally:
|
397 |
if os.path.exists(smi_file): os.remove(smi_file)
|
398 |
if os.path.exists(output_csv): os.remove(output_csv)
|
399 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
400 |
# ==============================================================================
|
401 |
# === GRADIO INTERFACE ===
|
402 |
# ==============================================================================
|
|
|
406 |
app_state = gr.State({})
|
407 |
with gr.Tabs():
|
408 |
with gr.Tab("Step 1: Data Collection & EDA"):
|
409 |
+
# UI Definition for Step 1...
|
410 |
gr.Markdown("## Fetch Bioactivity Data from ChEMBL and Perform Exploratory Analysis")
|
411 |
with gr.Row():
|
412 |
query_input = gr.Textbox(label="Target Query", placeholder="e.g., acetylcholinesterase, BRAF kinase", scale=3)
|
|
|
445 |
hdonors_stats_output = gr.Dataframe(label="Stats for H-Donors")
|
446 |
hacceptors_stats_output = gr.Dataframe(label="Stats for H-Acceptors")
|
447 |
with gr.Tab("Step 2: Feature Engineering"):
|
448 |
+
# UI Definition for Step 2...
|
449 |
gr.Markdown("## Calculate Molecular Fingerprints using PaDEL")
|
450 |
with gr.Row():
|
451 |
fingerprint_dropdown = gr.Dropdown(choices=FP_list, value='PubChem' if 'PubChem' in FP_list else None, label="Select Fingerprint Method", scale=3)
|
|
|
455 |
download_s2 = gr.DownloadButton("Download Feature Data (CSV)", variant="secondary", visible=False)
|
456 |
mols_grid_s2 = gr.HTML(label="Interactive Molecule Viewer")
|
457 |
with gr.Tab("Step 3: Model Training & Prediction"):
|
458 |
+
# UI Definition for Step 3...
|
459 |
gr.Markdown("## Train Regression Models and Predict pIC50")
|
460 |
with gr.Tabs():
|
461 |
with gr.Tab("Model Training & Evaluation"):
|
|
|
515 |
process_btn.click(fn=process_and_analyze_wrapper, inputs=[selected_target_dropdown, bioactivity_class_selector, app_state], outputs=[df_output_s1, freq_plot_output, scatter_plot_output, pic50_plot_output, pic50_stats_output, mw_plot_output, mw_stats_output, logp_plot_output, logp_stats_output, hdonors_plot_output, hdonors_stats_output, hacceptors_plot_output, hacceptors_stats_output, status_step1_process, app_state])
|
516 |
bioactivity_class_selector.change(fn=update_analysis_on_filter_change, inputs=[bioactivity_class_selector, app_state], outputs=[df_output_s1, freq_plot_output, scatter_plot_output, pic50_plot_output, pic50_stats_output, mw_plot_output, mw_stats_output, logp_plot_output, logp_stats_output, hdonors_plot_output, hdonors_stats_output, hacceptors_plot_output, hacceptors_stats_output, status_step1_process], show_progress="minimal")
|
517 |
calculate_fp_btn.click(fn=calculate_fingerprints, inputs=[app_state, fingerprint_dropdown], outputs=[status_step2, output_df_s2, download_s2, mols_grid_s2, app_state])
|
518 |
+
# The download button click handler was incorrect, it should take the dataframe from the state
|
519 |
@download_s2.click(inputs=app_state, outputs=download_s2, show_progress="hidden")
|
520 |
def download_handler(current_state):
|
521 |
df_to_download = current_state.get('fingerprint_data')
|
522 |
return save_dataframe_as_csv(df_to_download)
|
|
|
523 |
train_models_btn.click(fn=handle_model_training, inputs=[app_state], outputs=[status_step3_train, model_results_df, model_selector_s3, app_state])
|
524 |
for listener in [model_selector_s3.change, feature_count_s3.change]: listener(fn=update_analysis_plots, inputs=[model_selector_s3, feature_count_s3, app_state], outputs=[validation_plot_s3, feature_plot_s3], show_progress="minimal")
|
525 |
predict_btn_s3.click(fn=predict_on_upload, inputs=[upload_predict_file, model_selector_s3, app_state], outputs=[status_step3_predict, prediction_results_df, prediction_mols_grid])
|