Spaces:
Running
Running
File size: 10,183 Bytes
11e12c3 4ed4dfd a81f473 35ed017 11e12c3 e56ed1f 98e9d9e 7962386 11e12c3 e56ed1f ef610f3 c9fddab 98e9d9e ef610f3 1850745 11e12c3 98e9d9e 11e12c3 98e9d9e 11e12c3 98e9d9e 11e12c3 98e9d9e 11e12c3 98e9d9e 11e12c3 98e9d9e 11e12c3 7962386 ef610f3 11e12c3 7962386 11e12c3 e56ed1f c3644ec 11e12c3 c3644ec 11e12c3 98e9d9e e56ed1f 11e12c3 e56ed1f 11e12c3 e56ed1f 11e12c3 e56ed1f 11e12c3 98e9d9e 11e12c3 e56ed1f 11e12c3 7962386 11e12c3 7962386 11e12c3 7962386 e56ed1f 11e12c3 98e9d9e e56ed1f 11e12c3 35ed017 11e12c3 35ed017 11e12c3 35ed017 11e12c3 98e9d9e 11e12c3 b5c2863 11e12c3 98e9d9e 11e12c3 98e9d9e 11e12c3 e56ed1f 98e9d9e 11e12c3 98e9d9e 11e12c3 98e9d9e 11e12c3 4ed4dfd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 |
# app.py
# To run this app, save the code as app.py and run:
# streamlit run app.py
#
# You also need to install the following libraries:
# pip install streamlit torch transformers bitsandbytes rdkit-pypi py3Dmol pandas
import streamlit as st
import streamlit.components.v1 as components
import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer, pipeline, BitsAndBytesConfig
from rdkit import Chem
from rdkit.Chem import Draw, AllChem
from rdkit.Chem.Draw import MolToImage
import pandas as pd
import logging
# Set up logging to monitor quantization effects
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# --- Page Configuration ---
st.set_page_config(
page_title="ChemBERTa SMILES Utilities",
page_icon="π¬",
layout="wide",
)
# --- Model Loading (Cached for Performance) ---
@st.cache_resource(show_spinner="Loading ChemBERTa model...")
def load_models():
"""
Load the tokenizer and model, wrapped in a Streamlit cache resource decorator
to ensure it only runs once per session.
"""
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
quantization_config = None
try:
quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
bnb_8bit_compute_dtype=torch.float16,
bnb_8bit_use_double_quant=True,
)
logger.info("8-bit quantization configuration created.")
except ImportError:
logger.warning("bitsandbytes not available, falling back to standard loading.")
except Exception as e:
logger.warning(f"Quantization setup failed: {e}, using standard loading.")
model_name = "seyonec/PubChem10M_SMILES_BPE_450k"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model_kwargs = {"torch_dtype": torch_dtype}
if quantization_config and torch.cuda.is_available():
model_kwargs["quantization_config"] = quantization_config
model_kwargs["device_map"] = "auto"
elif torch.cuda.is_available():
model_kwargs["device_map"] = "auto"
try:
model = AutoModelForMaskedLM.from_pretrained(model_name, **model_kwargs)
model.eval()
pipeline_device = model.device.index if hasattr(model.device, 'type') and model.device.type == "cuda" else -1
fill_mask_pipeline = pipeline('fill-mask', model=model, tokenizer=tokenizer, device=pipeline_device)
logger.info("Models loaded successfully with optimizations.")
return tokenizer, fill_mask_pipeline
except Exception as e:
logger.error(f"Error loading optimized models: {e}. Retrying with standard loading.")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForMaskedLM.from_pretrained(model_name)
device_idx = 0 if torch.cuda.is_available() else -1
if torch.cuda.is_available():
model.to("cuda")
fill_mask_pipeline = pipeline('fill-mask', model=model, tokenizer=tokenizer, device=device_idx)
return tokenizer, fill_mask_pipeline
# Load the models once
fill_mask_tokenizer, fill_mask_pipeline = load_models()
# --- Molecule & Visualization Helpers ---
def get_mol(smiles):
"""Converts SMILES to RDKit Mol object and Kekulizes it."""
mol = Chem.MolFromSmiles(smiles)
if mol:
try:
Chem.Kekulize(mol)
except Exception:
pass
return mol
def find_matches_one(mol, submol_smarts):
"""Finds all matching atoms for a SMARTS pattern."""
if not mol or not submol_smarts: return []
submol = Chem.MolFromSmarts(submol_smarts)
return mol.GetSubstructMatches(submol) if submol else []
def get_image_with_highlight(mol, atomset=None, size=(300, 300)):
"""Draws a 2D molecule image with optional atom highlighting."""
if mol is None: return None
valid_atomset = [int(a) for a in atomset if str(a).isdigit()] if atomset else []
return MolToImage(mol, size=size, fitImage=True,
highlightAtoms=valid_atomset,
highlightAtomColors={i: (0, 1, 0, 0.5) for i in valid_atomset})
def generate_3d_view_html(smiles):
"""Generates an interactive 3D molecule view using py3Dmol."""
if not smiles: return None
mol = get_mol(smiles)
if not mol: return "<p>Invalid SMILES for 3D view.</p>"
try:
mol_3d = Chem.AddHs(mol)
AllChem.EmbedMolecule(mol_3d, randomSeed=42, useRandomCoords=True)
AllChem.MMFFOptimizeMolecule(mol_3d)
sdf_data = Chem.MolToMolBlock(mol_3d)
viewer = py3Dmol.view(width=350, height=350)
viewer.setBackgroundColor('#FFFFFF')
viewer.addModel(sdf_data, "sdf")
viewer.setStyle({'stick': {}, 'sphere': {'scale': 0.25}})
viewer.zoomTo()
return viewer._make_html()
except Exception as e:
logger.error(f"Failed to generate 3D view for {smiles}: {e}")
return f"<p>Error generating 3D view: {e}</p>"
# --- Core Application Logic ---
def run_masked_smiles_prediction(smiles_mask, substructure_smarts_highlight):
"""
Handles the logic for the masked SMILES prediction tab.
"""
if fill_mask_tokenizer.mask_token not in smiles_mask:
st.error(f"Error: Input SMILES must contain a mask token (e.g., {fill_mask_tokenizer.mask_token}).")
return
with st.spinner("Predicting completions..."):
try:
with torch.no_grad():
predictions = fill_mask_pipeline(smiles_mask, top_k=10)
except Exception as e:
st.error(f"An error occurred during prediction: {e}")
if torch.cuda.is_available(): torch.cuda.empty_cache()
return
results = []
for pred in predictions:
if len(results) >= 5: break
predicted_smiles = pred['sequence']
mol = get_mol(predicted_smiles)
if mol:
atom_matches = find_matches_one(mol, substructure_smarts_highlight)
results.append({
"smiles": predicted_smiles,
"score": f"{pred['score']:.4f}",
"image_2d": get_image_with_highlight(mol, atomset=atom_matches[0] if atom_matches else []),
"html_3d": generate_3d_view_html(predicted_smiles)
})
if torch.cuda.is_available(): torch.cuda.empty_cache()
st.session_state.prediction_results = results
# --- Streamlit UI Definition ---
st.title("π¬ ChemBERTa SMILES Utilities Dashboard (2D & 3D)")
st.markdown("A tool to predict masked tokens in SMILES strings and visualize molecules, powered by ChemBERTa and Streamlit.")
tab1, tab2 = st.tabs(["Masked SMILES Prediction", "Molecule Viewer (2D & 3D)"])
# --- Tab 1: Masked SMILES Prediction ---
with tab1:
st.header("Predict and Visualize Masked SMILES")
st.markdown("Enter a SMILES string with a `<mask>` token to predict possible completions.")
with st.form(key="prediction_form"):
col1, col2 = st.columns(2)
with col1:
smiles_input_masked = st.text_input(
"SMILES String with Mask",
value="C1=CC=CC<mask>C1",
help=f"The mask token is `{fill_mask_tokenizer.mask_token}`"
)
with col2:
substructure_input = st.text_input(
"Substructure to Highlight (SMARTS)",
value="C=C",
help="Enter a SMARTS pattern to highlight in the 2D images."
)
predict_button = st.form_submit_button("Predict and Visualize", use_container_width=True)
if predict_button:
run_masked_smiles_prediction(smiles_input_masked, substructure_input)
if 'prediction_results' in st.session_state and st.session_state.prediction_results:
results = st.session_state.prediction_results
st.subheader("Top 5 Valid Predictions")
# Display results in a table
df_data = [{"Predicted SMILES": r["smiles"], "Score": r["score"]} for r in results]
st.dataframe(pd.DataFrame(df_data), use_container_width=True)
st.markdown("---")
# Display molecule visualizations
for i, res in enumerate(results):
st.markdown(f"**Prediction {i+1}:** `{res['smiles']}` (Score: {res['score']})")
col1, col2 = st.columns(2)
with col1:
st.subheader("2D Structure")
if res["image_2d"]:
st.image(res["image_2d"], use_column_width=True)
else:
st.warning("Could not generate 2D image.")
with col2:
st.subheader("3D Interactive Structure")
if res["html_3d"]:
components.html(res["html_3d"], height=370)
else:
st.warning("Could not generate 3D view.")
st.markdown("---")
# --- Tab 2: Molecule Viewer ---
with tab2:
st.header("Visualize a Molecule from SMILES")
st.markdown("Enter a single SMILES string to display its 2D and 3D structures side-by-side.")
with st.form(key="viewer_form"):
smiles_input_viewer = st.text_input("SMILES String", value="CC(=O)Oc1ccccc1C(=O)O") # Aspirin
view_button = st.form_submit_button("View Molecule", use_container_width=True)
if view_button and smiles_input_viewer:
with st.spinner("Generating views..."):
mol = get_mol(smiles_input_viewer)
if not mol:
st.error("Invalid SMILES string provided.")
else:
st.subheader(f"Visualizations for: `{smiles_input_viewer}`")
col1, col2 = st.columns(2)
with col1:
st.subheader("2D Structure")
img_2d = MolToImage(mol, size=(450, 450), fitImage=True)
st.image(img_2d, use_column_width=True)
with col2:
st.subheader("3D Interactive Structure")
html_3d = generate_3d_view_html(smiles_input_viewer)
components.html(html_3d, height=470)
|