# app.py import streamlit as st import torch from transformers import AutoModelForMaskedLM, AutoTokenizer, pipeline, BitsAndBytesConfig from rdkit import Chem from rdkit.Chem import Draw, AllChem import pandas as pd import py3Dmol import re import logging # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # --- Page Configuration --- st.set_page_config( page_title="ChemBERTa SMILES Utilities", page_icon="๐งช", layout="wide", ) # --- Custom Styling (from drug_app) --- def apply_custom_styling(): st.markdown( """ """, unsafe_allow_html=True ) apply_custom_styling() # --- Model Loading (from mol_app) --- @st.cache_resource(show_spinner="Loading ChemBERTa model...") def load_optimized_models(): """Load models with quantization and other optimizations.""" device = "cuda" if torch.cuda.is_available() else "cpu" torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 try: quantization_config = BitsAndBytesConfig( load_in_8bit=True, bnb_8bit_compute_dtype=torch_dtype, bnb_8bit_use_double_quant=True, ) logger.info("8-bit quantization will be used.") except ImportError: quantization_config = None logger.warning("bitsandbytes not found. Model will be loaded without quantization.") model_name = "seyonec/PubChem10M_SMILES_BPE_450k" tokenizer = AutoTokenizer.from_pretrained(model_name) model_kwargs = {"torch_dtype": torch_dtype} if quantization_config and torch.cuda.is_available(): model_kwargs["quantization_config"] = quantization_config model_kwargs["device_map"] = "auto" model = AutoModelForMaskedLM.from_pretrained(model_name, **model_kwargs) pipe = pipeline( 'fill-mask', model=model, tokenizer=tokenizer, device=0 if device == "cuda" else -1 ) logger.info("ChemBERTa model loaded successfully.") return pipe, tokenizer fill_mask_pipeline, tokenizer = load_optimized_models() # --- Core Functions --- def get_mol(smiles): """Converts SMILES to RDKit Mol object.""" mol = Chem.MolFromSmiles(smiles) if mol: try: Chem.Kekulize(mol) except: pass return mol def find_matches_one(mol, submol_smarts): """Finds all matching atoms for a SMARTS pattern in a molecule.""" if not mol or not submol_smarts: return [] submol = Chem.MolFromSmarts(submol_smarts) if not submol: return [] matches = mol.GetSubstructMatches(submol) return matches # --- Visualization Function (Adapted from drug_app) --- def visualize_molecule_2d_3d(smiles: str, name: str, substructure_smarts=""): """Generates a side-by-side 2D SVG and 3D py3Dmol HTML view for a single molecule.""" log = "" try: mol = get_mol(smiles) if not mol: return f"
Invalid SMILES for {name}
", f"โ Invalid SMILES for {name}" # --- 2D Visualization --- drawer = Draw.rdMolDraw2D.MolDraw2DSVG(450, 350) opts = drawer.drawOptions() opts.clearBackground = False opts.addStereoAnnotation = True opts.baseFontSize = 0.9 # Highlighting atom_indices_to_highlight = [] if substructure_smarts: matches = find_matches_one(mol, substructure_smarts) if matches: atom_indices_to_highlight = list(matches[0]) # Highlight first match # Dark theme colors for 2D drawing opts.backgroundColour = (0.109, 0.109, 0.109) # rgb(28,28,28) opts.symbolColour = (1, 1, 1) opts.setAtomPalette({ -1: (1, 1, 1), # Default 6: (0.9, 0.9, 0.9), # Carbon 7: (0.5, 0.5, 1), # Nitrogen 8: (1, 0.2, 0.2), # Oxygen 16: (1, 0.8, 0.2), # Sulfur }) drawer.DrawMolecule(mol, highlightAtoms=atom_indices_to_highlight) drawer.FinishDrawing() svg_2d = drawer.GetDrawingText() # Fix colors for dark theme svg_2d = svg_2d.replace('stroke="black"', 'stroke="white"') svg_2d = svg_2d.replace('fill="black"', 'fill="white"') svg_2d = re.sub(r'fill:#(000000|000);', 'fill:white;', svg_2d) # --- 3D Visualization --- mol_3d = Chem.AddHs(mol) AllChem.EmbedMolecule(mol_3d, randomSeed=42) try: AllChem.MMFFOptimizeMolecule(mol_3d) except: AllChem.ETKDGv3().Embed(mol_3d) sdf_data = Chem.MolToMolBlock(mol_3d) viewer = py3Dmol.view(width=450, height=350) viewer.setBackgroundColor('#1C1C1C') viewer.addModel(sdf_data, "sdf") viewer.setStyle({'stick': {}, 'sphere': {'scale': 0.25}}) viewer.zoomTo() html_3d = viewer._make_html() # --- Combine Views --- combined_html = f"""Error visualizing {name}: {e}
", f"โ Error visualizing {name}: {e}" # --- Main Application Logic --- def predict_and_generate_visualizations(smiles_mask, substructure_smarts): """Predicts masked SMILES and returns a dataframe and HTML for visualizations.""" if tokenizer.mask_token not in smiles_mask: st.error(f"Error: Input SMILES must contain a mask token (e.g., `{tokenizer.mask_token}`).") return pd.DataFrame(), "", "Input error." status_log = "" try: with torch.no_grad(): predictions = fill_mask_pipeline(smiles_mask, top_k=15) if torch.cuda.is_available(): torch.cuda.empty_cache() except Exception as e: st.error(f"An error occurred during model prediction: {e}") return pd.DataFrame(), "", "Prediction error." results_data = [] combined_html = "" valid_predictions_count = 0 for i, pred in enumerate(predictions): if valid_predictions_count >= 5: break predicted_smiles = pred['sequence'] score = pred['score'] mol = get_mol(predicted_smiles) if mol: valid_predictions_count += 1 results_data.append({ "Rank": valid_predictions_count, "Predicted SMILES": predicted_smiles, "Score": f"{score:.4f}" }) html_view, log = visualize_molecule_2d_3d( predicted_smiles, f"Prediction #{valid_predictions_count}", substructure_smarts ) combined_html += html_view status_log += log df_results = pd.DataFrame(results_data) status_log += f"\nFound {valid_predictions_count} valid molecules from top predictions." return df_results, combined_html, status_log # --- Streamlit Interface --- st.title("๐งช ChemBERTa SMILES Utilities") st.markdown(""" Enter a SMILES string with a `