mol-lang-lab / app.py
alidenewade's picture
Update app.py
4ed4dfd verified
raw
history blame
10.2 kB
# app.py
# To run this app, save the code as app.py and run:
# streamlit run app.py
#
# You also need to install the following libraries:
# pip install streamlit torch transformers bitsandbytes rdkit-pypi py3Dmol pandas
import streamlit as st
import streamlit.components.v1 as components
import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer, pipeline, BitsAndBytesConfig
from rdkit import Chem
from rdkit.Chem import Draw, AllChem
from rdkit.Chem.Draw import MolToImage
import pandas as pd
import logging
# Set up logging to monitor quantization effects
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# --- Page Configuration ---
st.set_page_config(
page_title="ChemBERTa SMILES Utilities",
page_icon="πŸ”¬",
layout="wide",
)
# --- Model Loading (Cached for Performance) ---
@st.cache_resource(show_spinner="Loading ChemBERTa model...")
def load_models():
"""
Load the tokenizer and model, wrapped in a Streamlit cache resource decorator
to ensure it only runs once per session.
"""
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
quantization_config = None
try:
quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
bnb_8bit_compute_dtype=torch.float16,
bnb_8bit_use_double_quant=True,
)
logger.info("8-bit quantization configuration created.")
except ImportError:
logger.warning("bitsandbytes not available, falling back to standard loading.")
except Exception as e:
logger.warning(f"Quantization setup failed: {e}, using standard loading.")
model_name = "seyonec/PubChem10M_SMILES_BPE_450k"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model_kwargs = {"torch_dtype": torch_dtype}
if quantization_config and torch.cuda.is_available():
model_kwargs["quantization_config"] = quantization_config
model_kwargs["device_map"] = "auto"
elif torch.cuda.is_available():
model_kwargs["device_map"] = "auto"
try:
model = AutoModelForMaskedLM.from_pretrained(model_name, **model_kwargs)
model.eval()
pipeline_device = model.device.index if hasattr(model.device, 'type') and model.device.type == "cuda" else -1
fill_mask_pipeline = pipeline('fill-mask', model=model, tokenizer=tokenizer, device=pipeline_device)
logger.info("Models loaded successfully with optimizations.")
return tokenizer, fill_mask_pipeline
except Exception as e:
logger.error(f"Error loading optimized models: {e}. Retrying with standard loading.")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForMaskedLM.from_pretrained(model_name)
device_idx = 0 if torch.cuda.is_available() else -1
if torch.cuda.is_available():
model.to("cuda")
fill_mask_pipeline = pipeline('fill-mask', model=model, tokenizer=tokenizer, device=device_idx)
return tokenizer, fill_mask_pipeline
# Load the models once
fill_mask_tokenizer, fill_mask_pipeline = load_models()
# --- Molecule & Visualization Helpers ---
def get_mol(smiles):
"""Converts SMILES to RDKit Mol object and Kekulizes it."""
mol = Chem.MolFromSmiles(smiles)
if mol:
try:
Chem.Kekulize(mol)
except Exception:
pass
return mol
def find_matches_one(mol, submol_smarts):
"""Finds all matching atoms for a SMARTS pattern."""
if not mol or not submol_smarts: return []
submol = Chem.MolFromSmarts(submol_smarts)
return mol.GetSubstructMatches(submol) if submol else []
def get_image_with_highlight(mol, atomset=None, size=(300, 300)):
"""Draws a 2D molecule image with optional atom highlighting."""
if mol is None: return None
valid_atomset = [int(a) for a in atomset if str(a).isdigit()] if atomset else []
return MolToImage(mol, size=size, fitImage=True,
highlightAtoms=valid_atomset,
highlightAtomColors={i: (0, 1, 0, 0.5) for i in valid_atomset})
def generate_3d_view_html(smiles):
"""Generates an interactive 3D molecule view using py3Dmol."""
if not smiles: return None
mol = get_mol(smiles)
if not mol: return "<p>Invalid SMILES for 3D view.</p>"
try:
mol_3d = Chem.AddHs(mol)
AllChem.EmbedMolecule(mol_3d, randomSeed=42, useRandomCoords=True)
AllChem.MMFFOptimizeMolecule(mol_3d)
sdf_data = Chem.MolToMolBlock(mol_3d)
viewer = py3Dmol.view(width=350, height=350)
viewer.setBackgroundColor('#FFFFFF')
viewer.addModel(sdf_data, "sdf")
viewer.setStyle({'stick': {}, 'sphere': {'scale': 0.25}})
viewer.zoomTo()
return viewer._make_html()
except Exception as e:
logger.error(f"Failed to generate 3D view for {smiles}: {e}")
return f"<p>Error generating 3D view: {e}</p>"
# --- Core Application Logic ---
def run_masked_smiles_prediction(smiles_mask, substructure_smarts_highlight):
"""
Handles the logic for the masked SMILES prediction tab.
"""
if fill_mask_tokenizer.mask_token not in smiles_mask:
st.error(f"Error: Input SMILES must contain a mask token (e.g., {fill_mask_tokenizer.mask_token}).")
return
with st.spinner("Predicting completions..."):
try:
with torch.no_grad():
predictions = fill_mask_pipeline(smiles_mask, top_k=10)
except Exception as e:
st.error(f"An error occurred during prediction: {e}")
if torch.cuda.is_available(): torch.cuda.empty_cache()
return
results = []
for pred in predictions:
if len(results) >= 5: break
predicted_smiles = pred['sequence']
mol = get_mol(predicted_smiles)
if mol:
atom_matches = find_matches_one(mol, substructure_smarts_highlight)
results.append({
"smiles": predicted_smiles,
"score": f"{pred['score']:.4f}",
"image_2d": get_image_with_highlight(mol, atomset=atom_matches[0] if atom_matches else []),
"html_3d": generate_3d_view_html(predicted_smiles)
})
if torch.cuda.is_available(): torch.cuda.empty_cache()
st.session_state.prediction_results = results
# --- Streamlit UI Definition ---
st.title("πŸ”¬ ChemBERTa SMILES Utilities Dashboard (2D & 3D)")
st.markdown("A tool to predict masked tokens in SMILES strings and visualize molecules, powered by ChemBERTa and Streamlit.")
tab1, tab2 = st.tabs(["Masked SMILES Prediction", "Molecule Viewer (2D & 3D)"])
# --- Tab 1: Masked SMILES Prediction ---
with tab1:
st.header("Predict and Visualize Masked SMILES")
st.markdown("Enter a SMILES string with a `<mask>` token to predict possible completions.")
with st.form(key="prediction_form"):
col1, col2 = st.columns(2)
with col1:
smiles_input_masked = st.text_input(
"SMILES String with Mask",
value="C1=CC=CC<mask>C1",
help=f"The mask token is `{fill_mask_tokenizer.mask_token}`"
)
with col2:
substructure_input = st.text_input(
"Substructure to Highlight (SMARTS)",
value="C=C",
help="Enter a SMARTS pattern to highlight in the 2D images."
)
predict_button = st.form_submit_button("Predict and Visualize", use_container_width=True)
if predict_button:
run_masked_smiles_prediction(smiles_input_masked, substructure_input)
if 'prediction_results' in st.session_state and st.session_state.prediction_results:
results = st.session_state.prediction_results
st.subheader("Top 5 Valid Predictions")
# Display results in a table
df_data = [{"Predicted SMILES": r["smiles"], "Score": r["score"]} for r in results]
st.dataframe(pd.DataFrame(df_data), use_container_width=True)
st.markdown("---")
# Display molecule visualizations
for i, res in enumerate(results):
st.markdown(f"**Prediction {i+1}:** `{res['smiles']}` (Score: {res['score']})")
col1, col2 = st.columns(2)
with col1:
st.subheader("2D Structure")
if res["image_2d"]:
st.image(res["image_2d"], use_column_width=True)
else:
st.warning("Could not generate 2D image.")
with col2:
st.subheader("3D Interactive Structure")
if res["html_3d"]:
components.html(res["html_3d"], height=370)
else:
st.warning("Could not generate 3D view.")
st.markdown("---")
# --- Tab 2: Molecule Viewer ---
with tab2:
st.header("Visualize a Molecule from SMILES")
st.markdown("Enter a single SMILES string to display its 2D and 3D structures side-by-side.")
with st.form(key="viewer_form"):
smiles_input_viewer = st.text_input("SMILES String", value="CC(=O)Oc1ccccc1C(=O)O") # Aspirin
view_button = st.form_submit_button("View Molecule", use_container_width=True)
if view_button and smiles_input_viewer:
with st.spinner("Generating views..."):
mol = get_mol(smiles_input_viewer)
if not mol:
st.error("Invalid SMILES string provided.")
else:
st.subheader(f"Visualizations for: `{smiles_input_viewer}`")
col1, col2 = st.columns(2)
with col1:
st.subheader("2D Structure")
img_2d = MolToImage(mol, size=(450, 450), fitImage=True)
st.image(img_2d, use_column_width=True)
with col2:
st.subheader("3D Interactive Structure")
html_3d = generate_3d_view_html(smiles_input_viewer)
components.html(html_3d, height=470)