Spaces:
Running
Running
# app.py | |
import streamlit as st | |
import torch | |
from transformers import AutoModelForMaskedLM, AutoTokenizer, pipeline, BitsAndBytesConfig | |
from rdkit import Chem | |
from rdkit.Chem import Draw, AllChem | |
from rdkit.Chem.Draw import MolToImage | |
import pandas as pd | |
import io | |
import base64 | |
import logging | |
import py3Dmol | |
# Set up logging to monitor quantization effects | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# --- Quantization Configuration --- | |
def get_quantization_config(): | |
""" | |
Configure 8-bit quantization for model optimization. | |
Falls back gracefully if bitsandbytes is not available. | |
""" | |
try: | |
# 8-bit quantization configuration - good balance of speed and quality | |
quantization_config = BitsAndBytesConfig( | |
load_in_8bit=True, | |
bnb_8bit_compute_dtype=torch.float16, | |
bnb_8bit_use_double_quant=True, # Nested quantization for better compression | |
) | |
logger.info("8-bit quantization configuration loaded successfully") | |
return quantization_config | |
except ImportError: | |
logger.warning("bitsandbytes not available, falling back to standard loading") | |
return None | |
except Exception as e: | |
logger.warning(f"Quantization setup failed: {e}, using standard loading") | |
return None | |
def get_torch_dtype(): | |
"""Get appropriate torch dtype based on available hardware.""" | |
if torch.cuda.is_available(): | |
return torch.float16 # Use half precision on GPU | |
else: | |
return torch.float32 # Keep full precision on CPU | |
# --- Optimized Model Loading --- | |
def load_optimized_models(): | |
"""Load models with quantization and other optimizations. | |
Uses st.cache_resource to avoid reloading models on every rerun.""" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
torch_dtype = get_torch_dtype() | |
quantization_config = get_quantization_config() | |
logger.info(f"Loading models on device: {device} with dtype: {torch_dtype}") | |
# Model names | |
model_name = "seyonec/PubChem10M_SMILES_BPE_450k" | |
# Load tokenizer (doesn't need quantization) | |
fill_mask_tokenizer = AutoTokenizer.from_pretrained(model_name) | |
# Load model with quantization if available | |
model_kwargs = { | |
"torch_dtype": torch_dtype, | |
} | |
if quantization_config is not None and torch.cuda.is_available(): # Quantization typically for GPU | |
model_kwargs["quantization_config"] = quantization_config | |
# device_map="auto" is often used with bitsandbytes for automatic distribution | |
model_kwargs["device_map"] = "auto" | |
elif torch.cuda.is_available(): | |
model_kwargs["device_map"] = "auto" # For non-quantized GPU loading | |
else: | |
model_kwargs["device_map"] = None # For CPU | |
try: | |
# Masked LM Model | |
fill_mask_model = AutoModelForMaskedLM.from_pretrained( | |
model_name, | |
**model_kwargs | |
) | |
# Set model to evaluation mode for inference | |
fill_mask_model.eval() | |
# Create optimized pipeline | |
# Let pipeline infer device from model if possible, or set based on model's device | |
pipeline_device = fill_mask_model.device.index if hasattr(fill_mask_model.device, 'type') and fill_mask_model.device.type == "cuda" else -1 | |
fill_mask_pipeline = pipeline( | |
'fill-mask', | |
model=fill_mask_model, | |
tokenizer=fill_mask_tokenizer, | |
device=pipeline_device, # Use model's device | |
# torch_dtype=torch_dtype # Pipeline might infer this or it might conflict | |
) | |
logger.info("Models loaded successfully with optimizations") | |
return fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline | |
except Exception as e: | |
logger.error(f"Error loading optimized models: {e}") | |
# Fallback to standard loading | |
logger.info("Falling back to standard model loading...") | |
return load_standard_models(model_name) | |
def load_standard_models(model_name): | |
"""Fallback standard model loading without quantization.""" | |
fill_mask_tokenizer = AutoTokenizer.from_pretrained(model_name) | |
fill_mask_model = AutoModelForMaskedLM.from_pretrained(model_name) | |
# Determine device for standard loading | |
device_idx = 0 if torch.cuda.is_available() else -1 | |
fill_mask_pipeline = pipeline('fill-mask', model=fill_mask_model, tokenizer=fill_mask_tokenizer, device=device_idx) | |
if torch.cuda.is_available(): | |
fill_mask_model.to("cuda") | |
return fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline | |
# Load models with optimizations | |
fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline = load_optimized_models() | |
# --- Memory Management Utilities --- | |
def clear_gpu_cache(): | |
"""Clear CUDA cache to free up memory.""" | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
# --- Helper Functions from Notebook (adapted) --- | |
def get_mol(smiles): | |
"""Converts SMILES to RDKit Mol object and Kekulizes it.""" | |
mol = Chem.MolFromSmiles(smiles) | |
if mol is None: | |
return None | |
try: | |
Chem.Kekulize(mol) | |
except: # Kekulization can fail for some structures | |
pass | |
return mol | |
def find_matches_one(mol, submol_smarts): | |
"""Finds all matching atoms for a SMARTS pattern in a molecule.""" | |
if not mol or not submol_smarts: | |
return [] | |
submol = Chem.MolFromSmarts(submol_smarts) | |
if not submol: | |
return [] | |
matches = mol.GetSubstructMatches(submol) | |
return matches | |
def get_image_with_highlight(mol, atomset=None, size=(300, 300)): | |
"""Draws molecule with optional atom highlighting.""" | |
if mol is None: | |
return None | |
highlight_color = (0, 1, 0, 0.5) # Green with some transparency | |
# Ensure atomset contains integers if not None or empty | |
valid_atomset = [] | |
if atomset: | |
try: | |
valid_atomset = [int(a) for a in atomset] | |
except ValueError: | |
logger.warning(f"Invalid atom in atomset: {atomset}. Proceeding without highlighting problematic atoms.") | |
valid_atomset = [int(a) for a in atomset if str(a).isdigit()] # Filter out non-integers | |
img = MolToImage(mol, size=size, fitImage=True, | |
highlightAtoms=valid_atomset if valid_atomset else [], | |
highlightAtomColors={i: highlight_color for i in valid_atomset} if valid_atomset else {}) | |
return img | |
def mol_to_sdf_string(mol): | |
"""Converts an RDKit Mol object to an SDF string.""" | |
if mol is None: | |
return None | |
# Add 3D coordinates if not present | |
AllChem.EmbedMolecule(mol, AllChem.ETKDG()) | |
AllChem.UFFOptimizeMolecule(mol) | |
return Chem.MolToMolBlock(mol) | |
def render_mol_3d(sdf_string, width=300, height=300): | |
"""Renders a 3D molecule using py3Dmol.""" | |
if sdf_string is None: | |
return "" | |
viewer = py3Dmol.view(width=width, height=height) | |
viewer.addModel(sdf_string, 'sdf') | |
viewer.setStyle({'stick':{}}) # Display as sticks | |
viewer.zoomTo() | |
# Embed the viewer HTML into Streamlit | |
return viewer.to_html() | |
# --- Streamlit Interface Functions --- | |
def predict_and_visualize_masked_smiles(smiles_mask, substructure_smarts_highlight="CC=CC"): | |
""" | |
Predicts masked tokens in a SMILES string, shows scores, and visualizes molecules. | |
Returns 5 image paths and a status message. | |
""" | |
if fill_mask_tokenizer.mask_token not in smiles_mask: | |
st.error("Error: Input SMILES must contain a mask token (e.g., <mask>).") | |
return pd.DataFrame(), [None]*5, [None]*5, "Error: Input SMILES must contain a mask token (e.g., <mask>)." | |
try: | |
with torch.no_grad(): | |
predictions = fill_mask_pipeline(smiles_mask, top_k=10) | |
except Exception as e: | |
clear_gpu_cache() | |
st.error(f"Error during prediction: {str(e)}") | |
return pd.DataFrame(), [None]*5, [None]*5, f"Error during prediction: {str(e)}" | |
results_data = [] | |
image_2d_list = [] | |
image_3d_list = [] | |
valid_predictions_count = 0 | |
for pred in predictions: | |
if valid_predictions_count >= 5: | |
break | |
predicted_smiles = pred['sequence'] | |
score = pred['score'] | |
mol = get_mol(predicted_smiles) | |
if mol: | |
results_data.append({"Predicted SMILES": predicted_smiles, "Score": f"{score:.4f}"}) | |
atom_matches_indices = [] | |
if substructure_smarts_highlight: | |
matches = find_matches_one(mol, substructure_smarts_highlight) | |
if matches: | |
atom_matches_indices = list(matches[0]) # Highlight first match | |
img_2d = get_image_with_highlight(mol, atomset=atom_matches_indices) | |
image_2d_list.append(img_2d) | |
# For 3D, we need an SDF string | |
sdf_string = mol_to_sdf_string(mol) | |
img_3d_html = render_mol_3d(sdf_string, width=300, height=300) | |
image_3d_list.append(img_3d_html) | |
valid_predictions_count += 1 | |
# Pad image lists if fewer than 5 valid predictions | |
while len(image_2d_list) < 5: | |
image_2d_list.append(None) | |
image_3d_list.append(None) | |
df_results = pd.DataFrame(results_data) | |
clear_gpu_cache() | |
status_message = "Prediction successful." if valid_predictions_count > 0 else "No valid molecules found for top predictions." | |
return df_results, image_2d_list, image_3d_list, status_message | |
def display_molecule_with_3d(smiles_string): | |
""" | |
Displays a 2D image and 3D visualization of a molecule from its SMILES string. | |
""" | |
if not smiles_string: | |
return None, None, "Please enter a SMILES string." | |
mol = get_mol(smiles_string) | |
if mol is None: | |
return None, None, "Invalid SMILES string." | |
img_2d = MolToImage(mol, size=(400, 400), fitImage=True) | |
sdf_string = mol_to_sdf_string(mol) | |
img_3d_html = render_mol_3d(sdf_string, width=400, height=400) | |
return img_2d, img_3d_html, "Molecule displayed." | |
# --- Streamlit UI Definition --- | |
# Set wide mode and background color | |
st.set_page_config(layout="wide") | |
st.markdown( | |
""" | |
<style> | |
.stApp { | |
background-color: rgb(28,28,28); | |
color: white; /* Ensure text is visible on dark background */ | |
} | |
.stDataFrame { | |
color: black; /* Default DataFrame text color */ | |
} | |
h1, h2, h3, h4, h5, h6, .stMarkdown { | |
color: white; | |
} | |
.css-1d391kg, .css-1dp5dn1 { /* Target Streamlit's main content and sidebar */ | |
color: white; | |
} | |
.streamlit-expanderContent { | |
background-color: rgb(40,40,40); /* Slightly lighter background for expanders */ | |
border-radius: 10px; | |
padding: 10px; | |
} | |
/* Style for text inputs and buttons */ | |
.stTextInput>div>div>input { | |
background-color: rgb(50,50,50); | |
color: white; | |
border-radius: 5px; | |
border: 1px solid rgb(70,70,70); | |
} | |
.stButton>button { | |
background-color: rgb(0,128,255); /* Blue button */ | |
color: white; | |
border-radius: 8px; | |
padding: 10px 20px; | |
border: none; | |
transition: background-color 0.3s ease; | |
} | |
.stButton>button:hover { | |
background-color: rgb(0,100,200); | |
} | |
</style> | |
""", | |
unsafe_allow_html=True | |
) | |
st.title("ChemBERTa SMILES Utilities Dashboard") | |
tab1, tab2 = st.tabs(["Masked SMILES Prediction", "Molecule Viewer"]) | |
with tab1: | |
st.markdown("Enter a SMILES string with a `<mask>` token (e.g., `C1=CC=CC<mask>C1`) to predict possible completions.") | |
col1, col2 = st.columns([2, 1]) | |
with col1: | |
smiles_input_masked = st.text_input("SMILES String with Mask", value="C1=CC=CC<mask>C1") | |
with col2: | |
substructure_input = st.text_input("Substructure to Highlight (SMARTS)", value="C=C") | |
if st.button("Predict and Visualize", key="predict_button"): | |
with st.spinner("Predicting and visualizing..."): | |
df_predictions, img_2d_list, img_3d_list, status_msg = predict_and_visualize_masked_smiles( | |
smiles_input_masked, substructure_input | |
) | |
st.write(status_msg) | |
if not df_predictions.empty: | |
st.subheader("Top Predictions & Scores") | |
st.dataframe(df_predictions, use_container_width=True) | |
st.subheader("Predicted Molecule Visualizations (Top 5 Valid)") | |
for i in range(5): | |
if img_2d_list[i] is not None: | |
st.markdown(f"**Prediction {i+1}**") | |
cols_img = st.columns(2) | |
with cols_img[0]: | |
st.image(img_2d_list[i], caption=f"2D Prediction {i+1}", use_column_width=True) | |
with cols_img[1]: | |
st.components.v1.html(img_3d_list[i], height=300) | |
else: | |
if i < len(df_predictions): # Only show 'No visualization' if there was a prediction attempt | |
st.markdown(f"**Prediction {i+1}**: No visualization available (invalid SMILES or error).") | |
with tab2: | |
st.markdown("Enter a SMILES string to display its 2D and 3D structure.") | |
smiles_input_viewer = st.text_input("SMILES String", value="C1=CC=CC=C1", key="viewer_smiles_input") | |
if st.button("View Molecule", key="view_button"): | |
with st.spinner("Displaying molecule..."): | |
img_2d_viewer, img_3d_viewer_html, status_viewer_msg = display_molecule_with_3d(smiles_input_viewer) | |
st.write(status_viewer_msg) | |
if img_2d_viewer is not None: | |
cols_viewer = st.columns(2) | |
with cols_viewer[0]: | |
st.image(img_2d_viewer, caption="2D Molecule Structure", use_column_width=True) | |
with cols_viewer[1]: | |
st.components.v1.html(img_3d_viewer_html, height=400) | |
else: | |
st.warning("Could not display molecule. Please check the SMILES string.") | |