Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	Update app.py
Browse files
    	
        app.py
    CHANGED
    
    | @@ -1,73 +1,137 @@ | |
| 1 | 
            -
            # app.py
         | 
| 2 | 
             
            import streamlit as st
         | 
| 3 | 
             
            import torch
         | 
| 4 | 
            -
            from transformers import AutoModelForMaskedLM, AutoTokenizer, pipeline
         | 
| 5 | 
             
            from rdkit import Chem
         | 
| 6 | 
            -
            from rdkit.Chem import Draw, AllChem
         | 
| 7 | 
             
            from rdkit.Chem.Draw import MolToImage
         | 
| 8 | 
             
            import pandas as pd
         | 
| 9 | 
             
            import io
         | 
| 10 | 
             
            import base64
         | 
| 11 | 
             
            import logging
         | 
| 12 | 
             
            import py3Dmol
         | 
|  | |
| 13 |  | 
| 14 | 
            -
            # Set up logging to monitor effects
         | 
| 15 | 
             
            logging.basicConfig(level=logging.INFO)
         | 
| 16 | 
             
            logger = logging.getLogger(__name__)
         | 
| 17 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 18 | 
             
            # --- Optimized Model Loading ---
         | 
| 19 | 
             
            @st.cache_resource
         | 
| 20 | 
             
            def load_optimized_models():
         | 
| 21 | 
            -
                """Load models  | 
| 22 | 
            -
                device = " | 
| 23 | 
            -
                torch_dtype =  | 
|  | |
| 24 |  | 
| 25 | 
             
                logger.info(f"Loading models on device: {device} with dtype: {torch_dtype}")
         | 
| 26 |  | 
| 27 | 
             
                # Model names
         | 
| 28 | 
             
                model_name = "seyonec/PubChem10M_SMILES_BPE_450k"
         | 
| 29 |  | 
| 30 | 
            -
                # Load tokenizer
         | 
| 31 | 
             
                fill_mask_tokenizer = AutoTokenizer.from_pretrained(model_name)
         | 
| 32 |  | 
| 33 | 
            -
                # Load model with  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 34 | 
             
                try:
         | 
|  | |
| 35 | 
             
                    fill_mask_model = AutoModelForMaskedLM.from_pretrained(
         | 
| 36 | 
             
                        model_name,
         | 
| 37 | 
            -
                         | 
| 38 | 
            -
                        device_map=None # No device mapping for plain CPU
         | 
| 39 | 
             
                    )
         | 
| 40 |  | 
| 41 | 
             
                    # Set model to evaluation mode for inference
         | 
| 42 | 
             
                    fill_mask_model.eval()
         | 
| 43 |  | 
| 44 | 
            -
                    # Create pipeline | 
|  | |
|  | |
|  | |
| 45 | 
             
                    fill_mask_pipeline = pipeline(
         | 
| 46 | 
             
                        'fill-mask',
         | 
| 47 | 
             
                        model=fill_mask_model,
         | 
| 48 | 
             
                        tokenizer=fill_mask_tokenizer,
         | 
| 49 | 
            -
                        device | 
| 50 | 
             
                    )
         | 
| 51 |  | 
| 52 | 
            -
                    logger.info("Models loaded successfully  | 
| 53 | 
             
                    return fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline
         | 
| 54 |  | 
| 55 | 
             
                except Exception as e:
         | 
| 56 | 
            -
                    logger.error(f"Error loading models | 
| 57 | 
            -
                     | 
| 58 | 
            -
                     | 
| 59 | 
            -
                     | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 60 |  | 
| 61 | 
            -
             | 
| 62 | 
            -
             | 
|  | |
|  | |
| 63 |  | 
| 64 | 
            -
            # --- Memory Management Utilities  | 
| 65 | 
             
            def clear_gpu_cache():
         | 
| 66 | 
            -
                """ | 
| 67 | 
             
                if torch.cuda.is_available():
         | 
| 68 | 
             
                    torch.cuda.empty_cache()
         | 
| 69 |  | 
| 70 | 
            -
            # --- Helper Functions  | 
| 71 | 
             
            def get_mol(smiles):
         | 
| 72 | 
             
                """Converts SMILES to RDKit Mol object and Kekulizes it."""
         | 
| 73 | 
             
                mol = Chem.MolFromSmiles(smiles)
         | 
| @@ -109,50 +173,66 @@ def get_image_with_highlight(mol, atomset=None, size=(300, 300)): | |
| 109 | 
             
                                 highlightAtomColors={i: highlight_color for i in valid_atomset} if valid_atomset else {})
         | 
| 110 | 
             
                return img
         | 
| 111 |  | 
| 112 | 
            -
            def  | 
