alidenewade commited on
Commit
11e12c3
·
verified ·
1 Parent(s): 425ba96

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +189 -384
app.py CHANGED
@@ -1,448 +1,253 @@
1
- #!/usr/bin/env python3
2
- """
3
- ChemBERTa SMILES Utilities Dashboard
4
- A Streamlit application for molecular prediction and visualization
5
- """
 
6
 
7
  import streamlit as st
 
8
  import torch
9
  from transformers import AutoModelForMaskedLM, AutoTokenizer, pipeline, BitsAndBytesConfig
10
  from rdkit import Chem
11
- from rdkit.Chem import Draw, rdFMCS, AllChem
12
  from rdkit.Chem.Draw import MolToImage
13
  import pandas as pd
14
- import io
15
- import base64
16
  import logging
17
- import streamlit.components.v1 as components
18
- import sys
19
- import os
20
-
21
- # Check if running in Streamlit context
22
- def is_streamlit_context():
23
- """Check if we're running in a Streamlit context"""
24
- try:
25
- import streamlit.runtime.scriptrunner as sr
26
- return sr.get_script_run_ctx() is not None
27
- except:
28
- return False
29
-
30
- # Only proceed if we're in a Streamlit context or being run by streamlit
31
- if not is_streamlit_context() and __name__ == "__main__":
32
- print("This app must be run with: streamlit run app.py")
33
- print("Please use the command: streamlit run app.py --server.port=7860 --server.address=0.0.0.0")
34
- sys.exit(1)
35
 
36
  # Set up logging to monitor quantization effects
37
  logging.basicConfig(level=logging.INFO)
38
  logger = logging.getLogger(__name__)
39
 
40
- # Page configuration - only if in streamlit context
41
- if is_streamlit_context() or 'streamlit' in sys.modules:
42
- st.set_page_config(
43
- page_title="ChemBERTa SMILES Utilities Dashboard",
44
- page_icon="🧪",
45
- layout="wide"
46
- )
47
-
48
- # --- Quantization Configuration ---
49
- @st.cache_resource
50
- def get_quantization_config():
51
  """
52
- Configure 8-bit quantization for model optimization.
53
- Falls back gracefully if bitsandbytes is not available.
54
  """
 
 
 
 
55
  try:
56
- # Only use quantization on CUDA
57
- if not torch.cuda.is_available():
58
- logger.info("CUDA not available, skipping quantization")
59
- return None
60
-
61
- # 8-bit quantization configuration - good balance of speed and quality
62
  quantization_config = BitsAndBytesConfig(
63
  load_in_8bit=True,
64
  bnb_8bit_compute_dtype=torch.float16,
65
- bnb_8bit_use_double_quant=True, # Nested quantization for better compression
66
  )
67
- logger.info("8-bit quantization configuration loaded successfully")
68
- return quantization_config
69
  except ImportError:
70
- logger.warning("bitsandbytes not available, falling back to standard loading")
71
- return None
72
  except Exception as e:
73
- logger.warning(f"Quantization setup failed: {e}, using standard loading")
74
- return None
75
-
76
- def get_torch_dtype():
77
- """Get appropriate torch dtype based on available hardware."""
78
- if torch.cuda.is_available():
79
- return torch.float16 # Use half precision on GPU
80
- else:
81
- return torch.float32 # Keep full precision on CPU
82
-
83
- # --- Optimized Model Loading ---
84
- @st.cache_resource
85
- def load_optimized_models():
86
- """Load models with quantization and other optimizations."""
87
- device = "cuda" if torch.cuda.is_available() else "cpu"
88
- torch_dtype = get_torch_dtype()
89
- quantization_config = get_quantization_config()
90
-
91
- logger.info(f"Loading models on device: {device} with dtype: {torch_dtype}")
92
 
93
- # Model names
94
  model_name = "seyonec/PubChem10M_SMILES_BPE_450k"
 
95
 
