# app_interactive.py import streamlit as st import torch import random import os import pandas as pd from transformers import RobertaForMaskedLM, PreTrainedTokenizerFast import re # --- Configuration --- CHECKPOINT_BASE_DIR = "./checkpoints" PRESET_SENTENCE = "The quick brown fox jumps over the lazy dog near the river bank." TOP_K = 5 # --- Initialize Session State --- if 'masked_indices' not in st.session_state: st.session_state.masked_indices = set() if 'tokens' not in st.session_state: st.session_state.tokens = [] if 'token_ids' not in st.session_state: st.session_state.token_ids = [] if 'input_sentence' not in st.session_state: st.session_state.input_sentence = PRESET_SENTENCE if 'display_tokens' not in st.session_state: st.session_state.display_tokens = [] # --- Helper Functions --- def sanitize_token_display(token): """Clean up token display by removing special characters like Ġ.""" # Replace the 'Ġ' character with a more readable indicator if isinstance(token, str) and token.startswith('Ġ'): return token[1:] # Remove the Ġ character # Handle other special tokens if needed elif token in ['', '', '']: return token else: return token def find_checkpoints(base_dir): """Finds valid checkpoint directories within the base directory.""" checkpoints = [] if not os.path.isdir(base_dir): return checkpoints for item in os.listdir(base_dir): path = os.path.join(base_dir, item) if os.path.isdir(path) and item.startswith("checkpoint-"): if os.path.exists(os.path.join(path, "pytorch_model.bin")) or \ os.path.exists(os.path.join(path, "model.safetensors")): checkpoints.append(item) checkpoints.sort(key=lambda x: int(re.search(r'(\d+)', x).group(1))) return checkpoints @st.cache_resource def load_model_and_tokenizer(checkpoint_name): """Loads the model and tokenizer from the specified checkpoint directory name.""" checkpoint_path = os.path.join(CHECKPOINT_BASE_DIR, checkpoint_name) if not os.path.isdir(checkpoint_path): st.error(f"Checkpoint directory not found: {checkpoint_path}") return None, None try: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = RobertaForMaskedLM.from_pretrained(checkpoint_path).to(device) tokenizer = PreTrainedTokenizerFast.from_pretrained(checkpoint_path) model.eval() #st.success(f"Loaded {checkpoint_name} on {device}") return model, tokenizer, device except Exception as e: st.error(f"Error loading {checkpoint_name}: {e}") return None, None, None def tokenize_text(text, tokenizer): """Tokenize the input text and return tokens and their IDs.""" encoding = tokenizer(text, return_tensors="pt", add_special_tokens=True) input_ids = encoding.input_ids[0].tolist() # Get individual tokens tokens = [] for id in input_ids: token = tokenizer.convert_ids_to_tokens(id) tokens.append(token) return tokens, input_ids def toggle_token(index): """Toggle a token's masked status.""" if index in st.session_state.masked_indices: st.session_state.masked_indices.remove(index) else: st.session_state.masked_indices.add(index) def update_input_sentence(): """Update the input sentence and reset masked indices.""" st.session_state.input_sentence = st.session_state.input_text st.session_state.masked_indices = set() def get_predictions(model, tokenizer, device): """Get predictions for masked tokens.""" if not st.session_state.masked_indices: return None, None, None, None # Create a copy of the token IDs masked_input_ids = st.session_state.token_ids.copy() # Apply masks for idx in st.session_state.masked_indices: masked_input_ids[idx] = tokenizer.mask_token_id # Convert to tensor masked_input_tensor = torch.tensor([masked_input_ids]).to(device) # Get predictions with torch.no_grad(): outputs = model(input_ids=masked_input_tensor) logits = outputs.logits results = [] top1_predictions = {} prediction_tokens = {} original_token_ranks = {} for masked_index in st.session_state.masked_indices: mask_logits = logits[0, masked_index, :] probabilities = torch.softmax(mask_logits, dim=-1) top_k_probs, top_k_indices = torch.topk(probabilities, TOP_K) # Save top-1 prediction for reconstruction top1_id = top_k_indices[0].item() top1_predictions[masked_index] = top1_id # Sanitize the token here raw_token = tokenizer.convert_ids_to_tokens(top1_id) prediction_tokens[masked_index] = sanitize_token_display(raw_token) original_token = st.session_state.tokens[masked_index] original_id = st.session_state.token_ids[masked_index] # Check if original token is in top K predictions original_token_in_top_k = False original_token_rank = -1 # -1 means not in top K for rank, token_id in enumerate(top_k_indices.tolist()): predicted_token = tokenizer.convert_ids_to_tokens(token_id) if predicted_token.lower() == original_token.lower() or token_id == original_id: original_token_in_top_k = True original_token_rank = rank break original_token_ranks[masked_index] = original_token_rank for rank, (prob, token_id) in enumerate(zip(top_k_probs.tolist(), top_k_indices.tolist())): predicted_token = tokenizer.convert_ids_to_tokens(token_id) # Sanitize the predicted token for the results table clean_predicted_token = sanitize_token_display(predicted_token) # Case insensitive match is_match = predicted_token.lower() == original_token.lower() results.append({ "Masked Index": masked_index, "Rank": rank + 1, "Predicted Token": clean_predicted_token, # Use sanitized token "Original Token": sanitize_token_display(original_token), # Sanitize original token "Exact Match": is_match, "Probability": f"{prob:.4f}" }) # Reconstruct the sentence using top-1 predictions reconstructed_ids = masked_input_ids.copy() for idx in st.session_state.masked_indices: reconstructed_ids[idx] = top1_predictions[idx] reconstructed_text = tokenizer.decode(reconstructed_ids, skip_special_tokens=True) return results, reconstructed_text, prediction_tokens, original_token_ranks # --- Streamlit App Layout --- st.set_page_config(layout="wide", page_title="Interactive MLM Inference") # Custom CSS to prevent text wrapping in buttons st.markdown(""" """, unsafe_allow_html=True) st.title("🧪 Interactive MLM Inference") # --- Checkpoint Selection --- available_checkpoints = find_checkpoints(CHECKPOINT_BASE_DIR) if not available_checkpoints: st.error(f"No checkpoints found in '{CHECKPOINT_BASE_DIR}'. Please train a model first.") st.stop() selected_checkpoint = st.selectbox( "Select Checkpoint:", available_checkpoints, index=len(available_checkpoints) - 1 ) # --- Load Model --- if selected_checkpoint: model, tokenizer, device = load_model_and_tokenizer(selected_checkpoint) else: model, tokenizer, device = None, None, None # --- Interactive Inference Section --- st.divider() st.subheader("Interactive Token Masking") # 1. Original text area st.text_area( "Input Sentence:", value=st.session_state.input_sentence, key="input_text", on_change=update_input_sentence, height=100 ) if model and tokenizer and device: # Tokenize the input text st.session_state.tokens, st.session_state.token_ids = tokenize_text( st.session_state.input_sentence, tokenizer ) # Create sanitized display tokens st.session_state.display_tokens = [sanitize_token_display(token) for token in st.session_state.tokens] # 2. Interactive token display st.subheader("Click on tokens to mask/unmask them:") # Group tokens into rows (adjust number as needed) tokens_per_row = 12 # Calculate how many rows we need num_rows = (len(st.session_state.tokens) + tokens_per_row - 1) // tokens_per_row for row in range(num_rows): # Create columns for this row start_idx = row * tokens_per_row end_idx = min(start_idx + tokens_per_row, len(st.session_state.tokens)) row_tokens = st.session_state.tokens[start_idx:end_idx] # Create equal-width columns cols = st.columns(len(row_tokens)) for j, col in enumerate(cols): idx = start_idx + j token = st.session_state.tokens[idx] # Skip special tokens for masking is_special = token in [ tokenizer.cls_token, tokenizer.sep_token, tokenizer.pad_token ] is_masked = idx in st.session_state.masked_indices # Create a button for each token button_key = f"token_{idx}" button_label = sanitize_token_display(token) if not is_masked else "[MASK]" if col.button( button_label, key=button_key, disabled=is_special, help=f"Token ID: {st.session_state.token_ids[idx]}" ): toggle_token(idx) st.rerun() # 3. Prediction area if st.session_state.masked_indices: results, reconstructed_text, prediction_tokens, original_token_ranks = get_predictions(model, tokenizer, device) st.subheader("Predictions:") st.markdown("**Reconstructed sentence with predictions:**") # Create HTML for highlighting predictions html = "
" # Use the original tokenization to match masked positions for i, token in enumerate(st.session_state.tokens): # Skip special tokens if token in [tokenizer.cls_token, tokenizer.sep_token, tokenizer.pad_token]: continue if i in st.session_state.masked_indices: # This was a masked token original_token = sanitize_token_display(st.session_state.tokens[i]) predicted_token = prediction_tokens[i] # This is already sanitized in get_predictions original_rank = original_token_ranks[i] # Color based on original token's rank in predictions if original_rank == 0: # Rank 0 means it was the top prediction # Green for top prediction (rank 1) html += f"{predicted_token}" elif original_rank != -1: # In top 5 but not top # Blue for in top 5 but not top html += f"{predicted_token}" else: # Not in top 5 # Red for not in top 5 html += f"{predicted_token}" else: # Not a masked token, display normally sanitized_token = sanitize_token_display(token) html += f"{sanitized_token} " html += "
" # Display the highlighted text st.markdown(html, unsafe_allow_html=True) # Show detailed predictions st.markdown("**Top predictions for each masked token:**") for masked_idx in st.session_state.masked_indices: original_token = st.session_state.tokens[masked_idx] original_rank = original_token_ranks[masked_idx] # Create a note about whether the original token was in top predictions if original_rank == 0: rank_note = "✅ Original token was the top prediction" elif original_rank != -1: rank_note = f"ℹ️ Original token was prediction #{original_rank+1}" else: rank_note = "❌ Original token not in top 5 predictions" # Sanitize the token display clean_original_token = sanitize_token_display(original_token) st.markdown(f"**Token {clean_original_token} at position {masked_idx}** - {rank_note}") # The dataframe is already sanitized in the get_predictions function df = pd.DataFrame([r for r in results if r['Masked Index'] == masked_idx]) df = df[["Rank", "Predicted Token", "Probability"]] # Highlight the row with the original token if it's in top 5 if original_rank != -1: # Use pandas styler to highlight the row styled_df = df.style.apply(lambda x: ['background-color: #c3e6cb' if i == original_rank else '' for i in range(len(x))], axis=0) st.dataframe(styled_df, use_container_width=True) else: st.dataframe(df, use_container_width=True) else: st.info("Click on tokens above to mask them and see predictions.") else: st.warning("Please select a valid checkpoint to enable interactive masking.") st.divider() st.caption("Interactive app for RoBERTa Masked Language Modeling.")