# app.py import streamlit as st import torch from transformers import AutoModelForMaskedLM, AutoTokenizer, pipeline, BitsAndBytesConfig from rdkit import Chem from rdkit.Chem import Draw, AllChem import pandas as pd import py3Dmol import re import logging # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # --- Page Configuration --- st.set_page_config( page_title="ChemBERTa SMILES Utilities", page_icon="๐Ÿงช", layout="wide", ) # --- Custom Styling (from drug_app) --- def apply_custom_styling(): st.markdown( """ """, unsafe_allow_html=True ) apply_custom_styling() # --- Model Loading (from mol_app) --- @st.cache_resource(show_spinner="Loading ChemBERTa model...") def load_optimized_models(): """Load models with quantization and other optimizations.""" device = "cuda" if torch.cuda.is_available() else "cpu" torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 try: quantization_config = BitsAndBytesConfig( load_in_8bit=True, bnb_8bit_compute_dtype=torch_dtype, bnb_8bit_use_double_quant=True, ) logger.info("8-bit quantization will be used.") except ImportError: quantization_config = None logger.warning("bitsandbytes not found. Model will be loaded without quantization.") model_name = "seyonec/PubChem10M_SMILES_BPE_450k" tokenizer = AutoTokenizer.from_pretrained(model_name) model_kwargs = {"torch_dtype": torch_dtype} if quantization_config and torch.cuda.is_available(): model_kwargs["quantization_config"] = quantization_config model_kwargs["device_map"] = "auto" model = AutoModelForMaskedLM.from_pretrained(model_name, **model_kwargs) pipe = pipeline( 'fill-mask', model=model, tokenizer=tokenizer, device=0 if device == "cuda" else -1 ) logger.info("ChemBERTa model loaded successfully.") return pipe, tokenizer fill_mask_pipeline, tokenizer = load_optimized_models() # --- Core Functions --- def get_mol(smiles): """Converts SMILES to RDKit Mol object.""" mol = Chem.MolFromSmiles(smiles) if mol: try: Chem.Kekulize(mol) except: pass return mol def find_matches_one(mol, submol_smarts): """Finds all matching atoms for a SMARTS pattern in a molecule.""" if not mol or not submol_smarts: return [] submol = Chem.MolFromSmarts(submol_smarts) if not submol: return [] matches = mol.GetSubstructMatches(submol) return matches # --- Visualization Function (Adapted from drug_app) --- def visualize_molecule_2d_3d(smiles: str, name: str, substructure_smarts=""): """Generates a side-by-side 2D SVG and 3D py3Dmol HTML view for a single molecule.""" log = "" try: mol = get_mol(smiles) if not mol: return f"

Invalid SMILES for {name}

", f"โŒ Invalid SMILES for {name}" # --- 2D Visualization --- drawer = Draw.rdMolDraw2D.MolDraw2DSVG(450, 350) opts = drawer.drawOptions() opts.clearBackground = False opts.addStereoAnnotation = True opts.baseFontSize = 0.9 # Highlighting atom_indices_to_highlight = [] if substructure_smarts: matches = find_matches_one(mol, substructure_smarts) if matches: atom_indices_to_highlight = list(matches[0]) # Highlight first match # Dark theme colors for 2D drawing opts.backgroundColour = (0.109, 0.109, 0.109) # rgb(28,28,28) opts.symbolColour = (1, 1, 1) opts.setAtomPalette({ -1: (1, 1, 1), # Default 6: (0.9, 0.9, 0.9), # Carbon 7: (0.5, 0.5, 1), # Nitrogen 8: (1, 0.2, 0.2), # Oxygen 16: (1, 0.8, 0.2), # Sulfur }) drawer.DrawMolecule(mol, highlightAtoms=atom_indices_to_highlight) drawer.FinishDrawing() svg_2d = drawer.GetDrawingText() # Fix colors for dark theme svg_2d = svg_2d.replace('stroke="black"', 'stroke="white"') svg_2d = svg_2d.replace('fill="black"', 'fill="white"') svg_2d = re.sub(r'fill:#(000000|000);', 'fill:white;', svg_2d) # --- 3D Visualization --- mol_3d = Chem.AddHs(mol) AllChem.EmbedMolecule(mol_3d, randomSeed=42) try: AllChem.MMFFOptimizeMolecule(mol_3d) except: AllChem.ETKDGv3().Embed(mol_3d) sdf_data = Chem.MolToMolBlock(mol_3d) viewer = py3Dmol.view(width=450, height=350) viewer.setBackgroundColor('#1C1C1C') viewer.addModel(sdf_data, "sdf") viewer.setStyle({'stick': {}, 'sphere': {'scale': 0.25}}) viewer.zoomTo() html_3d = viewer._make_html() # --- Combine Views --- combined_html = f"""

