kikikara commited on
Commit
72b741e
Β·
verified Β·
1 Parent(s): 6e12229

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -27
app.py CHANGED
@@ -17,7 +17,7 @@ hf_logging.set_verbosity_error()
17
 
18
  MODEL_NAME = "bert-base-uncased"
19
  DEVICE = "cpu"
20
- SAVE_DIR = "μ €μž₯μ €μž₯1"
21
  LAYER_ID = 4
22
  SEED = 0
23
  CLF_NAME = "linear"
@@ -133,7 +133,6 @@ def analyze_sentence_for_gradio(sentence_text, top_k_value):
133
  error_html = f"<p style='color:red;'>Initialization Error: {html.escape(MODEL_LOADING_ERROR_MESSAGE)}</p>"
134
  empty_df = pd.DataFrame(columns=['token', 'score'])
135
  empty_fig = create_empty_plotly_figure("Model Loading Failed")
136
- # gr.Label에 λŒ€ν•œ 였λ₯˜ λ°˜ν™˜κ°’ μˆ˜μ • (λ‹¨μˆœ λ”•μ…”λ„ˆλ¦¬ λ˜λŠ” λ¬Έμžμ—΄)
137
  error_label_output = {"Status": "Error", "Message": "Model Loading Failed. Check logs."}
138
  return error_html, [], "Model Loading Failed", error_label_output, [], empty_df, empty_fig
139
 
@@ -215,8 +214,7 @@ def analyze_sentence_for_gradio(sentence_text, top_k_value):
215
  predicted_class_label_str = CLASS_LABEL_MAP.get(pred_idx, f"Unknown Index ({pred_idx})")
216
 
217
  prediction_summary_text = f"Predicted Class: {predicted_class_label_str}\nProbability: {pred_prob_val:.3f}"
218
- # μˆ˜μ •λœ λΆ€λΆ„: gr.Label에 μ ν•©ν•œ λ”•μ…”λ„ˆλ¦¬ ν˜•νƒœ (클래슀λͺ…: ν™•λ₯ κ°’)
219
- prediction_details_for_label = {predicted_class_label_str: float(f"{pred_prob_val:.3f}")} # ν™•λ₯ κ°’을 float으둜 전달
220
 
221
  pca_fig = create_empty_plotly_figure("PCA Plot N/A\n(Not enough non-special tokens for 3D)")