| 113 | 
            -
                """ | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 114 | 
             
                if mol is None:
         | 
| 115 | 
             
                    return None
         | 
| 116 | 
            -
                # Add 3D coordinates if not present
         | 
| 117 | 
            -
                AllChem.EmbedMolecule(mol, AllChem.ETKDG())
         | 
| 118 | 
            -
                AllChem.UFFOptimizeMolecule(mol)
         | 
| 119 | 
            -
                return Chem.MolToMolBlock(mol)
         | 
| 120 | 
            -
             | 
| 121 | 
            -
            def render_mol_3d(sdf_string, width=300, height=300):
         | 
| 122 | 
            -
                """Renders a 3D molecule using py3Dmol."""
         | 
| 123 | 
            -
                if sdf_string is None:
         | 
| 124 | 
            -
                    return ""
         | 
| 125 |  | 
| 126 | 
            -
                 | 
| 127 | 
            -
                 | 
| 128 | 
            -
                 | 
| 129 | 
            -
                 | 
|  | |
|  | |
|  | |
| 130 | 
             
                viewer.zoomTo()
         | 
| 131 | 
            -
                 | 
| 132 | 
            -
                return viewer | 
| 133 |  | 
| 134 | 
             
            # --- Streamlit Interface Functions ---
         | 
| 135 |  | 
| 136 | 
             
            def predict_and_visualize_masked_smiles(smiles_mask, substructure_smarts_highlight="CC=CC"):
         | 
| 137 | 
             
                """
         | 
| 138 | 
             
                Predicts masked tokens in a SMILES string, shows scores, and visualizes molecules.
         | 
| 139 | 
            -
                Returns 5 image paths and a status message.
         | 
| 140 | 
             
                """
         | 
|  | |
|  | |
|  | |
| 141 | 
             
                if fill_mask_tokenizer.mask_token not in smiles_mask:
         | 
| 142 | 
             
                    st.error("Error: Input SMILES must contain a mask token (e.g., <mask>).")
         | 
| 143 | 
            -
                    return | 
| 144 |  | 
| 145 | 
             
                try:
         | 
|  | |
| 146 | 
             
                    with torch.no_grad():
         | 
| 147 | 
             
                        predictions = fill_mask_pipeline(smiles_mask, top_k=10)
         | 
| 148 | 
             
                except Exception as e:
         | 
| 149 | 
             
                    clear_gpu_cache()
         | 
| 150 | 
             
                    st.error(f"Error during prediction: {str(e)}")
         | 
| 151 | 
            -
                    return | 
| 152 |  | 
| 153 | 
             
                results_data = []
         | 
| 154 | 
            -
                 | 
| 155 | 
            -
                image_3d_list = []
         | 
| 156 | 
             
                valid_predictions_count = 0
         | 
| 157 |  | 
| 158 | 
             
                for pred in predictions:
         | 
| @@ -165,157 +245,129 @@ def predict_and_visualize_masked_smiles(smiles_mask, substructure_smarts_highlig | |
| 165 | 
             
                    mol = get_mol(predicted_smiles)
         | 
| 166 | 
             
                    if mol:
         | 
| 167 | 
             
                        results_data.append({"Predicted SMILES": predicted_smiles, "Score": f"{score:.4f}"})
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 168 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 169 | 
             
                        atom_matches_indices = []
         | 
| 170 | 
             
                        if substructure_smarts_highlight:
         | 
| 171 | 
             
                            matches = find_matches_one(mol, substructure_smarts_highlight)
         | 
| 172 | 
             
                            if matches:
         | 
| 173 | 
            -
                                atom_matches_indices = list(matches[0]) | 
| 174 | 
            -
             | 
| 175 | 
            -
                        img_2d = get_image_with_highlight(mol, atomset=atom_matches_indices)
         | 
| 176 | 
            -
                        image_2d_list.append(img_2d)
         | 
| 177 | 
            -
                        
         | 
| 178 | 
            -
                        # For 3D, we need an SDF string
         | 
| 179 | 
            -
                        sdf_string = mol_to_sdf_string(mol)
         | 
| 180 | 
            -
                        img_3d_html = render_mol_3d(sdf_string, width=300, height=300)
         | 
| 181 | 
            -
                        image_3d_list.append(img_3d_html)
         | 
| 182 |  | 
| 183 | 
            -
                         | 
| 184 | 
            -
             | 
| 185 | 
            -
             | 
| 186 | 
            -
             | 
| 187 | 
            -
                     | 
| 188 | 
            -
             | 
| 189 | 
            -
             | 
| 190 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 191 |  | 
|  | |
| 192 | 
             
                clear_gpu_cache()
         | 
|  | |
| 193 |  | 
| 194 | 
            -
             | 
| 195 | 
            -
                return df_results, image_2d_list, image_3d_list, status_message
         | 
| 196 | 
            -
             | 
| 197 | 
            -
             | 
| 198 | 
            -
            def display_molecule_with_3d(smiles_string):
         | 
| 199 | 
             
                """
         | 
| 200 | 
            -
                Displays  | 
