|
|
|
import streamlit as st
|
|
import torch
|
|
import random
|
|
import os
|
|
import pandas as pd
|
|
from transformers import RobertaForMaskedLM, PreTrainedTokenizerFast
|
|
import re
|
|
|
|
|
|
CHECKPOINT_BASE_DIR = "./checkpoints"
|
|
PRESET_SENTENCE = "The quick brown fox jumps over the lazy dog near the river bank."
|
|
TOP_K = 5
|
|
|
|
|
|
if 'masked_indices' not in st.session_state:
|
|
st.session_state.masked_indices = set()
|
|
if 'tokens' not in st.session_state:
|
|
st.session_state.tokens = []
|
|
if 'token_ids' not in st.session_state:
|
|
st.session_state.token_ids = []
|
|
if 'input_sentence' not in st.session_state:
|
|
st.session_state.input_sentence = PRESET_SENTENCE
|
|
if 'display_tokens' not in st.session_state:
|
|
st.session_state.display_tokens = []
|
|
|
|
|
|
def sanitize_token_display(token):
|
|
"""Clean up token display by removing special characters like Ġ."""
|
|
|
|
if isinstance(token, str) and token.startswith('Ġ'):
|
|
return token[1:]
|
|
|
|
elif token in ['<s>', '</s>', '<pad>']:
|
|
return token
|
|
else:
|
|
return token
|
|
|
|
def find_checkpoints(base_dir):
|
|
"""Finds valid checkpoint directories within the base directory."""
|
|
checkpoints = []
|
|
if not os.path.isdir(base_dir):
|
|
return checkpoints
|
|
for item in os.listdir(base_dir):
|
|
path = os.path.join(base_dir, item)
|
|
if os.path.isdir(path) and item.startswith("checkpoint-"):
|
|
if os.path.exists(os.path.join(path, "pytorch_model.bin")) or \
|
|
os.path.exists(os.path.join(path, "model.safetensors")):
|
|
checkpoints.append(item)
|
|
checkpoints.sort(key=lambda x: int(re.search(r'(\d+)', x).group(1)))
|
|
return checkpoints
|
|
|
|
@st.cache_resource
|
|
def load_model_and_tokenizer(checkpoint_name):
|
|
"""Loads the model and tokenizer from the specified checkpoint directory name."""
|
|
checkpoint_path = os.path.join(CHECKPOINT_BASE_DIR, checkpoint_name)
|
|
if not os.path.isdir(checkpoint_path):
|
|
st.error(f"Checkpoint directory not found: {checkpoint_path}")
|
|
return None, None
|
|
try:
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
model = RobertaForMaskedLM.from_pretrained(checkpoint_path).to(device)
|
|
tokenizer = PreTrainedTokenizerFast.from_pretrained(checkpoint_path)
|
|
model.eval()
|
|
|
|
return model, tokenizer, device
|
|
except Exception as e:
|
|
st.error(f"Error loading {checkpoint_name}: {e}")
|
|
return None, None, None
|
|
|
|
def tokenize_text(text, tokenizer):
|
|
"""Tokenize the input text and return tokens and their IDs."""
|
|
encoding = tokenizer(text, return_tensors="pt", add_special_tokens=True)
|
|
input_ids = encoding.input_ids[0].tolist()
|
|
|
|
|
|
tokens = []
|
|
for id in input_ids:
|
|
token = tokenizer.convert_ids_to_tokens(id)
|
|
tokens.append(token)
|
|
|
|
return tokens, input_ids
|
|
|
|
def toggle_token(index):
|
|
"""Toggle a token's masked status."""
|
|
if index in st.session_state.masked_indices:
|
|
st.session_state.masked_indices.remove(index)
|
|
else:
|
|
st.session_state.masked_indices.add(index)
|
|
|
|
def update_input_sentence():
|
|
"""Update the input sentence and reset masked indices."""
|
|
st.session_state.input_sentence = st.session_state.input_text
|
|
st.session_state.masked_indices = set()
|
|
|
|
def get_predictions(model, tokenizer, device):
|
|
"""Get predictions for masked tokens."""
|
|
if not st.session_state.masked_indices:
|
|
return None, None, None, None
|
|
|
|
|
|
masked_input_ids = st.session_state.token_ids.copy()
|
|
|
|
|
|
for idx in st.session_state.masked_indices:
|
|
masked_input_ids[idx] = tokenizer.mask_token_id
|
|
|
|
|
|
masked_input_tensor = torch.tensor([masked_input_ids]).to(device)
|
|
|
|
|
|
with torch.no_grad():
|
|
outputs = model(input_ids=masked_input_tensor)
|
|
logits = outputs.logits
|
|
|
|
results = []
|
|
top1_predictions = {}
|
|
prediction_tokens = {}
|
|
original_token_ranks = {}
|
|
|
|
for masked_index in st.session_state.masked_indices:
|
|
mask_logits = logits[0, masked_index, :]
|
|
probabilities = torch.softmax(mask_logits, dim=-1)
|
|
top_k_probs, top_k_indices = torch.topk(probabilities, TOP_K)
|
|
|
|
|
|
top1_id = top_k_indices[0].item()
|
|
top1_predictions[masked_index] = top1_id
|
|
|
|
|
|
raw_token = tokenizer.convert_ids_to_tokens(top1_id)
|
|
prediction_tokens[masked_index] = sanitize_token_display(raw_token)
|
|
|
|
original_token = st.session_state.tokens[masked_index]
|
|
original_id = st.session_state.token_ids[masked_index]
|
|
|
|
|
|
original_token_in_top_k = False
|
|
original_token_rank = -1
|
|
|
|
for rank, token_id in enumerate(top_k_indices.tolist()):
|
|
predicted_token = tokenizer.convert_ids_to_tokens(token_id)
|
|
if predicted_token.lower() == original_token.lower() or token_id == original_id:
|
|
original_token_in_top_k = True
|
|
original_token_rank = rank
|
|
break
|
|
|
|
original_token_ranks[masked_index] = original_token_rank
|
|
|
|
for rank, (prob, token_id) in enumerate(zip(top_k_probs.tolist(), top_k_indices.tolist())):
|
|
predicted_token = tokenizer.convert_ids_to_tokens(token_id)
|
|
|
|
clean_predicted_token = sanitize_token_display(predicted_token)
|
|
|
|
|
|
is_match = predicted_token.lower() == original_token.lower()
|
|
results.append({
|
|
"Masked Index": masked_index,
|
|
"Rank": rank + 1,
|
|
"Predicted Token": clean_predicted_token,
|
|
"Original Token": sanitize_token_display(original_token),
|
|
"Exact Match": is_match,
|
|
"Probability": f"{prob:.4f}"
|
|
})
|
|
|
|
|
|
reconstructed_ids = masked_input_ids.copy()
|
|
for idx in st.session_state.masked_indices:
|
|
reconstructed_ids[idx] = top1_predictions[idx]
|
|
|
|
reconstructed_text = tokenizer.decode(reconstructed_ids, skip_special_tokens=True)
|
|
|
|
return results, reconstructed_text, prediction_tokens, original_token_ranks
|
|
|
|
|
|
|
|
st.set_page_config(layout="wide", page_title="Interactive MLM Inference")
|
|
|
|
|
|
st.markdown("""
|
|
<style>
|
|
.stButton button {
|
|
white-space: nowrap;
|
|
overflow: hidden;
|
|
text-overflow: ellipsis;
|
|
min-width: 80px;
|
|
}
|
|
</style>
|
|
""", unsafe_allow_html=True)
|
|
|
|
st.title("🧪 Interactive MLM Inference")
|
|
|
|
|
|
available_checkpoints = find_checkpoints(CHECKPOINT_BASE_DIR)
|
|
|
|
if not available_checkpoints:
|
|
st.error(f"No checkpoints found in '{CHECKPOINT_BASE_DIR}'. Please train a model first.")
|
|
st.stop()
|
|
|
|
selected_checkpoint = st.selectbox(
|
|
"Select Checkpoint:",
|
|
available_checkpoints,
|
|
index=len(available_checkpoints) - 1
|
|
)
|
|
|
|
|
|
if selected_checkpoint:
|
|
model, tokenizer, device = load_model_and_tokenizer(selected_checkpoint)
|
|
else:
|
|
model, tokenizer, device = None, None, None
|
|
|
|
|
|
st.divider()
|
|
st.subheader("Interactive Token Masking")
|
|
|
|
|
|
st.text_area(
|
|
"Input Sentence:",
|
|
value=st.session_state.input_sentence,
|
|
key="input_text",
|
|
on_change=update_input_sentence,
|
|
height=100
|
|
)
|
|
|
|
if model and tokenizer and device:
|
|
|
|
st.session_state.tokens, st.session_state.token_ids = tokenize_text(
|
|
st.session_state.input_sentence,
|
|
tokenizer
|
|
)
|
|
|
|
|
|
st.session_state.display_tokens = [sanitize_token_display(token) for token in st.session_state.tokens]
|
|
|
|
|
|
st.subheader("Click on tokens to mask/unmask them:")
|
|
|
|
|
|
tokens_per_row = 12
|
|
|
|
|
|
num_rows = (len(st.session_state.tokens) + tokens_per_row - 1) // tokens_per_row
|
|
|
|
for row in range(num_rows):
|
|
|
|
start_idx = row * tokens_per_row
|
|
end_idx = min(start_idx + tokens_per_row, len(st.session_state.tokens))
|
|
row_tokens = st.session_state.tokens[start_idx:end_idx]
|
|
|
|
|
|
cols = st.columns(len(row_tokens))
|
|
|
|
for j, col in enumerate(cols):
|
|
idx = start_idx + j
|
|
token = st.session_state.tokens[idx]
|
|
|
|
|
|
is_special = token in [
|
|
tokenizer.cls_token,
|
|
tokenizer.sep_token,
|
|
tokenizer.pad_token
|
|
]
|
|
|
|
is_masked = idx in st.session_state.masked_indices
|
|
|
|
|
|
button_key = f"token_{idx}"
|
|
button_label = sanitize_token_display(token) if not is_masked else "[MASK]"
|
|
|
|
if col.button(
|
|
button_label,
|
|
key=button_key,
|
|
disabled=is_special,
|
|
help=f"Token ID: {st.session_state.token_ids[idx]}"
|
|
):
|
|
toggle_token(idx)
|
|
st.rerun()
|
|
|
|
|
|
if st.session_state.masked_indices:
|
|
results, reconstructed_text, prediction_tokens, original_token_ranks = get_predictions(model, tokenizer, device)
|
|
|
|
st.subheader("Predictions:")
|
|
st.markdown("**Reconstructed sentence with predictions:**")
|
|
|
|
|
|
html = "<div style='padding: 10px; border-radius: 5px; border: 1px solid #ccc;'>"
|
|
|
|
|
|
for i, token in enumerate(st.session_state.tokens):
|
|
|
|
if token in [tokenizer.cls_token, tokenizer.sep_token, tokenizer.pad_token]:
|
|
continue
|
|
|
|
if i in st.session_state.masked_indices:
|
|
|
|
original_token = sanitize_token_display(st.session_state.tokens[i])
|
|
predicted_token = prediction_tokens[i]
|
|
original_rank = original_token_ranks[i]
|
|
|
|
|
|
if original_rank == 0:
|
|
|
|
html += f"<span style='background-color: #c3e6cb; padding: 2px 4px; border-radius: 3px; margin: 0 2px;'>{predicted_token}</span>"
|
|
elif original_rank != -1:
|
|
|
|
html += f"<span style='background-color: #b8daff; padding: 2px 4px; border-radius: 3px; margin: 0 2px;'>{predicted_token}</span>"
|
|
else:
|
|
|
|
html += f"<span style='background-color: #f8d7da; padding: 2px 4px; border-radius: 3px; margin: 0 2px;'>{predicted_token}</span>"
|
|
else:
|
|
|
|
sanitized_token = sanitize_token_display(token)
|
|
html += f"{sanitized_token} "
|
|
|
|
html += "</div>"
|
|
|
|
|
|
st.markdown(html, unsafe_allow_html=True)
|
|
|
|
|
|
st.markdown("**Top predictions for each masked token:**")
|
|
|
|
for masked_idx in st.session_state.masked_indices:
|
|
original_token = st.session_state.tokens[masked_idx]
|
|
original_rank = original_token_ranks[masked_idx]
|
|
|
|
|
|
if original_rank == 0:
|
|
rank_note = "✅ Original token was the top prediction"
|
|
elif original_rank != -1:
|
|
rank_note = f"ℹ️ Original token was prediction #{original_rank+1}"
|
|
else:
|
|
rank_note = "❌ Original token not in top 5 predictions"
|
|
|
|
|
|
clean_original_token = sanitize_token_display(original_token)
|
|
st.markdown(f"**Token {clean_original_token} at position {masked_idx}** - {rank_note}")
|
|
|
|
|
|
df = pd.DataFrame([r for r in results if r['Masked Index'] == masked_idx])
|
|
df = df[["Rank", "Predicted Token", "Probability"]]
|
|
|
|
|
|
if original_rank != -1:
|
|
|
|
styled_df = df.style.apply(lambda x: ['background-color: #c3e6cb' if i == original_rank else '' for i in range(len(x))], axis=0)
|
|
st.dataframe(styled_df, use_container_width=True)
|
|
else:
|
|
st.dataframe(df, use_container_width=True)
|
|
else:
|
|
st.info("Click on tokens above to mask them and see predictions.")
|
|
else:
|
|
st.warning("Please select a valid checkpoint to enable interactive masking.")
|
|
|
|
st.divider()
|
|
st.caption("Interactive app for RoBERTa Masked Language Modeling.") |