alidenewade commited on
Commit
e56ed1f
·
verified ·
1 Parent(s): 3923798

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +267 -242
app.py CHANGED
@@ -1,87 +1,20 @@
 
1
  import streamlit as st
2
- import pandas as pd
 
3
  from rdkit import Chem
4
  from rdkit.Chem import Draw, AllChem
5
- from rdkit.Chem.Draw import rdMolDraw2D
6
- import py3Dmol
7
  import io
8
  import base64
9
  import logging
10
-
11
- import torch
12
- from transformers import AutoModelForMaskedLM, AutoTokenizer, pipeline, BitsAndBytesConfig
13
 
14
  # Set up logging to monitor quantization effects
15
  logging.basicConfig(level=logging.INFO)
16
  logger = logging.getLogger(__name__)
17
 
18
- # --- Page Configuration ---
19
- st.set_page_config(
20
- page_title="Molecule Explorer & Predictor",
21
- page_icon="🔬",
22
- layout="wide",
23
- initial_sidebar_state="collapsed",
24
- )
25
-
26
- # Custom CSS for a professional, minimalist look (adapted from drug_app.txt)
27
- def apply_custom_styling():
28
- st.markdown(
29
- """
30
- <style>
31
- @import url('https://fonts.googleapis.com/css2?family=Roboto:wght@400;700&display=swap');
32
-
33
- html, body, [class*="st-"] {
34
- font-family: 'Roboto', sans-serif;
35
- }
36
-
37
- .stApp {
38
- background-color: rgb(28, 28, 28);
39
- color: white;
40
- }
41
-
42
- /* Tab styles */
43
- .stTabs [data-baseweb="tab-list"] {
44
- gap: 24px;
45
- }
46
-
47
- .stTabs [data-baseweb="tab"] {
48
- height: 50px;
49
- white-space: pre-wrap;
50
- background: none;
51
- border-radius: 0px;
52
- border-bottom: 2px solid #333;
53
- padding: 10px 4px;
54
- color: #AAA;
55
- }
56
-
57
- .stTabs [data-baseweb="tab"]:hover {
58
- background: #222;
59
- color: #FFF;
60
- }
61
-
62
- .stTabs [aria-selected="true"] {
63
- border-bottom: 2px solid #00A0FF; /* Highlight color for active tab */
64
- color: #FFF;
65
- }
66
-
67
- /* Button styles */
68
- .stButton>button {
69
- border-color: #00A0FF;
70
- color: #00A0FF;
71
- }
72
-
73
- .stButton>button:hover {
74
- border-color: #FFF;
75
- color: #FFF;
76
- background-color: #00A0FF;
77
- }
78
- </style>
79
- """,
80
- unsafe_allow_html=True
81
- )
82
-
83
- apply_custom_styling()
84
-
85
  # --- Quantization Configuration ---
86
  def get_quantization_config():
