alidenewade commited on
Commit
35ed017
Β·
verified Β·
1 Parent(s): 54eac43

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +286 -240
app.py CHANGED
@@ -1,142 +1,137 @@
1
  # app.py
2
- import gradio as gr
3
  import torch
4
  from transformers import AutoModelForMaskedLM, AutoTokenizer, pipeline, BitsAndBytesConfig
5
  from rdkit import Chem
6
- from rdkit.Chem import Draw, rdFMCS
7
- from rdkit.Chem.Draw import MolToImage
8
- # PIL is imported as Image by rdkit.Chem.Draw.MolToImage, but explicit import is good practice if used directly.
9
- # from PIL import Image
10
  import pandas as pd
11
- import io
12
- import base64
13
  import logging
14
 
15
- # Set up logging to monitor quantization effects
16
  logging.basicConfig(level=logging.INFO)
17
  logger = logging.getLogger(__name__)
18
 
19
- # --- Quantization Configuration ---
20
- def get_quantization_config():
21
- """
22
- Configure 8-bit quantization for model optimization.
23
- Falls back gracefully if bitsandbytes is not available.
24
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  try:
26
- # 8-bit quantization configuration - good balance of speed and quality
27
  quantization_config = BitsAndBytesConfig(
28
  load_in_8bit=True,
29
- bnb_8bit_compute_dtype=torch.float16,
30
- bnb_8bit_use_double_quant=True, # Nested quantization for better compression
31
  )
32
- logger.info("8-bit quantization configuration loaded successfully")
33
- return quantization_config
34
  except ImportError:
35
- logger.warning("bitsandbytes not available, falling back to standard loading")
36
- return None
37
- except Exception as e:
38
- logger.warning(f"Quantization setup failed: {e}, using standard loading")
39
- return None
40
-
41
- def get_torch_dtype():
42
- """Get appropriate torch dtype based on available hardware."""
43
- if torch.cuda.is_available():
44
- return torch.float16 # Use half precision on GPU
45
- else:
46
- return torch.float32 # Keep full precision on CPU
47
-
48
- # --- Optimized Model Loading ---
49
- def load_optimized_models():
50
- """Load models with quantization and other optimizations."""
51
- device = "cuda" if torch.cuda.is_available() else "cpu"
52
- torch_dtype = get_torch_dtype()
53
- quantization_config = get_quantization_config()
54
 
55
- logger.info(f"Loading models on device: {device} with dtype: {torch_dtype}")
56
-
57
- # Model names
58
  model_name = "seyonec/PubChem10M_SMILES_BPE_450k"
59
-
60
- # Load tokenizer (doesn't need quantization)
61
- fill_mask_tokenizer = AutoTokenizer.from_pretrained(model_name)
62
-
63
- # Load model with quantization if available
64
- model_kwargs = {
65
- "torch_dtype": torch_dtype,
66
- }
67
-
68
- if quantization_config is not None and torch.cuda.is_available(): # Quantization typically for GPU
69
  model_kwargs["quantization_config"] = quantization_config
70
- # device_map="auto" is often used with bitsandbytes for automatic distribution
71
  model_kwargs["device_map"] = "auto"
72
- elif torch.cuda.is_available():
73
- model_kwargs["device_map"] = "auto" # For non-quantized GPU loading
74
- else:
75
- model_kwargs["device_map"] = None # For CPU
76
 
77
- try:
78
- # Masked LM Model
79
- fill_mask_model = AutoModelForMaskedLM.from_pretrained(
80
- model_name,
81
- **model_kwargs
82
- )
 
 
 
 
83
 
84
- # Set model to evaluation mode for inference
85
- fill_mask_model.eval()
86
 
87
- # Create optimized pipeline
88
- # Let pipeline infer device from model if possible, or set based on model's device
89
- pipeline_device = fill_mask_model.device.index if hasattr(fill_mask_model.device, 'type') and fill_mask_model.device.type == "cuda" else -1
90
 
91
- fill_mask_pipeline = pipeline(
92
- 'fill-mask',
93
- model=fill_mask_model,
94
- tokenizer=fill_mask_tokenizer,
95
- device=pipeline_device, # Use model's device
96
- # torch_dtype=torch_dtype # Pipeline might infer this or it might conflict
97
- )
98
 
99
- logger.info("Models loaded successfully with optimizations")
100
- return fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline
101
-
102
- except Exception as e:
103
- logger.error(f"Error loading optimized models: {e}")
104
- # Fallback to standard loading
105
- logger.info("Falling back to standard model loading...")
106
- return load_standard_models(model_name)
107
-
108
- def load_standard_models(model_name):
109
- """Fallback standard model loading without quantization."""
110
- fill_mask_tokenizer = AutoTokenizer.from_pretrained(model_name)
111
- fill_mask_model = AutoModelForMaskedLM.from_pretrained(model_name)
112
- # Determine device for standard loading
113
- device_idx = 0 if torch.cuda.is_available() else -1
114
- fill_mask_pipeline = pipeline('fill-mask', model=fill_mask_model, tokenizer=fill_mask_tokenizer, device=device_idx)
115
-
116
- if torch.cuda.is_available():
117
- fill_mask_model.to("cuda")
118
-
119
- return fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline
120
-
121
- # Load models with optimizations
122
- fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline = load_optimized_models()
123
-
124
- # --- Memory Management Utilities ---
125
- def clear_gpu_cache():
126
- """Clear CUDA cache to free up memory."""
127
- if torch.cuda.is_available():
128
- torch.cuda.empty_cache()
129
-
130
- # --- Helper Functions from Notebook (adapted) ---
131
  def get_mol(smiles):
132
- """Converts SMILES to RDKit Mol object and Kekulizes it."""
133
  mol = Chem.MolFromSmiles(smiles)
134
- if mol is None:
135
- return None
136
- try:
137
- Chem.Kekulize(mol)
138
- except: # Kekulization can fail for some structures
139
- pass
140
  return mol
141
 
142
  def find_matches_one(mol, submol_smarts):
@@ -149,149 +144,200 @@ def find_matches_one(mol, submol_smarts):
149
  matches = mol.GetSubstructMatches(submol)
150
  return matches
151
 
152
- def get_image_with_highlight(mol, atomset=None, size=(300, 300)):
153
- """Draws molecule with optional atom highlighting."""
154
- if mol is None:
155
- return None
156
- highlight_color = (0, 1, 0, 0.5) # Green with some transparency
157
-
158
- # Ensure atomset contains integers if not None or empty
159
- valid_atomset = []
160
- if atomset:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  try:
162
- valid_atomset = [int(a) for a in atomset]
163
- except ValueError:
164
- logger.warning(f"Invalid atom in atomset: {atomset}. Proceeding without highlighting problematic atoms.")
165
- valid_atomset = [int(a) for a in atomset if str(a).isdigit()] # Filter out non-integers
166
-
167
- img = MolToImage(mol, size=size, fitImage=True,
168
- highlightAtoms=valid_atomset if valid_atomset else [],
169
- highlightAtomColors={i: highlight_color for i in valid_atomset} if valid_atomset else {})
170
- return img
171
-
172
- # --- Optimized Gradio Interface Functions ---
173
-
174
- def predict_and_visualize_masked_smiles(smiles_mask, substructure_smarts_highlight="CC=CC"):
175
- """
176
- Predicts masked tokens in a SMILES string, shows scores, and visualizes molecules.
177
- Optimized with memory management. Returns 7 items for Gradio outputs.
178
- """
179
- if fill_mask_tokenizer.mask_token not in smiles_mask:
180
- # Return 7 items for the 7 output components
181
- return pd.DataFrame(), None, None, None, None, None, "Error: Input SMILES must contain a mask token (e.g., <mask>)."
 
 
 
 
 
 
 
 
 
 
182
 
 
 
 
 
 
 
 
 
 
183
  try:
184
- # Use torch.no_grad() for inference to save memory
185
  with torch.no_grad():
186
- predictions = fill_mask_pipeline(smiles_mask, top_k=10) # Get more to filter for valid ones
 
 
187
  except Exception as e:
188
- clear_gpu_cache() # Clear cache on error
189
- # Return 7 items
190
- return pd.DataFrame(), None, None, None, None, None, f"Error during prediction: {str(e)}"
191
 
192
  results_data = []
193
- image_list = []
194
  valid_predictions_count = 0
195
 
196
- for pred in predictions:
197
  if valid_predictions_count >= 5:
198
  break
199
 
200
  predicted_smiles = pred['sequence']
201
  score = pred['score']
202
-
203
  mol = get_mol(predicted_smiles)
204
- if mol:
205
- results_data.append({"Predicted SMILES": predicted_smiles, "Score": f"{score:.4f}"})
206
 
207
- atom_matches_indices = []
208
- if substructure_smarts_highlight:
209
- matches = find_matches_one(mol, substructure_smarts_highlight)
210
- if matches:
211
- atom_matches_indices = list(matches[0]) # Highlight first match
212
-
213
- img = get_image_with_highlight(mol, atomset=atom_matches_indices)
214
- image_list.append(img)
215
  valid_predictions_count += 1
216
-
217
- # Pad image_list if fewer than 5 valid predictions
218
- while len(image_list) < 5:
219
- image_list.append(None)
 
 
 
 
 
 
 
 
 
220
 
221
  df_results = pd.DataFrame(results_data)
222
-
223
- # Clear cache after inference
224
- clear_gpu_cache()
225
-
226
- status_message = "Prediction successful." if valid_predictions_count > 0 else "No valid molecules found for top predictions."
227
-
228
- # Unpack image_list into individual image outputs + df_results + status_message
229
- return df_results, image_list[0], image_list[1], image_list[2], image_list[3], image_list[4], status_message
230
-
231
- def display_molecule_image(smiles_string):
232
- """
233
- Displays a 2D image of a molecule from its SMILES string.
234
- """
235
- if not smiles_string:
236
- return None, "Please enter a SMILES string."
237
- mol = get_mol(smiles_string)
238
- if mol is None:
239
- return None, "Invalid SMILES string."
240
- img = MolToImage(mol, size=(400, 400), fitImage=True)
241
- return img, "Molecule displayed."
242
-
243
- # --- Gradio Interface Definition ---
244
- with gr.Blocks(theme=gr.themes.Default()) as demo:
245
- gr.Markdown("# ChemBERTa SMILES Utilities Dashboard")
246
-
247
- with gr.Tab("Masked SMILES Prediction"):
248
- gr.Markdown("Enter a SMILES string with a `<mask>` token (e.g., `C1=CC=CC<mask>C1`) to predict possible completions.")
249
- with gr.Row():
250
- smiles_input_masked = gr.Textbox(label="SMILES String with Mask", value="C1=CC=CC<mask>C1")
251
- substructure_input = gr.Textbox(label="Substructure to Highlight (SMARTS)", value="C=C")
252
- predict_button_masked = gr.Button("Predict and Visualize")
253
-
254
- status_masked = gr.Textbox(label="Status", interactive=False)
255
- predictions_table = gr.DataFrame(label="Top Predictions & Scores")
256
-
257
- gr.Markdown("### Predicted Molecule Visualizations (Top 5 Valid)")
258
- with gr.Row():
259
- img_out_1 = gr.Image(label="Prediction 1", type="pil", interactive=False)
260
- img_out_2 = gr.Image(label="Prediction 2", type="pil", interactive=False)
261
- img_out_3 = gr.Image(label="Prediction 3", type="pil", interactive=False)
262
- img_out_4 = gr.Image(label="Prediction 4", type="pil", interactive=False)
263
- img_out_5 = gr.Image(label="Prediction 5", type="pil", interactive=False)
264
-
265
- # Automatically populate on load for the default example
266
- demo.load(
267
- lambda: predict_and_visualize_masked_smiles("C1=CC=CC<mask>C1", "C=C"),
268
- inputs=None,
269
- outputs=[predictions_table, img_out_1, img_out_2, img_out_3, img_out_4, img_out_5, status_masked]
270
- )
271
- predict_button_masked.click(
272
- predict_and_visualize_masked_smiles,
273
- inputs=[smiles_input_masked, substructure_input],
274
- outputs=[predictions_table, img_out_1, img_out_2, img_out_3, img_out_4, img_out_5, status_masked]
275
- )
276
-
277
- with gr.Tab("Molecule Viewer"):
278
- gr.Markdown("Enter a SMILES string to display its 2D structure.")
279
- smiles_input_viewer = gr.Textbox(label="SMILES String", value="C1=CC=CC=C1")
280
- view_button_molecule = gr.Button("View Molecule")
281
- status_viewer = gr.Textbox(label="Status", interactive=False)
282
- molecule_image_output = gr.Image(label="Molecule Structure", type="pil", interactive=False)
283
-
284
- # Automatically populate on load for the default example
285
- demo.load(
286
- lambda: display_molecule_image("C1=CC=CC=C1"),
287
- inputs=None,
288
- outputs=[molecule_image_output, status_viewer]
289
- )
290
- view_button_molecule.click(
291
- display_molecule_image,
292
- inputs=[smiles_input_viewer],
293
- outputs=[molecule_image_output, status_viewer]
294
- )
295
-
296
- if __name__ == "__main__":
297
- demo.launch()
 
1
  # app.py
2
+ import streamlit as st
3
  import torch
4
  from transformers import AutoModelForMaskedLM, AutoTokenizer, pipeline, BitsAndBytesConfig
5
  from rdkit import Chem
6
+ from rdkit.Chem import Draw, AllChem
 
 
 
7
  import pandas as pd
8
+ import py3Dmol
9
+ import re
10
  import logging
11
 
12
+ # Set up logging
13
  logging.basicConfig(level=logging.INFO)
14
  logger = logging.getLogger(__name__)
15
 
16
+ # --- Page Configuration ---
17
+ st.set_page_config(
18
+ page_title="ChemBERTa SMILES Utilities",
19
+ page_icon="πŸ§ͺ",
20
+ layout="wide",
21
+ )
22
+
23
+ # --- Custom Styling (from drug_app) ---
24
+ def apply_custom_styling():
25
+ st.markdown(
26
+ """
27
+ <style>
28
+ @import url('https://fonts.googleapis.com/css2?family=Roboto:wght@400;700&display=swap');
29
+
30
+ html, body, [class*="st-"] {
31
+ font-family: 'Roboto', sans-serif;
32
+ }
33
+
34
+ .stApp {
35
+ background-color: rgb(28, 28, 28);
36
+ color: white;
37
+ }
38
+
39
+ /* Tab styles */
40
+ .stTabs [data-baseweb="tab-list"] {
41
+ gap: 24px;
42
+ }
43
+
44
+ .stTabs [data-baseweb="tab"] {
45
+ height: 50px;
46
+ white-space: pre-wrap;
47
+ background: none;
48
+ border-radius: 0px;
49
+ border-bottom: 2px solid #333;
50
+ padding: 10px 4px;
51
+ color: #AAA;
52
+ }
53
+
54
+ .stTabs [data-baseweb="tab"]:hover {
55
+ background: #222;
56
+ color: #FFF;
57
+ }
58
+
59
+ .stTabs [aria-selected="true"] {
60
+ border-bottom: 2px solid #00A0FF; /* Highlight color for active tab */
61
+ color: #FFF;
62
+ }
63
+
64
+ /* Button styles */
65
+ .stButton>button {
66
+ border-color: #00A0FF;
67
+ color: #00A0FF;
68
+ background-color: transparent;
69
+ }
70
+
71
+ .stButton>button:hover {
72
+ border-color: #FFF;
73
+ color: #FFF;
74
+ background-color: #00A0FF;
75
+ }
76
+
77
+ </style>
78
+ """,
79
+ unsafe_allow_html=True
80
+ )
81
+
82
+ apply_custom_styling()
83
+
84
+
85
+ # --- Model Loading (from mol_app) ---
86
+ @st.cache_resource(show_spinner="Loading ChemBERTa model...")
87
+ def load_optimized_models():
88
+ """Load models with quantization and other optimizations."""
89
+ device = "cuda" if torch.cuda.is_available() else "cpu"
90
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
91
+
92
  try:
 
93
  quantization_config = BitsAndBytesConfig(
94
  load_in_8bit=True,
95
+ bnb_8bit_compute_dtype=torch_dtype,
96
+ bnb_8bit_use_double_quant=True,
97
  )
98
+ logger.info("8-bit quantization will be used.")
 
99
  except ImportError:
100
+ quantization_config = None
101
+ logger.warning("bitsandbytes not found. Model will be loaded without quantization.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
 
 
 
103
  model_name = "seyonec/PubChem10M_SMILES_BPE_450k"
104
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
105
+
106
+ model_kwargs = {"torch_dtype": torch_dtype}
107
+ if quantization_config and torch.cuda.is_available():
 
 
 
 
 
 
108
  model_kwargs["quantization_config"] = quantization_config
 
109
  model_kwargs["device_map"] = "auto"
 
 
 
 
110
 
111
+ model = AutoModelForMaskedLM.from_pretrained(model_name, **model_kwargs)
112
+
113
+ pipe = pipeline(
114
+ 'fill-mask',
115
+ model=model,
116
+ tokenizer=tokenizer,
117
+ device=0 if device == "cuda" else -1
118
+ )
119
+ logger.info("ChemBERTa model loaded successfully.")
120
+ return pipe, tokenizer
121
 
122
+ fill_mask_pipeline, tokenizer = load_optimized_models()
 
123
 
 
 
 
124
 
125
+ # --- Core Functions ---
 
 
 
 
 
 
126
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  def get_mol(smiles):
128
+ """Converts SMILES to RDKit Mol object."""
129
  mol = Chem.MolFromSmiles(smiles)
130
+ if mol:
131
+ try:
132
+ Chem.Kekulize(mol)
133
+ except:
134
+ pass
 
135
  return mol
136
 
137
  def find_matches_one(mol, submol_smarts):
 
144
  matches = mol.GetSubstructMatches(submol)
145
  return matches
146
 
147
+ # --- Visualization Function (Adapted from drug_app) ---
148
+ def visualize_molecule_2d_3d(smiles: str, name: str, substructure_smarts=""):
149
+ """Generates a side-by-side 2D SVG and 3D py3Dmol HTML view for a single molecule."""
150
+ log = ""
151
+ try:
152
+ mol = get_mol(smiles)
153
+ if not mol:
154
+ return f"<p>Invalid SMILES for {name}</p>", f"❌ Invalid SMILES for {name}"
155
+
156
+ # --- 2D Visualization ---
157
+ drawer = Draw.rdMolDraw2D.MolDraw2DSVG(450, 350)
158
+ opts = drawer.drawOptions()
159
+ opts.clearBackground = False
160
+ opts.addStereoAnnotation = True
161
+ opts.baseFontSize = 0.9
162
+
163
+ # Highlighting
164
+ atom_indices_to_highlight = []
165
+ if substructure_smarts:
166
+ matches = find_matches_one(mol, substructure_smarts)
167
+ if matches:
168
+ atom_indices_to_highlight = list(matches[0]) # Highlight first match
169
+
170
+ # Dark theme colors for 2D drawing
171
+ opts.backgroundColour = (0.109, 0.109, 0.109) # rgb(28,28,28)
172
+ opts.symbolColour = (1, 1, 1)
173
+ opts.setAtomPalette({
174
+ -1: (1, 1, 1), # Default
175
+ 6: (0.9, 0.9, 0.9), # Carbon
176
+ 7: (0.5, 0.5, 1), # Nitrogen
177
+ 8: (1, 0.2, 0.2), # Oxygen
178
+ 16: (1, 0.8, 0.2), # Sulfur
179
+ })
180
+
181
+ drawer.DrawMolecule(mol, highlightAtoms=atom_indices_to_highlight)
182
+ drawer.FinishDrawing()
183
+ svg_2d = drawer.GetDrawingText()
184
+
185
+ # Fix colors for dark theme
186
+ svg_2d = svg_2d.replace('stroke="black"', 'stroke="white"')
187
+ svg_2d = svg_2d.replace('fill="black"', 'fill="white"')
188
+ svg_2d = re.sub(r'fill:#(000000|000);', 'fill:white;', svg_2d)
189
+
190
+ # --- 3D Visualization ---
191
+ mol_3d = Chem.AddHs(mol)
192
+ AllChem.EmbedMolecule(mol_3d, randomSeed=42)
193
  try:
194
+ AllChem.MMFFOptimizeMolecule(mol_3d)
195
+ except:
196
+ AllChem.ETKDGv3().Embed(mol_3d)
197
+
198
+ sdf_data = Chem.MolToMolBlock(mol_3d)
199
+
200
+ viewer = py3Dmol.view(width=450, height=350)
201
+ viewer.setBackgroundColor('#1C1C1C')
202
+ viewer.addModel(sdf_data, "sdf")
203
+ viewer.setStyle({'stick': {}, 'sphere': {'scale': 0.25}})
204
+ viewer.zoomTo()
205
+ html_3d = viewer._make_html()
206
+
207
+ # --- Combine Views ---
208
+ combined_html = f"""
209
+ <div style="display: flex; flex-direction: row; align-items: center; justify-content: space-around; border: 1px solid #444; border-radius: 10px; padding: 10px; margin-bottom: 20px; background-color: #2b2b2b;">
210
+ <div style="text-align: center;">
211
+ <h4 style="color: white; font-family: 'Roboto', sans-serif;">{name} (2D Structure)</h4>
212
+ <div style="background-color: #1C1C1C; padding: 10px; border-radius: 5px;">{svg_2d}</div>
213
+ </div>
214
+ <div style="text-align: center;">
215
+ <h4 style="color: white; font-family: 'Roboto', sans-serif;">{name} (3D Interactive)</h4>
216
+ {html_3d}
217
+ </div>
218
+ </div>
219
+ """
220
+ log += f"βœ… Generated 2D/3D view for {name}.\n"
221
+ return combined_html, log
222
+ except Exception as e:
223
+ return f"<p>Error visualizing {name}: {e}</p>", f"❌ Error visualizing {name}: {e}"
224
 
225
+
226
+ # --- Main Application Logic ---
227
+ def predict_and_generate_visualizations(smiles_mask, substructure_smarts):
228
+ """Predicts masked SMILES and returns a dataframe and HTML for visualizations."""
229
+ if tokenizer.mask_token not in smiles_mask:
230
+ st.error(f"Error: Input SMILES must contain a mask token (e.g., `{tokenizer.mask_token}`).")
231
+ return pd.DataFrame(), "", "Input error."
232
+
233
+ status_log = ""
234
  try:
 
235
  with torch.no_grad():
236
+ predictions = fill_mask_pipeline(smiles_mask, top_k=15)
237
+ if torch.cuda.is_available():
238
+ torch.cuda.empty_cache()
239
  except Exception as e:
240
+ st.error(f"An error occurred during model prediction: {e}")
241
+ return pd.DataFrame(), "", "Prediction error."
 
242
 
243
  results_data = []
244
+ combined_html = ""
245
  valid_predictions_count = 0
246
 
247
+ for i, pred in enumerate(predictions):
248
  if valid_predictions_count >= 5:
249
  break
250
 
251
  predicted_smiles = pred['sequence']
252
  score = pred['score']
 
253
  mol = get_mol(predicted_smiles)
 
 
254
 
255
+ if mol:
 
 
 
 
 
 
 
256
  valid_predictions_count += 1
257
+ results_data.append({
258
+ "Rank": valid_predictions_count,
259
+ "Predicted SMILES": predicted_smiles,
260
+ "Score": f"{score:.4f}"
261
+ })
262
+
263
+ html_view, log = visualize_molecule_2d_3d(
264
+ predicted_smiles,
265
+ f"Prediction #{valid_predictions_count}",
266
+ substructure_smarts
267
+ )
268
+ combined_html += html_view
269
+ status_log += log
270
 
271
  df_results = pd.DataFrame(results_data)
272
+ status_log += f"\nFound {valid_predictions_count} valid molecules from top predictions."
273
+ return df_results, combined_html, status_log
274
+
275
+ # --- Streamlit Interface ---
276
+ st.title("πŸ§ͺ ChemBERTa SMILES Utilities")
277
+ st.markdown("""
278
+ Enter a SMILES string with a `<mask>` token (e.g., `C1=CC=CC<mask>C1`) to predict possible completions.
279
+ The model will generate the most likely atoms or fragments to fill the mask.
280
+ """)
281
+
282
+ tab1, tab2 = st.tabs(["Masked SMILES Prediction", "Molecule Viewer"])
283
+
284
+ with tab1:
285
+ st.header("Masked SMILES Prediction")
286
+
287
+ with st.form("prediction_form"):
288
+ col1, col2 = st.columns(2)
289
+ with col1:
290
+ smiles_input_masked = st.text_input(
291
+ "SMILES String with Mask",
292
+ value=f"C1=CC=CC{tokenizer.mask_token}C1",
293
+ help=f"Use `{tokenizer.mask_token}` as the mask token."
294
+ )
295
+ with col2:
296
+ substructure_input = st.text_input(
297
+ "Substructure to Highlight (SMARTS)",
298
+ value="C=C",
299
+ help="Enter a SMARTS pattern to highlight in the 2D view."
300
+ )
301
+
302
+ submit_button = st.form_submit_button("πŸš€ Predict and Visualize", use_container_width=True)
303
+
304
+ if 'results_df' not in st.session_state or submit_button:
305
+ if submit_button or 'results_df' not in st.session_state:
306
+ with st.spinner("Running predictions... This may take a moment."):
307
+ df, html, log = predict_and_generate_visualizations(smiles_input_masked, substructure_input)
308
+ st.session_state.results_df = df
309
+ st.session_state.results_html = html
310
+ st.session_state.status_log = log
311
+
312
+ st.subheader("Top Predictions & Scores")
313
+ if 'results_df' in st.session_state and not st.session_state.results_df.empty:
314
+ st.dataframe(st.session_state.results_df, use_container_width=True, hide_index=True)
315
+ st.subheader("Predicted Molecule Visualizations (Top 5 Valid)")
316
+ st.components.v1.html(st.session_state.results_html, height=1850, scrolling=True)
317
+ else:
318
+ st.info("No valid predictions to display. Try a different input.")
319
+
320
+ with st.expander("Show Logs"):
321
+ if 'status_log' in st.session_state:
322
+ st.text_area("", st.session_state.status_log, height=200, key="log_area_pred")
323
+
324
+ with tab2:
325
+ st.header("Molecule Viewer")
326
+ st.markdown("Enter a single SMILES string to display its 2D and 3D structure.")
327
+
328
+ with st.form("viewer_form"):
329
+ smiles_input_viewer = st.text_input("SMILES String", value="CC(=O)Oc1ccccc1C(=O)O") # Aspirin
330
+ viewer_submit = st.form_submit_button("πŸ‘οΈ View Molecule", use_container_width=True)
331
+
332
+ if viewer_submit:
333
+ with st.spinner("Generating visualization..."):
334
+ html_view, log = visualize_molecule_2d_3d(smiles_input_viewer, "Molecule")
335
+ st.session_state.viewer_html = html_view
336
+ st.session_state.viewer_log = log
337
+
338
+ if 'viewer_html' in st.session_state:
339
+ st.components.v1.html(st.session_state.viewer_html, height=450)
340
+
341
+ with st.expander("Show Logs"):
342
+ if 'viewer_log' in st.session_state:
343
+ st.text_area("", st.session_state.viewer_log, height=100, key="log_area_view")