mol-lang-lab / app.py
alidenewade's picture
Update app.py
ba49293 verified
raw
history blame
2.6 kB
# app.py
import gradio as gr
import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer, pipeline, BitsAndBytesConfig
from rdkit import Chem
from rdkit.Chem import Draw, rdFMCS
from rdkit.Chem.Draw import MolToImage
# PIL is imported as Image by rdkit.Chem.Draw.MolToImage, but explicit import is good practice if used directly.
# from PIL import Image
import pandas as pd
import io
import base64
import logging
# Model names
model_name = "seyonec/PubChem10M_SMILES_BPE_450k"
# Load tokenizer (doesn't need quantization)
fill_mask_tokenizer = AutoTokenizer.from_pretrained(model_name)
# Load model with quantization if available
model_kwargs = {
"torch_dtype": torch_dtype,
}
**model_kwargs
)
# Set model to evaluation mode for inference
fill_mask_model.eval()
# Create optimized pipeline
# Let pipeline infer device from model if possible, or set based on model's device
pipeline_device = fill_mask_model.device.index if hasattr(fill_mask_model.device, 'type') and fill_mask_model.device.type == "cuda" else -1
fill_mask_pipeline = pipeline(
'fill-mask',
model=fill_mask_model,
)
logger.info("Models loaded successfully with optimizations")
return fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline
except Exception as e:
logger.error(f"Error loading optimized models: {e}")
device_idx = 0 if torch.cuda.is_available() else -1
fill_mask_pipeline = pipeline('fill-mask', model=fill_mask_model, tokenizer=fill_mask_tokenizer, device=device_idx)
if torch.cuda.is_available():
fill_mask_model.to("cuda")
return fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline
# Load models with optimizations
fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline = load_optimized_models()
# --- Memory Management Utilities ---
def clear_gpu_cache():
# Unpack image_list into individual image outputs + df_results + status_message
return df_results, image_list[0], image_list[1], image_list[2], image_list[3], image_list[4], status_message
def display_molecule_image(smiles_string):
"""
Displays a 2D image of a molecule from its SMILES string.
outputs=[predictions_table, img_out_1, img_out_2, img_out_3, img_out_4, img_out_5, status_masked]
)
with gr.Tab("Molecule Viewer"):
gr.Markdown("Enter a SMILES string to display its 2D structure.")
smiles_input_viewer = gr.Textbox(label="SMILES String", value="C1=CC=CC=C1")