alidenewade commited on
Commit
98e9d9e
·
verified ·
1 Parent(s): 87445f0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +229 -177
app.py CHANGED
@@ -1,73 +1,137 @@
1
- # app.py
2
  import streamlit as st
3
  import torch
4
- from transformers import AutoModelForMaskedLM, AutoTokenizer, pipeline
5
  from rdkit import Chem
6
- from rdkit.Chem import Draw, AllChem
7
  from rdkit.Chem.Draw import MolToImage
8
  import pandas as pd
9
  import io
10
  import base64
11
  import logging
12
  import py3Dmol
 
13
 
14
- # Set up logging to monitor effects
15
  logging.basicConfig(level=logging.INFO)
16
  logger = logging.getLogger(__name__)
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  # --- Optimized Model Loading ---
19
  @st.cache_resource
20
  def load_optimized_models():
21
- """Load models for CPU directly, bypassing quantization and GPU checks."""
22
- device = "cpu" # Force device to CPU
23
- torch_dtype = torch.float32 # Force full precision for CPU
 
24
 
25
  logger.info(f"Loading models on device: {device} with dtype: {torch_dtype}")
26
 
27
  # Model names
28
  model_name = "seyonec/PubChem10M_SMILES_BPE_450k"
29
 
30
- # Load tokenizer
31
  fill_mask_tokenizer = AutoTokenizer.from_pretrained(model_name)
32
 
33
- # Load model with standard settings for CPU
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  try:
 
35
  fill_mask_model = AutoModelForMaskedLM.from_pretrained(
36
  model_name,
37
- torch_dtype=torch_dtype,
38
- device_map=None # No device mapping for plain CPU
39
  )
40
 
41
  # Set model to evaluation mode for inference
42
  fill_mask_model.eval()
43
 
44
- # Create pipeline for CPU
 
 
 
45
  fill_mask_pipeline = pipeline(
46
  'fill-mask',
47
  model=fill_mask_model,
48
  tokenizer=fill_mask_tokenizer,
49
- device=-1 # -1 means CPU
50
  )
51
 
52
- logger.info("Models loaded successfully for CPU.")
53
  return fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline
54
 
55
  except Exception as e:
56
- logger.error(f"Error loading models on CPU: {e}")
57
- st.error(f"Failed to load language model. Please try again. Error: {e}")
58
- # Re-raise or handle as appropriate for app startup
59
- raise # Critical error, app cannot proceed
 
 
 
 
 
 
 
 
60
 
61
- # Load models with optimizations
62
- fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline = load_optimized_models()
 
 
63
 
64
- # --- Memory Management Utilities (now mostly a placeholder for CPU) ---
65
  def clear_gpu_cache():
66
- """Placeholder for GPU cache clearing. Not effective on CPU."""
67
  if torch.cuda.is_available():
68
  torch.cuda.empty_cache()
69
 
70
- # --- Helper Functions from Notebook (adapted) ---
71
  def get_mol(smiles):
72
  """Converts SMILES to RDKit Mol object and Kekulizes it."""
73
  mol = Chem.MolFromSmiles(smiles)
@@ -109,50 +173,66 @@ def get_image_with_highlight(mol, atomset=None, size=(300, 300)):
109
  highlightAtomColors={i: highlight_color for i in valid_atomset} if valid_atomset else {})
110
  return img
111
 
112
- def mol_to_sdf_string(mol):
113
- """Converts an RDKit Mol object to an SDF string."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  if mol is None:
115
  return None
116
- # Add 3D coordinates if not present
117
- AllChem.EmbedMolecule(mol, AllChem.ETKDG())
118
- AllChem.UFFOptimizeMolecule(mol)
119
- return Chem.MolToMolBlock(mol)
120
-
121
- def render_mol_3d(sdf_string, width=300, height=300):
122
- """Renders a 3D molecule using py3Dmol."""
123
- if sdf_string is None:
124
- return ""
125
 
126
- viewer = py3Dmol.view(width=width, height=height)
127
- viewer.setBackgroundColor('#1C1C1C')
128
- viewer.addModel(sdf_string, 'sdf')
129
- viewer.setStyle({'stick':{}}) # Display as sticks
 
 
 
130
  viewer.zoomTo()
131
- # Embed the viewer HTML into Streamlit
132
- return viewer.to_html()
133
 
134
  # --- Streamlit Interface Functions ---
135
 
136
  def predict_and_visualize_masked_smiles(smiles_mask, substructure_smarts_highlight="CC=CC"):
137
  """
