Spaces:
Sleeping
Sleeping
# app.py | |
import streamlit as st | |
import torch | |
from transformers import AutoModelForMaskedLM, AutoTokenizer, pipeline, BitsAndBytesConfig | |
from rdkit import Chem | |
from rdkit.Chem import Draw, AllChem, rdBase | |
import pandas as pd | |
import py3Dmol | |
import re | |
import logging | |
# --- Setup --- | |
# Suppress RDKit console output for cleaner logs | |
rdBase.DisableLog('rdApp.error') | |
# Set up Python logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# --- Page Configuration --- | |
st.set_page_config( | |
page_title="ChemBERTa SMILES Utilities", | |
page_icon="π§ͺ", | |
layout="wide", | |
) | |
# --- Custom Styling (from drug_app) --- | |
def apply_custom_styling(): | |
st.markdown( | |
""" | |
<style> | |
@import url('https://fonts.googleapis.com/css2?family=Roboto:wght@400;700&display=swap'); | |
html, body, [class*="st-"] { | |
font-family: 'Roboto', sans-serif; | |
} | |
.stApp { | |
background-color: rgb(28, 28, 28); | |
color: white; | |
} | |
/* Tab styles */ | |
.stTabs [data-baseweb="tab-list"] { | |
gap: 24px; | |
} | |
.stTabs [data-baseweb="tab"] { | |
height: 50px; | |
white-space: pre-wrap; | |
background: none; | |
border-radius: 0px; | |
border-bottom: 2px solid #333; | |
padding: 10px 4px; | |
color: #AAA; | |
} | |
.stTabs [data-baseweb="tab"]:hover { | |
background: #222; | |
color: #FFF; | |
} | |
.stTabs [aria-selected="true"] { | |
border-bottom: 2px solid #00A0FF; /* Highlight color for active tab */ | |
color: #FFF; | |
} | |
/* Button styles */ | |
.stButton>button { | |
border-color: #00A0FF; | |
color: #00A0FF; | |
background-color: transparent; | |
} | |
.stButton>button:hover { | |
border-color: #FFF; | |
color: #FFF; | |
background-color: #00A0FF; | |
} | |
</style> | |
""", | |
unsafe_allow_html=True | |
) | |
apply_custom_styling() | |
# --- Model Loading (from mol_app) --- | |
# NOTE: The "missing ScriptRunContext" warnings in the logs are expected when not | |
# running via the 'streamlit run' command. They can be safely ignored. | |
def load_optimized_models(): | |
"""Load models with quantization and other optimizations.""" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
try: | |
quantization_config = BitsAndBytesConfig( | |
load_in_8bit=True, | |
bnb_8bit_compute_dtype=torch_dtype, | |
bnb_8bit_use_double_quant=True, | |
) | |
logger.info("8-bit quantization will be used.") | |
except ImportError: | |
quantization_config = None | |
logger.warning("bitsandbytes not found. Model will be loaded without quantization.") | |
model_name = "seyonec/PubChem10M_SMILES_BPE_450k" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model_kwargs = {"torch_dtype": torch_dtype} | |
if quantization_config and torch.cuda.is_available(): | |
model_kwargs["quantization_config"] = quantization_config | |
model_kwargs["device_map"] = "auto" | |
# The "Some weights of the model were not used" warning is expected and normal. | |
model = AutoModelForMaskedLM.from_pretrained(model_name, **model_kwargs) | |
pipe = pipeline( | |
'fill-mask', | |
model=model, | |
tokenizer=tokenizer, | |
device=0 if device == "cuda" else -1 | |
) | |
logger.info("ChemBERTa model loaded successfully.") | |
return pipe, tokenizer | |
fill_mask_pipeline, tokenizer = load_optimized_models() | |
# --- Core Functions --- | |
def get_mol(smiles): | |
"""Converts SMILES to RDKit Mol object.""" | |
# The SMILES Parse Errors in logs are expected; RDKit warns about invalid | |
# molecules generated by the model, which this function handles gracefully. | |
mol = Chem.MolFromSmiles(smiles) | |
if mol: | |
try: | |
Chem.Kekulize(mol) | |
except: | |
pass | |
return mol | |
def find_matches_one(mol, submol_smarts): | |
"""Finds all matching atoms for a SMARTS pattern in a molecule.""" | |
if not mol or not submol_smarts: | |
return [] | |
submol = Chem.MolFromSmarts(submol_smarts) | |
if not submol: | |
return [] | |
matches = mol.GetSubstructMatches(submol) | |
return matches | |
# --- Visualization Function (Adapted from drug_app) --- | |
def visualize_molecule_2d_3d(smiles: str, name: str, substructure_smarts=""): | |
"""Generates a side-by-side 2D SVG and 3D py3Dmol HTML view for a single molecule.""" | |
log = "" | |
try: | |
mol = get_mol(smiles) | |
if not mol: | |
return f"<p>Invalid SMILES for {name}</p>", f"β Invalid SMILES for {name}" | |
# --- 2D Visualization --- | |
drawer = Draw.rdMolDraw2D.MolDraw2DSVG(450, 350) | |
opts = drawer.drawOptions() | |
opts.clearBackground = False | |
opts.addStereoAnnotation = True | |
opts.baseFontSize = 0.9 | |
# Highlighting | |
atom_indices_to_highlight = [] | |
if substructure_smarts: | |
matches = find_matches_one(mol, substructure_smarts) | |
if matches: | |
atom_indices_to_highlight = list(matches[0]) # Highlight first match | |
# Dark theme colors for 2D drawing | |
opts.backgroundColour = (0.109, 0.109, 0.109) # rgb(28,28,28) | |
opts.symbolColour = (1, 1, 1) | |
opts.setAtomPalette({ | |
-1: (1, 1, 1), # Default | |
6: (0.9, 0.9, 0.9), # Carbon | |
7: (0.5, 0.5, 1), # Nitrogen | |
8: (1, 0.2, 0.2), # Oxygen | |
16: (1, 0.8, 0.2), # Sulfur | |
}) | |
drawer.DrawMolecule(mol, highlightAtoms=atom_indices_to_highlight) | |
drawer.FinishDrawing() | |
svg_2d = drawer.GetDrawingText() | |
# Fix colors for dark theme | |
svg_2d = svg_2d.replace('stroke="black"', 'stroke="white"') | |
svg_2d = svg_2d.replace('fill="black"', 'fill="white"') | |
svg_2d = re.sub(r'fill:#(000000|000);', 'fill:white;', svg_2d) | |
# --- 3D Visualization --- | |
mol_3d = Chem.AddHs(mol) | |
AllChem.EmbedMolecule(mol_3d, randomSeed=42) | |
try: | |
AllChem.MMFFOptimizeMolecule(mol_3d) | |
except Exception: # Fallback if MMFF fails | |
AllChem.ETKDGv3().Embed(mol_3d) | |
sdf_data = Chem.MolToMolBlock(mol_3d) | |
viewer = py3Dmol.view(width=450, height=350) | |
viewer.setBackgroundColor('#1C1C1C') | |
viewer.addModel(sdf_data, "sdf") | |
viewer.setStyle({'stick': {}, 'sphere': {'scale': 0.25}}) | |
viewer.zoomTo() | |
html_3d = viewer._make_html() | |
# --- Combine Views --- | |
combined_html = f""" | |
<div style="display: flex; flex-direction: row; align-items: center; justify-content: space-around; border: 1px solid #444; border-radius: 10px; padding: 10px; margin-bottom: 20px; background-color: #2b2b2b;"> | |
<div style="text-align: center;"> | |
<h4 style="color: white; font-family: 'Roboto', sans-serif;">{name} (2D Structure)</h4> | |
<div style="background-color: #1C1C1C; padding: 10px; border-radius: 5px;">{svg_2d}</div> | |
</div> | |
<div style="text-align: center;"> | |
<h4 style="color: white; font-family: 'Roboto', sans-serif;">{name} (3D Interactive)</h4> | |
{html_3d} | |
</div> | |
</div> | |
""" | |
log += f"β Generated 2D/3D view for {name}.\n" | |
return combined_html, log | |
except Exception as e: | |
return f"<p>Error visualizing {name}: {e}</p>", f"β Error visualizing {name}: {e}" | |
# --- Main Application Logic --- | |
def predict_and_generate_visualizations(smiles_mask, substructure_smarts): | |
"""Predicts masked SMILES and returns a dataframe and HTML for visualizations.""" | |
if tokenizer.mask_token not in smiles_mask: | |
st.error(f"Error: Input SMILES must contain a mask token (e.g., `{tokenizer.mask_token}`).") | |
return pd.DataFrame(), "", "Input error." | |
status_log = "" | |
try: | |
with torch.no_grad(): | |
predictions = fill_mask_pipeline(smiles_mask, top_k=15) | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
except Exception as e: | |
st.error(f"An error occurred during model prediction: {e}") | |
return pd.DataFrame(), "", "Prediction error." | |
results_data = [] | |
combined_html = "" | |
valid_predictions_count = 0 | |
for i, pred in enumerate(predictions): | |
if valid_predictions_count >= 5: | |
break | |
predicted_smiles = pred['sequence'] | |
score = pred['score'] | |
mol = get_mol(predicted_smiles) | |
if mol: | |
valid_predictions_count += 1 | |
results_data.append({ | |
"Rank": valid_predictions_count, | |
"Predicted SMILES": predicted_smiles, | |
"Score": f"{score:.4f}" | |
}) | |
html_view, log = visualize_molecule_2d_3d( | |
predicted_smiles, | |
f"Prediction #{valid_predictions_count}", | |
substructure_smarts | |
) | |
combined_html += html_view | |
status_log += log | |
df_results = pd.DataFrame(results_data) | |
status_log += f"\nFound {valid_predictions_count} valid molecules from top predictions." | |
return df_results, combined_html, status_log | |
# --- Streamlit Interface --- | |
st.title("π§ͺ ChemBERTa SMILES Utilities") | |
st.markdown(""" | |
Enter a SMILES string with a `<mask>` token to predict possible completions. | |
The model will generate the most likely atoms or fragments to fill the mask. | |
""") | |
tab1, tab2 = st.tabs(["Masked SMILES Prediction", "Molecule Viewer"]) | |
with tab1: | |
st.header("Masked SMILES Prediction") | |
with st.form("prediction_form"): | |
col1, col2 = st.columns(2) | |
with col1: | |
smiles_input_masked = st.text_input( | |
"SMILES String with Mask", | |
value=f"C1=CC=CC{tokenizer.mask_token}C1", | |
help=f"Use `{tokenizer.mask_token}` as the mask token." | |
) | |
with col2: | |
substructure_input = st.text_input( | |
"Substructure to Highlight (SMARTS)", | |
value="C=C", | |
help="Enter a SMARTS pattern to highlight in the 2D view." | |
) | |
submit_button = st.form_submit_button("π Predict and Visualize", use_container_width=True) | |
# --- Robust Session State Management --- | |
# This ensures the app loads with default predictions on the very first run, | |
# and only updates when the user clicks the button. | |
# The "Session state does not function" warning in logs is due to the execution | |
# environment and can be ignored. | |
if 'app_initialized' not in st.session_state: | |
with st.spinner("Running initial prediction..."): | |
df, html, log = predict_and_generate_visualizations(smiles_input_masked, substructure_input) | |
st.session_state.results_df = df | |
st.session_state.results_html = html | |
st.session_state.status_log = log | |
st.session_state.app_initialized = True | |
if submit_button: | |
with st.spinner("Running predictions... This may take a moment."): | |
df, html, log = predict_and_generate_visualizations(smiles_input_masked, substructure_input) | |
st.session_state.results_df = df | |
st.session_state.results_html = html | |
st.session_state.status_log = log | |
st.subheader("Top Predictions & Scores") | |
if 'results_df' in st.session_state and not st.session_state.results_df.empty: | |
st.dataframe(st.session_state.results__df, use_container_width=True, hide_index=True) | |
else: | |
st.info("No valid predictions to display. Try a different input.") | |
st.subheader("Predicted Molecule Visualizations (Top 5 Valid)") | |
if 'results_html' in st.session_state and st.session_state.results_html: | |
st.components.v1.html(st.session_state.results_html, height=1850, scrolling=True) | |
with st.expander("Show Logs"): | |
if 'status_log' in st.session_state: | |
# FIX: Added a label to st.text_area to resolve the accessibility warning. | |
st.text_area(label="Prediction Logs", value=st.session_state.status_log, height=200, key="log_area_pred") | |
with tab2: | |
st.header("Molecule Viewer") | |
st.markdown("Enter a single SMILES string to display its 2D and 3D structure.") | |
with st.form("viewer_form"): | |
smiles_input_viewer = st.text_input("SMILES String", value="CC(=O)Oc1ccccc1C(=O)O") # Aspirin | |
viewer_submit = st.form_submit_button("ποΈ View Molecule", use_container_width=True) | |
if viewer_submit: | |
with st.spinner("Generating visualization..."): | |
html_view, log = visualize_molecule_2d_3d(smiles_input_viewer, "Molecule") | |
st.session_state.viewer_html = html_view | |
st.session_state.viewer_log = log | |
if 'viewer_html' in st.session_state: | |
st.components.v1.html(st.session_state.viewer_html, height=450) | |
with st.expander("Show Logs"): | |
if 'viewer_log' in st.session_state: | |
# FIX: Added a label to st.text_area to resolve the accessibility warning. | |
st.text_area(label="Viewer Logs", value=st.session_state.viewer_log, height=100, key="log_area_view") | |