alidenewade commited on
Commit
ef610f3
·
verified ·
1 Parent(s): e27071c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +198 -43
app.py CHANGED
@@ -1,28 +1,151 @@
1
  # app.py
2
  import gradio as gr
3
  import torch
4
- from transformers import AutoModelForMaskedLM, AutoTokenizer, pipeline, RobertaModel, RobertaTokenizer
5
  from rdkit import Chem
6
  from rdkit.Chem import Draw, rdFMCS
7
  from rdkit.Chem.Draw import MolToImage
8
- from PIL import Image # Corrected Line
 
9
  import pandas as pd
10
- from bertviz import head_view
11
- from IPython.core.display import HTML
 
 
12
  import io
13
  import base64
 
14
 
15
- # --- Model and Tokenizer Loading ---
16
- # Masked LM Model
17
- fill_mask_model_name = "seyonec/PubChem10M_SMILES_BPE_450k"
18
- fill_mask_tokenizer = AutoTokenizer.from_pretrained(fill_mask_model_name)
19
- fill_mask_model = AutoModelForMaskedLM.from_pretrained(fill_mask_model_name)
20
- fill_mask_pipeline = pipeline('fill-mask', model=fill_mask_model, tokenizer=fill_mask_tokenizer)
21
 
22
- # Roberta Model for Attention
23
- attention_model_name = 'seyonec/PubChem10M_SMILES_BPE_450k' # Can be same or different as needed
24
- attention_model = RobertaModel.from_pretrained(attention_model_name, output_attentions=True)
25
- attention_tokenizer = RobertaTokenizer.from_pretrained(attention_model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  # --- Helper Functions from Notebook (adapted) ---
28
  def get_mol(smiles):
@@ -51,24 +174,40 @@ def get_image_with_highlight(mol, atomset=None, size=(300, 300)):
51
  if mol is None:
52
  return None
53
  highlight_color = (0, 1, 0, 0.5) # Green with some transparency
 
 
 
 
 
 
 
 
 
 
54
  img = MolToImage(mol, size=size, fitImage=True,
55
- highlightAtoms=atomset if atomset else [],
56
- highlightAtomColors={i: highlight_color for i in atomset} if atomset else {})
57
  return img
58
 
59
- # --- Gradio Interface Functions ---
60
 
61
  def predict_and_visualize_masked_smiles(smiles_mask, substructure_smarts_highlight="CC=CC"):
62
  """
63
  Predicts masked tokens in a SMILES string, shows scores, and visualizes molecules.
 
64
  """
65
  if fill_mask_tokenizer.mask_token not in smiles_mask:
66
- return pd.DataFrame(), [None]*5, "Error: Input SMILES must contain a mask token (e.g., <mask>)."
 
67
 
68
  try:
69
- predictions = fill_mask_pipeline(smiles_mask, top_k=10) # Get more to filter for valid ones
 
 
70
  except Exception as e:
71
- return pd.DataFrame(), [None]*5, f"Error during prediction: {str(e)}"
 
 
72
 
73
  results_data = []
74
  image_list = []
@@ -85,13 +224,13 @@ def predict_and_visualize_masked_smiles(smiles_mask, substructure_smarts_highlig
85
  if mol:
86
  results_data.append({"Predicted SMILES": predicted_smiles, "Score": f"{score:.4f}"})
87
 
88
- atom_matches = []
89
  if substructure_smarts_highlight:
90
  matches = find_matches_one(mol, substructure_smarts_highlight)
91
  if matches:
92
- atom_matches = list(matches[0]) # Highlight first match
93
 
94
- img = get_image_with_highlight(mol, atomset=atom_matches)
95
  image_list.append(img)
96
  valid_predictions_count += 1
97
 
@@ -100,50 +239,66 @@ def predict_and_visualize_masked_smiles(smiles_mask, substructure_smarts_highlig
100
  image_list.append(None)
101
 
102
  df_results = pd.DataFrame(results_data)
103
- return df_results, image_list, "Prediction successful." if valid_predictions_count > 0 else "No valid molecules found for top predictions."
 
 
 
 
 
 
 
104
 
105
 
106
  def visualize_attention_bertviz(sentence_a, sentence_b):
107
  """
108
- Generates and displays BertViz attention head view as HTML.
 
109
  """
110
  if not sentence_a or not sentence_b:
111
- return "Please provide two SMILES strings."
112
  try:
113
  inputs = attention_tokenizer.encode_plus(sentence_a, sentence_b, return_tensors='pt', add_special_tokens=True)
114
  input_ids = inputs['input_ids']
115
 
116
- # Ensure model is in eval mode and no_grad for inference
 
 
 
 
117
  attention_model.eval()
118
  with torch.no_grad():
119
- attention_outputs = attention_model(input_ids)
 
 
 
 
 
120
 
121
  attention = attention_outputs[-1] # Last item in the tuple is attentions
122
  input_id_list = input_ids[0].tolist()
123
  tokens = attention_tokenizer.convert_ids_to_tokens(input_id_list)
124
 
125
- html_object = head_view(attention, tokens, display_mode="light") # Use light mode for better Gradio compatibility
 
126
 
127
  # Extract HTML string from the IPython.core.display.HTML object
128
- html_string = html_object.data
129
-
130
- # Embed JavaScript directly if needed, or ensure Gradio's HTML component handles it.
131
- # BertViz often requires D3.js and jQuery. Gradio's HTML component might not execute all JS.
132
- # For robustness, it's better if head_view produces self-contained HTML or if Gradio supports JS execution.
133
- # A common workaround is to serve the HTML and use an iframe, or save to file and link.
134
- # Here, we'll return the raw HTML string and let Gradio's gr.HTML handle it.
135
 
136
  # Add D3 and jQuery CDN links to the HTML string for better rendering in Gradio
137
- # This is a common workaround if Gradio's HTML component doesn't include these by default
138
- # Note: This might still have limitations depending on Gradio's sandboxing.
139
  html_with_deps = f"""
140
- <script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/2.0.0/jquery.min.js"></script>
141
- <script src="https://cdnjs.cloudflare.com/ajax/libs/d3/3.5.8/d3.min.js"></script>
142
  {html_string}
143
  """
 
 
 
 
144
  return html_with_deps
145
  except Exception as e:
146
- return f"Error generating attention visualization: {str(e)}"
 
 
147
 
148
  def display_molecule_image(smiles_string):
149
  """
@@ -192,12 +347,12 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
192
  )
193
 
194
  with gr.Tab("Attention Visualization"):
195
- gr.Markdown("Enter two SMILES strings to visualize attention between them using BertViz. This may take a moment to render.")
196
  with gr.Row():
197
  smiles_a_input_attn = gr.Textbox(label="SMILES String A", value="CCCCC[C@@H](Br)CC")
198
  smiles_b_input_attn = gr.Textbox(label="SMILES String B", value="CCCCC[C@H](Br)CC")
199
  visualize_button_attn = gr.Button("Visualize Attention")
200
- attention_html_output = gr.HTML(label="Attention Head View")
201
 
202
  # Automatically populate on load for the default example
203
  demo.load(
 
1
  # app.py
2
  import gradio as gr
3
  import torch
4
+ from transformers import AutoModelForMaskedLM, AutoTokenizer, pipeline, RobertaModel, RobertaTokenizer, BitsAndBytesConfig
5
  from rdkit import Chem
6
  from rdkit.Chem import Draw, rdFMCS
7
  from rdkit.Chem.Draw import MolToImage
8
+ # PIL is imported as Image by rdkit.Chem.Draw.MolToImage, but explicit import is good practice if used directly.
9
+ # from PIL import Image
10
  import pandas as pd
11
+ from bertviz import head_view # For potential future use or if other parts rely on it
12
+ from bertviz import neuron_view as neuron_view_function # Specific import for neuron_view function
13
+ # IPython.core.display.HTML is generally for notebooks. Gradio's gr.HTML handles HTML strings directly.
14
+ # from IPython.core.display import HTML
15
  import io
16
  import base64
17
+ import logging
18
 
19
+ # Set up logging to monitor quantization effects
20
+ logging.basicConfig(level=logging.INFO)
21
+ logger = logging.getLogger(__name__)
 
 
 
22
 
23
+ # --- Quantization Configuration ---
24
+ def get_quantization_config():
25
+ """
26
+ Configure 8-bit quantization for model optimization.
27
+ Falls back gracefully if bitsandbytes is not available.
28
+ """
29
+ try:
30
+ # 8-bit quantization configuration - good balance of speed and quality
31
+ quantization_config = BitsAndBytesConfig(
32
+ load_in_8bit=True,
33
+ bnb_8bit_compute_dtype=torch.float16,
34
+ bnb_8bit_use_double_quant=True, # Nested quantization for better compression
35
+ )
36
+ logger.info("8-bit quantization configuration loaded successfully")
37
+ return quantization_config
38
+ except ImportError:
39
+ logger.warning("bitsandbytes not available, falling back to standard loading")
40
+ return None
41
+ except Exception as e:
42
+ logger.warning(f"Quantization setup failed: {e}, using standard loading")
43
+ return None
44
+
45
+ def get_torch_dtype():
46
+ """Get appropriate torch dtype based on available hardware."""
47
+ if torch.cuda.is_available():
48
+ return torch.float16 # Use half precision on GPU
49
+ else:
50
+ return torch.float32 # Keep full precision on CPU
51
+
52
+ # --- Optimized Model Loading ---
53
+ def load_optimized_models():
54
+ """Load models with quantization and other optimizations."""
55
+ device = "cuda" if torch.cuda.is_available() else "cpu"
56
+ torch_dtype = get_torch_dtype()
57
+ quantization_config = get_quantization_config()
58
+
59
+ logger.info(f"Loading models on device: {device} with dtype: {torch_dtype}")
60
+
61
+ # Model names
62
+ model_name = "seyonec/PubChem10M_SMILES_BPE_450k"
63
+
64
+ # Load tokenizers (these don't need quantization)
65
+ fill_mask_tokenizer = AutoTokenizer.from_pretrained(model_name)
66
+ attention_tokenizer = RobertaTokenizer.from_pretrained(model_name)
67
+
68
+ # Load models with quantization if available
69
+ model_kwargs = {
70
+ "torch_dtype": torch_dtype,
71
+ }
72
+
73
+ if quantization_config is not None and torch.cuda.is_available(): # Quantization typically for GPU
74
+ model_kwargs["quantization_config"] = quantization_config
75
+ # device_map="auto" is often used with bitsandbytes for automatic distribution
76
+ model_kwargs["device_map"] = "auto"
77
+ elif torch.cuda.is_available():
78
+ model_kwargs["device_map"] = "auto" # For non-quantized GPU loading
79
+ else:
80
+ model_kwargs["device_map"] = None # For CPU
81
+
82
+ try:
83
+ # Masked LM Model
84
+ fill_mask_model = AutoModelForMaskedLM.from_pretrained(
85
+ model_name,
86
+ **model_kwargs
87
+ )
88
+
89
+ # RoBERTa model for attention
90
+ attention_model_kwargs = model_kwargs.copy()
91
+ attention_model_kwargs["output_attentions"] = True
92
+
93
+ attention_model = RobertaModel.from_pretrained(
94
+ model_name,
95
+ **attention_model_kwargs
96
+ )
97
+
98
+ # Set models to evaluation mode for inference
99
+ fill_mask_model.eval()
100
+ attention_model.eval()
101
+
102
+ # Create optimized pipeline
103
+ # Let pipeline infer device from model if possible, or set based on model's device
104
+ pipeline_device = fill_mask_model.device.index if hasattr(fill_mask_model.device, 'type') and fill_mask_model.device.type == "cuda" else -1
105
+
106
+
107
+ fill_mask_pipeline = pipeline(
108
+ 'fill-mask',
109
+ model=fill_mask_model,
110
+ tokenizer=fill_mask_tokenizer,
111
+ device=pipeline_device, # Use model's device
112
+ # torch_dtype=torch_dtype # Pipeline might infer this or it might conflict
113
+ )
114
+
115
+ logger.info("Models loaded successfully with optimizations")
116
+ return fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline, attention_model, attention_tokenizer
117
+
118
+ except Exception as e:
119
+ logger.error(f"Error loading optimized models: {e}")
120
+ # Fallback to standard loading
121
+ logger.info("Falling back to standard model loading...")
122
+ return load_standard_models(model_name)
123
+
124
+ def load_standard_models(model_name):
125
+ """Fallback standard model loading without quantization."""
126
+ fill_mask_tokenizer = AutoTokenizer.from_pretrained(model_name)
127
+ fill_mask_model = AutoModelForMaskedLM.from_pretrained(model_name)
128
+ # Determine device for standard loading
129
+ device_idx = 0 if torch.cuda.is_available() else -1
130
+ fill_mask_pipeline = pipeline('fill-mask', model=fill_mask_model, tokenizer=fill_mask_tokenizer, device=device_idx)
131
+
132
+ attention_model = RobertaModel.from_pretrained(model_name, output_attentions=True)
133
+ attention_tokenizer = RobertaTokenizer.from_pretrained(model_name)
134
+
135
+ if torch.cuda.is_available():
136
+ fill_mask_model.to("cuda")
137
+ attention_model.to("cuda")
138
+
139
+ return fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline, attention_model, attention_tokenizer
140
+
141
+ # Load models with optimizations
142
+ fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline, attention_model, attention_tokenizer = load_optimized_models()
143
+
144
+ # --- Memory Management Utilities ---
145
+ def clear_gpu_cache():
146
+ """Clear CUDA cache to free up memory."""
147
+ if torch.cuda.is_available():
148
+ torch.cuda.empty_cache()
149
 
150
  # --- Helper Functions from Notebook (adapted) ---
151
  def get_mol(smiles):
 
174
  if mol is None:
175
  return None
176
  highlight_color = (0, 1, 0, 0.5) # Green with some transparency
177
+
178
+ # Ensure atomset contains integers if not None or empty
179
+ valid_atomset = []
180
+ if atomset:
181
+ try:
182
+ valid_atomset = [int(a) for a in atomset]
183
+ except ValueError:
184
+ logger.warning(f"Invalid atom in atomset: {atomset}. Proceeding without highlighting problematic atoms.")
185
+ valid_atomset = [int(a) for a in atomset if str(a).isdigit()] # Filter out non-integers
186
+
187
  img = MolToImage(mol, size=size, fitImage=True,
188
+ highlightAtoms=valid_atomset if valid_atomset else [],
189
+ highlightAtomColors={i: highlight_color for i in valid_atomset} if valid_atomset else {})
190
  return img
191
 
192
+ # --- Optimized Gradio Interface Functions ---
193
 
194
  def predict_and_visualize_masked_smiles(smiles_mask, substructure_smarts_highlight="CC=CC"):
195
  """
196
  Predicts masked tokens in a SMILES string, shows scores, and visualizes molecules.
197
+ Optimized with memory management. Returns 7 items for Gradio outputs.
198
  """
199
  if fill_mask_tokenizer.mask_token not in smiles_mask:
200
+ # Return 7 items for the 7 output components
201
+ return pd.DataFrame(), None, None, None, None, None, "Error: Input SMILES must contain a mask token (e.g., <mask>)."
202
 
203
  try:
204
+ # Use torch.no_grad() for inference to save memory
205
+ with torch.no_grad():
206
+ predictions = fill_mask_pipeline(smiles_mask, top_k=10) # Get more to filter for valid ones
207
  except Exception as e:
208
+ clear_gpu_cache() # Clear cache on error
209
+ # Return 7 items
210
+ return pd.DataFrame(), None, None, None, None, None, f"Error during prediction: {str(e)}"
211
 
212
  results_data = []
213
  image_list = []
 
224
  if mol:
225
  results_data.append({"Predicted SMILES": predicted_smiles, "Score": f"{score:.4f}"})
226
 
227
+ atom_matches_indices = []
228
  if substructure_smarts_highlight:
229
  matches = find_matches_one(mol, substructure_smarts_highlight)
230
  if matches:
231
+ atom_matches_indices = list(matches[0]) # Highlight first match
232
 
233
+ img = get_image_with_highlight(mol, atomset=atom_matches_indices)
234
  image_list.append(img)
235
  valid_predictions_count += 1
236
 
 
239
  image_list.append(None)
240
 
241
  df_results = pd.DataFrame(results_data)
242
+
243
+ # Clear cache after inference
244
+ clear_gpu_cache()
245
+
246
+ status_message = "Prediction successful." if valid_predictions_count > 0 else "No valid molecules found for top predictions."
247
+
248
+ # Unpack image_list into individual image outputs + df_results + status_message
249
+ return df_results, image_list[0], image_list[1], image_list[2], image_list[3], image_list[4], status_message
250
 
251
 
252
  def visualize_attention_bertviz(sentence_a, sentence_b):
253
  """
254
+ Generates and displays BertViz neuron-by-neuron attention view as HTML.
255
+ Optimized with memory management and mixed precision.
256
  """
257
  if not sentence_a or not sentence_b:
258
+ return "<p style='color:red;'>Please provide two SMILES strings.</p>"
259
  try:
260
  inputs = attention_tokenizer.encode_plus(sentence_a, sentence_b, return_tensors='pt', add_special_tokens=True)
261
  input_ids = inputs['input_ids']
262
 
263
+ # Move to appropriate device if using GPU
264
+ if torch.cuda.is_available() and hasattr(attention_model, 'device'):
265
+ input_ids = input_ids.to(attention_model.device)
266
+
267
+ # Ensure model is in eval mode and use no_grad for inference
268
  attention_model.eval()
269
  with torch.no_grad():
270
+ # Use autocast for mixed precision if on CUDA
271
+ if torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and hasattr(torch.cuda.amp, 'autocast'): # Check for amp
272
+ with torch.cuda.amp.autocast(dtype=torch.float16 if get_torch_dtype() == torch.float16 else None):
273
+ attention_outputs = attention_model(input_ids)
274
+ else:
275
+ attention_outputs = attention_model(input_ids)
276
 
277
  attention = attention_outputs[-1] # Last item in the tuple is attentions
278
  input_id_list = input_ids[0].tolist()
279
  tokens = attention_tokenizer.convert_ids_to_tokens(input_id_list)
280
 
281
+ # Using the specifically imported neuron_view_function
282
+ html_object = neuron_view_function(attention, tokens)
283
 
284
  # Extract HTML string from the IPython.core.display.HTML object
285
+ html_string = html_object.data # .data should provide the HTML string
 
 
 
 
 
 
286
 
287
  # Add D3 and jQuery CDN links to the HTML string for better rendering in Gradio
 
 
288
  html_with_deps = f"""
289
+ <script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.5.1/jquery.min.js"></script>
290
+ <script src="https://cdnjs.cloudflare.com/ajax/libs/d3/5.16.0/d3.min.js"></script>
291
  {html_string}
292
  """
293
+
294
+ # Clear cache after attention computation
295
+ clear_gpu_cache()
296
+
297
  return html_with_deps
298
  except Exception as e:
299
+ clear_gpu_cache() # Clear cache on error
300
+ logger.error(f"Error in visualize_attention_bertviz: {e}", exc_info=True)
301
+ return f"<p style='color:red;'>Error generating attention visualization: {str(e)}</p>"
302
 
303
  def display_molecule_image(smiles_string):
304
  """
 
347
  )
348
 
349
  with gr.Tab("Attention Visualization"):
350
+ gr.Markdown("Enter two SMILES strings to visualize **neuron-by-neuron attention** between them using BertViz. This may take a moment to render.")
351
  with gr.Row():
352
  smiles_a_input_attn = gr.Textbox(label="SMILES String A", value="CCCCC[C@@H](Br)CC")
353
  smiles_b_input_attn = gr.Textbox(label="SMILES String B", value="CCCCC[C@H](Br)CC")
354
  visualize_button_attn = gr.Button("Visualize Attention")
355
+ attention_html_output = gr.HTML(label="Attention Neuron View") # Changed label for clarity
356
 
357
  # Automatically populate on load for the default example
358
  demo.load(