kikikara commited on
Commit
234dafc
Β·
verified Β·
1 Parent(s): 72b741e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -38
app.py CHANGED
@@ -3,7 +3,7 @@ import os
3
  import joblib
4
  import torch
5
  import numpy as np
6
- import html
7
  from transformers import AutoTokenizer, AutoModel, logging as hf_logging
8
  import pandas as pd
9
  import matplotlib
@@ -17,7 +17,7 @@ hf_logging.set_verbosity_error()
17
 
18
  MODEL_NAME = "bert-base-uncased"
19
  DEVICE = "cpu"
20
- SAVE_DIR = "μ €μž₯μ €μž₯1" # This folder name is from your setup
21
  LAYER_ID = 4
22
  SEED = 0
23
  CLF_NAME = "linear"
@@ -127,14 +127,14 @@ def create_empty_plotly_figure(message="N/A"):
127
  )
128
  return fig
129
 
130
- # --- Core Analysis Function (returns 7 items for Gradio UI) ---
131
  def analyze_sentence_for_gradio(sentence_text, top_k_value):
132
  if not MODELS_LOADED_SUCCESSFULLY:
133
- error_html = f"<p style='color:red;'>Initialization Error: {html.escape(MODEL_LOADING_ERROR_MESSAGE)}</p>"
134
  empty_df = pd.DataFrame(columns=['token', 'score'])
135
  empty_fig = create_empty_plotly_figure("Model Loading Failed")
136
  error_label_output = {"Status": "Error", "Message": "Model Loading Failed. Check logs."}
137
- return error_html, [], "Model Loading Failed", error_label_output, [], empty_df, empty_fig
138
 
139
  try:
140
  tokenizer, model = TOKENIZER_GLOBAL, MODEL_GLOBAL
@@ -147,7 +147,7 @@ def analyze_sentence_for_gradio(sentence_text, top_k_value):
147
  empty_df = pd.DataFrame(columns=['token', 'score'])
148
  empty_fig = create_empty_plotly_figure("Invalid Input")
149
  error_label_output = {"Status": "Error", "Message": "Invalid input, no valid tokens."}
150
- return "<p style='color:orange;'>Input Error: No valid tokens found.</p>", [], "Input Error", error_label_output, [], empty_df, empty_fig
151
 
152
  input_embeds_detached = model.embeddings.word_embeddings(input_ids).clone().detach()
153
  input_embeds_for_grad = input_embeds_detached.clone().requires_grad_(True)
@@ -167,7 +167,7 @@ def analyze_sentence_for_gradio(sentence_text, top_k_value):
167
  empty_df = pd.DataFrame(columns=['token', 'score'])
168
  empty_fig = create_empty_plotly_figure("Gradient Error")
169
  error_label_output = {"Status": "Error", "Message": "Gradient calculation failed."}
170
- return "<p style='color:red;'>Analysis Error: Gradient calculation failed.</p>", [],"Analysis Error", error_label_output, [], empty_df, empty_fig
171
 
172
  grads = input_embeds_for_grad.grad.clone().detach()
173
  scores = (grads * input_embeds_detached).norm(dim=2).squeeze(0)
@@ -180,7 +180,8 @@ def analyze_sentence_for_gradio(sentence_text, top_k_value):
180
  actual_scores_np = scores_np[:len(actual_tokens)]
181
  actual_input_embeds = input_embeds_detached[0, :len(actual_tokens), :].cpu().numpy()
182
 
183
- html_tokens_list, highlighted_text_data = [], []
 
184
  cls_token_id, sep_token_id = tokenizer.cls_token_id, tokenizer.sep_token_id
185
 
186
  for i, tok_str in enumerate(actual_tokens):
@@ -190,15 +191,10 @@ def analyze_sentence_for_gradio(sentence_text, top_k_value):
190
  current_token_id = input_ids[0, i].item()
191
 
192
  if current_token_id == cls_token_id or current_token_id == sep_token_id:
193
- html_tokens_list.append(f"<span style='font-weight:bold;'>{html.escape(clean_tok_str)}</span>")
194
  highlighted_text_data.append((clean_tok_str + " ", None))
195
  else:
196
- color = f"rgba(220, 50, 50, {current_score_clipped:.2f})"
197
- html_tokens_list.append(f"<span style='background-color:{color}; color:white; padding: 1px 3px; margin: 1px; border-radius: 4px; display:inline-block;'>{html.escape(clean_tok_str)}</span>")
198
  highlighted_text_data.append((clean_tok_str + " ", round(current_score_clipped, 3)))
