File size: 10,183 Bytes
11e12c3
 
 
 
 
4ed4dfd
a81f473
35ed017
11e12c3
e56ed1f
98e9d9e
7962386
11e12c3
e56ed1f
 
ef610f3
c9fddab
98e9d9e
ef610f3
 
1850745
11e12c3
 
 
 
 
 
 
 
 
 
 
98e9d9e
11e12c3
 
98e9d9e
11e12c3
 
 
 
98e9d9e
 
 
 
11e12c3
98e9d9e
11e12c3
98e9d9e
11e12c3
98e9d9e
11e12c3
7962386
ef610f3
11e12c3
7962386
11e12c3
 
 
 
 
 
e56ed1f
c3644ec
11e12c3
 
 
 
 
 
 
 
 
 
c3644ec
 
11e12c3
 
 
 
 
 
 
 
 
98e9d9e
e56ed1f
 
 
11e12c3
 
 
 
 
e56ed1f
 
 
11e12c3
 
e56ed1f
11e12c3
e56ed1f
 
11e12c3
 
 
 
 
 
 
 
 
 
 
 
98e9d9e
11e12c3
 
 
 
 
 
 
 
 
 
 
 
 
 
e56ed1f
11e12c3
7962386
11e12c3
7962386
11e12c3
7962386
e56ed1f
11e12c3
98e9d9e
e56ed1f
11e12c3
 
 
 
 
 
 
 
35ed017
11e12c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35ed017
 
11e12c3
35ed017
11e12c3
 
98e9d9e
11e12c3
b5c2863
11e12c3
 
 
 
98e9d9e
11e12c3
98e9d9e
 
 
11e12c3
 
 
e56ed1f
98e9d9e
 
11e12c3
 
 
98e9d9e
 
11e12c3
 
 
 
 
 
 
 
 
 
 
 
 
 
98e9d9e
11e12c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ed4dfd
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
# app.py
# To run this app, save the code as app.py and run:
# streamlit run app.py
#
# You also need to install the following libraries:
# pip install streamlit torch transformers bitsandbytes rdkit-pypi py3Dmol pandas

import streamlit as st
import streamlit.components.v1 as components
import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer, pipeline, BitsAndBytesConfig
from rdkit import Chem
from rdkit.Chem import Draw, AllChem
from rdkit.Chem.Draw import MolToImage
import pandas as pd
import logging

# Set up logging to monitor quantization effects
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# --- Page Configuration ---
st.set_page_config(
    page_title="ChemBERTa SMILES Utilities",
    page_icon="πŸ”¬",
    layout="wide",
)

# --- Model Loading (Cached for Performance) ---

@st.cache_resource(show_spinner="Loading ChemBERTa model...")
def load_models():
    """
    Load the tokenizer and model, wrapped in a Streamlit cache resource decorator
    to ensure it only runs once per session.
    """
    device = "cuda" if torch.cuda.is_available() else "cpu"
    torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
    quantization_config = None

    try:
        quantization_config = BitsAndBytesConfig(
            load_in_8bit=True,
            bnb_8bit_compute_dtype=torch.float16,
            bnb_8bit_use_double_quant=True,
        )
        logger.info("8-bit quantization configuration created.")
    except ImportError:
        logger.warning("bitsandbytes not available, falling back to standard loading.")
    except Exception as e:
        logger.warning(f"Quantization setup failed: {e}, using standard loading.")

    model_name = "seyonec/PubChem10M_SMILES_BPE_450k"
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    model_kwargs = {"torch_dtype": torch_dtype}
    if quantization_config and torch.cuda.is_available():
        model_kwargs["quantization_config"] = quantization_config
        model_kwargs["device_map"] = "auto"
    elif torch.cuda.is_available():
        model_kwargs["device_map"] = "auto"

    try:
        model = AutoModelForMaskedLM.from_pretrained(model_name, **model_kwargs)
        model.eval()
        pipeline_device = model.device.index if hasattr(model.device, 'type') and model.device.type == "cuda" else -1
        fill_mask_pipeline = pipeline('fill-mask', model=model, tokenizer=tokenizer, device=pipeline_device)
        logger.info("Models loaded successfully with optimizations.")
        return tokenizer, fill_mask_pipeline
    except Exception as e:
        logger.error(f"Error loading optimized models: {e}. Retrying with standard loading.")
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForMaskedLM.from_pretrained(model_name)
        device_idx = 0 if torch.cuda.is_available() else -1
        if torch.cuda.is_available():
            model.to("cuda")
        fill_mask_pipeline = pipeline('fill-mask', model=model, tokenizer=tokenizer, device=device_idx)
        return tokenizer, fill_mask_pipeline

