alidenewade commited on
Commit
1850745
·
verified ·
1 Parent(s): d81a373

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +234 -0
app.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import gradio as gr
3
+ import torch
4
+ from transformers import AutoModelForMaskedLM, AutoTokenizer, pipeline, RobertaModel, RobertaTokenizer
5
+ from rdkit import Chem
6
+ from rdkit.Chem import Draw, rdFMCS
7
+ from rdkit.Chem.Draw import MolToImage
8
+ from PIL importImage
9
+ import pandas as pd
10
+ from bertviz import head_view
11
+ from IPython.core.display import HTML
12
+ import io
13
+ import base64
14
+
15
+ # --- Model and Tokenizer Loading ---
16
+ # Masked LM Model
17
+ fill_mask_model_name = "seyonec/PubChem10M_SMILES_BPE_450k"
18
+ fill_mask_tokenizer = AutoTokenizer.from_pretrained(fill_mask_model_name)
19
+ fill_mask_model = AutoModelForMaskedLM.from_pretrained(fill_mask_model_name)
20
+ fill_mask_pipeline = pipeline('fill-mask', model=fill_mask_model, tokenizer=fill_mask_tokenizer)
21
+
22
+ # Roberta Model for Attention
23
+ attention_model_name = 'seyonec/PubChem10M_SMILES_BPE_450k' # Can be same or different as needed
24
+ attention_model = RobertaModel.from_pretrained(attention_model_name, output_attentions=True)
25
+ attention_tokenizer = RobertaTokenizer.from_pretrained(attention_model_name)
26
+
27
+ # --- Helper Functions from Notebook (adapted) ---
28
+ def get_mol(smiles):
29
+ """Converts SMILES to RDKit Mol object and Kekulizes it."""
30
+ mol = Chem.MolFromSmiles(smiles)
31
+ if mol is None:
32
+ return None
33
+ try:
34
+ Chem.Kekulize(mol)
35
+ except: # Kekulization can fail for some structures
36
+ pass
37
+ return mol
38
+
39
+ def find_matches_one(mol, submol_smarts):
40
+ """Finds all matching atoms for a SMARTS pattern in a molecule."""
41
+ if not mol or not submol_smarts:
42
+ return []
43
+ submol = Chem.MolFromSmarts(submol_smarts)
44
+ if not submol:
45
+ return []
46
+ matches = mol.GetSubstructMatches(submol)
47
+ return matches
48
+
49
+ def get_image_with_highlight(mol, atomset=None, size=(300, 300)):
50
+ """Draws molecule with optional atom highlighting."""
51
+ if mol is None:
52
+ return None
53
+ highlight_color = (0, 1, 0, 0.5) # Green with some transparency
54
+ img = MolToImage(mol, size=size, fitImage=True,
55
+ highlightAtoms=atomset if atomset else [],
56
+ highlightAtomColors={i: highlight_color for i in atomset} if atomset else {})
57
+ return img
58
+
59
+ # --- Gradio Interface Functions ---
60
+
61
+ def predict_and_visualize_masked_smiles(smiles_mask, substructure_smarts_highlight="CC=CC"):
62
+ """
63
+ Predicts masked tokens in a SMILES string, shows scores, and visualizes molecules.
64
+ """
65
+ if fill_mask_tokenizer.mask_token not in smiles_mask:
66
+ return pd.DataFrame(), [None]*5, "Error: Input SMILES must contain a mask token (e.g., <mask>)."
67
+
68
+ try:
69
+ predictions = fill_mask_pipeline(smiles_mask, top_k=10) # Get more to filter for valid ones
70
+ except Exception as e:
71
+ return pd.DataFrame(), [None]*5, f"Error during prediction: {str(e)}"
72
+
73
+ results_data = []
74
+ image_list = []
75
+ valid_predictions_count = 0
76
+
77
+ for pred in predictions:
78
+ if valid_predictions_count >= 5:
79
+ break
80
+
81
+ predicted_smiles = pred['sequence']
82
+ score = pred['score']
83
+
84
+ mol = get_mol(predicted_smiles)
85
+ if mol:
86
+ results_data.append({"Predicted SMILES": predicted_smiles, "Score": f"{score:.4f}"})
87
+
88
+ atom_matches = []
89
+ if substructure_smarts_highlight:
90
+ matches = find_matches_one(mol, substructure_smarts_highlight)
91
+ if matches:
92
+ atom_matches = list(matches[0]) # Highlight first match
93
+
94
+ img = get_image_with_highlight(mol, atomset=atom_matches)
95
+ image_list.append(img)
96
+ valid_predictions_count += 1
97
+
98
+ # Pad image_list if fewer than 5 valid predictions
99
+ while len(image_list) < 5:
100
+ image_list.append(None)
101
+
102
+ df_results = pd.DataFrame(results_data)
103
+ return df_results, image_list, "Prediction successful." if valid_predictions_count > 0 else "No valid molecules found for top predictions."
104
+
105
+
106
+ def visualize_attention_bertviz(sentence_a, sentence_b):
107
+ """
108
+ Generates and displays BertViz attention head view as HTML.
109
+ """
110
+ if not sentence_a or not sentence_b:
111
+ return "Please provide two SMILES strings."
112
+ try:
113
+ inputs = attention_tokenizer.encode_plus(sentence_a, sentence_b, return_tensors='pt', add_special_tokens=True)
114
+ input_ids = inputs['input_ids']
115
+
116
+ # Ensure model is in eval mode and no_grad for inference
117
+ attention_model.eval()
118
+ with torch.no_grad():
119
+ attention_outputs = attention_model(input_ids)
120
+
121
+ attention = attention_outputs[-1] # Last item in the tuple is attentions
122
+ input_id_list = input_ids[0].tolist()
123
+ tokens = attention_tokenizer.convert_ids_to_tokens(input_id_list)
124
+
125
+ html_object = head_view(attention, tokens, display_mode="light") # Use light mode for better Gradio compatibility
126
+
127
+ # Extract HTML string from the IPython.core.display.HTML object
128
+ html_string = html_object.data
129
+
130
+ # Embed JavaScript directly if needed, or ensure Gradio's HTML component handles it.
131
+ # BertViz often requires D3.js and jQuery. Gradio's HTML component might not execute all JS.
132
+ # For robustness, it's better if head_view produces self-contained HTML or if Gradio supports JS execution.
133
+ # A common workaround is to serve the HTML and use an iframe, or save to file and link.
134
+ # Here, we'll return the raw HTML string and let Gradio's gr.HTML handle it.
135
+
136
+ # Add D3 and jQuery CDN links to the HTML string for better rendering in Gradio
137
+ # This is a common workaround if Gradio's HTML component doesn't include these by default
138
+ # Note: This might still have limitations depending on Gradio's sandboxing.
139
+ html_with_deps = f"""
140
+ <script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/2.0.0/jquery.min.js"></script>
141
+ <script src="https://cdnjs.cloudflare.com/ajax/libs/d3/3.5.8/d3.min.js"></script>
142
+ {html_string}
143
+ """
144
+ return html_with_deps
145
+ except Exception as e:
146
+ return f"Error generating attention visualization: {str(e)}"
147
+
148
+ def display_molecule_image(smiles_string):
149
+ """
150
+ Displays a 2D image of a molecule from its SMILES string.
151
+ """
152
+ if not smiles_string:
153
+ return None, "Please enter a SMILES string."
154
+ mol = get_mol(smiles_string)
155
+ if mol is None:
156
+ return None, "Invalid SMILES string."
157
+ img = MolToImage(mol, size=(400, 400), fitImage=True)
158
+ return img, "Molecule displayed."
159
+
160
+ # --- Gradio Interface Definition ---
161
+ with gr.Blocks(theme=gr.themes.Default()) as demo:
162
+ gr.Markdown("# ChemBERTa SMILES Utilities Dashboard")
163
+
164
+ with gr.Tab("Masked SMILES Prediction"):
165
+ gr.Markdown("Enter a SMILES string with a `<mask>` token (e.g., `C1=CC=CC<mask>C1`) to predict possible completions.")
166
+ with gr.Row():
167
+ smiles_input_masked = gr.Textbox(label="SMILES String with Mask", value="C1=CC=CC<mask>C1")
168
+ substructure_input = gr.Textbox(label="Substructure to Highlight (SMARTS)", value="C=C")
169
+ predict_button_masked = gr.Button("Predict and Visualize")
170
+
171
+ status_masked = gr.Textbox(label="Status", interactive=False)
172
+ predictions_table = gr.DataFrame(label="Top Predictions & Scores")
173
+
174
+ gr.Markdown("### Predicted Molecule Visualizations (Top 5 Valid)")
175
+ with gr.Row():
176
+ img_out_1 = gr.Image(label="Prediction 1", type="pil", interactive=False)
177
+ img_out_2 = gr.Image(label="Prediction 2", type="pil", interactive=False)
178
+ img_out_3 = gr.Image(label="Prediction 3", type="pil", interactive=False)
179
+ img_out_4 = gr.Image(label="Prediction 4", type="pil", interactive=False)
180
+ img_out_5 = gr.Image(label="Prediction 5", type="pil", interactive=False)
181
+
182
+ # Automatically populate on load for the default example
183
+ demo.load(
184
+ lambda: predict_and_visualize_masked_smiles("C1=CC=CC<mask>C1", "C=C"),
185
+ inputs=None,
186
+ outputs=[predictions_table, img_out_1, img_out_2, img_out_3, img_out_4, img_out_5, status_masked]
187
+ )
188
+ predict_button_masked.click(
189
+ predict_and_visualize_masked_smiles,
190
+ inputs=[smiles_input_masked, substructure_input],
191
+ outputs=[predictions_table, img_out_1, img_out_2, img_out_3, img_out_4, img_out_5, status_masked]
192
+ )
193
+
194
+ with gr.Tab("Attention Visualization"):
195
+ gr.Markdown("Enter two SMILES strings to visualize attention between them using BertViz. This may take a moment to render.")
196
+ with gr.Row():
197
+ smiles_a_input_attn = gr.Textbox(label="SMILES String A", value="CCCCC[C@@H](Br)CC")
198
+ smiles_b_input_attn = gr.Textbox(label="SMILES String B", value="CCCCC[C@H](Br)CC")
199
+ visualize_button_attn = gr.Button("Visualize Attention")
200
+ attention_html_output = gr.HTML(label="Attention Head View")
201
+
202
+ # Automatically populate on load for the default example
203
+ demo.load(
204
+ lambda: visualize_attention_bertviz("CCCCC[C@@H](Br)CC", "CCCCC[C@H](Br)CC"),
205
+ inputs=None,
206
+ outputs=[attention_html_output]
207
+ )
208
+ visualize_button_attn.click(
209
+ visualize_attention_bertviz,
210
+ inputs=[smiles_a_input_attn, smiles_b_input_attn],
211
+ outputs=[attention_html_output]
212
+ )
213
+
214
+ with gr.Tab("Molecule Viewer"):
215
+ gr.Markdown("Enter a SMILES string to display its 2D structure.")
216
+ smiles_input_viewer = gr.Textbox(label="SMILES String", value="C1=CC=CC=C1")
217
+ view_button_molecule = gr.Button("View Molecule")
218
+ status_viewer = gr.Textbox(label="Status", interactive=False)
219
+ molecule_image_output = gr.Image(label="Molecule Structure", type="pil", interactive=False)
220
+
221
+ # Automatically populate on load for the default example
222
+ demo.load(
223
+ lambda: display_molecule_image("C1=CC=CC=C1"),
224
+ inputs=None,
225
+ outputs=[molecule_image_output, status_viewer]
226
+ )
227
+ view_button_molecule.click(
228
+ display_molecule_image,
229
+ inputs=[smiles_input_viewer],
230
+ outputs=[molecule_image_output, status_viewer]
231
+ )
232
+
233
+ if __name__ == "__main__":
234
+ demo.launch()