alidenewade commited on
Commit
ba49293
·
verified ·
1 Parent(s): 4ed4dfd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -236
app.py CHANGED
@@ -1,253 +1,75 @@
1
  # app.py
2
- # To run this app, save the code as app.py and run:
3
- # streamlit run app.py
4
- #
5
- # You also need to install the following libraries:
6
- # pip install streamlit torch transformers bitsandbytes rdkit-pypi py3Dmol pandas
7
-
8
- import streamlit as st
9
- import streamlit.components.v1 as components
10
  import torch
11
  from transformers import AutoModelForMaskedLM, AutoTokenizer, pipeline, BitsAndBytesConfig
12
  from rdkit import Chem
13
- from rdkit.Chem import Draw, AllChem
14
  from rdkit.Chem.Draw import MolToImage
 
 
15
  import pandas as pd
 
 
 
16
  import logging
 
 
 
 
 
17
 
18
- # Set up logging to monitor quantization effects
19
- logging.basicConfig(level=logging.INFO)
20
- logger = logging.getLogger(__name__)
21
 
22
- # --- Page Configuration ---
23
- st.set_page_config(
24
- page_title="ChemBERTa SMILES Utilities",
25
- page_icon="🔬",
26
- layout="wide",
27
- )
28
 
29
- # --- Model Loading (Cached for Performance) ---
 
30
 
31
- @st.cache_resource(show_spinner="Loading ChemBERTa model...")
32
- def load_models():
33
- """
34
- Load the tokenizer and model, wrapped in a Streamlit cache resource decorator
35
- to ensure it only runs once per session.
36
- """
37
- device = "cuda" if torch.cuda.is_available() else "cpu"
38
- torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
39
- quantization_config = None
40
-
41
- try:
42
- quantization_config = BitsAndBytesConfig(
43
- load_in_8bit=True,
44
- bnb_8bit_compute_dtype=torch.float16,
45
- bnb_8bit_use_double_quant=True,
46
  )
47
- logger.info("8-bit quantization configuration created.")
48
- except ImportError:
49
- logger.warning("bitsandbytes not available, falling back to standard loading.")
50
- except Exception as e:
51
- logger.warning(f"Quantization setup failed: {e}, using standard loading.")
52
 
53
- model_name = "seyonec/PubChem10M_SMILES_BPE_450k"
54
- tokenizer = AutoTokenizer.from_pretrained(model_name)
55
-
56
- model_kwargs = {"torch_dtype": torch_dtype}
57
- if quantization_config and torch.cuda.is_available():
58
- model_kwargs["quantization_config"] = quantization_config
59
- model_kwargs["device_map"] = "auto"
60
- elif torch.cuda.is_available():
61
- model_kwargs["device_map"] = "auto"
62
-
63
- try:
64
- model = AutoModelForMaskedLM.from_pretrained(model_name, **model_kwargs)
65
- model.eval()
66
- pipeline_device = model.device.index if hasattr(model.device, 'type') and model.device.type == "cuda" else -1
67
- fill_mask_pipeline = pipeline('fill-mask', model=model, tokenizer=tokenizer, device=pipeline_device)
68
- logger.info("Models loaded successfully with optimizations.")
69
- return tokenizer, fill_mask_pipeline
70
- except Exception as e:
71
- logger.error(f"Error loading optimized models: {e}. Retrying with standard loading.")
72
- tokenizer = AutoTokenizer.from_pretrained(model_name)
73
- model = AutoModelForMaskedLM.from_pretrained(model_name)
74
- device_idx = 0 if torch.cuda.is_available() else -1
75
- if torch.cuda.is_available():
76
- model.to("cuda")
77
- fill_mask_pipeline = pipeline('fill-mask', model=model, tokenizer=tokenizer, device=device_idx)
78
- return tokenizer, fill_mask_pipeline
79
-
80
- # Load the models once
81
- fill_mask_tokenizer, fill_mask_pipeline = load_models()
82
-
83
-
84
- # --- Molecule & Visualization Helpers ---
85
-
86
- def get_mol(smiles):
87
- """Converts SMILES to RDKit Mol object and Kekulizes it."""
88
- mol = Chem.MolFromSmiles(smiles)
89
- if mol:
90
- try:
91
- Chem.Kekulize(mol)
92
- except Exception:
93
- pass
94
- return mol
95
-
96
- def find_matches_one(mol, submol_smarts):
97
- """Finds all matching atoms for a SMARTS pattern."""
98
- if not mol or not submol_smarts: return []
99
- submol = Chem.MolFromSmarts(submol_smarts)
100
- return mol.GetSubstructMatches(submol) if submol else []
101
-
102
- def get_image_with_highlight(mol, atomset=None, size=(300, 300)):
103
- """Draws a 2D molecule image with optional atom highlighting."""
104
- if mol is None: return None
105
- valid_atomset = [int(a) for a in atomset if str(a).isdigit()] if atomset else []
106
- return MolToImage(mol, size=size, fitImage=True,
107
- highlightAtoms=valid_atomset,
108
- highlightAtomColors={i: (0, 1, 0, 0.5) for i in valid_atomset})
109
-
110
- def generate_3d_view_html(smiles):
111
- """Generates an interactive 3D molecule view using py3Dmol."""
112
- if not smiles: return None
113
- mol = get_mol(smiles)
114
- if not mol: return "<p>Invalid SMILES for 3D view.</p>"
115
- try:
116
- mol_3d = Chem.AddHs(mol)
117
- AllChem.EmbedMolecule(mol_3d, randomSeed=42, useRandomCoords=True)
118
- AllChem.MMFFOptimizeMolecule(mol_3d)
119
- sdf_data = Chem.MolToMolBlock(mol_3d)
120
-
121
- viewer = py3Dmol.view(width=350, height=350)
122
- viewer.setBackgroundColor('#FFFFFF')
123
- viewer.addModel(sdf_data, "sdf")
124
- viewer.setStyle({'stick': {}, 'sphere': {'scale': 0.25}})
125
- viewer.zoomTo()
126
- return viewer._make_html()
127
  except Exception as e:
