File size: 13,718 Bytes
35ed017
1850745
7962386
 
 
35ed017
7962386
 
ef610f3
1850745
7962386
 
c9fddab
7962386
ef610f3
 
1850745
35ed017
 
7962386
 
35ed017
7962386
35ed017
 
7962386
35ed017
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7962386
 
 
 
 
 
ef610f3
7962386
ef610f3
 
7962386
 
ef610f3
7962386
 
ef610f3
7962386
 
 
 
 
 
 
 
 
 
 
 
ef610f3
7962386
 
 
 
 
 
 
 
 
 
 
ef610f3
7962386
 
 
 
 
 
 
 
 
 
ef610f3
 
7962386
 
 
 
ef610f3
35ed017
7962386
 
 
 
 
1850745
7962386
35ed017
7962386
 
 
 
 
 
 
 
1850745
7962386
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3752899
7962386
 
 
3752899
 
 
 
35ed017
3752899
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c9fddab
7962386
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35ed017
7962386
35ed017
7962386
35ed017
7962386
 
 
35ed017
7962386
 
 
b5c2863
7962386
b5c2863
7962386
 
35ed017
7962386
35ed017
7962386
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
import streamlit as st
import pandas as pd
from rdkit import Chem
from rdkit.Chem import Draw, AllChem
from rdkit.Chem.Draw import rdMolDraw2D
import py3Dmol
import io
import base64
import logging

import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer, pipeline, BitsAndBytesConfig

# Set up logging to monitor quantization effects
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# --- Page Configuration ---
st.set_page_config(
    page_title="Molecule Explorer & Predictor",
    page_icon="πŸ”¬",
    layout="wide",
    initial_sidebar_state="collapsed",
)

# Custom CSS for a professional, minimalist look (adapted from drug_app.txt)
def apply_custom_styling():
    st.markdown(
        """
        <style>
        @import url('https://fonts.googleapis.com/css2?family=Roboto:wght@400;700&display=swap');
        
        html, body, [class*="st-"] {
            font-family: 'Roboto', sans-serif;
        }

        .stApp {
            background-color: rgb(28, 28, 28);
            color: white;
        }

        /* Tab styles */
        .stTabs [data-baseweb="tab-list"] {
            gap: 24px;
        }

        .stTabs [data-baseweb="tab"] {
            height: 50px;
            white-space: pre-wrap;
            background: none;
            border-radius: 0px;
            border-bottom: 2px solid #333;
            padding: 10px 4px;
            color: #AAA;
        }
        
        .stTabs [data-baseweb="tab"]:hover {
            background: #222;
            color: #FFF;
        }

        .stTabs [aria-selected="true"] {
            border-bottom: 2px solid #00A0FF; /* Highlight color for active tab */
            color: #FFF;
        }
        
        /* Button styles */
        .stButton>button {
            border-color: #00A0FF;
            color: #00A0FF;
        }
        
        .stButton>button:hover {
            border-color: #FFF;
            color: #FFF;
            background-color: #00A0FF;
        }
        </style>
        """,
        unsafe_allow_html=True
    )

apply_custom_styling()

# --- Quantization Configuration ---
def get_quantization_config():
    """
    Configure 8-bit quantization for model optimization.
    Falls back gracefully if bitsandbytes is not available.
    """
    try:
        # 8-bit quantization configuration - good balance of speed and quality
        quantization_config = BitsAndBytesConfig(
            load_in_8bit=True,
            bnb_8bit_compute_dtype=torch.float16,
            bnb_8bit_use_double_quant=True,  # Nested quantization for better compression
        )
        logger.info("8-bit quantization configuration loaded successfully")
        return quantization_config
    except ImportError:
        logger.warning("bitsandbytes not available, falling back to standard loading")
        return None
    except Exception as e:
        logger.warning(f"Quantization setup failed: {e}, using standard loading")
        return None

def get_torch_dtype():
    """Get appropriate torch dtype based on available hardware."""
    if torch.cuda.is_available():
        return torch.float16  # Use half precision on GPU
    else:
        return torch.float32  # Keep full precision on CPU