96
- try:
97
- # Load tokenizer (doesn't need quantization)
98
- fill_mask_tokenizer = AutoTokenizer.from_pretrained(model_name)
99
-
100
- # Load model with quantization if available
101
- model_kwargs = {
102
- "torch_dtype": torch_dtype,
103
- }
104
-
105
- if quantization_config is not None and torch.cuda.is_available():
106
- model_kwargs["quantization_config"] = quantization_config
107
- model_kwargs["device_map"] = "auto"
108
- else:
109
- # For CPU or non-quantized loading
110
- model_kwargs["device_map"] = None
111
-
112
- # Masked LM Model
113
- fill_mask_model = AutoModelForMaskedLM.from_pretrained(
114
- model_name,
115
- **model_kwargs
116
- )
117
-
118
- # Move to device if not using device_map
119
- if model_kwargs["device_map"] is None and torch.cuda.is_available():
120
- fill_mask_model.to(device)
121
-
122
- # Set model to evaluation mode for inference
123
- fill_mask_model.eval()
124
-
125
- # Create pipeline with proper device handling
126
- pipeline_device = 0 if torch.cuda.is_available() else -1
127
-
128
- fill_mask_pipeline = pipeline(
129
- 'fill-mask',
130
- model=fill_mask_model,
131
- tokenizer=fill_mask_tokenizer,
132
- device=pipeline_device,
133
- )
134
 
135
- logger.info("Models loaded successfully with optimizations")
136
- return fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline
137
-
138
- except Exception as e:
139
- logger.error(f"Error loading optimized models: {e}")
140
- # Fallback to standard loading
141
- logger.info("Falling back to standard model loading...")
142
- return load_standard_models(model_name)
143
-
144
- def load_standard_models(model_name):
145
- """Fallback standard model loading without quantization."""
146
  try:
147
- fill_mask_tokenizer = AutoTokenizer.from_pretrained(model_name)
148
- fill_mask_model = AutoModelForMaskedLM.from_pretrained(
149
- model_name,
150
- torch_dtype=torch.float32
151
- )
152
-
153
- # Determine device for standard loading
 
 
 
154
  device_idx = 0 if torch.cuda.is_available() else -1
155
-
156
  if torch.cuda.is_available():
157
- fill_mask_model.to("cuda")
158
-
159
- fill_mask_pipeline = pipeline(
160
- 'fill-mask',
161
- model=fill_mask_model,
162
- tokenizer=fill_mask_tokenizer,
163
- device=device_idx
164
- )
 
165
 
166
- return fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline
167
- except Exception as e:
168
- logger.error(f"Failed to load models: {e}")
169
- if is_streamlit_context():
170
- st.error(f"Failed to load models: {e}")
171
- return None, None, None
172
-
173
- # --- Memory Management Utilities ---
174
- def clear_gpu_cache():
175
- """Clear CUDA cache to free up memory."""
176
- if torch.cuda.is_available():
177
- torch.cuda.empty_cache()
178
-
179
- # --- Helper Functions ---
180
  def get_mol(smiles):
181
  """Converts SMILES to RDKit Mol object and Kekulizes it."""
182
  mol = Chem.MolFromSmiles(smiles)
183
- if mol is None:
184
- return None
185
- try:
186
- Chem.Kekulize(mol)
187
- except: # Kekulization can fail for some structures
188
- pass
189
  return mol
190
 
191
  def find_matches_one(mol, submol_smarts):
192
- """Finds all matching atoms for a SMARTS pattern in a molecule."""
193
- if not mol or not submol_smarts:
194
- return []
195
  submol = Chem.MolFromSmarts(submol_smarts)
196
- if not submol:
197
- return []
198
- matches = mol.GetSubstructMatches(submol)
199
- return matches
200
 
201
  def get_image_with_highlight(mol, atomset=None, size=(300, 300)):