| 201 | 
             
                """
         | 
| 202 | 
             
                if not smiles_string:
         | 
| 203 | 
            -
                     | 
|  | |
|  | |
| 204 | 
             
                mol = get_mol(smiles_string)
         | 
| 205 | 
             
                if mol is None:
         | 
| 206 | 
            -
                     | 
| 207 | 
            -
             | 
| 208 | 
            -
                img_2d = MolToImage(mol, size=(400, 400), fitImage=True)
         | 
| 209 |  | 
| 210 | 
            -
                 | 
| 211 | 
            -
                img_3d_html = render_mol_3d(sdf_string, width=400, height=400)
         | 
| 212 |  | 
| 213 | 
            -
                 | 
| 214 | 
            -
             | 
| 215 | 
            -
             | 
| 216 | 
            -
            # --- Streamlit UI Definition ---
         | 
| 217 | 
            -
             | 
| 218 | 
            -
            # Set wide mode and background color
         | 
| 219 | 
            -
            st.set_page_config(layout="wide")
         | 
| 220 | 
            -
             | 
| 221 | 
            -
            st.markdown(
         | 
| 222 | 
            -
                """
         | 
| 223 | 
            -
                <style>
         | 
| 224 | 
            -
                .stApp {
         | 
| 225 | 
            -
                    background-color: rgb(28,28,28);
         | 
| 226 | 
            -
                    color: white; /* Ensure text is visible on dark background */
         | 
| 227 | 
            -
                }
         | 
| 228 | 
            -
                .stDataFrame {
         | 
| 229 | 
            -
                    color: black; /* Default DataFrame text color */
         | 
| 230 | 
            -
                }
         | 
| 231 | 
            -
                h1, h2, h3, h4, h5, h6, .stMarkdown {
         | 
| 232 | 
            -
                    color: white;
         | 
| 233 | 
            -
                }
         | 
| 234 | 
            -
                .css-1d391kg, .css-1dp5dn1 { /* Target Streamlit's main content and sidebar */
         | 
| 235 | 
            -
                    color: white;
         | 
| 236 | 
            -
                }
         | 
| 237 | 
            -
                .streamlit-expanderContent {
         | 
| 238 | 
            -
                    background-color: rgb(40,40,40); /* Slightly lighter background for expanders */
         | 
| 239 | 
            -
                    border-radius: 10px;
         | 
| 240 | 
            -
                    padding: 10px;
         | 
| 241 | 
            -
                }
         | 
| 242 | 
            -
                /* Style for text inputs and buttons */
         | 
| 243 | 
            -
                .stTextInput>div>div>input {
         | 
| 244 | 
            -
                    background-color: rgb(50,50,50);
         | 
| 245 | 
            -
                    color: white;
         | 
| 246 | 
            -
                    border-radius: 5px;
         | 
| 247 | 
            -
                    border: 1px solid rgb(70,70,70);
         | 
| 248 | 
            -
                }
         | 
| 249 | 
            -
                .stButton>button {
         | 
| 250 | 
            -
                    background-color: rgb(0,128,255); /* Blue button */
         | 
| 251 | 
            -
                    color: white;
         | 
| 252 | 
            -
                    border-radius: 8px;
         | 
| 253 | 
            -
                    padding: 10px 20px;
         | 
| 254 | 
            -
                    border: none;
         | 
| 255 | 
            -
                    transition: background-color 0.3s ease;
         | 
| 256 | 
            -
                }
         | 
| 257 | 
            -
                .stButton>button:hover {
         | 
| 258 | 
            -
                    background-color: rgb(0,100,200);
         | 
| 259 | 
            -
                }
         | 
| 260 | 
            -
                </style>
         | 
