']:
return token
else:
return token
def find_checkpoints(base_dir):
"""Finds valid checkpoint directories within the base directory."""
checkpoints = []
if not os.path.isdir(base_dir):
return checkpoints
for item in os.listdir(base_dir):
path = os.path.join(base_dir, item)
if os.path.isdir(path) and item.startswith("checkpoint-"):
if os.path.exists(os.path.join(path, "pytorch_model.bin")) or \
os.path.exists(os.path.join(path, "model.safetensors")):
checkpoints.append(item)
checkpoints.sort(key=lambda x: int(re.search(r'(\d+)', x).group(1)))
return checkpoints
@st.cache_resource
def load_model_and_tokenizer(checkpoint_name):
"""Loads the model and tokenizer from the specified checkpoint directory name."""
checkpoint_path = os.path.join(CHECKPOINT_BASE_DIR, checkpoint_name)
if not os.path.isdir(checkpoint_path):
st.error(f"Checkpoint directory not found: {checkpoint_path}")
return None, None
try:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = RobertaForMaskedLM.from_pretrained(checkpoint_path).to(device)
tokenizer = PreTrainedTokenizerFast.from_pretrained(checkpoint_path)
model.eval()
#st.success(f"Loaded {checkpoint_name} on {device}")
return model, tokenizer, device
except Exception as e:
st.error(f"Error loading {checkpoint_name}: {e}")
return None, None, None
def tokenize_text(text, tokenizer):
"""Tokenize the input text and return tokens and their IDs."""
encoding = tokenizer(text, return_tensors="pt", add_special_tokens=True)
input_ids = encoding.input_ids[0].tolist()
# Get individual tokens
tokens = []
for id in input_ids:
token = tokenizer.convert_ids_to_tokens(id)
tokens.append(token)
return tokens, input_ids
def toggle_token(index):
"""Toggle a token's masked status."""
if index in st.session_state.masked_indices:
st.session_state.masked_indices.remove(index)
else:
st.session_state.masked_indices.add(index)
def update_input_sentence():
"""Update the input sentence and reset masked indices."""
st.session_state.input_sentence = st.session_state.input_text
st.session_state.masked_indices = set()
def get_predictions(model, tokenizer, device):
"""Get predictions for masked tokens."""
if not st.session_state.masked_indices:
return None, None, None, None
# Create a copy of the token IDs
masked_input_ids = st.session_state.token_ids.copy()
# Apply masks
for idx in st.session_state.masked_indices:
masked_input_ids[idx] = tokenizer.mask_token_id
# Convert to tensor
masked_input_tensor = torch.tensor([masked_input_ids]).to(device)
# Get predictions
with torch.no_grad():
outputs = model(input_ids=masked_input_tensor)
logits = outputs.logits
results = []
top1_predictions = {}
prediction_tokens = {}
original_token_ranks = {}
for masked_index in st.session_state.masked_indices:
mask_logits = logits[0, masked_index, :]
probabilities = torch.softmax(mask_logits, dim=-1)
top_k_probs, top_k_indices = torch.topk(probabilities, TOP_K)
# Save top-1 prediction for reconstruction
top1_id = top_k_indices[0].item()
top1_predictions[masked_index] = top1_id
# Sanitize the token here
raw_token = tokenizer.convert_ids_to_tokens(top1_id)
prediction_tokens[masked_index] = sanitize_token_display(raw_token)
original_token = st.session_state.tokens[masked_index]
original_id = st.session_state.token_ids[masked_index]
# Check if original token is in top K predictions
original_token_in_top_k = False
original_token_rank = -1 # -1 means not in top K
for rank, token_id in enumerate(top_k_indices.tolist()):
predicted_token = tokenizer.convert_ids_to_tokens(token_id)
if predicted_token.lower() == original_token.lower() or token_id == original_id:
original_token_in_top_k = True
original_token_rank = rank
break
original_token_ranks[masked_index] = original_token_rank
for rank, (prob, token_id) in enumerate(zip(top_k_probs.tolist(), top_k_indices.tolist())):
predicted_token = tokenizer.convert_ids_to_tokens(token_id)
# Sanitize the predicted token for the results table
clean_predicted_token = sanitize_token_display(predicted_token)
# Case insensitive match
is_match = predicted_token.lower() == original_token.lower()
results.append({
"Masked Index": masked_index,
"Rank": rank + 1,
"Predicted Token": clean_predicted_token, # Use sanitized token
"Original Token": sanitize_token_display(original_token), # Sanitize original token
"Exact Match": is_match,
"Probability": f"{prob:.4f}"
})
# Reconstruct the sentence using top-1 predictions
reconstructed_ids = masked_input_ids.copy()
for idx in st.session_state.masked_indices:
reconstructed_ids[idx] = top1_predictions[idx]
reconstructed_text = tokenizer.decode(reconstructed_ids, skip_special_tokens=True)
return results, reconstructed_text, prediction_tokens, original_token_ranks
# --- Streamlit App Layout ---
st.set_page_config(layout="wide", page_title="Interactive MLM Inference")
# Custom CSS to prevent text wrapping in buttons
st.markdown("""
""", unsafe_allow_html=True)
st.title("🧪 Interactive MLM Inference")
# --- Checkpoint Selection ---
available_checkpoints = find_checkpoints(CHECKPOINT_BASE_DIR)
if not available_checkpoints:
st.error(f"No checkpoints found in '{CHECKPOINT_BASE_DIR}'. Please train a model first.")
st.stop()
selected_checkpoint = st.selectbox(
"Select Checkpoint:",
available_checkpoints,
index=len(available_checkpoints) - 1
)
# --- Load Model ---
if selected_checkpoint:
model, tokenizer, device = load_model_and_tokenizer(selected_checkpoint)
else:
model, tokenizer, device = None, None, None
# --- Interactive Inference Section ---
st.divider()
st.subheader("Interactive Token Masking")
# 1. Original text area
st.text_area(
"Input Sentence:",
value=st.session_state.input_sentence,
key="input_text",
on_change=update_input_sentence,
height=100
)
if model and tokenizer and device:
# Tokenize the input text
st.session_state.tokens, st.session_state.token_ids = tokenize_text(
st.session_state.input_sentence,
tokenizer
)
# Create sanitized display tokens
st.session_state.display_tokens = [sanitize_token_display(token) for token in st.session_state.tokens]
# 2. Interactive token display
st.subheader("Click on tokens to mask/unmask them:")
# Group tokens into rows (adjust number as needed)
tokens_per_row = 12
# Calculate how many rows we need
num_rows = (len(st.session_state.tokens) + tokens_per_row - 1) // tokens_per_row
for row in range(num_rows):
# Create columns for this row
start_idx = row * tokens_per_row
end_idx = min(start_idx + tokens_per_row, len(st.session_state.tokens))
row_tokens = st.session_state.tokens[start_idx:end_idx]
# Create equal-width columns
cols = st.columns(len(row_tokens))
for j, col in enumerate(cols):
idx = start_idx + j
token = st.session_state.tokens[idx]
# Skip special tokens for masking
is_special = token in [
tokenizer.cls_token,
tokenizer.sep_token,
tokenizer.pad_token
]
is_masked = idx in st.session_state.masked_indices
# Create a button for each token
button_key = f"token_{idx}"
button_label = sanitize_token_display(token) if not is_masked else "[MASK]"
if col.button(
button_label,
key=button_key,
disabled=is_special,
help=f"Token ID: {st.session_state.token_ids[idx]}"
):
toggle_token(idx)
st.rerun()
# 3. Prediction area
if st.session_state.masked_indices:
results, reconstructed_text, prediction_tokens, original_token_ranks = get_predictions(model, tokenizer, device)
st.subheader("Predictions:")
st.markdown("**Reconstructed sentence with predictions:**")
# Create HTML for highlighting predictions
html = ""
# Use the original tokenization to match masked positions
for i, token in enumerate(st.session_state.tokens):
# Skip special tokens
if token in [tokenizer.cls_token, tokenizer.sep_token, tokenizer.pad_token]:
continue
if i in st.session_state.masked_indices:
# This was a masked token
original_token = sanitize_token_display(st.session_state.tokens[i])
predicted_token = prediction_tokens[i] # This is already sanitized in get_predictions
original_rank = original_token_ranks[i]
# Color based on original token's rank in predictions
if original_rank == 0: # Rank 0 means it was the top prediction
# Green for top prediction (rank 1)
html += f"{predicted_token}"
elif original_rank != -1: # In top 5 but not top
# Blue for in top 5 but not top
html += f"{predicted_token}"
else: # Not in top 5
# Red for not in top 5
html += f"{predicted_token}"
else:
# Not a masked token, display normally
sanitized_token = sanitize_token_display(token)
html += f"{sanitized_token} "
html += "
"
# Display the highlighted text
st.markdown(html, unsafe_allow_html=True)
# Show detailed predictions
st.markdown("**Top predictions for each masked token:**")
for masked_idx in st.session_state.masked_indices:
original_token = st.session_state.tokens[masked_idx]
original_rank = original_token_ranks[masked_idx]
# Create a note about whether the original token was in top predictions
if original_rank == 0:
rank_note = "✅ Original token was the top prediction"
elif original_rank != -1:
rank_note = f"ℹ️ Original token was prediction #{original_rank+1}"
else:
rank_note = "❌ Original token not in top 5 predictions"
# Sanitize the token display
clean_original_token = sanitize_token_display(original_token)
st.markdown(f"**Token {clean_original_token} at position {masked_idx}** - {rank_note}")
# The dataframe is already sanitized in the get_predictions function
df = pd.DataFrame([r for r in results if r['Masked Index'] == masked_idx])
df = df[["Rank", "Predicted Token", "Probability"]]
# Highlight the row with the original token if it's in top 5
if original_rank != -1:
# Use pandas styler to highlight the row
styled_df = df.style.apply(lambda x: ['background-color: #c3e6cb' if i == original_rank else '' for i in range(len(x))], axis=0)
st.dataframe(styled_df, use_container_width=True)
else:
st.dataframe(df, use_container_width=True)
else:
st.info("Click on tokens above to mask them and see predictions.")
else:
st.warning("Please select a valid checkpoint to enable interactive masking.")
st.divider()
st.caption("Interactive app for RoBERTa Masked Language Modeling.")