87
  """
@@ -111,10 +44,11 @@ def get_torch_dtype():
111
  else:
112
  return torch.float32 # Keep full precision on CPU
113
 
114
- # --- Optimized Model Loading with Streamlit Caching ---
115
- @st.cache_resource(show_spinner="Loading molecular language model...")
116
  def load_optimized_models():
117
- """Load models with quantization and other optimizations using Streamlit caching."""
 
118
  device = "cuda" if torch.cuda.is_available() else "cpu"
119
  torch_dtype = get_torch_dtype()
120
  quantization_config = get_quantization_config()
@@ -124,7 +58,7 @@ def load_optimized_models():
124
  # Model names
125
  model_name = "seyonec/PubChem10M_SMILES_BPE_450k"
126
 
127
- # Load tokenizer
128
  fill_mask_tokenizer = AutoTokenizer.from_pretrained(model_name)
129
 
130
  # Load model with quantization if available
@@ -134,6 +68,7 @@ def load_optimized_models():
134
 
135
  if quantization_config is not None and torch.cuda.is_available(): # Quantization typically for GPU
136
  model_kwargs["quantization_config"] = quantization_config
 
137
  model_kwargs["device_map"] = "auto"
138
  elif torch.cuda.is_available():
139
  model_kwargs["device_map"] = "auto" # For non-quantized GPU loading
@@ -141,216 +76,306 @@ def load_optimized_models():
141
  model_kwargs["device_map"] = None # For CPU
142
 
143
  try:
 
144
  fill_mask_model = AutoModelForMaskedLM.from_pretrained(
145
  model_name,
146
  **model_kwargs
147
  )
 
 
148
  fill_mask_model.eval()
149
 
 
 
150
  pipeline_device = fill_mask_model.device.index if hasattr(fill_mask_model.device, 'type') and fill_mask_model.device.type == "cuda" else -1
151
 
152
  fill_mask_pipeline = pipeline(
153
  'fill-mask',
154
  model=fill_mask_model,
155
  tokenizer=fill_mask_tokenizer,
156
- device=pipeline_device,
 
157
  )
 
158
  logger.info("Models loaded successfully with optimizations")
159
  return fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline
 
160
  except Exception as e:
161
  logger.error(f"Error loading optimized models: {e}")
 
162
  logger.info("Falling back to standard model loading...")
163
  return load_standard_models(model_name)
164
 
165
- @st.cache_resource(show_spinner="Loading standard molecular language model...")
166
- def load_standard_models(model_name="seyonec/PubChem10M_SMILES_BPE_450k"):
167
- """Fallback standard model loading without quantization using Streamlit caching."""
168
  fill_mask_tokenizer = AutoTokenizer.from_pretrained(model_name)
169
  fill_mask_model = AutoModelForMaskedLM.from_pretrained(model_name)
 
170
  device_idx = 0 if torch.cuda.is_available() else -1
171
  fill_mask_pipeline = pipeline('fill-mask', model=fill_mask_model, tokenizer=fill_mask_tokenizer, device=device_idx)
 
172
  if torch.cuda.is_available():
173
  fill_mask_model.to("cuda")
174
- logger.info("Standard models loaded successfully")
175
  return fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline
176
 
177
- # --- RDKit and Py3Dmol Visualization Functions ---
 
178
 
179
- def mol_to_svg(mol, size=(400, 300)):
180
- """Converts an RDKit molecule object to an SVG image string using default RDKit colors."""
181
- if not mol:
182
- return None
183
- drawer = rdMolDraw2D.MolDraw2DSVG(*size)
184
- # Removing custom color settings as per user request to use default RDKit colors
185
- # drawer.drawOptions().clearBackground = False # Keep background transparent/dark
186
- # drawer.drawOptions().addStereoAnnotation = True
187
- # drawer.drawOptions().baseFontSize = 0.8
188
-
189
- # # Set dark theme colors for RDKit drawing - REMOVED AS PER USER REQUEST
190
- # atom_colors = {
191
- # 6: (0.8, 0.8, 0.8), # Carbon (light gray)
192
- # 7: (0.2, 0.5, 1.0), # Nitrogen (blue)
193
- # 8: (1.0, 0.2, 0.2), # Oxygen (red)
194
- # 9: (0.2, 0.8, 0.2), # Fluorine (green)
195
- # 15: (1.0, 0.5, 0.0), # Phosphorus (orange)
196
- # 16: (1.0, 0.8, 0.0), # Sulfur (yellow)
197
- # 17: (0.2, 0.7, 0.2), # Chlorine (dark green)
198
- # 35: (0.5, 0.2, 0.8), # Bromine (purple)
199
- # 53: (0.8, 0.2, 0.5), # Iodine (pink/magenta)
200
- # }
201
- # # Set default atom color
202
- # drawer.drawOptions().setAtomColor(Chem.rdatomicnumlist.Get): (0.8, 0.8, 0.8) # Default to light gray for unknown atoms
203
- # for atom_num, color in atom_colors.items():
204
- # drawer.drawOptions().setAtomColor(atom_num, color)
205
-
206
- # drawer.drawOptions().bondColor = (0.7, 0.7, 0.7) # Bond color (medium gray)
207
- # drawer.drawOptions().highlightColour = (0.2, 0.6, 1.0) # Highlight color (blue)
208
-
209
- drawer.DrawMolecule(mol)
210
- drawer.FinishDrawing()
211
- svg = drawer.GetDrawingText()
212
- return svg
213
-
214
- def mol_to_sdf(mol):
215
- """Converts an RDKit molecule object to an SDF string."""
216
- if not mol:
217
  return None
218
- # Add hydrogens to the molecule
219
- mol_with_h = Chem.AddHs(mol)
220
-
221
- # Generate 3D coordinates using ETKDGv3, a common conformer generation method
222
- # MaxAttempts is increased for robustness, randomSeed for reproducibility
223
  try:
224
- AllChem.EmbedMolecule(mol_with_h, AllChem.ETKDGv3(), maxAttempts=50, randomSeed=42)
225
- # Optimize 3D coordinates using Universal Force Field (UFF)
226
- AllChem.UFFOptimizeMolecule(mol_with_h)
227
- sdf_string = Chem.MolToMolBlock(mol_with_h)
228
- return sdf_string
229
- except Exception as e:
230
- logger.error(f"Error generating 3D coordinates for SMILES: {Chem.MolToSmiles(mol)} - {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
  return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
 
233
- def visualize_molecule_3d(mol_sdf: str, width='100%', height=400):
234
  """