202
- """Draws molecule with optional atom highlighting."""
203
- if mol is None:
204
- return None
205
- highlight_color = (0, 1, 0, 0.5) # Green with some transparency
206
-
207
- # Ensure atomset contains integers if not None or empty
208
- valid_atomset = []
209
- if atomset:
210
- try:
211
- valid_atomset = [int(a) for a in atomset]
212
- except (ValueError, TypeError):
213
- logger.warning(f"Invalid atom in atomset: {atomset}. Proceeding without highlighting problematic atoms.")
214
- valid_atomset = [int(a) for a in atomset if str(a).isdigit()] # Filter out non-integers
215
-
216
- img = MolToImage(mol, size=size, fitImage=True,
217
- highlightAtoms=valid_atomset if valid_atomset else [],
218
- highlightAtomColors={i: highlight_color for i in valid_atomset} if valid_atomset else {})
219
- return img
220
-
221
- def generate_3d_structure(mol):
222
- """Generate 3D coordinates for a molecule."""
223
- if mol is None:
224
- return None
225
-
226
- # Create a copy to avoid modifying the original
227
- mol_3d = Chem.Mol(mol)
228
-
229
- # Add hydrogens
230
- mol_3d = Chem.AddHs(mol_3d)
231
-
232
- # Generate 3D coordinates
233
  try:
234
- AllChem.EmbedMolecule(mol_3d, randomSeed=42)
235
- AllChem.UFFOptimizeMolecule(mol_3d)
236
- return mol_3d
237
- except:
238
- # If 3D generation fails, return None
239
- return None
240
-
241
- def mol_to_3d_html(mol):
242
- """Convert molecule to 3D HTML representation using py3Dmol."""
243
- if mol is None:
244
- return None
245
-
246
- # Generate SDF string
247
- sdf = Chem.MolToMolBlock(mol)
248
-
249
- # Create 3D viewer HTML
250
- html_template = """
251
- <div id="3dmolviewer_{id}" style="height: 400px; width: 100%; position: relative;" class="viewer_3Dmoljs"></div>
252
- <script src="https://cdnjs.cloudflare.com/ajax/libs/3Dmol/2.0.4/3Dmol-min.js"></script>
253
- <script>
254
- let viewer_{id} = $3Dmol.createViewer(document.getElementById('3dmolviewer_{id}'), {{
255
- defaultcolors: $3Dmol.rasmolElementColors
256
- }});
257
- viewer_{id}.addModel(`{sdf}`, 'sdf');
258
- viewer_{id}.setStyle({{}}, {{stick: {{}}}});
259
- viewer_{id}.zoomTo();
260
- viewer_{id}.render();
261
- </script>
262
- """
263
-
264
- import random
265
- viewer_id = random.randint(1000, 9999)
266
-
267
- html_content = html_template.format(id=viewer_id, sdf=sdf.replace('`', '\\`'))
268
-
269
- return html_content
270
 
271
- # --- Streamlit Interface Functions ---
272
 
273
- def predict_and_visualize_masked_smiles(smiles_mask, substructure_smarts_highlight="CC=CC"):
274
  """
275
- Predicts masked tokens in a SMILES string, shows scores, and visualizes molecules.
276
  """
277
- # Load models when needed
278
- try:
279
- models = load_optimized_models()
280
- if models[0] is None: # Check if loading failed
281
- st.error("Failed to load models. Please check the logs.")
282
- return
283
- fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline = models
284
- except Exception as e:
285
- st.error(f"Error loading models: {str(e)}")
286
- return
287
-
288
  if fill_mask_tokenizer.mask_token not in smiles_mask:
289
- st.error("Error: Input SMILES must contain a mask token (e.g., <mask>).")
290
  return
291
 
292
- try:
293
- # Use torch.no_grad() for inference to save memory
294
- with torch.no_grad():
295
- predictions = fill_mask_pipeline(smiles_mask, top_k=10)
296
- except Exception as e:
297
- clear_gpu_cache()
298
- st.error(f"Error during prediction: {str(e)}")
299
- return
300
 
301
- results_data = []
302
- valid_predictions = []
303
- valid_predictions_count = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
304
 
305
- for pred in predictions:
306
- if valid_predictions_count >= 5:
307
- break
308
 
309
- predicted_smiles = pred['sequence']
310
- score = pred['score']
311
 
312
- mol = get_mol(predicted_smiles)
313
- if mol:
314
- results_data.append({"Predicted SMILES": predicted_smiles, "Score": f"{score:.4f}"})
315
- valid_predictions.append((mol, predicted_smiles, score))
316
- valid_predictions_count += 1
317
 
318
- if valid_predictions_count == 0:
319
- st.warning("No valid molecules found for top predictions.")
320
- return
321
 
