alidenewade commited on
Commit
c9fddab
Β·
verified Β·
1 Parent(s): 45d1bdf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -16
app.py CHANGED
@@ -3,13 +3,17 @@ import streamlit as st
3
  import torch
4
  from transformers import AutoModelForMaskedLM, AutoTokenizer, pipeline, BitsAndBytesConfig
5
  from rdkit import Chem
6
- from rdkit.Chem import Draw, AllChem
7
  import pandas as pd
8
  import py3Dmol
9
  import re
10
  import logging
11
 
12
- # Set up logging
 
 
 
 
13
  logging.basicConfig(level=logging.INFO)
14
  logger = logging.getLogger(__name__)
15
 
@@ -83,6 +87,8 @@ apply_custom_styling()
83
 
84
 
85
  # --- Model Loading (from mol_app) ---
 
 
86
  @st.cache_resource(show_spinner="Loading ChemBERTa model...")
87
  def load_optimized_models():
88
  """Load models with quantization and other optimizations."""
@@ -108,6 +114,7 @@ def load_optimized_models():
108
  model_kwargs["quantization_config"] = quantization_config
109
  model_kwargs["device_map"] = "auto"
110
 
 
111
  model = AutoModelForMaskedLM.from_pretrained(model_name, **model_kwargs)
112
 