235
- Generates an interactive 3D molecule visualization using py3Dmol.
236
- Accepts an SDF string.
237
  """
238
- if not mol_sdf:
239
- return None
 
 
240
  try:
241
- viewer = py3Dmol.view(width=width, height=height)
242
- viewer.setBackgroundColor('#1C1C1C') # Dark background
243
- viewer.addModel(mol_sdf, "sdf")
244
- viewer.setStyle({'stick':{}, 'sphere':{'radius':0.3}}) # Stick and Sphere representation
245
- viewer.zoomTo()
246
- html_view = viewer._make_html()
247
- return html_view
248
  except Exception as e:
249
- st.error(f"Error generating 3D visualization: {e}")
250
- return None
 
251
 
252
- # --- Main Streamlit Application Layout ---
 
 
 
253
 
254
- st.title("🔬 Molecule Explorer & Predictor")
 
 
255
 
256
- # Initialize session state for consistent data across reruns
257
- if 'tokenizer' not in st.session_state:
258
- st.session_state.tokenizer, st.session_state.model, st.session_state.pipeline = load_optimized_models()
259
 
260
- tokenizer = st.session_state.tokenizer
261
- model = st.session_state.model
262
- fill_mask_pipeline = st.session_state.pipeline
263
 
264
- tab1, tab2 = st.tabs(["Molecule Viewer (2D & 3D)", "Masked SMILES Predictor"])
 
 
 
 
265
 
266
- with tab1:
267
- st.header("Visualize Molecules in 2D and 3D")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
 
269
- smiles_input = st.text_input("Enter SMILES string:", "CCO", help="e.g., CCO (ethanol), C1=CC=CC=C1 (benzene)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
 
271
- if st.button("View Molecule"):
272
- if smiles_input:
273
- mol = Chem.MolFromSmiles(smiles_input)
274
- if mol:
275
- st.subheader("2D Structure")
276
- svg = mol_to_svg(mol)
277
- if svg:
278
- st.image(svg, use_column_width=True)
279
- else:
280
- st.warning("Could not generate 2D image.")
281
-
282
- st.subheader("3D Structure (Interactive)")
283
- sdf_string = mol_to_sdf(mol)
284
- if sdf_string:
285
- html_3d = visualize_molecule_3d(sdf_string)
286
- if html_3d:
287
- st.components.v1.html(html_3d, width=700, height=500, scrolling=False)
 
 
 
 
 
 
 
 
 
288
  else:
289
- st.warning("Could not generate 3D visualization.")
290
- else:
291
- st.warning("Could not generate 3D SDF data.")
292
- else:
293
- st.error("Invalid SMILES string. Please enter a valid chemical structure.")
294
- else:
295
- st.info("Please enter a SMILES string to view the molecule.")
296
 
297
  with tab2:
298
- st.header("Masked SMILES Prediction")
299
-
300
- masked_smiles_input = st.text_input(
301
- "Enter masked SMILES string (use `<mask>` for the masked token):",
302
- "C1=CC=CC<mask>C1",
303
- help="Example: 'C1=CC=CC<mask>C1' (masked benzene), 'CCO<mask>C' (masked ether)"
304
- )
305
- top_k_predictions = st.slider("Number of predictions to show:", 1, 10, 5)
306
-
307
- if st.button("Predict Masked Token"):
308
- if masked_smiles_input and "<mask>" in masked_smiles_input:
309
- try:
310
- # Perform prediction using the loaded pipeline
311
- predictions = fill_mask_pipeline(masked_smiles_input, top_k=top_k_predictions)
312
-
313
- prediction_data = []
314
- for pred in predictions:
315
- token_str = pred['token_str']
316
- sequence = pred['sequence']
317
- score = pred['score']
318
-
319
- mol = Chem.MolFromSmiles(sequence)
320
- img_svg = None
321
- if mol:
322
- img_svg = mol_to_svg(mol, size=(200,150)) # Smaller image for table
323
-
324
- prediction_data.append({
325
- "Predicted Token": token_str,
326
- "Full SMILES": sequence,
327
- "Confidence Score": f"{score:.4f}",
328
- "Structure SVG": img_svg # Store SVG string
329
- })
330
-
331
- df_predictions = pd.DataFrame(prediction_data)
332
-
333
- st.subheader("Predictions:")
334
-
335
- # Create a version of the dataframe without the SVG for initial display
336
- display_df = df_predictions.drop(columns=["Structure SVG"])
337
- st.dataframe(display_df, use_container_width=True, hide_index=True)
338
-
339
- st.subheader("Predicted Structures:")
340
- # Determine the number of columns based on the number of predictions, up to a max
341
- num_cols = min(len(df_predictions), 5) # Display up to 5 images per row
342
- cols = st.columns(num_cols)
343
-
344
- for i, row in df_predictions.iterrows():
345
- with cols[i % num_cols]: # Distribute images into columns
346
- st.markdown(f"**{row['Predicted Token']}** (Score: {row['Confidence Score']})")
347
- if row['Structure SVG']:
348
- st.image(row['Structure SVG'], use_column_width='auto')
349
- else:
350
- st.write("*(Invalid SMILES)*")
351
-
352
- except Exception as e:
353
- st.error(f"An error occurred during prediction: {e}")
354
- st.info("Please ensure your masked SMILES is valid and contains `<mask>`.")
355
- else:
356
- st.info("Please enter a masked SMILES string (e.g., `C1=CC=CC<mask>C1`).")
 
1
+ # app.py
2
  import streamlit as st
3
+ import torch
4
+ from transformers import AutoModelForMaskedLM, AutoTokenizer, pipeline, BitsAndBytesConfig
5
  from rdkit import Chem
6
  from rdkit.Chem import Draw, AllChem
7
+ from rdkit.Chem.Draw import MolToImage
8
+ import pandas as pd
9
  import io
10
  import base64
11
  import logging
12
+ import py3Dmol
 
 
13
 
14
  # Set up logging to monitor quantization effects
15
  logging.basicConfig(level=logging.INFO)
16
  logger = logging.getLogger(__name__)
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  # --- Quantization Configuration ---
19
  def get_quantization_config():
20
  """
 