138
  Predicts masked tokens in a SMILES string, shows scores, and visualizes molecules.
139
- Returns 5 image paths and a status message.
140
  """
 
 
 
141
  if fill_mask_tokenizer.mask_token not in smiles_mask:
142
  st.error("Error: Input SMILES must contain a mask token (e.g., <mask>).")
143
- return pd.DataFrame(), [None]*5, [None]*5, "Error: Input SMILES must contain a mask token (e.g., <mask>)."
144
 
145
  try:
 
146
  with torch.no_grad():
147
  predictions = fill_mask_pipeline(smiles_mask, top_k=10)
148
  except Exception as e:
149
  clear_gpu_cache()
150
  st.error(f"Error during prediction: {str(e)}")
151
- return pd.DataFrame(), [None]*5, [None]*5, f"Error during prediction: {str(e)}"
152
 
153
  results_data = []
154
- image_2d_list = []
155
- image_3d_list = []
156
  valid_predictions_count = 0
157
 
158
  for pred in predictions:
@@ -165,157 +245,129 @@ def predict_and_visualize_masked_smiles(smiles_mask, substructure_smarts_highlig
165
  mol = get_mol(predicted_smiles)
166
  if mol:
167
  results_data.append({"Predicted SMILES": predicted_smiles, "Score": f"{score:.4f}"})
 
 
 
 
 
 
168
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  atom_matches_indices = []
170
  if substructure_smarts_highlight:
171
  matches = find_matches_one(mol, substructure_smarts_highlight)
172
  if matches:
173
- atom_matches_indices = list(matches[0]) # Highlight first match
174
-
175
- img_2d = get_image_with_highlight(mol, atomset=atom_matches_indices)
176
- image_2d_list.append(img_2d)
177
-
178
- # For 3D, we need an SDF string
179
- sdf_string = mol_to_sdf_string(mol)
180
- img_3d_html = render_mol_3d(sdf_string, width=300, height=300)
181
- image_3d_list.append(img_3d_html)
182
 
183
- valid_predictions_count += 1
184
-
185
- # Pad image lists if fewer than 5 valid predictions
186
- while len(image_2d_list) < 5:
187
- image_2d_list.append(None)
188
- image_3d_list.append(None)
189
-
190
- df_results = pd.DataFrame(results_data)
 
 
 
 
 
 
 
191
 
 
192
  clear_gpu_cache()
 
193
 
194
- status_message = "Prediction successful." if valid_predictions_count > 0 else "No valid molecules found for top predictions."
195
- return df_results, image_2d_list, image_3d_list, status_message
196
-
197
-
198
- def display_molecule_with_3d(smiles_string):
199
  """