113
  pipe = pipeline(
@@ -126,6 +133,8 @@ fill_mask_pipeline, tokenizer = load_optimized_models()
126
 
127
  def get_mol(smiles):
128
  """Converts SMILES to RDKit Mol object."""
 
 
129
  mol = Chem.MolFromSmiles(smiles)
130
  if mol:
131
  try:
@@ -192,7 +201,7 @@ def visualize_molecule_2d_3d(smiles: str, name: str, substructure_smarts=""):
192
  AllChem.EmbedMolecule(mol_3d, randomSeed=42)
193
  try:
194
  AllChem.MMFFOptimizeMolecule(mol_3d)
195
- except:
196
  AllChem.ETKDGv3().Embed(mol_3d)
197
 
198
  sdf_data = Chem.MolToMolBlock(mol_3d)
@@ -275,7 +284,7 @@ def predict_and_generate_visualizations(smiles_mask, substructure_smarts):
275
  # --- Streamlit Interface ---
276
  st.title("πŸ§ͺ ChemBERTa SMILES Utilities")
277
  st.markdown("""
278
- Enter a SMILES string with a `<mask>` token (e.g., `C1=CC=CC<mask>C1`) to predict possible completions.
279
  The model will generate the most likely atoms or fragments to fill the mask.
280
  """)
281
 
@@ -301,25 +310,40 @@ with tab1:
301
 
302
  submit_button = st.form_submit_button("πŸš€ Predict and Visualize", use_container_width=True)
303
 
304
- if 'results_df' not in st.session_state or submit_button:
305
- if submit_button or 'results_df' not in st.session_state:
306
- with st.spinner("Running predictions... This may take a moment."):
307
- df, html, log = predict_and_generate_visualizations(smiles_input_masked, substructure_input)
308
- st.session_state.results_df = df
309
- st.session_state.results_html = html
310
- st.session_state.status_log = log
 
 
 
 
 
 
 
 
 
 
 
 
311
 
312
  st.subheader("Top Predictions & Scores")
313
  if 'results_df' in st.session_state and not st.session_state.results_df.empty:
314
- st.dataframe(st.session_state.results_df, use_container_width=True, hide_index=True)
315
- st.subheader("Predicted Molecule Visualizations (Top 5 Valid)")
316
- st.components.v1.html(st.session_state.results_html, height=1850, scrolling=True)
317
  else:
318
  st.info("No valid predictions to display. Try a different input.")
 
 
 
 
319
 
320
  with st.expander("Show Logs"):
321
  if 'status_log' in st.session_state:
322
- st.text_area("", st.session_state.status_log, height=200, key="log_area_pred")
 
323
 
324
  with tab2:
325
  st.header("Molecule Viewer")
@@ -340,4 +364,5 @@ with tab2:
340
 
341
  with st.expander("Show Logs"):
342
  if 'viewer_log' in st.session_state:
343
- st.text_area("", st.session_state.viewer_log, height=100, key="log_area_view")
 
 
3
  import torch
4
  from transformers import AutoModelForMaskedLM, AutoTokenizer, pipeline, BitsAndBytesConfig
5
  from rdkit import Chem
6
+ from rdkit.Chem import Draw, AllChem, rdBase
7
  import pandas as pd
8
  import py3Dmol
9
  import re
10
  import logging
11
 
12
+ # --- Setup ---
13
+ # Suppress RDKit console output for cleaner logs
14
+ rdBase.DisableLog('rdApp.error')
15
+
16
+ # Set up Python logging
17
  logging.basicConfig(level=logging.INFO)
18
  logger = logging.getLogger(__name__)
19
 
 
87
 
88
 
89
  # --- Model Loading (from mol_app) ---
90
+ # NOTE: The "missing ScriptRunContext" warnings in the logs are expected when not
91
+ # running via the 'streamlit run' command. They can be safely ignored.
92
  @st.cache_resource(show_spinner="Loading ChemBERTa model...")
93
  def load_optimized_models():
94
  """Load models with quantization and other optimizations."""
 
114
  model_kwargs["quantization_config"] = quantization_config
115
  model_kwargs["device_map"] = "auto"
116
 
117
+ # The "Some weights of the model were not used" warning is expected and normal.
118
  model = AutoModelForMaskedLM.from_pretrained(model_name, **model_kwargs)
119
 
120
  pipe = pipeline(
 
133
 
134
  def get_mol(smiles):
135
  """Converts SMILES to RDKit Mol object."""
136
+ # The SMILES Parse Errors in logs are expected; RDKit warns about invalid
137
+ # molecules generated by the model, which this function handles gracefully.
138
  mol = Chem.MolFromSmiles(smiles)
139
  if mol:
140
  try:
 
201
  AllChem.EmbedMolecule(mol_3d, randomSeed=42)
202
  try:
203
  AllChem.MMFFOptimizeMolecule(mol_3d)
204
+ except Exception: # Fallback if MMFF fails
205
  AllChem.ETKDGv3().Embed(mol_3d)
206
 
207
  sdf_data = Chem.MolToMolBlock(mol_3d)
 
284
  # --- Streamlit Interface ---
285
  st.title("πŸ§ͺ ChemBERTa SMILES Utilities")
286
  st.markdown("""
287
+ Enter a SMILES string with a `<mask>` token to predict possible completions.
288
  The model will generate the most likely atoms or fragments to fill the mask.
289
  """)
290
 
 
310
 
311
  submit_button = st.form_submit_button("πŸš€ Predict and Visualize", use_container_width=True)
312
 
313
+ # --- Robust Session State Management ---
314
+ # This ensures the app loads with default predictions on the very first run,
315
+ # and only updates when the user clicks the button.
316
+ # The "Session state does not function" warning in logs is due to the execution
317
+ # environment and can be ignored.
318
+ if 'app_initialized' not in st.session_state:
319
+ with st.spinner("Running initial prediction..."):
320
+ df, html, log = predict_and_generate_visualizations(smiles_input_masked, substructure_input)
321
+ st.session_state.results_df = df
322
+ st.session_state.results_html = html
323
+ st.session_state.status_log = log
324
+ st.session_state.app_initialized = True
325
+
326
+ if submit_button:
327
+ with st.spinner("Running predictions... This may take a moment."):
328
+ df, html, log = predict_and_generate_visualizations(smiles_input_masked, substructure_input)
329
+ st.session_state.results_df = df
330
+ st.session_state.results_html = html
331
+ st.session_state.status_log = log
332
 
333
  st.subheader("Top Predictions & Scores")
334
  if 'results_df' in st.session_state and not st.session_state.results_df.empty:
335
+ st.dataframe(st.session_state.results__df, use_container_width=True, hide_index=True)
 
 
336
  else:
337
  st.info("No valid predictions to display. Try a different input.")
338
+
339
+ st.subheader("Predicted Molecule Visualizations (Top 5 Valid)")
340
+ if 'results_html' in st.session_state and st.session_state.results_html:
341
+ st.components.v1.html(st.session_state.results_html, height=1850, scrolling=True)
342
 
343
  with st.expander("Show Logs"):
344
  if 'status_log' in st.session_state:
345
+ # FIX: Added a label to st.text_area to resolve the accessibility warning.
346
+ st.text_area(label="Prediction Logs", value=st.session_state.status_log, height=200, key="log_area_pred")
347
 
348
  with tab2:
349
  st.header("Molecule Viewer")
 
364
 
365
  with st.expander("Show Logs"):
366
  if 'viewer_log' in st.session_state:
367
+ # FIX: Added a label to st.text_area to resolve the accessibility warning.
368
+ st.text_area(label="Viewer Logs", value=st.session_state.viewer_log, height=100, key="log_area_view")