44
  else:
45
  return torch.float32 # Keep full precision on CPU
46
 
47
+ # --- Optimized Model Loading ---
48
+ @st.cache_resource
49
  def load_optimized_models():
50
+ """Load models with quantization and other optimizations.
51
+ Uses st.cache_resource to avoid reloading models on every rerun."""
52
  device = "cuda" if torch.cuda.is_available() else "cpu"
53
  torch_dtype = get_torch_dtype()
54
  quantization_config = get_quantization_config()
 
58
  # Model names
59
  model_name = "seyonec/PubChem10M_SMILES_BPE_450k"
60
 
61
+ # Load tokenizer (doesn't need quantization)
62
  fill_mask_tokenizer = AutoTokenizer.from_pretrained(model_name)
63
 
64
  # Load model with quantization if available
 
68
 
69
  if quantization_config is not None and torch.cuda.is_available(): # Quantization typically for GPU
70
  model_kwargs["quantization_config"] = quantization_config
71
+ # device_map="auto" is often used with bitsandbytes for automatic distribution
72
  model_kwargs["device_map"] = "auto"
73
  elif torch.cuda.is_available():
74
  model_kwargs["device_map"] = "auto" # For non-quantized GPU loading
 
76
  model_kwargs["device_map"] = None # For CPU
77
 
78
  try:
79
+ # Masked LM Model
80
  fill_mask_model = AutoModelForMaskedLM.from_pretrained(
81
  model_name,
82
  **model_kwargs
83
  )
84
+
85
+ # Set model to evaluation mode for inference
86
  fill_mask_model.eval()
87
 
