kikikara commited on
Commit
2b56bba
ยท
verified ยท
1 Parent(s): 1613a6f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -21
app.py CHANGED
@@ -9,7 +9,6 @@ import pandas as pd
9
  import matplotlib
10
  matplotlib.use('Agg') # Matplotlib ๋ฐฑ์—”๋“œ ์„ค์ • (Gradio์™€ ํ•จ๊ป˜ ์‚ฌ์šฉ ์‹œ ์ค‘์š”)
11
  import matplotlib.pyplot as plt
12
- # from mpl_toolkits.mplot3d import Axes3D # 3D ํ”Œ๋กฏ์— Axes3D ๋ช…์‹œ์  ์ž„ํฌํŠธ๋Š” ์ตœ์‹  matplotlib์—์„œ ํ•„์ˆ˜๋Š” ์•„๋‹ ์ˆ˜ ์žˆ์Œ
13
  from sklearn.decomposition import PCA
14
 
15
  # --- ๊ธฐ์กด ์„ค์ • ๋ฐ ์ „์—ญ ๋ชจ๋ธ ๋กœ๋“œ ๋ถ€๋ถ„ ---
@@ -55,7 +54,7 @@ try:
55
 
56
  TOKENIZER_GLOBAL = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
57
  MODEL_GLOBAL = AutoModel.from_pretrained(
58
- MODEL_NAME, output_hidden_states=True, output_attentions=False # ์–ดํ…์…˜ ๋ถˆํ•„์š”
59
  ).to(DEVICE).eval()
60
 
61
  if hasattr(lda, 'classes_'): CLASS_NAMES_GLOBAL = lda.classes_