199
 
200
- html_output_str = " ".join(html_tokens_list).replace(" ##", "")
201
-
202
  top_tokens_for_df, top_tokens_for_barplot_list = [], []
203
  valid_indices = [idx for idx, token_id in enumerate(input_ids[0,:len(actual_tokens)].tolist())
204
  if token_id not in [cls_token_id, sep_token_id]]
@@ -230,22 +226,22 @@ def analyze_sentence_for_gradio(sentence_text, top_k_value):
230
  token_embeddings_3d = pca.fit_transform(pca_embeddings)
231
  pca_fig = plot_token_pca_3d_plotly(token_embeddings_3d, pca_tokens, pca_scores_for_plot)
232
 
233
- return (html_output_str, highlighted_text_data,
234
  prediction_summary_text, prediction_details_for_label,
235
  top_tokens_for_df, barplot_df,
236
- pca_fig)
237
 
238
  except Exception as e:
239
  import traceback
240
  tb_str = traceback.format_exc()
241
- error_html = f"<p style='color:red;'>Analysis Error: {html.escape(str(e))}</p><pre>{html.escape(tb_str)}</pre>"
242
  print(f"analyze_sentence_for_gradio error: {e}\n{tb_str}")
243
  empty_df = pd.DataFrame(columns=['token', 'score'])
244
  empty_fig = create_empty_plotly_figure("Analysis Error")
245
  error_label_output = {"Status": "Error", "Message": f"Analysis failed: {str(e)}"}
246
- return error_html, [], "Analysis Failed", error_label_output, [], empty_df, empty_fig
247
 