322
- # Display results table
323
- df_results = pd.DataFrame(results_data)
324
- st.subheader("Top Predictions & Scores")
325
- st.dataframe(df_results, use_container_width=True)
326
 
327
- # Display molecule visualizations
328
- st.subheader("Predicted Molecule Visualizations")
329
-
330
- for i, (mol, smiles, score) in enumerate(valid_predictions):
331
- st.write(f"**Prediction {i+1}:** {smiles} (Score: {score:.4f})")
332
-
333
- col1, col2 = st.columns(2)
334
-
335
- with col1:
336
- st.write("**2D Structure:**")
337
- atom_matches_indices = []
338
- if substructure_smarts_highlight:
339
- matches = find_matches_one(mol, substructure_smarts_highlight)
340
- if matches:
341
- atom_matches_indices = list(matches[0])
342
-
343
- img_2d = get_image_with_highlight(mol, atomset=atom_matches_indices)
344
- if img_2d:
345
- st.image(img_2d, use_column_width=True)
346
-
347
- with col2:
348
- st.write("**3D Structure:**")
349
- mol_3d = generate_3d_structure(mol)
350
- if mol_3d:
351
- html_3d = mol_to_3d_html(mol_3d)
352
- if html_3d:
353
- components.html(html_3d, height=450)
354
- else:
355
- st.write("3D structure generation failed for this molecule.")
356
-
357
- st.divider()
358
-
359
- # Clear cache after inference
360
- clear_gpu_cache()
361
- st.success("Prediction successful!")
362
-
363
- def display_molecule_image(smiles_string):
364
- """
365
- Displays both 2D and 3D images of a molecule from its SMILES string.
366
- """
367
- if not smiles_string:
368
- st.error("Please enter a SMILES string.")
369
- return
370
-
371
- mol = get_mol(smiles_string)
372
- if mol is None:
373
- st.error("Invalid SMILES string.")
374
- return
375
-
376
- st.success("Molecule displayed successfully!")
377
-
378
- col1, col2 = st.columns(2)
379
-
380
- with col1:
381
- st.subheader("2D Structure")
382
- img_2d = MolToImage(mol, size=(400, 400), fitImage=True)
383
- st.image(img_2d, use_column_width=True)
384
-
385
- with col2:
386
- st.subheader("3D Structure")
387
- mol_3d = generate_3d_structure(mol)
388
- if mol_3d:
389
- html_3d = mol_to_3d_html(mol_3d)
390
- if html_3d:
391
- components.html(html_3d, height=450)
392
- else:
393
- st.write("3D structure generation failed for this molecule.")
394
-
395
- # --- Main Streamlit App ---
396
- def main():
397
- # Only run if in Streamlit context
398
- if not is_streamlit_context():
399
- return
400
-
401
- # Initialize session state
402
- if 'initialized' not in st.session_state:
403
- st.session_state.initialized = True
404
-
405
- st.title("🧪 ChemBERTa SMILES Utilities Dashboard")
406
-
407
- # Sidebar for navigation
408
- st.sidebar.title("Navigation")
409
- tab_selection = st.sidebar.selectbox(
410
- "Choose a tool:",
411
- ["Masked SMILES Prediction", "Molecule Viewer"]
412
- )
413
-
414
- if tab_selection == "Masked SMILES Prediction":
415
- st.header("Masked SMILES Prediction")
416
- st.markdown("Enter a SMILES string with a `<mask>` token (e.g., `C1=CC=CC<mask>C1`) to predict possible completions.")
417
-
418
  col1, col2 = st.columns(2)
419
  with col1:
420
  smiles_input_masked = st.text_input(
421
- "SMILES String with Mask",
422
- value="C1=CC=CC<mask>C1"
 
423
  )
424
  with col2:
425
  substructure_input = st.text_input(
426
- "Substructure to Highlight (SMARTS)",
427
- value="C=C"
 
428
  )
429
 
430
- if st.button("Predict and Visualize", type="primary"):
431
- with st.spinner("Predicting masked SMILES..."):
432
- predict_and_visualize_masked_smiles(smiles_input_masked, substructure_input)
433
-
434
- elif tab_selection == "Molecule Viewer":
435
- st.header("Molecule Viewer")
436
- st.markdown("Enter a SMILES string to display its 2D and 3D structure.")
437
-
438
- smiles_input_viewer = st.text_input(
439
- "SMILES String",
440
- value="C1=CC=CC=C1"
441
- )
 
 
442
 
