# app.py import gradio as gr import torch from transformers import AutoModelForMaskedLM, AutoTokenizer, pipeline, RobertaModel, RobertaTokenizer, BitsAndBytesConfig from rdkit import Chem from rdkit.Chem import Draw, rdFMCS from rdkit.Chem.Draw import MolToImage # PIL is imported as Image by rdkit.Chem.Draw.MolToImage, but explicit import is good practice if used directly. # from PIL import Image import pandas as pd from bertviz import head_view # For potential future use or if other parts rely on it from bertviz import neuron_view as neuron_view_function # Specific import for neuron_view function # IPython.core.display.HTML is generally for notebooks. Gradio's gr.HTML handles HTML strings directly. # from IPython.core.display import HTML import io import base64 import logging # Set up logging to monitor quantization effects logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # --- Quantization Configuration --- def get_quantization_config(): """ Configure 8-bit quantization for model optimization. Falls back gracefully if bitsandbytes is not available. """ try: # 8-bit quantization configuration - good balance of speed and quality quantization_config = BitsAndBytesConfig( load_in_8bit=True, bnb_8bit_compute_dtype=torch.float16, bnb_8bit_use_double_quant=True, # Nested quantization for better compression ) logger.info("8-bit quantization configuration loaded successfully") return quantization_config except ImportError: logger.warning("bitsandbytes not available, falling back to standard loading") return None except Exception as e: logger.warning(f"Quantization setup failed: {e}, using standard loading") return None def get_torch_dtype(): """Get appropriate torch dtype based on available hardware.""" if torch.cuda.is_available(): return torch.float16 # Use half precision on GPU else: return torch.float32 # Keep full precision on CPU # --- Optimized Model Loading --- def load_optimized_models(): """Load models with quantization and other optimizations.""" device = "cuda" if torch.cuda.is_available() else "cpu" torch_dtype = get_torch_dtype() quantization_config = get_quantization_config() logger.info(f"Loading models on device: {device} with dtype: {torch_dtype}") # Model names model_name = "seyonec/PubChem10M_SMILES_BPE_450k" # Load tokenizers (these don't need quantization) fill_mask_tokenizer = AutoTokenizer.from_pretrained(model_name) attention_tokenizer = RobertaTokenizer.from_pretrained(model_name) # Load models with quantization if available model_kwargs = { "torch_dtype": torch_dtype, } if quantization_config is not None and torch.cuda.is_available(): # Quantization typically for GPU model_kwargs["quantization_config"] = quantization_config # device_map="auto" is often used with bitsandbytes for automatic distribution model_kwargs["device_map"] = "auto" elif torch.cuda.is_available(): model_kwargs["device_map"] = "auto" # For non-quantized GPU loading else: model_kwargs["device_map"] = None # For CPU try: # Masked LM Model fill_mask_model = AutoModelForMaskedLM.from_pretrained( model_name, **model_kwargs ) # RoBERTa model for attention attention_model_kwargs = model_kwargs.copy() attention_model_kwargs["output_attentions"] = True attention_model = RobertaModel.from_pretrained( model_name, **attention_model_kwargs ) # Set models to evaluation mode for inference fill_mask_model.eval() attention_model.eval() # Create optimized pipeline # Let pipeline infer device from model if possible, or set based on model's device pipeline_device = fill_mask_model.device.index if hasattr(fill_mask_model.device, 'type') and fill_mask_model.device.type == "cuda" else -1 fill_mask_pipeline = pipeline( 'fill-mask', model=fill_mask_model, tokenizer=fill_mask_tokenizer, device=pipeline_device, # Use model's device # torch_dtype=torch_dtype # Pipeline might infer this or it might conflict ) logger.info("Models loaded successfully with optimizations") return fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline, attention_model, attention_tokenizer except Exception as e: logger.error(f"Error loading optimized models: {e}") # Fallback to standard loading logger.info("Falling back to standard model loading...") return load_standard_models(model_name) def load_standard_models(model_name): """Fallback standard model loading without quantization.""" fill_mask_tokenizer = AutoTokenizer.from_pretrained(model_name) fill_mask_model = AutoModelForMaskedLM.from_pretrained(model_name) # Determine device for standard loading device_idx = 0 if torch.cuda.is_available() else -1 fill_mask_pipeline = pipeline('fill-mask', model=fill_mask_model, tokenizer=fill_mask_tokenizer, device=device_idx) attention_model = RobertaModel.from_pretrained(model_name, output_attentions=True) attention_tokenizer = RobertaTokenizer.from_pretrained(model_name) if torch.cuda.is_available(): fill_mask_model.to("cuda") attention_model.to("cuda") return fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline, attention_model, attention_tokenizer # Load models with optimizations fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline, attention_model, attention_tokenizer = load_optimized_models() # --- Memory Management Utilities --- def clear_gpu_cache(): """Clear CUDA cache to free up memory.""" if torch.cuda.is_available(): torch.cuda.empty_cache() # --- Helper Functions from Notebook (adapted) --- def get_mol(smiles): """Converts SMILES to RDKit Mol object and Kekulizes it.""" mol = Chem.MolFromSmiles(smiles) if mol is None: return None try: Chem.Kekulize(mol) except: # Kekulization can fail for some structures pass return mol def find_matches_one(mol, submol_smarts): """Finds all matching atoms for a SMARTS pattern in a molecule.""" if not mol or not submol_smarts: return [] submol = Chem.MolFromSmarts(submol_smarts) if not submol: return [] matches = mol.GetSubstructMatches(submol) return matches def get_image_with_highlight(mol, atomset=None, size=(300, 300)): """Draws molecule with optional atom highlighting.""" if mol is None: return None highlight_color = (0, 1, 0, 0.5) # Green with some transparency # Ensure atomset contains integers if not None or empty valid_atomset = [] if atomset: try: valid_atomset = [int(a) for a in atomset] except ValueError: logger.warning(f"Invalid atom in atomset: {atomset}. Proceeding without highlighting problematic atoms.") valid_atomset = [int(a) for a in atomset if str(a).isdigit()] # Filter out non-integers img = MolToImage(mol, size=size, fitImage=True, highlightAtoms=valid_atomset if valid_atomset else [], highlightAtomColors={i: highlight_color for i in valid_atomset} if valid_atomset else {}) return img # --- Optimized Gradio Interface Functions --- def predict_and_visualize_masked_smiles(smiles_mask, substructure_smarts_highlight="CC=CC"): """ Predicts masked tokens in a SMILES string, shows scores, and visualizes molecules. Optimized with memory management. Returns 7 items for Gradio outputs. """ if fill_mask_tokenizer.mask_token not in smiles_mask: # Return 7 items for the 7 output components return pd.DataFrame(), None, None, None, None, None, "Error: Input SMILES must contain a mask token (e.g., )." try: # Use torch.no_grad() for inference to save memory with torch.no_grad(): predictions = fill_mask_pipeline(smiles_mask, top_k=10) # Get more to filter for valid ones except Exception as e: clear_gpu_cache() # Clear cache on error # Return 7 items return pd.DataFrame(), None, None, None, None, None, f"Error during prediction: {str(e)}" results_data = [] image_list = [] valid_predictions_count = 0 for pred in predictions: if valid_predictions_count >= 5: break predicted_smiles = pred['sequence'] score = pred['score'] mol = get_mol(predicted_smiles) if mol: results_data.append({"Predicted SMILES": predicted_smiles, "Score": f"{score:.4f}"}) atom_matches_indices = [] if substructure_smarts_highlight: matches = find_matches_one(mol, substructure_smarts_highlight) if matches: atom_matches_indices = list(matches[0]) # Highlight first match img = get_image_with_highlight(mol, atomset=atom_matches_indices) image_list.append(img) valid_predictions_count += 1 # Pad image_list if fewer than 5 valid predictions while len(image_list) < 5: image_list.append(None) df_results = pd.DataFrame(results_data) # Clear cache after inference clear_gpu_cache() status_message = "Prediction successful." if valid_predictions_count > 0 else "No valid molecules found for top predictions." # Unpack image_list into individual image outputs + df_results + status_message return df_results, image_list[0], image_list[1], image_list[2], image_list[3], image_list[4], status_message def visualize_attention_bertviz(sentence_a, sentence_b): """ Generates and displays BertViz neuron-by-neuron attention view as HTML. Optimized with memory management and mixed precision. """ if not sentence_a or not sentence_b: return "