128
- logger.error(f"Failed to generate 3D view for {smiles}: {e}")
129
- return f"<p>Error generating 3D view: {e}</p>"
 
130
 
131
- # --- Core Application Logic ---
132
 
133
- def run_masked_smiles_prediction(smiles_mask, substructure_smarts_highlight):
134
- """
135
- Handles the logic for the masked SMILES prediction tab.
136
- """
137
- if fill_mask_tokenizer.mask_token not in smiles_mask:
138
- st.error(f"Error: Input SMILES must contain a mask token (e.g., {fill_mask_tokenizer.mask_token}).")
139
- return
140
-
141
- with st.spinner("Predicting completions..."):
142
- try:
143
- with torch.no_grad():
144
- predictions = fill_mask_pipeline(smiles_mask, top_k=10)
145
- except Exception as e:
146
- st.error(f"An error occurred during prediction: {e}")
147
- if torch.cuda.is_available(): torch.cuda.empty_cache()
148
- return
149
-
150
- results = []
151
- for pred in predictions:
152
- if len(results) >= 5: break
153
- predicted_smiles = pred['sequence']
154
- mol = get_mol(predicted_smiles)
155
- if mol:
156
- atom_matches = find_matches_one(mol, substructure_smarts_highlight)
157
- results.append({
158
- "smiles": predicted_smiles,
159
- "score": f"{pred['score']:.4f}",
160
- "image_2d": get_image_with_highlight(mol, atomset=atom_matches[0] if atom_matches else []),
161
- "html_3d": generate_3d_view_html(predicted_smiles)
162
- })
163
-
164
- if torch.cuda.is_available(): torch.cuda.empty_cache()
165
- st.session_state.prediction_results = results
166
-
167
-
168
- # --- Streamlit UI Definition ---
169
-
170
- st.title("🔬 ChemBERTa SMILES Utilities Dashboard (2D & 3D)")
171
- st.markdown("A tool to predict masked tokens in SMILES strings and visualize molecules, powered by ChemBERTa and Streamlit.")
172
-
173
- tab1, tab2 = st.tabs(["Masked SMILES Prediction", "Molecule Viewer (2D & 3D)"])
174
-
175
- # --- Tab 1: Masked SMILES Prediction ---
176
- with tab1:
177
- st.header("Predict and Visualize Masked SMILES")
178
- st.markdown("Enter a SMILES string with a `<mask>` token to predict possible completions.")
179
-
180
- with st.form(key="prediction_form"):
181
- col1, col2 = st.columns(2)
182
- with col1:
183
- smiles_input_masked = st.text_input(
184
- "SMILES String with Mask",
185
- value="C1=CC=CC<mask>C1",
186
- help=f"The mask token is `{fill_mask_tokenizer.mask_token}`"
187
- )
188
- with col2:
189
- substructure_input = st.text_input(
190
- "Substructure to Highlight (SMARTS)",
191
- value="C=C",
192
- help="Enter a SMARTS pattern to highlight in the 2D images."
193
- )
194
-
195
- predict_button = st.form_submit_button("Predict and Visualize", use_container_width=True)
196
-
197
- if predict_button:
198
- run_masked_smiles_prediction(smiles_input_masked, substructure_input)
199
-
200
- if 'prediction_results' in st.session_state and st.session_state.prediction_results:
201
- results = st.session_state.prediction_results
202
- st.subheader("Top 5 Valid Predictions")
203
-
204
- # Display results in a table
205
- df_data = [{"Predicted SMILES": r["smiles"], "Score": r["score"]} for r in results]
206
- st.dataframe(pd.DataFrame(df_data), use_container_width=True)
207
-
208
- st.markdown("---")
209
-
210
- # Display molecule visualizations
211
- for i, res in enumerate(results):
212
- st.markdown(f"**Prediction {i+1}:** `{res['smiles']}` (Score: {res['score']})")
213
- col1, col2 = st.columns(2)
214
- with col1:
215
- st.subheader("2D Structure")
216
- if res["image_2d"]:
217
- st.image(res["image_2d"], use_column_width=True)
218
- else:
219
- st.warning("Could not generate 2D image.")
220
- with col2:
221
- st.subheader("3D Interactive Structure")
222
- if res["html_3d"]:
223
- components.html(res["html_3d"], height=370)
224
- else:
225
- st.warning("Could not generate 3D view.")
226
- st.markdown("---")
227
-
228
- # --- Tab 2: Molecule Viewer ---
229
- with tab2:
230
- st.header("Visualize a Molecule from SMILES")
231
- st.markdown("Enter a single SMILES string to display its 2D and 3D structures side-by-side.")
232
-
233
- with st.form(key="viewer_form"):
234
- smiles_input_viewer = st.text_input("SMILES String", value="CC(=O)Oc1ccccc1C(=O)O") # Aspirin
235
- view_button = st.form_submit_button("View Molecule", use_container_width=True)
236
-
237
- if view_button and smiles_input_viewer:
238
- with st.spinner("Generating views..."):
239
- mol = get_mol(smiles_input_viewer)
240
- if not mol:
241
- st.error("Invalid SMILES string provided.")
242
- else:
243
- st.subheader(f"Visualizations for: `{smiles_input_viewer}`")
244
- col1, col2 = st.columns(2)
245
- with col1:
246
- st.subheader("2D Structure")
247
- img_2d = MolToImage(mol, size=(450, 450), fitImage=True)
248
- st.image(img_2d, use_column_width=True)
249
- with col2:
250
- st.subheader("3D Interactive Structure")
251
- html_3d = generate_3d_view_html(smiles_input_viewer)
252
- components.html(html_3d, height=470)
253
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # app.py
2
+ import gradio as gr
 
 
 
 
 
 
 
