Spaces:
Running
Running
File size: 13,718 Bytes
35ed017 1850745 7962386 35ed017 7962386 ef610f3 1850745 7962386 c9fddab 7962386 ef610f3 1850745 35ed017 7962386 35ed017 7962386 35ed017 7962386 35ed017 7962386 ef610f3 7962386 ef610f3 7962386 ef610f3 7962386 ef610f3 7962386 ef610f3 7962386 ef610f3 7962386 ef610f3 7962386 ef610f3 35ed017 7962386 1850745 7962386 35ed017 7962386 1850745 7962386 3752899 7962386 3752899 35ed017 3752899 c9fddab 7962386 35ed017 7962386 35ed017 7962386 35ed017 7962386 35ed017 7962386 b5c2863 7962386 b5c2863 7962386 35ed017 7962386 35ed017 7962386 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 |
import streamlit as st
import pandas as pd
from rdkit import Chem
from rdkit.Chem import Draw, AllChem
from rdkit.Chem.Draw import rdMolDraw2D
import py3Dmol
import io
import base64
import logging
import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer, pipeline, BitsAndBytesConfig
# Set up logging to monitor quantization effects
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# --- Page Configuration ---
st.set_page_config(
page_title="Molecule Explorer & Predictor",
page_icon="π¬",
layout="wide",
initial_sidebar_state="collapsed",
)
# Custom CSS for a professional, minimalist look (adapted from drug_app.txt)
def apply_custom_styling():
st.markdown(
"""
<style>
@import url('https://fonts.googleapis.com/css2?family=Roboto:wght@400;700&display=swap');
html, body, [class*="st-"] {
font-family: 'Roboto', sans-serif;
}
.stApp {
background-color: rgb(28, 28, 28);
color: white;
}
/* Tab styles */
.stTabs [data-baseweb="tab-list"] {
gap: 24px;
}
.stTabs [data-baseweb="tab"] {
height: 50px;
white-space: pre-wrap;
background: none;
border-radius: 0px;
border-bottom: 2px solid #333;
padding: 10px 4px;
color: #AAA;
}
.stTabs [data-baseweb="tab"]:hover {
background: #222;
color: #FFF;
}
.stTabs [aria-selected="true"] {
border-bottom: 2px solid #00A0FF; /* Highlight color for active tab */
color: #FFF;
}
/* Button styles */
.stButton>button {
border-color: #00A0FF;
color: #00A0FF;
}
.stButton>button:hover {
border-color: #FFF;
color: #FFF;
background-color: #00A0FF;
}
</style>
""",
unsafe_allow_html=True
)
apply_custom_styling()
# --- Quantization Configuration ---
def get_quantization_config():
"""
Configure 8-bit quantization for model optimization.
Falls back gracefully if bitsandbytes is not available.
"""
try:
# 8-bit quantization configuration - good balance of speed and quality
quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
bnb_8bit_compute_dtype=torch.float16,
bnb_8bit_use_double_quant=True, # Nested quantization for better compression
)
logger.info("8-bit quantization configuration loaded successfully")
return quantization_config
except ImportError:
logger.warning("bitsandbytes not available, falling back to standard loading")
return None
except Exception as e:
logger.warning(f"Quantization setup failed: {e}, using standard loading")
return None
def get_torch_dtype():
"""Get appropriate torch dtype based on available hardware."""
if torch.cuda.is_available():
return torch.float16 # Use half precision on GPU
else:
return torch.float32 # Keep full precision on CPU
# --- Optimized Model Loading with Streamlit Caching ---
@st.cache_resource(show_spinner="Loading molecular language model...")
def load_optimized_models():
"""Load models with quantization and other optimizations using Streamlit caching."""
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = get_torch_dtype()
quantization_config = get_quantization_config()
logger.info(f"Loading models on device: {device} with dtype: {torch_dtype}")
# Model names
model_name = "seyonec/PubChem10M_SMILES_BPE_450k"
# Load tokenizer
fill_mask_tokenizer = AutoTokenizer.from_pretrained(model_name)
# Load model with quantization if available
model_kwargs = {
"torch_dtype": torch_dtype,
}
if quantization_config is not None and torch.cuda.is_available(): # Quantization typically for GPU
model_kwargs["quantization_config"] = quantization_config
model_kwargs["device_map"] = "auto"
elif torch.cuda.is_available():
model_kwargs["device_map"] = "auto" # For non-quantized GPU loading
else:
model_kwargs["device_map"] = None # For CPU
try:
fill_mask_model = AutoModelForMaskedLM.from_pretrained(
model_name,
**model_kwargs
)
fill_mask_model.eval()
pipeline_device = fill_mask_model.device.index if hasattr(fill_mask_model.device, 'type') and fill_mask_model.device.type == "cuda" else -1
fill_mask_pipeline = pipeline(
'fill-mask',
model=fill_mask_model,
tokenizer=fill_mask_tokenizer,
device=pipeline_device,
)
logger.info("Models loaded successfully with optimizations")
return fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline
except Exception as e:
logger.error(f"Error loading optimized models: {e}")
logger.info("Falling back to standard model loading...")
return load_standard_models(model_name)
@st.cache_resource(show_spinner="Loading standard molecular language model...")
def load_standard_models(model_name="seyonec/PubChem10M_SMILES_BPE_450k"):
"""Fallback standard model loading without quantization using Streamlit caching."""
fill_mask_tokenizer = AutoTokenizer.from_pretrained(model_name)
fill_mask_model = AutoModelForMaskedLM.from_pretrained(model_name)
device_idx = 0 if torch.cuda.is_available() else -1
fill_mask_pipeline = pipeline('fill-mask', model=fill_mask_model, tokenizer=fill_mask_tokenizer, device=device_idx)
if torch.cuda.is_available():
fill_mask_model.to("cuda")
logger.info("Standard models loaded successfully")
return fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline
# --- RDKit and Py3Dmol Visualization Functions ---
def mol_to_svg(mol, size=(400, 300)):
"""Converts an RDKit molecule object to an SVG image string using default RDKit colors."""
if not mol:
return None
drawer = rdMolDraw2D.MolDraw2DSVG(*size)
# Removing custom color settings as per user request to use default RDKit colors
# drawer.drawOptions().clearBackground = False # Keep background transparent/dark
# drawer.drawOptions().addStereoAnnotation = True
# drawer.drawOptions().baseFontSize = 0.8
# # Set dark theme colors for RDKit drawing - REMOVED AS PER USER REQUEST
# atom_colors = {
# 6: (0.8, 0.8, 0.8), # Carbon (light gray)
# 7: (0.2, 0.5, 1.0), # Nitrogen (blue)
# 8: (1.0, 0.2, 0.2), # Oxygen (red)
# 9: (0.2, 0.8, 0.2), # Fluorine (green)
# 15: (1.0, 0.5, 0.0), # Phosphorus (orange)
# 16: (1.0, 0.8, 0.0), # Sulfur (yellow)
# 17: (0.2, 0.7, 0.2), # Chlorine (dark green)
# 35: (0.5, 0.2, 0.8), # Bromine (purple)
# 53: (0.8, 0.2, 0.5), # Iodine (pink/magenta)
# }
# # Set default atom color
# drawer.drawOptions().setAtomColor(Chem.rdatomicnumlist.Get): (0.8, 0.8, 0.8) # Default to light gray for unknown atoms
# for atom_num, color in atom_colors.items():
# drawer.drawOptions().setAtomColor(atom_num, color)
# drawer.drawOptions().bondColor = (0.7, 0.7, 0.7) # Bond color (medium gray)
# drawer.drawOptions().highlightColour = (0.2, 0.6, 1.0) # Highlight color (blue)
drawer.DrawMolecule(mol)
drawer.FinishDrawing()
svg = drawer.GetDrawingText()
return svg
def mol_to_sdf(mol):
"""Converts an RDKit molecule object to an SDF string."""
if not mol:
return None
# Add hydrogens to the molecule
mol_with_h = Chem.AddHs(mol)
# Generate 3D coordinates using ETKDGv3, a common conformer generation method
# MaxAttempts is increased for robustness, randomSeed for reproducibility
try:
AllChem.EmbedMolecule(mol_with_h, AllChem.ETKDGv3(), maxAttempts=50, randomSeed=42)
# Optimize 3D coordinates using Universal Force Field (UFF)
AllChem.UFFOptimizeMolecule(mol_with_h)
sdf_string = Chem.MolToMolBlock(mol_with_h)
return sdf_string
except Exception as e:
logger.error(f"Error generating 3D coordinates for SMILES: {Chem.MolToSmiles(mol)} - {e}")
return None
def visualize_molecule_3d(mol_sdf: str, width='100%', height=400):
"""
Generates an interactive 3D molecule visualization using py3Dmol.
Accepts an SDF string.
"""
if not mol_sdf:
return None
try:
viewer = py3Dmol.view(width=width, height=height)
viewer.setBackgroundColor('#1C1C1C') # Dark background
viewer.addModel(mol_sdf, "sdf")
viewer.setStyle({'stick':{}, 'sphere':{'radius':0.3}}) # Stick and Sphere representation
viewer.zoomTo()
html_view = viewer._make_html()
return html_view
except Exception as e:
st.error(f"Error generating 3D visualization: {e}")
return None
# --- Main Streamlit Application Layout ---
st.title("π¬ Molecule Explorer & Predictor")
# Initialize session state for consistent data across reruns
if 'tokenizer' not in st.session_state:
st.session_state.tokenizer, st.session_state.model, st.session_state.pipeline = load_optimized_models()
tokenizer = st.session_state.tokenizer
model = st.session_state.model
fill_mask_pipeline = st.session_state.pipeline
tab1, tab2 = st.tabs(["Molecule Viewer (2D & 3D)", "Masked SMILES Predictor"])
with tab1:
st.header("Visualize Molecules in 2D and 3D")
smiles_input = st.text_input("Enter SMILES string:", "CCO", help="e.g., CCO (ethanol), C1=CC=CC=C1 (benzene)")
if st.button("View Molecule"):
if smiles_input:
mol = Chem.MolFromSmiles(smiles_input)
if mol:
st.subheader("2D Structure")
svg = mol_to_svg(mol)
if svg:
st.image(svg, use_column_width=True)
else:
st.warning("Could not generate 2D image.")
st.subheader("3D Structure (Interactive)")
sdf_string = mol_to_sdf(mol)
if sdf_string:
html_3d = visualize_molecule_3d(sdf_string)
if html_3d:
st.components.v1.html(html_3d, width=700, height=500, scrolling=False)
else:
st.warning("Could not generate 3D visualization.")
else:
st.warning("Could not generate 3D SDF data.")
else:
st.error("Invalid SMILES string. Please enter a valid chemical structure.")
else:
st.info("Please enter a SMILES string to view the molecule.")
with tab2:
st.header("Masked SMILES Prediction")
masked_smiles_input = st.text_input(
"Enter masked SMILES string (use `<mask>` for the masked token):",
"C1=CC=CC<mask>C1",
help="Example: 'C1=CC=CC<mask>C1' (masked benzene), 'CCO<mask>C' (masked ether)"
)
top_k_predictions = st.slider("Number of predictions to show:", 1, 10, 5)
if st.button("Predict Masked Token"):
if masked_smiles_input and "<mask>" in masked_smiles_input:
try:
# Perform prediction using the loaded pipeline
predictions = fill_mask_pipeline(masked_smiles_input, top_k=top_k_predictions)
prediction_data = []
for pred in predictions:
token_str = pred['token_str']
sequence = pred['sequence']
score = pred['score']
mol = Chem.MolFromSmiles(sequence)
img_svg = None
if mol:
img_svg = mol_to_svg(mol, size=(200,150)) # Smaller image for table
prediction_data.append({
"Predicted Token": token_str,
"Full SMILES": sequence,
"Confidence Score": f"{score:.4f}",
"Structure SVG": img_svg # Store SVG string
})
df_predictions = pd.DataFrame(prediction_data)
st.subheader("Predictions:")
# Create a version of the dataframe without the SVG for initial display
display_df = df_predictions.drop(columns=["Structure SVG"])
st.dataframe(display_df, use_container_width=True, hide_index=True)
st.subheader("Predicted Structures:")
# Determine the number of columns based on the number of predictions, up to a max
num_cols = min(len(df_predictions), 5) # Display up to 5 images per row
cols = st.columns(num_cols)
for i, row in df_predictions.iterrows():
with cols[i % num_cols]: # Distribute images into columns
st.markdown(f"**{row['Predicted Token']}** (Score: {row['Confidence Score']})")
if row['Structure SVG']:
st.image(row['Structure SVG'], use_column_width='auto')
else:
st.write("*(Invalid SMILES)*")
except Exception as e:
st.error(f"An error occurred during prediction: {e}")
st.info("Please ensure your masked SMILES is valid and contains `<mask>`.")
else:
st.info("Please enter a masked SMILES string (e.g., `C1=CC=CC<mask>C1`).")
|