| 261 | 
            -
                """,
         | 
| 262 | 
            -
                unsafe_allow_html=True
         | 
| 263 | 
            -
            )
         | 
| 264 | 
            -
             | 
| 265 | 
            -
             | 
| 266 | 
            -
            st.title("ChemBERTa SMILES Utilities Dashboard")
         | 
| 267 | 
            -
             | 
| 268 | 
            -
            tab1, tab2 = st.tabs(["Masked SMILES Prediction", "Molecule Viewer"])
         | 
| 269 | 
            -
             | 
| 270 | 
            -
            with tab1:
         | 
| 271 | 
            -
                st.markdown("Enter a SMILES string with a `<mask>` token (e.g., `C1=CC=CC<mask>C1`) to predict possible completions.")
         | 
| 272 |  | 
| 273 | 
            -
                col1, col2 = st.columns([2, 1])
         | 
| 274 | 
             
                with col1:
         | 
| 275 | 
            -
                     | 
|  | |
|  | |
|  | |
| 276 | 
             
                with col2:
         | 
| 277 | 
            -
                     | 
| 278 | 
            -
             | 
| 279 | 
            -
             | 
| 280 | 
            -
             | 
| 281 | 
            -
                         | 
| 282 | 
            -
                             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 283 | 
             
                        )
         | 
| 284 | 
            -
             | 
| 285 | 
            -
                        
         | 
| 286 | 
            -
             | 
| 287 | 
            -
                             | 
| 288 | 
            -
             | 
| 289 | 
            -
             | 
| 290 | 
            -
             | 
| 291 | 
            -
             | 
| 292 | 
            -
             | 
| 293 | 
            -
                                    st.markdown(f"**Prediction {i+1}**")
         | 
| 294 | 
            -
                                    cols_img = st.columns(2)
         | 
| 295 | 
            -
                                    with cols_img[0]:
         | 
| 296 | 
            -
                                        st.image(img_2d_list[i], caption=f"2D Prediction {i+1}", use_column_width=True)
         | 
| 297 | 
            -
                                    with cols_img[1]:
         | 
| 298 | 
            -
                                        st.components.v1.html(img_3d_list[i], height=300)
         | 
| 299 | 
            -
                                else:
         | 
| 300 | 
            -
                                    if i < len(df_predictions): # Only show 'No visualization' if there was a prediction attempt
         | 
| 301 | 
            -
                                         st.markdown(f"**Prediction {i+1}**: No visualization available (invalid SMILES or error).")
         | 
| 302 | 
            -
             | 
| 303 | 
            -
             | 
| 304 | 
            -
            with tab2:
         | 
| 305 | 
            -
                st.markdown("Enter a SMILES string to display its 2D and 3D structure.")
         | 
| 306 | 
            -
                smiles_input_viewer = st.text_input("SMILES String", value="C1=CC=CC=C1", key="viewer_smiles_input")
         | 
| 307 |  | 
| 308 | 
            -
                 | 
| 309 | 
            -
                     | 
| 310 | 
            -
             | 
| 311 | 
            -
             | 
| 312 | 
            -
             | 
| 313 | 
            -
                         | 
| 314 | 
            -
             | 
| 315 | 
            -
             | 
| 316 | 
            -
             | 
| 317 | 
            -
             | 
| 318 | 
            -
             | 
| 319 | 
            -
             | 
| 320 | 
            -
                            st.warning("Could not display molecule. Please check the SMILES string.")
         | 
| 321 |  | 
|  | |
|  | 
|  | |
|  | |
| 1 | 
             
            import streamlit as st
         | 
| 2 | 
             
            import torch
         | 
| 3 | 
            +
            from transformers import AutoModelForMaskedLM, AutoTokenizer, pipeline, BitsAndBytesConfig
         | 
| 4 | 
             
            from rdkit import Chem
         | 
| 5 | 
            +
            from rdkit.Chem import Draw, rdFMCS, AllChem
         | 
| 6 | 
             
            from rdkit.Chem.Draw import MolToImage
         | 
| 7 | 
             
            import pandas as pd
         | 
| 8 | 
             
            import io
         | 
| 9 | 
             
            import base64
         | 
| 10 | 
             
            import logging
         | 
| 11 | 
             
            import py3Dmol
         | 
| 12 | 
            +
            from stmol import showmol
         | 
| 13 |  | 
| 14 | 
            +
            # Set up logging to monitor quantization effects
         | 
| 15 | 
             
            logging.basicConfig(level=logging.INFO)
         | 
| 16 | 
             
            logger = logging.getLogger(__name__)
         | 
| 17 |  | 
| 18 | 
            +
            # Page configuration
         | 
| 19 | 
            +
            st.set_page_config(
         | 
| 20 | 
            +
                page_title="ChemBERTa SMILES Utilities Dashboard",
         | 
| 21 | 
            +
                page_icon="🧪",
         | 
| 22 | 
            +
                layout="wide"
         | 
| 23 | 
            +
            )
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            # --- Quantization Configuration ---
         | 
| 26 | 
            +
            @st.cache_resource
         | 
| 27 | 
            +
            def get_quantization_config():
         | 
| 28 | 
            +
                """
         | 
| 29 | 
            +
                Configure 8-bit quantization for model optimization.
         | 
| 30 | 
            +
                Falls back gracefully if bitsandbytes is not available.
         | 
