mol-lang-lab / app.py
alidenewade's picture
Update app.py
c9fddab verified
raw
history blame
13.5 kB
# app.py
import streamlit as st
import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer, pipeline, BitsAndBytesConfig
from rdkit import Chem
from rdkit.Chem import Draw, AllChem, rdBase
import pandas as pd
import py3Dmol
import re
import logging
# --- Setup ---
# Suppress RDKit console output for cleaner logs
rdBase.DisableLog('rdApp.error')
# Set up Python logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# --- Page Configuration ---
st.set_page_config(
page_title="ChemBERTa SMILES Utilities",
page_icon="πŸ§ͺ",
layout="wide",
)
# --- Custom Styling (from drug_app) ---
def apply_custom_styling():
st.markdown(
"""
<style>
@import url('https://fonts.googleapis.com/css2?family=Roboto:wght@400;700&display=swap');
html, body, [class*="st-"] {
font-family: 'Roboto', sans-serif;
}
.stApp {
background-color: rgb(28, 28, 28);
color: white;
}
/* Tab styles */
.stTabs [data-baseweb="tab-list"] {
gap: 24px;
}
.stTabs [data-baseweb="tab"] {
height: 50px;
white-space: pre-wrap;
background: none;
border-radius: 0px;
border-bottom: 2px solid #333;
padding: 10px 4px;
color: #AAA;
}
.stTabs [data-baseweb="tab"]:hover {
background: #222;
color: #FFF;
}
.stTabs [aria-selected="true"] {
border-bottom: 2px solid #00A0FF; /* Highlight color for active tab */
color: #FFF;
}
/* Button styles */
.stButton>button {
border-color: #00A0FF;
color: #00A0FF;
background-color: transparent;
}
.stButton>button:hover {
border-color: #FFF;
color: #FFF;
background-color: #00A0FF;
}
</style>
""",
unsafe_allow_html=True
)
apply_custom_styling()
# --- Model Loading (from mol_app) ---
# NOTE: The "missing ScriptRunContext" warnings in the logs are expected when not
# running via the 'streamlit run' command. They can be safely ignored.
@st.cache_resource(show_spinner="Loading ChemBERTa model...")
def load_optimized_models():
"""Load models with quantization and other optimizations."""
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
try:
quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
bnb_8bit_compute_dtype=torch_dtype,
bnb_8bit_use_double_quant=True,
)
logger.info("8-bit quantization will be used.")
except ImportError:
quantization_config = None
logger.warning("bitsandbytes not found. Model will be loaded without quantization.")
model_name = "seyonec/PubChem10M_SMILES_BPE_450k"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model_kwargs = {"torch_dtype": torch_dtype}
if quantization_config and torch.cuda.is_available():
model_kwargs["quantization_config"] = quantization_config
model_kwargs["device_map"] = "auto"
# The "Some weights of the model were not used" warning is expected and normal.
model = AutoModelForMaskedLM.from_pretrained(model_name, **model_kwargs)
pipe = pipeline(
'fill-mask',
model=model,
tokenizer=tokenizer,
device=0 if device == "cuda" else -1
)
logger.info("ChemBERTa model loaded successfully.")
return pipe, tokenizer
fill_mask_pipeline, tokenizer = load_optimized_models()
# --- Core Functions ---
def get_mol(smiles):
"""Converts SMILES to RDKit Mol object."""
# The SMILES Parse Errors in logs are expected; RDKit warns about invalid
# molecules generated by the model, which this function handles gracefully.
mol = Chem.MolFromSmiles(smiles)
if mol:
try:
Chem.Kekulize(mol)
except:
pass
return mol
def find_matches_one(mol, submol_smarts):
"""Finds all matching atoms for a SMARTS pattern in a molecule."""
if not mol or not submol_smarts:
return []
submol = Chem.MolFromSmarts(submol_smarts)
if not submol:
return []
matches = mol.GetSubstructMatches(submol)
return matches
# --- Visualization Function (Adapted from drug_app) ---
def visualize_molecule_2d_3d(smiles: str, name: str, substructure_smarts=""):
"""Generates a side-by-side 2D SVG and 3D py3Dmol HTML view for a single molecule."""
log = ""
try:
mol = get_mol(smiles)
if not mol:
return f"<p>Invalid SMILES for {name}</p>", f"❌ Invalid SMILES for {name}"
# --- 2D Visualization ---
drawer = Draw.rdMolDraw2D.MolDraw2DSVG(450, 350)
opts = drawer.drawOptions()
opts.clearBackground = False
opts.addStereoAnnotation = True
opts.baseFontSize = 0.9
# Highlighting
atom_indices_to_highlight = []
if substructure_smarts:
matches = find_matches_one(mol, substructure_smarts)
if matches:
atom_indices_to_highlight = list(matches[0]) # Highlight first match
# Dark theme colors for 2D drawing
opts.backgroundColour = (0.109, 0.109, 0.109) # rgb(28,28,28)
opts.symbolColour = (1, 1, 1)
opts.setAtomPalette({
-1: (1, 1, 1), # Default
6: (0.9, 0.9, 0.9), # Carbon
7: (0.5, 0.5, 1), # Nitrogen
8: (1, 0.2, 0.2), # Oxygen
16: (1, 0.8, 0.2), # Sulfur
})
drawer.DrawMolecule(mol, highlightAtoms=atom_indices_to_highlight)
drawer.FinishDrawing()
svg_2d = drawer.GetDrawingText()
# Fix colors for dark theme
svg_2d = svg_2d.replace('stroke="black"', 'stroke="white"')
svg_2d = svg_2d.replace('fill="black"', 'fill="white"')
svg_2d = re.sub(r'fill:#(000000|000);', 'fill:white;', svg_2d)
# --- 3D Visualization ---
mol_3d = Chem.AddHs(mol)
AllChem.EmbedMolecule(mol_3d, randomSeed=42)
try:
AllChem.MMFFOptimizeMolecule(mol_3d)
except Exception: # Fallback if MMFF fails
AllChem.ETKDGv3().Embed(mol_3d)
sdf_data = Chem.MolToMolBlock(mol_3d)
viewer = py3Dmol.view(width=450, height=350)
viewer.setBackgroundColor('#1C1C1C')
viewer.addModel(sdf_data, "sdf")
viewer.setStyle({'stick': {}, 'sphere': {'scale': 0.25}})
viewer.zoomTo()
html_3d = viewer._make_html()
# --- Combine Views ---
combined_html = f"""
<div style="display: flex; flex-direction: row; align-items: center; justify-content: space-around; border: 1px solid #444; border-radius: 10px; padding: 10px; margin-bottom: 20px; background-color: #2b2b2b;">
<div style="text-align: center;">
<h4 style="color: white; font-family: 'Roboto', sans-serif;">{name} (2D Structure)</h4>
<div style="background-color: #1C1C1C; padding: 10px; border-radius: 5px;">{svg_2d}</div>
</div>
<div style="text-align: center;">
<h4 style="color: white; font-family: 'Roboto', sans-serif;">{name} (3D Interactive)</h4>
{html_3d}
</div>
</div>
"""
log += f"βœ… Generated 2D/3D view for {name}.\n"
return combined_html, log
except Exception as e:
return f"<p>Error visualizing {name}: {e}</p>", f"❌ Error visualizing {name}: {e}"
# --- Main Application Logic ---
def predict_and_generate_visualizations(smiles_mask, substructure_smarts):
"""Predicts masked SMILES and returns a dataframe and HTML for visualizations."""
if tokenizer.mask_token not in smiles_mask:
st.error(f"Error: Input SMILES must contain a mask token (e.g., `{tokenizer.mask_token}`).")
return pd.DataFrame(), "", "Input error."
status_log = ""
try:
with torch.no_grad():
predictions = fill_mask_pipeline(smiles_mask, top_k=15)
if torch.cuda.is_available():
torch.cuda.empty_cache()
except Exception as e:
st.error(f"An error occurred during model prediction: {e}")
return pd.DataFrame(), "", "Prediction error."
results_data = []
combined_html = ""
valid_predictions_count = 0
for i, pred in enumerate(predictions):
if valid_predictions_count >= 5:
break
predicted_smiles = pred['sequence']
score = pred['score']
mol = get_mol(predicted_smiles)
if mol:
valid_predictions_count += 1
results_data.append({
"Rank": valid_predictions_count,
"Predicted SMILES": predicted_smiles,
"Score": f"{score:.4f}"
})
html_view, log = visualize_molecule_2d_3d(
predicted_smiles,
f"Prediction #{valid_predictions_count}",
substructure_smarts
)
combined_html += html_view
status_log += log
df_results = pd.DataFrame(results_data)
status_log += f"\nFound {valid_predictions_count} valid molecules from top predictions."
return df_results, combined_html, status_log
# --- Streamlit Interface ---
st.title("πŸ§ͺ ChemBERTa SMILES Utilities")
st.markdown("""
Enter a SMILES string with a `<mask>` token to predict possible completions.
The model will generate the most likely atoms or fragments to fill the mask.
""")
tab1, tab2 = st.tabs(["Masked SMILES Prediction", "Molecule Viewer"])
with tab1:
st.header("Masked SMILES Prediction")
with st.form("prediction_form"):
col1, col2 = st.columns(2)
with col1:
smiles_input_masked = st.text_input(
"SMILES String with Mask",
value=f"C1=CC=CC{tokenizer.mask_token}C1",
help=f"Use `{tokenizer.mask_token}` as the mask token."
)
with col2:
substructure_input = st.text_input(
"Substructure to Highlight (SMARTS)",
value="C=C",
help="Enter a SMARTS pattern to highlight in the 2D view."
)
submit_button = st.form_submit_button("πŸš€ Predict and Visualize", use_container_width=True)
# --- Robust Session State Management ---
# This ensures the app loads with default predictions on the very first run,
# and only updates when the user clicks the button.
# The "Session state does not function" warning in logs is due to the execution
# environment and can be ignored.
if 'app_initialized' not in st.session_state:
with st.spinner("Running initial prediction..."):
df, html, log = predict_and_generate_visualizations(smiles_input_masked, substructure_input)
st.session_state.results_df = df
st.session_state.results_html = html
st.session_state.status_log = log
st.session_state.app_initialized = True
if submit_button:
with st.spinner("Running predictions... This may take a moment."):
df, html, log = predict_and_generate_visualizations(smiles_input_masked, substructure_input)
st.session_state.results_df = df
st.session_state.results_html = html
st.session_state.status_log = log
st.subheader("Top Predictions & Scores")
if 'results_df' in st.session_state and not st.session_state.results_df.empty:
st.dataframe(st.session_state.results__df, use_container_width=True, hide_index=True)
else:
st.info("No valid predictions to display. Try a different input.")
st.subheader("Predicted Molecule Visualizations (Top 5 Valid)")
if 'results_html' in st.session_state and st.session_state.results_html:
st.components.v1.html(st.session_state.results_html, height=1850, scrolling=True)
with st.expander("Show Logs"):
if 'status_log' in st.session_state:
# FIX: Added a label to st.text_area to resolve the accessibility warning.
st.text_area(label="Prediction Logs", value=st.session_state.status_log, height=200, key="log_area_pred")
with tab2:
st.header("Molecule Viewer")
st.markdown("Enter a single SMILES string to display its 2D and 3D structure.")
with st.form("viewer_form"):
smiles_input_viewer = st.text_input("SMILES String", value="CC(=O)Oc1ccccc1C(=O)O") # Aspirin
viewer_submit = st.form_submit_button("πŸ‘οΈ View Molecule", use_container_width=True)
if viewer_submit:
with st.spinner("Generating visualization..."):
html_view, log = visualize_molecule_2d_3d(smiles_input_viewer, "Molecule")
st.session_state.viewer_html = html_view
st.session_state.viewer_log = log
if 'viewer_html' in st.session_state:
st.components.v1.html(st.session_state.viewer_html, height=450)
with st.expander("Show Logs"):
if 'viewer_log' in st.session_state:
# FIX: Added a label to st.text_area to resolve the accessibility warning.
st.text_area(label="Viewer Logs", value=st.session_state.viewer_log, height=100, key="log_area_view")