200
- Displays a 2D image and 3D visualization of a molecule from its SMILES string.
201
  """
202
  if not smiles_string:
203
- return None, None, "Please enter a SMILES string."
 
 
204
  mol = get_mol(smiles_string)
205
  if mol is None:
206
- return None, None, "Invalid SMILES string."
207
-
208
- img_2d = MolToImage(mol, size=(400, 400), fitImage=True)
209
 
210
- sdf_string = mol_to_sdf_string(mol)
211
- img_3d_html = render_mol_3d(sdf_string, width=400, height=400)
212
 
213
- return img_2d, img_3d_html, "Molecule displayed."
214
-
215
-
216
- # --- Streamlit UI Definition ---
217
-
218
- # Set wide mode and background color
219
- st.set_page_config(layout="wide")
220
-
221
- st.markdown(
222
- """
223
- <style>
224
- .stApp {
225
- background-color: rgb(28,28,28);
226
- color: white; /* Ensure text is visible on dark background */
227
- }
228
- .stDataFrame {
229
- color: black; /* Default DataFrame text color */
230
- }
231
- h1, h2, h3, h4, h5, h6, .stMarkdown {
232
- color: white;
233
- }
234
- .css-1d391kg, .css-1dp5dn1 { /* Target Streamlit's main content and sidebar */
235
- color: white;
236
- }
237
- .streamlit-expanderContent {
238
- background-color: rgb(40,40,40); /* Slightly lighter background for expanders */
239
- border-radius: 10px;
240
- padding: 10px;
241
- }
242
- /* Style for text inputs and buttons */
243
- .stTextInput>div>div>input {
244
- background-color: rgb(50,50,50);
245
- color: white;
246
- border-radius: 5px;
247
- border: 1px solid rgb(70,70,70);
248
- }
249
- .stButton>button {
250
- background-color: rgb(0,128,255); /* Blue button */
251
- color: white;
252
- border-radius: 8px;
253
- padding: 10px 20px;
254
- border: none;
255
- transition: background-color 0.3s ease;
256
- }
257
- .stButton>button:hover {
258
- background-color: rgb(0,100,200);
259
- }
260
- </style>
261
- """,
262
- unsafe_allow_html=True
263
- )
264
-
265
-
266
- st.title("ChemBERTa SMILES Utilities Dashboard")
267
-
268
- tab1, tab2 = st.tabs(["Masked SMILES Prediction", "Molecule Viewer"])
269
-
270
- with tab1:
271
- st.markdown("Enter a SMILES string with a `<mask>` token (e.g., `C1=CC=CC<mask>C1`) to predict possible completions.")
272
 
273
- col1, col2 = st.columns([2, 1])
274
  with col1:
275
- smiles_input_masked = st.text_input("SMILES String with Mask", value="C1=CC=CC<mask>C1")
 
 
 
276
  with col2:
277
- substructure_input = st.text_input("Substructure to Highlight (SMARTS)", value="C=C")
278
-
279
- if st.button("Predict and Visualize", key="predict_button"):
280
- with st.spinner("Predicting and visualizing..."):
281
- df_predictions, img_2d_list, img_3d_list, status_msg = predict_and_visualize_masked_smiles(
282
- smiles_input_masked, substructure_input
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
  )
284
- st.write(status_msg)
285
-
286
- if not df_predictions.empty:
287
- st.subheader("Top Predictions & Scores")
288
- st.dataframe(df_predictions, use_container_width=True)
289
-
290
- st.subheader("Predicted Molecule Visualizations (Top 5 Valid)")
291
- for i in range(5):
292
- if img_2d_list[i] is not None:
293
- st.markdown(f"**Prediction {i+1}**")
294
- cols_img = st.columns(2)
295
- with cols_img[0]:
296
- st.image(img_2d_list[i], caption=f"2D Prediction {i+1}", use_column_width=True)
297
- with cols_img[1]:
298
- st.components.v1.html(img_3d_list[i], height=300)
299
- else:
300
- if i < len(df_predictions): # Only show 'No visualization' if there was a prediction attempt
301
- st.markdown(f"**Prediction {i+1}**: No visualization available (invalid SMILES or error).")
302
-
303
-
304
- with tab2:
305
- st.markdown("Enter a SMILES string to display its 2D and 3D structure.")
306
- smiles_input_viewer = st.text_input("SMILES String", value="C1=CC=CC=C1", key="viewer_smiles_input")
307
 
308
- if st.button("View Molecule", key="view_button"):
309
- with st.spinner("Displaying molecule..."):
310
- img_2d_viewer, img_3d_viewer_html, status_viewer_msg = display_molecule_with_3d(smiles_input_viewer)
311
- st.write(status_viewer_msg)
312
-
313
- if img_2d_viewer is not None:
314
- cols_viewer = st.columns(2)
315
- with cols_viewer[0]:
316
- st.image(img_2d_viewer, caption="2D Molecule Structure", use_column_width=True)
317
- with cols_viewer[1]:
318
- st.components.v1.html(img_3d_viewer_html, height=400)
319
- else:
320
- st.warning("Could not display molecule. Please check the SMILES string.")
321
 
 
 
 
 
1
  import streamlit as st
2
  import torch
3
+ from transformers import AutoModelForMaskedLM, AutoTokenizer, pipeline, BitsAndBytesConfig
4
  from rdkit import Chem
5
+ from rdkit.Chem import Draw, rdFMCS, AllChem
6
  from rdkit.Chem.Draw import MolToImage
7
  import pandas as pd
8
  import io
9
  import base64
10
  import logging
11
  import py3Dmol
12
+ from stmol import showmol
13
 
14
+ # Set up logging to monitor quantization effects
15
  logging.basicConfig(level=logging.INFO)
16
  logger = logging.getLogger(__name__)
17
 
18
+ # Page configuration
19
+ st.set_page_config(
20
+ page_title="ChemBERTa SMILES Utilities Dashboard",
21
+ page_icon="🧪",
22
+ layout="wide"
23
+ )
24
+
25
+ # --- Quantization Configuration ---
26
+ @st.cache_resource
27
+ def get_quantization_config():
28
+ """
29
+ Configure 8-bit quantization for model optimization.
30
+ Falls back gracefully if bitsandbytes is not available.
31
+ """
32
+ try:
33
+ # 8-bit quantization configuration - good balance of speed and quality
34
+ quantization_config = BitsAndBytesConfig(
35
+ load_in_8bit=True,
36
+ bnb_8bit_compute_dtype=torch.float16,
37
+ bnb_8bit_use_double_quant=True, # Nested quantization for better compression
38
+ )
39
+ logger.info("8-bit quantization configuration loaded successfully")
40
+ return quantization_config
41
+ except ImportError:
42
+ logger.warning("bitsandbytes not available, falling back to standard loading")
43
+ return None
44
+ except Exception as e:
45
+ logger.warning(f"Quantization setup failed: {e}, using standard loading")
46
+ return None
47
+
48
+ def get_torch_dtype():
49
+ """Get appropriate torch dtype based on available hardware."""
50
+ if torch.cuda.is_available():
51
+ return torch.float16 # Use half precision on GPU
52
+ else:
53
+ return torch.float32 # Keep full precision on CPU
54
+
55
  # --- Optimized Model Loading ---
56
  @st.cache_resource
57
  def load_optimized_models():
58
+ """Load models with quantization and other optimizations."""
59
+ device = "cuda" if torch.cuda.is_available() else "cpu"
60
+ torch_dtype = get_torch_dtype()
61
+ quantization_config = get_quantization_config()
62
 
63
  logger.info(f"Loading models on device: {device} with dtype: {torch_dtype}")
64
 
65
  # Model names
66
  model_name = "seyonec/PubChem10M_SMILES_BPE_450k"
67
 
68
+ # Load tokenizer (doesn't need quantization)
69
  fill_mask_tokenizer = AutoTokenizer.from_pretrained(model_name)
70
 
71
+ # Load model with quantization if available
72
+ model_kwargs = {
73
+ "torch_dtype": torch_dtype,
74
+ }
75
+
76
+ if quantization_config is not None and torch.cuda.is_available(): # Quantization typically for GPU
77
+ model_kwargs["quantization_config"] = quantization_config
78
+ # device_map="auto" is often used with bitsandbytes for automatic distribution
79
+ model_kwargs["device_map"] = "auto"
80
+ elif torch.cuda.is_available():
81
+ model_kwargs["device_map"] = "auto" # For non-quantized GPU loading
82
+ else:
83
+ model_kwargs["device_map"] = None # For CPU
84
+
85
  try:
86
+ # Masked LM Model
87
  fill_mask_model = AutoModelForMaskedLM.from_pretrained(
88
  model_name,
89
+ **model_kwargs
 
90
  )
91
 
92
  # Set model to evaluation mode for inference
93
  fill_mask_model.eval()
94
 
95
+ # Create optimized pipeline
96
+ # Let pipeline infer device from model if possible, or set based on model's device
97
+ pipeline_device = fill_mask_model.device.index if hasattr(fill_mask_model.device, 'type') and fill_mask_model.device.type == "cuda" else -1
98
+
99
  fill_mask_pipeline = pipeline(
100
  'fill-mask',
101
  model=fill_mask_model,
102
  tokenizer=fill_mask_tokenizer,
103
+ device=pipeline_device, # Use model's device
104
  )
105
 
106
+ logger.info("Models loaded successfully with optimizations")
107
  return fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline
108
 
109
  except Exception as e:
110
+ logger.error(f"Error loading optimized models: {e}")
111
+ # Fallback to standard loading
112
+ logger.info("Falling back to standard model loading...")
113
+ return load_standard_models(model_name)
114
+
115
+ def load_standard_models(model_name):
116
+ """Fallback standard model loading without quantization."""
117
+ fill_mask_tokenizer = AutoTokenizer.from_pretrained(model_name)
118
+ fill_mask_model = AutoModelForMaskedLM.from_pretrained(model_name)
119
+ # Determine device for standard loading
120
+ device_idx = 0 if torch.cuda.is_available() else -1
121
+ fill_mask_pipeline = pipeline('fill-mask', model=fill_mask_model, tokenizer=fill_mask_tokenizer, device=device_idx)
122
 
123
+ if torch.cuda.is_available():
124
+ fill_mask_model.to("cuda")
125
+
126
+ return fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline
127
 
128
+ # --- Memory Management Utilities ---
129
  def clear_gpu_cache():
130
+ """Clear CUDA cache to free up memory."""
131
  if torch.cuda.is_available():
132
  torch.cuda.empty_cache()
133
 
134
+ # --- Helper Functions ---
135
  def get_mol(smiles):
136
  """Converts SMILES to RDKit Mol object and Kekulizes it."""
137
  mol = Chem.MolFromSmiles(smiles)
 
173
  highlightAtomColors={i: highlight_color for i in valid_atomset} if valid_atomset else {})
174
  return img
175
 
176
+ def generate_3d_structure(mol):
177
+ """Generate 3D coordinates for a molecule."""
178
+ if mol is None:
179
+ return None
180
+
181
+ # Create a copy to avoid modifying the original
182
+ mol_3d = Chem.Mol(mol)
183
+
184
+ # Add hydrogens
185
+ mol_3d = Chem.AddHs(mol_3d)
186
+
187
+ # Generate 3D coordinates
188
+ try:
189
+ AllChem.EmbedMolecule(mol_3d, randomSeed=42)
190
+ AllChem.UFFOptimizeMolecule(mol_3d)
191
+ return mol_3d
192
+ except:
193
+ # If 3D generation fails, return None
194
+ return None
195
+
196
+ def mol_to_3d_html(mol):
197
+ """Convert molecule to 3D HTML representation using py3Dmol."""
198
  if mol is None:
199
  return None
 
 
 
 
 
 
 
 
 
200
 
201
+ # Generate SDF string
202
+ sdf = Chem.MolToMolBlock(mol)
203
+
204
+ # Create 3D viewer
205
+ viewer = py3Dmol.view(width=400, height=400)
206
+ viewer.addModel(sdf, 'sdf')
207
+ viewer.setStyle({'stick': {}})
208
  viewer.zoomTo()
209
+
210
+ return viewer
211
 
212
  # --- Streamlit Interface Functions ---
213
 
214
  def predict_and_visualize_masked_smiles(smiles_mask, substructure_smarts_highlight="CC=CC"):
215
  """
