Spaces:
Running
Running
import streamlit as st | |
import pandas as pd | |
from rdkit import Chem | |
from rdkit.Chem import Draw, AllChem | |
from rdkit.Chem.Draw import rdMolDraw2D | |
import py3Dmol | |
import io | |
import base64 | |
import logging | |
import torch | |
from transformers import AutoModelForMaskedLM, AutoTokenizer, pipeline, BitsAndBytesConfig | |
# Set up logging to monitor quantization effects | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# --- Page Configuration --- | |
st.set_page_config( | |
page_title="Molecule Explorer & Predictor", | |
page_icon="π¬", | |
layout="wide", | |
initial_sidebar_state="collapsed", | |
) | |
# Custom CSS for a professional, minimalist look (adapted from drug_app.txt) | |
def apply_custom_styling(): | |
st.markdown( | |
""" | |
<style> | |
@import url('https://fonts.googleapis.com/css2?family=Roboto:wght@400;700&display=swap'); | |
html, body, [class*="st-"] { | |
font-family: 'Roboto', sans-serif; | |
} | |
.stApp { | |
background-color: rgb(28, 28, 28); | |
color: white; | |
} | |
/* Tab styles */ | |
.stTabs [data-baseweb="tab-list"] { | |
gap: 24px; | |
} | |
.stTabs [data-baseweb="tab"] { | |
height: 50px; | |
white-space: pre-wrap; | |
background: none; | |
border-radius: 0px; | |
border-bottom: 2px solid #333; | |
padding: 10px 4px; | |
color: #AAA; | |
} | |
.stTabs [data-baseweb="tab"]:hover { | |
background: #222; | |
color: #FFF; | |
} | |
.stTabs [aria-selected="true"] { | |
border-bottom: 2px solid #00A0FF; /* Highlight color for active tab */ | |
color: #FFF; | |
} | |
/* Button styles */ | |
.stButton>button { | |
border-color: #00A0FF; | |
color: #00A0FF; | |
} | |
.stButton>button:hover { | |
border-color: #FFF; | |
color: #FFF; | |
background-color: #00A0FF; | |
} | |
</style> | |
""", | |
unsafe_allow_html=True | |
) | |
apply_custom_styling() | |
# --- Quantization Configuration --- | |
def get_quantization_config(): | |
""" | |
Configure 8-bit quantization for model optimization. | |
Falls back gracefully if bitsandbytes is not available. | |
""" | |
try: | |
# 8-bit quantization configuration - good balance of speed and quality | |
quantization_config = BitsAndBytesConfig( | |
load_in_8bit=True, | |
bnb_8bit_compute_dtype=torch.float16, | |
bnb_8bit_use_double_quant=True, # Nested quantization for better compression | |
) | |
logger.info("8-bit quantization configuration loaded successfully") | |
return quantization_config | |
except ImportError: | |
logger.warning("bitsandbytes not available, falling back to standard loading") | |
return None | |
except Exception as e: | |
logger.warning(f"Quantization setup failed: {e}, using standard loading") | |
return None | |
def get_torch_dtype(): | |
"""Get appropriate torch dtype based on available hardware.""" | |
if torch.cuda.is_available(): | |
return torch.float16 # Use half precision on GPU | |
else: | |
return torch.float32 # Keep full precision on CPU | |
# --- Optimized Model Loading with Streamlit Caching --- | |
def load_optimized_models(): | |
"""Load models with quantization and other optimizations using Streamlit caching.""" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
torch_dtype = get_torch_dtype() | |
quantization_config = get_quantization_config() | |
logger.info(f"Loading models on device: {device} with dtype: {torch_dtype}") | |
# Model names | |
model_name = "seyonec/PubChem10M_SMILES_BPE_450k" | |
# Load tokenizer | |
fill_mask_tokenizer = AutoTokenizer.from_pretrained(model_name) | |
# Load model with quantization if available | |
model_kwargs = { | |
"torch_dtype": torch_dtype, | |
} | |
if quantization_config is not None and torch.cuda.is_available(): # Quantization typically for GPU | |
model_kwargs["quantization_config"] = quantization_config | |
model_kwargs["device_map"] = "auto" | |
elif torch.cuda.is_available(): | |
model_kwargs["device_map"] = "auto" # For non-quantized GPU loading | |
else: | |
model_kwargs["device_map"] = None # For CPU | |
try: | |
fill_mask_model = AutoModelForMaskedLM.from_pretrained( | |
model_name, | |
**model_kwargs | |
) | |
fill_mask_model.eval() | |
pipeline_device = fill_mask_model.device.index if hasattr(fill_mask_model.device, 'type') and fill_mask_model.device.type == "cuda" else -1 | |
fill_mask_pipeline = pipeline( | |
'fill-mask', | |
model=fill_mask_model, | |
tokenizer=fill_mask_tokenizer, | |
device=pipeline_device, | |
) | |
logger.info("Models loaded successfully with optimizations") | |
return fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline | |
except Exception as e: | |
logger.error(f"Error loading optimized models: {e}") | |
logger.info("Falling back to standard model loading...") | |
return load_standard_models(model_name) | |
def load_standard_models(model_name="seyonec/PubChem10M_SMILES_BPE_450k"): | |
"""Fallback standard model loading without quantization using Streamlit caching.""" | |
fill_mask_tokenizer = AutoTokenizer.from_pretrained(model_name) | |
fill_mask_model = AutoModelForMaskedLM.from_pretrained(model_name) | |
device_idx = 0 if torch.cuda.is_available() else -1 | |
fill_mask_pipeline = pipeline('fill-mask', model=fill_mask_model, tokenizer=fill_mask_tokenizer, device=device_idx) | |
if torch.cuda.is_available(): | |
fill_mask_model.to("cuda") | |
logger.info("Standard models loaded successfully") | |
return fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline | |
# --- RDKit and Py3Dmol Visualization Functions --- | |
def mol_to_svg(mol, size=(400, 300)): | |
"""Converts an RDKit molecule object to an SVG image string using default RDKit colors.""" | |
if not mol: | |
return None | |
drawer = rdMolDraw2D.MolDraw2DSVG(*size) | |
# Removing custom color settings as per user request to use default RDKit colors | |
# drawer.drawOptions().clearBackground = False # Keep background transparent/dark | |
# drawer.drawOptions().addStereoAnnotation = True | |
# drawer.drawOptions().baseFontSize = 0.8 | |
# # Set dark theme colors for RDKit drawing - REMOVED AS PER USER REQUEST | |
# atom_colors = { | |
# 6: (0.8, 0.8, 0.8), # Carbon (light gray) | |
# 7: (0.2, 0.5, 1.0), # Nitrogen (blue) | |
# 8: (1.0, 0.2, 0.2), # Oxygen (red) | |
# 9: (0.2, 0.8, 0.2), # Fluorine (green) | |
# 15: (1.0, 0.5, 0.0), # Phosphorus (orange) | |
# 16: (1.0, 0.8, 0.0), # Sulfur (yellow) | |
# 17: (0.2, 0.7, 0.2), # Chlorine (dark green) | |
# 35: (0.5, 0.2, 0.8), # Bromine (purple) | |
# 53: (0.8, 0.2, 0.5), # Iodine (pink/magenta) | |
# } | |
# # Set default atom color | |
# drawer.drawOptions().setAtomColor(Chem.rdatomicnumlist.Get): (0.8, 0.8, 0.8) # Default to light gray for unknown atoms | |
# for atom_num, color in atom_colors.items(): | |
# drawer.drawOptions().setAtomColor(atom_num, color) | |
# drawer.drawOptions().bondColor = (0.7, 0.7, 0.7) # Bond color (medium gray) | |
# drawer.drawOptions().highlightColour = (0.2, 0.6, 1.0) # Highlight color (blue) | |
drawer.DrawMolecule(mol) | |
drawer.FinishDrawing() | |
svg = drawer.GetDrawingText() | |
return svg | |
def mol_to_sdf(mol): | |
"""Converts an RDKit molecule object to an SDF string.""" | |
if not mol: | |
return None | |
# Add hydrogens to the molecule | |
mol_with_h = Chem.AddHs(mol) | |
# Generate 3D coordinates using ETKDGv3, a common conformer generation method | |
# MaxAttempts is increased for robustness, randomSeed for reproducibility | |
try: | |
AllChem.EmbedMolecule(mol_with_h, AllChem.ETKDGv3(), maxAttempts=50, randomSeed=42) | |
# Optimize 3D coordinates using Universal Force Field (UFF) | |
AllChem.UFFOptimizeMolecule(mol_with_h) | |
sdf_string = Chem.MolToMolBlock(mol_with_h) | |
return sdf_string | |
except Exception as e: | |
logger.error(f"Error generating 3D coordinates for SMILES: {Chem.MolToSmiles(mol)} - {e}") | |
return None | |
def visualize_molecule_3d(mol_sdf: str, width='100%', height=400): | |
""" | |
Generates an interactive 3D molecule visualization using py3Dmol. | |
Accepts an SDF string. | |
""" | |
if not mol_sdf: | |
return None | |
try: | |
viewer = py3Dmol.view(width=width, height=height) | |
viewer.setBackgroundColor('#1C1C1C') # Dark background | |
viewer.addModel(mol_sdf, "sdf") | |
viewer.setStyle({'stick':{}, 'sphere':{'radius':0.3}}) # Stick and Sphere representation | |
viewer.zoomTo() | |
html_view = viewer._make_html() | |
return html_view | |
except Exception as e: | |
st.error(f"Error generating 3D visualization: {e}") | |
return None | |
# --- Main Streamlit Application Layout --- | |
st.title("π¬ Molecule Explorer & Predictor") | |
# Initialize session state for consistent data across reruns | |
if 'tokenizer' not in st.session_state: | |
st.session_state.tokenizer, st.session_state.model, st.session_state.pipeline = load_optimized_models() | |
tokenizer = st.session_state.tokenizer | |
model = st.session_state.model | |
fill_mask_pipeline = st.session_state.pipeline | |
tab1, tab2 = st.tabs(["Molecule Viewer (2D & 3D)", "Masked SMILES Predictor"]) | |
with tab1: | |
st.header("Visualize Molecules in 2D and 3D") | |
smiles_input = st.text_input("Enter SMILES string:", "CCO", help="e.g., CCO (ethanol), C1=CC=CC=C1 (benzene)") | |
if st.button("View Molecule"): | |
if smiles_input: | |
mol = Chem.MolFromSmiles(smiles_input) | |
if mol: | |
st.subheader("2D Structure") | |
svg = mol_to_svg(mol) | |
if svg: | |
st.image(svg, use_column_width=True) | |
else: | |
st.warning("Could not generate 2D image.") | |
st.subheader("3D Structure (Interactive)") | |
sdf_string = mol_to_sdf(mol) | |
if sdf_string: | |
html_3d = visualize_molecule_3d(sdf_string) | |
if html_3d: | |
st.components.v1.html(html_3d, width=700, height=500, scrolling=False) | |
else: | |
st.warning("Could not generate 3D visualization.") | |
else: | |
st.warning("Could not generate 3D SDF data.") | |
else: | |
st.error("Invalid SMILES string. Please enter a valid chemical structure.") | |
else: | |
st.info("Please enter a SMILES string to view the molecule.") | |
with tab2: | |
st.header("Masked SMILES Prediction") | |
masked_smiles_input = st.text_input( | |
"Enter masked SMILES string (use `<mask>` for the masked token):", | |
"C1=CC=CC<mask>C1", | |
help="Example: 'C1=CC=CC<mask>C1' (masked benzene), 'CCO<mask>C' (masked ether)" | |
) | |
top_k_predictions = st.slider("Number of predictions to show:", 1, 10, 5) | |
if st.button("Predict Masked Token"): | |
if masked_smiles_input and "<mask>" in masked_smiles_input: | |
try: | |
# Perform prediction using the loaded pipeline | |
predictions = fill_mask_pipeline(masked_smiles_input, top_k=top_k_predictions) | |
prediction_data = [] | |
for pred in predictions: | |
token_str = pred['token_str'] | |
sequence = pred['sequence'] | |
score = pred['score'] | |
mol = Chem.MolFromSmiles(sequence) | |
img_svg = None | |
if mol: | |
img_svg = mol_to_svg(mol, size=(200,150)) # Smaller image for table | |
prediction_data.append({ | |
"Predicted Token": token_str, | |
"Full SMILES": sequence, | |
"Confidence Score": f"{score:.4f}", | |
"Structure SVG": img_svg # Store SVG string | |
}) | |
df_predictions = pd.DataFrame(prediction_data) | |
st.subheader("Predictions:") | |
# Create a version of the dataframe without the SVG for initial display | |
display_df = df_predictions.drop(columns=["Structure SVG"]) | |
st.dataframe(display_df, use_container_width=True, hide_index=True) | |
st.subheader("Predicted Structures:") | |
# Determine the number of columns based on the number of predictions, up to a max | |
num_cols = min(len(df_predictions), 5) # Display up to 5 images per row | |
cols = st.columns(num_cols) | |
for i, row in df_predictions.iterrows(): | |
with cols[i % num_cols]: # Distribute images into columns | |
st.markdown(f"**{row['Predicted Token']}** (Score: {row['Confidence Score']})") | |
if row['Structure SVG']: | |
st.image(row['Structure SVG'], use_column_width='auto') | |
else: | |
st.write("*(Invalid SMILES)*") | |
except Exception as e: | |
st.error(f"An error occurred during prediction: {e}") | |
st.info("Please ensure your masked SMILES is valid and contains `<mask>`.") | |
else: | |
st.info("Please enter a masked SMILES string (e.g., `C1=CC=CC<mask>C1`).") | |