mol-lang-lab / app.py
alidenewade's picture
Update app.py
e56ed1f verified
raw
history blame
14.1 kB
# app.py
import streamlit as st
import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer, pipeline, BitsAndBytesConfig
from rdkit import Chem
from rdkit.Chem import Draw, AllChem
from rdkit.Chem.Draw import MolToImage
import pandas as pd
import io
import base64
import logging
import py3Dmol
# Set up logging to monitor quantization effects
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# --- Quantization Configuration ---
def get_quantization_config():
"""
Configure 8-bit quantization for model optimization.
Falls back gracefully if bitsandbytes is not available.
"""
try:
# 8-bit quantization configuration - good balance of speed and quality
quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
bnb_8bit_compute_dtype=torch.float16,
bnb_8bit_use_double_quant=True, # Nested quantization for better compression
)
logger.info("8-bit quantization configuration loaded successfully")
return quantization_config
except ImportError:
logger.warning("bitsandbytes not available, falling back to standard loading")
return None
except Exception as e:
logger.warning(f"Quantization setup failed: {e}, using standard loading")
return None
def get_torch_dtype():
"""Get appropriate torch dtype based on available hardware."""
if torch.cuda.is_available():
return torch.float16 # Use half precision on GPU
else:
return torch.float32 # Keep full precision on CPU
# --- Optimized Model Loading ---
@st.cache_resource
def load_optimized_models():
"""Load models with quantization and other optimizations.
Uses st.cache_resource to avoid reloading models on every rerun."""
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = get_torch_dtype()
quantization_config = get_quantization_config()
logger.info(f"Loading models on device: {device} with dtype: {torch_dtype}")
# Model names
model_name = "seyonec/PubChem10M_SMILES_BPE_450k"
# Load tokenizer (doesn't need quantization)
fill_mask_tokenizer = AutoTokenizer.from_pretrained(model_name)
# Load model with quantization if available
model_kwargs = {
"torch_dtype": torch_dtype,
}
if quantization_config is not None and torch.cuda.is_available(): # Quantization typically for GPU
model_kwargs["quantization_config"] = quantization_config
# device_map="auto" is often used with bitsandbytes for automatic distribution
model_kwargs["device_map"] = "auto"
elif torch.cuda.is_available():
model_kwargs["device_map"] = "auto" # For non-quantized GPU loading
else:
model_kwargs["device_map"] = None # For CPU
try:
# Masked LM Model
fill_mask_model = AutoModelForMaskedLM.from_pretrained(
model_name,
**model_kwargs
)
# Set model to evaluation mode for inference
fill_mask_model.eval()
# Create optimized pipeline
# Let pipeline infer device from model if possible, or set based on model's device
pipeline_device = fill_mask_model.device.index if hasattr(fill_mask_model.device, 'type') and fill_mask_model.device.type == "cuda" else -1
fill_mask_pipeline = pipeline(
'fill-mask',
model=fill_mask_model,
tokenizer=fill_mask_tokenizer,
device=pipeline_device, # Use model's device
# torch_dtype=torch_dtype # Pipeline might infer this or it might conflict
)
logger.info("Models loaded successfully with optimizations")
return fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline
except Exception as e:
logger.error(f"Error loading optimized models: {e}")
# Fallback to standard loading
logger.info("Falling back to standard model loading...")
return load_standard_models(model_name)
def load_standard_models(model_name):
"""Fallback standard model loading without quantization."""
fill_mask_tokenizer = AutoTokenizer.from_pretrained(model_name)
fill_mask_model = AutoModelForMaskedLM.from_pretrained(model_name)
# Determine device for standard loading
device_idx = 0 if torch.cuda.is_available() else -1
fill_mask_pipeline = pipeline('fill-mask', model=fill_mask_model, tokenizer=fill_mask_tokenizer, device=device_idx)
if torch.cuda.is_available():
fill_mask_model.to("cuda")
return fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline
# Load models with optimizations
fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline = load_optimized_models()
# --- Memory Management Utilities ---
def clear_gpu_cache():
"""Clear CUDA cache to free up memory."""
if torch.cuda.is_available():
torch.cuda.empty_cache()
# --- Helper Functions from Notebook (adapted) ---
def get_mol(smiles):
"""Converts SMILES to RDKit Mol object and Kekulizes it."""
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return None
try:
Chem.Kekulize(mol)
except: # Kekulization can fail for some structures
pass
return mol
def find_matches_one(mol, submol_smarts):
"""Finds all matching atoms for a SMARTS pattern in a molecule."""
if not mol or not submol_smarts:
return []
submol = Chem.MolFromSmarts(submol_smarts)
if not submol:
return []
matches = mol.GetSubstructMatches(submol)
return matches
def get_image_with_highlight(mol, atomset=None, size=(300, 300)):
"""Draws molecule with optional atom highlighting."""
if mol is None:
return None
highlight_color = (0, 1, 0, 0.5) # Green with some transparency
# Ensure atomset contains integers if not None or empty
valid_atomset = []
if atomset:
try:
valid_atomset = [int(a) for a in atomset]
except ValueError:
logger.warning(f"Invalid atom in atomset: {atomset}. Proceeding without highlighting problematic atoms.")
valid_atomset = [int(a) for a in atomset if str(a).isdigit()] # Filter out non-integers
img = MolToImage(mol, size=size, fitImage=True,
highlightAtoms=valid_atomset if valid_atomset else [],
highlightAtomColors={i: highlight_color for i in valid_atomset} if valid_atomset else {})
return img
def mol_to_sdf_string(mol):
"""Converts an RDKit Mol object to an SDF string."""
if mol is None:
return None
# Add 3D coordinates if not present
AllChem.EmbedMolecule(mol, AllChem.ETKDG())
AllChem.UFFOptimizeMolecule(mol)
return Chem.MolToMolBlock(mol)
def render_mol_3d(sdf_string, width=300, height=300):
"""Renders a 3D molecule using py3Dmol."""
if sdf_string is None:
return ""
viewer = py3Dmol.view(width=width, height=height)
viewer.addModel(sdf_string, 'sdf')
viewer.setStyle({'stick':{}}) # Display as sticks
viewer.zoomTo()
# Embed the viewer HTML into Streamlit
return viewer.to_html()
# --- Streamlit Interface Functions ---
def predict_and_visualize_masked_smiles(smiles_mask, substructure_smarts_highlight="CC=CC"):
"""
Predicts masked tokens in a SMILES string, shows scores, and visualizes molecules.
Returns 5 image paths and a status message.
"""
if fill_mask_tokenizer.mask_token not in smiles_mask:
st.error("Error: Input SMILES must contain a mask token (e.g., <mask>).")
return pd.DataFrame(), [None]*5, [None]*5, "Error: Input SMILES must contain a mask token (e.g., <mask>)."
try:
with torch.no_grad():
predictions = fill_mask_pipeline(smiles_mask, top_k=10)
except Exception as e:
clear_gpu_cache()
st.error(f"Error during prediction: {str(e)}")
return pd.DataFrame(), [None]*5, [None]*5, f"Error during prediction: {str(e)}"
results_data = []
image_2d_list = []
image_3d_list = []
valid_predictions_count = 0
for pred in predictions:
if valid_predictions_count >= 5:
break
predicted_smiles = pred['sequence']
score = pred['score']
mol = get_mol(predicted_smiles)
if mol:
results_data.append({"Predicted SMILES": predicted_smiles, "Score": f"{score:.4f}"})
atom_matches_indices = []
if substructure_smarts_highlight:
matches = find_matches_one(mol, substructure_smarts_highlight)
if matches:
atom_matches_indices = list(matches[0]) # Highlight first match
img_2d = get_image_with_highlight(mol, atomset=atom_matches_indices)
image_2d_list.append(img_2d)
# For 3D, we need an SDF string
sdf_string = mol_to_sdf_string(mol)
img_3d_html = render_mol_3d(sdf_string, width=300, height=300)
image_3d_list.append(img_3d_html)
valid_predictions_count += 1
# Pad image lists if fewer than 5 valid predictions
while len(image_2d_list) < 5:
image_2d_list.append(None)
image_3d_list.append(None)
df_results = pd.DataFrame(results_data)
clear_gpu_cache()
status_message = "Prediction successful." if valid_predictions_count > 0 else "No valid molecules found for top predictions."
return df_results, image_2d_list, image_3d_list, status_message
def display_molecule_with_3d(smiles_string):
"""
Displays a 2D image and 3D visualization of a molecule from its SMILES string.
"""
if not smiles_string:
return None, None, "Please enter a SMILES string."
mol = get_mol(smiles_string)
if mol is None:
return None, None, "Invalid SMILES string."
img_2d = MolToImage(mol, size=(400, 400), fitImage=True)
sdf_string = mol_to_sdf_string(mol)
img_3d_html = render_mol_3d(sdf_string, width=400, height=400)
return img_2d, img_3d_html, "Molecule displayed."
# --- Streamlit UI Definition ---
# Set wide mode and background color
st.set_page_config(layout="wide")
st.markdown(
"""
<style>
.stApp {
background-color: rgb(28,28,28);
color: white; /* Ensure text is visible on dark background */
}
.stDataFrame {
color: black; /* Default DataFrame text color */
}
h1, h2, h3, h4, h5, h6, .stMarkdown {
color: white;
}
.css-1d391kg, .css-1dp5dn1 { /* Target Streamlit's main content and sidebar */
color: white;
}
.streamlit-expanderContent {
background-color: rgb(40,40,40); /* Slightly lighter background for expanders */
border-radius: 10px;
padding: 10px;
}
/* Style for text inputs and buttons */
.stTextInput>div>div>input {
background-color: rgb(50,50,50);
color: white;
border-radius: 5px;
border: 1px solid rgb(70,70,70);
}
.stButton>button {
background-color: rgb(0,128,255); /* Blue button */
color: white;
border-radius: 8px;
padding: 10px 20px;
border: none;
transition: background-color 0.3s ease;
}
.stButton>button:hover {
background-color: rgb(0,100,200);
}
</style>
""",
unsafe_allow_html=True
)
st.title("ChemBERTa SMILES Utilities Dashboard")
tab1, tab2 = st.tabs(["Masked SMILES Prediction", "Molecule Viewer"])
with tab1:
st.markdown("Enter a SMILES string with a `<mask>` token (e.g., `C1=CC=CC<mask>C1`) to predict possible completions.")
col1, col2 = st.columns([2, 1])
with col1:
smiles_input_masked = st.text_input("SMILES String with Mask", value="C1=CC=CC<mask>C1")
with col2:
substructure_input = st.text_input("Substructure to Highlight (SMARTS)", value="C=C")
if st.button("Predict and Visualize", key="predict_button"):
with st.spinner("Predicting and visualizing..."):
df_predictions, img_2d_list, img_3d_list, status_msg = predict_and_visualize_masked_smiles(
smiles_input_masked, substructure_input
)
st.write(status_msg)
if not df_predictions.empty:
st.subheader("Top Predictions & Scores")
st.dataframe(df_predictions, use_container_width=True)
st.subheader("Predicted Molecule Visualizations (Top 5 Valid)")
for i in range(5):
if img_2d_list[i] is not None:
st.markdown(f"**Prediction {i+1}**")
cols_img = st.columns(2)
with cols_img[0]:
st.image(img_2d_list[i], caption=f"2D Prediction {i+1}", use_column_width=True)
with cols_img[1]:
st.components.v1.html(img_3d_list[i], height=300)
else:
if i < len(df_predictions): # Only show 'No visualization' if there was a prediction attempt
st.markdown(f"**Prediction {i+1}**: No visualization available (invalid SMILES or error).")
with tab2:
st.markdown("Enter a SMILES string to display its 2D and 3D structure.")
smiles_input_viewer = st.text_input("SMILES String", value="C1=CC=CC=C1", key="viewer_smiles_input")
if st.button("View Molecule", key="view_button"):
with st.spinner("Displaying molecule..."):
img_2d_viewer, img_3d_viewer_html, status_viewer_msg = display_molecule_with_3d(smiles_input_viewer)
st.write(status_viewer_msg)
if img_2d_viewer is not None:
cols_viewer = st.columns(2)
with cols_viewer[0]:
st.image(img_2d_viewer, caption="2D Molecule Structure", use_column_width=True)
with cols_viewer[1]:
st.components.v1.html(img_3d_viewer_html, height=400)
else:
st.warning("Could not display molecule. Please check the SMILES string.")