248
- # --- Gradio UI Definition (Tabs removed, visualizations shown sequentially or in rows) ---
249
  theme = gr.themes.Monochrome(
250
  primary_hue=gr.themes.colors.blue,
251
  secondary_hue=gr.themes.colors.sky,
@@ -262,7 +258,6 @@ with gr.Blocks(title="AI Sentence Analyzer XAI πŸš€", theme=theme, css=".gradio-
262
  gr.Markdown("Analyze English sentences to understand BERT model predictions through various XAI visualization techniques. "
263
  "Explore token importance and their distribution in the embedding space.")
264
 
265
- # Inputs and Summary Outputs Row
266
  with gr.Row(equal_height=False):
267
  with gr.Column(scale=1, min_width=350):
268
  with gr.Group():
@@ -278,18 +273,13 @@ with gr.Blocks(title="AI Sentence Analyzer XAI πŸš€", theme=theme, css=".gradio-
278
  with gr.Accordion("⭐ Top-K Important Tokens (Table)", open=True):
279
  output_top_tokens_df = gr.DataFrame(headers=["Token", "Score"], label="Most Important Tokens",
280
  row_count=(1,"dynamic"), col_count=(2,"fixed"), interactive=False, wrap=True)
281
- gr.Markdown("---") # Separator
282
 
283
- # Visualization Section Title
284
  gr.Markdown("## πŸ“Š Detailed Visualizations")
 
 
285
 
286
- # HTML Highlight (Custom) - Full Width
287
- with gr.Group():
288
- gr.Markdown("### 🎨 HTML Highlight (Custom)")
289
- output_html_visualization = gr.HTML(label="Token Importance (Gradient x Input based)")
290
-
291
- # Highlighted Text (Gradio) - Full Width
292
- with gr.Group():
293
  gr.Markdown("### πŸ–οΈ Highlighted Text (Gradio)")
294
  output_highlighted_text = gr.HighlightedText(
295
  label="Token Importance (Score: 0-1)",
@@ -297,9 +287,8 @@ with gr.Blocks(title="AI Sentence Analyzer XAI πŸš€", theme=theme, css=".gradio-
297
  combine_adjacent=False
298
  )
299
 
300
- # BarPlot and PCA Plot Side-by-Side
301
- with gr.Row():
302
- with gr.Column(scale=1, min_width=400): # Adjusted min_width for BarPlot
303
  with gr.Group():
304
  gr.Markdown("### πŸ“Š Top-K Bar Plot")
305
  output_top_tokens_barplot = gr.BarPlot(
@@ -307,14 +296,14 @@ with gr.Blocks(title="AI Sentence Analyzer XAI πŸš€", theme=theme, css=".gradio-
307
  x="token",
308
  y="score",
309
  tooltip=['token', 'score'],
310
- min_width=300 # BarPlot itself can define min_width
311
  )
312
- with gr.Column(scale=1, min_width=400): # Adjusted min_width for PCA
313
  with gr.Group():
314
  gr.Markdown("### 🌐 Token Embeddings 3D PCA (Interactive)")
315
  output_pca_plot = gr.Plot(label="3D PCA of Token Embeddings (Colored by Importance Score)")
316
 
317
- gr.Markdown("---") # Separator
318
 
319
  gr.Examples(
320
  examples=[
@@ -323,8 +312,8 @@ with gr.Blocks(title="AI Sentence Analyzer XAI πŸš€", theme=theme, css=".gradio-
323
  ["I was thoroughly disappointed with the lackluster performance and predictable plot.", 4]
324
  ],
325
  inputs=[input_sentence, input_top_k],
326
- outputs=[
327
- output_html_visualization, output_highlighted_text,
328
  output_prediction_summary, output_prediction_details,
329
  output_top_tokens_df, output_top_tokens_barplot,
330
  output_pca_plot
@@ -337,8 +326,8 @@ with gr.Blocks(title="AI Sentence Analyzer XAI πŸš€", theme=theme, css=".gradio-
337
  submit_button.click(
338
  fn=analyze_sentence_for_gradio,
339
  inputs=[input_sentence, input_top_k],
340
- outputs=[
341
- output_html_visualization, output_highlighted_text,
342
  output_prediction_summary, output_prediction_details,
343
  output_top_tokens_df, output_top_tokens_barplot,
344
  output_pca_plot
 
3
  import joblib
4
  import torch
5
  import numpy as np
6
+ import html # μ—¬μ „νžˆ highlighted_text_data 생성 μ‹œ html.escapeλ₯Ό μ‚¬μš©ν•  수 μžˆμœΌλ―€λ‘œ μœ μ§€
7
  from transformers import AutoTokenizer, AutoModel, logging as hf_logging
8
  import pandas as pd
9
  import matplotlib
 
17
 
18
  MODEL_NAME = "bert-base-uncased"
19
  DEVICE = "cpu"
20
+ SAVE_DIR = "μ €μž₯μ €μž₯1"
21
  LAYER_ID = 4
22
  SEED = 0
23
  CLF_NAME = "linear"
 
127
  )
128
  return fig
129
 
130
+ # --- Core Analysis Function (returns 6 items for Gradio UI) ---
131
  def analyze_sentence_for_gradio(sentence_text, top_k_value):
132
  if not MODELS_LOADED_SUCCESSFULLY:
133
+ # HTML output removed, adjust error return
134
  empty_df = pd.DataFrame(columns=['token', 'score'])
135
  empty_fig = create_empty_plotly_figure("Model Loading Failed")
136
  error_label_output = {"Status": "Error", "Message": "Model Loading Failed. Check logs."}
137
+ return [], "Model Loading Failed", error_label_output, [], empty_df, empty_fig # 6 items
138
 
139
  try:
140
  tokenizer, model = TOKENIZER_GLOBAL, MODEL_GLOBAL
 
147
  empty_df = pd.DataFrame(columns=['token', 'score'])
148
  empty_fig = create_empty_plotly_figure("Invalid Input")
149
  error_label_output = {"Status": "Error", "Message": "Invalid input, no valid tokens."}
150
+ return [], "Input Error", error_label_output, [], empty_df, empty_fig # 6 items
151
 
152
  input_embeds_detached = model.embeddings.word_embeddings(input_ids).clone().detach()
153
  input_embeds_for_grad = input_embeds_detached.clone().requires_grad_(True)
 
167
  empty_df = pd.DataFrame(columns=['token', 'score'])
168
  empty_fig = create_empty_plotly_figure("Gradient Error")
169
  error_label_output = {"Status": "Error", "Message": "Gradient calculation failed."}
170
+ return [],"Analysis Error", error_label_output, [], empty_df, empty_fig # 6 items
171
 
172
  grads = input_embeds_for_grad.grad.clone().detach()
173
  scores = (grads * input_embeds_detached).norm(dim=2).squeeze(0)
 
180
  actual_scores_np = scores_np[:len(actual_tokens)]
181
  actual_input_embeds = input_embeds_detached[0, :len(actual_tokens), :].cpu().numpy()
182
 
183
+ # HTML generation logic removed
184
+ highlighted_text_data = []
185
  cls_token_id, sep_token_id = tokenizer.cls_token_id, tokenizer.sep_token_id
186
 
187
  for i, tok_str in enumerate(actual_tokens):
 
191
  current_token_id = input_ids[0, i].item()
192
 
193
  if current_token_id == cls_token_id or current_token_id == sep_token_id:
 
194
  highlighted_text_data.append((clean_tok_str + " ", None))
195
  else:
 
 
196
  highlighted_text_data.append((clean_tok_str + " ", round(current_score_clipped, 3)))
197
 
 
 
198
  top_tokens_for_df, top_tokens_for_barplot_list = [], []
199
  valid_indices = [idx for idx, token_id in enumerate(input_ids[0,:len(actual_tokens)].tolist())
200
  if token_id not in [cls_token_id, sep_token_id]]
 
226
  token_embeddings_3d = pca.fit_transform(pca_embeddings)
227
  pca_fig = plot_token_pca_3d_plotly(token_embeddings_3d, pca_tokens, pca_scores_for_plot)
228
 
229
+ return (highlighted_text_data, # HTML output removed
230
  prediction_summary_text, prediction_details_for_label,
231
  top_tokens_for_df, barplot_df,
232
+ pca_fig) # 6 items
233
 
234
  except Exception as e:
235
  import traceback
236
  tb_str = traceback.format_exc()
237
+ # HTML output removed
238
  print(f"analyze_sentence_for_gradio error: {e}\n{tb_str}")
239
  empty_df = pd.DataFrame(columns=['token', 'score'])
240
  empty_fig = create_empty_plotly_figure("Analysis Error")
241
  error_label_output = {"Status": "Error", "Message": f"Analysis failed: {str(e)}"}
242
+ return [], "Analysis Failed", error_label_output, [], empty_df, empty_fig # 6 items
243
 
244
+ # --- Gradio UI Definition (HTML Highlight Tab removed) ---
245
  theme = gr.themes.Monochrome(
246
  primary_hue=gr.themes.colors.blue,
247
  secondary_hue=gr.themes.colors.sky,
 
258
  gr.Markdown("Analyze English sentences to understand BERT model predictions through various XAI visualization techniques. "
259
  "Explore token importance and their distribution in the embedding space.")
260
 
 
261
  with gr.Row(equal_height=False):
262
  with gr.Column(scale=1, min_width=350):
263
  with gr.Group():
 
273
  with gr.Accordion("⭐ Top-K Important Tokens (Table)", open=True):
274
  output_top_tokens_df = gr.DataFrame(headers=["Token", "Score"], label="Most Important Tokens",
275
  row_count=(1,"dynamic"), col_count=(2,"fixed"), interactive=False, wrap=True)
276
+ gr.Markdown("---")
277
 
 
278
  gr.Markdown("## πŸ“Š Detailed Visualizations")
279
+
280
+ # HTML Highlight (Custom) section removed
281
 
282
+ with gr.Group(): # HighlightedText
 
 
 
 
 
 
283
  gr.Markdown("### πŸ–οΈ Highlighted Text (Gradio)")
284
  output_highlighted_text = gr.HighlightedText(
285
  label="Token Importance (Score: 0-1)",
 
287
  combine_adjacent=False
288
  )
289
 
290
+ with gr.Row(): # BarPlot and PCA Plot Side-by-Side
291
+ with gr.Column(scale=1, min_width=400):
 
292
  with gr.Group():
293
  gr.Markdown("### πŸ“Š Top-K Bar Plot")
294
  output_top_tokens_barplot = gr.BarPlot(
 
296
  x="token",
297
  y="score",
298
  tooltip=['token', 'score'],
299
+ min_width=300
300
  )
301
+ with gr.Column(scale=1, min_width=400):
302
  with gr.Group():
303
  gr.Markdown("### 🌐 Token Embeddings 3D PCA (Interactive)")
304
  output_pca_plot = gr.Plot(label="3D PCA of Token Embeddings (Colored by Importance Score)")
305
 
306
+ gr.Markdown("---")
307
 
308
  gr.Examples(
309
  examples=[
 
312
  ["I was thoroughly disappointed with the lackluster performance and predictable plot.", 4]
313
  ],
314
  inputs=[input_sentence, input_top_k],
315
+ outputs=[ # output_html_visualization removed
316
+ output_highlighted_text,
317
  output_prediction_summary, output_prediction_details,
318
  output_top_tokens_df, output_top_tokens_barplot,
319
  output_pca_plot
 
326
  submit_button.click(
327
  fn=analyze_sentence_for_gradio,
328
  inputs=[input_sentence, input_top_k],
329
+ outputs=[ # output_html_visualization removed
330
+ output_highlighted_text,
331
  output_prediction_summary, output_prediction_details,
332
  output_top_tokens_df, output_top_tokens_barplot,
333
  output_pca_plot