| 31 | 
            +
                """
         | 
| 32 | 
            +
                try:
         | 
| 33 | 
            +
                    # 8-bit quantization configuration - good balance of speed and quality
         | 
| 34 | 
            +
                    quantization_config = BitsAndBytesConfig(
         | 
| 35 | 
            +
                        load_in_8bit=True,
         | 
| 36 | 
            +
                        bnb_8bit_compute_dtype=torch.float16,
         | 
| 37 | 
            +
                        bnb_8bit_use_double_quant=True,  # Nested quantization for better compression
         | 
| 38 | 
            +
                    )
         | 
| 39 | 
            +
                    logger.info("8-bit quantization configuration loaded successfully")
         | 
| 40 | 
            +
                    return quantization_config
         | 
| 41 | 
            +
                except ImportError:
         | 
| 42 | 
            +
                    logger.warning("bitsandbytes not available, falling back to standard loading")
         | 
| 43 | 
            +
                    return None
         | 
| 44 | 
            +
                except Exception as e:
         | 
| 45 | 
            +
                    logger.warning(f"Quantization setup failed: {e}, using standard loading")
         | 
| 46 | 
            +
                    return None
         | 
| 47 | 
            +
             | 
| 48 | 
            +
            def get_torch_dtype():
         | 
| 49 | 
            +
                """Get appropriate torch dtype based on available hardware."""
         | 
| 50 | 
            +
                if torch.cuda.is_available():
         | 
| 51 | 
            +
                    return torch.float16  # Use half precision on GPU
         | 
| 52 | 
            +
                else:
         | 
| 53 | 
            +
                    return torch.float32  # Keep full precision on CPU
         | 
| 54 | 
            +
             | 
| 55 | 
             
            # --- Optimized Model Loading ---
         | 
| 56 | 
             
            @st.cache_resource
         | 
| 57 | 
             
            def load_optimized_models():
         | 
| 58 | 
            +
                """Load models with quantization and other optimizations."""
         | 
| 59 | 
            +
                device = "cuda" if torch.cuda.is_available() else "cpu"
         | 
| 60 | 
            +
                torch_dtype = get_torch_dtype()
         | 
| 61 | 
            +
                quantization_config = get_quantization_config()
         | 
| 62 |  | 
| 63 | 
             
                logger.info(f"Loading models on device: {device} with dtype: {torch_dtype}")
         | 
| 64 |  | 
| 65 | 
             
                # Model names
         | 
| 66 | 
             
                model_name = "seyonec/PubChem10M_SMILES_BPE_450k"
         | 
| 67 |  | 
| 68 | 
            +
                # Load tokenizer (doesn't need quantization)
         | 
| 69 | 
             
                fill_mask_tokenizer = AutoTokenizer.from_pretrained(model_name)
         | 
| 70 |  | 
| 71 | 
            +
                # Load model with quantization if available
         | 
| 72 | 
            +
                model_kwargs = {
         | 
| 73 | 
            +
                    "torch_dtype": torch_dtype,
         | 
| 74 | 
            +
                }
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                if quantization_config is not None and torch.cuda.is_available(): # Quantization typically for GPU
         | 
| 77 | 
            +
                    model_kwargs["quantization_config"] = quantization_config
         | 
| 78 | 
            +
                    # device_map="auto" is often used with bitsandbytes for automatic distribution
         | 
| 79 | 
            +
                    model_kwargs["device_map"] = "auto"
         | 
| 80 | 
            +
                elif torch.cuda.is_available():
         | 
| 81 | 
            +
                    model_kwargs["device_map"] = "auto" # For non-quantized GPU loading
         | 
| 82 | 
            +
                else:
         | 
| 83 | 
            +
                    model_kwargs["device_map"] = None # For CPU
         | 
| 84 | 
            +
             | 
| 85 | 
             
                try:
         | 
| 86 | 
            +
                    # Masked LM Model
         | 
| 87 | 
             
                    fill_mask_model = AutoModelForMaskedLM.from_pretrained(
         | 
| 88 | 
             
                        model_name,
         | 
| 89 | 
            +
                        **model_kwargs
         | 
|  | |
| 90 | 
             
                    )
         | 
| 91 |  | 
| 92 | 
             
                    # Set model to evaluation mode for inference
         | 
| 93 | 
             
                    fill_mask_model.eval()
         | 
| 94 |  | 
| 95 | 
            +
                    # Create optimized pipeline
         | 
| 96 | 
            +
                    # Let pipeline infer device from model if possible, or set based on model's device
         | 
| 97 | 
            +
                    pipeline_device = fill_mask_model.device.index if hasattr(fill_mask_model.device, 'type') and fill_mask_model.device.type == "cuda" else -1
         | 
| 98 | 
            +
             | 
| 99 | 
             
                    fill_mask_pipeline = pipeline(
         | 
| 100 | 
             
                        'fill-mask',
         | 
| 101 | 
             
                        model=fill_mask_model,
         | 
| 102 | 
             
                        tokenizer=fill_mask_tokenizer,
         | 
| 103 | 
            +
                        device=pipeline_device, # Use model's device
         | 
| 104 | 
             
                    )
         | 
| 105 |  | 
| 106 | 
            +
                    logger.info("Models loaded successfully with optimizations")
         | 
| 107 | 
             
                    return fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline
         | 
| 108 |  | 
| 109 | 
             
                except Exception as e:
         | 
| 110 | 
            +
                    logger.error(f"Error loading optimized models: {e}")
         | 
| 111 | 
            +
                    # Fallback to standard loading
         | 
| 112 | 
            +
                    logger.info("Falling back to standard model loading...")
         | 
| 113 | 
            +
                    return load_standard_models(model_name)
         | 
| 114 | 
            +
             | 
| 115 | 
            +
            def load_standard_models(model_name):
         | 
| 116 | 
            +
                """Fallback standard model loading without quantization."""
         | 
| 117 | 
            +
                fill_mask_tokenizer = AutoTokenizer.from_pretrained(model_name)
         | 
| 118 | 
            +
                fill_mask_model = AutoModelForMaskedLM.from_pretrained(model_name)
         | 
| 119 | 
            +
                # Determine device for standard loading
         | 
| 120 | 
            +
                device_idx = 0 if torch.cuda.is_available() else -1
         | 
| 121 | 
            +
                fill_mask_pipeline = pipeline('fill-mask', model=fill_mask_model, tokenizer=fill_mask_tokenizer, device=device_idx)
         | 
| 122 |  | 
| 123 | 
            +
                if torch.cuda.is_available():
         | 
| 124 | 
            +
                    fill_mask_model.to("cuda")
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                return fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline
         | 
| 127 |  | 
| 128 | 
            +
            # --- Memory Management Utilities ---
         | 
| 129 | 
             
            def clear_gpu_cache():
         | 
| 130 | 
            +
                """Clear CUDA cache to free up memory."""
         | 
| 131 | 
             
                if torch.cuda.is_available():
         | 
| 132 | 
             
                    torch.cuda.empty_cache()
         | 
| 133 |  | 
| 134 | 
            +
            # --- Helper Functions ---
         | 
| 135 | 
             
            def get_mol(smiles):
         | 
| 136 | 
             
                """Converts SMILES to RDKit Mol object and Kekulizes it."""
         | 
| 137 | 
             
                mol = Chem.MolFromSmiles(smiles)
         | 
|  | |
| 173 | 
             
                                 highlightAtomColors={i: highlight_color for i in valid_atomset} if valid_atomset else {})
         | 
| 174 | 
             
                return img
         | 
| 175 |  | 
| 176 | 
            +
            def generate_3d_structure(mol):
         | 
| 177 | 
            +
                """Generate 3D coordinates for a molecule."""
         | 
| 178 | 
            +
                if mol is None:
         | 
| 179 | 
            +
                    return None
         | 
| 180 | 
            +
                
         | 
| 181 | 
            +
                # Create a copy to avoid modifying the original
         | 
| 182 | 
            +
                mol_3d = Chem.Mol(mol)
         | 
| 183 | 
            +
                
         | 
| 184 | 
            +
                # Add hydrogens
         | 
| 185 | 
            +
                mol_3d = Chem.AddHs(mol_3d)
         | 
| 186 | 
            +
                
         | 
| 187 | 
            +
                # Generate 3D coordinates
         | 
| 188 | 
            +
                try:
         | 
| 189 | 
            +
                    AllChem.EmbedMolecule(mol_3d, randomSeed=42)
         | 
| 190 | 
            +
                    AllChem.UFFOptimizeMolecule(mol_3d)
         | 
| 191 | 
            +
                    return mol_3d
         | 
| 192 | 
            +
                except:
         | 
| 193 | 
            +
                    # If 3D generation fails, return None
         | 
| 194 | 
            +
                    return None
         | 
| 195 | 
            +
             | 
| 196 | 
            +
            def mol_to_3d_html(mol):
         | 
| 197 | 
            +
                """Convert molecule to 3D HTML representation using py3Dmol."""
         | 
| 198 | 
             
                if mol is None:
         | 
| 199 | 
             
                    return None
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 200 |  | 
| 201 | 
            +
                # Generate SDF string
         | 
| 202 | 
            +
                sdf = Chem.MolToMolBlock(mol)
         | 
| 203 | 
            +
                
         | 
| 204 | 
            +
                # Create 3D viewer
         | 
| 205 | 
            +
                viewer = py3Dmol.view(width=400, height=400)
         | 
| 206 | 
            +
                viewer.addModel(sdf, 'sdf')
         | 
| 207 | 
            +
                viewer.setStyle({'stick': {}})
         | 
| 208 | 
             
                viewer.zoomTo()
         | 
| 209 | 
            +
                
         | 
| 210 | 
            +
                return viewer
         | 
| 211 |  | 
| 212 | 
             
            # --- Streamlit Interface Functions ---
         | 
| 213 |  | 
| 214 | 
             
            def predict_and_visualize_masked_smiles(smiles_mask, substructure_smarts_highlight="CC=CC"):
         | 
| 215 | 
             
                """
         | 