@@ -75,14 +74,13 @@ def plot_token_pca_3d(token_embeddings_3d, tokens, scores, title="Token Embeddin
75
  ax = fig.add_subplot(111, projection='3d')
76
 
77
  num_annotations = min(len(tokens), 15)
78
- # ์ ์ˆ˜๊ฐ€ ๋†’์€ ์ˆœ์œผ๋กœ ์ •๋ ฌํ•˜์—ฌ ์–ด๋…ธํ…Œ์ดํŠธํ•  ์ธ๋ฑ์Šค ์„ ํƒ (์ค‘์š”๋„ ๋†’์€ ํ† ํฐ ์œ„์ฃผ)
79
- # scores๊ฐ€ NumPy ๋ฐฐ์—ด์ด๋ผ๊ณ  ๊ฐ€์ •
80
- if len(scores) > 0:
81
- indices_to_annotate = np.argsort(scores)[-num_annotations:]
82
- else: # scores๊ฐ€ ๋น„์–ด์žˆ๊ฑฐ๋‚˜ ๋ฌธ์ œ๊ฐ€ ์žˆ๋Š” ๊ฒฝ์šฐ
83
  indices_to_annotate = np.array([])
84
 
85
-
86
  scatter = ax.scatter(token_embeddings_3d[:, 0], token_embeddings_3d[:, 1], token_embeddings_3d[:, 2],
87
  c=scores, cmap="coolwarm_r", s=50, alpha=0.8, depthshade=True)
88
 
@@ -110,10 +108,6 @@ def analyze_sentence_for_gradio(sentence_text, top_k_value):
110
  ax = fig.add_subplot(111)
111
  ax.text(0.5, 0.5, message, ha='center', va='center', fontsize=10)
112
  ax.axis('off')
113
- # Gradio๊ฐ€ Figure ๊ฐ์ฒด๋ฅผ ์ฒ˜๋ฆฌํ•˜๋ฏ€๋กœ, ์—ฌ๊ธฐ์„œ๋Š” close๋ฅผ ํ˜ธ์ถœํ•˜์ง€ ์•Š๊ฑฐ๋‚˜
114
- # Gradio์˜ Plot ์ปดํฌ๋„ŒํŠธ๊ฐ€ Figure๋ฅผ ์–ด๋–ป๊ฒŒ ๋‹ค๋ฃจ๋Š”์ง€ ํ™•์ธ ํ•„์š”.
115
- # ์ผ๋ฐ˜์ ์œผ๋กœ๋Š” closeํ•˜์ง€ ์•Š๊ณ  Figure ๊ฐ์ฒด ์ž์ฒด๋ฅผ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
116
- # plt.close(fig) # ์ผ๋‹จ ์ฃผ์„ ์ฒ˜๋ฆฌํ•˜์—ฌ Gradio๊ฐ€ Figure๋ฅผ ๋ฐ›๋„๋ก ํ•จ
117
  return fig
118
 
119
  if not MODELS_LOADED_SUCCESSFULLY:
@@ -204,23 +198,20 @@ def analyze_sentence_for_gradio(sentence_text, top_k_value):
204
  prediction_summary_text = f"ํด๋ž˜์Šค: {predicted_class_label_str}\nํ™•๋ฅ : {pred_prob_val:.3f}"
205
  prediction_details_for_label = {"์˜ˆ์ธก ํด๋ž˜์Šค": predicted_class_label_str, "ํ™•๋ฅ ": f"{pred_prob_val:.3f}"}
206
 
207
- pca_fig = create_empty_plot("PCA Plot N/A\n(Not enough non-special tokens for 3D)") # ๊ธฐ๋ณธ ๋นˆ ํ”Œ๋กฏ
208
  non_special_token_indices = [idx for idx, token_id in enumerate(input_ids[0,:len(actual_tokens)].tolist())
209
  if token_id not in [cls_token_id, sep_token_id]]
210
 
211
  if len(non_special_token_indices) >= 3 :
212
  pca_tokens = [actual_tokens[i] for i in non_special_token_indices]
213
- # non_special_token_indices์— ํ•ด๋‹นํ•˜๋Š” ์ž„๋ฒ ๋”ฉ๊ณผ ์ ์ˆ˜๋งŒ ์ถ”์ถœ
214
- if len(pca_tokens) > 0: # pca_tokens๊ฐ€ ๋น„์–ด์žˆ์ง€ ์•Š์€์ง€ ํ™•์ธ
215
  pca_embeddings = actual_input_embeds[non_special_token_indices, :]
216
  pca_scores = actual_scores_np[non_special_token_indices]
217
 
218
  pca = PCA(n_components=3, random_state=SEED)
219
  token_embeddings_3d = pca.fit_transform(pca_embeddings)
220
- # ์ด์ „ ๊ทธ๋ฆผ์ด ์žˆ๋‹ค๋ฉด ๋‹ซ๊ณ  ์ƒˆ๋กœ ๊ทธ๋ฆผ (Gradio Plot์ด Figure ๊ฐ์ฒด๋ฅผ ์ง์ ‘ ๋ฐ›์œผ๋ฏ€๋กœ)
221
- plt.close(pca_fig)
222
  pca_fig = plot_token_pca_3d(token_embeddings_3d, pca_tokens, pca_scores)
223
- # else: pca_fig๋Š” ์ด๋ฏธ ์œ„์—์„œ ๋นˆ ํ”Œ๋กฏ์œผ๋กœ ์ดˆ๊ธฐํ™”๋จ
224
 
225
  return (html_output_str, highlighted_text_data,
226
  prediction_summary_text, prediction_details_for_label,
@@ -236,7 +227,6 @@ def analyze_sentence_for_gradio(sentence_text, top_k_value):
236
  empty_fig_placeholder = create_empty_plot("Error during plot generation")
237
  return error_html, [], "๋ถ„์„ ์‹คํŒจ", {"์˜ค๋ฅ˜": str(e)}, [], empty_df, empty_fig_placeholder
238
 
239
-
240
  # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ Gradio ์ธํ„ฐํŽ˜์ด์Šค ์ •์˜ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
241
  theme = gr.themes.Glass(primary_hue="blue", secondary_hue="cyan", neutral_hue="sky").set(
242
  body_background_fill="linear-gradient(to right, #c9d6ff, #e2e2e2)",
@@ -279,7 +269,7 @@ with gr.Blocks(title="AI ๋ฌธ์žฅ ๋ถ„์„๊ธฐ XAI ๐Ÿš€", theme=theme, css=".gradio-c
279
  label="Top-K ํ† ํฐ ์ค‘์š”๋„",
280
  x="token",
281
  y="score",
282
- tooltip=['token', 'score'], # ์ˆ˜์ •๋œ ๋ถ€๋ถ„: ๋ฌธ์ œ๊ฐ€ ๋œ ํŒŒ๋ผ๋ฏธํ„ฐ ์‚ญ์ œ
283
  min_width=300
284
  )
285
  with gr.TabItem("๐ŸŒ ํ† ํฐ ์ž„๋ฒ ๋”ฉ 3D PCA", id=3):
@@ -302,7 +292,8 @@ with gr.Blocks(title="AI ๋ฌธ์žฅ ๋ถ„์„๊ธฐ XAI ๐Ÿš€", theme=theme, css=".gradio-c
302
  fn=analyze_sentence_for_gradio,
303
  cache_examples=False
304
  )
305
- gr.Markdown("<p style='text-align: center; color: #666;'>Explainable AI Demo with Gradio & Transformers</p>", unsafe_allow_html=True)
 
306
 
307
  submit_button.click(
308
  fn=analyze_sentence_for_gradio,
 
9
  import matplotlib
10
  matplotlib.use('Agg') # Matplotlib ๋ฐฑ์—”๋“œ ์„ค์ • (Gradio์™€ ํ•จ๊ป˜ ์‚ฌ์šฉ ์‹œ ์ค‘์š”)
11
  import matplotlib.pyplot as plt
 
12
  from sklearn.decomposition import PCA
13
 
14
  # --- ๊ธฐ์กด ์„ค์ • ๋ฐ ์ „์—ญ ๋ชจ๋ธ ๋กœ๋“œ ๋ถ€๋ถ„ ---
 
54
 
55
  TOKENIZER_GLOBAL = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
56
  MODEL_GLOBAL = AutoModel.from_pretrained(
57
+ MODEL_NAME, output_hidden_states=True, output_attentions=False
58
  ).to(DEVICE).eval()
59
 
60
  if hasattr(lda, 'classes_'): CLASS_NAMES_GLOBAL = lda.classes_
 
74
  ax = fig.add_subplot(111, projection='3d')
75
 
76
  num_annotations = min(len(tokens), 15)
77
+ if len(scores) > 0 and len(tokens) > 0: # scores์™€ tokens๊ฐ€ ๋น„์–ด์žˆ์ง€ ์•Š์€์ง€ ํ™•์ธ
78
+ # scores๊ฐ€ NumPy ๋ฐฐ์—ด์ด ์•„๋‹ ์ˆ˜ ์žˆ์œผ๋ฏ€๋กœ, ๋ฆฌ์ŠคํŠธ์ธ ๊ฒฝ์šฐ np.array๋กœ ๋ณ€ํ™˜
79
+ scores_np_array = np.array(scores)
80
+ indices_to_annotate = np.argsort(scores_np_array)[-num_annotations:]
81
+ else:
82
  indices_to_annotate = np.array([])
83
 
 
84
  scatter = ax.scatter(token_embeddings_3d[:, 0], token_embeddings_3d[:, 1], token_embeddings_3d[:, 2],
85
  c=scores, cmap="coolwarm_r", s=50, alpha=0.8, depthshade=True)
86
 
 
108
  ax = fig.add_subplot(111)
109
  ax.text(0.5, 0.5, message, ha='center', va='center', fontsize=10)
110
  ax.axis('off')
 
 
 
 
111
  return fig
112
 
113
  if not MODELS_LOADED_SUCCESSFULLY:
 
198
  prediction_summary_text = f"ํด๋ž˜์Šค: {predicted_class_label_str}\nํ™•๋ฅ : {pred_prob_val:.3f}"
199
  prediction_details_for_label = {"์˜ˆ์ธก ํด๋ž˜์Šค": predicted_class_label_str, "ํ™•๋ฅ ": f"{pred_prob_val:.3f}"}
200
 
201
+ pca_fig = create_empty_plot("PCA Plot N/A\n(Not enough non-special tokens for 3D)")
202
  non_special_token_indices = [idx for idx, token_id in enumerate(input_ids[0,:len(actual_tokens)].tolist())
203
  if token_id not in [cls_token_id, sep_token_id]]
204
 
205
  if len(non_special_token_indices) >= 3 :
206
  pca_tokens = [actual_tokens[i] for i in non_special_token_indices]
207
+ if len(pca_tokens) > 0:
 
208
  pca_embeddings = actual_input_embeds[non_special_token_indices, :]
209
  pca_scores = actual_scores_np[non_special_token_indices]
210
 
211
  pca = PCA(n_components=3, random_state=SEED)
212
  token_embeddings_3d = pca.fit_transform(pca_embeddings)
213
+ # plt.close(pca_fig) # ์ด์ „ ๋นˆ ๊ทธ๋ฆผ ๋‹ซ๊ธฐ
 
214
  pca_fig = plot_token_pca_3d(token_embeddings_3d, pca_tokens, pca_scores)
 
215
 
216
  return (html_output_str, highlighted_text_data,
217
  prediction_summary_text, prediction_details_for_label,
 
227
  empty_fig_placeholder = create_empty_plot("Error during plot generation")
228
  return error_html, [], "๋ถ„์„ ์‹คํŒจ", {"์˜ค๋ฅ˜": str(e)}, [], empty_df, empty_fig_placeholder
229
 
 
230
  # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ Gradio ์ธํ„ฐํŽ˜์ด์Šค ์ •์˜ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
231
  theme = gr.themes.Glass(primary_hue="blue", secondary_hue="cyan", neutral_hue="sky").set(
232
  body_background_fill="linear-gradient(to right, #c9d6ff, #e2e2e2)",
 
269
  label="Top-K ํ† ํฐ ์ค‘์š”๋„",
270
  x="token",
271
  y="score",
272
+ tooltip=['token', 'score'], # SyntaxError ์ˆ˜์ •๋จ
273
  min_width=300
274
  )
275
  with gr.TabItem("๐ŸŒ ํ† ํฐ ์ž„๋ฒ ๋”ฉ 3D PCA", id=3):
 
292
  fn=analyze_sentence_for_gradio,
293
  cache_examples=False
294
  )
295
+ # gr.Markdown์„ gr.HTML๋กœ ๋ณ€๊ฒฝํ•˜์—ฌ HTML ํƒœ๊ทธ ์ง์ ‘ ์‚ฌ์šฉ
296
+ gr.HTML("<p style='text-align: center; color: #666;'>Explainable AI Demo with Gradio & Transformers</p>")
297
 
298
  submit_button.click(
299
  fn=analyze_sentence_for_gradio,