File size: 14,529 Bytes
a08869d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 |
# app_interactive.py
import streamlit as st
import torch
import random
import os
import pandas as pd
from transformers import RobertaForMaskedLM, PreTrainedTokenizerFast
import re
# --- Configuration ---
CHECKPOINT_BASE_DIR = "./checkpoints"
PRESET_SENTENCE = "The quick brown fox jumps over the lazy dog near the river bank."
TOP_K = 5
# --- Initialize Session State ---
if 'masked_indices' not in st.session_state:
st.session_state.masked_indices = set()
if 'tokens' not in st.session_state:
st.session_state.tokens = []
if 'token_ids' not in st.session_state:
st.session_state.token_ids = []
if 'input_sentence' not in st.session_state:
st.session_state.input_sentence = PRESET_SENTENCE
if 'display_tokens' not in st.session_state:
st.session_state.display_tokens = []
# --- Helper Functions ---
def sanitize_token_display(token):
"""Clean up token display by removing special characters like Ġ."""
# Replace the 'Ġ' character with a more readable indicator
if isinstance(token, str) and token.startswith('Ġ'):
return token[1:] # Remove the Ġ character
# Handle other special tokens if needed
elif token in ['<s>', '</s>', '<pad>']:
return token
else:
return token
def find_checkpoints(base_dir):
"""Finds valid checkpoint directories within the base directory."""
checkpoints = []
if not os.path.isdir(base_dir):
return checkpoints
for item in os.listdir(base_dir):
path = os.path.join(base_dir, item)
if os.path.isdir(path) and item.startswith("checkpoint-"):
if os.path.exists(os.path.join(path, "pytorch_model.bin")) or \
os.path.exists(os.path.join(path, "model.safetensors")):
checkpoints.append(item)
checkpoints.sort(key=lambda x: int(re.search(r'(\d+)', x).group(1)))
return checkpoints
@st.cache_resource
def load_model_and_tokenizer(checkpoint_name):
"""Loads the model and tokenizer from the specified checkpoint directory name."""
checkpoint_path = os.path.join(CHECKPOINT_BASE_DIR, checkpoint_name)
if not os.path.isdir(checkpoint_path):
st.error(f"Checkpoint directory not found: {checkpoint_path}")
return None, None
try:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = RobertaForMaskedLM.from_pretrained(checkpoint_path).to(device)
tokenizer = PreTrainedTokenizerFast.from_pretrained(checkpoint_path)
model.eval()
#st.success(f"Loaded {checkpoint_name} on {device}")
return model, tokenizer, device
except Exception as e:
st.error(f"Error loading {checkpoint_name}: {e}")
return None, None, None
def tokenize_text(text, tokenizer):
"""Tokenize the input text and return tokens and their IDs."""
encoding = tokenizer(text, return_tensors="pt", add_special_tokens=True)
input_ids = encoding.input_ids[0].tolist()
# Get individual tokens
tokens = []
for id in input_ids:
token = tokenizer.convert_ids_to_tokens(id)
tokens.append(token)
return tokens, input_ids
def toggle_token(index):
"""Toggle a token's masked status."""
if index in st.session_state.masked_indices:
st.session_state.masked_indices.remove(index)
else:
st.session_state.masked_indices.add(index)
def update_input_sentence():
"""Update the input sentence and reset masked indices."""
st.session_state.input_sentence = st.session_state.input_text
st.session_state.masked_indices = set()
def get_predictions(model, tokenizer, device):
"""Get predictions for masked tokens."""
if not st.session_state.masked_indices:
return None, None, None, None
# Create a copy of the token IDs
masked_input_ids = st.session_state.token_ids.copy()
# Apply masks
for idx in st.session_state.masked_indices:
masked_input_ids[idx] = tokenizer.mask_token_id
# Convert to tensor
masked_input_tensor = torch.tensor([masked_input_ids]).to(device)
# Get predictions
with torch.no_grad():
outputs = model(input_ids=masked_input_tensor)
logits = outputs.logits
results = []
top1_predictions = {}
prediction_tokens = {}
original_token_ranks = {}
for masked_index in st.session_state.masked_indices:
mask_logits = logits[0, masked_index, :]
probabilities = torch.softmax(mask_logits, dim=-1)
top_k_probs, top_k_indices = torch.topk(probabilities, TOP_K)
# Save top-1 prediction for reconstruction
top1_id = top_k_indices[0].item()
top1_predictions[masked_index] = top1_id
# Sanitize the token here
raw_token = tokenizer.convert_ids_to_tokens(top1_id)
prediction_tokens[masked_index] = sanitize_token_display(raw_token)
original_token = st.session_state.tokens[masked_index]
original_id = st.session_state.token_ids[masked_index]
# Check if original token is in top K predictions
original_token_in_top_k = False
original_token_rank = -1 # -1 means not in top K
for rank, token_id in enumerate(top_k_indices.tolist()):
predicted_token = tokenizer.convert_ids_to_tokens(token_id)
if predicted_token.lower() == original_token.lower() or token_id == original_id:
original_token_in_top_k = True
original_token_rank = rank
break
original_token_ranks[masked_index] = original_token_rank
for rank, (prob, token_id) in enumerate(zip(top_k_probs.tolist(), top_k_indices.tolist())):
predicted_token = tokenizer.convert_ids_to_tokens(token_id)
# Sanitize the predicted token for the results table
clean_predicted_token = sanitize_token_display(predicted_token)
# Case insensitive match
is_match = predicted_token.lower() == original_token.lower()
results.append({
"Masked Index": masked_index,
"Rank": rank + 1,
"Predicted Token": clean_predicted_token, # Use sanitized token
"Original Token": sanitize_token_display(original_token), # Sanitize original token
"Exact Match": is_match,
"Probability": f"{prob:.4f}"
})
# Reconstruct the sentence using top-1 predictions
reconstructed_ids = masked_input_ids.copy()
for idx in st.session_state.masked_indices:
reconstructed_ids[idx] = top1_predictions[idx]
reconstructed_text = tokenizer.decode(reconstructed_ids, skip_special_tokens=True)
return results, reconstructed_text, prediction_tokens, original_token_ranks
# --- Streamlit App Layout ---
st.set_page_config(layout="wide", page_title="Interactive MLM Inference")
# Custom CSS to prevent text wrapping in buttons
st.markdown("""
<style>
.stButton button {
white-space: nowrap;
overflow: hidden;
text-overflow: ellipsis;
min-width: 80px;
}
</style>
""", unsafe_allow_html=True)
st.title("🧪 Interactive MLM Inference")
# --- Checkpoint Selection ---
available_checkpoints = find_checkpoints(CHECKPOINT_BASE_DIR)
if not available_checkpoints:
st.error(f"No checkpoints found in '{CHECKPOINT_BASE_DIR}'. Please train a model first.")
st.stop()
selected_checkpoint = st.selectbox(
"Select Checkpoint:",
available_checkpoints,
index=len(available_checkpoints) - 1
)
# --- Load Model ---
if selected_checkpoint:
model, tokenizer, device = load_model_and_tokenizer(selected_checkpoint)
else:
model, tokenizer, device = None, None, None
# --- Interactive Inference Section ---
st.divider()
st.subheader("Interactive Token Masking")
# 1. Original text area
st.text_area(
"Input Sentence:",
value=st.session_state.input_sentence,
key="input_text",
on_change=update_input_sentence,
height=100
)
if model and tokenizer and device:
# Tokenize the input text
st.session_state.tokens, st.session_state.token_ids = tokenize_text(
st.session_state.input_sentence,
tokenizer
)
# Create sanitized display tokens
st.session_state.display_tokens = [sanitize_token_display(token) for token in st.session_state.tokens]
# 2. Interactive token display
st.subheader("Click on tokens to mask/unmask them:")
# Group tokens into rows (adjust number as needed)
tokens_per_row = 12
# Calculate how many rows we need
num_rows = (len(st.session_state.tokens) + tokens_per_row - 1) // tokens_per_row
for row in range(num_rows):
# Create columns for this row
start_idx = row * tokens_per_row
end_idx = min(start_idx + tokens_per_row, len(st.session_state.tokens))
row_tokens = st.session_state.tokens[start_idx:end_idx]
# Create equal-width columns
cols = st.columns(len(row_tokens))
for j, col in enumerate(cols):
idx = start_idx + j
token = st.session_state.tokens[idx]
# Skip special tokens for masking
is_special = token in [
tokenizer.cls_token,
tokenizer.sep_token,
tokenizer.pad_token
]
is_masked = idx in st.session_state.masked_indices
# Create a button for each token
button_key = f"token_{idx}"
button_label = sanitize_token_display(token) if not is_masked else "[MASK]"
if col.button(
button_label,
key=button_key,
disabled=is_special,
help=f"Token ID: {st.session_state.token_ids[idx]}"
):
toggle_token(idx)
st.rerun()
# 3. Prediction area
if st.session_state.masked_indices:
results, reconstructed_text, prediction_tokens, original_token_ranks = get_predictions(model, tokenizer, device)
st.subheader("Predictions:")
st.markdown("**Reconstructed sentence with predictions:**")
# Create HTML for highlighting predictions
html = "<div style='padding: 10px; border-radius: 5px; border: 1px solid #ccc;'>"
# Use the original tokenization to match masked positions
for i, token in enumerate(st.session_state.tokens):
# Skip special tokens
if token in [tokenizer.cls_token, tokenizer.sep_token, tokenizer.pad_token]:
continue
if i in st.session_state.masked_indices:
# This was a masked token
original_token = sanitize_token_display(st.session_state.tokens[i])
predicted_token = prediction_tokens[i] # This is already sanitized in get_predictions
original_rank = original_token_ranks[i]
# Color based on original token's rank in predictions
if original_rank == 0: # Rank 0 means it was the top prediction
# Green for top prediction (rank 1)
html += f"<span style='background-color: #c3e6cb; padding: 2px 4px; border-radius: 3px; margin: 0 2px;'>{predicted_token}</span>"
elif original_rank != -1: # In top 5 but not top
# Blue for in top 5 but not top
html += f"<span style='background-color: #b8daff; padding: 2px 4px; border-radius: 3px; margin: 0 2px;'>{predicted_token}</span>"
else: # Not in top 5
# Red for not in top 5
html += f"<span style='background-color: #f8d7da; padding: 2px 4px; border-radius: 3px; margin: 0 2px;'>{predicted_token}</span>"
else:
# Not a masked token, display normally
sanitized_token = sanitize_token_display(token)
html += f"{sanitized_token} "
html += "</div>"
# Display the highlighted text
st.markdown(html, unsafe_allow_html=True)
# Show detailed predictions
st.markdown("**Top predictions for each masked token:**")
for masked_idx in st.session_state.masked_indices:
original_token = st.session_state.tokens[masked_idx]
original_rank = original_token_ranks[masked_idx]
# Create a note about whether the original token was in top predictions
if original_rank == 0:
rank_note = "✅ Original token was the top prediction"
elif original_rank != -1:
rank_note = f"ℹ️ Original token was prediction #{original_rank+1}"
else:
rank_note = "❌ Original token not in top 5 predictions"
# Sanitize the token display
clean_original_token = sanitize_token_display(original_token)
st.markdown(f"**Token {clean_original_token} at position {masked_idx}** - {rank_note}")
# The dataframe is already sanitized in the get_predictions function
df = pd.DataFrame([r for r in results if r['Masked Index'] == masked_idx])
df = df[["Rank", "Predicted Token", "Probability"]]
# Highlight the row with the original token if it's in top 5
if original_rank != -1:
# Use pandas styler to highlight the row
styled_df = df.style.apply(lambda x: ['background-color: #c3e6cb' if i == original_rank else '' for i in range(len(x))], axis=0)
st.dataframe(styled_df, use_container_width=True)
else:
st.dataframe(df, use_container_width=True)
else:
st.info("Click on tokens above to mask them and see predictions.")
else:
st.warning("Please select a valid checkpoint to enable interactive masking.")
st.divider()
st.caption("Interactive app for RoBERTa Masked Language Modeling.") |