alidenewade commited on
Commit
89eae1b
·
verified ·
1 Parent(s): e7bd3ae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -15
app.py CHANGED
@@ -68,26 +68,27 @@ sns.set_theme(style='whitegrid')
68
 
69
  # --- FINGERPRINT CONFIGURATION ---
70
  # Create a dummy PubChem.xml if no XML files are found, to ensure fp_config is populated
71
- # Check if the 'padel_descriptors' directory exists, create it if not
72
- if not os.path.exists('padel_descriptors'):
73
- os.makedirs('padel_descriptors')
 
74
 
75
  # Check for XML files within the 'padel_descriptors' folder
76
- xml_files = sorted(glob.glob('padel_descriptors/*.xml'))
77
-
78
  if not xml_files:
79
- # If no XML files found in the directory, try to create a dummy one.
80
  try:
81
- with open('padel_descriptors/PubChem.xml', 'w') as f:
 
82
  f.write('')
83
- xml_files = sorted(glob.glob('padel_descriptors/*.xml')) # Re-check after creating
84
  except IOError:
85
- warnings.warn("Could not create a dummy 'PubChem.xml' file in 'padel_descriptors' folder. Fingerprint calculation might fail if no .xml files are present.")
86
 
87
  if not xml_files:
88
  warnings.warn(
89
- "No descriptor .xml files found in the 'padel_descriptors' folder. Fingerprint calculation will not be possible. "
90
- "Please place descriptor XML files in the 'padel_descriptors' folder."
 
91
  )
92
  fp_config = {os.path.splitext(os.path.basename(file))[0]: file for file in xml_files}
93
  FP_list = sorted(list(fp_config.keys()))
@@ -207,7 +208,7 @@ def mannwhitney_test(df, descriptor):
207
  # ==============================================================================
208
  # === STEP 2: FEATURE ENGINEERING FUNCTIONS ===
209
  # ==============================================================================
210
-
211
  def create_molecule_grid_html(df, smiles_col='canonical_smiles', max_mols=20):
212
  html_parts = ['<div style="display: flex; flex-wrap: wrap; gap: 10px;">']
213
  for idx, row in df.head(max_mols).iterrows():
@@ -326,6 +327,7 @@ def run_regression_suite(df: pd.DataFrame, progress=gr.Progress()):
326
  model_choices = results_df['Model'].tolist()
327
  yield "✅ Model training & evaluation complete.", model_run_results, gr.Dropdown(choices=model_choices, interactive=True)
328
 
 
329
  def create_prediction_grid_html(df, smiles_col='canonical_smiles', pred_col='predicted_pIC50', max_mols=20):
330
  html_parts = ['<div style="display: flex; flex-wrap: wrap; gap: 10px;">']
331
  for idx, row in df.head(max_mols).iterrows():