3
  import torch
4
  from transformers import AutoModelForMaskedLM, AutoTokenizer, pipeline, BitsAndBytesConfig
5
  from rdkit import Chem
6
+ from rdkit.Chem import Draw, rdFMCS
7
  from rdkit.Chem.Draw import MolToImage
8
+ # PIL is imported as Image by rdkit.Chem.Draw.MolToImage, but explicit import is good practice if used directly.
9
+ # from PIL import Image
10
  import pandas as pd
11
+
12
+ import io
13
+ import base64
14
  import logging
15
+ # Model names
16
+ model_name = "seyonec/PubChem10M_SMILES_BPE_450k"
17
+
18
+ # Load tokenizer (doesn't need quantization)
19
+ fill_mask_tokenizer = AutoTokenizer.from_pretrained(model_name)
20
 
 
 
 
21
 
22
+ # Load model with quantization if available
23
+ model_kwargs = {
24
+ "torch_dtype": torch_dtype,
25
+ }
26
+ **model_kwargs
27
+ )
28
 
29
+ # Set model to evaluation mode for inference
30
+ fill_mask_model.eval()
31
 
32
+
33
+ # Create optimized pipeline
34
+ # Let pipeline infer device from model if possible, or set based on model's device
35
+ pipeline_device = fill_mask_model.device.index if hasattr(fill_mask_model.device, 'type') and fill_mask_model.device.type == "cuda" else -1
36
+
37
+
38
+ fill_mask_pipeline = pipeline(
39
+ 'fill-mask',
40
+ model=fill_mask_model,
 
 
 
 
 
 
41
  )
 
 
 
 
 
42
 
43
+ logger.info("Models loaded successfully with optimizations")
44
+ return fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline
45
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  except Exception as e:
47
+ logger.error(f"Error loading optimized models: {e}")
48
+ device_idx = 0 if torch.cuda.is_available() else -1
49
+ fill_mask_pipeline = pipeline('fill-mask', model=fill_mask_model, tokenizer=fill_mask_tokenizer, device=device_idx)
50
 
 
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
+
54
+ if torch.cuda.is_available():
55
+ fill_mask_model.to("cuda")
56
+
57
+
58
+ return fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline
59
+
60
+ # Load models with optimizations
61
+ fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline = load_optimized_models()
62
+
63
+ # --- Memory Management Utilities ---
64
+ def clear_gpu_cache():
65
+ # Unpack image_list into individual image outputs + df_results + status_message
66
+ return df_results, image_list[0], image_list[1], image_list[2], image_list[3], image_list[4], status_message
67
+
68
+ def display_molecule_image(smiles_string):
69
+ """
70
+ Displays a 2D image of a molecule from its SMILES string.
71
+ outputs=[predictions_table, img_out_1, img_out_2, img_out_3, img_out_4, img_out_5, status_masked]
72
+ )
73
+ with gr.Tab("Molecule Viewer"):
74
+ gr.Markdown("Enter a SMILES string to display its 2D structure.")
75
+ smiles_input_viewer = gr.Textbox(label="SMILES String", value="C1=CC=CC=C1")