Spaces:
Sleeping
Sleeping
# app.py | |
import gradio as gr | |
import torch | |
from transformers import AutoModelForMaskedLM, AutoTokenizer, pipeline, BitsAndBytesConfig | |
from rdkit import Chem | |
from rdkit.Chem import Draw, rdFMCS | |
from rdkit.Chem.Draw import MolToImage | |
# PIL is imported as Image by rdkit.Chem.Draw.MolToImage, but explicit import is good practice if used directly. | |
# from PIL import Image | |
import pandas as pd | |
import io | |
import base64 | |
import logging | |
# Set up logging to monitor quantization effects | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# --- Quantization Configuration --- | |
def get_quantization_config(): | |
""" | |
Configure 8-bit quantization for model optimization. | |
Falls back gracefully if bitsandbytes is not available. | |
""" | |
try: | |
# 8-bit quantization configuration - good balance of speed and quality | |
quantization_config = BitsAndBytesConfig( | |
load_in_8bit=True, | |
bnb_8bit_compute_dtype=torch.float16, | |
bnb_8bit_use_double_quant=True, # Nested quantization for better compression | |
) | |
logger.info("8-bit quantization configuration loaded successfully") | |
return quantization_config | |
except ImportError: | |
logger.warning("bitsandbytes not available, falling back to standard loading") | |
return None | |
except Exception as e: | |
logger.warning(f"Quantization setup failed: {e}, using standard loading") | |
return None | |
def get_torch_dtype(): | |
"""Get appropriate torch dtype based on available hardware.""" | |
if torch.cuda.is_available(): | |
return torch.float16 # Use half precision on GPU | |
else: | |
return torch.float32 # Keep full precision on CPU | |
# --- Optimized Model Loading --- | |
def load_optimized_models(): | |
"""Load models with quantization and other optimizations.""" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
torch_dtype = get_torch_dtype() | |
quantization_config = get_quantization_config() | |
logger.info(f"Loading models on device: {device} with dtype: {torch_dtype}") | |
# Model name | |
model_name = "seyonec/PubChem10M_SMILES_BPE_450k" | |
# Load tokenizer (doesn't need quantization) | |
fill_mask_tokenizer = AutoTokenizer.from_pretrained(model_name) | |
# Load model with quantization if available | |
model_kwargs = { | |
"torch_dtype": torch_dtype, | |
} | |
if quantization_config is not None and torch.cuda.is_available(): # Quantization typically for GPU | |
model_kwargs["quantization_config"] = quantization_config | |
# device_map="auto" is often used with bitsandbytes for automatic distribution | |
model_kwargs["device_map"] = "auto" | |
elif torch.cuda.is_available(): | |
model_kwargs["device_map"] = "auto" # For non-quantized GPU loading | |
else: | |
model_kwargs["device_map"] = None # For CPU | |
try: | |
# Masked LM Model | |
fill_mask_model = AutoModelForMaskedLM.from_pretrained( | |
model_name, | |
**model_kwargs | |
) | |
fill_mask_model.eval() # Set model to evaluation mode for inference | |
# Create optimized pipeline | |
# Let pipeline infer device from model if possible, or set based on model's device | |
pipeline_device = fill_mask_model.device.index if hasattr(fill_mask_model.device, 'type') and fill_mask_model.device.type == "cuda" else -1 | |
fill_mask_pipeline = pipeline( | |
'fill-mask', | |
model=fill_mask_model, | |
tokenizer=fill_mask_tokenizer, | |
device=pipeline_device, # Use model's device | |
) | |
logger.info("Models loaded successfully with optimizations") | |
return fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline | |
except Exception as e: | |
logger.error(f"Error loading optimized models: {e}") | |
# Fallback to standard loading | |
logger.info("Falling back to standard model loading...") | |
return load_standard_models(model_name) | |
def load_standard_models(model_name): | |
"""Fallback standard model loading without quantization.""" | |
fill_mask_tokenizer = AutoTokenizer.from_pretrained(model_name) | |
fill_mask_model = AutoModelForMaskedLM.from_pretrained(model_name) | |
# Determine device for standard loading | |
device_idx = 0 if torch.cuda.is_available() else -1 | |
fill_mask_pipeline = pipeline('fill-mask', model=fill_mask_model, tokenizer=fill_mask_tokenizer, device=device_idx) | |
if torch.cuda.is_available(): | |
fill_mask_model.to("cuda") | |
return fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline | |
# Load models with optimizations | |
fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline = load_optimized_models() | |
# --- Memory Management Utilities --- | |
def clear_gpu_cache(): | |
"""Clear CUDA cache to free up memory.""" | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
# --- Helper Functions from Notebook (adapted) --- | |
def get_mol(smiles): | |
"""Converts SMILES to RDKit Mol object and Kekulizes it.""" | |
mol = Chem.MolFromSmiles(smiles) | |
if mol is None: | |
return None | |
try: | |
Chem.Kekulize(mol) | |
except: # Kekulization can fail for some structures | |
pass | |
return mol | |
def find_matches_one(mol, submol_smarts): | |
"""Finds all matching atoms for a SMARTS pattern in a molecule.""" | |
if not mol or not submol_smarts: | |
return [] | |
submol = Chem.MolFromSmarts(submol_smarts) | |
if not submol: | |
return [] | |
matches = mol.GetSubstructMatches(submol) | |
return matches | |
def get_image_with_highlight(mol, atomset=None, size=(300, 300)): | |
"""Draws molecule with optional atom highlighting.""" | |
if mol is None: | |
return None | |
highlight_color = (0, 1, 0, 0.5) # Green with some transparency | |
# Ensure atomset contains integers if not None or empty | |
valid_atomset = [] | |
if atomset: | |
try: | |
valid_atomset = [int(a) for a in atomset] | |
except ValueError: | |
logger.warning(f"Invalid atom in atomset: {atomset}. Proceeding without highlighting problematic atoms.") | |
valid_atomset = [int(a) for a in atomset if str(a).isdigit()] # Filter out non-integers | |
img = MolToImage(mol, size=size, fitImage=True, | |
highlightAtoms=valid_atomset if valid_atomset else [], | |
highlightAtomColors={i: highlight_color for i in valid_atomset} if valid_atomset else {}) | |
return img | |
# --- Optimized Gradio Interface Functions --- | |
def predict_and_visualize_masked_smiles(smiles_mask, substructure_smarts_highlight="CC=CC"): | |
""" | |
Predicts masked tokens in a SMILES string, shows scores, and visualizes molecules. | |
Optimized with memory management. Returns 7 items for Gradio outputs. | |
""" | |
if fill_mask_tokenizer.mask_token not in smiles_mask: | |
# Return 7 items for the 7 output components | |
return pd.DataFrame(), None, None, None, None, None, "Error: Input SMILES must contain a mask token (e.g., <mask>)." | |
try: | |
# Use torch.no_grad() for inference to save memory | |
with torch.no_grad(): | |
predictions = fill_mask_pipeline(smiles_mask, top_k=10) # Get more to filter for valid ones | |
except Exception as e: | |
clear_gpu_cache() # Clear cache on error | |
# Return 7 items | |
return pd.DataFrame(), None, None, None, None, None, f"Error during prediction: {str(e)}" | |
results_data = [] | |
image_list = [] | |
valid_predictions_count = 0 | |
for pred in predictions: | |
if valid_predictions_count >= 5: | |
break | |
predicted_smiles = pred['sequence'] | |
score = pred['score'] | |
mol = get_mol(predicted_smiles) | |
if mol: | |
results_data.append({"Predicted SMILES": predicted_smiles, "Score": f"{score:.4f}"}) | |
atom_matches_indices = [] | |
if substructure_smarts_highlight: | |
matches = find_matches_one(mol, substructure_smarts_highlight) | |
if matches: | |
atom_matches_indices = list(matches[0]) # Highlight first match | |
img = get_image_with_highlight(mol, atomset=atom_matches_indices) | |
image_list.append(img) | |
valid_predictions_count += 1 | |
# Pad image_list if fewer than 5 valid predictions | |
while len(image_list) < 5: | |
image_list.append(None) | |
df_results = pd.DataFrame(results_data) | |
# Clear cache after inference | |
clear_gpu_cache() | |
status_message = "Prediction successful." if valid_predictions_count > 0 else "No valid molecules found for top predictions." | |
# Unpack image_list into individual image outputs + df_results + status_message | |
return df_results, image_list[0], image_list[1], image_list[2], image_list[3], image_list[4], status_message | |
def display_molecule_image(smiles_string): | |
""" | |
Displays a 2D image of a molecule from its SMILES string. | |
""" | |
if not smiles_string: | |
return None, "Please enter a SMILES string." | |
mol = get_mol(smiles_string) | |
if mol is None: | |
return None, "Invalid SMILES string." | |
img = MolToImage(mol, size=(400, 400), fitImage=True) | |
return img, "Molecule displayed." | |
# --- Gradio Interface Definition --- | |
with gr.Blocks(theme=gr.themes.Default()) as demo: | |
gr.Markdown("# ChemBERTa SMILES Utilities Dashboard") | |
with gr.Tab("Masked SMILES Prediction"): | |
gr.Markdown("Enter a SMILES string with a `<mask>` token (e.g., `C1=CC=CC<mask>C1`) to predict possible completions.") | |
with gr.Row(): | |
smiles_input_masked = gr.Textbox(label="SMILES String with Mask", value="C1=CC=CC<mask>C1") | |
substructure_input = gr.Textbox(label="Substructure to Highlight (SMARTS)", value="C=C") | |
predict_button_masked = gr.Button("Predict and Visualize") | |
status_masked = gr.Textbox(label="Status", interactive=False) | |
predictions_table = gr.DataFrame(label="Top Predictions & Scores") | |
gr.Markdown("### Predicted Molecule Visualizations (Top 5 Valid)") | |
with gr.Row(): | |
img_out_1 = gr.Image(label="Prediction 1", type="pil", interactive=False) | |
img_out_2 = gr.Image(label="Prediction 2", type="pil", interactive=False) | |
img_out_3 = gr.Image(label="Prediction 3", type="pil", interactive=False) | |
img_out_4 = gr.Image(label="Prediction 4", type="pil", interactive=False) | |
img_out_5 = gr.Image(label="Prediction 5", type="pil", interactive=False) | |
# Automatically populate on load for the default example | |
demo.load( | |
lambda: predict_and_visualize_masked_smiles("C1=CC=CC<mask>C1", "C=C"), | |
inputs=None, | |
outputs=[predictions_table, img_out_1, img_out_2, img_out_3, img_out_4, img_out_5, status_masked] | |
) | |
predict_button_masked.click( | |
predict_and_visualize_masked_smiles, | |
inputs=[smiles_input_masked, substructure_input], | |
outputs=[predictions_table, img_out_1, img_out_2, img_out_3, img_out_4, img_out_5, status_masked] | |
) | |
with gr.Tab("Molecule Viewer"): | |
gr.Markdown("Enter a SMILES string to display its 2D structure.") | |
smiles_input_viewer = gr.Textbox(label="SMILES String", value="C1=CC=CC=C1") | |
view_button_molecule = gr.Button("View Molecule") | |
status_viewer = gr.Textbox(label="Status", interactive=False) | |
molecule_image_output = gr.Image(label="Molecule Structure", type="pil", interactive=False) | |
# Automatically populate on load for the default example | |
demo.load( | |
lambda: display_molecule_image("C1=CC=CC=C1"), | |
inputs=None, | |
outputs=[molecule_image_output, status_viewer] | |
) | |
view_button_molecule.click( | |
display_molecule_image, | |
inputs=[smiles_input_viewer], | |
outputs=[molecule_image_output, status_viewer] | |
) | |
if __name__ == "__main__": | |
demo.launch() | |