216
  Predicts masked tokens in a SMILES string, shows scores, and visualizes molecules.
 
217
  """
218
+ # Load models
219
+ fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline = load_optimized_models()
220
+
221
  if fill_mask_tokenizer.mask_token not in smiles_mask:
222
  st.error("Error: Input SMILES must contain a mask token (e.g., <mask>).")
223
+ return
224
 
225
  try:
226
+ # Use torch.no_grad() for inference to save memory
227
  with torch.no_grad():
228
  predictions = fill_mask_pipeline(smiles_mask, top_k=10)
229
  except Exception as e:
230
  clear_gpu_cache()
231
  st.error(f"Error during prediction: {str(e)}")
232
+ return
233
 
234
  results_data = []
235
+ valid_predictions = []
 
236
  valid_predictions_count = 0
237
 
238
  for pred in predictions:
 
245
  mol = get_mol(predicted_smiles)
246
  if mol:
247
  results_data.append({"Predicted SMILES": predicted_smiles, "Score": f"{score:.4f}"})
248
+ valid_predictions.append((mol, predicted_smiles, score))
249
+ valid_predictions_count += 1
250
+
251
+ if valid_predictions_count == 0:
252
+ st.warning("No valid molecules found for top predictions.")
253
+ return
254
 
255
+ # Display results table
256
+ df_results = pd.DataFrame(results_data)
257
+ st.subheader("Top Predictions & Scores")
258
+ st.dataframe(df_results, use_container_width=True)
259
+
260
+ # Display molecule visualizations
261
+ st.subheader("Predicted Molecule Visualizations")
262
+
263
+ for i, (mol, smiles, score) in enumerate(valid_predictions):
264
+ st.write(f"**Prediction {i+1}:** {smiles} (Score: {score:.4f})")
265
+
266
+ col1, col2 = st.columns(2)
267
+
268
+ with col1:
269
+ st.write("**2D Structure:**")
270
  atom_matches_indices = []
271
  if substructure_smarts_highlight:
272
  matches = find_matches_one(mol, substructure_smarts_highlight)
273
  if matches:
274
+ atom_matches_indices = list(matches[0])
 
 
 
 
 
 
 
 
275
 
276
+ img_2d = get_image_with_highlight(mol, atomset=atom_matches_indices)
277
+ if img_2d:
278
+ st.image(img_2d, use_column_width=True)
279
+
280
+ with col2:
281
+ st.write("**3D Structure:**")
282
+ mol_3d = generate_3d_structure(mol)
283
+ if mol_3d:
284
+ viewer_3d = mol_to_3d_html(mol_3d)
285
+ if viewer_3d:
286
+ showmol(viewer_3d, height=400, width=400)
287
+ else:
288
+ st.write("3D structure generation failed for this molecule.")
289
+
290
+ st.divider()
291
 
292
+ # Clear cache after inference
293
  clear_gpu_cache()
294
+ st.success("Prediction successful!")
295
 
296
+ def display_molecule_image(smiles_string):
 
 
 
 
297
  """