443
- if st.button("View Molecule", type="primary"):
444
- with st.spinner("Generating molecule structures..."):
445
- display_molecule_image(smiles_input_viewer)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
446
 
447
- if __name__ == "__main__":
448
- main()
 
1
+ # app.py
2
+ # To run this app, save the code as app.py and run:
3
+ # streamlit run app.py
4
+ #
5
+ # You also need to install the following libraries:
6
+ # pip install streamlit torch transformers bitsandbytes rdkit-pypi py3Dmol pandas
7
 
8
  import streamlit as st
9
+ import streamlit.components.v1 as components
10
  import torch
11
  from transformers import AutoModelForMaskedLM, AutoTokenizer, pipeline, BitsAndBytesConfig
12
  from rdkit import Chem
13
+ from rdkit.Chem import Draw, AllChem
14
  from rdkit.Chem.Draw import MolToImage
15
  import pandas as pd
 
 
16
  import logging
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  # Set up logging to monitor quantization effects
19
  logging.basicConfig(level=logging.INFO)
20
  logger = logging.getLogger(__name__)
21
 
22
+ # --- Page Configuration ---
23
+ st.set_page_config(
24
+ page_title="ChemBERTa SMILES Utilities",
25
+ page_icon="🔬",
26
+ layout="wide",
27
+ )
28
+
29
+ # --- Model Loading (Cached for Performance) ---
30
+
31
+ @st.cache_resource(show_spinner="Loading ChemBERTa model...")
32
+ def load_models():
33
  """
34
+ Load the tokenizer and model, wrapped in a Streamlit cache resource decorator
35
+ to ensure it only runs once per session.
36
  """
37
+ device = "cuda" if torch.cuda.is_available() else "cpu"
38
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
39
+ quantization_config = None
40
+
41
  try:
 
 
 
 
 
 
42
  quantization_config = BitsAndBytesConfig(
43
  load_in_8bit=True,
44
  bnb_8bit_compute_dtype=torch.float16,
45
+ bnb_8bit_use_double_quant=True,
46
  )
47
+ logger.info("8-bit quantization configuration created.")
 
48
  except ImportError:
49
+ logger.warning("bitsandbytes not available, falling back to standard loading.")
 
50
  except Exception as e:
51
+ logger.warning(f"Quantization setup failed: {e}, using standard loading.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
 
53
  model_name = "seyonec/PubChem10M_SMILES_BPE_450k"
54
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
55
 
56
+ model_kwargs = {"torch_dtype": torch_dtype}
57
+ if quantization_config and torch.cuda.is_available():
58
+ model_kwargs["quantization_config"] = quantization_config
59
+ model_kwargs["device_map"] = "auto"
60
+ elif torch.cuda.is_available():
61
+ model_kwargs["device_map"] = "auto"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
 
 
 
 
 
 
 
 
 
 
 
63
  try:
64
+ model = AutoModelForMaskedLM.from_pretrained(model_name, **model_kwargs)
65
+ model.eval()
66
+ pipeline_device = model.device.index if hasattr(model.device, 'type') and model.device.type == "cuda" else -1
67
+ fill_mask_pipeline = pipeline('fill-mask', model=model, tokenizer=tokenizer, device=pipeline_device)
68
+ logger.info("Models loaded successfully with optimizations.")
69
+ return tokenizer, fill_mask_pipeline
70
+ except Exception as e:
71
+ logger.error(f"Error loading optimized models: {e}. Retrying with standard loading.")
72
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
73
+ model = AutoModelForMaskedLM.from_pretrained(model_name)
74
  device_idx = 0 if torch.cuda.is_available() else -1
 
75
  if torch.cuda.is_available():
76
+ model.to("cuda")
77
+ fill_mask_pipeline = pipeline('fill-mask', model=model, tokenizer=tokenizer, device=device_idx)
78
+ return tokenizer, fill_mask_pipeline
79
+
80
+ # Load the models once
81
+ fill_mask_tokenizer, fill_mask_pipeline = load_models()
82
+
83
+
84
+ # --- Molecule & Visualization Helpers ---
85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  def get_mol(smiles):
87
  """Converts SMILES to RDKit Mol object and Kekulizes it."""
88
  mol = Chem.MolFromSmiles(smiles)
89
+ if mol:
90
+ try:
91
+ Chem.Kekulize(mol)
92
+ except Exception:
93
+ pass
 
94
  return mol
95
 
96
  def find_matches_one(mol, submol_smarts):
97
+ """Finds all matching atoms for a SMARTS pattern."""
98
+ if not mol or not submol_smarts: return []
 
99
  submol = Chem.MolFromSmarts(submol_smarts)
100
+ return mol.GetSubstructMatches(submol) if submol else []
 
 
 
101
 
102
  def get_image_with_highlight(mol, atomset=None, size=(300, 300)):
103
+ """Draws a 2D molecule image with optional atom highlighting."""
104
+ if mol is None: return None
105
+ valid_atomset = [int(a) for a in atomset if str(a).isdigit()] if atomset else []
106
+ return MolToImage(mol, size=size, fitImage=True,
107
+ highlightAtoms=valid_atomset,
108
+ highlightAtomColors={i: (0, 1, 0, 0.5) for i in valid_atomset})
109
+
110
+ def generate_3d_view_html(smiles):
111
+ """Generates an interactive 3D molecule view using py3Dmol."""
112
+ if not smiles: return None
113
+ mol = get_mol(smiles)
114
+ if not mol: return "<p>Invalid SMILES for 3D view.</p>"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  try:
116
+ mol_3d = Chem.AddHs(mol)
117
+ AllChem.EmbedMolecule(mol_3d, randomSeed=42, useRandomCoords=True)
118
+ AllChem.MMFFOptimizeMolecule(mol_3d)
119
+ sdf_data = Chem.MolToMolBlock(mol_3d)
120
+
121
+ viewer = py3Dmol.view(width=350, height=350)
122
+ viewer.setBackgroundColor('#FFFFFF')
123
+ viewer.addModel(sdf_data, "sdf")
124
+ viewer.setStyle({'stick': {}, 'sphere': {'scale': 0.25}})
125
+ viewer.zoomTo()
126
+ return viewer._make_html()
127
+ except Exception as e:
128
+ logger.error(f"Failed to generate 3D view for {smiles}: {e}")
129
+ return f"<p>Error generating 3D view: {e}</p>"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
+ # --- Core Application Logic ---
132
 
133
+ def run_masked_smiles_prediction(smiles_mask, substructure_smarts_highlight):
134
  """
135
+ Handles the logic for the masked SMILES prediction tab.
136
  """
 
 
 
 
 
 
 
 
 
 
 
137
  if fill_mask_tokenizer.mask_token not in smiles_mask:
138
+ st.error(f"Error: Input SMILES must contain a mask token (e.g., {fill_mask_tokenizer.mask_token}).")
139
  return
140
 
141
+ with st.spinner("Predicting completions..."):
142
+ try:
143
+ with torch.no_grad():
144
+ predictions = fill_mask_pipeline(smiles_mask, top_k=10)
145
+ except Exception as e:
146
+ st.error(f"An error occurred during prediction: {e}")
147
+ if torch.cuda.is_available(): torch.cuda.empty_cache()
148
+ return
149
 
150
+ results = []
151
+ for pred in predictions:
152
+ if len(results) >= 5: break
153
+ predicted_smiles = pred['sequence']
154
+ mol = get_mol(predicted_smiles)
155
+ if mol:
156
+ atom_matches = find_matches_one(mol, substructure_smarts_highlight)
157
+ results.append({
158
+ "smiles": predicted_smiles,
159
+ "score": f"{pred['score']:.4f}",
160
+ "image_2d": get_image_with_highlight(mol, atomset=atom_matches[0] if atom_matches else []),
161
+ "html_3d": generate_3d_view_html(predicted_smiles)
162
+ })
163
+
164
+ if torch.cuda.is_available(): torch.cuda.empty_cache()
165
+ st.session_state.prediction_results = results
166
 
 
 
 
167
 
168
+ # --- Streamlit UI Definition ---
 
169
 
170
+ st.title("🔬 ChemBERTa SMILES Utilities Dashboard (2D & 3D)")
171
+ st.markdown("A tool to predict masked tokens in SMILES strings and visualize molecules, powered by ChemBERTa and Streamlit.")
 
 
 
172
 
173
+ tab1, tab2 = st.tabs(["Masked SMILES Prediction", "Molecule Viewer (2D & 3D)"])
 
 
174
 
175
+ # --- Tab 1: Masked SMILES Prediction ---
176
+ with tab1:
177
+ st.header("Predict and Visualize Masked SMILES")
178
+ st.markdown("Enter a SMILES string with a `<mask>` token to predict possible completions.")
179
 
180
+ with st.form(key="prediction_form"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  col1, col2 = st.columns(2)
182
  with col1:
183
  smiles_input_masked = st.text_input(
184
+ "SMILES String with Mask",
185
+ value="C1=CC=CC<mask>C1",
186
+ help=f"The mask token is `{fill_mask_tokenizer.mask_token}`"
187
  )
188
  with col2:
189
  substructure_input = st.text_input(
190
+ "Substructure to Highlight (SMARTS)",
191
+ value="C=C",
192
+ help="Enter a SMARTS pattern to highlight in the 2D images."
193
  )
194
 
195
+ predict_button = st.form_submit_button("Predict and Visualize", use_container_width=True)
196
+
197
+ if predict_button:
198
+ run_masked_smiles_prediction(smiles_input_masked, substructure_input)
199
+
200
+ if 'prediction_results' in st.session_state and st.session_state.prediction_results:
201
+ results = st.session_state.prediction_results
202
+ st.subheader("Top 5 Valid Predictions")
203
+
204
+ # Display results in a table
205
+ df_data = [{"Predicted SMILES": r["smiles"], "Score": r["score"]} for r in results]
206
+ st.dataframe(pd.DataFrame(df_data), use_container_width=True)
207
+
208
+ st.markdown("---")
209
 
210
+ # Display molecule visualizations
211
+ for i, res in enumerate(results):
212
+ st.markdown(f"**Prediction {i+1}:** `{res['smiles']}` (Score: {res['score']})")
213
+ col1, col2 = st.columns(2)
214
+ with col1:
215
+ st.subheader("2D Structure")
216
+ if res["image_2d"]:
217
+ st.image(res["image_2d"], use_column_width=True)
218
+ else:
219
+ st.warning("Could not generate 2D image.")
220
+ with col2:
221
+ st.subheader("3D Interactive Structure")
222
+ if res["html_3d"]:
223
+ components.html(res["html_3d"], height=370)
224
+ else:
225
+ st.warning("Could not generate 3D view.")
226
+ st.markdown("---")
227
+
228
+ # --- Tab 2: Molecule Viewer ---
229
+ with tab2:
230
+ st.header("Visualize a Molecule from SMILES")
231
+ st.markdown("Enter a single SMILES string to display its 2D and 3D structures side-by-side.")
232
+
233
+ with st.form(key="viewer_form"):
234
+ smiles_input_viewer = st.text_input("SMILES String", value="CC(=O)Oc1ccccc1C(=O)O") # Aspirin
235
+ view_button = st.form_submit_button("View Molecule", use_container_width=True)
236
+
237
+ if view_button and smiles_input_viewer:
238
+ with st.spinner("Generating views..."):
239
+ mol = get_mol(smiles_input_viewer)
240
+ if not mol:
241
+ st.error("Invalid SMILES string provided.")
242
+ else:
243
+ st.subheader(f"Visualizations for: `{smiles_input_viewer}`")
244
+ col1, col2 = st.columns(2)
245
+ with col1:
246
+ st.subheader("2D Structure")
247
+ img_2d = MolToImage(mol, size=(450, 450), fitImage=True)
248
+ st.image(img_2d, use_column_width=True)
249
+ with col2:
250
+ st.subheader("3D Interactive Structure")
251
+ html_3d = generate_3d_view_html(smiles_input_viewer)
252
+ components.html(html_3d, height=470)
253