Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -3,13 +3,17 @@ import streamlit as st
|
|
3 |
import torch
|
4 |
from transformers import AutoModelForMaskedLM, AutoTokenizer, pipeline, BitsAndBytesConfig
|
5 |
from rdkit import Chem
|
6 |
-
from rdkit.Chem import Draw, AllChem
|
7 |
import pandas as pd
|
8 |
import py3Dmol
|
9 |
import re
|
10 |
import logging
|
11 |
|
12 |
-
#
|
|
|
|
|
|
|
|
|
13 |
logging.basicConfig(level=logging.INFO)
|
14 |
logger = logging.getLogger(__name__)
|
15 |
|
@@ -83,6 +87,8 @@ apply_custom_styling()
|
|
83 |
|
84 |
|
85 |
# --- Model Loading (from mol_app) ---
|
|
|
|
|
86 |
@st.cache_resource(show_spinner="Loading ChemBERTa model...")
|
87 |
def load_optimized_models():
|
88 |
"""Load models with quantization and other optimizations."""
|
@@ -108,6 +114,7 @@ def load_optimized_models():
|
|
108 |
model_kwargs["quantization_config"] = quantization_config
|
109 |
model_kwargs["device_map"] = "auto"
|
110 |
|
|
|
111 |
model = AutoModelForMaskedLM.from_pretrained(model_name, **model_kwargs)
|
112 |
|
113 |
pipe = pipeline(
|
@@ -126,6 +133,8 @@ fill_mask_pipeline, tokenizer = load_optimized_models()
|
|
126 |
|
127 |
def get_mol(smiles):
|
128 |
"""Converts SMILES to RDKit Mol object."""
|
|
|
|
|
129 |
mol = Chem.MolFromSmiles(smiles)
|
130 |
if mol:
|
131 |
try:
|
@@ -192,7 +201,7 @@ def visualize_molecule_2d_3d(smiles: str, name: str, substructure_smarts=""):
|
|
192 |
AllChem.EmbedMolecule(mol_3d, randomSeed=42)
|
193 |
try:
|
194 |
AllChem.MMFFOptimizeMolecule(mol_3d)
|
195 |
-
except:
|
196 |
AllChem.ETKDGv3().Embed(mol_3d)
|
197 |
|
198 |
sdf_data = Chem.MolToMolBlock(mol_3d)
|
@@ -275,7 +284,7 @@ def predict_and_generate_visualizations(smiles_mask, substructure_smarts):
|
|
275 |
# --- Streamlit Interface ---
|
276 |
st.title("π§ͺ ChemBERTa SMILES Utilities")
|
277 |
st.markdown("""
|
278 |
-
Enter a SMILES string with a `<mask>` token
|
279 |
The model will generate the most likely atoms or fragments to fill the mask.
|
280 |
""")
|
281 |
|
@@ -301,25 +310,40 @@ with tab1:
|
|
301 |
|
302 |
submit_button = st.form_submit_button("π Predict and Visualize", use_container_width=True)
|
303 |
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
311 |
|
312 |
st.subheader("Top Predictions & Scores")
|
313 |
if 'results_df' in st.session_state and not st.session_state.results_df.empty:
|
314 |
-
st.dataframe(st.session_state.
|
315 |
-
st.subheader("Predicted Molecule Visualizations (Top 5 Valid)")
|
316 |
-
st.components.v1.html(st.session_state.results_html, height=1850, scrolling=True)
|
317 |
else:
|
318 |
st.info("No valid predictions to display. Try a different input.")
|
|
|
|
|
|
|
|
|
319 |
|
320 |
with st.expander("Show Logs"):
|
321 |
if 'status_log' in st.session_state:
|
322 |
-
st.text_area
|
|
|
323 |
|
324 |
with tab2:
|
325 |
st.header("Molecule Viewer")
|
@@ -340,4 +364,5 @@ with tab2:
|
|
340 |
|
341 |
with st.expander("Show Logs"):
|
342 |
if 'viewer_log' in st.session_state:
|
343 |
-
st.text_area
|
|
|
|
3 |
import torch
|
4 |
from transformers import AutoModelForMaskedLM, AutoTokenizer, pipeline, BitsAndBytesConfig
|
5 |
from rdkit import Chem
|
6 |
+
from rdkit.Chem import Draw, AllChem, rdBase
|
7 |
import pandas as pd
|
8 |
import py3Dmol
|
9 |
import re
|
10 |
import logging
|
11 |
|
12 |
+
# --- Setup ---
|
13 |
+
# Suppress RDKit console output for cleaner logs
|
14 |
+
rdBase.DisableLog('rdApp.error')
|
15 |
+
|
16 |
+
# Set up Python logging
|
17 |
logging.basicConfig(level=logging.INFO)
|
18 |
logger = logging.getLogger(__name__)
|
19 |
|
|
|
87 |
|
88 |
|
89 |
# --- Model Loading (from mol_app) ---
|
90 |
+
# NOTE: The "missing ScriptRunContext" warnings in the logs are expected when not
|
91 |
+
# running via the 'streamlit run' command. They can be safely ignored.
|
92 |
@st.cache_resource(show_spinner="Loading ChemBERTa model...")
|
93 |
def load_optimized_models():
|
94 |
"""Load models with quantization and other optimizations."""
|
|
|
114 |
model_kwargs["quantization_config"] = quantization_config
|
115 |
model_kwargs["device_map"] = "auto"
|
116 |
|
117 |
+
# The "Some weights of the model were not used" warning is expected and normal.
|
118 |
model = AutoModelForMaskedLM.from_pretrained(model_name, **model_kwargs)
|
119 |
|
120 |
pipe = pipeline(
|
|
|
133 |
|
134 |
def get_mol(smiles):
|
135 |
"""Converts SMILES to RDKit Mol object."""
|
136 |
+
# The SMILES Parse Errors in logs are expected; RDKit warns about invalid
|
137 |
+
# molecules generated by the model, which this function handles gracefully.
|
138 |
mol = Chem.MolFromSmiles(smiles)
|
139 |
if mol:
|
140 |
try:
|
|
|
201 |
AllChem.EmbedMolecule(mol_3d, randomSeed=42)
|
202 |
try:
|
203 |
AllChem.MMFFOptimizeMolecule(mol_3d)
|
204 |
+
except Exception: # Fallback if MMFF fails
|
205 |
AllChem.ETKDGv3().Embed(mol_3d)
|
206 |
|
207 |
sdf_data = Chem.MolToMolBlock(mol_3d)
|
|
|
284 |
# --- Streamlit Interface ---
|
285 |
st.title("π§ͺ ChemBERTa SMILES Utilities")
|
286 |
st.markdown("""
|
287 |
+
Enter a SMILES string with a `<mask>` token to predict possible completions.
|
288 |
The model will generate the most likely atoms or fragments to fill the mask.
|
289 |
""")
|
290 |
|
|
|
310 |
|
311 |
submit_button = st.form_submit_button("π Predict and Visualize", use_container_width=True)
|
312 |
|
313 |
+
# --- Robust Session State Management ---
|
314 |
+
# This ensures the app loads with default predictions on the very first run,
|
315 |
+
# and only updates when the user clicks the button.
|
316 |
+
# The "Session state does not function" warning in logs is due to the execution
|
317 |
+
# environment and can be ignored.
|
318 |
+
if 'app_initialized' not in st.session_state:
|
319 |
+
with st.spinner("Running initial prediction..."):
|
320 |
+
df, html, log = predict_and_generate_visualizations(smiles_input_masked, substructure_input)
|
321 |
+
st.session_state.results_df = df
|
322 |
+
st.session_state.results_html = html
|
323 |
+
st.session_state.status_log = log
|
324 |
+
st.session_state.app_initialized = True
|
325 |
+
|
326 |
+
if submit_button:
|
327 |
+
with st.spinner("Running predictions... This may take a moment."):
|
328 |
+
df, html, log = predict_and_generate_visualizations(smiles_input_masked, substructure_input)
|
329 |
+
st.session_state.results_df = df
|
330 |
+
st.session_state.results_html = html
|
331 |
+
st.session_state.status_log = log
|
332 |
|
333 |
st.subheader("Top Predictions & Scores")
|
334 |
if 'results_df' in st.session_state and not st.session_state.results_df.empty:
|
335 |
+
st.dataframe(st.session_state.results__df, use_container_width=True, hide_index=True)
|
|
|
|
|
336 |
else:
|
337 |
st.info("No valid predictions to display. Try a different input.")
|
338 |
+
|
339 |
+
st.subheader("Predicted Molecule Visualizations (Top 5 Valid)")
|
340 |
+
if 'results_html' in st.session_state and st.session_state.results_html:
|
341 |
+
st.components.v1.html(st.session_state.results_html, height=1850, scrolling=True)
|
342 |
|
343 |
with st.expander("Show Logs"):
|
344 |
if 'status_log' in st.session_state:
|
345 |
+
# FIX: Added a label to st.text_area to resolve the accessibility warning.
|
346 |
+
st.text_area(label="Prediction Logs", value=st.session_state.status_log, height=200, key="log_area_pred")
|
347 |
|
348 |
with tab2:
|
349 |
st.header("Molecule Viewer")
|
|
|
364 |
|
365 |
with st.expander("Show Logs"):
|
366 |
if 'viewer_log' in st.session_state:
|
367 |
+
# FIX: Added a label to st.text_area to resolve the accessibility warning.
|
368 |
+
st.text_area(label="Viewer Logs", value=st.session_state.viewer_log, height=100, key="log_area_view")
|