298
+ Displays both 2D and 3D images of a molecule from its SMILES string.
299
  """
300
  if not smiles_string:
301
+ st.error("Please enter a SMILES string.")
302
+ return
303
+
304
  mol = get_mol(smiles_string)
305
  if mol is None:
306
+ st.error("Invalid SMILES string.")
307
+ return
 
308
 
309
+ st.success("Molecule displayed successfully!")
 
310
 
311
+ col1, col2 = st.columns(2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
312
 
 
313
  with col1:
314
+ st.subheader("2D Structure")
315
+ img_2d = MolToImage(mol, size=(400, 400), fitImage=True)
316
+ st.image(img_2d, use_column_width=True)
317
+
318
  with col2:
319
+ st.subheader("3D Structure")
320
+ mol_3d = generate_3d_structure(mol)
321
+ if mol_3d:
322
+ viewer_3d = mol_to_3d_html(mol_3d)
323
+ if viewer_3d:
324
+ showmol(viewer_3d, height=400, width=400)
325
+ else:
326
+ st.write("3D structure generation failed for this molecule.")
327
+
328
+ # --- Main Streamlit App ---
329
+ def main():
330
+ st.title("🧪 ChemBERTa SMILES Utilities Dashboard")
331
+
332
+ # Sidebar for navigation
333
+ st.sidebar.title("Navigation")
334
+ tab_selection = st.sidebar.selectbox(
335
+ "Choose a tool:",
336
+ ["Masked SMILES Prediction", "Molecule Viewer"]
337
+ )
338
+
339
+ if tab_selection == "Masked SMILES Prediction":
340
+ st.header("Masked SMILES Prediction")
341
+ st.markdown("Enter a SMILES string with a `<mask>` token (e.g., `C1=CC=CC<mask>C1`) to predict possible completions.")
342
+
343
+ col1, col2 = st.columns(2)
344
+ with col1:
345
+ smiles_input_masked = st.text_input(
346
+ "SMILES String with Mask",
347
+ value="C1=CC=CC<mask>C1"
348
  )
349
+ with col2:
350
+ substructure_input = st.text_input(
351
+ "Substructure to Highlight (SMARTS)",
352
+ value="C=C"
353
+ )
354
+
355
+ if st.button("Predict and Visualize", type="primary"):
356
+ with st.spinner("Predicting masked SMILES..."):
357
+ predict_and_visualize_masked_smiles(smiles_input_masked, substructure_input)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
358
 
359
+ elif tab_selection == "Molecule Viewer":
360
+ st.header("Molecule Viewer")
361
+ st.markdown("Enter a SMILES string to display its 2D and 3D structure.")
362
+
363
+ smiles_input_viewer = st.text_input(
364
+ "SMILES String",
365
+ value="C1=CC=CC=C1"
366
+ )
367
+
368
+ if st.button("View Molecule", type="primary"):
369
+ with st.spinner("Generating molecule structures..."):
370
+ display_molecule_image(smiles_input_viewer)
 
371
 
372
+ if __name__ == "__main__":
373
+ main()