Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,473 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --- IMPORTS ---
|
2 |
+
# Core and Data Handling
|
3 |
+
import gradio as gr #
|
4 |
+
import pandas as pd #
|
5 |
+
import numpy as np #
|
6 |
+
import os #
|
7 |
+
import glob #
|
8 |
+
import time #
|
9 |
+
import warnings #
|
10 |
+
|
11 |
+
# Chemistry and Cheminformatics
|
12 |
+
from rdkit import Chem #
|
13 |
+
from rdkit.Chem import Descriptors, Lipinski #
|
14 |
+
from chembl_webresource_client.new_client import new_client #
|
15 |
+
from padelpy import padeldescriptor #
|
16 |
+
import mols2grid #
|
17 |
+
|
18 |
+
# Plotting and Visualization
|
19 |
+
import matplotlib.pyplot as plt #
|
20 |
+
import seaborn as sns #
|
21 |
+
from scipy import stats #
|
22 |
+
from scipy.stats import mannwhitneyu #
|
23 |
+
|
24 |
+
# Machine Learning Models and Metrics
|
25 |
+
from sklearn.model_selection import train_test_split #
|
26 |
+
from sklearn.feature_selection import VarianceThreshold #
|
27 |
+
from sklearn.linear_model import ( #
|
28 |
+
LinearRegression, Ridge, Lasso, ElasticNet, BayesianRidge, #
|
29 |
+
HuberRegressor, PassiveAggressiveRegressor, OrthogonalMatchingPursuit, #
|
30 |
+
LassoLars #
|
31 |
+
)
|
32 |
+
from sklearn.tree import DecisionTreeRegressor #
|
33 |
+
from sklearn.ensemble import ( #
|
34 |
+
RandomForestRegressor, GradientBoostingRegressor, ExtraTreesRegressor, #
|
35 |
+
AdaBoostRegressor #
|
36 |
+
)
|
37 |
+
from sklearn.neighbors import KNeighborsRegressor #
|
38 |
+
from sklearn.dummy import DummyRegressor #
|
39 |
+
from sklearn.metrics import ( #
|
40 |
+
mean_absolute_error, mean_squared_error, r2_score #
|
41 |
+
)
|
42 |
+
|
43 |
+
# A placeholder class to store all results from a modeling run
|
44 |
+
class ModelRunResult: #
|
45 |
+
def __init__(self, dataframe, plotter, models, selected_features): #
|
46 |
+
self.dataframe = dataframe #
|
47 |
+
self.plotter = plotter #
|
48 |
+
self.models = models #
|
49 |
+
self.selected_features = selected_features #
|
50 |
+
|
51 |
+
# Optional Advanced Models
|
52 |
+
try: #
|
53 |
+
import xgboost as xgb #
|
54 |
+
import lightgbm as lgb #
|
55 |
+
import catboost as cb #
|
56 |
+
_has_extra_libs = True #
|
57 |
+
except ImportError: #
|
58 |
+
_has_extra_libs = False #
|
59 |
+
warnings.warn("Optional libraries (xgboost, lightgbm, catboost) not found. Some models will be unavailable.") #
|
60 |
+
|
61 |
+
# --- GLOBAL CONFIGURATION & SETUP ---
|
62 |
+
warnings.filterwarnings("ignore") #
|
63 |
+
sns.set_theme(style='whitegrid') #
|
64 |
+
|
65 |
+
# --- FINGERPRINT CONFIGURATION ---
|
66 |
+
DESCRIPTOR_DIR = "padel_descriptors"
|
67 |
+
|
68 |
+
# Check if the descriptor directory exists and contains files
|
69 |
+
if not os.path.isdir(DESCRIPTOR_DIR):
|
70 |
+
warnings.warn(
|
71 |
+
f"The descriptor directory '{DESCRIPTOR_DIR}' was not found. "
|
72 |
+
"Fingerprint calculation will be disabled. Please create this directory and upload your .xml files."
|
73 |
+
)
|
74 |
+
xml_files = []
|
75 |
+
else:
|
76 |
+
xml_files = sorted(glob.glob(os.path.join(DESCRIPTOR_DIR, '*.xml')))
|
77 |
+
|
78 |
+
if not xml_files:
|
79 |
+
warnings.warn(
|
80 |
+
f"No descriptor .xml files found in the '{DESCRIPTOR_DIR}' directory. "
|
81 |
+
"Fingerprint calculation will not be possible."
|
82 |
+
)
|
83 |
+
|
84 |
+
# The key is the filename without extension; the value is the full path to the file
|
85 |
+
fp_config = {os.path.splitext(os.path.basename(file))[0]: file for file in xml_files}
|
86 |
+
FP_list = sorted(list(fp_config.keys()))
|
87 |
+
|
88 |
+
|
89 |
+
# ==============================================================================
|
90 |
+
# === STEP 1: CORE DATA COLLECTION & EDA FUNCTIONS ===
|
91 |
+
# ==============================================================================
|
92 |
+
|
93 |
+
def get_target_chembl_id(query): #
|
94 |
+
try: #
|
95 |
+
target = new_client.target #
|
96 |
+
res = target.search(query) #
|
97 |
+
if not res: #
|
98 |
+
return pd.DataFrame(), gr.Dropdown(choices=[], value=None), "No targets found for your query." #
|
99 |
+
df = pd.DataFrame(res) #
|
100 |
+
return df[["target_chembl_id", "pref_name", "organism"]], gr.Dropdown(choices=df["target_chembl_id"].tolist()), f"Found {len(df)} targets." #
|
101 |
+
except Exception as e: #
|
102 |
+
raise gr.Error(f"ChEMBL search failed: {e}") #
|
103 |
+
|
104 |
+
def get_bioactivity_data(target_id): #
|
105 |
+
try: #
|
106 |
+
activity = new_client.activity #
|
107 |
+
res = activity.filter(target_chembl_id=target_id).filter(standard_type="IC50") #
|
108 |
+
if not res: #
|
109 |
+
return pd.DataFrame(), "No IC50 bioactivity data found for this target." #
|
110 |
+
df = pd.DataFrame(res) #
|
111 |
+
return df, f"Fetched {len(df)} data points." #
|
112 |
+
except Exception as e: #
|
113 |
+
raise gr.Error(f"Failed to fetch bioactivity data: {e}") #
|
114 |
+
|
115 |
+
def pIC50_calc(input_df): #
|
116 |
+
df_copy = input_df.copy() #
|
117 |
+
df_copy['standard_value'] = pd.to_numeric(df_copy['standard_value'], errors='coerce') #
|
118 |
+
df_copy.dropna(subset=['standard_value'], inplace=True) #
|
119 |
+
df_copy['standard_value_norm'] = df_copy['standard_value'].apply(lambda x: min(x, 100000000)) #
|
120 |
+
pIC50_values = [] #
|
121 |
+
for i in df_copy['standard_value_norm']: #
|
122 |
+
if pd.notna(i) and i > 0: #
|
123 |
+
molar = i * (10**-9) #
|
124 |
+
pIC50_values.append(-np.log10(molar)) #
|
125 |
+
else: #
|
126 |
+
pIC50_values.append(np.nan) #
|
127 |
+
df_copy['pIC50'] = pIC50_values #
|
128 |
+
df_copy['bioactivity_class'] = df_copy['standard_value_norm'].apply( #
|
129 |
+
lambda x: "inactive" if pd.notna(x) and x >= 10000 else ("active" if pd.notna(x) and x <= 1000 else "intermediate") #
|
130 |
+
)
|
131 |
+
return df_copy.drop(columns=['standard_value', 'standard_value_norm']) #
|
132 |
+
|
133 |
+
def lipinski_descriptors(smiles_series): #
|
134 |
+
moldata, valid_smiles = [], [] #
|
135 |
+
for elem in smiles_series: #
|
136 |
+
if elem and isinstance(elem, str): #
|
137 |
+
mol = Chem.MolFromSmiles(elem) #
|
138 |
+
if mol: #
|
139 |
+
moldata.append(mol) #
|
140 |
+
valid_smiles.append(elem) #
|
141 |
+
descriptor_rows = [] #
|
142 |
+
for mol in moldata: #
|
143 |
+
row = [Descriptors.MolWt(mol), Descriptors.MolLogP(mol), Lipinski.NumHDonors(mol), Lipinski.NumHAcceptors(mol)] #
|
144 |
+
descriptor_rows.append(row) #
|
145 |
+
columnNames = ["MW", "LogP", "NumHDonors", "NumHAcceptors"] #
|
146 |
+
if not descriptor_rows: return pd.DataFrame(columns=columnNames), [] #
|
147 |
+
return pd.DataFrame(data=np.array(descriptor_rows), columns=columnNames), valid_smiles #
|
148 |
+
|
149 |
+
def clean_and_process_data(df): #
|
150 |
+
if df is None or df.empty: raise gr.Error("No data to process. Please fetch data first.") #
|
151 |
+
if "canonical_smiles" not in df.columns or df["canonical_smiles"].isnull().all(): #
|
152 |
+
try: #
|
153 |
+
df["canonical_smiles"] = [c.get("molecule_structures", {}).get("canonical_smiles") for c in new_client.molecule.get(list(df["molecule_chembl_id"]))] #
|
154 |
+
except Exception as e: #
|
155 |
+
raise gr.Error(f"Could not fetch SMILES from ChEMBL: {e}") #
|
156 |
+
df = df[df.standard_value.notna()] #
|
157 |
+
df = df[df.canonical_smiles.notna()] #
|
158 |
+
df.drop_duplicates(['canonical_smiles'], inplace=True) #
|
159 |
+
df["standard_value"] = pd.to_numeric(df["standard_value"], errors='coerce') #
|
160 |
+
df.dropna(subset=['standard_value'], inplace=True) #
|
161 |
+
df_processed = pIC50_calc(df) #
|
162 |
+
df_processed = df_processed[df_processed.pIC50.notna()] #
|
163 |
+
if df_processed.empty: return pd.DataFrame(), "No compounds remaining after pIC50 calculation." #
|
164 |
+
df_lipinski, valid_smiles = lipinski_descriptors(df_processed['canonical_smiles']) #
|
165 |
+
if not valid_smiles: return pd.DataFrame(), "No valid SMILES could be processed for Lipinski descriptors." #
|
166 |
+
df_processed = df_processed[df_processed['canonical_smiles'].isin(valid_smiles)].reset_index(drop=True) #
|
167 |
+
df_lipinski = df_lipinski.reset_index(drop=True) #
|
168 |
+
df_final = pd.concat([df_processed, df_lipinski], axis=1) #
|
169 |
+
return df_final, f"Processing complete. {len(df_final)} compounds remain after cleaning." #
|
170 |
+
|
171 |
+
def run_eda_analysis(df, selected_classes): #
|
172 |
+
if df is None or df.empty: raise gr.Error("No data available for analysis.") #
|
173 |
+
df_filtered = df[df.bioactivity_class.isin(selected_classes)].copy() #
|
174 |
+
if df_filtered.empty: return (None, None, None, pd.DataFrame(), None, pd.DataFrame(), None, pd.DataFrame(), None, pd.DataFrame(), None, pd.DataFrame(), "No data for selected classes.") #
|
175 |
+
plots = [create_frequency_plot(df_filtered), create_scatter_plot(df_filtered)] #
|
176 |
+
stats_dfs = [] #
|
177 |
+
for desc in ['pIC50', 'MW', 'LogP', 'NumHDonors', 'NumHAcceptors']: #
|
178 |
+
plots.append(create_boxplot(df_filtered, desc)) #
|
179 |
+
stats_dfs.append(mannwhitney_test(df_filtered, desc)) #
|
180 |
+
plt.close('all') #
|
181 |
+
return (plots[0], plots[1], plots[2], stats_dfs[0], plots[3], stats_dfs[1], plots[4], stats_dfs[2], plots[5], stats_dfs[3], plots[6], stats_dfs[4], f"EDA complete for {len(df_filtered)} compounds.") #
|
182 |
+
|
183 |
+
def create_frequency_plot(df): #
|
184 |
+
plt.figure(figsize=(5.5, 5.5)); sns.barplot(x=df['bioactivity_class'].value_counts().index, y=df['bioactivity_class'].value_counts().values, palette={'active': '#1f77b4', 'inactive': '#ff7f0e', 'intermediate': '#2ca02c'}); plt.xlabel('Bioactivity Class', fontsize=12); plt.ylabel('Frequency', fontsize=12); plt.title('Frequency of Bioactivity Classes', fontsize=14); return plt.gcf() #
|
185 |
+
def create_scatter_plot(df): #
|
186 |
+
plt.figure(figsize=(5.5, 5.5)); sns.scatterplot(data=df, x='MW', y='LogP', hue='bioactivity_class', size='pIC50', palette={'active': '#1f77b4', 'inactive': '#ff7f0e', 'intermediate': '#2ca02c'}, sizes=(20, 200), alpha=0.7); plt.xlabel('Molecular Weight (MW)', fontsize=12); plt.ylabel('LogP', fontsize=12); plt.title('Chemical Space: MW vs. LogP', fontsize=14); plt.legend(title='Bioactivity Class'); return plt.gcf() #
|
187 |
+
def create_boxplot(df, descriptor): #
|
188 |
+
plt.figure(figsize=(5.5, 5.5)); sns.boxplot(x='bioactivity_class', y=descriptor, data=df, palette={'active': '#1f77b4', 'inactive': '#ff7f0e', 'intermediate': '#2ca02c'}); plt.xlabel('Bioactivity Class', fontsize=12); plt.ylabel(descriptor, fontsize=12); plt.title(f'{descriptor} by Bioactivity Class', fontsize=14); return plt.gcf() #
|
189 |
+
def mannwhitney_test(df, descriptor): #
|
190 |
+
results = [] #
|
191 |
+
for c1, c2 in [('active', 'inactive'), ('active', 'intermediate'), ('inactive', 'intermediate')]: #
|
192 |
+
if c1 in df['bioactivity_class'].unique() and c2 in df['bioactivity_class'].unique(): #
|
193 |
+
d1, d2 = df[df.bioactivity_class == c1][descriptor].dropna(), df[df.bioactivity_class == c2][descriptor].dropna() #
|
194 |
+
if not d1.empty and not d2.empty: #
|
195 |
+
stat, p = mannwhitneyu(d1, d2) #
|
196 |
+
results.append({'Comparison': f'{c1.title()} vs {c2.title()}', 'Statistics': stat, 'p-value': p, 'Interpretation': 'Different distribution (p < 0.05)' if p <= 0.05 else 'Same distribution (p > 0.05)'}) #
|
197 |
+
return pd.DataFrame(results) #
|
198 |
+
|
199 |
+
# ==============================================================================
|
200 |
+
# === STEP 2: FEATURE ENGINEERING FUNCTIONS ===
|
201 |
+
# ==============================================================================
|
202 |
+
|
203 |
+
def calculate_fingerprints(current_state, fingerprint_type, progress=gr.Progress()): #
|
204 |
+
input_df = current_state.get('cleaned_data') #
|
205 |
+
if input_df is None or input_df.empty: raise gr.Error("No cleaned data found. Please complete Step 1.") #
|
206 |
+
if not fingerprint_type: raise gr.Error("Please select a fingerprint type.") #
|
207 |
+
progress(0, desc="Starting..."); yield f"🧪 Starting fingerprint calculation...", None, gr.update(visible=False), None, current_state #
|
208 |
+
try: #
|
209 |
+
smi_file, output_csv = 'molecule.smi', 'fingerprints.csv' #
|
210 |
+
|
211 |
+
input_df[['canonical_smiles', 'canonical_smiles']].to_csv(smi_file, sep='\t', index=False, header=False) #
|
212 |
+
|
213 |
+
if os.path.exists(output_csv): os.remove(output_csv) #
|
214 |
+
descriptortypes = fp_config.get(fingerprint_type) #
|
215 |
+
if not descriptortypes: raise gr.Error(f"Descriptor XML for '{fingerprint_type}' not found.") #
|
216 |
+
|
217 |
+
progress(0.3, desc="⚗️ Running PaDEL..."); yield f"⚗️ Running PaDEL...", None, gr.update(visible=False), None, current_state #
|
218 |
+
padeldescriptor(mol_dir=smi_file, d_file=output_csv, descriptortypes=descriptortypes, detectaromaticity=True, standardizenitro=True, standardizetautomers=True, threads=-1, removesalt=True, log=False, fingerprints=True) #
|
219 |
+
if not os.path.exists(output_csv) or os.path.getsize(output_csv) == 0: #
|
220 |
+
raise gr.Error("PaDEL failed to produce an output file. Check molecule validity.") #
|
221 |
+
|
222 |
+
progress(0.7, desc="📊 Processing results..."); yield "📊 Processing results...", None, gr.update(visible=False), None, current_state #
|
223 |
+
df_X = pd.read_csv(output_csv).rename(columns={'Name': 'canonical_smiles'}) #
|
224 |
+
|
225 |
+
final_df = pd.merge(input_df[['canonical_smiles', 'pIC50']], df_X, on='canonical_smiles', how='inner') #
|
226 |
+
|
227 |
+
current_state['fingerprint_data'] = final_df; current_state['fingerprint_type'] = fingerprint_type #
|
228 |
+
progress(0.9, desc="🖼️ Generating molecule grid...") #
|
229 |
+
mols_html = mols2grid.display(final_df, smiles_col='canonical_smiles', subset=['img', 'pIC50'], rename={"pIC50": "pIC50"}, transform={"pIC50": lambda x: f"{x:.2f}"})._repr_html_() #
|
230 |
+
success_msg = f"✅ Success! Generated {len(df_X.columns) -1} descriptors for {len(final_df)} molecules." #
|
231 |
+
progress(1, desc="Completed!"); yield success_msg, final_df, gr.update(visible=True), gr.update(value=mols_html, visible=True), current_state #
|
232 |
+
except Exception as e: raise gr.Error(f"Calculation failed: {e}") #
|
233 |
+
finally: #
|
234 |
+
if os.path.exists('molecule.smi'): os.remove('molecule.smi') #
|
235 |
+
if os.path.exists('fingerprints.csv'): os.remove('fingerprints.csv') #
|
236 |
+
|
237 |
+
# ==============================================================================
|
238 |
+
# === STEP 3: MODEL TRAINING & PREDICTION FUNCTIONS ===
|
239 |
+
# ==============================================================================
|
240 |
+
class ModelPlotter: #
|
241 |
+
def __init__(self, models: dict, X_test: pd.DataFrame, y_test: pd.Series): #
|
242 |
+
self._models, self._X_test, self._y_test = models, X_test, y_test #
|
243 |
+
def plot_validation(self, model_name: str): #
|
244 |
+
if model_name not in self._models: raise ValueError(f"Model '{model_name}' not found.") #
|
245 |
+
model, y_pred = self._models[model_name], self._models[model_name].predict(self._X_test) #
|
246 |
+
residuals = self._y_test - y_pred #
|
247 |
+
fig, axes = plt.subplots(2, 2, figsize=(12, 10)); fig.suptitle(f'Model Validation Plots for {model_name}', fontsize=16, y=1.02) #
|
248 |
+
sns.scatterplot(x=self._y_test, y=y_pred, ax=axes[0, 0], alpha=0.6); axes[0, 0].set_title('Actual vs. Predicted'); axes[0, 0].set_xlabel('Actual pIC50'); axes[0, 0].set_ylabel('Predicted pIC50'); lims = [min(self._y_test.min(), y_pred.min()), max(self._y_test.max(), y_pred.max())]; axes[0, 0].plot(lims, lims, 'r--', alpha=0.75, zorder=0) #
|
249 |
+
sns.scatterplot(x=y_pred, y=residuals, ax=axes[0, 1], alpha=0.6); axes[0, 1].axhline(y=0, color='r', linestyle='--'); axes[0, 1].set_title('Residuals vs. Predicted'); axes[0, 1].set_xlabel('Predicted pIC50'); axes[0, 1].set_ylabel('Residuals') #
|
250 |
+
sns.histplot(residuals, kde=True, ax=axes[1, 0]); axes[1, 0].set_title('Distribution of Residuals') #
|
251 |
+
stats.probplot(residuals, dist="norm", plot=axes[1, 1]); axes[1, 1].set_title('Normal Q-Q Plot') #
|
252 |
+
plt.tight_layout(); return fig #
|
253 |
+
def plot_feature_importance(self, model_name: str, top_n: int = 7): #
|
254 |
+
if model_name not in self._models: raise ValueError(f"Model '{model_name}' not found.") #
|
255 |
+
model = self._models[model_name] #
|
256 |
+
if hasattr(model, 'feature_importances_'): importances = model.feature_importances_ #
|
257 |
+
elif hasattr(model, 'coef_'): importances = np.abs(model.coef_) #
|
258 |
+
else: return None #
|
259 |
+
top_features = pd.DataFrame({'Feature': self._X_test.columns, 'Importance': importances}).sort_values(by='Importance', ascending=False).head(top_n) #
|
260 |
+
plt.figure(figsize=(10, top_n * 0.5)); sns.barplot(x='Importance', y='Feature', data=top_features, palette='viridis', orient='h'); plt.title(f'Top {top_n} Features for {model_name}'); plt.tight_layout(); return plt.gcf() #
|
261 |
+
|
262 |
+
def run_regression_suite(df: pd.DataFrame, progress=gr.Progress()): #
|
263 |
+
progress(0, desc="Splitting data..."); yield "Splitting data (80/20 train/test split)...", None, None #
|
264 |
+
X = df.drop(columns=['pIC50', 'canonical_smiles'], errors='ignore') #
|
265 |
+
y = df['pIC50'] #
|
266 |
+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) #
|
267 |
+
|
268 |
+
progress(0.1, desc="Selecting features..."); yield "Performing feature selection (removing low variance)...", None, None #
|
269 |
+
selector = VarianceThreshold(threshold=0.1) #
|
270 |
+
X_train = pd.DataFrame(selector.fit_transform(X_train), columns=X_train.columns[selector.get_support()], index=X_train.index) #
|
271 |
+
X_test = pd.DataFrame(selector.transform(X_test), columns=X_test.columns[selector.get_support()], index=X_test.index) #
|
272 |
+
selected_features = X_train.columns.tolist() #
|
273 |
+
|
274 |
+
model_defs = [('Linear Regression', LinearRegression()), ('Ridge', Ridge(random_state=42)), ('Lasso', Lasso(random_state=42)), ('Random Forest', RandomForestRegressor(random_state=42, n_jobs=-1)), ('Gradient Boosting', GradientBoostingRegressor(random_state=42))] #
|
275 |
+
if _has_extra_libs: model_defs.extend([('XGBoost', xgb.XGBRegressor(random_state=42, n_jobs=-1, verbosity=0)), ('LightGBM', lgb.LGBMRegressor(random_state=42, n_jobs=-1, verbosity=-1)), ('CatBoost', cb.CatBoostRegressor(random_state=42, verbose=0))]) #
|
276 |
+
|
277 |
+
results_list, trained_models = [], {} #
|
278 |
+
for i, (name, model) in enumerate(model_defs): #
|
279 |
+
progress(0.2 + (i / len(model_defs)) * 0.8, desc=f"Training {name}...") #
|
280 |
+
yield f"Training {i+1}/{len(model_defs)}: {name}...", None, None #
|
281 |
+
start_time = time.time(); model.fit(X_train, y_train); y_pred = model.predict(X_test) #
|
282 |
+
results_list.append({'Model': name, 'R²': r2_score(y_test, y_pred), 'MAE': mean_absolute_error(y_test, y_pred), 'RMSE': np.sqrt(mean_squared_error(y_test, y_pred)), 'Time (s)': f"{time.time() - start_time:.2f}"}) #
|
283 |
+
trained_models[name] = model #
|
284 |
+
|
285 |
+
results_df = pd.DataFrame(results_list).sort_values(by='R²', ascending=False).reset_index(drop=True) #
|
286 |
+
plotter = ModelPlotter(trained_models, X_test, y_test) #
|
287 |
+
model_run_results = ModelRunResult(results_df, plotter, trained_models, selected_features) #
|
288 |
+
|
289 |
+
model_choices = results_df['Model'].tolist() #
|
290 |
+
yield "✅ Model training & evaluation complete.", model_run_results, gr.Dropdown(choices=model_choices, interactive=True) #
|
291 |
+
|
292 |
+
def predict_on_upload(uploaded_file, model_name, current_state, progress=gr.Progress()): #
|
293 |
+
if not uploaded_file: raise gr.Error("Please upload a file.") #
|
294 |
+
if not model_name: raise gr.Error("Please select a trained model.") #
|
295 |
+
model_run_results = current_state.get('model_results') #
|
296 |
+
fingerprint_type = current_state.get('fingerprint_type') #
|
297 |
+
if not model_run_results or not fingerprint_type: raise gr.Error("Please run Steps 2 and 3 first.") #
|
298 |
+
|
299 |
+
model = model_run_results.models.get(model_name) #
|
300 |
+
selected_features = model_run_results.selected_features #
|
301 |
+
if model is None: raise gr.Error(f"Model '{model_name}' not found.") #
|
302 |
+
|
303 |
+
smi_file, output_csv = 'predict.smi', 'predict_fp.csv' #
|
304 |
+
try: #
|
305 |
+
progress(0, desc="Reading & processing new molecules..."); yield "Reading uploaded file...", None, None #
|
306 |
+
df_new = pd.read_csv(uploaded_file.name) #
|
307 |
+
if 'canonical_smiles' not in df_new.columns: raise gr.Error("CSV must contain a 'canonical_smiles' column.") #
|
308 |
+
df_new = df_new.reset_index().rename(columns={'index': 'mol_id'}) #
|
309 |
+
|
310 |
+
padel_input = pd.DataFrame({'smiles': df_new['canonical_smiles'], 'name': df_new['mol_id']}) #
|
311 |
+
padel_input.to_csv(smi_file, sep='\t', index=False, header=False) #
|
312 |
+
if os.path.exists(output_csv): os.remove(output_csv) #
|
313 |
+
|
314 |
+
progress(0.3, desc="Calculating fingerprints..."); yield "Calculating fingerprints for new molecules...", None, None #
|
315 |
+
padeldescriptor(mol_dir=smi_file, d_file=output_csv, descriptortypes=fp_config.get(fingerprint_type), detectaromaticity=True, standardizenitro=True, threads=-1, removesalt=True, log=False, fingerprints=True) #
|
316 |
+
if not os.path.exists(output_csv) or os.path.getsize(output_csv) == 0: raise gr.Error("PaDEL calculation failed for the uploaded molecules.") #
|
317 |
+
|
318 |
+
progress(0.7, desc="Aligning features and predicting..."); yield "Aligning features and predicting...", None, None #
|
319 |
+
df_fp = pd.read_csv(output_csv).rename(columns={'Name': 'mol_id'}) #
|
320 |
+
|
321 |
+
X_new = df_fp.set_index('mol_id') #
|
322 |
+
X_new_aligned = X_new.reindex(columns=selected_features, fill_value=0)[selected_features] #
|
323 |
+
|
324 |
+
predictions = model.predict(X_new_aligned) #
|
325 |
+
|
326 |
+
results_subset = pd.DataFrame({'mol_id': X_new_aligned.index, 'predicted_pIC50': predictions}) #
|
327 |
+
df_results = pd.merge(df_new, results_subset, on='mol_id', how='left') #
|
328 |
+
|
329 |
+
progress(0.9, desc="Generating visualization..."); yield "Generating visualization...", None, None #
|
330 |
+
|
331 |
+
df_grid_view = df_results.dropna(subset=['predicted_pIC50']).copy() #
|
332 |
+
mols_html = "<h3>No molecules with successful predictions to display.</h3>" #
|
333 |
+
if not df_grid_view.empty: #
|
334 |
+
df_grid_view.rename(columns={"predicted_pIC50": "Predicted pIC50"}, inplace=True) #
|
335 |
+
mols_html = mols2grid.display( #
|
336 |
+
df_grid_view, #
|
337 |
+
smiles_col='canonical_smiles', #
|
338 |
+
subset=['img', 'Predicted pIC50'], #
|
339 |
+
transform={"Predicted pIC50": lambda x: f"{x:.2f}"} #
|
340 |
+
)._repr_html_() #
|
341 |
+
|
342 |
+
progress(1, desc="Complete!"); yield "✅ Prediction complete.", df_results[['canonical_smiles', 'predicted_pIC50']], mols_html #
|
343 |
+
finally: #
|
344 |
+
if os.path.exists(smi_file): os.remove(smi_file) #
|
345 |
+
if os.path.exists(output_csv): os.remove(output_csv) #
|
346 |
+
|
347 |
+
# ==============================================================================
|
348 |
+
# === GRADIO INTERFACE ===
|
349 |
+
# ==============================================================================
|
350 |
+
with gr.Blocks(theme=gr.themes.Default(primary_hue="blue", secondary_hue="sky"), title="Comprehensive Drug Discovery Workflow") as demo: #
|
351 |
+
gr.Markdown("# 🧪 Comprehensive Drug Discovery Workflow") #
|
352 |
+
gr.Markdown("A 3-step application to fetch, analyze, and model chemical bioactivity data.") #
|
353 |
+
app_state = gr.State({}) #
|
354 |
+
with gr.Tabs(): #
|
355 |
+
with gr.Tab("Step 1: Data Collection & EDA"): #
|
356 |
+
gr.Markdown("## Fetch Bioactivity Data from ChEMBL and Perform Exploratory Analysis") #
|
357 |
+
with gr.Row(): #
|
358 |
+
query_input = gr.Textbox(label="Target Query", placeholder="e.g., acetylcholinesterase, BRAF kinase", scale=3) #
|
359 |
+
fetch_btn = gr.Button("Fetch Targets", variant="primary", scale=1) #
|
360 |
+
status_step1_fetch = gr.Textbox(label="Status", interactive=False) #
|
361 |
+
target_id_table = gr.Dataframe(label="Available Targets", interactive=False, headers=["target_chembl_id", "pref_name", "organism"]) #
|
362 |
+
with gr.Row(): #
|
363 |
+
selected_target_dropdown = gr.Dropdown(label="Select Target ChEMBL ID", interactive=True, scale=3) #
|
364 |
+
process_btn = gr.Button("Process Data & Run EDA", variant="primary", scale=1, interactive=False) #
|
365 |
+
status_step1_process = gr.Textbox(label="Status", interactive=False) #
|
366 |
+
gr.Markdown("### Filtered Data & Analysis") #
|
367 |
+
bioactivity_class_selector = gr.CheckboxGroup(["active", "inactive", "intermediate"], label="Filter by Bioactivity Class", value=["active", "inactive", "intermediate"]) #
|
368 |
+
df_output_s1 = gr.Dataframe(label="Cleaned Bioactivity Data") #
|
369 |
+
with gr.Tabs(): #
|
370 |
+
with gr.Tab("Chemical Space Overview"): #
|
371 |
+
with gr.Row(): #
|
372 |
+
freq_plot_output = gr.Plot(label="Frequency of Bioactivity Classes") #
|
373 |
+
scatter_plot_output = gr.Plot(label="Scatter Plot: MW vs LogP") #
|
374 |
+
with gr.Tab("pIC50 Analysis"): #
|
375 |
+
with gr.Row(): #
|
376 |
+
pic50_plot_output = gr.Plot(label="pIC50 Box Plot") #
|
377 |
+
pic50_stats_output = gr.Dataframe(label="Mann-Whitney U Test Results for pIC50") #
|
378 |
+
with gr.Tab("Molecular Weight Analysis"): #
|
379 |
+
with gr.Row(): #
|
380 |
+
mw_plot_output = gr.Plot(label="MW Box Plot") #
|
381 |
+
mw_stats_output = gr.Dataframe(label="Mann-Whitney U Test Results for MW") #
|
382 |
+
with gr.Tab("LogP Analysis"): #
|
383 |
+
with gr.Row(): #
|
384 |
+
logp_plot_output = gr.Plot(label="LogP Box Plot") #
|
385 |
+
logp_stats_output = gr.Dataframe(label="Mann-Whitney U Test Results for LogP") #
|
386 |
+
with gr.Tab("H-Bond Donor/Acceptor Analysis"): #
|
387 |
+
with gr.Row(): #
|
388 |
+
hdonors_plot_output = gr.Plot(label="H-Donors Box Plot") #
|
389 |
+
hacceptors_plot_output = gr.Plot(label="H-Acceptors Box Plot") #
|
390 |
+
with gr.Row(): #
|
391 |
+
hdonors_stats_output = gr.Dataframe(label="Stats for H-Donors") #
|
392 |
+
hacceptors_stats_output = gr.Dataframe(label="Stats for H-Acceptors") #
|
393 |
+
with gr.Tab("Step 2: Feature Engineering"): #
|
394 |
+
gr.Markdown("## Calculate Molecular Fingerprints using PaDEL") #
|
395 |
+
with gr.Row(): #
|
396 |
+
fingerprint_dropdown = gr.Dropdown(choices=FP_list, value='PubChem' if 'PubChem' in FP_list else None, label="Select Fingerprint Method", scale=3) #
|
397 |
+
calculate_fp_btn = gr.Button("Calculate Fingerprints", variant="primary", scale=1) #
|
398 |
+
status_step2 = gr.Textbox(label="Status", interactive=False) #
|
399 |
+
output_df_s2 = gr.Dataframe(label="Final Processed Data", wrap=True) #
|
400 |
+
download_s2 = gr.DownloadButton("Download Feature Data (CSV)", variant="secondary", visible=False) #
|
401 |
+
mols_grid_s2 = gr.HTML(label="Interactive Molecule Viewer") #
|
402 |
+
with gr.Tab("Step 3: Model Training & Prediction"): #
|
403 |
+
gr.Markdown("## Train Regression Models and Predict pIC50") #
|
404 |
+
with gr.Tabs(): #
|
405 |
+
with gr.Tab("Model Training & Evaluation"): #
|
406 |
+
train_models_btn = gr.Button("Train All Models", variant="primary") #
|
407 |
+
status_step3_train = gr.Textbox(label="Status", interactive=False) #
|
408 |
+
model_results_df = gr.DataFrame(label="Ranked Model Results", interactive=False) #
|
409 |
+
with gr.Row(): #
|
410 |
+
model_selector_s3 = gr.Dropdown(label="Select Model to Analyze", interactive=False) #
|
411 |
+
feature_count_s3 = gr.Number(label="Top Features to Show", value=7, minimum=3, maximum=20, step=1) #
|
412 |
+
with gr.Tabs(): #
|
413 |
+
with gr.Tab("Validation Plots"): validation_plot_s3 = gr.Plot(label="Model Validation Plots") #
|
414 |
+
with gr.Tab("Feature Importance"): feature_plot_s3 = gr.Plot(label="Top Feature Importances") #
|
415 |
+
with gr.Tab("Predict on New Data"): #
|
416 |
+
gr.Markdown("Upload a CSV with a `canonical_smiles` column to predict pIC50.") #
|
417 |
+
with gr.Row(): #
|
418 |
+
upload_predict_file = gr.File(label="Upload CSV for Prediction", file_types=[".csv"]) #
|
419 |
+
predict_btn_s3 = gr.Button("Run Prediction", variant="primary") #
|
420 |
+
status_step3_predict = gr.Textbox(label="Status", interactive=False) #
|
421 |
+
prediction_results_df = gr.DataFrame(label="Prediction Results") #
|
422 |
+
prediction_mols_grid = gr.HTML(label="Interactive Molecular Grid of Predictions") #
|
423 |
+
|
424 |
+
# --- EVENT HANDLERS ---
|
425 |
+
def enable_process_button(target_id): return gr.update(interactive=bool(target_id)) #
|
426 |
+
def process_and_analyze_wrapper(target_id, selected_classes, current_state, progress=gr.Progress()): #
|
427 |
+
if not target_id: raise gr.Error("Please select a target ChEMBL ID first.") #
|
428 |
+
progress(0, desc="Fetching data..."); raw_data, msg1 = get_bioactivity_data(target_id); yield {status_step1_process: gr.update(value=msg1)} #
|
429 |
+
progress(0.3, desc="Cleaning data..."); processed_data, msg2 = clean_and_process_data(raw_data); yield {df_output_s1: processed_data, status_step1_process: gr.update(value=msg2)} #
|
430 |
+
current_state['cleaned_data'] = processed_data #
|
431 |
+
progress(0.6, desc="Running EDA..."); plots_and_stats = run_eda_analysis(processed_data, selected_classes); msg3 = plots_and_stats[-1] #
|
432 |
+
progress(1, desc="Done!") #
|
433 |
+
filtered_data = processed_data[processed_data.bioactivity_class.isin(selected_classes)] if not processed_data.empty else pd.DataFrame() #
|
434 |
+
outputs = [filtered_data] + list(plots_and_stats[:-1]) + [msg3, current_state] #
|
435 |
+
output_components = [df_output_s1, freq_plot_output, scatter_plot_output, pic50_plot_output, pic50_stats_output, mw_plot_output, mw_stats_output, logp_plot_output, logp_stats_output, hdonors_plot_output, hdonors_stats_output, hacceptors_plot_output, hacceptors_stats_output, status_step1_process, app_state] #
|
436 |
+
yield dict(zip(output_components, outputs)) #
|
437 |
+
def update_analysis_on_filter_change(selected_classes, current_state): #
|
438 |
+
cleaned_data = current_state.get('cleaned_data') #
|
439 |
+
if cleaned_data is None or cleaned_data.empty: return (pd.DataFrame(),) + (None,) * 11 + ("No data available.",) #
|
440 |
+
plots_and_stats = run_eda_analysis(cleaned_data, selected_classes); msg = plots_and_stats[-1] #
|
441 |
+
filtered_data = cleaned_data[cleaned_data.bioactivity_class.isin(selected_classes)] #
|
442 |
+
return (filtered_data,) + plots_and_stats[:-1] + (msg,) #
|
443 |
+
def handle_model_training(current_state, progress=gr.Progress(track_tqdm=True)): #
|
444 |
+
fingerprint_data = current_state.get('fingerprint_data') #
|
445 |
+
if fingerprint_data is None or fingerprint_data.empty: raise gr.Error("No feature data. Please complete Step 2.") #
|
446 |
+
for status_msg, model_results, model_choices_update in run_regression_suite(fingerprint_data, progress=progress): #
|
447 |
+
if model_results: current_state['model_results'] = model_results #
|
448 |
+
yield status_msg, model_results.dataframe if model_results else None, model_choices_update, current_state #
|
449 |
+
def save_dataframe_as_csv(df): #
|
450 |
+
if df is None or df.empty: return None #
|
451 |
+
filename = "feature_engineered_data.csv"; df.to_csv(filename, index=False); return gr.File(value=filename, visible=True) #
|
452 |
+
def update_analysis_plots(model_name, feature_count, current_state): #
|
453 |
+
model_results = current_state.get('model_results') #
|
454 |
+
if not model_results or not model_name: return None, None #
|
455 |
+
plotter = model_results.plotter; validation_fig = plotter.plot_validation(model_name); feature_fig = plotter.plot_feature_importance(model_name, int(feature_count)); plt.close('all'); return validation_fig, feature_fig #
|
456 |
+
|
457 |
+
fetch_btn.click(fn=get_target_chembl_id, inputs=query_input, outputs=[target_id_table, selected_target_dropdown, status_step1_fetch], show_progress="minimal") #
|
458 |
+
selected_target_dropdown.change(fn=enable_process_button, inputs=selected_target_dropdown, outputs=process_btn, show_progress="hidden") #
|
459 |
+
process_btn.click(fn=process_and_analyze_wrapper, inputs=[selected_target_dropdown, bioactivity_class_selector, app_state], outputs=[df_output_s1, freq_plot_output, scatter_plot_output, pic50_plot_output, pic50_stats_output, mw_plot_output, mw_stats_output, logp_plot_output, logp_stats_output, hdonors_plot_output, hdonors_stats_output, hacceptors_plot_output, hacceptors_stats_output, status_step1_process, app_state]) #
|
460 |
+
bioactivity_class_selector.change(fn=update_analysis_on_filter_change, inputs=[bioactivity_class_selector, app_state], outputs=[df_output_s1, freq_plot_output, scatter_plot_output, pic50_plot_output, pic50_stats_output, mw_plot_output, mw_stats_output, logp_plot_output, logp_stats_output, hdonors_plot_output, hdonors_stats_output, hacceptors_plot_output, hacceptors_stats_output, status_step1_process], show_progress="minimal") #
|
461 |
+
calculate_fp_btn.click(fn=calculate_fingerprints, inputs=[app_state, fingerprint_dropdown], outputs=[status_step2, output_df_s2, download_s2, mols_grid_s2, app_state]) #
|
462 |
+
|
463 |
+
@download_s2.click(inputs=app_state, outputs=download_s2, show_progress="hidden") #
|
464 |
+
def download_handler(current_state): #
|
465 |
+
df_to_download = current_state.get('fingerprint_data') #
|
466 |
+
return save_dataframe_as_csv(df_to_download) #
|
467 |
+
|
468 |
+
train_models_btn.click(fn=handle_model_training, inputs=[app_state], outputs=[status_step3_train, model_results_df, model_selector_s3, app_state]) #
|
469 |
+
for listener in [model_selector_s3.change, feature_count_s3.change]: listener(fn=update_analysis_plots, inputs=[model_selector_s3, feature_count_s3, app_state], outputs=[validation_plot_s3, feature_plot_s3], show_progress="minimal") #
|
470 |
+
predict_btn_s3.click(fn=predict_on_upload, inputs=[upload_predict_file, model_selector_s3, app_state], outputs=[status_step3_predict, prediction_results_df, prediction_mols_grid]) #
|
471 |
+
|
472 |
+
if __name__ == "__main__": #
|
473 |
+
demo.launch(debug=True) #
|