Please provide two SMILES strings.

" try: inputs = attention_tokenizer.encode_plus(sentence_a, sentence_b, return_tensors='pt', add_special_tokens=True) input_ids = inputs['input_ids'] # Move to appropriate device if using GPU if torch.cuda.is_available() and hasattr(attention_model, 'device'): input_ids = input_ids.to(attention_model.device) # Ensure model is in eval mode and use no_grad for inference attention_model.eval() with torch.no_grad(): # Use autocast for mixed precision if on CUDA if torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and hasattr(torch.cuda.amp, 'autocast'): # Check for amp with torch.cuda.amp.autocast(dtype=torch.float16 if get_torch_dtype() == torch.float16 else None): attention_outputs = attention_model(input_ids) else: attention_outputs = attention_model(input_ids) attention = attention_outputs[-1] # Last item in the tuple is attentions input_id_list = input_ids[0].tolist() tokens = attention_tokenizer.convert_ids_to_tokens(input_id_list) # Using the specifically imported neuron_view_function html_object = neuron_view_function(attention, tokens) # Extract HTML string from the IPython.core.display.HTML object html_string = html_object.data # .data should provide the HTML string # Add D3 and jQuery CDN links to the HTML string for better rendering in Gradio html_with_deps = f""" {html_string} """ # Clear cache after attention computation clear_gpu_cache() return html_with_deps except Exception as e: clear_gpu_cache() # Clear cache on error logger.error(f"Error in visualize_attention_bertviz: {e}", exc_info=True) return f"

Error generating attention visualization: {str(e)}

" def display_molecule_image(smiles_string): """ Displays a 2D image of a molecule from its SMILES string. """ if not smiles_string: return None, "Please enter a SMILES string." mol = get_mol(smiles_string) if mol is None: return None, "Invalid SMILES string." img = MolToImage(mol, size=(400, 400), fitImage=True) return img, "Molecule displayed." # --- Gradio Interface Definition --- with gr.Blocks(theme=gr.themes.Default()) as demo: gr.Markdown("# ChemBERTa SMILES Utilities Dashboard") with gr.Tab("Masked SMILES Prediction"): gr.Markdown("Enter a SMILES string with a `` token (e.g., `C1=CC=CCC1`) to predict possible completions.") with gr.Row(): smiles_input_masked = gr.Textbox(label="SMILES String with Mask", value="C1=CC=CCC1") substructure_input = gr.Textbox(label="Substructure to Highlight (SMARTS)", value="C=C") predict_button_masked = gr.Button("Predict and Visualize") status_masked = gr.Textbox(label="Status", interactive=False) predictions_table = gr.DataFrame(label="Top Predictions & Scores") gr.Markdown("### Predicted Molecule Visualizations (Top 5 Valid)") with gr.Row(): img_out_1 = gr.Image(label="Prediction 1", type="pil", interactive=False) img_out_2 = gr.Image(label="Prediction 2", type="pil", interactive=False) img_out_3 = gr.Image(label="Prediction 3", type="pil", interactive=False) img_out_4 = gr.Image(label="Prediction 4", type="pil", interactive=False) img_out_5 = gr.Image(label="Prediction 5", type="pil", interactive=False) # Automatically populate on load for the default example demo.load( lambda: predict_and_visualize_masked_smiles("C1=CC=CCC1", "C=C"), inputs=None, outputs=[predictions_table, img_out_1, img_out_2, img_out_3, img_out_4, img_out_5, status_masked] ) predict_button_masked.click( predict_and_visualize_masked_smiles, inputs=[smiles_input_masked, substructure_input], outputs=[predictions_table, img_out_1, img_out_2, img_out_3, img_out_4, img_out_5, status_masked] ) with gr.Tab("Attention Visualization"): gr.Markdown("Enter two SMILES strings to visualize **neuron-by-neuron attention** between them using BertViz. This may take a moment to render.") with gr.Row(): smiles_a_input_attn = gr.Textbox(label="SMILES String A", value="CCCCC[C@@H](Br)CC") smiles_b_input_attn = gr.Textbox(label="SMILES String B", value="CCCCC[C@H](Br)CC") visualize_button_attn = gr.Button("Visualize Attention") attention_html_output = gr.HTML(label="Attention Neuron View") # Changed label for clarity # Automatically populate on load for the default example demo.load( lambda: visualize_attention_bertviz("CCCCC[C@@H](Br)CC", "CCCCC[C@H](Br)CC"), inputs=None, outputs=[attention_html_output] ) visualize_button_attn.click( visualize_attention_bertviz, inputs=[smiles_a_input_attn, smiles_b_input_attn], outputs=[attention_html_output] ) with gr.Tab("Molecule Viewer"): gr.Markdown("Enter a SMILES string to display its 2D structure.") smiles_input_viewer = gr.Textbox(label="SMILES String", value="C1=CC=CC=C1") view_button_molecule = gr.Button("View Molecule") status_viewer = gr.Textbox(label="Status", interactive=False) molecule_image_output = gr.Image(label="Molecule Structure", type="pil", interactive=False) # Automatically populate on load for the default example demo.load( lambda: display_molecule_image("C1=CC=CC=C1"), inputs=None, outputs=[molecule_image_output, status_viewer] ) view_button_molecule.click( display_molecule_image, inputs=[smiles_input_viewer], outputs=[molecule_image_output, status_viewer] ) if __name__ == "__main__": demo.launch()