mol-lang-lab / app.py
alidenewade's picture
Update app.py
ee612c3 verified
# app.py
import gradio as gr
import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer, pipeline, BitsAndBytesConfig
from rdkit import Chem
from rdkit.Chem import Draw, rdFMCS
from rdkit.Chem.Draw import MolToImage
# PIL is imported as Image by rdkit.Chem.Draw.MolToImage, but explicit import is good practice if used directly.
# from PIL import Image
import pandas as pd
import io
import base64
import logging
# Set up logging to monitor quantization effects
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# --- Quantization Configuration ---
def get_quantization_config():
"""
Configure 8-bit quantization for model optimization.
Falls back gracefully if bitsandbytes is not available.
"""
try:
# 8-bit quantization configuration - good balance of speed and quality
quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
bnb_8bit_compute_dtype=torch.float16,
bnb_8bit_use_double_quant=True, # Nested quantization for better compression
)
logger.info("8-bit quantization configuration loaded successfully")
return quantization_config
except ImportError:
logger.warning("bitsandbytes not available, falling back to standard loading")
return None
except Exception as e:
logger.warning(f"Quantization setup failed: {e}, using standard loading")
return None
def get_torch_dtype():
"""Get appropriate torch dtype based on available hardware."""
if torch.cuda.is_available():
return torch.float16 # Use half precision on GPU
else:
return torch.float32 # Keep full precision on CPU
# --- Optimized Model Loading ---
def load_optimized_models():
"""Load models with quantization and other optimizations."""
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = get_torch_dtype()
quantization_config = get_quantization_config()
logger.info(f"Loading models on device: {device} with dtype: {torch_dtype}")
# Model name
model_name = "seyonec/PubChem10M_SMILES_BPE_450k"
# Load tokenizer (doesn't need quantization)
fill_mask_tokenizer = AutoTokenizer.from_pretrained(model_name)
# Load model with quantization if available
model_kwargs = {
"torch_dtype": torch_dtype,
}
if quantization_config is not None and torch.cuda.is_available(): # Quantization typically for GPU
model_kwargs["quantization_config"] = quantization_config
# device_map="auto" is often used with bitsandbytes for automatic distribution
model_kwargs["device_map"] = "auto"
elif torch.cuda.is_available():
model_kwargs["device_map"] = "auto" # For non-quantized GPU loading
else:
model_kwargs["device_map"] = None # For CPU
try:
# Masked LM Model
fill_mask_model = AutoModelForMaskedLM.from_pretrained(
model_name,
**model_kwargs
)
fill_mask_model.eval() # Set model to evaluation mode for inference
# Create optimized pipeline
# Let pipeline infer device from model if possible, or set based on model's device
pipeline_device = fill_mask_model.device.index if hasattr(fill_mask_model.device, 'type') and fill_mask_model.device.type == "cuda" else -1
fill_mask_pipeline = pipeline(
'fill-mask',
model=fill_mask_model,
tokenizer=fill_mask_tokenizer,
device=pipeline_device, # Use model's device
)
logger.info("Models loaded successfully with optimizations")
return fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline
except Exception as e:
logger.error(f"Error loading optimized models: {e}")
# Fallback to standard loading
logger.info("Falling back to standard model loading...")
return load_standard_models(model_name)
def load_standard_models(model_name):
"""Fallback standard model loading without quantization."""
fill_mask_tokenizer = AutoTokenizer.from_pretrained(model_name)
fill_mask_model = AutoModelForMaskedLM.from_pretrained(model_name)
# Determine device for standard loading
device_idx = 0 if torch.cuda.is_available() else -1
fill_mask_pipeline = pipeline('fill-mask', model=fill_mask_model, tokenizer=fill_mask_tokenizer, device=device_idx)
if torch.cuda.is_available():
fill_mask_model.to("cuda")
return fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline
# Load models with optimizations
fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline = load_optimized_models()
# --- Memory Management Utilities ---
def clear_gpu_cache():
"""Clear CUDA cache to free up memory."""
if torch.cuda.is_available():
torch.cuda.empty_cache()
# --- Helper Functions from Notebook (adapted) ---
def get_mol(smiles):
"""Converts SMILES to RDKit Mol object and Kekulizes it."""
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return None
try:
Chem.Kekulize(mol)
except: # Kekulization can fail for some structures
pass
return mol
def find_matches_one(mol, submol_smarts):
"""Finds all matching atoms for a SMARTS pattern in a molecule."""
if not mol or not submol_smarts:
return []
submol = Chem.MolFromSmarts(submol_smarts)
if not submol:
return []
matches = mol.GetSubstructMatches(submol)
return matches
def get_image_with_highlight(mol, atomset=None, size=(300, 300)):
"""Draws molecule with optional atom highlighting."""
if mol is None:
return None
highlight_color = (0, 1, 0, 0.5) # Green with some transparency
# Ensure atomset contains integers if not None or empty
valid_atomset = []
if atomset:
try:
valid_atomset = [int(a) for a in atomset]
except ValueError:
logger.warning(f"Invalid atom in atomset: {atomset}. Proceeding without highlighting problematic atoms.")
valid_atomset = [int(a) for a in atomset if str(a).isdigit()] # Filter out non-integers
img = MolToImage(mol, size=size, fitImage=True,
highlightAtoms=valid_atomset if valid_atomset else [],
highlightAtomColors={i: highlight_color for i in valid_atomset} if valid_atomset else {})
return img
# --- Optimized Gradio Interface Functions ---
def predict_and_visualize_masked_smiles(smiles_mask, substructure_smarts_highlight="CC=CC"):
"""
Predicts masked tokens in a SMILES string, shows scores, and visualizes molecules.
Optimized with memory management. Returns 7 items for Gradio outputs.
"""
if fill_mask_tokenizer.mask_token not in smiles_mask:
# Return 7 items for the 7 output components
return pd.DataFrame(), None, None, None, None, None, "Error: Input SMILES must contain a mask token (e.g., <mask>)."
try:
# Use torch.no_grad() for inference to save memory
with torch.no_grad():
predictions = fill_mask_pipeline(smiles_mask, top_k=10) # Get more to filter for valid ones
except Exception as e:
clear_gpu_cache() # Clear cache on error
# Return 7 items
return pd.DataFrame(), None, None, None, None, None, f"Error during prediction: {str(e)}"
results_data = []
image_list = []
valid_predictions_count = 0
for pred in predictions:
if valid_predictions_count >= 5:
break
predicted_smiles = pred['sequence']
score = pred['score']
mol = get_mol(predicted_smiles)
if mol:
results_data.append({"Predicted SMILES": predicted_smiles, "Score": f"{score:.4f}"})
atom_matches_indices = []
if substructure_smarts_highlight:
matches = find_matches_one(mol, substructure_smarts_highlight)
if matches:
atom_matches_indices = list(matches[0]) # Highlight first match
img = get_image_with_highlight(mol, atomset=atom_matches_indices)
image_list.append(img)
valid_predictions_count += 1
# Pad image_list if fewer than 5 valid predictions
while len(image_list) < 5:
image_list.append(None)
df_results = pd.DataFrame(results_data)
# Clear cache after inference
clear_gpu_cache()
status_message = "Prediction successful." if valid_predictions_count > 0 else "No valid molecules found for top predictions."
# Unpack image_list into individual image outputs + df_results + status_message
return df_results, image_list[0], image_list[1], image_list[2], image_list[3], image_list[4], status_message
def display_molecule_image(smiles_string):
"""
Displays a 2D image of a molecule from its SMILES string.
"""
if not smiles_string:
return None, "Please enter a SMILES string."
mol = get_mol(smiles_string)
if mol is None:
return None, "Invalid SMILES string."
img = MolToImage(mol, size=(400, 400), fitImage=True)
return img, "Molecule displayed."
# --- Gradio Interface Definition ---
with gr.Blocks(theme=gr.themes.Default()) as demo:
gr.Markdown("# ChemBERTa SMILES Utilities Dashboard")
with gr.Tab("Masked SMILES Prediction"):
gr.Markdown("Enter a SMILES string with a `<mask>` token (e.g., `C1=CC=CC<mask>C1`) to predict possible completions.")
with gr.Row():
smiles_input_masked = gr.Textbox(label="SMILES String with Mask", value="C1=CC=CC<mask>C1")
substructure_input = gr.Textbox(label="Substructure to Highlight (SMARTS)", value="C=C")
predict_button_masked = gr.Button("Predict and Visualize")
status_masked = gr.Textbox(label="Status", interactive=False)
predictions_table = gr.DataFrame(label="Top Predictions & Scores")
gr.Markdown("### Predicted Molecule Visualizations (Top 5 Valid)")
with gr.Row():
img_out_1 = gr.Image(label="Prediction 1", type="pil", interactive=False)
img_out_2 = gr.Image(label="Prediction 2", type="pil", interactive=False)
img_out_3 = gr.Image(label="Prediction 3", type="pil", interactive=False)
img_out_4 = gr.Image(label="Prediction 4", type="pil", interactive=False)
img_out_5 = gr.Image(label="Prediction 5", type="pil", interactive=False)
# Automatically populate on load for the default example
demo.load(
lambda: predict_and_visualize_masked_smiles("C1=CC=CC<mask>C1", "C=C"),
inputs=None,
outputs=[predictions_table, img_out_1, img_out_2, img_out_3, img_out_4, img_out_5, status_masked]
)
predict_button_masked.click(
predict_and_visualize_masked_smiles,
inputs=[smiles_input_masked, substructure_input],
outputs=[predictions_table, img_out_1, img_out_2, img_out_3, img_out_4, img_out_5, status_masked]
)
with gr.Tab("Molecule Viewer"):
gr.Markdown("Enter a SMILES string to display its 2D structure.")
smiles_input_viewer = gr.Textbox(label="SMILES String", value="C1=CC=CC=C1")
view_button_molecule = gr.Button("View Molecule")
status_viewer = gr.Textbox(label="Status", interactive=False)
molecule_image_output = gr.Image(label="Molecule Structure", type="pil", interactive=False)
# Automatically populate on load for the default example
demo.load(
lambda: display_molecule_image("C1=CC=CC=C1"),
inputs=None,
outputs=[molecule_image_output, status_viewer]
)
view_button_molecule.click(
display_molecule_image,
inputs=[smiles_input_viewer],
outputs=[molecule_image_output, status_viewer]
)
if __name__ == "__main__":
demo.launch()