Spaces:
Running
Running
# app.py | |
# To run this app, save the code as app.py and run: | |
# streamlit run app.py | |
# | |
# You also need to install the following libraries: | |
# pip install streamlit torch transformers bitsandbytes rdkit-pypi py3Dmol pandas | |
import streamlit as st | |
import streamlit.components.v1 as components | |
import torch | |
from transformers import AutoModelForMaskedLM, AutoTokenizer, pipeline, BitsAndBytesConfig | |
from rdkit import Chem | |
from rdkit.Chem import Draw, AllChem | |
from rdkit.Chem.Draw import MolToImage | |
import pandas as pd | |
import logging | |
# Set up logging to monitor quantization effects | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# --- Page Configuration --- | |
st.set_page_config( | |
page_title="ChemBERTa SMILES Utilities", | |
page_icon="π¬", | |
layout="wide", | |
) | |
# --- Model Loading (Cached for Performance) --- | |
def load_models(): | |
""" | |
Load the tokenizer and model, wrapped in a Streamlit cache resource decorator | |
to ensure it only runs once per session. | |
""" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
quantization_config = None | |
try: | |
quantization_config = BitsAndBytesConfig( | |
load_in_8bit=True, | |
bnb_8bit_compute_dtype=torch.float16, | |
bnb_8bit_use_double_quant=True, | |
) | |
logger.info("8-bit quantization configuration created.") | |
except ImportError: | |
logger.warning("bitsandbytes not available, falling back to standard loading.") | |
except Exception as e: | |
logger.warning(f"Quantization setup failed: {e}, using standard loading.") | |
model_name = "seyonec/PubChem10M_SMILES_BPE_450k" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model_kwargs = {"torch_dtype": torch_dtype} | |
if quantization_config and torch.cuda.is_available(): | |
model_kwargs["quantization_config"] = quantization_config | |
model_kwargs["device_map"] = "auto" | |
elif torch.cuda.is_available(): | |
model_kwargs["device_map"] = "auto" | |
try: | |
model = AutoModelForMaskedLM.from_pretrained(model_name, **model_kwargs) | |
model.eval() | |
pipeline_device = model.device.index if hasattr(model.device, 'type') and model.device.type == "cuda" else -1 | |
fill_mask_pipeline = pipeline('fill-mask', model=model, tokenizer=tokenizer, device=pipeline_device) | |
logger.info("Models loaded successfully with optimizations.") | |
return tokenizer, fill_mask_pipeline | |
except Exception as e: | |
logger.error(f"Error loading optimized models: {e}. Retrying with standard loading.") | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForMaskedLM.from_pretrained(model_name) | |
device_idx = 0 if torch.cuda.is_available() else -1 | |
if torch.cuda.is_available(): | |
model.to("cuda") | |
fill_mask_pipeline = pipeline('fill-mask', model=model, tokenizer=tokenizer, device=device_idx) | |
return tokenizer, fill_mask_pipeline | |
# Load the models once | |
fill_mask_tokenizer, fill_mask_pipeline = load_models() | |
# --- Molecule & Visualization Helpers --- | |
def get_mol(smiles): | |
"""Converts SMILES to RDKit Mol object and Kekulizes it.""" | |
mol = Chem.MolFromSmiles(smiles) | |
if mol: | |
try: | |
Chem.Kekulize(mol) | |
except Exception: | |
pass | |
return mol | |
def find_matches_one(mol, submol_smarts): | |
"""Finds all matching atoms for a SMARTS pattern.""" | |
if not mol or not submol_smarts: return [] | |
submol = Chem.MolFromSmarts(submol_smarts) | |
return mol.GetSubstructMatches(submol) if submol else [] | |
def get_image_with_highlight(mol, atomset=None, size=(300, 300)): | |
"""Draws a 2D molecule image with optional atom highlighting.""" | |
if mol is None: return None | |
valid_atomset = [int(a) for a in atomset if str(a).isdigit()] if atomset else [] | |
return MolToImage(mol, size=size, fitImage=True, | |
highlightAtoms=valid_atomset, | |
highlightAtomColors={i: (0, 1, 0, 0.5) for i in valid_atomset}) | |
def generate_3d_view_html(smiles): | |
"""Generates an interactive 3D molecule view using py3Dmol.""" | |
if not smiles: return None | |
mol = get_mol(smiles) | |
if not mol: return "<p>Invalid SMILES for 3D view.</p>" | |
try: | |
mol_3d = Chem.AddHs(mol) | |
AllChem.EmbedMolecule(mol_3d, randomSeed=42, useRandomCoords=True) | |
AllChem.MMFFOptimizeMolecule(mol_3d) | |
sdf_data = Chem.MolToMolBlock(mol_3d) | |
viewer = py3Dmol.view(width=350, height=350) | |
viewer.setBackgroundColor('#FFFFFF') | |
viewer.addModel(sdf_data, "sdf") | |
viewer.setStyle({'stick': {}, 'sphere': {'scale': 0.25}}) | |
viewer.zoomTo() | |
return viewer._make_html() | |
except Exception as e: | |
logger.error(f"Failed to generate 3D view for {smiles}: {e}") | |
return f"<p>Error generating 3D view: {e}</p>" | |
# --- Core Application Logic --- | |
def run_masked_smiles_prediction(smiles_mask, substructure_smarts_highlight): | |
""" | |
Handles the logic for the masked SMILES prediction tab. | |
""" | |
if fill_mask_tokenizer.mask_token not in smiles_mask: | |
st.error(f"Error: Input SMILES must contain a mask token (e.g., {fill_mask_tokenizer.mask_token}).") | |
return | |
with st.spinner("Predicting completions..."): | |
try: | |
with torch.no_grad(): | |
predictions = fill_mask_pipeline(smiles_mask, top_k=10) | |
except Exception as e: | |
st.error(f"An error occurred during prediction: {e}") | |
if torch.cuda.is_available(): torch.cuda.empty_cache() | |
return | |
results = [] | |
for pred in predictions: | |
if len(results) >= 5: break | |
predicted_smiles = pred['sequence'] | |
mol = get_mol(predicted_smiles) | |
if mol: | |
atom_matches = find_matches_one(mol, substructure_smarts_highlight) | |
results.append({ | |
"smiles": predicted_smiles, | |
"score": f"{pred['score']:.4f}", | |
"image_2d": get_image_with_highlight(mol, atomset=atom_matches[0] if atom_matches else []), | |
"html_3d": generate_3d_view_html(predicted_smiles) | |
}) | |
if torch.cuda.is_available(): torch.cuda.empty_cache() | |
st.session_state.prediction_results = results | |
# --- Streamlit UI Definition --- | |
st.title("π¬ ChemBERTa SMILES Utilities Dashboard (2D & 3D)") | |
st.markdown("A tool to predict masked tokens in SMILES strings and visualize molecules, powered by ChemBERTa and Streamlit.") | |
tab1, tab2 = st.tabs(["Masked SMILES Prediction", "Molecule Viewer (2D & 3D)"]) | |
# --- Tab 1: Masked SMILES Prediction --- | |
with tab1: | |
st.header("Predict and Visualize Masked SMILES") | |
st.markdown("Enter a SMILES string with a `<mask>` token to predict possible completions.") | |
with st.form(key="prediction_form"): | |
col1, col2 = st.columns(2) | |
with col1: | |
smiles_input_masked = st.text_input( | |
"SMILES String with Mask", | |
value="C1=CC=CC<mask>C1", | |
help=f"The mask token is `{fill_mask_tokenizer.mask_token}`" | |
) | |
with col2: | |
substructure_input = st.text_input( | |
"Substructure to Highlight (SMARTS)", | |
value="C=C", | |
help="Enter a SMARTS pattern to highlight in the 2D images." | |
) | |
predict_button = st.form_submit_button("Predict and Visualize", use_container_width=True) | |
if predict_button: | |
run_masked_smiles_prediction(smiles_input_masked, substructure_input) | |
if 'prediction_results' in st.session_state and st.session_state.prediction_results: | |
results = st.session_state.prediction_results | |
st.subheader("Top 5 Valid Predictions") | |
# Display results in a table | |
df_data = [{"Predicted SMILES": r["smiles"], "Score": r["score"]} for r in results] | |
st.dataframe(pd.DataFrame(df_data), use_container_width=True) | |
st.markdown("---") | |
# Display molecule visualizations | |
for i, res in enumerate(results): | |
st.markdown(f"**Prediction {i+1}:** `{res['smiles']}` (Score: {res['score']})") | |
col1, col2 = st.columns(2) | |
with col1: | |
st.subheader("2D Structure") | |
if res["image_2d"]: | |
st.image(res["image_2d"], use_column_width=True) | |
else: | |
st.warning("Could not generate 2D image.") | |
with col2: | |
st.subheader("3D Interactive Structure") | |
if res["html_3d"]: | |
components.html(res["html_3d"], height=370) | |
else: | |
st.warning("Could not generate 3D view.") | |
st.markdown("---") | |
# --- Tab 2: Molecule Viewer --- | |
with tab2: | |
st.header("Visualize a Molecule from SMILES") | |
st.markdown("Enter a single SMILES string to display its 2D and 3D structures side-by-side.") | |
with st.form(key="viewer_form"): | |
smiles_input_viewer = st.text_input("SMILES String", value="CC(=O)Oc1ccccc1C(=O)O") # Aspirin | |
view_button = st.form_submit_button("View Molecule", use_container_width=True) | |
if view_button and smiles_input_viewer: | |
with st.spinner("Generating views..."): | |
mol = get_mol(smiles_input_viewer) | |
if not mol: | |
st.error("Invalid SMILES string provided.") | |
else: | |
st.subheader(f"Visualizations for: `{smiles_input_viewer}`") | |
col1, col2 = st.columns(2) | |
with col1: | |
st.subheader("2D Structure") | |
img_2d = MolToImage(mol, size=(450, 450), fitImage=True) | |
st.image(img_2d, use_column_width=True) | |
with col2: | |
st.subheader("3D Interactive Structure") | |
html_3d = generate_3d_view_html(smiles_input_viewer) | |
components.html(html_3d, height=470) | |