# --- Optimized Model Loading with Streamlit Caching ---
@st.cache_resource(show_spinner="Loading molecular language model...")
def load_optimized_models():
    """Load models with quantization and other optimizations using Streamlit caching."""
    device = "cuda" if torch.cuda.is_available() else "cpu"
    torch_dtype = get_torch_dtype()
    quantization_config = get_quantization_config()

    logger.info(f"Loading models on device: {device} with dtype: {torch_dtype}")

    # Model names
    model_name = "seyonec/PubChem10M_SMILES_BPE_450k"

    # Load tokenizer
    fill_mask_tokenizer = AutoTokenizer.from_pretrained(model_name)

    # Load model with quantization if available
    model_kwargs = {
        "torch_dtype": torch_dtype,
    }

    if quantization_config is not None and torch.cuda.is_available(): # Quantization typically for GPU
        model_kwargs["quantization_config"] = quantization_config
        model_kwargs["device_map"] = "auto"
    elif torch.cuda.is_available():
        model_kwargs["device_map"] = "auto" # For non-quantized GPU loading
    else:
        model_kwargs["device_map"] = None # For CPU

    try:
        fill_mask_model = AutoModelForMaskedLM.from_pretrained(
            model_name,
            **model_kwargs
        )
        fill_mask_model.eval()

        pipeline_device = fill_mask_model.device.index if hasattr(fill_mask_model.device, 'type') and fill_mask_model.device.type == "cuda" else -1

        fill_mask_pipeline = pipeline(
            'fill-mask',
            model=fill_mask_model,
            tokenizer=fill_mask_tokenizer,
            device=pipeline_device,
        )
        logger.info("Models loaded successfully with optimizations")
        return fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline
    except Exception as e:
        logger.error(f"Error loading optimized models: {e}")
        logger.info("Falling back to standard model loading...")
        return load_standard_models(model_name)

@st.cache_resource(show_spinner="Loading standard molecular language model...")
def load_standard_models(model_name="seyonec/PubChem10M_SMILES_BPE_450k"):
    """Fallback standard model loading without quantization using Streamlit caching."""
    fill_mask_tokenizer = AutoTokenizer.from_pretrained(model_name)
    fill_mask_model = AutoModelForMaskedLM.from_pretrained(model_name)
    device_idx = 0 if torch.cuda.is_available() else -1
    fill_mask_pipeline = pipeline('fill-mask', model=fill_mask_model, tokenizer=fill_mask_tokenizer, device=device_idx)
    if torch.cuda.is_available():
        fill_mask_model.to("cuda")
    logger.info("Standard models loaded successfully")
    return fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline

# --- RDKit and Py3Dmol Visualization Functions ---

def mol_to_svg(mol, size=(400, 300)):
    """Converts an RDKit molecule object to an SVG image string using default RDKit colors."""
    if not mol:
        return None
    drawer = rdMolDraw2D.MolDraw2DSVG(*size)
    # Removing custom color settings as per user request to use default RDKit colors
    # drawer.drawOptions().clearBackground = False # Keep background transparent/dark
    # drawer.drawOptions().addStereoAnnotation = True
    # drawer.drawOptions().baseFontSize = 0.8
    
    # # Set dark theme colors for RDKit drawing - REMOVED AS PER USER REQUEST
    # atom_colors = {
    #     6: (0.8, 0.8, 0.8),  # Carbon (light gray)
    #     7: (0.2, 0.5, 1.0),  # Nitrogen (blue)
    #     8: (1.0, 0.2, 0.2),  # Oxygen (red)
    #     9: (0.2, 0.8, 0.2),  # Fluorine (green)
    #     15: (1.0, 0.5, 0.0), # Phosphorus (orange)
    #     16: (1.0, 0.8, 0.0), # Sulfur (yellow)
    #     17: (0.2, 0.7, 0.2), # Chlorine (dark green)
    #     35: (0.5, 0.2, 0.8), # Bromine (purple)
    #     53: (0.8, 0.2, 0.5), # Iodine (pink/magenta)
    # }
    # # Set default atom color
    # drawer.drawOptions().setAtomColor(Chem.rdatomicnumlist.Get): (0.8, 0.8, 0.8) # Default to light gray for unknown atoms
    # for atom_num, color in atom_colors.items():
    #     drawer.drawOptions().setAtomColor(atom_num, color)

    # drawer.drawOptions().bondColor = (0.7, 0.7, 0.7) # Bond color (medium gray)
    # drawer.drawOptions().highlightColour = (0.2, 0.6, 1.0) # Highlight color (blue)
    
    drawer.DrawMolecule(mol)
    drawer.FinishDrawing()
    svg = drawer.GetDrawingText()
    return svg

def mol_to_sdf(mol):
    """Converts an RDKit molecule object to an SDF string."""
    if not mol:
        return None
    # Add hydrogens to the molecule
    mol_with_h = Chem.AddHs(mol)
    
    # Generate 3D coordinates using ETKDGv3, a common conformer generation method
    # MaxAttempts is increased for robustness, randomSeed for reproducibility
    try:
        AllChem.EmbedMolecule(mol_with_h, AllChem.ETKDGv3(), maxAttempts=50, randomSeed=42)
        # Optimize 3D coordinates using Universal Force Field (UFF)
        AllChem.UFFOptimizeMolecule(mol_with_h)
        sdf_string = Chem.MolToMolBlock(mol_with_h)
        return sdf_string
    except Exception as e:
        logger.error(f"Error generating 3D coordinates for SMILES: {Chem.MolToSmiles(mol)} - {e}")
        return None