88
+ # Create optimized pipeline
89
+ # Let pipeline infer device from model if possible, or set based on model's device
90
  pipeline_device = fill_mask_model.device.index if hasattr(fill_mask_model.device, 'type') and fill_mask_model.device.type == "cuda" else -1
91
 
92
  fill_mask_pipeline = pipeline(
93
  'fill-mask',
94
  model=fill_mask_model,
95
  tokenizer=fill_mask_tokenizer,
96
+ device=pipeline_device, # Use model's device
97
+ # torch_dtype=torch_dtype # Pipeline might infer this or it might conflict
98
  )
99
+
100
  logger.info("Models loaded successfully with optimizations")
101
  return fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline
102
+
103
  except Exception as e:
104
  logger.error(f"Error loading optimized models: {e}")
105
+ # Fallback to standard loading
106
  logger.info("Falling back to standard model loading...")
107
  return load_standard_models(model_name)
108
 
109
+ def load_standard_models(model_name):
110
+ """Fallback standard model loading without quantization."""
 
111
  fill_mask_tokenizer = AutoTokenizer.from_pretrained(model_name)
112
  fill_mask_model = AutoModelForMaskedLM.from_pretrained(model_name)
113
+ # Determine device for standard loading
114
  device_idx = 0 if torch.cuda.is_available() else -1
115
  fill_mask_pipeline = pipeline('fill-mask', model=fill_mask_model, tokenizer=fill_mask_tokenizer, device=device_idx)
116
+
117
  if torch.cuda.is_available():
118
  fill_mask_model.to("cuda")
119
+
120
  return fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline
121
 
122
+ # Load models with optimizations
123
+ fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline = load_optimized_models()
124
 
125
+ # --- Memory Management Utilities ---
126
+ def clear_gpu_cache():
127
+ """Clear CUDA cache to free up memory."""
128
+ if torch.cuda.is_available():
129
+ torch.cuda.empty_cache()
130
+
131
+ # --- Helper Functions from Notebook (adapted) ---
132
+ def get_mol(smiles):
133
+ """Converts SMILES to RDKit Mol object and Kekulizes it."""
134
+ mol = Chem.MolFromSmiles(smiles)
135
+ if mol is None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  return None
 
 
 
 
 
137
  try:
138
+ Chem.Kekulize(mol)
139
+ except: # Kekulization can fail for some structures
140
+ pass
141
+ return mol
142
+
143
+ def find_matches_one(mol, submol_smarts):
144
+ """Finds all matching atoms for a SMARTS pattern in a molecule."""
145
+ if not mol or not submol_smarts:
146
+ return []
147
+ submol = Chem.MolFromSmarts(submol_smarts)
148
+ if not submol:
149
+ return []
150
+ matches = mol.GetSubstructMatches(submol)
151
+ return matches
152
+
153
+ def get_image_with_highlight(mol, atomset=None, size=(300, 300)):
154
+ """Draws molecule with optional atom highlighting."""
155
+ if mol is None:
156
+ return None
157
+ highlight_color = (0, 1, 0, 0.5) # Green with some transparency
158
+
159
+ # Ensure atomset contains integers if not None or empty
160
+ valid_atomset = []
161
+ if atomset:
162
+ try:
163
+ valid_atomset = [int(a) for a in atomset]
164
+ except ValueError:
165
+ logger.warning(f"Invalid atom in atomset: {atomset}. Proceeding without highlighting problematic atoms.")
166
+ valid_atomset = [int(a) for a in atomset if str(a).isdigit()] # Filter out non-integers
167
+
168
+ img = MolToImage(mol, size=size, fitImage=True,
169
+ highlightAtoms=valid_atomset if valid_atomset else [],
170
+ highlightAtomColors={i: highlight_color for i in valid_atomset} if valid_atomset else {})
171
+ return img
172
+
173
+ def mol_to_sdf_string(mol):
174
+ """Converts an RDKit Mol object to an SDF string."""
175
+ if mol is None:
176
  return None
177
+ # Add 3D coordinates if not present
178
+ AllChem.EmbedMolecule(mol, AllChem.ETKDG())
179
+ AllChem.UFFOptimizeMolecule(mol)
180
+ return Chem.MolToMolBlock(mol)
181
+
182
+ def render_mol_3d(sdf_string, width=300, height=300):
183
+ """Renders a 3D molecule using py3Dmol."""
184
+ if sdf_string is None:
185
+ return ""
186
+
187
+ viewer = py3Dmol.view(width=width, height=height)
188
+ viewer.addModel(sdf_string, 'sdf')
189
+ viewer.setStyle({'stick':{}}) # Display as sticks
190
+ viewer.zoomTo()
191
+ # Embed the viewer HTML into Streamlit
192
+ return viewer.to_html()
193
+
194
+ # --- Streamlit Interface Functions ---
195
 
