alidenewade commited on
Commit
41fa9c6
·
verified ·
1 Parent(s): ba49293

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +328 -14
app.py CHANGED
@@ -1,34 +1,103 @@
1
  # app.py
2
  import gradio as gr
3
  import torch
4
- from transformers import AutoModelForMaskedLM, AutoTokenizer, pipeline, BitsAndBytesConfig
5
  from rdkit import Chem
6
  from rdkit.Chem import Draw, rdFMCS
7
  from rdkit.Chem.Draw import MolToImage
8
  # PIL is imported as Image by rdkit.Chem.Draw.MolToImage, but explicit import is good practice if used directly.
9
  # from PIL import Image
10
  import pandas as pd
11
-
 
 
 
12
  import io
13
  import base64
14
  import logging
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  # Model names
16
  model_name = "seyonec/PubChem10M_SMILES_BPE_450k"
17
 
18
- # Load tokenizer (doesn't need quantization)
19
  fill_mask_tokenizer = AutoTokenizer.from_pretrained(model_name)
 
20
 
21
-
22
- # Load model with quantization if available
23
  model_kwargs = {
24
  "torch_dtype": torch_dtype,
25
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  **model_kwargs
27
  )
28
 
29
- # Set model to evaluation mode for inference
30
- fill_mask_model.eval()
 
31
 
 
 
 
 
 
 
 
 
32
 
33
  # Create optimized pipeline
34
  # Let pipeline infer device from model if possible, or set based on model's device
@@ -38,38 +107,283 @@ import logging
38
  fill_mask_pipeline = pipeline(
39
  'fill-mask',
40
  model=fill_mask_model,
 
 
 
41
  )
42
 
43
  logger.info("Models loaded successfully with optimizations")
44
- return fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline
45
 
46
  except Exception as e:
47
  logger.error(f"Error loading optimized models: {e}")
 
 
 
 
 
 
 
 
 
48
  device_idx = 0 if torch.cuda.is_available() else -1
49
  fill_mask_pipeline = pipeline('fill-mask', model=fill_mask_model, tokenizer=fill_mask_tokenizer, device=device_idx)
50
 
51
-
52
-
53
 
54
  if torch.cuda.is_available():
55
  fill_mask_model.to("cuda")
 
56
 
57
-
58
- return fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline
59
 
60
  # Load models with optimizations
61
- fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline = load_optimized_models()
62
 
63
  # --- Memory Management Utilities ---
64
  def clear_gpu_cache():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  # Unpack image_list into individual image outputs + df_results + status_message
