import streamlit as st import pandas as pd from rdkit import Chem from rdkit.Chem import Draw, AllChem from rdkit.Chem.Draw import rdMolDraw2D import py3Dmol import io import base64 import logging import torch from transformers import AutoModelForMaskedLM, AutoTokenizer, pipeline, BitsAndBytesConfig # Set up logging to monitor quantization effects logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # --- Page Configuration --- st.set_page_config( page_title="Molecule Explorer & Predictor", page_icon="🔬", layout="wide", initial_sidebar_state="collapsed", ) # Custom CSS for a professional, minimalist look (adapted from drug_app.txt) def apply_custom_styling(): st.markdown( """ """, unsafe_allow_html=True ) apply_custom_styling() # --- Quantization Configuration --- def get_quantization_config(): """ Configure 8-bit quantization for model optimization. Falls back gracefully if bitsandbytes is not available. """ try: # 8-bit quantization configuration - good balance of speed and quality quantization_config = BitsAndBytesConfig( load_in_8bit=True, bnb_8bit_compute_dtype=torch.float16, bnb_8bit_use_double_quant=True, # Nested quantization for better compression ) logger.info("8-bit quantization configuration loaded successfully") return quantization_config except ImportError: logger.warning("bitsandbytes not available, falling back to standard loading") return None except Exception as e: logger.warning(f"Quantization setup failed: {e}, using standard loading") return None def get_torch_dtype(): """Get appropriate torch dtype based on available hardware.""" if torch.cuda.is_available(): return torch.float16 # Use half precision on GPU else: return torch.float32 # Keep full precision on CPU # --- Optimized Model Loading with Streamlit Caching --- @st.cache_resource(show_spinner="Loading molecular language model...") def load_optimized_models(): """Load models with quantization and other optimizations using Streamlit caching.""" device = "cuda" if torch.cuda.is_available() else "cpu" torch_dtype = get_torch_dtype() quantization_config = get_quantization_config() logger.info(f"Loading models on device: {device} with dtype: {torch_dtype}") # Model names model_name = "seyonec/PubChem10M_SMILES_BPE_450k" # Load tokenizer fill_mask_tokenizer = AutoTokenizer.from_pretrained(model_name) # Load model with quantization if available model_kwargs = { "torch_dtype": torch_dtype, } if quantization_config is not None and torch.cuda.is_available(): # Quantization typically for GPU model_kwargs["quantization_config"] = quantization_config model_kwargs["device_map"] = "auto" elif torch.cuda.is_available(): model_kwargs["device_map"] = "auto" # For non-quantized GPU loading else: model_kwargs["device_map"] = None # For CPU try: fill_mask_model = AutoModelForMaskedLM.from_pretrained( model_name, **model_kwargs ) fill_mask_model.eval() pipeline_device = fill_mask_model.device.index if hasattr(fill_mask_model.device, 'type') and fill_mask_model.device.type == "cuda" else -1 fill_mask_pipeline = pipeline( 'fill-mask', model=fill_mask_model, tokenizer=fill_mask_tokenizer, device=pipeline_device, ) logger.info("Models loaded successfully with optimizations") return fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline except Exception as e: logger.error(f"Error loading optimized models: {e}") logger.info("Falling back to standard model loading...") return load_standard_models(model_name) @st.cache_resource(show_spinner="Loading standard molecular language model...") def load_standard_models(model_name="seyonec/PubChem10M_SMILES_BPE_450k"): """Fallback standard model loading without quantization using Streamlit caching.""" fill_mask_tokenizer = AutoTokenizer.from_pretrained(model_name) fill_mask_model = AutoModelForMaskedLM.from_pretrained(model_name) device_idx = 0 if torch.cuda.is_available() else -1 fill_mask_pipeline = pipeline('fill-mask', model=fill_mask_model, tokenizer=fill_mask_tokenizer, device=device_idx) if torch.cuda.is_available(): fill_mask_model.to("cuda") logger.info("Standard models loaded successfully") return fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline # --- RDKit and Py3Dmol Visualization Functions --- def mol_to_svg(mol, size=(400, 300)): """Converts an RDKit molecule object to an SVG image string using default RDKit colors.""" if not mol: return None drawer = rdMolDraw2D.MolDraw2DSVG(*size) # Removing custom color settings as per user request to use default RDKit colors # drawer.drawOptions().clearBackground = False # Keep background transparent/dark # drawer.drawOptions().addStereoAnnotation = True # drawer.drawOptions().baseFontSize = 0.8 # # Set dark theme colors for RDKit drawing - REMOVED AS PER USER REQUEST # atom_colors = { # 6: (0.8, 0.8, 0.8), # Carbon (light gray) # 7: (0.2, 0.5, 1.0), # Nitrogen (blue) # 8: (1.0, 0.2, 0.2), # Oxygen (red) # 9: (0.2, 0.8, 0.2), # Fluorine (green) # 15: (1.0, 0.5, 0.0), # Phosphorus (orange) # 16: (1.0, 0.8, 0.0), # Sulfur (yellow) # 17: (0.2, 0.7, 0.2), # Chlorine (dark green) # 35: (0.5, 0.2, 0.8), # Bromine (purple) # 53: (0.8, 0.2, 0.5), # Iodine (pink/magenta) # } # # Set default atom color # drawer.drawOptions().setAtomColor(Chem.rdatomicnumlist.Get): (0.8, 0.8, 0.8) # Default to light gray for unknown atoms # for atom_num, color in atom_colors.items(): # drawer.drawOptions().setAtomColor(atom_num, color) # drawer.drawOptions().bondColor = (0.7, 0.7, 0.7) # Bond color (medium gray) # drawer.drawOptions().highlightColour = (0.2, 0.6, 1.0) # Highlight color (blue) drawer.DrawMolecule(mol) drawer.FinishDrawing() svg = drawer.GetDrawingText() return svg def mol_to_sdf(mol): """Converts an RDKit molecule object to an SDF string.""" if not mol: return None # Add hydrogens to the molecule mol_with_h = Chem.AddHs(mol) # Generate 3D coordinates using ETKDGv3, a common conformer generation method # MaxAttempts is increased for robustness, randomSeed for reproducibility try: AllChem.EmbedMolecule(mol_with_h, AllChem.ETKDGv3(), maxAttempts=50, randomSeed=42) # Optimize 3D coordinates using Universal Force Field (UFF) AllChem.UFFOptimizeMolecule(mol_with_h) sdf_string = Chem.MolToMolBlock(mol_with_h) return sdf_string except Exception as e: logger.error(f"Error generating 3D coordinates for SMILES: {Chem.MolToSmiles(mol)} - {e}") return None def visualize_molecule_3d(mol_sdf: str, width='100%', height=400): """ Generates an interactive 3D molecule visualization using py3Dmol. Accepts an SDF string. """ if not mol_sdf: return None try: viewer = py3Dmol.view(width=width, height=height) viewer.setBackgroundColor('#1C1C1C') # Dark background viewer.addModel(mol_sdf, "sdf") viewer.setStyle({'stick':{}, 'sphere':{'radius':0.3}}) # Stick and Sphere representation viewer.zoomTo() html_view = viewer._make_html() return html_view except Exception as e: st.error(f"Error generating 3D visualization: {e}") return None # --- Main Streamlit Application Layout --- st.title("🔬 Molecule Explorer & Predictor") # Initialize session state for consistent data across reruns if 'tokenizer' not in st.session_state: st.session_state.tokenizer, st.session_state.model, st.session_state.pipeline = load_optimized_models() tokenizer = st.session_state.tokenizer model = st.session_state.model fill_mask_pipeline = st.session_state.pipeline tab1, tab2 = st.tabs(["Molecule Viewer (2D & 3D)", "Masked SMILES Predictor"]) with tab1: st.header("Visualize Molecules in 2D and 3D") smiles_input = st.text_input("Enter SMILES string:", "CCO", help="e.g., CCO (ethanol), C1=CC=CC=C1 (benzene)") if st.button("View Molecule"): if smiles_input: mol = Chem.MolFromSmiles(smiles_input) if mol: st.subheader("2D Structure") svg = mol_to_svg(mol) if svg: st.image(svg, use_column_width=True) else: st.warning("Could not generate 2D image.") st.subheader("3D Structure (Interactive)") sdf_string = mol_to_sdf(mol) if sdf_string: html_3d = visualize_molecule_3d(sdf_string) if html_3d: st.components.v1.html(html_3d, width=700, height=500, scrolling=False) else: st.warning("Could not generate 3D visualization.") else: st.warning("Could not generate 3D SDF data.") else: st.error("Invalid SMILES string. Please enter a valid chemical structure.") else: st.info("Please enter a SMILES string to view the molecule.") with tab2: st.header("Masked SMILES Prediction") masked_smiles_input = st.text_input( "Enter masked SMILES string (use `` for the masked token):", "C1=CC=CCC1", help="Example: 'C1=CC=CCC1' (masked benzene), 'CCOC' (masked ether)" ) top_k_predictions = st.slider("Number of predictions to show:", 1, 10, 5) if st.button("Predict Masked Token"): if masked_smiles_input and "" in masked_smiles_input: try: # Perform prediction using the loaded pipeline predictions = fill_mask_pipeline(masked_smiles_input, top_k=top_k_predictions) prediction_data = [] for pred in predictions: token_str = pred['token_str'] sequence = pred['sequence'] score = pred['score'] mol = Chem.MolFromSmiles(sequence) img_svg = None if mol: img_svg = mol_to_svg(mol, size=(200,150)) # Smaller image for table prediction_data.append({ "Predicted Token": token_str, "Full SMILES": sequence, "Confidence Score": f"{score:.4f}", "Structure SVG": img_svg # Store SVG string }) df_predictions = pd.DataFrame(prediction_data) st.subheader("Predictions:") # Create a version of the dataframe without the SVG for initial display display_df = df_predictions.drop(columns=["Structure SVG"]) st.dataframe(display_df, use_container_width=True, hide_index=True) st.subheader("Predicted Structures:") # Determine the number of columns based on the number of predictions, up to a max num_cols = min(len(df_predictions), 5) # Display up to 5 images per row cols = st.columns(num_cols) for i, row in df_predictions.iterrows(): with cols[i % num_cols]: # Distribute images into columns st.markdown(f"**{row['Predicted Token']}** (Score: {row['Confidence Score']})") if row['Structure SVG']: st.image(row['Structure SVG'], use_column_width='auto') else: st.write("*(Invalid SMILES)*") except Exception as e: st.error(f"An error occurred during prediction: {e}") st.info("Please ensure your masked SMILES is valid and contains ``.") else: st.info("Please enter a masked SMILES string (e.g., `C1=CC=CCC1`).")