222
  non_special_token_indices = [idx for idx, token_id in enumerate(input_ids[0,:len(actual_tokens)].tolist())
@@ -244,11 +242,10 @@ def analyze_sentence_for_gradio(sentence_text, top_k_value):
244
  print(f"analyze_sentence_for_gradio error: {e}\n{tb_str}")
245
  empty_df = pd.DataFrame(columns=['token', 'score'])
246
  empty_fig = create_empty_plotly_figure("Analysis Error")
247
- # gr.Label에 λŒ€ν•œ 였λ₯˜ λ°˜ν™˜κ°’ μˆ˜μ •
248
  error_label_output = {"Status": "Error", "Message": f"Analysis failed: {str(e)}"}
249
  return error_html, [], "Analysis Failed", error_label_output, [], empty_df, empty_fig
250
 
251
- # --- Gradio UI Definition (Translated and Enhanced) ---
252
  theme = gr.themes.Monochrome(
253
  primary_hue=gr.themes.colors.blue,
254
  secondary_hue=gr.themes.colors.sky,
@@ -265,6 +262,7 @@ with gr.Blocks(title="AI Sentence Analyzer XAI πŸš€", theme=theme, css=".gradio-
265
  gr.Markdown("Analyze English sentences to understand BERT model predictions through various XAI visualization techniques. "
266
  "Explore token importance and their distribution in the embedding space.")
267
 
 
268
  with gr.Row(equal_height=False):
269
  with gr.Column(scale=1, min_width=350):
270
  with gr.Group():
@@ -276,32 +274,48 @@ with gr.Blocks(title="AI Sentence Analyzer XAI πŸš€", theme=theme, css=".gradio-
276
  with gr.Column(scale=2):
277
  with gr.Accordion("🎯 Prediction Outcome", open=True):
278
  output_prediction_summary = gr.Textbox(label="Prediction Summary", lines=2, interactive=False)
279
- output_prediction_details = gr.Label(label="Prediction Details & Confidence") # λ ˆμ΄λΈ” 이름 λ³€κ²½
280
  with gr.Accordion("⭐ Top-K Important Tokens (Table)", open=True):
281
  output_top_tokens_df = gr.DataFrame(headers=["Token", "Score"], label="Most Important Tokens",
282
  row_count=(1,"dynamic"), col_count=(2,"fixed"), interactive=False, wrap=True)
 
283
 
284
- with gr.Tabs() as tabs:
285
- with gr.TabItem("🎨 HTML Highlight (Custom)", id=0):
286
- output_html_visualization = gr.HTML(label="Token Importance (Gradient x Input based)")
287
- with gr.TabItem("πŸ–οΈ Highlighted Text (Gradio)", id=1):
288
- output_highlighted_text = gr.HighlightedText(
289
- label="Token Importance (Score: 0-1)",
290
- show_legend=True,
291
- combine_adjacent=False
292
- )
293
- with gr.TabItem("πŸ“Š Top-K Bar Plot", id=2):
294
- output_top_tokens_barplot = gr.BarPlot(
295
- label="Top-K Token Importance Scores",
296
- x="token",
297
- y="score",
298
- tooltip=['token', 'score'],
299
- min_width=300
300
- )
301
- with gr.TabItem("🌐 Token Embeddings 3D PCA (Interactive)", id=3):
302
- output_pca_plot = gr.Plot(label="3D PCA of Token Embeddings (Colored by Importance Score)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303
 
304
- gr.Markdown("---")
 
305
  gr.Examples(
306
  examples=[
307
  ["This movie is an absolute masterpiece, captivating from start to finish.", 5],
 
17
 
18
  MODEL_NAME = "bert-base-uncased"
19
  DEVICE = "cpu"
20
+ SAVE_DIR = "μ €μž₯μ €μž₯1" # This folder name is from your setup
21
  LAYER_ID = 4
22
  SEED = 0
23
  CLF_NAME = "linear"
 
133
  error_html = f"<p style='color:red;'>Initialization Error: {html.escape(MODEL_LOADING_ERROR_MESSAGE)}</p>"
134
  empty_df = pd.DataFrame(columns=['token', 'score'])
135
  empty_fig = create_empty_plotly_figure("Model Loading Failed")
 
136
  error_label_output = {"Status": "Error", "Message": "Model Loading Failed. Check logs."}
137
  return error_html, [], "Model Loading Failed", error_label_output, [], empty_df, empty_fig
138
 
 
214
  predicted_class_label_str = CLASS_LABEL_MAP.get(pred_idx, f"Unknown Index ({pred_idx})")
215
 
216
  prediction_summary_text = f"Predicted Class: {predicted_class_label_str}\nProbability: {pred_prob_val:.3f}"
217
+ prediction_details_for_label = {predicted_class_label_str: float(f"{pred_prob_val:.3f}")}
 
218
 
219
  pca_fig = create_empty_plotly_figure("PCA Plot N/A\n(Not enough non-special tokens for 3D)")
220
  non_special_token_indices = [idx for idx, token_id in enumerate(input_ids[0,:len(actual_tokens)].tolist())
 
242
  print(f"analyze_sentence_for_gradio error: {e}\n{tb_str}")
243
  empty_df = pd.DataFrame(columns=['token', 'score'])
244
  empty_fig = create_empty_plotly_figure("Analysis Error")
 
245
  error_label_output = {"Status": "Error", "Message": f"Analysis failed: {str(e)}"}
246
  return error_html, [], "Analysis Failed", error_label_output, [], empty_df, empty_fig
247
 
248
+ # --- Gradio UI Definition (Tabs removed, visualizations shown sequentially or in rows) ---
249
  theme = gr.themes.Monochrome(
250
  primary_hue=gr.themes.colors.blue,
251
  secondary_hue=gr.themes.colors.sky,
 
262
  gr.Markdown("Analyze English sentences to understand BERT model predictions through various XAI visualization techniques. "
263
  "Explore token importance and their distribution in the embedding space.")
264
 
265
+ # Inputs and Summary Outputs Row
266
  with gr.Row(equal_height=False):
267
  with gr.Column(scale=1, min_width=350):
268
  with gr.Group():
 
274
  with gr.Column(scale=2):
275
  with gr.Accordion("🎯 Prediction Outcome", open=True):
276
  output_prediction_summary = gr.Textbox(label="Prediction Summary", lines=2, interactive=False)
277
+ output_prediction_details = gr.Label(label="Prediction Details & Confidence")
278
  with gr.Accordion("⭐ Top-K Important Tokens (Table)", open=True):
279
  output_top_tokens_df = gr.DataFrame(headers=["Token", "Score"], label="Most Important Tokens",
280
  row_count=(1,"dynamic"), col_count=(2,"fixed"), interactive=False, wrap=True)
281
+ gr.Markdown("---") # Separator
282
 
283
+ # Visualization Section Title
284
+ gr.Markdown("## πŸ“Š Detailed Visualizations")
285
+
286
+ # HTML Highlight (Custom) - Full Width
287
+ with gr.Group():
288
+ gr.Markdown("### 🎨 HTML Highlight (Custom)")
289
+ output_html_visualization = gr.HTML(label="Token Importance (Gradient x Input based)")
290
+
291
+ # Highlighted Text (Gradio) - Full Width
292
+ with gr.Group():
293
+ gr.Markdown("### πŸ–οΈ Highlighted Text (Gradio)")
294
+ output_highlighted_text = gr.HighlightedText(
295
+ label="Token Importance (Score: 0-1)",
296
+ show_legend=True,
297
+ combine_adjacent=False
298
+ )
299
+
300
+ # BarPlot and PCA Plot Side-by-Side
301
+ with gr.Row():
302
+ with gr.Column(scale=1, min_width=400): # Adjusted min_width for BarPlot
303
+ with gr.Group():
304
+ gr.Markdown("### πŸ“Š Top-K Bar Plot")
305
+ output_top_tokens_barplot = gr.BarPlot(
306
+ label="Top-K Token Importance Scores",
307
+ x="token",
308
+ y="score",
309
+ tooltip=['token', 'score'],
310
+ min_width=300 # BarPlot itself can define min_width
311
+ )
312
+ with gr.Column(scale=1, min_width=400): # Adjusted min_width for PCA
313
+ with gr.Group():
314
+ gr.Markdown("### 🌐 Token Embeddings 3D PCA (Interactive)")
315
+ output_pca_plot = gr.Plot(label="3D PCA of Token Embeddings (Colored by Importance Score)")
316
 
317
+ gr.Markdown("---") # Separator
318
+
319
  gr.Examples(
320
  examples=[
321
  ["This movie is an absolute masterpiece, captivating from start to finish.", 5],