66
  return df_results, image_list[0], image_list[1], image_list[2], image_list[3], image_list[4], status_message
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  def display_molecule_image(smiles_string):
69
  """
70
  Displays a 2D image of a molecule from its SMILES string.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  outputs=[predictions_table, img_out_1, img_out_2, img_out_3, img_out_4, img_out_5, status_masked]
72
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  with gr.Tab("Molecule Viewer"):
74
  gr.Markdown("Enter a SMILES string to display its 2D structure.")
75
- smiles_input_viewer = gr.Textbox(label="SMILES String", value="C1=CC=CC=C1")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # app.py
2
  import gradio as gr
3
  import torch
4
+ from transformers import AutoModelForMaskedLM, AutoTokenizer, pipeline, RobertaModel, RobertaTokenizer, BitsAndBytesConfig
5
  from rdkit import Chem
6
  from rdkit.Chem import Draw, rdFMCS
7
  from rdkit.Chem.Draw import MolToImage
8
  # PIL is imported as Image by rdkit.Chem.Draw.MolToImage, but explicit import is good practice if used directly.
9
  # from PIL import Image
10
  import pandas as pd
11
+ from bertviz import head_view # For potential future use or if other parts rely on it
12
+ from bertviz import neuron_view as neuron_view_function # Specific import for neuron_view function
13
+ # IPython.core.display.HTML is generally for notebooks. Gradio's gr.HTML handles HTML strings directly.
14
+ # from IPython.core.display import HTML
15
  import io
16
  import base64
17
  import logging
18
+
19
+ # Set up logging to monitor quantization effects
20
+ logging.basicConfig(level=logging.INFO)
21
+ logger = logging.getLogger(__name__)
22
+
23
+ # --- Quantization Configuration ---
24
+ def get_quantization_config():
25
+ """
26
+ Configure 8-bit quantization for model optimization.
27
+ Falls back gracefully if bitsandbytes is not available.
28
+ """
29
+ try:
30
+ # 8-bit quantization configuration - good balance of speed and quality
31
+ quantization_config = BitsAndBytesConfig(
32
+ load_in_8bit=True,
33
+ bnb_8bit_compute_dtype=torch.float16,
34
+ bnb_8bit_use_double_quant=True, # Nested quantization for better compression
35
+ )
36
+ logger.info("8-bit quantization configuration loaded successfully")
37
+ return quantization_config
38
+ except ImportError:
39
+ logger.warning("bitsandbytes not available, falling back to standard loading")
40
+ return None
41
+ except Exception as e:
42
+ logger.warning(f"Quantization setup failed: {e}, using standard loading")
43
+ return None
44
+
45
+ def get_torch_dtype():
46
+ """Get appropriate torch dtype based on available hardware."""
47
+ if torch.cuda.is_available():
48
+ return torch.float16 # Use half precision on GPU
49
+ else:
50
+ return torch.float32 # Keep full precision on CPU
51
+
52
+ # --- Optimized Model Loading ---
53
+ def load_optimized_models():
54
+ """Load models with quantization and other optimizations."""
55
+ device = "cuda" if torch.cuda.is_available() else "cpu"
56
+ torch_dtype = get_torch_dtype()
57
+ quantization_config = get_quantization_config()
58
+
59
+ logger.info(f"Loading models on device: {device} with dtype: {torch_dtype}")
60
+
61
  # Model names
62
  model_name = "seyonec/PubChem10M_SMILES_BPE_450k"
63
 
64
+ # Load tokenizers (these don't need quantization)
65
  fill_mask_tokenizer = AutoTokenizer.from_pretrained(model_name)
66
+ attention_tokenizer = RobertaTokenizer.from_pretrained(model_name)
67
 
68
+ # Load models with quantization if available
 
69
  model_kwargs = {
70
  "torch_dtype": torch_dtype,
71
  }
72
+
73
+ if quantization_config is not None and torch.cuda.is_available(): # Quantization typically for GPU
74
+ model_kwargs["quantization_config"] = quantization_config
75
+ # device_map="auto" is often used with bitsandbytes for automatic distribution
76
+ model_kwargs["device_map"] = "auto"
77
+ elif torch.cuda.is_available():
78
+ model_kwargs["device_map"] = "auto" # For non-quantized GPU loading
79
+ else:
80
+ model_kwargs["device_map"] = None # For CPU
81
+
82
+ try:
83
+ # Masked LM Model
84
+ fill_mask_model = AutoModelForMaskedLM.from_pretrained(
85
+ model_name,
86
  **model_kwargs
87
  )
88
 
89
+ # RoBERTa model for attention
90
+ attention_model_kwargs = model_kwargs.copy()
91
+ attention_model_kwargs["output_attentions"] = True
92
 
93
+ attention_model = RobertaModel.from_pretrained(
94
+ model_name,
95
+ **attention_model_kwargs
96
+ )
97
+
98
+ # Set models to evaluation mode for inference
99
+ fill_mask_model.eval()
100
+ attention_model.eval()
101
 
102
  # Create optimized pipeline
103
  # Let pipeline infer device from model if possible, or set based on model's device
 
107
  fill_mask_pipeline = pipeline(
108
  'fill-mask',
109
  model=fill_mask_model,
110
+ tokenizer=fill_mask_tokenizer,
111
+ device=pipeline_device, # Use model's device
112
+ # torch_dtype=torch_dtype # Pipeline might infer this or it might conflict
113
  )
114
 
115
  logger.info("Models loaded successfully with optimizations")
116
+ return fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline, attention_model, attention_tokenizer
117
 
118
  except Exception as e:
119
  logger.error(f"Error loading optimized models: {e}")
120
+ # Fallback to standard loading
121
+ logger.info("Falling back to standard model loading...")
122
+ return load_standard_models(model_name)
123
+
124
+ def load_standard_models(model_name):
125
+ """Fallback standard model loading without quantization."""
126
+ fill_mask_tokenizer = AutoTokenizer.from_pretrained(model_name)
127
+ fill_mask_model = AutoModelForMaskedLM.from_pretrained(model_name)
128
+ # Determine device for standard loading
129
  device_idx = 0 if torch.cuda.is_available() else -1
130
  fill_mask_pipeline = pipeline('fill-mask', model=fill_mask_model, tokenizer=fill_mask_tokenizer, device=device_idx)
131
 
132
+ attention_model = RobertaModel.from_pretrained(model_name, output_attentions=True)
133
+ attention_tokenizer = RobertaTokenizer.from_pretrained(model_name)
134
 
135
  if torch.cuda.is_available():
136
  fill_mask_model.to("cuda")
137
+ attention_model.to("cuda")
138
 
139
+ return fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline, attention_model, attention_tokenizer
 
140
 
141
  # Load models with optimizations
142
+ fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline, attention_model, attention_tokenizer = load_optimized_models()
143
 
144
  # --- Memory Management Utilities ---
145
  def clear_gpu_cache():
146
+ """Clear CUDA cache to free up memory."""
147
+ if torch.cuda.is_available():
148
+ torch.cuda.empty_cache()
149
+
150
+ # --- Helper Functions from Notebook (adapted) ---
151
+ def get_mol(smiles):
152
+ """Converts SMILES to RDKit Mol object and Kekulizes it."""
153
+ mol = Chem.MolFromSmiles(smiles)
154
+ if mol is None:
155
+ return None
156
+ try:
157
+ Chem.Kekulize(mol)
158
+ except: # Kekulization can fail for some structures
159
+ pass
160
+ return mol
161
+
162
+ def find_matches_one(mol, submol_smarts):
163
+ """Finds all matching atoms for a SMARTS pattern in a molecule."""
164
+ if not mol or not submol_smarts:
165
+ return []
166
+ submol = Chem.MolFromSmarts(submol_smarts)
167
+ if not submol:
168
+ return []
169
+ matches = mol.GetSubstructMatches(submol)
170
+ return matches
171
+
172
+ def get_image_with_highlight(mol, atomset=None, size=(300, 300)):
173
+ """Draws molecule with optional atom highlighting."""
174
+ if mol is None:
175
+ return None
176
+ highlight_color = (0, 1, 0, 0.5) # Green with some transparency
177
+
178
+ # Ensure atomset contains integers if not None or empty
179
+ valid_atomset = []
180
+ if atomset:
181
+ try:
182
+ valid_atomset = [int(a) for a in atomset]
183
+ except ValueError:
184
+ logger.warning(f"Invalid atom in atomset: {atomset}. Proceeding without highlighting problematic atoms.")
185
+ valid_atomset = [int(a) for a in atomset if str(a).isdigit()] # Filter out non-integers
186
+
187
+ img = MolToImage(mol, size=size, fitImage=True,
188
+ highlightAtoms=valid_atomset if valid_atomset else [],
189
+ highlightAtomColors={i: highlight_color for i in valid_atomset} if valid_atomset else {})
190
+ return img
191
+
192
+ # --- Optimized Gradio Interface Functions ---
193
+
194
+ def predict_and_visualize_masked_smiles(smiles_mask, substructure_smarts_highlight="CC=CC"):
195
+ """
196
+ Predicts masked tokens in a SMILES string, shows scores, and visualizes molecules.
197
+ Optimized with memory management. Returns 7 items for Gradio outputs.
198
+ """
199
+ if fill_mask_tokenizer.mask_token not in smiles_mask:
200
+ # Return 7 items for the 7 output components
201
+ return pd.DataFrame(), None, None, None, None, None, "Error: Input SMILES must contain a mask token (e.g., <mask>)."
202
+
203
+ try:
204
+ # Use torch.no_grad() for inference to save memory
205
+ with torch.no_grad():
206
+ predictions = fill_mask_pipeline(smiles_mask, top_k=10) # Get more to filter for valid ones
207
+ except Exception as e:
208
+ clear_gpu_cache() # Clear cache on error
209
+ # Return 7 items
210
+ return pd.DataFrame(), None, None, None, None, None, f"Error during prediction: {str(e)}"
211
+
212
+ results_data = []
213
+ image_list = []
214
+ valid_predictions_count = 0
215
+
216
+ for pred in predictions:
217
+ if valid_predictions_count >= 5:
218
+ break
219
+
220
+ predicted_smiles = pred['sequence']
221
+ score = pred['score']
222
+
223
+ mol = get_mol(predicted_smiles)
224
+ if mol:
225
+ results_data.append({"Predicted SMILES": predicted_smiles, "Score": f"{score:.4f}"})
226
+
227
+ atom_matches_indices = []
228
+ if substructure_smarts_highlight:
229
+ matches = find_matches_one(mol, substructure_smarts_highlight)
230
+ if matches:
231
+ atom_matches_indices = list(matches[0]) # Highlight first match
232
+
233
+ img = get_image_with_highlight(mol, atomset=atom_matches_indices)
234
+ image_list.append(img)
235
+ valid_predictions_count += 1
236
+
237
+ # Pad image_list if fewer than 5 valid predictions
238
+ while len(image_list) < 5:
239
+ image_list.append(None)
240
+
241
+ df_results = pd.DataFrame(results_data)
242
+
243
+ # Clear cache after inference
244
+ clear_gpu_cache()
245
+
246
+ status_message = "Prediction successful." if valid_predictions_count > 0 else "No valid molecules found for top predictions."
247
+
248
  # Unpack image_list into individual image outputs + df_results + status_message
249
  return df_results, image_list[0], image_list[1], image_list[2], image_list[3], image_list[4], status_message
250
 
251
+
252
+ def visualize_attention_bertviz(sentence_a, sentence_b):
253
+ """
254
+ Generates and displays BertViz neuron-by-neuron attention view as HTML.
255
+ Optimized with memory management and mixed precision.
256
+ """
257
+ if not sentence_a or not sentence_b:
258
+ return "<p style='color:red;'>Please provide two SMILES strings.</p>"
259
+ try:
260
+ inputs = attention_tokenizer.encode_plus(sentence_a, sentence_b, return_tensors='pt', add_special_tokens=True)
261
+ input_ids = inputs['input_ids']
262
+
263
+ # Move to appropriate device if using GPU
264
+ if torch.cuda.is_available() and hasattr(attention_model, 'device'):
265
+ input_ids = input_ids.to(attention_model.device)
266
+
267
+ # Ensure model is in eval mode and use no_grad for inference
268
+ attention_model.eval()
269
+ with torch.no_grad():
270
+ # Use autocast for mixed precision if on CUDA
271
+ if torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and hasattr(torch.cuda.amp, 'autocast'): # Check for amp
272
+ with torch.cuda.amp.autocast(dtype=torch.float16 if get_torch_dtype() == torch.float16 else None):
273
+ attention_outputs = attention_model(input_ids)
274
+ else:
275
+ attention_outputs = attention_model(input_ids)
276
+
277
+ attention = attention_outputs[-1] # Last item in the tuple is attentions
278
+ input_id_list = input_ids[0].tolist()
279
+ tokens = attention_tokenizer.convert_ids_to_tokens(input_id_list)
280
+
281
+ # Using the specifically imported neuron_view_function
282
+ html_object = neuron_view_function(attention, tokens)
283
+
284
+ # Extract HTML string from the IPython.core.display.HTML object
285
+ html_string = html_object.data # .data should provide the HTML string
286
+
287
+ # Add D3 and jQuery CDN links to the HTML string for better rendering in Gradio
288
+ html_with_deps = f"""
289
+ <script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.5.1/jquery.min.js"></script>
290
+ <script src="https://cdnjs.cloudflare.com/ajax/libs/d3/5.16.0/d3.min.js"></script>
291
+ {html_string}
292
+ """
293
+
294
+ # Clear cache after attention computation
295
+ clear_gpu_cache()
296
+
297
+ return html_with_deps
298
+ except Exception as e:
299
+ clear_gpu_cache() # Clear cache on error
300
+ logger.error(f"Error in visualize_attention_bertviz: {e}", exc_info=True)
301
+ return f"<p style='color:red;'>Error generating attention visualization: {str(e)}</p>"
302
+
303
  def display_molecule_image(smiles_string):
304
  """