| 216 | 
             
                Predicts masked tokens in a SMILES string, shows scores, and visualizes molecules.
         | 
|  | |
| 217 | 
             
                """
         | 
| 218 | 
            +
                # Load models
         | 
| 219 | 
            +
                fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline = load_optimized_models()
         | 
| 220 | 
            +
                
         | 
| 221 | 
             
                if fill_mask_tokenizer.mask_token not in smiles_mask:
         | 
| 222 | 
             
                    st.error("Error: Input SMILES must contain a mask token (e.g., <mask>).")
         | 
| 223 | 
            +
                    return
         | 
| 224 |  | 
| 225 | 
             
                try:
         | 
| 226 | 
            +
                    # Use torch.no_grad() for inference to save memory
         | 
| 227 | 
             
                    with torch.no_grad():
         | 
| 228 | 
             
                        predictions = fill_mask_pipeline(smiles_mask, top_k=10)
         | 
| 229 | 
             
                except Exception as e:
         | 
| 230 | 
             
                    clear_gpu_cache()
         | 
| 231 | 
             
                    st.error(f"Error during prediction: {str(e)}")
         | 
| 232 | 
            +
                    return
         | 
| 233 |  | 
| 234 | 
             
                results_data = []
         | 
| 235 | 
            +
                valid_predictions = []
         | 
|  | |
| 236 | 
             
                valid_predictions_count = 0
         | 
| 237 |  | 
| 238 | 
             
                for pred in predictions:
         | 
|  | |
| 245 | 
             
                    mol = get_mol(predicted_smiles)
         | 
| 246 | 
             
                    if mol:
         | 
| 247 | 
             
                        results_data.append({"Predicted SMILES": predicted_smiles, "Score": f"{score:.4f}"})
         | 
| 248 | 
            +
                        valid_predictions.append((mol, predicted_smiles, score))
         | 
| 249 | 
            +
                        valid_predictions_count += 1
         | 
| 250 | 
            +
             | 
| 251 | 
            +
                if valid_predictions_count == 0:
         | 
| 252 | 
            +
                    st.warning("No valid molecules found for top predictions.")
         | 
| 253 | 
            +
                    return
         | 
| 254 |  | 
| 255 | 
            +
                # Display results table
         | 
| 256 | 
            +
                df_results = pd.DataFrame(results_data)
         | 
| 257 | 
            +
                st.subheader("Top Predictions & Scores")
         | 
| 258 | 
            +
                st.dataframe(df_results, use_container_width=True)
         | 
| 259 | 
            +
             | 
| 260 | 
            +
                # Display molecule visualizations
         | 
| 261 | 
            +
                st.subheader("Predicted Molecule Visualizations")
         | 
| 262 | 
            +
                
         | 
| 263 | 
            +
                for i, (mol, smiles, score) in enumerate(valid_predictions):
         | 
| 264 | 
            +
                    st.write(f"**Prediction {i+1}:** {smiles} (Score: {score:.4f})")
         | 
| 265 | 
            +
                    
         | 
| 266 | 
            +
                    col1, col2 = st.columns(2)
         | 
| 267 | 
            +
                    
         | 
| 268 | 
            +
                    with col1:
         | 
| 269 | 
            +
                        st.write("**2D Structure:**")
         | 
| 270 | 
             
                        atom_matches_indices = []
         | 
| 271 | 
             
                        if substructure_smarts_highlight:
         | 
| 272 | 
             
                            matches = find_matches_one(mol, substructure_smarts_highlight)
         | 
| 273 | 
             
                            if matches:
         | 
| 274 | 
            +
                                atom_matches_indices = list(matches[0])
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 275 |  | 
| 276 | 
            +
                        img_2d = get_image_with_highlight(mol, atomset=atom_matches_indices)
         | 
| 277 | 
            +
                        if img_2d:
         | 
| 278 | 
            +
                            st.image(img_2d, use_column_width=True)
         | 
| 279 | 
            +
                    
         | 
| 280 | 
            +
                    with col2:
         | 
| 281 | 
            +
                        st.write("**3D Structure:**")
         | 
| 282 | 
            +
                        mol_3d = generate_3d_structure(mol)
         | 
| 283 | 
            +
                        if mol_3d:
         | 
| 284 | 
            +
                            viewer_3d = mol_to_3d_html(mol_3d)
         | 
| 285 | 
            +
                            if viewer_3d:
         | 
| 286 | 
            +
                                showmol(viewer_3d, height=400, width=400)
         | 
| 287 | 
            +
                        else:
         | 
| 288 | 
            +
                            st.write("3D structure generation failed for this molecule.")
         | 
| 289 | 
            +
                    
         | 
| 290 | 
            +
                    st.divider()
         | 
| 291 |  | 
| 292 | 
            +
                # Clear cache after inference
         | 
| 293 | 
             
                clear_gpu_cache()
         | 
| 294 | 
            +
                st.success("Prediction successful!")
         | 
| 295 |  | 
| 296 | 
            +
            def display_molecule_image(smiles_string):
         | 
|  | |
|  | |
|  | |
|  | |
| 297 | 
             
                """
         | 