{name} (2D Structure)

{svg_2d}

{name} (3D Interactive)

{html_3d}
""" log += f"โœ… Generated 2D/3D view for {name}.\n" return combined_html, log except Exception as e: return f"

Error visualizing {name}: {e}

", f"โŒ Error visualizing {name}: {e}" # --- Main Application Logic --- def predict_and_generate_visualizations(smiles_mask, substructure_smarts): """Predicts masked SMILES and returns a dataframe and HTML for visualizations.""" if tokenizer.mask_token not in smiles_mask: st.error(f"Error: Input SMILES must contain a mask token (e.g., `{tokenizer.mask_token}`).") return pd.DataFrame(), "", "Input error." status_log = "" try: with torch.no_grad(): predictions = fill_mask_pipeline(smiles_mask, top_k=15) if torch.cuda.is_available(): torch.cuda.empty_cache() except Exception as e: st.error(f"An error occurred during model prediction: {e}") return pd.DataFrame(), "", "Prediction error." results_data = [] combined_html = "" valid_predictions_count = 0 for i, pred in enumerate(predictions): if valid_predictions_count >= 5: break predicted_smiles = pred['sequence'] score = pred['score'] mol = get_mol(predicted_smiles) if mol: valid_predictions_count += 1 results_data.append({ "Rank": valid_predictions_count, "Predicted SMILES": predicted_smiles, "Score": f"{score:.4f}" }) html_view, log = visualize_molecule_2d_3d( predicted_smiles, f"Prediction #{valid_predictions_count}", substructure_smarts ) combined_html += html_view status_log += log df_results = pd.DataFrame(results_data) status_log += f"\nFound {valid_predictions_count} valid molecules from top predictions." return df_results, combined_html, status_log # --- Streamlit Interface --- st.title("๐Ÿงช ChemBERTa SMILES Utilities") st.markdown(""" Enter a SMILES string with a `` token (e.g., `C1=CC=CCC1`) to predict possible completions. The model will generate the most likely atoms or fragments to fill the mask. """) tab1, tab2 = st.tabs(["Masked SMILES Prediction", "Molecule Viewer"]) with tab1: st.header("Masked SMILES Prediction") with st.form("prediction_form"): col1, col2 = st.columns(2) with col1: smiles_input_masked = st.text_input( "SMILES String with Mask", value=f"C1=CC=CC{tokenizer.mask_token}C1", help=f"Use `{tokenizer.mask_token}` as the mask token." ) with col2: substructure_input = st.text_input( "Substructure to Highlight (SMARTS)", value="C=C", help="Enter a SMARTS pattern to highlight in the 2D view." ) submit_button = st.form_submit_button("๐Ÿš€ Predict and Visualize", use_container_width=True) if 'results_df' not in st.session_state or submit_button: if submit_button or 'results_df' not in st.session_state: with st.spinner("Running predictions... This may take a moment."): df, html, log = predict_and_generate_visualizations(smiles_input_masked, substructure_input) st.session_state.results_df = df st.session_state.results_html = html st.session_state.status_log = log st.subheader("Top Predictions & Scores") if 'results_df' in st.session_state and not st.session_state.results_df.empty: st.dataframe(st.session_state.results_df, use_container_width=True, hide_index=True) st.subheader("Predicted Molecule Visualizations (Top 5 Valid)") st.components.v1.html(st.session_state.results_html, height=1850, scrolling=True) else: st.info("No valid predictions to display. Try a different input.") with st.expander("Show Logs"): if 'status_log' in st.session_state: st.text_area("", st.session_state.status_log, height=200, key="log_area_pred") with tab2: st.header("Molecule Viewer") st.markdown("Enter a single SMILES string to display its 2D and 3D structure.") with st.form("viewer_form"): smiles_input_viewer = st.text_input("SMILES String", value="CC(=O)Oc1ccccc1C(=O)O") # Aspirin viewer_submit = st.form_submit_button("๐Ÿ‘๏ธ View Molecule", use_container_width=True) if viewer_submit: with st.spinner("Generating visualization..."): html_view, log = visualize_molecule_2d_3d(smiles_input_viewer, "Molecule") st.session_state.viewer_html = html_view st.session_state.viewer_log = log if 'viewer_html' in st.session_state: st.components.v1.html(st.session_state.viewer_html, height=450) with st.expander("Show Logs"): if 'viewer_log' in st.session_state: st.text_area("", st.session_state.viewer_log, height=100, key="log_area_view")