dejanseo commited on
Commit
a08869d
·
verified ·
1 Parent(s): 6ce6815

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +357 -0
app.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app_interactive.py
2
+ import streamlit as st
3
+ import torch
4
+ import random
5
+ import os
6
+ import pandas as pd
7
+ from transformers import RobertaForMaskedLM, PreTrainedTokenizerFast
8
+ import re
9
+
10
+ # --- Configuration ---
11
+ CHECKPOINT_BASE_DIR = "./checkpoints"
12
+ PRESET_SENTENCE = "The quick brown fox jumps over the lazy dog near the river bank."
13
+ TOP_K = 5
14
+
15
+ # --- Initialize Session State ---
16
+ if 'masked_indices' not in st.session_state:
17
+ st.session_state.masked_indices = set()
18
+ if 'tokens' not in st.session_state:
19
+ st.session_state.tokens = []
20
+ if 'token_ids' not in st.session_state:
21
+ st.session_state.token_ids = []
22
+ if 'input_sentence' not in st.session_state:
23
+ st.session_state.input_sentence = PRESET_SENTENCE
24
+ if 'display_tokens' not in st.session_state:
25
+ st.session_state.display_tokens = []
26
+
27
+ # --- Helper Functions ---
28
+ def sanitize_token_display(token):
29
+ """Clean up token display by removing special characters like Ġ."""
30
+ # Replace the 'Ġ' character with a more readable indicator
31
+ if isinstance(token, str) and token.startswith('Ġ'):
32
+ return token[1:] # Remove the Ġ character
33
+ # Handle other special tokens if needed
34
+ elif token in ['<s>', '</s>', '<pad>']:
35
+ return token
36
+ else:
37
+ return token
38
+
39
+ def find_checkpoints(base_dir):
40
+ """Finds valid checkpoint directories within the base directory."""
41
+ checkpoints = []
42
+ if not os.path.isdir(base_dir):
43
+ return checkpoints
44
+ for item in os.listdir(base_dir):
45
+ path = os.path.join(base_dir, item)
46
+ if os.path.isdir(path) and item.startswith("checkpoint-"):
47
+ if os.path.exists(os.path.join(path, "pytorch_model.bin")) or \
48
+ os.path.exists(os.path.join(path, "model.safetensors")):
49
+ checkpoints.append(item)
50
+ checkpoints.sort(key=lambda x: int(re.search(r'(\d+)', x).group(1)))
51
+ return checkpoints
52
+
53
+ @st.cache_resource
54
+ def load_model_and_tokenizer(checkpoint_name):
55
+ """Loads the model and tokenizer from the specified checkpoint directory name."""
56
+ checkpoint_path = os.path.join(CHECKPOINT_BASE_DIR, checkpoint_name)
57
+ if not os.path.isdir(checkpoint_path):
58
+ st.error(f"Checkpoint directory not found: {checkpoint_path}")
59
+ return None, None
60
+ try:
61
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
62
+ model = RobertaForMaskedLM.from_pretrained(checkpoint_path).to(device)
63
+ tokenizer = PreTrainedTokenizerFast.from_pretrained(checkpoint_path)
64
+ model.eval()
65
+ #st.success(f"Loaded {checkpoint_name} on {device}")
66
+ return model, tokenizer, device
67
+ except Exception as e:
68
+ st.error(f"Error loading {checkpoint_name}: {e}")
69
+ return None, None, None
70
+
71
+ def tokenize_text(text, tokenizer):
72
+ """Tokenize the input text and return tokens and their IDs."""
73
+ encoding = tokenizer(text, return_tensors="pt", add_special_tokens=True)
74
+ input_ids = encoding.input_ids[0].tolist()
75
+
76
+ # Get individual tokens
77
+ tokens = []
78
+ for id in input_ids:
79
+ token = tokenizer.convert_ids_to_tokens(id)
80
+ tokens.append(token)
81
+
82
+ return tokens, input_ids
83
+
84
+ def toggle_token(index):
85
+ """Toggle a token's masked status."""
86
+ if index in st.session_state.masked_indices:
87
+ st.session_state.masked_indices.remove(index)
88
+ else:
89
+ st.session_state.masked_indices.add(index)
90
+
91
+ def update_input_sentence():
92
+ """Update the input sentence and reset masked indices."""
93
+ st.session_state.input_sentence = st.session_state.input_text
94
+ st.session_state.masked_indices = set()
95
+
96
+ def get_predictions(model, tokenizer, device):
97
+ """Get predictions for masked tokens."""
98
+ if not st.session_state.masked_indices:
99
+ return None, None, None, None
100
+
101
+ # Create a copy of the token IDs
102
+ masked_input_ids = st.session_state.token_ids.copy()
103
+
104
+ # Apply masks
105
+ for idx in st.session_state.masked_indices:
106
+ masked_input_ids[idx] = tokenizer.mask_token_id
107
+
108
+ # Convert to tensor
109
+ masked_input_tensor = torch.tensor([masked_input_ids]).to(device)
110
+
111
+ # Get predictions
112
+ with torch.no_grad():
113
+ outputs = model(input_ids=masked_input_tensor)
114
+ logits = outputs.logits
115
+
116
+ results = []
117
+ top1_predictions = {}
118
+ prediction_tokens = {}
119
+ original_token_ranks = {}
120
+
121
+ for masked_index in st.session_state.masked_indices:
122
+ mask_logits = logits[0, masked_index, :]
123
+ probabilities = torch.softmax(mask_logits, dim=-1)
124
+ top_k_probs, top_k_indices = torch.topk(probabilities, TOP_K)
125
+
126
+ # Save top-1 prediction for reconstruction
127
+ top1_id = top_k_indices[0].item()
128
+ top1_predictions[masked_index] = top1_id
129
+
130
+ # Sanitize the token here
131
+ raw_token = tokenizer.convert_ids_to_tokens(top1_id)
132
+ prediction_tokens[masked_index] = sanitize_token_display(raw_token)
133
+
134
+ original_token = st.session_state.tokens[masked_index]
135
+ original_id = st.session_state.token_ids[masked_index]
136
+
137
+ # Check if original token is in top K predictions
138
+ original_token_in_top_k = False
139
+ original_token_rank = -1 # -1 means not in top K
140
+
141
+ for rank, token_id in enumerate(top_k_indices.tolist()):
142
+ predicted_token = tokenizer.convert_ids_to_tokens(token_id)
143
+ if predicted_token.lower() == original_token.lower() or token_id == original_id:
144
+ original_token_in_top_k = True
145
+ original_token_rank = rank
146
+ break
147
+
148
+ original_token_ranks[masked_index] = original_token_rank
149
+
150
+ for rank, (prob, token_id) in enumerate(zip(top_k_probs.tolist(), top_k_indices.tolist())):
151
+ predicted_token = tokenizer.convert_ids_to_tokens(token_id)
152
+ # Sanitize the predicted token for the results table
153
+ clean_predicted_token = sanitize_token_display(predicted_token)
154
+
155
+ # Case insensitive match
156
+ is_match = predicted_token.lower() == original_token.lower()
157
+ results.append({
158
+ "Masked Index": masked_index,
159
+ "Rank": rank + 1,
160
+ "Predicted Token": clean_predicted_token, # Use sanitized token
161
+ "Original Token": sanitize_token_display(original_token), # Sanitize original token
162
+ "Exact Match": is_match,
163
+ "Probability": f"{prob:.4f}"
164
+ })
165
+
166
+ # Reconstruct the sentence using top-1 predictions
167
+ reconstructed_ids = masked_input_ids.copy()
168
+ for idx in st.session_state.masked_indices:
169
+ reconstructed_ids[idx] = top1_predictions[idx]
170
+
171
+ reconstructed_text = tokenizer.decode(reconstructed_ids, skip_special_tokens=True)
172
+
173
+ return results, reconstructed_text, prediction_tokens, original_token_ranks
174
+
175
+ # --- Streamlit App Layout ---
176
+
177
+ st.set_page_config(layout="wide", page_title="Interactive MLM Inference")
178
+
179
+ # Custom CSS to prevent text wrapping in buttons
180
+ st.markdown("""
181
+ <style>
182
+ .stButton button {
183
+ white-space: nowrap;
184
+ overflow: hidden;
185
+ text-overflow: ellipsis;
186
+ min-width: 80px;
187
+ }
188
+ </style>
189
+ """, unsafe_allow_html=True)
190
+
191
+ st.title("🧪 Interactive MLM Inference")
192
+
193
+ # --- Checkpoint Selection ---
194
+ available_checkpoints = find_checkpoints(CHECKPOINT_BASE_DIR)
195
+
196
+ if not available_checkpoints:
197
+ st.error(f"No checkpoints found in '{CHECKPOINT_BASE_DIR}'. Please train a model first.")
198
+ st.stop()
199
+
200
+ selected_checkpoint = st.selectbox(
201
+ "Select Checkpoint:",
202
+ available_checkpoints,
203
+ index=len(available_checkpoints) - 1
204
+ )
205
+
206
+ # --- Load Model ---
207
+ if selected_checkpoint:
208
+ model, tokenizer, device = load_model_and_tokenizer(selected_checkpoint)
209
+ else:
210
+ model, tokenizer, device = None, None, None
211
+
212
+ # --- Interactive Inference Section ---
213
+ st.divider()
214
+ st.subheader("Interactive Token Masking")
215
+
216
+ # 1. Original text area
217
+ st.text_area(
218
+ "Input Sentence:",
219
+ value=st.session_state.input_sentence,
220
+ key="input_text",
221
+ on_change=update_input_sentence,
222
+ height=100
223
+ )
224
+
225
+ if model and tokenizer and device:
226
+ # Tokenize the input text
227
+ st.session_state.tokens, st.session_state.token_ids = tokenize_text(
228
+ st.session_state.input_sentence,
229
+ tokenizer
230
+ )
231
+
232
+ # Create sanitized display tokens
233
+ st.session_state.display_tokens = [sanitize_token_display(token) for token in st.session_state.tokens]
234
+
235
+ # 2. Interactive token display
236
+ st.subheader("Click on tokens to mask/unmask them:")
237
+
238
+ # Group tokens into rows (adjust number as needed)
239
+ tokens_per_row = 12
240
+
241
+ # Calculate how many rows we need
242
+ num_rows = (len(st.session_state.tokens) + tokens_per_row - 1) // tokens_per_row
243
+
244
+ for row in range(num_rows):
245
+ # Create columns for this row
246
+ start_idx = row * tokens_per_row
247
+ end_idx = min(start_idx + tokens_per_row, len(st.session_state.tokens))
248
+ row_tokens = st.session_state.tokens[start_idx:end_idx]
249
+
250
+ # Create equal-width columns
251
+ cols = st.columns(len(row_tokens))
252
+
253
+ for j, col in enumerate(cols):
254
+ idx = start_idx + j
255
+ token = st.session_state.tokens[idx]
256
+
257
+ # Skip special tokens for masking
258
+ is_special = token in [
259
+ tokenizer.cls_token,
260
+ tokenizer.sep_token,
261
+ tokenizer.pad_token
262
+ ]
263
+
264
+ is_masked = idx in st.session_state.masked_indices
265
+
266
+ # Create a button for each token
267
+ button_key = f"token_{idx}"
268
+ button_label = sanitize_token_display(token) if not is_masked else "[MASK]"
269
+
270
+ if col.button(
271
+ button_label,
272
+ key=button_key,
273
+ disabled=is_special,
274
+ help=f"Token ID: {st.session_state.token_ids[idx]}"
275
+ ):
276
+ toggle_token(idx)
277
+ st.rerun()
278
+
279
+ # 3. Prediction area
280
+ if st.session_state.masked_indices:
281
+ results, reconstructed_text, prediction_tokens, original_token_ranks = get_predictions(model, tokenizer, device)
282
+
283
+ st.subheader("Predictions:")
284
+ st.markdown("**Reconstructed sentence with predictions:**")
285
+
286
+ # Create HTML for highlighting predictions
287
+ html = "<div style='padding: 10px; border-radius: 5px; border: 1px solid #ccc;'>"
288
+
289
+ # Use the original tokenization to match masked positions
290
+ for i, token in enumerate(st.session_state.tokens):
291
+ # Skip special tokens
292
+ if token in [tokenizer.cls_token, tokenizer.sep_token, tokenizer.pad_token]:
293
+ continue
294
+
295
+ if i in st.session_state.masked_indices:
296
+ # This was a masked token
297
+ original_token = sanitize_token_display(st.session_state.tokens[i])
298
+ predicted_token = prediction_tokens[i] # This is already sanitized in get_predictions
299
+ original_rank = original_token_ranks[i]
300
+
301
+ # Color based on original token's rank in predictions
302
+ if original_rank == 0: # Rank 0 means it was the top prediction
303
+ # Green for top prediction (rank 1)
304
+ html += f"<span style='background-color: #c3e6cb; padding: 2px 4px; border-radius: 3px; margin: 0 2px;'>{predicted_token}</span>"
305
+ elif original_rank != -1: # In top 5 but not top
306
+ # Blue for in top 5 but not top
307
+ html += f"<span style='background-color: #b8daff; padding: 2px 4px; border-radius: 3px; margin: 0 2px;'>{predicted_token}</span>"
308
+ else: # Not in top 5
309
+ # Red for not in top 5
310
+ html += f"<span style='background-color: #f8d7da; padding: 2px 4px; border-radius: 3px; margin: 0 2px;'>{predicted_token}</span>"
311
+ else:
312
+ # Not a masked token, display normally
313
+ sanitized_token = sanitize_token_display(token)
314
+ html += f"{sanitized_token} "
315
+
316
+ html += "</div>"
317
+
318
+ # Display the highlighted text
319
+ st.markdown(html, unsafe_allow_html=True)
320
+
321
+ # Show detailed predictions
322
+ st.markdown("**Top predictions for each masked token:**")
323
+
324
+ for masked_idx in st.session_state.masked_indices:
325
+ original_token = st.session_state.tokens[masked_idx]
326
+ original_rank = original_token_ranks[masked_idx]
327
+
328
+ # Create a note about whether the original token was in top predictions
329
+ if original_rank == 0:
330
+ rank_note = "✅ Original token was the top prediction"
331
+ elif original_rank != -1:
332
+ rank_note = f"ℹ️ Original token was prediction #{original_rank+1}"
333
+ else:
334
+ rank_note = "❌ Original token not in top 5 predictions"
335
+
336
+ # Sanitize the token display
337
+ clean_original_token = sanitize_token_display(original_token)
338
+ st.markdown(f"**Token {clean_original_token} at position {masked_idx}** - {rank_note}")
339
+
340
+ # The dataframe is already sanitized in the get_predictions function
341
+ df = pd.DataFrame([r for r in results if r['Masked Index'] == masked_idx])
342
+ df = df[["Rank", "Predicted Token", "Probability"]]
343
+
344
+ # Highlight the row with the original token if it's in top 5
345
+ if original_rank != -1:
346
+ # Use pandas styler to highlight the row
347
+ styled_df = df.style.apply(lambda x: ['background-color: #c3e6cb' if i == original_rank else '' for i in range(len(x))], axis=0)
348
+ st.dataframe(styled_df, use_container_width=True)
349
+ else:
350
+ st.dataframe(df, use_container_width=True)
351
+ else:
352
+ st.info("Click on tokens above to mask them and see predictions.")
353
+ else:
354
+ st.warning("Please select a valid checkpoint to enable interactive masking.")
355
+
356
+ st.divider()
357
+ st.caption("Interactive app for RoBERTa Masked Language Modeling.")