Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,70 +1,69 @@
|
|
| 1 |
# --- IMPORTS ---
|
| 2 |
# Core and Data Handling
|
| 3 |
-
import gradio as gr
|
| 4 |
-
import pandas as pd
|
| 5 |
-
import numpy as np
|
| 6 |
-
import os
|
| 7 |
-
import glob
|
| 8 |
-
import time
|
| 9 |
-
import warnings
|
| 10 |
|
| 11 |
# Chemistry and Cheminformatics
|
| 12 |
-
from rdkit import Chem
|
| 13 |
-
from rdkit.Chem import Descriptors, Lipinski
|
| 14 |
-
from chembl_webresource_client.new_client import new_client
|
| 15 |
-
from padelpy import padeldescriptor
|
| 16 |
-
import mols2grid #
|
| 17 |
|
| 18 |
# Plotting and Visualization
|
| 19 |
-
import matplotlib.pyplot as plt
|
| 20 |
-
import seaborn as sns
|
| 21 |
-
from scipy import stats
|
| 22 |
-
from scipy.stats import mannwhitneyu
|
| 23 |
|
| 24 |
# Machine Learning Models and Metrics
|
| 25 |
-
from sklearn.model_selection import train_test_split
|
| 26 |
-
from sklearn.feature_selection import VarianceThreshold
|
| 27 |
-
from sklearn.linear_model import (
|
| 28 |
-
LinearRegression, Ridge, Lasso, ElasticNet, BayesianRidge,
|
| 29 |
-
HuberRegressor, PassiveAggressiveRegressor, OrthogonalMatchingPursuit,
|
| 30 |
-
LassoLars
|
| 31 |
)
|
| 32 |
-
from sklearn.tree import DecisionTreeRegressor
|
| 33 |
-
from sklearn.ensemble import (
|
| 34 |
-
RandomForestRegressor, GradientBoostingRegressor, ExtraTreesRegressor,
|
| 35 |
-
AdaBoostRegressor
|
| 36 |
)
|
| 37 |
-
from sklearn.neighbors import KNeighborsRegressor
|
| 38 |
-
from sklearn.dummy import DummyRegressor
|
| 39 |
-
from sklearn.metrics import (
|
| 40 |
-
mean_absolute_error, mean_squared_error, r2_score
|
| 41 |
)
|
| 42 |
|
| 43 |
# A placeholder class to store all results from a modeling run
|
| 44 |
-
class ModelRunResult:
|
| 45 |
-
def __init__(self, dataframe, plotter, models, selected_features):
|
| 46 |
-
self.dataframe = dataframe
|
| 47 |
-
self.plotter = plotter
|
| 48 |
-
self.models = models
|
| 49 |
-
self.selected_features = selected_features
|
| 50 |
|
| 51 |
# Optional Advanced Models
|
| 52 |
-
try:
|
| 53 |
-
import xgboost as xgb
|
| 54 |
-
import lightgbm as lgb
|
| 55 |
-
import catboost as cb
|
| 56 |
-
_has_extra_libs = True
|
| 57 |
-
except ImportError:
|
| 58 |
-
_has_extra_libs = False
|
| 59 |
-
warnings.warn("Optional libraries (xgboost, lightgbm, catboost) not found. Some models will be unavailable.")
|
| 60 |
|
| 61 |
# --- GLOBAL CONFIGURATION & SETUP ---
|
| 62 |
-
warnings.filterwarnings("ignore")
|
| 63 |
-
sns.set_theme(style='whitegrid')
|
| 64 |
|
| 65 |
# --- FINGERPRINT CONFIGURATION ---
|
| 66 |
DESCRIPTOR_DIR = "padel_descriptors"
|
| 67 |
-
|
| 68 |
# Check if the descriptor directory exists and contains files
|
| 69 |
if not os.path.isdir(DESCRIPTOR_DIR):
|
| 70 |
warnings.warn(
|
|
@@ -74,411 +73,893 @@ if not os.path.isdir(DESCRIPTOR_DIR):
|
|
| 74 |
xml_files = []
|
| 75 |
else:
|
| 76 |
xml_files = sorted(glob.glob(os.path.join(DESCRIPTOR_DIR, '*.xml')))
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
)
|
| 83 |
|
| 84 |
# The key is the filename without extension; the value is the full path to the file
|
| 85 |
fp_config = {os.path.splitext(os.path.basename(file))[0]: file for file in xml_files}
|
| 86 |
FP_list = sorted(list(fp_config.keys()))
|
| 87 |
|
| 88 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
# ==============================================================================
|
| 90 |
# === STEP 1: CORE DATA COLLECTION & EDA FUNCTIONS ===
|
| 91 |
# ==============================================================================
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
df
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
df
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
df_copy =
|
| 117 |
-
df_copy
|
| 118 |
-
df_copy
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
molar
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
df_copy['
|
| 128 |
-
|
| 129 |
-
lambda x: "inactive" if pd.notna(x) and x >= 10000 else ("active" if pd.notna(x) and x <= 1000 else "intermediate") #
|
| 130 |
)
|
| 131 |
-
return df_copy.drop(columns=['standard_value', 'standard_value_norm'])
|
| 132 |
-
|
| 133 |
-
def lipinski_descriptors(smiles_series):
|
| 134 |
-
moldata, valid_smiles = [], []
|
| 135 |
-
for elem in smiles_series:
|
| 136 |
-
if elem and isinstance(elem, str):
|
| 137 |
-
mol = Chem.MolFromSmiles(elem)
|
| 138 |
-
if mol:
|
| 139 |
-
moldata.append(mol)
|
| 140 |
-
valid_smiles.append(elem)
|
| 141 |
-
descriptor_rows = []
|
| 142 |
-
for mol in moldata:
|
| 143 |
-
row = [Descriptors.MolWt(mol), Descriptors.MolLogP(mol), Lipinski.NumHDonors(mol), Lipinski.NumHAcceptors(mol)]
|
| 144 |
-
descriptor_rows.append(row)
|
| 145 |
-
columnNames = ["MW", "LogP", "NumHDonors", "NumHAcceptors"]
|
| 146 |
-
if not descriptor_rows: return pd.DataFrame(columns=columnNames), []
|
| 147 |
-
return pd.DataFrame(data=np.array(descriptor_rows), columns=columnNames), valid_smiles
|
| 148 |
-
|
| 149 |
-
def clean_and_process_data(df):
|
| 150 |
-
if df is None or df.empty: raise gr.Error("No data to process. Please fetch data first.")
|
| 151 |
-
if "canonical_smiles" not in df.columns or df["canonical_smiles"].isnull().all():
|
| 152 |
-
try:
|
| 153 |
-
df["canonical_smiles"] = [c.get("molecule_structures", {}).get("canonical_smiles") for c in new_client.molecule.get(list(df["molecule_chembl_id"]))]
|
| 154 |
-
except Exception as e:
|
| 155 |
-
raise gr.Error(f"Could not fetch SMILES from ChEMBL: {e}")
|
| 156 |
-
df = df[df.standard_value.notna()]
|
| 157 |
-
df = df[df.canonical_smiles.notna()]
|
| 158 |
-
df.drop_duplicates(['canonical_smiles'], inplace=True)
|
| 159 |
-
df["standard_value"] = pd.to_numeric(df["standard_value"], errors='coerce')
|
| 160 |
-
df.dropna(subset=['standard_value'], inplace=True)
|
| 161 |
-
df_processed = pIC50_calc(df)
|
| 162 |
-
df_processed = df_processed[df_processed.pIC50.notna()]
|
| 163 |
-
if df_processed.empty: return pd.DataFrame(), "No compounds remaining after pIC50 calculation."
|
| 164 |
-
df_lipinski, valid_smiles = lipinski_descriptors(df_processed['canonical_smiles'])
|
| 165 |
-
if not valid_smiles: return pd.DataFrame(), "No valid SMILES could be processed for Lipinski descriptors."
|
| 166 |
-
df_processed = df_processed[df_processed['canonical_smiles'].isin(valid_smiles)].reset_index(drop=True)
|
| 167 |
-
df_lipinski = df_lipinski.reset_index(drop=True)
|
| 168 |
-
df_final = pd.concat([df_processed, df_lipinski], axis=1)
|
| 169 |
-
return df_final, f"Processing complete. {len(df_final)} compounds remain after cleaning."
|
| 170 |
-
|
| 171 |
-
def run_eda_analysis(df, selected_classes):
|
| 172 |
-
if df is None or df.empty: raise gr.Error("No data available for analysis.")
|
| 173 |
-
df_filtered = df[df.bioactivity_class.isin(selected_classes)].copy()
|
| 174 |
-
if df_filtered.empty: return (None, None, None, pd.DataFrame(), None, pd.DataFrame(), None, pd.DataFrame(), None, pd.DataFrame(), None, pd.DataFrame(), "No data for selected classes.")
|
| 175 |
-
plots = [create_frequency_plot(df_filtered), create_scatter_plot(df_filtered)]
|
| 176 |
-
stats_dfs = []
|
| 177 |
-
for desc in ['pIC50', 'MW', 'LogP', 'NumHDonors', 'NumHAcceptors']:
|
| 178 |
-
plots.append(create_boxplot(df_filtered, desc))
|
| 179 |
-
stats_dfs.append(mannwhitney_test(df_filtered, desc))
|
| 180 |
-
plt.close('all')
|
| 181 |
-
return (plots[0], plots[1], plots[2], stats_dfs[0], plots[3], stats_dfs[1], plots[4], stats_dfs[2], plots[5], stats_dfs[3], plots[6], stats_dfs[4], f"EDA complete for {len(df_filtered)} compounds.")
|
| 182 |
-
|
| 183 |
-
def create_frequency_plot(df):
|
| 184 |
-
plt.figure(figsize=(5.5, 5.5)); sns.barplot(x=df['bioactivity_class'].value_counts().index, y=df['bioactivity_class'].value_counts().values, palette={'active': '#1f77b4', 'inactive': '#ff7f0e', 'intermediate': '#2ca02c'}); plt.xlabel('Bioactivity Class', fontsize=12); plt.ylabel('Frequency', fontsize=12); plt.title('Frequency of Bioactivity Classes', fontsize=14); return plt.gcf()
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
def
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
|
|
|
|
|
|
|
|
|
| 198 |
|
| 199 |
# ==============================================================================
|
| 200 |
# === STEP 2: FEATURE ENGINEERING FUNCTIONS ===
|
| 201 |
# ==============================================================================
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
input_df
|
| 205 |
-
|
| 206 |
-
if not fingerprint_type:
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
|
| 211 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
if not descriptortypes: raise gr.Error(f"Descriptor XML for '{fingerprint_type}' not found.") #
|
| 216 |
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
|
| 222 |
-
progress(0.7, desc="📊 Processing results...")
|
| 223 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
|
| 225 |
-
|
|
|
|
|
|
|
| 226 |
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
if os.path.exists('molecule.smi'): os.remove('molecule.smi') #
|
| 235 |
-
if os.path.exists('fingerprints.csv'): os.remove('fingerprints.csv') #
|
| 236 |
|
| 237 |
# ==============================================================================
|
| 238 |
-
# === STEP 3:
|
| 239 |
# ==============================================================================
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
selected_features = X_train.columns.tolist() #
|
| 273 |
-
|
| 274 |
-
model_defs = [
|
| 275 |
-
('Linear Regression', LinearRegression()),
|
| 276 |
-
('Ridge', Ridge(random_state=42)),
|
| 277 |
-
('Lasso', Lasso(random_state=42)),
|
| 278 |
-
('Random Forest', RandomForestRegressor(random_state=42, n_jobs=-1)),
|
| 279 |
-
# ('Gradient Boosting', GradientBoostingRegressor(random_state=42)) # <-- Commented out
|
| 280 |
-
]
|
| 281 |
-
if _has_extra_libs:
|
| 282 |
-
model_defs.extend([
|
| 283 |
-
# ('XGBoost', xgb.XGBRegressor(random_state=42, n_jobs=-1, verbosity=0)), # <-- Commented out
|
| 284 |
-
('LightGBM', lgb.LGBMRegressor(random_state=42, n_jobs=-1, verbosity=-1)),
|
| 285 |
-
# ('CatBoost', cb.CatBoostRegressor(random_state=42, verbose=0)) # <-- Commented out
|
| 286 |
-
])
|
| 287 |
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 299 |
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
if not uploaded_file: raise gr.Error("Please upload a file.") #
|
| 305 |
-
if not model_name: raise gr.Error("Please select a trained model.") #
|
| 306 |
-
model_run_results = current_state.get('model_results') #
|
| 307 |
-
fingerprint_type = current_state.get('fingerprint_type') #
|
| 308 |
-
if not model_run_results or not fingerprint_type: raise gr.Error("Please run Steps 2 and 3 first.") #
|
| 309 |
|
| 310 |
-
model = model_run_results.models.get(model_name)
|
| 311 |
-
selected_features = model_run_results.selected_features
|
| 312 |
-
if model is None:
|
|
|
|
| 313 |
|
| 314 |
-
smi_file, output_csv = 'predict.smi', 'predict_fp.csv'
|
| 315 |
-
try:
|
| 316 |
-
progress(0, desc="Reading & processing new molecules...")
|
| 317 |
-
|
| 318 |
-
if 'canonical_smiles' not in df_new.columns: raise gr.Error("CSV must contain a 'canonical_smiles' column.") #
|
| 319 |
-
df_new = df_new.reset_index().rename(columns={'index': 'mol_id'}) #
|
| 320 |
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
|
|
|
| 324 |
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 328 |
|
| 329 |
-
progress(0.
|
| 330 |
-
|
| 331 |
|
| 332 |
-
|
| 333 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 334 |
|
| 335 |
-
|
|
|
|
| 336 |
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 341 |
|
| 342 |
-
df_grid_view = df_results.dropna(subset=['predicted_pIC50']).copy()
|
| 343 |
-
mols_html = "<h3>No molecules with successful predictions to display.</h3>"
|
| 344 |
-
if not df_grid_view.empty: #
|
| 345 |
-
df_grid_view.rename(columns={"predicted_pIC50": "Predicted pIC50"}, inplace=True) #
|
| 346 |
-
mols_html = mols2grid.display( #
|
| 347 |
-
df_grid_view, #
|
| 348 |
-
smiles_col='canonical_smiles', #
|
| 349 |
-
subset=['img', 'Predicted pIC50'], #
|
| 350 |
-
transform={"Predicted pIC50": lambda x: f"{x:.2f}"} #
|
| 351 |
-
)._repr_html_() #
|
| 352 |
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 357 |
|
| 358 |
# ==============================================================================
|
| 359 |
-
# === GRADIO INTERFACE ===
|
| 360 |
# ==============================================================================
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
gr.
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
gr.
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 434 |
|
| 435 |
# --- EVENT HANDLERS ---
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 473 |
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 478 |
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 482 |
|
| 483 |
-
|
| 484 |
-
demo.launch(debug=True) #
|
|
|
|
| 1 |
# --- IMPORTS ---
|
| 2 |
# Core and Data Handling
|
| 3 |
+
import gradio as gr
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import numpy as np
|
| 6 |
+
import os
|
| 7 |
+
import glob
|
| 8 |
+
import time
|
| 9 |
+
import warnings
|
| 10 |
|
| 11 |
# Chemistry and Cheminformatics
|
| 12 |
+
from rdkit import Chem
|
| 13 |
+
from rdkit.Chem import Descriptors, Lipinski
|
| 14 |
+
from chembl_webresource_client.new_client import new_client
|
| 15 |
+
from padelpy import padeldescriptor
|
| 16 |
+
# import mols2grid # This line will be removed
|
| 17 |
|
| 18 |
# Plotting and Visualization
|
| 19 |
+
import matplotlib.pyplot as plt
|
| 20 |
+
import seaborn as sns
|
| 21 |
+
from scipy import stats
|
| 22 |
+
from scipy.stats import mannwhitneyu
|
| 23 |
|
| 24 |
# Machine Learning Models and Metrics
|
| 25 |
+
from sklearn.model_selection import train_test_split
|
| 26 |
+
from sklearn.feature_selection import VarianceThreshold
|
| 27 |
+
from sklearn.linear_model import (
|
| 28 |
+
LinearRegression, Ridge, Lasso, ElasticNet, BayesianRidge,
|
| 29 |
+
HuberRegressor, PassiveAggressiveRegressor, OrthogonalMatchingPursuit,
|
| 30 |
+
LassoLars
|
| 31 |
)
|
| 32 |
+
from sklearn.tree import DecisionTreeRegressor
|
| 33 |
+
from sklearn.ensemble import (
|
| 34 |
+
RandomForestRegressor, GradientBoostingRegressor, ExtraTreesRegressor,
|
| 35 |
+
AdaBoostRegressor
|
| 36 |
)
|
| 37 |
+
from sklearn.neighbors import KNeighborsRegressor
|
| 38 |
+
from sklearn.dummy import DummyRegressor
|
| 39 |
+
from sklearn.metrics import (
|
| 40 |
+
mean_absolute_error, mean_squared_error, r2_score
|
| 41 |
)
|
| 42 |
|
| 43 |
# A placeholder class to store all results from a modeling run
|
| 44 |
+
class ModelRunResult:
|
| 45 |
+
def __init__(self, dataframe, plotter, models, selected_features):
|
| 46 |
+
self.dataframe = dataframe
|
| 47 |
+
self.plotter = plotter
|
| 48 |
+
self.models = models
|
| 49 |
+
self.selected_features = selected_features
|
| 50 |
|
| 51 |
# Optional Advanced Models
|
| 52 |
+
try:
|
| 53 |
+
import xgboost as xgb
|
| 54 |
+
import lightgbm as lgb
|
| 55 |
+
import catboost as cb
|
| 56 |
+
_has_extra_libs = True
|
| 57 |
+
except ImportError:
|
| 58 |
+
_has_extra_libs = False
|
| 59 |
+
warnings.warn("Optional libraries (xgboost, lightgbm, catboost) not found. Some models will be unavailable.")
|
| 60 |
|
| 61 |
# --- GLOBAL CONFIGURATION & SETUP ---
|
| 62 |
+
warnings.filterwarnings("ignore")
|
| 63 |
+
sns.set_theme(style='whitegrid')
|
| 64 |
|
| 65 |
# --- FINGERPRINT CONFIGURATION ---
|
| 66 |
DESCRIPTOR_DIR = "padel_descriptors"
|
|
|
|
| 67 |
# Check if the descriptor directory exists and contains files
|
| 68 |
if not os.path.isdir(DESCRIPTOR_DIR):
|
| 69 |
warnings.warn(
|
|
|
|
| 73 |
xml_files = []
|
| 74 |
else:
|
| 75 |
xml_files = sorted(glob.glob(os.path.join(DESCRIPTOR_DIR, '*.xml')))
|
| 76 |
+
if not xml_files:
|
| 77 |
+
warnings.warn(
|
| 78 |
+
f"No descriptor .xml files found in the '{DESCRIPTOR_DIR}' directory. "
|
| 79 |
+
"Fingerprint calculation will not be possible."
|
| 80 |
+
)
|
|
|
|
| 81 |
|
| 82 |
# The key is the filename without extension; the value is the full path to the file
|
| 83 |
fp_config = {os.path.splitext(os.path.basename(file))[0]: file for file in xml_files}
|
| 84 |
FP_list = sorted(list(fp_config.keys()))
|
| 85 |
|
| 86 |
|
| 87 |
+
# --- NEW MOLECULE DISPLAY FUNCTIONS ---
|
| 88 |
+
def create_molecule_grid_html(df, smiles_col='canonical_smiles', additional_cols=None, max_mols=50):
|
| 89 |
+
"""
|
| 90 |
+
Create a custom HTML grid for displaying molecules using RDKit and base64 encoding
|
| 91 |
+
This works better in Hugging Face Spaces than mols2grid
|
| 92 |
+
"""
|
| 93 |
+
from rdkit import Chem
|
| 94 |
+
from rdkit.Chem import Draw
|
| 95 |
+
import base64
|
| 96 |
+
from io import BytesIO
|
| 97 |
+
|
| 98 |
+
if df.empty:
|
| 99 |
+
return "<h3>No molecules to display.</h3>"
|
| 100 |
+
|
| 101 |
+
# Limit number of molecules for performance
|
| 102 |
+
df_display = df.head(max_mols).copy()
|
| 103 |
+
|
| 104 |
+
# Additional columns to display
|
| 105 |
+
if additional_cols is None:
|
| 106 |
+
additional_cols = []
|
| 107 |
+
|
| 108 |
+
html_parts = ["""
|
| 109 |
+
<div style="display: grid; grid-template-columns: repeat(auto-fill, minmax(300px, 1fr)); gap: 15px; padding: 10px;">
|
| 110 |
+
"""]
|
| 111 |
+
|
| 112 |
+
for idx, row in df_display.iterrows():
|
| 113 |
+
try:
|
| 114 |
+
# Generate molecule image
|
| 115 |
+
mol = Chem.MolFromSmiles(row[smiles_col])
|
| 116 |
+
if mol is None:
|
| 117 |
+
continue
|
| 118 |
+
|
| 119 |
+
# Draw molecule
|
| 120 |
+
img = Draw.MolToImage(mol, size=(250, 200))
|
| 121 |
+
|
| 122 |
+
# Convert to base64
|
| 123 |
+
buffer = BytesIO()
|
| 124 |
+
img.save(buffer, format='PNG')
|
| 125 |
+
img_str = base64.b64encode(buffer.getvalue()).decode()
|
| 126 |
+
|
| 127 |
+
# Create card HTML
|
| 128 |
+
card_html = f"""
|
| 129 |
+
<div style="border: 1px solid #ddd; border-radius: 8px; padding: 10px; background: white;">
|
| 130 |
+
<img src="data:image/png;base64,{img_str}" style="width: 100%; height: auto;" alt="Molecule"/>
|
| 131 |
+
<div style="margin-top: 8px; font-size: 12px;">
|
| 132 |
+
<strong>SMILES:</strong> {row[smiles_col][:50]}{'...' if len(str(row[smiles_col])) > 50 else ''}<br/>
|
| 133 |
+
"""
|
| 134 |
+
|
| 135 |
+
# Add additional columns
|
| 136 |
+
for col in additional_cols:
|
| 137 |
+
if col in row and pd.notna(row[col]):
|
| 138 |
+
value = row[col]
|
| 139 |
+
if isinstance(value, float):
|
| 140 |
+
value = f"{value:.2f}"
|
| 141 |
+
card_html += f"<strong>{col}:</strong> {value}<br/>"
|
| 142 |
+
|
| 143 |
+
card_html += """
|
| 144 |
+
</div>
|
| 145 |
+
</div>
|
| 146 |
+
"""
|
| 147 |
+
html_parts.append(card_html)
|
| 148 |
+
|
| 149 |
+
except Exception as e:
|
| 150 |
+
print(f"Error processing molecule {idx}: {e}")
|
| 151 |
+
continue
|
| 152 |
+
|
| 153 |
+
html_parts.append("</div>")
|
| 154 |
+
|
| 155 |
+
if len(html_parts) == 2: # Only header and footer, no molecules processed
|
| 156 |
+
return "<h3>No valid molecules to display.</h3>"
|
| 157 |
+
|
| 158 |
+
return "".join(html_parts)
|
| 159 |
+
|
| 160 |
+
def create_simple_molecule_table(df, smiles_col='canonical_smiles', additional_cols=None, max_mols=20):
|
| 161 |
+
"""
|
| 162 |
+
Create a simple HTML table with molecule images - fallback option
|
| 163 |
+
"""
|
| 164 |
+
from rdkit import Chem
|
| 165 |
+
from rdkit.Chem import Draw
|
| 166 |
+
import base64
|
| 167 |
+
from io import BytesIO
|
| 168 |
+
|
| 169 |
+
if df.empty:
|
| 170 |
+
return "<h3>No molecules to display.</h3>"
|
| 171 |
+
|
| 172 |
+
df_display = df.head(max_mols).copy()
|
| 173 |
+
|
| 174 |
+
if additional_cols is None:
|
| 175 |
+
additional_cols = []
|
| 176 |
+
|
| 177 |
+
# Start HTML table
|
| 178 |
+
html = """
|
| 179 |
+
<table style="border-collapse: collapse; width: 100%;">
|
| 180 |
+
<thead>
|
| 181 |
+
<tr style="background-color: #f2f2f2;">
|
| 182 |
+
<th style="border: 1px solid #ddd; padding: 8px;">Structure</th>
|
| 183 |
+
<th style="border: 1px solid #ddd; padding: 8px;">SMILES</th>
|
| 184 |
+
"""
|
| 185 |
+
|
| 186 |
+
for col in additional_cols:
|
| 187 |
+
html += f'<th style="border: 1px solid #ddd; padding: 8px;">{col}</th>'
|
| 188 |
+
|
| 189 |
+
html += """
|
| 190 |
+
</tr>
|
| 191 |
+
</thead>
|
| 192 |
+
<tbody>
|
| 193 |
+
"""
|
| 194 |
+
|
| 195 |
+
for idx, row in df_display.iterrows():
|
| 196 |
+
try:
|
| 197 |
+
mol = Chem.MolFromSmiles(row[smiles_col])
|
| 198 |
+
if mol is None:
|
| 199 |
+
continue
|
| 200 |
+
|
| 201 |
+
# Generate image
|
| 202 |
+
img = Draw.MolToImage(mol, size=(200, 150))
|
| 203 |
+
buffer = BytesIO()
|
| 204 |
+
img.save(buffer, format='PNG')
|
| 205 |
+
img_str = base64.b64encode(buffer.getvalue()).decode()
|
| 206 |
+
|
| 207 |
+
html += f"""
|
| 208 |
+
<tr>
|
| 209 |
+
<td style="border: 1px solid #ddd; padding: 8px; text-align: center;">
|
| 210 |
+
<img src="data:image/png;base64,{img_str}" style="max-width: 200px; height: auto;" alt="Molecule"/>
|
| 211 |
+
</td>
|
| 212 |
+
<td style="border: 1px solid #ddd; padding: 8px; font-family: monospace; font-size: 11px;">
|
| 213 |
+
{row[smiles_col][:100]}{'...' if len(str(row[smiles_col])) > 100 else ''}
|
| 214 |
+
</td>
|
| 215 |
+
"""
|
| 216 |
+
|
| 217 |
+
for col in additional_cols:
|
| 218 |
+
value = row[col] if col in row and pd.notna(row[col]) else "N/A"
|
| 219 |
+
if isinstance(value, float):
|
| 220 |
+
value = f"{value:.2f}"
|
| 221 |
+
html += f'<td style="border: 1px solid #ddd; padding: 8px;">{value}</td>'
|
| 222 |
+
|
| 223 |
+
html += "</tr>"
|
| 224 |
+
|
| 225 |
+
except Exception as e:
|
| 226 |
+
print(f"Error processing molecule {idx}: {e}")
|
| 227 |
+
continue
|
| 228 |
+
|
| 229 |
+
html += "</tbody></table>"
|
| 230 |
+
return html
|
| 231 |
+
|
| 232 |
# ==============================================================================
|
| 233 |
# === STEP 1: CORE DATA COLLECTION & EDA FUNCTIONS ===
|
| 234 |
# ==============================================================================
|
| 235 |
+
def get_target_chembl_id(query):
|
| 236 |
+
try:
|
| 237 |
+
target = new_client.target
|
| 238 |
+
res = target.search(query)
|
| 239 |
+
if not res:
|
| 240 |
+
return pd.DataFrame(), gr.Dropdown(choices=[], value=None), "No targets found for your query."
|
| 241 |
+
df = pd.DataFrame(res)
|
| 242 |
+
return df[["target_chembl_id", "pref_name", "organism"]], gr.Dropdown(choices=df["target_chembl_id"].tolist()), f"Found {len(df)} targets."
|
| 243 |
+
except Exception as e:
|
| 244 |
+
raise gr.Error(f"ChEMBL search failed: {e}")
|
| 245 |
+
|
| 246 |
+
def get_bioactivity_data(target_id):
|
| 247 |
+
try:
|
| 248 |
+
activity = new_client.activity
|
| 249 |
+
res = activity.filter(target_chembl_id=target_id).filter(standard_type="IC50")
|
| 250 |
+
if not res:
|
| 251 |
+
return pd.DataFrame(), "No IC50 bioactivity data found for this target."
|
| 252 |
+
df = pd.DataFrame(res)
|
| 253 |
+
return df, f"Fetched {len(df)} data points."
|
| 254 |
+
except Exception as e:
|
| 255 |
+
raise gr.Error(f"Failed to fetch bioactivity data: {e}")
|
| 256 |
+
|
| 257 |
+
def pIC50_calc(input_df):
|
| 258 |
+
df_copy = input_df.copy()
|
| 259 |
+
df_copy['standard_value'] = pd.to_numeric(df_copy['standard_value'], errors='coerce')
|
| 260 |
+
df_copy.dropna(subset=['standard_value'], inplace=True)
|
| 261 |
+
df_copy['standard_value_norm'] = df_copy['standard_value'].apply(lambda x: min(x, 100000000))
|
| 262 |
+
pIC50_values = []
|
| 263 |
+
for i in df_copy['standard_value_norm']:
|
| 264 |
+
if pd.notna(i) and i > 0:
|
| 265 |
+
molar = i * (10**-9)
|
| 266 |
+
pIC50_values.append(-np.log10(molar))
|
| 267 |
+
else:
|
| 268 |
+
pIC50_values.append(np.nan)
|
| 269 |
+
df_copy['pIC50'] = pIC50_values
|
| 270 |
+
df_copy['bioactivity_class'] = df_copy['standard_value_norm'].apply(
|
| 271 |
+
lambda x: "inactive" if pd.notna(x) and x >= 10000 else ("active" if pd.notna(x) and x <= 1000 else "intermediate")
|
|
|
|
| 272 |
)
|
| 273 |
+
return df_copy.drop(columns=['standard_value', 'standard_value_norm'])
|
| 274 |
+
|
| 275 |
+
def lipinski_descriptors(smiles_series):
|
| 276 |
+
moldata, valid_smiles = [], []
|
| 277 |
+
for elem in smiles_series:
|
| 278 |
+
if elem and isinstance(elem, str):
|
| 279 |
+
mol = Chem.MolFromSmiles(elem)
|
| 280 |
+
if mol:
|
| 281 |
+
moldata.append(mol)
|
| 282 |
+
valid_smiles.append(elem)
|
| 283 |
+
descriptor_rows = []
|
| 284 |
+
for mol in moldata:
|
| 285 |
+
row = [Descriptors.MolWt(mol), Descriptors.MolLogP(mol), Lipinski.NumHDonors(mol), Lipinski.NumHAcceptors(mol)]
|
| 286 |
+
descriptor_rows.append(row)
|
| 287 |
+
columnNames = ["MW", "LogP", "NumHDonors", "NumHAcceptors"]
|
| 288 |
+
if not descriptor_rows: return pd.DataFrame(columns=columnNames), []
|
| 289 |
+
return pd.DataFrame(data=np.array(descriptor_rows), columns=columnNames), valid_smiles
|
| 290 |
+
|
| 291 |
+
def clean_and_process_data(df):
|
| 292 |
+
if df is None or df.empty: raise gr.Error("No data to process. Please fetch data first.")
|
| 293 |
+
if "canonical_smiles" not in df.columns or df["canonical_smiles"].isnull().all():
|
| 294 |
+
try:
|
| 295 |
+
df["canonical_smiles"] = [c.get("molecule_structures", {}).get("canonical_smiles") for c in new_client.molecule.get(list(df["molecule_chembl_id"]))]
|
| 296 |
+
except Exception as e:
|
| 297 |
+
raise gr.Error(f"Could not fetch SMILES from ChEMBL: {e}")
|
| 298 |
+
df = df[df.standard_value.notna()]
|
| 299 |
+
df = df[df.canonical_smiles.notna()]
|
| 300 |
+
df.drop_duplicates(['canonical_smiles'], inplace=True)
|
| 301 |
+
df["standard_value"] = pd.to_numeric(df["standard_value"], errors='coerce')
|
| 302 |
+
df.dropna(subset=['standard_value'], inplace=True)
|
| 303 |
+
df_processed = pIC50_calc(df)
|
| 304 |
+
df_processed = df_processed[df_processed.pIC50.notna()]
|
| 305 |
+
if df_processed.empty: return pd.DataFrame(), "No compounds remaining after pIC50 calculation."
|
| 306 |
+
df_lipinski, valid_smiles = lipinski_descriptors(df_processed['canonical_smiles'])
|
| 307 |
+
if not valid_smiles: return pd.DataFrame(), "No valid SMILES could be processed for Lipinski descriptors."
|
| 308 |
+
df_processed = df_processed[df_processed['canonical_smiles'].isin(valid_smiles)].reset_index(drop=True)
|
| 309 |
+
df_lipinski = df_lipinski.reset_index(drop=True)
|
| 310 |
+
df_final = pd.concat([df_processed, df_lipinski], axis=1)
|
| 311 |
+
return df_final, f"Processing complete. {len(df_final)} compounds remain after cleaning."
|
| 312 |
+
|
| 313 |
+
def run_eda_analysis(df, selected_classes):
|
| 314 |
+
if df is None or df.empty: raise gr.Error("No data available for analysis.")
|
| 315 |
+
df_filtered = df[df.bioactivity_class.isin(selected_classes)].copy()
|
| 316 |
+
if df_filtered.empty: return (None, None, None, pd.DataFrame(), None, pd.DataFrame(), None, pd.DataFrame(), None, pd.DataFrame(), None, pd.DataFrame(), "No data for selected classes.")
|
| 317 |
+
plots = [create_frequency_plot(df_filtered), create_scatter_plot(df_filtered)]
|
| 318 |
+
stats_dfs = []
|
| 319 |
+
for desc in ['pIC50', 'MW', 'LogP', 'NumHDonors', 'NumHAcceptors']:
|
| 320 |
+
plots.append(create_boxplot(df_filtered, desc))
|
| 321 |
+
stats_dfs.append(mannwhitney_test(df_filtered, desc))
|
| 322 |
+
plt.close('all')
|
| 323 |
+
return (plots[0], plots[1], plots[2], stats_dfs[0], plots[3], stats_dfs[1], plots[4], stats_dfs[2], plots[5], stats_dfs[3], plots[6], stats_dfs[4], f"EDA complete for {len(df_filtered)} compounds.")
|
| 324 |
+
|
| 325 |
+
def create_frequency_plot(df):
|
| 326 |
+
plt.figure(figsize=(5.5, 5.5)); sns.barplot(x=df['bioactivity_class'].value_counts().index, y=df['bioactivity_class'].value_counts().values, palette={'active': '#1f77b4', 'inactive': '#ff7f0e', 'intermediate': '#2ca02c'}); plt.xlabel('Bioactivity Class', fontsize=12); plt.ylabel('Frequency', fontsize=12); plt.title('Frequency of Bioactivity Classes', fontsize=14); return plt.gcf()
|
| 327 |
+
|
| 328 |
+
def create_scatter_plot(df):
|
| 329 |
+
plt.figure(figsize=(5.5, 5.5)); sns.scatterplot(data=df, x='MW', y='LogP', hue='bioactivity_class', size='pIC50', palette={'active': '#1f77b4', 'inactive': '#ff7f0e', 'intermediate': '#2ca02c'}, sizes=(20, 200), alpha=0.7); plt.xlabel('Molecular Weight (MW)', fontsize=12); plt.ylabel('LogP', fontsize=12); plt.title('Chemical Space: MW vs. LogP', fontsize=14); plt.legend(title='Bioactivity Class'); return plt.gcf()
|
| 330 |
+
|
| 331 |
+
def create_boxplot(df, descriptor):
|
| 332 |
+
plt.figure(figsize=(5.5, 5.5)); sns.boxplot(x='bioactivity_class', y=descriptor, data=df, palette={'active': '#1f77b4', 'inactive': '#ff7f0e', 'intermediate': '#2ca02c'}); plt.xlabel('Bioactivity Class', fontsize=12); plt.ylabel(descriptor, fontsize=12); plt.title(f'{descriptor} by Bioactivity Class', fontsize=14); return plt.gcf()
|
| 333 |
+
|
| 334 |
+
def mannwhitney_test(df, descriptor):
|
| 335 |
+
results = []
|
| 336 |
+
for c1, c2 in [('active', 'inactive'), ('active', 'intermediate'), ('inactive', 'intermediate')]:
|
| 337 |
+
if c1 in df['bioactivity_class'].unique() and c2 in df['bioactivity_class'].unique():
|
| 338 |
+
d1, d2 = df[df.bioactivity_class == c1][descriptor].dropna(), df[df.bioactivity_class == c2][descriptor].dropna()
|
| 339 |
+
if not d1.empty and not d2.empty:
|
| 340 |
+
stat, p = mannwhitneyu(d1, d2)
|
| 341 |
+
results.append({'Comparison': f'{c1.title()} vs {c2.title()}', 'Statistics': stat, 'p-value': p, 'Interpretation': 'Different distribution (p < 0.05)' if p <= 0.05 else 'Same distribution (p > 0.05)'})
|
| 342 |
+
return pd.DataFrame(results)
|
| 343 |
|
| 344 |
# ==============================================================================
|
| 345 |
# === STEP 2: FEATURE ENGINEERING FUNCTIONS ===
|
| 346 |
# ==============================================================================
|
| 347 |
+
def calculate_fingerprints(current_state, fingerprint_type, progress=gr.Progress()):
|
| 348 |
+
input_df = current_state.get('cleaned_data')
|
| 349 |
+
if input_df is None or input_df.empty:
|
| 350 |
+
raise gr.Error("No cleaned data found. Please complete Step 1.")
|
| 351 |
+
if not fingerprint_type:
|
| 352 |
+
raise gr.Error("Please select a fingerprint type.")
|
| 353 |
+
|
| 354 |
+
progress(0, desc="Starting...")
|
| 355 |
+
yield f"��� Starting fingerprint calculation...", None, gr.update(visible=False), None, current_state
|
| 356 |
+
|
| 357 |
+
try:
|
| 358 |
+
smi_file, output_csv = 'molecule.smi', 'fingerprints.csv'
|
| 359 |
+
|
| 360 |
+
input_df[['canonical_smiles', 'canonical_smiles']].to_csv(smi_file, sep='\t', index=False, header=False)
|
| 361 |
|
| 362 |
+
if os.path.exists(output_csv):
|
| 363 |
+
os.remove(output_csv)
|
| 364 |
+
descriptortypes = fp_config.get(fingerprint_type)
|
| 365 |
+
if not descriptortypes:
|
| 366 |
+
raise gr.Error(f"Descriptor XML for '{fingerprint_type}' not found.")
|
| 367 |
|
| 368 |
+
progress(0.3, desc="⚗️ Running PaDEL...")
|
| 369 |
+
yield f"⚗️ Running PaDEL...", None, gr.update(visible=False), None, current_state
|
|
|
|
| 370 |
|
| 371 |
+
padeldescriptor(
|
| 372 |
+
mol_dir=smi_file,
|
| 373 |
+
d_file=output_csv,
|
| 374 |
+
descriptortypes=descriptortypes,
|
| 375 |
+
detectaromaticity=True,
|
| 376 |
+
standardizenitro=True,
|
| 377 |
+
standardizetautomers=True,
|
| 378 |
+
threads=-1,
|
| 379 |
+
removesalt=True,
|
| 380 |
+
log=False,
|
| 381 |
+
fingerprints=True
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
if not os.path.exists(output_csv) or os.path.getsize(output_csv) == 0:
|
| 385 |
+
raise gr.Error("PaDEL failed to produce an output file. Check molecule validity.")
|
| 386 |
|
| 387 |
+
progress(0.7, desc="📊 Processing results...")
|
| 388 |
+
yield "📊 Processing results...", None, gr.update(visible=False), None, current_state
|
| 389 |
+
|
| 390 |
+
df_X = pd.read_csv(output_csv).rename(columns={'Name': 'canonical_smiles'})
|
| 391 |
+
final_df = pd.merge(input_df[['canonical_smiles', 'pIC50']], df_X, on='canonical_smiles', how='inner')
|
| 392 |
+
|
| 393 |
+
current_state['fingerprint_data'] = final_df
|
| 394 |
+
current_state['fingerprint_type'] = fingerprint_type
|
| 395 |
+
|
| 396 |
+
progress(0.9, desc="🖼️ Generating molecule grid...")
|
| 397 |
+
|
| 398 |
+
# Use custom molecule display instead of mols2grid
|
| 399 |
+
try:
|
| 400 |
+
# Try the grid layout first
|
| 401 |
+
mols_html = create_molecule_grid_html(
|
| 402 |
+
final_df,
|
| 403 |
+
smiles_col='canonical_smiles',
|
| 404 |
+
additional_cols=['pIC50'],
|
| 405 |
+
max_mols=50
|
| 406 |
+
)
|
| 407 |
+
except Exception as e:
|
| 408 |
+
print(f"Grid layout failed: {e}, trying table layout...")
|
| 409 |
+
# Fallback to table layout
|
| 410 |
+
mols_html = create_simple_molecule_table(
|
| 411 |
+
final_df,
|
| 412 |
+
smiles_col='canonical_smiles',
|
| 413 |
+
additional_cols=['pIC50'],
|
| 414 |
+
max_mols=20
|
| 415 |
+
)
|
| 416 |
|
| 417 |
+
success_msg = f"✅ Success! Generated {len(df_X.columns) -1} descriptors for {len(final_df)} molecules."
|
| 418 |
+
progress(1, desc="Completed!")
|
| 419 |
+
yield success_msg, final_df, gr.update(visible=True), gr.update(value=mols_html, visible=True), current_state
|
| 420 |
|
| 421 |
+
except Exception as e:
|
| 422 |
+
raise gr.Error(f"Calculation failed: {e}")
|
| 423 |
+
finally:
|
| 424 |
+
if os.path.exists('molecule.smi'):
|
| 425 |
+
os.remove('molecule.smi')
|
| 426 |
+
if os.path.exists('fingerprints.csv'):
|
| 427 |
+
os.remove('fingerprints.csv')
|
|
|
|
|
|
|
| 428 |
|
| 429 |
# ==============================================================================
|
| 430 |
+
# === STEP 3: MODELING FUNCTIONS ===
|
| 431 |
# ==============================================================================
|
| 432 |
+
|
| 433 |
+
# Model definitions
|
| 434 |
+
regression_models = {
|
| 435 |
+
"Linear Regression": LinearRegression,
|
| 436 |
+
"Ridge Regression": Ridge,
|
| 437 |
+
"Lasso Regression": Lasso,
|
| 438 |
+
"Elastic Net": ElasticNet,
|
| 439 |
+
"Bayesian Ridge": BayesianRidge,
|
| 440 |
+
"Huber Regressor": HuberRegressor,
|
| 441 |
+
"Passive Aggressive Regressor": PassiveAggressiveRegressor,
|
| 442 |
+
"Orthogonal Matching Pursuit": OrthogonalMatchingPursuit,
|
| 443 |
+
"Lasso Lars": LassoLars,
|
| 444 |
+
"Decision Tree Regressor": DecisionTreeRegressor,
|
| 445 |
+
"Random Forest Regressor": RandomForestRegressor,
|
| 446 |
+
"Gradient Boosting Regressor": GradientBoostingRegressor,
|
| 447 |
+
"Extra Trees Regressor": ExtraTreesRegressor,
|
| 448 |
+
"AdaBoost Regressor": AdaBoostRegressor,
|
| 449 |
+
"K-Neighbors Regressor": KNeighborsRegressor,
|
| 450 |
+
"Dummy Regressor (Mean)": DummyRegressor,
|
| 451 |
+
}
|
| 452 |
+
|
| 453 |
+
if _has_extra_libs:
|
| 454 |
+
regression_models.update({
|
| 455 |
+
"XGBoost Regressor": xgb.XGBRegressor,
|
| 456 |
+
"LightGBM Regressor": lgb.LGBMRegressor,
|
| 457 |
+
"CatBoost Regressor": cb.CatBoostRegressor,
|
| 458 |
+
})
|
| 459 |
+
|
| 460 |
+
def handle_model_training(current_state, progress=gr.Progress()):
|
| 461 |
+
df_fp = current_state.get('fingerprint_data')
|
| 462 |
+
if df_fp is None or df_fp.empty:
|
| 463 |
+
raise gr.Error("No fingerprint data found. Please complete Step 2.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 464 |
|
| 465 |
+
yield "⚙️ Starting model training...", None, None, current_state
|
| 466 |
+
progress(0, desc="Starting model training...")
|
| 467 |
+
|
| 468 |
+
try:
|
| 469 |
+
X = df_fp.drop(columns=['canonical_smiles', 'pIC50'])
|
| 470 |
+
y = df_fp['pIC50']
|
| 471 |
+
|
| 472 |
+
if X.empty:
|
| 473 |
+
raise gr.Error("No features found for training. Check fingerprint calculation.")
|
| 474 |
+
|
| 475 |
+
# Remove features with zero variance
|
| 476 |
+
sel = VarianceThreshold(threshold=0.0)
|
| 477 |
+
X_filtered = sel.fit_transform(X)
|
| 478 |
+
selected_features = X.columns[sel.get_support()]
|
| 479 |
+
|
| 480 |
+
if len(selected_features) == 0:
|
| 481 |
+
raise gr.Error("No features with non-zero variance found. Cannot train models.")
|
| 482 |
+
|
| 483 |
+
X_filtered = pd.DataFrame(X_filtered, columns=selected_features)
|
| 484 |
+
|
| 485 |
+
X_train, X_test, y_train, y_test = train_test_split(X_filtered, y, test_size=0.2, random_state=42)
|
| 486 |
+
|
| 487 |
+
results = []
|
| 488 |
+
trained_models = {}
|
| 489 |
+
|
| 490 |
+
for i, (name, model_class) in enumerate(regression_models.items()):
|
| 491 |
+
current_progress = (i + 1) / len(regression_models)
|
| 492 |
+
progress(current_progress, desc=f"Training {name}...")
|
| 493 |
+
yield f"Training {name}...", None, None, current_state
|
| 494 |
+
|
| 495 |
+
model = model_class()
|
| 496 |
+
model.fit(X_train, y_train)
|
| 497 |
+
y_pred = model.predict(X_test)
|
| 498 |
+
|
| 499 |
+
mae = mean_absolute_error(y_test, y_pred)
|
| 500 |
+
mse = mean_squared_error(y_test, y_pred)
|
| 501 |
+
r2 = r2_score(y_test, y_pred)
|
| 502 |
+
|
| 503 |
+
results.append({
|
| 504 |
+
"Model": name,
|
| 505 |
+
"MAE": mae,
|
| 506 |
+
"MSE": mse,
|
| 507 |
+
"R2": r2
|
| 508 |
+
})
|
| 509 |
+
trained_models[name] = model
|
| 510 |
+
|
| 511 |
+
results_df = pd.DataFrame(results)
|
| 512 |
+
|
| 513 |
+
# Store results in the state
|
| 514 |
+
current_state['model_results'] = ModelRunResult(
|
| 515 |
+
dataframe=results_df,
|
| 516 |
+
plotter=None, # No plot generated here directly, but could be added
|
| 517 |
+
models=trained_models,
|
| 518 |
+
selected_features=selected_features
|
| 519 |
+
)
|
| 520 |
+
current_state['X_train'] = X_train
|
| 521 |
+
current_state['y_train'] = y_train
|
| 522 |
+
current_state['X_test'] = X_test
|
| 523 |
+
current_state['y_test'] = y_test
|
| 524 |
+
|
| 525 |
+
progress(1, desc="Model training complete!")
|
| 526 |
+
yield "✅ Model training complete!", results_df, gr.update(choices=list(trained_models.keys()), value=list(trained_models.keys())[0]), current_state
|
| 527 |
+
|
| 528 |
+
except Exception as e:
|
| 529 |
+
raise gr.Error(f"Model training failed: {e}")
|
| 530 |
+
|
| 531 |
+
def update_analysis_plots(model_name, feature_count, current_state):
|
| 532 |
+
model_run_results = current_state.get('model_results')
|
| 533 |
+
X_train = current_state.get('X_train')
|
| 534 |
+
y_train = current_state.get('y_train')
|
| 535 |
+
X_test = current_state.get('X_test')
|
| 536 |
+
y_test = current_state.get('y_test')
|
| 537 |
+
|
| 538 |
+
if not model_run_results or not model_name or X_test is None or y_test is None:
|
| 539 |
+
return None, None, None, None, "Please train models first."
|
| 540 |
+
|
| 541 |
+
model = model_run_results.models.get(model_name)
|
| 542 |
+
if model is None:
|
| 543 |
+
return None, None, None, None, f"Model '{model_name}' not found."
|
| 544 |
+
|
| 545 |
+
y_pred_test = model.predict(X_test)
|
| 546 |
+
y_pred_train = model.predict(X_train)
|
| 547 |
+
|
| 548 |
+
# Calculate R2, MAE, MSE for test set
|
| 549 |
+
r2_test = r2_score(y_test, y_pred_test)
|
| 550 |
+
mae_test = mean_absolute_error(y_test, y_pred_test)
|
| 551 |
+
mse_test = mean_squared_error(y_test, y_pred_test)
|
| 552 |
+
|
| 553 |
+
# Plot Y-obs vs Y-pred
|
| 554 |
+
fig_obs_pred, ax_obs_pred = plt.subplots(figsize=(6, 6))
|
| 555 |
+
sns.scatterplot(x=y_test, y=y_pred_test, ax=ax_obs_pred, alpha=0.7, color='blue', label='Test Data')
|
| 556 |
+
sns.scatterplot(x=y_train, y=y_pred_train, ax=ax_obs_pred, alpha=0.7, color='green', label='Train Data')
|
| 557 |
+
ax_obs_pred.plot([min(y_test.min(), y_train.min()), max(y_test.max(), y_train.max())],
|
| 558 |
+
[min(y_test.min(), y_train.min()), max(y_test.max(), y_train.max())],
|
| 559 |
+
color='red', linestyle='--', label='Ideal Prediction')
|
| 560 |
+
ax_obs_pred.set_xlabel("Observed pIC50", fontsize=12)
|
| 561 |
+
ax_obs_pred.set_ylabel("Predicted pIC50", fontsize=12)
|
| 562 |
+
ax_obs_pred.set_title(f"{model_name}: Observed vs. Predicted pIC50", fontsize=14)
|
| 563 |
+
ax_obs_pred.legend()
|
| 564 |
+
plt.close(fig_obs_pred)
|
| 565 |
+
|
| 566 |
+
# Plot Residuals
|
| 567 |
+
residuals = y_test - y_pred_test
|
| 568 |
+
fig_residuals, ax_residuals = plt.subplots(figsize=(6, 6))
|
| 569 |
+
sns.scatterplot(x=y_pred_test, y=residuals, ax=ax_residuals, alpha=0.7, color='purple')
|
| 570 |
+
ax_residuals.axhline(y=0, color='red', linestyle='--')
|
| 571 |
+
ax_residuals.set_xlabel("Predicted pIC50", fontsize=12)
|
| 572 |
+
ax_residuals.set_ylabel("Residuals (Observed - Predicted)", fontsize=12)
|
| 573 |
+
ax_residuals.set_title(f"{model_name}: Residuals Plot", fontsize=14)
|
| 574 |
+
plt.close(fig_residuals)
|
| 575 |
+
|
| 576 |
+
# Feature Importance Plot (if applicable)
|
| 577 |
+
fig_feature_importance = None
|
| 578 |
+
if hasattr(model, 'feature_importances_') and feature_count > 0:
|
| 579 |
+
feature_importances = pd.Series(model.feature_importances_, index=model_run_results.selected_features)
|
| 580 |
+
top_features = feature_importances.nlargest(feature_count)
|
| 581 |
+
|
| 582 |
+
fig_feature_importance, ax_fi = plt.subplots(figsize=(8, 6))
|
| 583 |
+
sns.barplot(x=top_features.values, y=top_features.index, ax=ax_fi, palette='viridis')
|
| 584 |
+
ax_fi.set_xlabel("Importance", fontsize=12)
|
| 585 |
+
ax_fi.set_ylabel("Feature", fontsize=12)
|
| 586 |
+
ax_fi.set_title(f"{model_name}: Top {feature_count} Feature Importances", fontsize=14)
|
| 587 |
+
plt.tight_layout()
|
| 588 |
+
plt.close(fig_feature_importance)
|
| 589 |
+
elif hasattr(model, 'coef_') and feature_count > 0 and model_name not in ["Dummy Regressor (Mean)", "K-Neighbors Regressor"]:
|
| 590 |
+
# For linear models, coefficients can be used as importance
|
| 591 |
+
feature_importances = pd.Series(model.coef_, index=model_run_results.selected_features)
|
| 592 |
+
top_features = feature_importances.abs().nlargest(feature_count) # Use absolute value for ranking
|
| 593 |
+
top_features_values = feature_importances[top_features.index] # Get actual signed values
|
| 594 |
+
|
| 595 |
+
fig_feature_importance, ax_fi = plt.subplots(figsize=(8, 6))
|
| 596 |
+
sns.barplot(x=top_features_values.values, y=top_features_values.index, ax=ax_fi, palette='coolwarm')
|
| 597 |
+
ax_fi.set_xlabel("Coefficient Value", fontsize=12)
|
| 598 |
+
ax_fi.set_ylabel("Feature", fontsize=12)
|
| 599 |
+
ax_fi.set_title(f"{model_name}: Top {feature_count} Feature Coefficients", fontsize=14)
|
| 600 |
+
plt.tight_layout()
|
| 601 |
+
plt.close(fig_feature_importance)
|
| 602 |
+
|
| 603 |
+
|
| 604 |
+
return fig_obs_pred, fig_residuals, fig_feature_importance, \
|
| 605 |
+
f"R2 (Test): {r2_test:.4f}, MAE (Test): {mae_test:.4f}, MSE (Test): {mse_test:.4f}", \
|
| 606 |
+
"Plots updated."
|
| 607 |
+
|
| 608 |
+
# Updated prediction function
|
| 609 |
+
def predict_on_upload(uploaded_file, model_name, current_state, progress=gr.Progress()):
|
| 610 |
+
if not uploaded_file:
|
| 611 |
+
raise gr.Error("Please upload a file.")
|
| 612 |
+
if not model_name:
|
| 613 |
+
raise gr.Error("Please select a trained model.")
|
| 614 |
|
| 615 |
+
model_run_results = current_state.get('model_results')
|
| 616 |
+
fingerprint_type = current_state.get('fingerprint_type')
|
| 617 |
+
if not model_run_results or not fingerprint_type:
|
| 618 |
+
raise gr.Error("Please run Steps 2 and 3 first.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 619 |
|
| 620 |
+
model = model_run_results.models.get(model_name)
|
| 621 |
+
selected_features = model_run_results.selected_features
|
| 622 |
+
if model is None:
|
| 623 |
+
raise gr.Error(f"Model '{model_name}' not found.")
|
| 624 |
|
| 625 |
+
smi_file, output_csv = 'predict.smi', 'predict_fp.csv'
|
| 626 |
+
try:
|
| 627 |
+
progress(0, desc="Reading & processing new molecules...")
|
| 628 |
+
yield "Reading uploaded file...", None, None
|
|
|
|
|
|
|
| 629 |
|
| 630 |
+
df_new = pd.read_csv(uploaded_file.name)
|
| 631 |
+
if 'canonical_smiles' not in df_new.columns:
|
| 632 |
+
raise gr.Error("CSV must contain a 'canonical_smiles' column.")
|
| 633 |
+
df_new = df_new.reset_index().rename(columns={'index': 'mol_id'})
|
| 634 |
|
| 635 |
+
padel_input = pd.DataFrame({
|
| 636 |
+
'smiles': df_new['canonical_smiles'],
|
| 637 |
+
'name': df_new['mol_id']
|
| 638 |
+
})
|
| 639 |
+
padel_input.to_csv(smi_file, sep='\t', index=False, header=False)
|
| 640 |
+
if os.path.exists(output_csv):
|
| 641 |
+
os.remove(output_csv)
|
| 642 |
|
| 643 |
+
progress(0.3, desc="Calculating fingerprints...")
|
| 644 |
+
yield "Calculating fingerprints for new molecules...", None, None
|
| 645 |
|
| 646 |
+
padeldescriptor(
|
| 647 |
+
mol_dir=smi_file,
|
| 648 |
+
d_file=output_csv,
|
| 649 |
+
descriptortypes=fp_config.get(fingerprint_type),
|
| 650 |
+
detectaromaticity=True,
|
| 651 |
+
standardizenitro=True,
|
| 652 |
+
threads=-1,
|
| 653 |
+
removesalt=True,
|
| 654 |
+
log=False,
|
| 655 |
+
fingerprints=True
|
| 656 |
+
)
|
| 657 |
|
| 658 |
+
if not os.path.exists(output_csv) or os.path.getsize(output_csv) == 0:
|
| 659 |
+
raise gr.Error("PaDEL calculation failed for the uploaded molecules.")
|
| 660 |
|
| 661 |
+
progress(0.7, desc="Aligning features and predicting...")
|
| 662 |
+
yield "Aligning features and predicting...", None, None
|
| 663 |
+
|
| 664 |
+
df_fp = pd.read_csv(output_csv).rename(columns={'Name': 'mol_id'})
|
| 665 |
+
X_new = df_fp.set_index('mol_id')
|
| 666 |
+
X_new_aligned = X_new.reindex(columns=selected_features, fill_value=0)[selected_features]
|
| 667 |
+
predictions = model.predict(X_new_aligned)
|
| 668 |
+
results_subset = pd.DataFrame({
|
| 669 |
+
'mol_id': X_new_aligned.index,
|
| 670 |
+
'predicted_pIC50': predictions
|
| 671 |
+
})
|
| 672 |
+
df_results = pd.merge(df_new, results_subset, on='mol_id', how='left')
|
| 673 |
+
|
| 674 |
+
progress(0.9, desc="Generating visualization...")
|
| 675 |
+
yield "Generating visualization...", None, None
|
| 676 |
|
| 677 |
+
df_grid_view = df_results.dropna(subset=['predicted_pIC50']).copy()
|
| 678 |
+
mols_html = "<h3>No molecules with successful predictions to display.</h3>"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 679 |
|
| 680 |
+
if not df_grid_view.empty:
|
| 681 |
+
try:
|
| 682 |
+
# Use custom molecule display
|
| 683 |
+
mols_html = create_molecule_grid_html(
|
| 684 |
+
df_grid_view,
|
| 685 |
+
smiles_col='canonical_smiles',
|
| 686 |
+
additional_cols=['predicted_pIC50'],
|
| 687 |
+
max_mols=50
|
| 688 |
+
)
|
| 689 |
+
except Exception as e:
|
| 690 |
+
print(f"Grid layout failed: {e}, trying table layout...")
|
| 691 |
+
mols_html = create_simple_molecule_table(
|
| 692 |
+
df_grid_view,
|
| 693 |
+
smiles_col='canonical_smiles',
|
| 694 |
+
additional_cols=['predicted_pIC50'],
|
| 695 |
+
max_mols=20
|
| 696 |
+
)
|
| 697 |
+
|
| 698 |
+
progress(1, desc="Complete!")
|
| 699 |
+
yield "✅ Prediction complete.", df_results[['canonical_smiles', 'predicted_pIC50']], mols_html
|
| 700 |
+
|
| 701 |
+
finally:
|
| 702 |
+
if os.path.exists(smi_file):
|
| 703 |
+
os.remove(smi_file)
|
| 704 |
+
if os.path.exists(output_csv):
|
| 705 |
+
os.remove(output_csv)
|
| 706 |
|
| 707 |
# ==============================================================================
|
| 708 |
+
# === GRADIO INTERFACE LAYOUT ===
|
| 709 |
# ==============================================================================
|
| 710 |
+
|
| 711 |
+
with gr.Blocks(css=".container { max-width: 1200px; margin: auto; }") as demo:
|
| 712 |
+
app_state = gr.State({}) # Store application state (e.g., fetched data, trained models)
|
| 713 |
+
|
| 714 |
+
gr.Markdown("# 💊 Bioactivity Prediction App")
|
| 715 |
+
gr.Markdown("---")
|
| 716 |
+
|
| 717 |
+
with gr.Tabs():
|
| 718 |
+
with gr.TabItem("Step 1: Data Collection & EDA"):
|
| 719 |
+
with gr.Row():
|
| 720 |
+
with gr.Column():
|
| 721 |
+
gr.Markdown("## 1.1 ChEMBL Target Search")
|
| 722 |
+
target_query = gr.Textbox(
|
| 723 |
+
label="Enter target protein name (e.g., 'EGFR', 'acetylcholinesterase')",
|
| 724 |
+
placeholder="EGFR"
|
| 725 |
+
)
|
| 726 |
+
search_target_btn = gr.Button("Search ChEMBL")
|
| 727 |
+
target_output_df = gr.DataFrame(
|
| 728 |
+
label="Search Results",
|
| 729 |
+
headers=["target_chembl_id", "pref_name", "organism"],
|
| 730 |
+
max_rows=5
|
| 731 |
+
)
|
| 732 |
+
status_step1_search = gr.Textbox(label="Status", interactive=False)
|
| 733 |
+
|
| 734 |
+
gr.Markdown("## 1.2 Fetch Bioactivity Data")
|
| 735 |
+
chembl_id_selector = gr.Dropdown(
|
| 736 |
+
label="Select Target ChEMBL ID",
|
| 737 |
+
choices=[], interactive=True
|
| 738 |
+
)
|
| 739 |
+
fetch_data_btn = gr.Button("Fetch Bioactivity Data (IC50)")
|
| 740 |
+
bioactivity_output_df = gr.DataFrame(label="Raw Bioactivity Data", max_rows=5)
|
| 741 |
+
status_step1_fetch = gr.Textbox(label="Status", interactive=False)
|
| 742 |
+
|
| 743 |
+
gr.Markdown("## 1.3 Clean & Process Data")
|
| 744 |
+
process_data_btn = gr.Button("Process & Calculate pIC50/Lipinski Descriptors")
|
| 745 |
+
cleaned_data_output_df = gr.DataFrame(label="Cleaned & Processed Data", max_rows=5)
|
| 746 |
+
status_step1_process_clean = gr.Textbox(label="Status", interactive=False)
|
| 747 |
+
download_s1 = gr.DownloadButton("Download Cleaned Data (CSV)", visible=False, interactive=False)
|
| 748 |
+
|
| 749 |
+
with gr.Column():
|
| 750 |
+
gr.Markdown("## 1.4 Exploratory Data Analysis (EDA)")
|
| 751 |
+
bioactivity_class_selector = gr.CheckboxGroup(
|
| 752 |
+
label="Select Bioactivity Classes for EDA",
|
| 753 |
+
choices=["active", "inactive", "intermediate"],
|
| 754 |
+
value=["active", "inactive", "intermediate"],
|
| 755 |
+
interactive=True
|
| 756 |
+
)
|
| 757 |
+
run_eda_btn = gr.Button("Run EDA")
|
| 758 |
+
status_step1_process = gr.Textbox(label="Status", interactive=False)
|
| 759 |
+
|
| 760 |
+
gr.Markdown("### Bioactivity Class Frequency")
|
| 761 |
+
freq_plot_output = gr.Plot(label="Bioactivity Class Frequency")
|
| 762 |
+
|
| 763 |
+
gr.Markdown("### Chemical Space (MW vs LogP)")
|
| 764 |
+
scatter_plot_output = gr.Plot(label="MW vs LogP Scatter Plot")
|
| 765 |
+
|
| 766 |
+
gr.Markdown("### pIC50 Distribution")
|
| 767 |
+
pic50_plot_output = gr.Plot(label="pIC50 Box Plot")
|
| 768 |
+
pic50_stats_output = gr.DataFrame(label="pIC50 Mann-Whitney U Test", max_rows=5)
|
| 769 |
+
|
| 770 |
+
gr.Markdown("### Molecular Weight (MW) Distribution")
|
| 771 |
+
mw_plot_output = gr.Plot(label="MW Box Plot")
|
| 772 |
+
mw_stats_output = gr.DataFrame(label="MW Mann-Whitney U Test", max_rows=5)
|
| 773 |
+
|
| 774 |
+
gr.Markdown("### LogP Distribution")
|
| 775 |
+
logp_plot_output = gr.Plot(label="LogP Box Plot")
|
| 776 |
+
logp_stats_output = gr.DataFrame(label="LogP Mann-Whitney U Test", max_rows=5)
|
| 777 |
+
|
| 778 |
+
gr.Markdown("### Hydrogen Donors Distribution")
|
| 779 |
+
hdonors_plot_output = gr.Plot(label="Hydrogen Donors Box Plot")
|
| 780 |
+
hdonors_stats_output = gr.DataFrame(label="Hydrogen Donors Mann-Whitney U Test", max_rows=5)
|
| 781 |
+
|
| 782 |
+
gr.Markdown("### Hydrogen Acceptors Distribution")
|
| 783 |
+
hacceptors_plot_output = gr.Plot(label="Hydrogen Acceptors Box Plot")
|
| 784 |
+
hacceptors_stats_output = gr.DataFrame(label="Hydrogen Acceptors Mann-Whitney U Test", max_rows=5)
|
| 785 |
+
|
| 786 |
+
with gr.TabItem("Step 2: Feature Engineering (Fingerprints)"):
|
| 787 |
+
gr.Markdown("## 2.1 Calculate Molecular Fingerprints")
|
| 788 |
+
gr.Markdown(f"Available Fingerprint Types: {', '.join(FP_list)}")
|
| 789 |
+
fingerprint_dropdown = gr.Dropdown(
|
| 790 |
+
label="Select Fingerprint Type",
|
| 791 |
+
choices=FP_list,
|
| 792 |
+
interactive=True,
|
| 793 |
+
value=FP_list[0] if FP_list else None # Set default if available
|
| 794 |
+
)
|
| 795 |
+
calculate_fp_btn = gr.Button("Calculate Fingerprints")
|
| 796 |
+
status_step2 = gr.Textbox(label="Status", interactive=False)
|
| 797 |
+
output_df_s2 = gr.DataFrame(label="Fingerprint Data (First 5 rows, with pIC50)", max_rows=5)
|
| 798 |
+
download_s2 = gr.DownloadButton("Download Fingerprint Data (CSV)", visible=False, interactive=False)
|
| 799 |
+
|
| 800 |
+
# Using HTML component for custom molecule display
|
| 801 |
+
mols_grid_s2 = gr.HTML(label="Molecules with pIC50", visible=True)
|
| 802 |
+
|
| 803 |
+
|
| 804 |
+
with gr.TabItem("Step 3: Model Training & Evaluation"):
|
| 805 |
+
gr.Markdown("## 3.1 Train Regression Models")
|
| 806 |
+
train_models_btn = gr.Button("Train Models")
|
| 807 |
+
status_step3_train = gr.Textbox(label="Status", interactive=False)
|
| 808 |
+
model_results_df = gr.DataFrame(label="Model Performance Metrics", max_rows=10)
|
| 809 |
+
|
| 810 |
+
gr.Markdown("## 3.2 Model Analysis")
|
| 811 |
+
model_selector_s3 = gr.Dropdown(
|
| 812 |
+
label="Select Model for Detailed Analysis",
|
| 813 |
+
choices=[],
|
| 814 |
+
interactive=True
|
| 815 |
+
)
|
| 816 |
+
feature_count_s3 = gr.Slider(
|
| 817 |
+
minimum=0, maximum=50, step=1, value=10,
|
| 818 |
+
label="Number of Top Features to Display", interactive=True
|
| 819 |
+
)
|
| 820 |
+
model_metrics_summary = gr.Textbox(label="Selected Model Metrics (Test Set)", interactive=False)
|
| 821 |
+
obs_pred_plot = gr.Plot(label="Observed vs. Predicted pIC50")
|
| 822 |
+
residuals_plot = gr.Plot(label="Residuals Plot")
|
| 823 |
+
feature_importance_plot = gr.Plot(label="Feature Importance/Coefficients")
|
| 824 |
+
|
| 825 |
+
with gr.TabItem("Step 4: Prediction on New Data"):
|
| 826 |
+
gr.Markdown("## 4.1 Upload New Molecules & Predict")
|
| 827 |
+
upload_new_mols = gr.File(label="Upload CSV with 'canonical_smiles' column")
|
| 828 |
+
model_selector_s4 = gr.Dropdown(
|
| 829 |
+
label="Select Trained Model for Prediction",
|
| 830 |
+
choices=[],
|
| 831 |
+
interactive=True
|
| 832 |
+
)
|
| 833 |
+
predict_btn = gr.Button("Predict pIC50 for New Molecules")
|
| 834 |
+
status_step4 = gr.Textbox(label="Status", interactive=False)
|
| 835 |
+
predictions_output_df = gr.DataFrame(label="Predictions for New Molecules", max_rows=10)
|
| 836 |
+
download_s4 = gr.DownloadButton("Download Predictions (CSV)", visible=False, interactive=False)
|
| 837 |
+
# Using HTML component for custom molecule display
|
| 838 |
+
mols_grid_s4 = gr.HTML(label="New Molecules with Predicted pIC50", visible=True)
|
| 839 |
|
| 840 |
# --- EVENT HANDLERS ---
|
| 841 |
+
|
| 842 |
+
# Step 1 Callbacks
|
| 843 |
+
search_target_btn.click(
|
| 844 |
+
fn=lambda query: (get_target_chembl_id(query), gr.update(visible=True)), # Return two values
|
| 845 |
+
inputs=target_query,
|
| 846 |
+
outputs=[target_output_df, chembl_id_selector, status_step1_search]
|
| 847 |
+
)
|
| 848 |
+
|
| 849 |
+
chembl_id_selector.change(
|
| 850 |
+
fn=lambda x: gr.update(value=x),
|
| 851 |
+
inputs=chembl_id_selector,
|
| 852 |
+
outputs=chembl_id_selector # This is just to ensure the value is updated
|
| 853 |
+
)
|
| 854 |
+
|
| 855 |
+
fetch_data_btn.click(
|
| 856 |
+
fn=lambda target_id: (get_bioactivity_data(target_id), gr.update(value=target_id)),
|
| 857 |
+
inputs=chembl_id_selector,
|
| 858 |
+
outputs=[bioactivity_output_df, status_step1_fetch]
|
| 859 |
+
).then(
|
| 860 |
+
fn=lambda df: (df, df.copy()), # Pass the fetched df to the state
|
| 861 |
+
inputs=bioactivity_output_df,
|
| 862 |
+
outputs=[gr.State(value={}, key='raw_data'), app_state]
|
| 863 |
+
)
|
| 864 |
+
|
| 865 |
+
process_data_btn.click(
|
| 866 |
+
fn=lambda current_state: clean_and_process_data(current_state.get('raw_data')),
|
| 867 |
+
inputs=app_state,
|
| 868 |
+
outputs=[cleaned_data_output_df, status_step1_process_clean]
|
| 869 |
+
).then(
|
| 870 |
+
fn=lambda df: (gr.update(visible=True), df), # Show download button and update state
|
| 871 |
+
inputs=cleaned_data_output_df,
|
| 872 |
+
outputs=[download_s1, gr.State(value={}, key='cleaned_data'), app_state]
|
| 873 |
+
)
|
| 874 |
+
|
| 875 |
+
@download_s1.click(inputs=app_state, outputs=download_s1, show_progress="hidden")
|
| 876 |
+
def download_handler_s1(current_state):
|
| 877 |
+
df_to_download = current_state.get('cleaned_data')
|
| 878 |
+
if df_to_download is None:
|
| 879 |
+
raise gr.Error("No data to download. Please process data first.")
|
| 880 |
+
# Create a dummy file path for Gradio to handle the download
|
| 881 |
+
file_path = "cleaned_data.csv"
|
| 882 |
+
df_to_download.to_csv(file_path, index=False)
|
| 883 |
+
return gr.File(file_path, visible=True)
|
| 884 |
+
|
| 885 |
+
run_eda_btn.click(
|
| 886 |
+
fn=lambda df, classes: run_eda_analysis(df, classes),
|
| 887 |
+
inputs=[cleaned_data_output_df, bioactivity_class_selector],
|
| 888 |
+
outputs=[freq_plot_output, scatter_plot_output, pic50_plot_output, pic50_stats_output,
|
| 889 |
+
mw_plot_output, mw_stats_output, logp_plot_output, logp_stats_output,
|
| 890 |
+
hdonors_plot_output, hdonors_stats_output, hacceptors_plot_output, hacceptors_stats_output,
|
| 891 |
+
status_step1_process]
|
| 892 |
+
)
|
| 893 |
|
| 894 |
+
# Update EDA plots on filter change (if data is already processed)
|
| 895 |
+
bioactivity_class_selector.change(
|
| 896 |
+
fn=lambda current_state, selected_classes: run_eda_analysis(current_state.get('cleaned_data'), selected_classes),
|
| 897 |
+
inputs=[app_state, bioactivity_class_selector],
|
| 898 |
+
outputs=[freq_plot_output, scatter_plot_output, pic50_plot_output, pic50_stats_output,
|
| 899 |
+
mw_plot_output, mw_stats_output, logp_plot_output, logp_stats_output,
|
| 900 |
+
hdonors_plot_output, hdonors_stats_output, hacceptors_plot_output, hacceptors_stats_output,
|
| 901 |
+
status_step1_process],
|
| 902 |
+
show_progress="minimal"
|
| 903 |
+
)
|
| 904 |
+
|
| 905 |
+
|
| 906 |
+
# Step 2 Callbacks
|
| 907 |
+
calculate_fp_btn.click(
|
| 908 |
+
fn=calculate_fingerprints,
|
| 909 |
+
inputs=[app_state, fingerprint_dropdown],
|
| 910 |
+
outputs=[status_step2, output_df_s2, download_s2, mols_grid_s2, app_state]
|
| 911 |
+
)
|
| 912 |
|
| 913 |
+
@download_s2.click(inputs=app_state, outputs=download_s2, show_progress="hidden")
|
| 914 |
+
def download_handler(current_state):
|
| 915 |
+
df_to_download = current_state.get('fingerprint_data')
|
| 916 |
+
if df_to_download is None:
|
| 917 |
+
raise gr.Error("No data to download. Please calculate fingerprints first.")
|
| 918 |
+
file_path = "fingerprint_data.csv"
|
| 919 |
+
df_to_download.to_csv(file_path, index=False)
|
| 920 |
+
return gr.File(file_path, visible=True)
|
| 921 |
+
|
| 922 |
+
|
| 923 |
+
# Step 3 Callbacks
|
| 924 |
+
train_models_btn.click(
|
| 925 |
+
fn=handle_model_training,
|
| 926 |
+
inputs=[app_state],
|
| 927 |
+
outputs=[status_step3_train, model_results_df, model_selector_s3, app_state]
|
| 928 |
+
)
|
| 929 |
+
|
| 930 |
+
# Update plots when model or feature count changes
|
| 931 |
+
for listener in [model_selector_s3.change, feature_count_s3.change]:
|
| 932 |
+
listener(
|
| 933 |
+
fn=update_analysis_plots,
|
| 934 |
+
inputs=[model_selector_s3, feature_count_s3, app_state],
|
| 935 |
+
outputs=[obs_pred_plot, residuals_plot, feature_importance_plot, model_metrics_summary, status_step3_train]
|
| 936 |
+
)
|
| 937 |
+
|
| 938 |
+
# Update model selector in Step 4 when models are trained
|
| 939 |
+
model_selector_s3.change(
|
| 940 |
+
fn=lambda choice: gr.update(choices=model_selector_s3.choices, value=choice),
|
| 941 |
+
inputs=model_selector_s3,
|
| 942 |
+
outputs=model_selector_s4
|
| 943 |
+
)
|
| 944 |
+
|
| 945 |
+
|
| 946 |
+
# Step 4 Callbacks
|
| 947 |
+
predict_btn.click(
|
| 948 |
+
fn=predict_on_upload,
|
| 949 |
+
inputs=[upload_new_mols, model_selector_s4, app_state],
|
| 950 |
+
outputs=[status_step4, predictions_output_df, mols_grid_s4]
|
| 951 |
+
).then(
|
| 952 |
+
fn=lambda: gr.update(visible=True),
|
| 953 |
+
outputs=download_s4
|
| 954 |
+
)
|
| 955 |
+
|
| 956 |
+
@download_s4.click(inputs=predictions_output_df, outputs=download_s4, show_progress="hidden")
|
| 957 |
+
def download_predictions(df_predictions):
|
| 958 |
+
if df_predictions is None or df_predictions.empty:
|
| 959 |
+
raise gr.Error("No predictions to download.")
|
| 960 |
+
file_path = "predictions.csv"
|
| 961 |
+
df_predictions.to_csv(file_path, index=False)
|
| 962 |
+
return gr.File(file_path, visible=True)
|
| 963 |
+
|
| 964 |
|
| 965 |
+
demo.launch()
|
|
|