305
  Displays a 2D image of a molecule from its SMILES string.
306
+ """
307
+ if not smiles_string:
308
+ return None, "Please enter a SMILES string."
309
+ mol = get_mol(smiles_string)
310
+ if mol is None:
311
+ return None, "Invalid SMILES string."
312
+ img = MolToImage(mol, size=(400, 400), fitImage=True)
313
+ return img, "Molecule displayed."
314
+
315
+ # --- Gradio Interface Definition ---
316
+ with gr.Blocks(theme=gr.themes.Default()) as demo:
317
+ gr.Markdown("# ChemBERTa SMILES Utilities Dashboard")
318
+
319
+ with gr.Tab("Masked SMILES Prediction"):
320
+ gr.Markdown("Enter a SMILES string with a `<mask>` token (e.g., `C1=CC=CC<mask>C1`) to predict possible completions.")
321
+ with gr.Row():
322
+ smiles_input_masked = gr.Textbox(label="SMILES String with Mask", value="C1=CC=CC<mask>C1")
323
+ substructure_input = gr.Textbox(label="Substructure to Highlight (SMARTS)", value="C=C")
324
+ predict_button_masked = gr.Button("Predict and Visualize")
325
+
326
+ status_masked = gr.Textbox(label="Status", interactive=False)
327
+ predictions_table = gr.DataFrame(label="Top Predictions & Scores")
328
+
329
+ gr.Markdown("### Predicted Molecule Visualizations (Top 5 Valid)")
330
+ with gr.Row():
331
+ img_out_1 = gr.Image(label="Prediction 1", type="pil", interactive=False)
332
+ img_out_2 = gr.Image(label="Prediction 2", type="pil", interactive=False)
333
+ img_out_3 = gr.Image(label="Prediction 3", type="pil", interactive=False)
334
+ img_out_4 = gr.Image(label="Prediction 4", type="pil", interactive=False)
335
+ img_out_5 = gr.Image(label="Prediction 5", type="pil", interactive=False)
336
+
337
+ # Automatically populate on load for the default example
338
+ demo.load(
339
+ lambda: predict_and_visualize_masked_smiles("C1=CC=CC<mask>C1", "C=C"),
340
+ inputs=None,
341
+ outputs=[predictions_table, img_out_1, img_out_2, img_out_3, img_out_4, img_out_5, status_masked]
342
+ )
343
+ predict_button_masked.click(
344
+ predict_and_visualize_masked_smiles,
345
+ inputs=[smiles_input_masked, substructure_input],
346
  outputs=[predictions_table, img_out_1, img_out_2, img_out_3, img_out_4, img_out_5, status_masked]
347
  )
348
+
349
+ with gr.Tab("Attention Visualization"):
350
+ gr.Markdown("Enter two SMILES strings to visualize **neuron-by-neuron attention** between them using BertViz. This may take a moment to render.")
351
+ with gr.Row():
352
+ smiles_a_input_attn = gr.Textbox(label="SMILES String A", value="CCCCC[C@@H](Br)CC")
353
+ smiles_b_input_attn = gr.Textbox(label="SMILES String B", value="CCCCC[C@H](Br)CC")
354
+ visualize_button_attn = gr.Button("Visualize Attention")
355
+ attention_html_output = gr.HTML(label="Attention Neuron View") # Changed label for clarity
356
+
357
+ # Automatically populate on load for the default example
358
+ demo.load(
359
+ lambda: visualize_attention_bertviz("CCCCC[C@@H](Br)CC", "CCCCC[C@H](Br)CC"),
360
+ inputs=None,
361
+ outputs=[attention_html_output]
362
+ )
363
+ visualize_button_attn.click(
364
+ visualize_attention_bertviz,
365
+ inputs=[smiles_a_input_attn, smiles_b_input_attn],
366
+ outputs=[attention_html_output]
367
+ )
368
+
369
  with gr.Tab("Molecule Viewer"):
370
  gr.Markdown("Enter a SMILES string to display its 2D structure.")
371
+ smiles_input_viewer = gr.Textbox(label="SMILES String", value="C1=CC=CC=C1")
372
+ view_button_molecule = gr.Button("View Molecule")
373
+ status_viewer = gr.Textbox(label="Status", interactive=False)
374
+ molecule_image_output = gr.Image(label="Molecule Structure", type="pil", interactive=False)
375
+
376
+ # Automatically populate on load for the default example
377
+ demo.load(
378
+ lambda: display_molecule_image("C1=CC=CC=C1"),
379
+ inputs=None,
380
+ outputs=[molecule_image_output, status_viewer]
381
+ )
382
+ view_button_molecule.click(
383
+ display_molecule_image,
384
+ inputs=[smiles_input_viewer],
385
+ outputs=[molecule_image_output, status_viewer]
386
+ )
387
+
388
+ if __name__ == "__main__":
389
+ demo.launch()