196
+ def predict_and_visualize_masked_smiles(smiles_mask, substructure_smarts_highlight="CC=CC"):
197
  """
198
+ Predicts masked tokens in a SMILES string, shows scores, and visualizes molecules.
199
+ Returns 5 image paths and a status message.
200
  """
201
+ if fill_mask_tokenizer.mask_token not in smiles_mask:
202
+ st.error("Error: Input SMILES must contain a mask token (e.g., <mask>).")
203
+ return pd.DataFrame(), [None]*5, [None]*5, "Error: Input SMILES must contain a mask token (e.g., <mask>)."
204
+
205
  try:
206
+ with torch.no_grad():
207
+ predictions = fill_mask_pipeline(smiles_mask, top_k=10)
 
 
 
 
 
208
  except Exception as e:
209
+ clear_gpu_cache()
210
+ st.error(f"Error during prediction: {str(e)}")
211
+ return pd.DataFrame(), [None]*5, [None]*5, f"Error during prediction: {str(e)}"
212
 
213
+ results_data = []
214
+ image_2d_list = []
215
+ image_3d_list = []
216
+ valid_predictions_count = 0
217
 
218
+ for pred in predictions:
219
+ if valid_predictions_count >= 5:
220
+ break
221
 
222
+ predicted_smiles = pred['sequence']
223
+ score = pred['score']
 
224
 
225
+ mol = get_mol(predicted_smiles)
226
+ if mol:
227
+ results_data.append({"Predicted SMILES": predicted_smiles, "Score": f"{score:.4f}"})
228
 
229
+ atom_matches_indices = []
230
+ if substructure_smarts_highlight:
231
+ matches = find_matches_one(mol, substructure_smarts_highlight)
232
+ if matches:
233
+ atom_matches_indices = list(matches[0]) # Highlight first match
234
 
235
+ img_2d = get_image_with_highlight(mol, atomset=atom_matches_indices)
236
+ image_2d_list.append(img_2d)
237
+
238
+ # For 3D, we need an SDF string
239
+ sdf_string = mol_to_sdf_string(mol)
240
+ img_3d_html = render_mol_3d(sdf_string, width=300, height=300)
241
+ image_3d_list.append(img_3d_html)
242
+
243
+ valid_predictions_count += 1
244
+
245
+ # Pad image lists if fewer than 5 valid predictions
246
+ while len(image_2d_list) < 5:
247
+ image_2d_list.append(None)
248
+ image_3d_list.append(None)
249
+
250
+ df_results = pd.DataFrame(results_data)
251
+
252
+ clear_gpu_cache()
253
+
254
+ status_message = "Prediction successful." if valid_predictions_count > 0 else "No valid molecules found for top predictions."
255
+ return df_results, image_2d_list, image_3d_list, status_message
256
+
257
+
258
+ def display_molecule_with_3d(smiles_string):
259
+ """
260
+ Displays a 2D image and 3D visualization of a molecule from its SMILES string.
261
+ """
262
+ if not smiles_string:
263
+ return None, None, "Please enter a SMILES string."
264
+ mol = get_mol(smiles_string)
265
+ if mol is None:
266
+ return None, None, "Invalid SMILES string."
267
+
268
+ img_2d = MolToImage(mol, size=(400, 400), fitImage=True)
269
+
270
+ sdf_string = mol_to_sdf_string(mol)
271
+ img_3d_html = render_mol_3d(sdf_string, width=400, height=400)
272
 
