Spaces:
Running
Running
# app.py | |
import gradio as gr | |
import torch | |
from transformers import AutoModelForMaskedLM, AutoTokenizer, pipeline, BitsAndBytesConfig | |
from rdkit import Chem | |
from rdkit.Chem import Draw, rdFMCS | |
from rdkit.Chem.Draw import MolToImage | |
# PIL is imported as Image by rdkit.Chem.Draw.MolToImage, but explicit import is good practice if used directly. | |
# from PIL import Image | |
import pandas as pd | |
import io | |
import base64 | |
import logging | |
# Model names | |
model_name = "seyonec/PubChem10M_SMILES_BPE_450k" | |
# Load tokenizer (doesn't need quantization) | |
fill_mask_tokenizer = AutoTokenizer.from_pretrained(model_name) | |
# Load model with quantization if available | |
model_kwargs = { | |
"torch_dtype": torch_dtype, | |
} | |
**model_kwargs | |
) | |
# Set model to evaluation mode for inference | |
fill_mask_model.eval() | |
# Create optimized pipeline | |
# Let pipeline infer device from model if possible, or set based on model's device | |
pipeline_device = fill_mask_model.device.index if hasattr(fill_mask_model.device, 'type') and fill_mask_model.device.type == "cuda" else -1 | |
fill_mask_pipeline = pipeline( | |
'fill-mask', | |
model=fill_mask_model, | |
) | |
logger.info("Models loaded successfully with optimizations") | |
return fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline | |
except Exception as e: | |
logger.error(f"Error loading optimized models: {e}") | |
device_idx = 0 if torch.cuda.is_available() else -1 | |
fill_mask_pipeline = pipeline('fill-mask', model=fill_mask_model, tokenizer=fill_mask_tokenizer, device=device_idx) | |
if torch.cuda.is_available(): | |
fill_mask_model.to("cuda") | |
return fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline | |
# Load models with optimizations | |
fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline = load_optimized_models() | |
# --- Memory Management Utilities --- | |
def clear_gpu_cache(): | |
# Unpack image_list into individual image outputs + df_results + status_message | |
return df_results, image_list[0], image_list[1], image_list[2], image_list[3], image_list[4], status_message | |
def display_molecule_image(smiles_string): | |
""" | |
Displays a 2D image of a molecule from its SMILES string. | |
outputs=[predictions_table, img_out_1, img_out_2, img_out_3, img_out_4, img_out_5, status_masked] | |
) | |
with gr.Tab("Molecule Viewer"): | |
gr.Markdown("Enter a SMILES string to display its 2D structure.") | |
smiles_input_viewer = gr.Textbox(label="SMILES String", value="C1=CC=CC=C1") |