kikikara commited on
Commit
9a618da
ยท
verified ยท
1 Parent(s): c9044fd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -36
app.py CHANGED
@@ -17,12 +17,11 @@ hf_logging.set_verbosity_error()
17
 
18
  MODEL_NAME = "bert-base-uncased"
19
  DEVICE = "cpu"
20
- SAVE_DIR = "์ €์žฅ์ €์žฅ1" # This folder name is from your setup
21
  LAYER_ID = 4
22
  SEED = 0
23
  CLF_NAME = "linear"
24
 
25
- # Class label mapping provided by user
26
  CLASS_LABEL_MAP = {
27
  0: "World",
28
  1: "Sports",
@@ -32,7 +31,6 @@ CLASS_LABEL_MAP = {
32
 
33
  TOKENIZER_GLOBAL, MODEL_GLOBAL = None, None
34
  W_GLOBAL, MU_GLOBAL, W_P_GLOBAL, B_P_GLOBAL = None, None, None, None
35
- # CLASS_NAMES_GLOBAL = None # We'll use CLASS_LABEL_MAP instead for clarity
36
  MODELS_LOADED_SUCCESSFULLY = False
37
  MODEL_LOADING_ERROR_MESSAGE = ""
38
 
@@ -73,24 +71,20 @@ except Exception as e:
73
 
74
  # Helper function: 3D PCA Visualization using Plotly
75
  def plot_token_pca_3d_plotly(token_embeddings_3d, tokens, scores, title="Token Embeddings 3D PCA (Colored by Importance)"):
76
- num_annotations = min(len(tokens), 20) # Annotate up to 20 most important tokens
77
-
78
- # Ensure scores is a 1D numpy array for Plotly marker color processing
79
  scores_array = np.array(scores).flatten()
80
-
81
- # Prepare text annotations (only for most important tokens to avoid clutter)
82
  text_annotations = [''] * len(tokens)
83
  if len(scores_array) > 0 and len(tokens) > 0:
84
  indices_to_annotate = np.argsort(scores_array)[-num_annotations:]
85
  for i in indices_to_annotate:
86
- if i < len(tokens): # Ensure index is valid
87
  text_annotations[i] = tokens[i]
88
 
89
  fig = go.Figure(data=[go.Scatter3d(
90
  x=token_embeddings_3d[:, 0],
91
  y=token_embeddings_3d[:, 1],
92
  z=token_embeddings_3d[:, 2],
93
- mode='markers+text', # Show markers, text for selected
94
  text=text_annotations,
95
  textfont=dict(size=9, color='#333333'),
96
  textposition='top center',
@@ -98,25 +92,26 @@ def plot_token_pca_3d_plotly(token_embeddings_3d, tokens, scores, title="Token E
98
  size=6,
99
  color=scores_array,
100
  colorscale='RdBu',
101
- reversescale=True, # Makes red high, blue low (like coolwarm_r)
102
  opacity=0.8,
103
  colorbar=dict(title='Importance', tickfont=dict(size=9), len=0.75, yanchor='middle')
104
  ),
105
- hoverinfo='text', # Show full token text on hover
106
- hovertext=[f"Token: {t}<br>Score: {s:.3f}" for t, s in zip(tokens, scores_array)] # Custom hover text
107
  )])
108
 
109
  fig.update_layout(
110
  title=dict(text=title, x=0.5, font=dict(size=16)),
111
  scene=dict(
112
- xaxis=dict(title='PCA Comp 1', titlefont=dict(size=10), tickfont=dict(size=9), backgroundcolor="rgba(230, 230, 230, 0.8)"),
113
- yaxis=dict(title='PCA Comp 2', titlefont=dict(size=10), tickfont=dict(size=9), backgroundcolor="rgba(230, 230, 230, 0.8)"),
114
- zaxis=dict(title='PCA Comp 3', titlefont=dict(size=10), tickfont=dict(size=9), backgroundcolor="rgba(230, 230, 230, 0.8)"),
 
115
  bgcolor="rgba(255, 255, 255, 0.95)",
116
- camera_eye=dict(x=1.5, y=1.5, z=0.5) # Initial camera angle
117
  ),
118
  margin=dict(l=5, r=5, b=5, t=45),
119
- paper_bgcolor='rgba(0,0,0,0)' # Transparent paper background
120
  )
121
  return fig
122
 
@@ -127,7 +122,7 @@ def create_empty_plotly_figure(message="N/A"):
127
  fig.update_layout(
128
  xaxis={'visible': False},
129
  yaxis={'visible': False},
130
- height=300, # Define a height for empty plot
131
  paper_bgcolor='rgba(0,0,0,0)',
132
  plot_bgcolor='rgba(0,0,0,0)'
133
  )
@@ -139,7 +134,8 @@ def analyze_sentence_for_gradio(sentence_text, top_k_value):
139
  error_html = f"<p style='color:red;'>Initialization Error: {html.escape(MODEL_LOADING_ERROR_MESSAGE)}</p>"
140
  empty_df = pd.DataFrame(columns=['token', 'score'])
141
  empty_fig = create_empty_plotly_figure("Model Loading Failed")
142
- return error_html, [], "Model Loading Failed", {"Error":"Model Loading Failed"}, [], empty_df, empty_fig
 
143
 
144
  try:
145
  tokenizer, model = TOKENIZER_GLOBAL, MODEL_GLOBAL
@@ -151,7 +147,7 @@ def analyze_sentence_for_gradio(sentence_text, top_k_value):
151
  if input_ids.shape[1] == 0:
152
  empty_df = pd.DataFrame(columns=['token', 'score'])
153
  empty_fig = create_empty_plotly_figure("Invalid Input")
154
- return "<p style='color:orange;'>Input Error: No valid tokens found.</p>", [], "Input Error", {"Error":"Input Error"}, [], empty_df, empty_fig
155
 
156
  input_embeds_detached = model.embeddings.word_embeddings(input_ids).clone().detach()
157
  input_embeds_for_grad = input_embeds_detached.clone().requires_grad_(True)
@@ -170,12 +166,12 @@ def analyze_sentence_for_gradio(sentence_text, top_k_value):
170
  if input_embeds_for_grad.grad is None:
171
  empty_df = pd.DataFrame(columns=['token', 'score'])
172
  empty_fig = create_empty_plotly_figure("Gradient Error")
173
- return "<p style='color:red;'>Analysis Error: Gradient calculation failed.</p>", [],"Analysis Error", {"Error":"Analysis Error"}, [], empty_df, empty_fig
174
 
175
  grads = input_embeds_for_grad.grad.clone().detach()
176
  scores = (grads * input_embeds_detached).norm(dim=2).squeeze(0)
177
  scores_np = scores.cpu().numpy()
178
- valid_scores_for_norm = scores_np[np.isfinite(scores_np)] # Renamed to avoid conflict
179
  scores_np = scores_np / (valid_scores_for_norm.max() + 1e-9) if len(valid_scores_for_norm) > 0 and valid_scores_for_norm.max() > 0 else np.zeros_like(scores_np)
180
 
181
  tokens_raw = tokenizer.convert_ids_to_tokens(input_ids[0], skip_special_tokens=False)
@@ -196,7 +192,7 @@ def analyze_sentence_for_gradio(sentence_text, top_k_value):
196
  html_tokens_list.append(f"<span style='font-weight:bold;'>{html.escape(clean_tok_str)}</span>")
197
  highlighted_text_data.append((clean_tok_str + " ", None))
198
  else:
199
- color = f"rgba(220, 50, 50, {current_score_clipped:.2f})" # Slightly adjusted red
200
  html_tokens_list.append(f"<span style='background-color:{color}; color:white; padding: 1px 3px; margin: 1px; border-radius: 4px; display:inline-block;'>{html.escape(clean_tok_str)}</span>")
201
  highlighted_text_data.append((clean_tok_str + " ", round(current_score_clipped, 3)))
202
 
@@ -227,7 +223,7 @@ def analyze_sentence_for_gradio(sentence_text, top_k_value):
227
  pca_tokens = [actual_tokens[i] for i in non_special_token_indices]
228
  if len(pca_tokens) > 0:
229
  pca_embeddings = actual_input_embeds[non_special_token_indices, :]
230
- pca_scores_for_plot = actual_scores_np[non_special_token_indices] # Use this for coloring
231
 
232
  pca = PCA(n_components=3, random_state=SEED)
233
  token_embeddings_3d = pca.fit_transform(pca_embeddings)
@@ -245,10 +241,10 @@ def analyze_sentence_for_gradio(sentence_text, top_k_value):
245
  print(f"analyze_sentence_for_gradio error: {e}\n{tb_str}")
246
  empty_df = pd.DataFrame(columns=['token', 'score'])
247
  empty_fig = create_empty_plotly_figure("Analysis Error")
248
- return error_html, [], "Analysis Failed", {"Error": str(e)}, [], empty_df, empty_fig
 
249
 
250
  # --- Gradio UI Definition (Translated and Enhanced) ---
251
- # Using a built-in theme and some CSS for aesthetics
252
  theme = gr.themes.Monochrome(
253
  primary_hue=gr.themes.colors.blue,
254
  secondary_hue=gr.themes.colors.sky,
@@ -260,14 +256,13 @@ theme = gr.themes.Monochrome(
260
  button_primary_text_color="white",
261
  )
262
 
263
-
264
  with gr.Blocks(title="AI Sentence Analyzer XAI ๐Ÿš€", theme=theme, css=".gradio-container {max-width: 98% !important;}") as demo:
265
  gr.Markdown("# ๐Ÿš€ AI Sentence Analyzer XAI: Exploring Model Explanations")
266
  gr.Markdown("Analyze English sentences to understand BERT model predictions through various XAI visualization techniques. "
267
  "Explore token importance and their distribution in the embedding space.")
268
 
269
  with gr.Row(equal_height=False):
270
- with gr.Column(scale=1, min_width=350): # Increased min_width slightly
271
  with gr.Group():
272
  gr.Markdown("### โœ๏ธ Input Sentence & Settings")
273
  input_sentence = gr.Textbox(lines=5, label="English Sentence to Analyze", placeholder="Enter the English sentence you want to analyze here...")
@@ -289,10 +284,6 @@ with gr.Blocks(title="AI Sentence Analyzer XAI ๐Ÿš€", theme=theme, css=".gradio-
289
  output_highlighted_text = gr.HighlightedText(
290
  label="Token Importance (Score: 0-1)",
291
  show_legend=True,
292
- # Color map can be more sophisticated if scores are categorical
293
- # For numerical scores (0-1), Gradio tries to infer intensity.
294
- # Example color map (if scores were categories like "LOW", "MEDIUM", "HIGH"):
295
- # color_map={"LOW": "lightblue", "MEDIUM": "lightgreen", "HIGH": "pink"},
296
  combine_adjacent=False
297
  )
298
  with gr.TabItem("๐Ÿ“Š Top-K Bar Plot", id=2):
@@ -301,8 +292,7 @@ with gr.Blocks(title="AI Sentence Analyzer XAI ๐Ÿš€", theme=theme, css=".gradio-
301
  x="token",
302
  y="score",
303
  tooltip=['token', 'score'],
304
- min_width=300,
305
- # title="Top-K Most Important Tokens" # BarPlot may not have a direct title prop
306
  )
307
  with gr.TabItem("๐ŸŒ Token Embeddings 3D PCA (Interactive)", id=3):
308
  output_pca_plot = gr.Plot(label="3D PCA of Token Embeddings (Colored by Importance Score)")
@@ -322,7 +312,7 @@ with gr.Blocks(title="AI Sentence Analyzer XAI ๐Ÿš€", theme=theme, css=".gradio-
322
  output_pca_plot
323
  ],
324
  fn=analyze_sentence_for_gradio,
325
- cache_examples=False # Set to True for faster loading of examples if inputs/outputs are static
326
  )
327
  gr.HTML("<p style='text-align: center; color: #4a5568;'>Explainable AI Demo powered by Gradio & Hugging Face Transformers</p>")
328
 
 
17
 
18
  MODEL_NAME = "bert-base-uncased"
19
  DEVICE = "cpu"
20
+ SAVE_DIR = "์ €์žฅ์ €์žฅ1"
21
  LAYER_ID = 4
22
  SEED = 0
23
  CLF_NAME = "linear"
24
 
 
25
  CLASS_LABEL_MAP = {
26
  0: "World",
27
  1: "Sports",
 
31
 
32
  TOKENIZER_GLOBAL, MODEL_GLOBAL = None, None
33
  W_GLOBAL, MU_GLOBAL, W_P_GLOBAL, B_P_GLOBAL = None, None, None, None
 
34
  MODELS_LOADED_SUCCESSFULLY = False
35
  MODEL_LOADING_ERROR_MESSAGE = ""
36
 
 
71
 
72
  # Helper function: 3D PCA Visualization using Plotly
73
  def plot_token_pca_3d_plotly(token_embeddings_3d, tokens, scores, title="Token Embeddings 3D PCA (Colored by Importance)"):
74
+ num_annotations = min(len(tokens), 20)
 
 
75
  scores_array = np.array(scores).flatten()
 
 
76
  text_annotations = [''] * len(tokens)
77
  if len(scores_array) > 0 and len(tokens) > 0:
78
  indices_to_annotate = np.argsort(scores_array)[-num_annotations:]
79
  for i in indices_to_annotate:
80
+ if i < len(tokens):
81
  text_annotations[i] = tokens[i]
82
 
83
  fig = go.Figure(data=[go.Scatter3d(
84
  x=token_embeddings_3d[:, 0],
85
  y=token_embeddings_3d[:, 1],
86
  z=token_embeddings_3d[:, 2],
87
+ mode='markers+text',
88
  text=text_annotations,
89
  textfont=dict(size=9, color='#333333'),
90
  textposition='top center',
 
92
  size=6,
93
  color=scores_array,
94
  colorscale='RdBu',
95
+ reversescale=True,
96
  opacity=0.8,
97
  colorbar=dict(title='Importance', tickfont=dict(size=9), len=0.75, yanchor='middle')
98
  ),
99
+ hoverinfo='text',
100
+ hovertext=[f"Token: {t}<br>Score: {s:.3f}" for t, s in zip(tokens, scores_array)]
101
  )])
102
 
103
  fig.update_layout(
104
  title=dict(text=title, x=0.5, font=dict(size=16)),
105
  scene=dict(
106
+ # ์ˆ˜์ •๋œ ๋ถ€๋ถ„: title ์†์„ฑ ๋‚ด์— text์™€ font๋ฅผ ํฌํ•จ
107
+ xaxis=dict(title=dict(text='PCA Comp 1', font=dict(size=10)), tickfont=dict(size=9), backgroundcolor="rgba(230, 230, 230, 0.8)"),
108
+ yaxis=dict(title=dict(text='PCA Comp 2', font=dict(size=10)), tickfont=dict(size=9), backgroundcolor="rgba(230, 230, 230, 0.8)"),
109
+ zaxis=dict(title=dict(text='PCA Comp 3', font=dict(size=10)), tickfont=dict(size=9), backgroundcolor="rgba(230, 230, 230, 0.8)"),
110
  bgcolor="rgba(255, 255, 255, 0.95)",
111
+ camera_eye=dict(x=1.5, y=1.5, z=0.5)
112
  ),
113
  margin=dict(l=5, r=5, b=5, t=45),
114
+ paper_bgcolor='rgba(0,0,0,0)'
115
  )
116
  return fig
117
 
 
122
  fig.update_layout(
123
  xaxis={'visible': False},
124
  yaxis={'visible': False},
125
+ height=300,
126
  paper_bgcolor='rgba(0,0,0,0)',
127
  plot_bgcolor='rgba(0,0,0,0)'
128
  )
 
134
  error_html = f"<p style='color:red;'>Initialization Error: {html.escape(MODEL_LOADING_ERROR_MESSAGE)}</p>"
135
  empty_df = pd.DataFrame(columns=['token', 'score'])
136
  empty_fig = create_empty_plotly_figure("Model Loading Failed")
137
+ # gr.Label์— ๋Œ€ํ•œ ์˜ค๋ฅ˜ ๋ฐ˜ํ™˜๊ฐ’ ์ˆ˜์ •
138
+ return error_html, [], "Model Loading Failed", {"Status":"Error", "Message":"Model Loading Failed"}, [], empty_df, empty_fig
139
 
140
  try:
141
  tokenizer, model = TOKENIZER_GLOBAL, MODEL_GLOBAL
 
147
  if input_ids.shape[1] == 0:
148
  empty_df = pd.DataFrame(columns=['token', 'score'])
149
  empty_fig = create_empty_plotly_figure("Invalid Input")
150
+ return "<p style='color:orange;'>Input Error: No valid tokens found.</p>", [], "Input Error", {"Status":"Error", "Message":"Invalid Input"}, [], empty_df, empty_fig
151
 
152
  input_embeds_detached = model.embeddings.word_embeddings(input_ids).clone().detach()
153
  input_embeds_for_grad = input_embeds_detached.clone().requires_grad_(True)
 
166
  if input_embeds_for_grad.grad is None:
167
  empty_df = pd.DataFrame(columns=['token', 'score'])
168
  empty_fig = create_empty_plotly_figure("Gradient Error")
169
+ return "<p style='color:red;'>Analysis Error: Gradient calculation failed.</p>", [],"Analysis Error", {"Status":"Error", "Message":"Gradient Error"}, [], empty_df, empty_fig
170
 
171
  grads = input_embeds_for_grad.grad.clone().detach()
172
  scores = (grads * input_embeds_detached).norm(dim=2).squeeze(0)
173
  scores_np = scores.cpu().numpy()
174
+ valid_scores_for_norm = scores_np[np.isfinite(scores_np)]
175
  scores_np = scores_np / (valid_scores_for_norm.max() + 1e-9) if len(valid_scores_for_norm) > 0 and valid_scores_for_norm.max() > 0 else np.zeros_like(scores_np)
176
 
177
  tokens_raw = tokenizer.convert_ids_to_tokens(input_ids[0], skip_special_tokens=False)
 
192
  html_tokens_list.append(f"<span style='font-weight:bold;'>{html.escape(clean_tok_str)}</span>")
193
  highlighted_text_data.append((clean_tok_str + " ", None))
194
  else:
195
+ color = f"rgba(220, 50, 50, {current_score_clipped:.2f})"
196
  html_tokens_list.append(f"<span style='background-color:{color}; color:white; padding: 1px 3px; margin: 1px; border-radius: 4px; display:inline-block;'>{html.escape(clean_tok_str)}</span>")
197
  highlighted_text_data.append((clean_tok_str + " ", round(current_score_clipped, 3)))
198
 
 
223
  pca_tokens = [actual_tokens[i] for i in non_special_token_indices]
224
  if len(pca_tokens) > 0:
225
  pca_embeddings = actual_input_embeds[non_special_token_indices, :]
226
+ pca_scores_for_plot = actual_scores_np[non_special_token_indices]
227
 
228
  pca = PCA(n_components=3, random_state=SEED)
229
  token_embeddings_3d = pca.fit_transform(pca_embeddings)
 
241
  print(f"analyze_sentence_for_gradio error: {e}\n{tb_str}")
242
  empty_df = pd.DataFrame(columns=['token', 'score'])
243
  empty_fig = create_empty_plotly_figure("Analysis Error")
244
+ # gr.Label์— ๋Œ€ํ•œ ์˜ค๋ฅ˜ ๋ฐ˜ํ™˜๊ฐ’ ์ˆ˜์ •
245
+ return error_html, [], "Analysis Failed", {"Status":"Error", "Message": str(e)}, [], empty_df, empty_fig
246
 
247
  # --- Gradio UI Definition (Translated and Enhanced) ---
 
248
  theme = gr.themes.Monochrome(
249
  primary_hue=gr.themes.colors.blue,
250
  secondary_hue=gr.themes.colors.sky,
 
256
  button_primary_text_color="white",
257
  )
258
 
 
259
  with gr.Blocks(title="AI Sentence Analyzer XAI ๐Ÿš€", theme=theme, css=".gradio-container {max-width: 98% !important;}") as demo:
260
  gr.Markdown("# ๐Ÿš€ AI Sentence Analyzer XAI: Exploring Model Explanations")
261
  gr.Markdown("Analyze English sentences to understand BERT model predictions through various XAI visualization techniques. "
262
  "Explore token importance and their distribution in the embedding space.")
263
 
264
  with gr.Row(equal_height=False):
265
+ with gr.Column(scale=1, min_width=350):
266
  with gr.Group():
267
  gr.Markdown("### โœ๏ธ Input Sentence & Settings")
268
  input_sentence = gr.Textbox(lines=5, label="English Sentence to Analyze", placeholder="Enter the English sentence you want to analyze here...")
 
284
  output_highlighted_text = gr.HighlightedText(
285
  label="Token Importance (Score: 0-1)",
286
  show_legend=True,
 
 
 
 
287
  combine_adjacent=False
288
  )
289
  with gr.TabItem("๐Ÿ“Š Top-K Bar Plot", id=2):
 
292
  x="token",
293
  y="score",
294
  tooltip=['token', 'score'],
295
+ min_width=300
 
296
  )
297
  with gr.TabItem("๐ŸŒ Token Embeddings 3D PCA (Interactive)", id=3):
298
  output_pca_plot = gr.Plot(label="3D PCA of Token Embeddings (Colored by Importance Score)")
 
312
  output_pca_plot
313
  ],
314
  fn=analyze_sentence_for_gradio,
315
+ cache_examples=False
316
  )
317
  gr.HTML("<p style='text-align: center; color: #4a5568;'>Explainable AI Demo powered by Gradio & Hugging Face Transformers</p>")
318