273
+ return img_2d, img_3d_html, "Molecule displayed."
274
+
275
+
276
+ # --- Streamlit UI Definition ---
277
+
278
+ # Set wide mode and background color
279
+ st.set_page_config(layout="wide")
280
+
281
+ st.markdown(
282
+ """
283
+ <style>
284
+ .stApp {
285
+ background-color: rgb(28,28,28);
286
+ color: white; /* Ensure text is visible on dark background */
287
+ }
288
+ .stDataFrame {
289
+ color: black; /* Default DataFrame text color */
290
+ }
291
+ h1, h2, h3, h4, h5, h6, .stMarkdown {
292
+ color: white;
293
+ }
294
+ .css-1d391kg, .css-1dp5dn1 { /* Target Streamlit's main content and sidebar */
295
+ color: white;
296
+ }
297
+ .streamlit-expanderContent {
298
+ background-color: rgb(40,40,40); /* Slightly lighter background for expanders */
299
+ border-radius: 10px;
300
+ padding: 10px;
301
+ }
302
+ /* Style for text inputs and buttons */
303
+ .stTextInput>div>div>input {
304
+ background-color: rgb(50,50,50);
305
+ color: white;
306
+ border-radius: 5px;
307
+ border: 1px solid rgb(70,70,70);
308
+ }
309
+ .stButton>button {
310
+ background-color: rgb(0,128,255); /* Blue button */
311
+ color: white;
312
+ border-radius: 8px;
313
+ padding: 10px 20px;
314
+ border: none;
315
+ transition: background-color 0.3s ease;
316
+ }
317
+ .stButton>button:hover {
318
+ background-color: rgb(0,100,200);
319
+ }
320
+ </style>
321
+ """,
322
+ unsafe_allow_html=True
323
+ )
324
+
325
+
326
+ st.title("ChemBERTa SMILES Utilities Dashboard")
327
+
328
+ tab1, tab2 = st.tabs(["Masked SMILES Prediction", "Molecule Viewer"])
329
+
330
+ with tab1:
331
+ st.markdown("Enter a SMILES string with a `<mask>` token (e.g., `C1=CC=CC<mask>C1`) to predict possible completions.")
332
 
333
+ col1, col2 = st.columns([2, 1])
334
+ with col1:
335
+ smiles_input_masked = st.text_input("SMILES String with Mask", value="C1=CC=CC<mask>C1")
336
+ with col2:
337
+ substructure_input = st.text_input("Substructure to Highlight (SMARTS)", value="C=C")
338
+
339
+ if st.button("Predict and Visualize", key="predict_button"):
340
+ with st.spinner("Predicting and visualizing..."):
341
+ df_predictions, img_2d_list, img_3d_list, status_msg = predict_and_visualize_masked_smiles(
342
+ smiles_input_masked, substructure_input
343
+ )
344
+ st.write(status_msg)
345
+
346
+ if not df_predictions.empty:
347
+ st.subheader("Top Predictions & Scores")
348
+ st.dataframe(df_predictions, use_container_width=True)
349
+
350
+ st.subheader("Predicted Molecule Visualizations (Top 5 Valid)")
351
+ for i in range(5):
352
+ if img_2d_list[i] is not None:
353
+ st.markdown(f"**Prediction {i+1}**")
354
+ cols_img = st.columns(2)
355
+ with cols_img[0]:
356
+ st.image(img_2d_list[i], caption=f"2D Prediction {i+1}", use_column_width=True)
357
+ with cols_img[1]:
358
+ st.components.v1.html(img_3d_list[i], height=300)
359
  else:
360
+ if i < len(df_predictions): # Only show 'No visualization' if there was a prediction attempt
361
+ st.markdown(f"**Prediction {i+1}**: No visualization available (invalid SMILES or error).")
362
+
 
 
 
 
363
 
364
  with tab2:
365
+ st.markdown("Enter a SMILES string to display its 2D and 3D structure.")
366
+ smiles_input_viewer = st.text_input("SMILES String", value="C1=CC=CC=C1", key="viewer_smiles_input")
367
+
368
+ if st.button("View Molecule", key="view_button"):
369
+ with st.spinner("Displaying molecule..."):
370
+ img_2d_viewer, img_3d_viewer_html, status_viewer_msg = display_molecule_with_3d(smiles_input_viewer)
371
+ st.write(status_viewer_msg)
372
+
373
+ if img_2d_viewer is not None:
374
+ cols_viewer = st.columns(2)
375
+ with cols_viewer[0]:
376
+ st.image(img_2d_viewer, caption="2D Molecule Structure", use_column_width=True)
377
+ with cols_viewer[1]:
378
+ st.components.v1.html(img_3d_viewer_html, height=400)
379
+ else:
380
+ st.warning("Could not display molecule. Please check the SMILES string.")
381
+