def visualize_molecule_3d(mol_sdf: str, width='100%', height=400):
    """
    Generates an interactive 3D molecule visualization using py3Dmol.
    Accepts an SDF string.
    """
    if not mol_sdf:
        return None
    try:
        viewer = py3Dmol.view(width=width, height=height)
        viewer.setBackgroundColor('#1C1C1C') # Dark background
        viewer.addModel(mol_sdf, "sdf")
        viewer.setStyle({'stick':{}, 'sphere':{'radius':0.3}}) # Stick and Sphere representation
        viewer.zoomTo()
        html_view = viewer._make_html()
        return html_view
    except Exception as e:
        st.error(f"Error generating 3D visualization: {e}")
        return None

# --- Main Streamlit Application Layout ---

st.title("πŸ”¬ Molecule Explorer & Predictor")

# Initialize session state for consistent data across reruns
if 'tokenizer' not in st.session_state:
    st.session_state.tokenizer, st.session_state.model, st.session_state.pipeline = load_optimized_models()

tokenizer = st.session_state.tokenizer
model = st.session_state.model
fill_mask_pipeline = st.session_state.pipeline

tab1, tab2 = st.tabs(["Molecule Viewer (2D & 3D)", "Masked SMILES Predictor"])

with tab1:
    st.header("Visualize Molecules in 2D and 3D")
    
    smiles_input = st.text_input("Enter SMILES string:", "CCO", help="e.g., CCO (ethanol), C1=CC=CC=C1 (benzene)")
    
    if st.button("View Molecule"):
        if smiles_input:
            mol = Chem.MolFromSmiles(smiles_input)
            if mol:
                st.subheader("2D Structure")
                svg = mol_to_svg(mol)
                if svg:
                    st.image(svg, use_column_width=True)
                else:
                    st.warning("Could not generate 2D image.")

                st.subheader("3D Structure (Interactive)")
                sdf_string = mol_to_sdf(mol)
                if sdf_string:
                    html_3d = visualize_molecule_3d(sdf_string)
                    if html_3d:
                        st.components.v1.html(html_3d, width=700, height=500, scrolling=False)
                    else:
                        st.warning("Could not generate 3D visualization.")
                else:
                    st.warning("Could not generate 3D SDF data.")
            else:
                st.error("Invalid SMILES string. Please enter a valid chemical structure.")
        else:
            st.info("Please enter a SMILES string to view the molecule.")

with tab2:
    st.header("Masked SMILES Prediction")

    masked_smiles_input = st.text_input(
        "Enter masked SMILES string (use `<mask>` for the masked token):",
        "C1=CC=CC<mask>C1",
        help="Example: 'C1=CC=CC<mask>C1' (masked benzene), 'CCO<mask>C' (masked ether)"
    )
    top_k_predictions = st.slider("Number of predictions to show:", 1, 10, 5)

    if st.button("Predict Masked Token"):
        if masked_smiles_input and "<mask>" in masked_smiles_input:
            try:
                # Perform prediction using the loaded pipeline
                predictions = fill_mask_pipeline(masked_smiles_input, top_k=top_k_predictions)
                
                prediction_data = []
                for pred in predictions:
                    token_str = pred['token_str']
                    sequence = pred['sequence']
                    score = pred['score']

                    mol = Chem.MolFromSmiles(sequence)
                    img_svg = None
                    if mol:
                        img_svg = mol_to_svg(mol, size=(200,150)) # Smaller image for table

                    prediction_data.append({
                        "Predicted Token": token_str,
                        "Full SMILES": sequence,
                        "Confidence Score": f"{score:.4f}",
                        "Structure SVG": img_svg # Store SVG string
                    })

                df_predictions = pd.DataFrame(prediction_data)

                st.subheader("Predictions:")
                
                # Create a version of the dataframe without the SVG for initial display
                display_df = df_predictions.drop(columns=["Structure SVG"])
                st.dataframe(display_df, use_container_width=True, hide_index=True)

                st.subheader("Predicted Structures:")
                # Determine the number of columns based on the number of predictions, up to a max
                num_cols = min(len(df_predictions), 5) # Display up to 5 images per row
                cols = st.columns(num_cols) 

                for i, row in df_predictions.iterrows():
                    with cols[i % num_cols]: # Distribute images into columns
                        st.markdown(f"**{row['Predicted Token']}** (Score: {row['Confidence Score']})")
                        if row['Structure SVG']:
                            st.image(row['Structure SVG'], use_column_width='auto')
                        else:
                            st.write("*(Invalid SMILES)*")

            except Exception as e:
                st.error(f"An error occurred during prediction: {e}")
                st.info("Please ensure your masked SMILES is valid and contains `<mask>`.")
        else:
            st.info("Please enter a masked SMILES string (e.g., `C1=CC=CC<mask>C1`).")