Spaces:
Running
Running
File size: 16,508 Bytes
1850745 ef610f3 1850745 ef610f3 1850745 ef610f3 1850745 ef610f3 1850745 ef610f3 1850745 ef610f3 1850745 ef610f3 e27071c ef610f3 1850745 ef610f3 1850745 ef610f3 1850745 ef610f3 1850745 ef610f3 1850745 ef610f3 1850745 e27071c 1850745 e27071c 1850745 e27071c ef610f3 1850745 ef610f3 e27071c ef610f3 1850745 e27071c 1850745 ef610f3 1850745 ef610f3 1850745 ef610f3 1850745 e27071c ef610f3 1850745 ef610f3 e27071c 1850745 e27071c ef610f3 e27071c 1850745 ef610f3 e27071c 1850745 ef610f3 1850745 ef610f3 1850745 ef610f3 1850745 e27071c 1850745 e27071c 1850745 e27071c 1850745 e27071c 1850745 e27071c 1850745 ef610f3 1850745 ef610f3 1850745 e27071c 1850745 e27071c 1850745 e27071c 1850745 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 |
# app.py
import gradio as gr
import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer, pipeline, RobertaModel, RobertaTokenizer, BitsAndBytesConfig
from rdkit import Chem
from rdkit.Chem import Draw, rdFMCS
from rdkit.Chem.Draw import MolToImage
# PIL is imported as Image by rdkit.Chem.Draw.MolToImage, but explicit import is good practice if used directly.
# from PIL import Image
import pandas as pd
from bertviz import head_view # For potential future use or if other parts rely on it
from bertviz import neuron_view as neuron_view_function # Specific import for neuron_view function
# IPython.core.display.HTML is generally for notebooks. Gradio's gr.HTML handles HTML strings directly.
# from IPython.core.display import HTML
import io
import base64
import logging
# Set up logging to monitor quantization effects
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# --- Quantization Configuration ---
def get_quantization_config():
"""
Configure 8-bit quantization for model optimization.
Falls back gracefully if bitsandbytes is not available.
"""
try:
# 8-bit quantization configuration - good balance of speed and quality
quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
bnb_8bit_compute_dtype=torch.float16,
bnb_8bit_use_double_quant=True, # Nested quantization for better compression
)
logger.info("8-bit quantization configuration loaded successfully")
return quantization_config
except ImportError:
logger.warning("bitsandbytes not available, falling back to standard loading")
return None
except Exception as e:
logger.warning(f"Quantization setup failed: {e}, using standard loading")
return None
def get_torch_dtype():
"""Get appropriate torch dtype based on available hardware."""
if torch.cuda.is_available():
return torch.float16 # Use half precision on GPU
else:
return torch.float32 # Keep full precision on CPU
# --- Optimized Model Loading ---
def load_optimized_models():
"""Load models with quantization and other optimizations."""
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = get_torch_dtype()
quantization_config = get_quantization_config()
logger.info(f"Loading models on device: {device} with dtype: {torch_dtype}")
# Model names
model_name = "seyonec/PubChem10M_SMILES_BPE_450k"
# Load tokenizers (these don't need quantization)
fill_mask_tokenizer = AutoTokenizer.from_pretrained(model_name)
attention_tokenizer = RobertaTokenizer.from_pretrained(model_name)
# Load models with quantization if available
model_kwargs = {
"torch_dtype": torch_dtype,
}
if quantization_config is not None and torch.cuda.is_available(): # Quantization typically for GPU
model_kwargs["quantization_config"] = quantization_config
# device_map="auto" is often used with bitsandbytes for automatic distribution
model_kwargs["device_map"] = "auto"
elif torch.cuda.is_available():
model_kwargs["device_map"] = "auto" # For non-quantized GPU loading
else:
model_kwargs["device_map"] = None # For CPU
try:
# Masked LM Model
fill_mask_model = AutoModelForMaskedLM.from_pretrained(
model_name,
**model_kwargs
)
# RoBERTa model for attention
attention_model_kwargs = model_kwargs.copy()
attention_model_kwargs["output_attentions"] = True
attention_model = RobertaModel.from_pretrained(
model_name,
**attention_model_kwargs
)
# Set models to evaluation mode for inference
fill_mask_model.eval()
attention_model.eval()
# Create optimized pipeline
# Let pipeline infer device from model if possible, or set based on model's device
pipeline_device = fill_mask_model.device.index if hasattr(fill_mask_model.device, 'type') and fill_mask_model.device.type == "cuda" else -1
fill_mask_pipeline = pipeline(
'fill-mask',
model=fill_mask_model,
tokenizer=fill_mask_tokenizer,
device=pipeline_device, # Use model's device
# torch_dtype=torch_dtype # Pipeline might infer this or it might conflict
)
logger.info("Models loaded successfully with optimizations")
return fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline, attention_model, attention_tokenizer
except Exception as e:
logger.error(f"Error loading optimized models: {e}")
# Fallback to standard loading
logger.info("Falling back to standard model loading...")
return load_standard_models(model_name)
def load_standard_models(model_name):
"""Fallback standard model loading without quantization."""
fill_mask_tokenizer = AutoTokenizer.from_pretrained(model_name)
fill_mask_model = AutoModelForMaskedLM.from_pretrained(model_name)
# Determine device for standard loading
device_idx = 0 if torch.cuda.is_available() else -1
fill_mask_pipeline = pipeline('fill-mask', model=fill_mask_model, tokenizer=fill_mask_tokenizer, device=device_idx)
attention_model = RobertaModel.from_pretrained(model_name, output_attentions=True)
attention_tokenizer = RobertaTokenizer.from_pretrained(model_name)
if torch.cuda.is_available():
fill_mask_model.to("cuda")
attention_model.to("cuda")
return fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline, attention_model, attention_tokenizer
# Load models with optimizations
fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline, attention_model, attention_tokenizer = load_optimized_models()
# --- Memory Management Utilities ---
def clear_gpu_cache():
"""Clear CUDA cache to free up memory."""
if torch.cuda.is_available():
torch.cuda.empty_cache()
# --- Helper Functions from Notebook (adapted) ---
def get_mol(smiles):
"""Converts SMILES to RDKit Mol object and Kekulizes it."""
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return None
try:
Chem.Kekulize(mol)
except: # Kekulization can fail for some structures
pass
return mol
def find_matches_one(mol, submol_smarts):
"""Finds all matching atoms for a SMARTS pattern in a molecule."""
if not mol or not submol_smarts:
return []
submol = Chem.MolFromSmarts(submol_smarts)
if not submol:
return []
matches = mol.GetSubstructMatches(submol)
return matches
def get_image_with_highlight(mol, atomset=None, size=(300, 300)):
"""Draws molecule with optional atom highlighting."""
if mol is None:
return None
highlight_color = (0, 1, 0, 0.5) # Green with some transparency
# Ensure atomset contains integers if not None or empty
valid_atomset = []
if atomset:
try:
valid_atomset = [int(a) for a in atomset]
except ValueError:
logger.warning(f"Invalid atom in atomset: {atomset}. Proceeding without highlighting problematic atoms.")
valid_atomset = [int(a) for a in atomset if str(a).isdigit()] # Filter out non-integers
img = MolToImage(mol, size=size, fitImage=True,
highlightAtoms=valid_atomset if valid_atomset else [],
highlightAtomColors={i: highlight_color for i in valid_atomset} if valid_atomset else {})
return img
# --- Optimized Gradio Interface Functions ---
def predict_and_visualize_masked_smiles(smiles_mask, substructure_smarts_highlight="CC=CC"):
"""
Predicts masked tokens in a SMILES string, shows scores, and visualizes molecules.
Optimized with memory management. Returns 7 items for Gradio outputs.
"""
if fill_mask_tokenizer.mask_token not in smiles_mask:
# Return 7 items for the 7 output components
return pd.DataFrame(), None, None, None, None, None, "Error: Input SMILES must contain a mask token (e.g., <mask>)."
try:
# Use torch.no_grad() for inference to save memory
with torch.no_grad():
predictions = fill_mask_pipeline(smiles_mask, top_k=10) # Get more to filter for valid ones
except Exception as e:
clear_gpu_cache() # Clear cache on error
# Return 7 items
return pd.DataFrame(), None, None, None, None, None, f"Error during prediction: {str(e)}"
results_data = []
image_list = []
valid_predictions_count = 0
for pred in predictions:
if valid_predictions_count >= 5:
break
predicted_smiles = pred['sequence']
score = pred['score']
mol = get_mol(predicted_smiles)
if mol:
results_data.append({"Predicted SMILES": predicted_smiles, "Score": f"{score:.4f}"})
atom_matches_indices = []
if substructure_smarts_highlight:
matches = find_matches_one(mol, substructure_smarts_highlight)
if matches:
atom_matches_indices = list(matches[0]) # Highlight first match
img = get_image_with_highlight(mol, atomset=atom_matches_indices)
image_list.append(img)
valid_predictions_count += 1
# Pad image_list if fewer than 5 valid predictions
while len(image_list) < 5:
image_list.append(None)
df_results = pd.DataFrame(results_data)
# Clear cache after inference
clear_gpu_cache()
status_message = "Prediction successful." if valid_predictions_count > 0 else "No valid molecules found for top predictions."
# Unpack image_list into individual image outputs + df_results + status_message
return df_results, image_list[0], image_list[1], image_list[2], image_list[3], image_list[4], status_message
def visualize_attention_bertviz(sentence_a, sentence_b):
"""
Generates and displays BertViz neuron-by-neuron attention view as HTML.
Optimized with memory management and mixed precision.
"""
if not sentence_a or not sentence_b:
return "<p style='color:red;'>Please provide two SMILES strings.</p>"
try:
inputs = attention_tokenizer.encode_plus(sentence_a, sentence_b, return_tensors='pt', add_special_tokens=True)
input_ids = inputs['input_ids']
# Move to appropriate device if using GPU
if torch.cuda.is_available() and hasattr(attention_model, 'device'):
input_ids = input_ids.to(attention_model.device)
# Ensure model is in eval mode and use no_grad for inference
attention_model.eval()
with torch.no_grad():
# Use autocast for mixed precision if on CUDA
if torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and hasattr(torch.cuda.amp, 'autocast'): # Check for amp
with torch.cuda.amp.autocast(dtype=torch.float16 if get_torch_dtype() == torch.float16 else None):
attention_outputs = attention_model(input_ids)
else:
attention_outputs = attention_model(input_ids)
attention = attention_outputs[-1] # Last item in the tuple is attentions
input_id_list = input_ids[0].tolist()
tokens = attention_tokenizer.convert_ids_to_tokens(input_id_list)
# Using the specifically imported neuron_view_function
html_object = neuron_view_function(attention, tokens)
# Extract HTML string from the IPython.core.display.HTML object
html_string = html_object.data # .data should provide the HTML string
# Add D3 and jQuery CDN links to the HTML string for better rendering in Gradio
html_with_deps = f"""
<script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.5.1/jquery.min.js"></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/d3/5.16.0/d3.min.js"></script>
{html_string}
"""
# Clear cache after attention computation
clear_gpu_cache()
return html_with_deps
except Exception as e:
clear_gpu_cache() # Clear cache on error
logger.error(f"Error in visualize_attention_bertviz: {e}", exc_info=True)
return f"<p style='color:red;'>Error generating attention visualization: {str(e)}</p>"
def display_molecule_image(smiles_string):
"""
Displays a 2D image of a molecule from its SMILES string.
"""
if not smiles_string:
return None, "Please enter a SMILES string."
mol = get_mol(smiles_string)
if mol is None:
return None, "Invalid SMILES string."
img = MolToImage(mol, size=(400, 400), fitImage=True)
return img, "Molecule displayed."
# --- Gradio Interface Definition ---
with gr.Blocks(theme=gr.themes.Default()) as demo:
gr.Markdown("# ChemBERTa SMILES Utilities Dashboard")
with gr.Tab("Masked SMILES Prediction"):
gr.Markdown("Enter a SMILES string with a `<mask>` token (e.g., `C1=CC=CC<mask>C1`) to predict possible completions.")
with gr.Row():
smiles_input_masked = gr.Textbox(label="SMILES String with Mask", value="C1=CC=CC<mask>C1")
substructure_input = gr.Textbox(label="Substructure to Highlight (SMARTS)", value="C=C")
predict_button_masked = gr.Button("Predict and Visualize")
status_masked = gr.Textbox(label="Status", interactive=False)
predictions_table = gr.DataFrame(label="Top Predictions & Scores")
gr.Markdown("### Predicted Molecule Visualizations (Top 5 Valid)")
with gr.Row():
img_out_1 = gr.Image(label="Prediction 1", type="pil", interactive=False)
img_out_2 = gr.Image(label="Prediction 2", type="pil", interactive=False)
img_out_3 = gr.Image(label="Prediction 3", type="pil", interactive=False)
img_out_4 = gr.Image(label="Prediction 4", type="pil", interactive=False)
img_out_5 = gr.Image(label="Prediction 5", type="pil", interactive=False)
# Automatically populate on load for the default example
demo.load(
lambda: predict_and_visualize_masked_smiles("C1=CC=CC<mask>C1", "C=C"),
inputs=None,
outputs=[predictions_table, img_out_1, img_out_2, img_out_3, img_out_4, img_out_5, status_masked]
)
predict_button_masked.click(
predict_and_visualize_masked_smiles,
inputs=[smiles_input_masked, substructure_input],
outputs=[predictions_table, img_out_1, img_out_2, img_out_3, img_out_4, img_out_5, status_masked]
)
with gr.Tab("Attention Visualization"):
gr.Markdown("Enter two SMILES strings to visualize **neuron-by-neuron attention** between them using BertViz. This may take a moment to render.")
with gr.Row():
smiles_a_input_attn = gr.Textbox(label="SMILES String A", value="CCCCC[C@@H](Br)CC")
smiles_b_input_attn = gr.Textbox(label="SMILES String B", value="CCCCC[C@H](Br)CC")
visualize_button_attn = gr.Button("Visualize Attention")
attention_html_output = gr.HTML(label="Attention Neuron View") # Changed label for clarity
# Automatically populate on load for the default example
demo.load(
lambda: visualize_attention_bertviz("CCCCC[C@@H](Br)CC", "CCCCC[C@H](Br)CC"),
inputs=None,
outputs=[attention_html_output]
)
visualize_button_attn.click(
visualize_attention_bertviz,
inputs=[smiles_a_input_attn, smiles_b_input_attn],
outputs=[attention_html_output]
)
with gr.Tab("Molecule Viewer"):
gr.Markdown("Enter a SMILES string to display its 2D structure.")
smiles_input_viewer = gr.Textbox(label="SMILES String", value="C1=CC=CC=C1")
view_button_molecule = gr.Button("View Molecule")
status_viewer = gr.Textbox(label="Status", interactive=False)
molecule_image_output = gr.Image(label="Molecule Structure", type="pil", interactive=False)
# Automatically populate on load for the default example
demo.load(
lambda: display_molecule_image("C1=CC=CC=C1"),
inputs=None,
outputs=[molecule_image_output, status_viewer]
)
view_button_molecule.click(
display_molecule_image,
inputs=[smiles_input_viewer],
outputs=[molecule_image_output, status_viewer]
)
if __name__ == "__main__":
demo.launch() |