# Load the models once
fill_mask_tokenizer, fill_mask_pipeline = load_models()


# --- Molecule & Visualization Helpers ---

def get_mol(smiles):
    """Converts SMILES to RDKit Mol object and Kekulizes it."""
    mol = Chem.MolFromSmiles(smiles)
    if mol:
        try:
            Chem.Kekulize(mol)
        except Exception:
            pass
    return mol

def find_matches_one(mol, submol_smarts):
    """Finds all matching atoms for a SMARTS pattern."""
    if not mol or not submol_smarts: return []
    submol = Chem.MolFromSmarts(submol_smarts)
    return mol.GetSubstructMatches(submol) if submol else []

def get_image_with_highlight(mol, atomset=None, size=(300, 300)):
    """Draws a 2D molecule image with optional atom highlighting."""
    if mol is None: return None
    valid_atomset = [int(a) for a in atomset if str(a).isdigit()] if atomset else []
    return MolToImage(mol, size=size, fitImage=True,
                      highlightAtoms=valid_atomset,
                      highlightAtomColors={i: (0, 1, 0, 0.5) for i in valid_atomset})

def generate_3d_view_html(smiles):
    """Generates an interactive 3D molecule view using py3Dmol."""
    if not smiles: return None
    mol = get_mol(smiles)
    if not mol: return "<p>Invalid SMILES for 3D view.</p>"
    try:
        mol_3d = Chem.AddHs(mol)
        AllChem.EmbedMolecule(mol_3d, randomSeed=42, useRandomCoords=True)
        AllChem.MMFFOptimizeMolecule(mol_3d)
        sdf_data = Chem.MolToMolBlock(mol_3d)

        viewer = py3Dmol.view(width=350, height=350)
        viewer.setBackgroundColor('#FFFFFF')
        viewer.addModel(sdf_data, "sdf")
        viewer.setStyle({'stick': {}, 'sphere': {'scale': 0.25}})
        viewer.zoomTo()
        return viewer._make_html()
    except Exception as e:
        logger.error(f"Failed to generate 3D view for {smiles}: {e}")
        return f"<p>Error generating 3D view: {e}</p>"

# --- Core Application Logic ---

def run_masked_smiles_prediction(smiles_mask, substructure_smarts_highlight):
    """
    Handles the logic for the masked SMILES prediction tab.
    """
    if fill_mask_tokenizer.mask_token not in smiles_mask:
        st.error(f"Error: Input SMILES must contain a mask token (e.g., {fill_mask_tokenizer.mask_token}).")
        return

    with st.spinner("Predicting completions..."):
        try:
            with torch.no_grad():
                predictions = fill_mask_pipeline(smiles_mask, top_k=10)
        except Exception as e:
            st.error(f"An error occurred during prediction: {e}")
            if torch.cuda.is_available(): torch.cuda.empty_cache()
            return

        results = []
        for pred in predictions:
            if len(results) >= 5: break
            predicted_smiles = pred['sequence']
            mol = get_mol(predicted_smiles)
            if mol:
                atom_matches = find_matches_one(mol, substructure_smarts_highlight)
                results.append({
                    "smiles": predicted_smiles,
                    "score": f"{pred['score']:.4f}",
                    "image_2d": get_image_with_highlight(mol, atomset=atom_matches[0] if atom_matches else []),
                    "html_3d": generate_3d_view_html(predicted_smiles)
                })
        
        if torch.cuda.is_available(): torch.cuda.empty_cache()
        st.session_state.prediction_results = results