| 298 | 
            +
                Displays both 2D and 3D images of a molecule from its SMILES string.
         | 
| 299 | 
             
                """
         | 
| 300 | 
             
                if not smiles_string:
         | 
| 301 | 
            +
                    st.error("Please enter a SMILES string.")
         | 
| 302 | 
            +
                    return
         | 
| 303 | 
            +
                
         | 
| 304 | 
             
                mol = get_mol(smiles_string)
         | 
| 305 | 
             
                if mol is None:
         | 
| 306 | 
            +
                    st.error("Invalid SMILES string.")
         | 
| 307 | 
            +
                    return
         | 
|  | |
| 308 |  | 
| 309 | 
            +
                st.success("Molecule displayed successfully!")
         | 
|  | |
| 310 |  | 
| 311 | 
            +
                col1, col2 = st.columns(2)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 312 |  | 
|  | |
| 313 | 
             
                with col1:
         | 
| 314 | 
            +
                    st.subheader("2D Structure")
         | 
| 315 | 
            +
                    img_2d = MolToImage(mol, size=(400, 400), fitImage=True)
         | 
| 316 | 
            +
                    st.image(img_2d, use_column_width=True)
         | 
| 317 | 
            +
                
         | 
| 318 | 
             
                with col2:
         | 
| 319 | 
            +
                    st.subheader("3D Structure")
         | 
| 320 | 
            +
                    mol_3d = generate_3d_structure(mol)
         | 
| 321 | 
            +
                    if mol_3d:
         | 
| 322 | 
            +
                        viewer_3d = mol_to_3d_html(mol_3d)
         | 
| 323 | 
            +
                        if viewer_3d:
         | 
| 324 | 
            +
                            showmol(viewer_3d, height=400, width=400)
         | 
| 325 | 
            +
                    else:
         | 
| 326 | 
            +
                        st.write("3D structure generation failed for this molecule.")
         | 
| 327 | 
            +
             | 
| 328 | 
            +
            # --- Main Streamlit App ---
         | 
| 329 | 
            +
            def main():
         | 
| 330 | 
            +
                st.title("🧪 ChemBERTa SMILES Utilities Dashboard")
         | 
| 331 | 
            +
                
         | 
| 332 | 
            +
                # Sidebar for navigation
         | 
| 333 | 
            +
                st.sidebar.title("Navigation")
         | 
| 334 | 
            +
                tab_selection = st.sidebar.selectbox(
         | 
| 335 | 
            +
                    "Choose a tool:",
         | 
| 336 | 
            +
                    ["Masked SMILES Prediction", "Molecule Viewer"]
         | 
| 337 | 
            +
                )
         | 
| 338 | 
            +
                
         | 
| 339 | 
            +
                if tab_selection == "Masked SMILES Prediction":
         | 
| 340 | 
            +
                    st.header("Masked SMILES Prediction")
         | 
| 341 | 
            +
                    st.markdown("Enter a SMILES string with a `<mask>` token (e.g., `C1=CC=CC<mask>C1`) to predict possible completions.")
         | 
| 342 | 
            +
                    
         | 
| 343 | 
            +
                    col1, col2 = st.columns(2)
         | 
| 344 | 
            +
                    with col1:
         | 
| 345 | 
            +
                        smiles_input_masked = st.text_input(
         | 
| 346 | 
            +
                            "SMILES String with Mask", 
         | 
| 347 | 
            +
                            value="C1=CC=CC<mask>C1"
         | 
| 348 | 
             
                        )
         | 
| 349 | 
            +
                    with col2:
         | 
| 350 | 
            +
                        substructure_input = st.text_input(
         | 
| 351 | 
            +
                            "Substructure to Highlight (SMARTS)", 
         | 
| 352 | 
            +
                            value="C=C"
         | 
| 353 | 
            +
                        )
         | 
| 354 | 
            +
                    
         | 
| 355 | 
            +
                    if st.button("Predict and Visualize", type="primary"):
         | 
| 356 | 
            +
                        with st.spinner("Predicting masked SMILES..."):
         | 
| 357 | 
            +
                            predict_and_visualize_masked_smiles(smiles_input_masked, substructure_input)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 358 |  | 
| 359 | 
            +
                elif tab_selection == "Molecule Viewer":
         | 
| 360 | 
            +
                    st.header("Molecule Viewer")
         | 
| 361 | 
            +
                    st.markdown("Enter a SMILES string to display its 2D and 3D structure.")
         | 
| 362 | 
            +
                    
         | 
| 363 | 
            +
                    smiles_input_viewer = st.text_input(
         | 
| 364 | 
            +
                        "SMILES String", 
         | 
| 365 | 
            +
                        value="C1=CC=CC=C1"
         | 
| 366 | 
            +
                    )
         | 
| 367 | 
            +
                    
         | 
| 368 | 
            +
                    if st.button("View Molecule", type="primary"):
         | 
| 369 | 
            +
                        with st.spinner("Generating molecule structures..."):
         | 
| 370 | 
            +
                            display_molecule_image(smiles_input_viewer)
         | 
|  | |
| 371 |  | 
| 372 | 
            +
            if __name__ == "__main__":
         | 
| 373 | 
            +
                main()
         | 