@@ -392,9 +394,6 @@ def predict_on_upload(uploaded_file, model_name, current_state, progress=gr.Prog
392
 
393
  progress(0.9, desc="Generating visualization..."); yield "Generating visualization...", None, None
394
 
395
- # DEBUG FIX: The main fix for the KeyError.
396
- # Create a copy, rename the column *before* calling mols2grid.
397
- # This is more robust than relying on the library's 'rename' parameter.
398
  df_grid_view = df_results.dropna(subset=['predicted_pIC50']).copy()
399
  mols_html = "<h3>No molecules with successful predictions to display.</h3>"
400
  if not df_grid_view.empty:
@@ -480,6 +479,8 @@ with gr.Blocks(theme=gr.themes.Default(primary_hue="blue", secondary_hue="sky"),
480
  gr.Markdown("Upload a CSV with a `canonical_smiles` column to predict pIC50.")
481
  with gr.Row():
482
  upload_predict_file = gr.File(label="Upload CSV for Prediction", file_types=[".csv"])
 
 
483
  predict_btn_s3 = gr.Button("Run Prediction", variant="primary")
484
  status_step3_predict = gr.Textbox(label="Status", interactive=False)
485
  prediction_results_df = gr.DataFrame(label="Prediction Results")
@@ -517,6 +518,22 @@ with gr.Blocks(theme=gr.themes.Default(primary_hue="blue", secondary_hue="sky"),
517
  model_results = current_state.get('model_results')
518
  if not model_results or not model_name: return None, None
519
  plotter = model_results.plotter; validation_fig = plotter.plot_validation(model_name); feature_fig = plotter.plot_feature_importance(model_name, int(feature_count)); plt.close('all'); return validation_fig, feature_fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
520
 
521
  fetch_btn.click(fn=get_target_chembl_id, inputs=query_input, outputs=[target_id_table, selected_target_dropdown, status_step1_fetch], show_progress="minimal")
522
  selected_target_dropdown.change(fn=enable_process_button, inputs=selected_target_dropdown, outputs=process_btn, show_progress="hidden")
@@ -531,6 +548,9 @@ with gr.Blocks(theme=gr.themes.Default(primary_hue="blue", secondary_hue="sky"),
531
  train_models_btn.click(fn=handle_model_training, inputs=[app_state], outputs=[status_step3_train, model_results_df, model_selector_s3, app_state])
532
  for listener in [model_selector_s3.change, feature_count_s3.change]: listener(fn=update_analysis_plots, inputs=[model_selector_s3, feature_count_s3, app_state], outputs=[validation_plot_s3, feature_plot_s3], show_progress="minimal")
533
  predict_btn_s3.click(fn=predict_on_upload, inputs=[upload_predict_file, model_selector_s3, app_state], outputs=[status_step3_predict, prediction_results_df, prediction_mols_grid])
 
 
 
534
 
535
  if __name__ == "__main__":
536
  demo.launch(debug=True)
 
68
 
69
  # --- FINGERPRINT CONFIGURATION ---
70
  # Create a dummy PubChem.xml if no XML files are found, to ensure fp_config is populated
71
+ # Updated path for XML files to 'padel_descriptors/*.xml'
72
+ padel_descriptors_dir = 'padel_descriptors'
73
+ if not os.path.exists(padel_descriptors_dir):
74
+ os.makedirs(padel_descriptors_dir)
75
 
76
  # Check for XML files within the 'padel_descriptors' folder
77
+ xml_files = sorted(glob.glob(os.path.join(padel_descriptors_dir, '*.xml')))
 
78
  if not xml_files:
 
79
  try:
80
+ # Create a dummy PubChem.xml inside 'padel_descriptors' if no XML files are found
81
+ with open(os.path.join(padel_descriptors_dir, 'PubChem.xml'), 'w') as f:
82
  f.write('')
83
+ xml_files = sorted(glob.glob(os.path.join(padel_descriptors_dir, '*.xml'))) # Re-scan after creating dummy
84
  except IOError:
85
+ warnings.warn("Could not create a dummy 'PubChem.xml' file in 'padel_descriptors'. Fingerprint calculation might fail if no .xml files are present.")
86
 
87
  if not xml_files:
88
  warnings.warn(
89
+ "No descriptor .xml files found in 'padel_descriptors' directory. "
90
+ "Fingerprint calculation will not be possible. "
91
+ "Please place descriptor XML files in the 'padel_descriptors' directory."
92
  )
93
  fp_config = {os.path.splitext(os.path.basename(file))[0]: file for file in xml_files}
94
  FP_list = sorted(list(fp_config.keys()))
 
208
  # ==============================================================================
209
  # === STEP 2: FEATURE ENGINEERING FUNCTIONS ===
210
  # ==============================================================================
211
+ # Replacement for mols2grid.display in Step 2
212
  def create_molecule_grid_html(df, smiles_col='canonical_smiles', max_mols=20):
213
  html_parts = ['<div style="display: flex; flex-wrap: wrap; gap: 10px;">']
214
  for idx, row in df.head(max_mols).iterrows():
 
327
  model_choices = results_df['Model'].tolist()
328
  yield "✅ Model training & evaluation complete.", model_run_results, gr.Dropdown(choices=model_choices, interactive=True)
329
 
330
+ # Replacement for mols2grid.display in Step 3
331
  def create_prediction_grid_html(df, smiles_col='canonical_smiles', pred_col='predicted_pIC50', max_mols=20):
332
  html_parts = ['<div style="display: flex; flex-wrap: wrap; gap: 10px;">']
333
  for idx, row in df.head(max_mols).iterrows():
 
394
 
395
  progress(0.9, desc="Generating visualization..."); yield "Generating visualization...", None, None
396
 
 
 
 
397
  df_grid_view = df_results.dropna(subset=['predicted_pIC50']).copy()
398
  mols_html = "<h3>No molecules with successful predictions to display.</h3>"
399
  if not df_grid_view.empty:
 
479
  gr.Markdown("Upload a CSV with a `canonical_smiles` column to predict pIC50.")
480
  with gr.Row():
481
  upload_predict_file = gr.File(label="Upload CSV for Prediction", file_types=[".csv"])
482
+ # Add the example data button
483
+ load_example_btn = gr.Button("Load Example Data (example_data.csv)", variant="secondary")
484
  predict_btn_s3 = gr.Button("Run Prediction", variant="primary")
485
  status_step3_predict = gr.Textbox(label="Status", interactive=False)
486
  prediction_results_df = gr.DataFrame(label="Prediction Results")
 
518
  model_results = current_state.get('model_results')
519
  if not model_results or not model_name: return None, None
520
  plotter = model_results.plotter; validation_fig = plotter.plot_validation(model_name); feature_fig = plotter.plot_feature_importance(model_name, int(feature_count)); plt.close('all'); return validation_fig, feature_fig
521
+
522
+ # New function to load example data
523
+ def load_example_data():
524
+ example_file_path = "example_data.csv"
525
+ # Create a dummy example_data.csv if it doesn't exist for demonstration
526
+ if not os.path.exists(example_file_path):
527
+ dummy_data = pd.DataFrame({
528
+ 'canonical_smiles': [
529
+ 'CCO',
530
+ 'CC(=O)Oc1ccccc1C(=O)O',
531
+ 'Cc1ccc(cc1)C(C)C(=O)O'
532
+ ]
533
+ })
534
+ dummy_data.to_csv(example_file_path, index=False)
535
+ return gr.File(value=example_file_path, interactive=True)
536
+
537
 
538
  fetch_btn.click(fn=get_target_chembl_id, inputs=query_input, outputs=[target_id_table, selected_target_dropdown, status_step1_fetch], show_progress="minimal")
539
  selected_target_dropdown.change(fn=enable_process_button, inputs=selected_target_dropdown, outputs=process_btn, show_progress="hidden")
 
548
  train_models_btn.click(fn=handle_model_training, inputs=[app_state], outputs=[status_step3_train, model_results_df, model_selector_s3, app_state])
549
  for listener in [model_selector_s3.change, feature_count_s3.change]: listener(fn=update_analysis_plots, inputs=[model_selector_s3, feature_count_s3, app_state], outputs=[validation_plot_s3, feature_plot_s3], show_progress="minimal")
550
  predict_btn_s3.click(fn=predict_on_upload, inputs=[upload_predict_file, model_selector_s3, app_state], outputs=[status_step3_predict, prediction_results_df, prediction_mols_grid])
551
+
552
+ # New event handler for the example data button
553
+ load_example_btn.click(fn=load_example_data, outputs=upload_predict_file, show_progress="hidden")
554
 
555
  if __name__ == "__main__":
556
  demo.launch(debug=True)