# --- Streamlit UI Definition ---

st.title("πŸ”¬ ChemBERTa SMILES Utilities Dashboard (2D & 3D)")
st.markdown("A tool to predict masked tokens in SMILES strings and visualize molecules, powered by ChemBERTa and Streamlit.")

tab1, tab2 = st.tabs(["Masked SMILES Prediction", "Molecule Viewer (2D & 3D)"])

# --- Tab 1: Masked SMILES Prediction ---
with tab1:
    st.header("Predict and Visualize Masked SMILES")
    st.markdown("Enter a SMILES string with a `<mask>` token to predict possible completions.")

    with st.form(key="prediction_form"):
        col1, col2 = st.columns(2)
        with col1:
            smiles_input_masked = st.text_input(
                "SMILES String with Mask",
                value="C1=CC=CC<mask>C1",
                help=f"The mask token is `{fill_mask_tokenizer.mask_token}`"
            )
        with col2:
            substructure_input = st.text_input(
                "Substructure to Highlight (SMARTS)",
                value="C=C",
                help="Enter a SMARTS pattern to highlight in the 2D images."
            )
        
        predict_button = st.form_submit_button("Predict and Visualize", use_container_width=True)

    if predict_button:
        run_masked_smiles_prediction(smiles_input_masked, substructure_input)

    if 'prediction_results' in st.session_state and st.session_state.prediction_results:
        results = st.session_state.prediction_results
        st.subheader("Top 5 Valid Predictions")

        # Display results in a table
        df_data = [{"Predicted SMILES": r["smiles"], "Score": r["score"]} for r in results]
        st.dataframe(pd.DataFrame(df_data), use_container_width=True)

        st.markdown("---")
        
        # Display molecule visualizations
        for i, res in enumerate(results):
            st.markdown(f"**Prediction {i+1}:** `{res['smiles']}` (Score: {res['score']})")
            col1, col2 = st.columns(2)
            with col1:
                st.subheader("2D Structure")
                if res["image_2d"]:
                    st.image(res["image_2d"], use_column_width=True)
                else:
                    st.warning("Could not generate 2D image.")
            with col2:
                st.subheader("3D Interactive Structure")
                if res["html_3d"]:
                    components.html(res["html_3d"], height=370)
                else:
                    st.warning("Could not generate 3D view.")
            st.markdown("---")

# --- Tab 2: Molecule Viewer ---
with tab2:
    st.header("Visualize a Molecule from SMILES")
    st.markdown("Enter a single SMILES string to display its 2D and 3D structures side-by-side.")

    with st.form(key="viewer_form"):
        smiles_input_viewer = st.text_input("SMILES String", value="CC(=O)Oc1ccccc1C(=O)O") # Aspirin
        view_button = st.form_submit_button("View Molecule", use_container_width=True)

    if view_button and smiles_input_viewer:
        with st.spinner("Generating views..."):
            mol = get_mol(smiles_input_viewer)
            if not mol:
                st.error("Invalid SMILES string provided.")
            else:
                st.subheader(f"Visualizations for: `{smiles_input_viewer}`")
                col1, col2 = st.columns(2)
                with col1:
                    st.subheader("2D Structure")
                    img_2d = MolToImage(mol, size=(450, 450), fitImage=True)
                    st.image(img_2d, use_column_width=True)
                with col2:
                    st.subheader("3D Interactive Structure")
                    html_3d = generate_3d_view_html(smiles_input_viewer)
                    components.html(html_3d, height=470)