kikikara commited on
Commit
6e12229
ยท
verified ยท
1 Parent(s): 9a618da

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -13
app.py CHANGED
@@ -6,11 +6,11 @@ import numpy as np
6
  import html
7
  from transformers import AutoTokenizer, AutoModel, logging as hf_logging
8
  import pandas as pd
9
- import matplotlib # Still used for a basic empty plot if Plotly one is too complex for that
10
- matplotlib.use('Agg') # Matplotlib backend setting
11
  import matplotlib.pyplot as plt
12
  from sklearn.decomposition import PCA
13
- import plotly.graph_objects as go # For interactive 3D PCA plot
14
 
15
  # --- Global Settings and Model Loading ---
16
  hf_logging.set_verbosity_error()
@@ -103,7 +103,6 @@ def plot_token_pca_3d_plotly(token_embeddings_3d, tokens, scores, title="Token E
103
  fig.update_layout(
104
  title=dict(text=title, x=0.5, font=dict(size=16)),
105
  scene=dict(
106
- # ์ˆ˜์ •๋œ ๋ถ€๋ถ„: title ์†์„ฑ ๋‚ด์— text์™€ font๋ฅผ ํฌํ•จ
107
  xaxis=dict(title=dict(text='PCA Comp 1', font=dict(size=10)), tickfont=dict(size=9), backgroundcolor="rgba(230, 230, 230, 0.8)"),
108
  yaxis=dict(title=dict(text='PCA Comp 2', font=dict(size=10)), tickfont=dict(size=9), backgroundcolor="rgba(230, 230, 230, 0.8)"),
109
  zaxis=dict(title=dict(text='PCA Comp 3', font=dict(size=10)), tickfont=dict(size=9), backgroundcolor="rgba(230, 230, 230, 0.8)"),
@@ -134,8 +133,9 @@ def analyze_sentence_for_gradio(sentence_text, top_k_value):
134
  error_html = f"<p style='color:red;'>Initialization Error: {html.escape(MODEL_LOADING_ERROR_MESSAGE)}</p>"
135
  empty_df = pd.DataFrame(columns=['token', 'score'])
136
  empty_fig = create_empty_plotly_figure("Model Loading Failed")
137
- # gr.Label์— ๋Œ€ํ•œ ์˜ค๋ฅ˜ ๋ฐ˜ํ™˜๊ฐ’ ์ˆ˜์ •
138
- return error_html, [], "Model Loading Failed", {"Status":"Error", "Message":"Model Loading Failed"}, [], empty_df, empty_fig
 
139
 
140
  try:
141
  tokenizer, model = TOKENIZER_GLOBAL, MODEL_GLOBAL
@@ -147,7 +147,8 @@ def analyze_sentence_for_gradio(sentence_text, top_k_value):
147
  if input_ids.shape[1] == 0:
148
  empty_df = pd.DataFrame(columns=['token', 'score'])
149
  empty_fig = create_empty_plotly_figure("Invalid Input")
150
- return "<p style='color:orange;'>Input Error: No valid tokens found.</p>", [], "Input Error", {"Status":"Error", "Message":"Invalid Input"}, [], empty_df, empty_fig
 
151
 
152
  input_embeds_detached = model.embeddings.word_embeddings(input_ids).clone().detach()
153
  input_embeds_for_grad = input_embeds_detached.clone().requires_grad_(True)
@@ -166,7 +167,8 @@ def analyze_sentence_for_gradio(sentence_text, top_k_value):
166
  if input_embeds_for_grad.grad is None:
167
  empty_df = pd.DataFrame(columns=['token', 'score'])
168
  empty_fig = create_empty_plotly_figure("Gradient Error")
169
- return "<p style='color:red;'>Analysis Error: Gradient calculation failed.</p>", [],"Analysis Error", {"Status":"Error", "Message":"Gradient Error"}, [], empty_df, empty_fig
 
170
 
171
  grads = input_embeds_for_grad.grad.clone().detach()
172
  scores = (grads * input_embeds_detached).norm(dim=2).squeeze(0)
@@ -210,11 +212,12 @@ def analyze_sentence_for_gradio(sentence_text, top_k_value):
210
 
211
  barplot_df = pd.DataFrame(top_tokens_for_barplot_list) if top_tokens_for_barplot_list else pd.DataFrame(columns=['token', 'score'])
212
 
213
- predicted_class_label_str = CLASS_LABEL_MAP.get(pred_idx, f"Unknown Index: {pred_idx}")
214
 
215
  prediction_summary_text = f"Predicted Class: {predicted_class_label_str}\nProbability: {pred_prob_val:.3f}"
216
- prediction_details_for_label = {"Predicted Class": predicted_class_label_str, "Probability": f"{pred_prob_val:.3f}"}
217
-
 
218
  pca_fig = create_empty_plotly_figure("PCA Plot N/A\n(Not enough non-special tokens for 3D)")
219
  non_special_token_indices = [idx for idx, token_id in enumerate(input_ids[0,:len(actual_tokens)].tolist())
220
  if token_id not in [cls_token_id, sep_token_id]]
@@ -242,7 +245,8 @@ def analyze_sentence_for_gradio(sentence_text, top_k_value):
242
  empty_df = pd.DataFrame(columns=['token', 'score'])
243
  empty_fig = create_empty_plotly_figure("Analysis Error")
244
  # gr.Label์— ๋Œ€ํ•œ ์˜ค๋ฅ˜ ๋ฐ˜ํ™˜๊ฐ’ ์ˆ˜์ •
245
- return error_html, [], "Analysis Failed", {"Status":"Error", "Message": str(e)}, [], empty_df, empty_fig
 
246
 
247
  # --- Gradio UI Definition (Translated and Enhanced) ---
248
  theme = gr.themes.Monochrome(
@@ -272,7 +276,7 @@ with gr.Blocks(title="AI Sentence Analyzer XAI ๐Ÿš€", theme=theme, css=".gradio-
272
  with gr.Column(scale=2):
273
  with gr.Accordion("๐ŸŽฏ Prediction Outcome", open=True):
274
  output_prediction_summary = gr.Textbox(label="Prediction Summary", lines=2, interactive=False)
275
- output_prediction_details = gr.Label(label="Detailed Prediction")
276
  with gr.Accordion("โญ Top-K Important Tokens (Table)", open=True):
277
  output_top_tokens_df = gr.DataFrame(headers=["Token", "Score"], label="Most Important Tokens",
278
  row_count=(1,"dynamic"), col_count=(2,"fixed"), interactive=False, wrap=True)
 
6
  import html
7
  from transformers import AutoTokenizer, AutoModel, logging as hf_logging
8
  import pandas as pd
9
+ import matplotlib
10
+ matplotlib.use('Agg')
11
  import matplotlib.pyplot as plt
12
  from sklearn.decomposition import PCA
13
+ import plotly.graph_objects as go
14
 
15
  # --- Global Settings and Model Loading ---
16
  hf_logging.set_verbosity_error()
 
103
  fig.update_layout(
104
  title=dict(text=title, x=0.5, font=dict(size=16)),
105
  scene=dict(
 
106
  xaxis=dict(title=dict(text='PCA Comp 1', font=dict(size=10)), tickfont=dict(size=9), backgroundcolor="rgba(230, 230, 230, 0.8)"),
107
  yaxis=dict(title=dict(text='PCA Comp 2', font=dict(size=10)), tickfont=dict(size=9), backgroundcolor="rgba(230, 230, 230, 0.8)"),
108
  zaxis=dict(title=dict(text='PCA Comp 3', font=dict(size=10)), tickfont=dict(size=9), backgroundcolor="rgba(230, 230, 230, 0.8)"),
 
133
  error_html = f"<p style='color:red;'>Initialization Error: {html.escape(MODEL_LOADING_ERROR_MESSAGE)}</p>"
134
  empty_df = pd.DataFrame(columns=['token', 'score'])
135
  empty_fig = create_empty_plotly_figure("Model Loading Failed")
136
+ # gr.Label์— ๋Œ€ํ•œ ์˜ค๋ฅ˜ ๋ฐ˜ํ™˜๊ฐ’ ์ˆ˜์ • (๋‹จ์ˆœ ๋”•์…”๋„ˆ๋ฆฌ ๋˜๋Š” ๋ฌธ์ž์—ด)
137
+ error_label_output = {"Status": "Error", "Message": "Model Loading Failed. Check logs."}
138
+ return error_html, [], "Model Loading Failed", error_label_output, [], empty_df, empty_fig
139
 
140
  try:
141
  tokenizer, model = TOKENIZER_GLOBAL, MODEL_GLOBAL
 
147
  if input_ids.shape[1] == 0:
148
  empty_df = pd.DataFrame(columns=['token', 'score'])
149
  empty_fig = create_empty_plotly_figure("Invalid Input")
150
+ error_label_output = {"Status": "Error", "Message": "Invalid input, no valid tokens."}
151
+ return "<p style='color:orange;'>Input Error: No valid tokens found.</p>", [], "Input Error", error_label_output, [], empty_df, empty_fig
152
 
153
  input_embeds_detached = model.embeddings.word_embeddings(input_ids).clone().detach()
154
  input_embeds_for_grad = input_embeds_detached.clone().requires_grad_(True)
 
167
  if input_embeds_for_grad.grad is None:
168
  empty_df = pd.DataFrame(columns=['token', 'score'])
169
  empty_fig = create_empty_plotly_figure("Gradient Error")
170
+ error_label_output = {"Status": "Error", "Message": "Gradient calculation failed."}
171
+ return "<p style='color:red;'>Analysis Error: Gradient calculation failed.</p>", [],"Analysis Error", error_label_output, [], empty_df, empty_fig
172
 
173
  grads = input_embeds_for_grad.grad.clone().detach()
174
  scores = (grads * input_embeds_detached).norm(dim=2).squeeze(0)
 
212
 
213
  barplot_df = pd.DataFrame(top_tokens_for_barplot_list) if top_tokens_for_barplot_list else pd.DataFrame(columns=['token', 'score'])
214
 
215
+ predicted_class_label_str = CLASS_LABEL_MAP.get(pred_idx, f"Unknown Index ({pred_idx})")
216
 
217
  prediction_summary_text = f"Predicted Class: {predicted_class_label_str}\nProbability: {pred_prob_val:.3f}"
218
+ # ์ˆ˜์ •๋œ ๋ถ€๋ถ„: gr.Label์— ์ ํ•ฉํ•œ ๋”•์…”๋„ˆ๋ฆฌ ํ˜•ํƒœ (ํด๋ž˜์Šค๋ช…: ํ™•๋ฅ ๊ฐ’)
219
+ prediction_details_for_label = {predicted_class_label_str: float(f"{pred_prob_val:.3f}")} # ํ™•๋ฅ ๊ฐ’์„ float์œผ๋กœ ์ „๋‹ฌ
220
+
221
  pca_fig = create_empty_plotly_figure("PCA Plot N/A\n(Not enough non-special tokens for 3D)")
222
  non_special_token_indices = [idx for idx, token_id in enumerate(input_ids[0,:len(actual_tokens)].tolist())
223
  if token_id not in [cls_token_id, sep_token_id]]
 
245
  empty_df = pd.DataFrame(columns=['token', 'score'])
246
  empty_fig = create_empty_plotly_figure("Analysis Error")
247
  # gr.Label์— ๋Œ€ํ•œ ์˜ค๋ฅ˜ ๋ฐ˜ํ™˜๊ฐ’ ์ˆ˜์ •
248
+ error_label_output = {"Status": "Error", "Message": f"Analysis failed: {str(e)}"}
249
+ return error_html, [], "Analysis Failed", error_label_output, [], empty_df, empty_fig
250
 
251
  # --- Gradio UI Definition (Translated and Enhanced) ---
252
  theme = gr.themes.Monochrome(
 
276
  with gr.Column(scale=2):
277
  with gr.Accordion("๐ŸŽฏ Prediction Outcome", open=True):
278
  output_prediction_summary = gr.Textbox(label="Prediction Summary", lines=2, interactive=False)
279
+ output_prediction_details = gr.Label(label="Prediction Details & Confidence") # ๋ ˆ์ด๋ธ” ์ด๋ฆ„ ๋ณ€๊ฒฝ
280
  with gr.Accordion("โญ Top-K Important Tokens (Table)", open=True):
281
  output_top_tokens_df = gr.DataFrame(headers=["Token", "Score"], label="Most Important Tokens",
282
  row_count=(1,"dynamic"), col_count=(2,"fixed"), interactive=False, wrap=True)