Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -9,7 +9,6 @@ import pandas as pd
|
|
| 9 |
import matplotlib
|
| 10 |
matplotlib.use('Agg') # Matplotlib ๋ฐฑ์๋ ์ค์ (Gradio์ ํจ๊ป ์ฌ์ฉ ์ ์ค์)
|
| 11 |
import matplotlib.pyplot as plt
|
| 12 |
-
# from mpl_toolkits.mplot3d import Axes3D # 3D ํ๋กฏ์ Axes3D ๋ช
์์ ์ํฌํธ๋ ์ต์ matplotlib์์ ํ์๋ ์๋ ์ ์์
|
| 13 |
from sklearn.decomposition import PCA
|
| 14 |
|
| 15 |
# --- ๊ธฐ์กด ์ค์ ๋ฐ ์ ์ญ ๋ชจ๋ธ ๋ก๋ ๋ถ๋ถ ---
|
|
@@ -55,7 +54,7 @@ try:
|
|
| 55 |
|
| 56 |
TOKENIZER_GLOBAL = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
|
| 57 |
MODEL_GLOBAL = AutoModel.from_pretrained(
|
| 58 |
-
MODEL_NAME, output_hidden_states=True, output_attentions=False
|
| 59 |
).to(DEVICE).eval()
|
| 60 |
|
| 61 |
if hasattr(lda, 'classes_'): CLASS_NAMES_GLOBAL = lda.classes_
|
|
@@ -75,14 +74,13 @@ def plot_token_pca_3d(token_embeddings_3d, tokens, scores, title="Token Embeddin
|
|
| 75 |
ax = fig.add_subplot(111, projection='3d')
|
| 76 |
|
| 77 |
num_annotations = min(len(tokens), 15)
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
indices_to_annotate = np.argsort(
|
| 82 |
-
else:
|
| 83 |
indices_to_annotate = np.array([])
|
| 84 |
|
| 85 |
-
|
| 86 |
scatter = ax.scatter(token_embeddings_3d[:, 0], token_embeddings_3d[:, 1], token_embeddings_3d[:, 2],
|
| 87 |
c=scores, cmap="coolwarm_r", s=50, alpha=0.8, depthshade=True)
|
| 88 |
|
|
@@ -110,10 +108,6 @@ def analyze_sentence_for_gradio(sentence_text, top_k_value):
|
|
| 110 |
ax = fig.add_subplot(111)
|
| 111 |
ax.text(0.5, 0.5, message, ha='center', va='center', fontsize=10)
|
| 112 |
ax.axis('off')
|
| 113 |
-
# Gradio๊ฐ Figure ๊ฐ์ฒด๋ฅผ ์ฒ๋ฆฌํ๋ฏ๋ก, ์ฌ๊ธฐ์๋ close๋ฅผ ํธ์ถํ์ง ์๊ฑฐ๋
|
| 114 |
-
# Gradio์ Plot ์ปดํฌ๋ํธ๊ฐ Figure๋ฅผ ์ด๋ป๊ฒ ๋ค๋ฃจ๋์ง ํ์ธ ํ์.
|
| 115 |
-
# ์ผ๋ฐ์ ์ผ๋ก๋ closeํ์ง ์๊ณ Figure ๊ฐ์ฒด ์์ฒด๋ฅผ ๋ฐํํฉ๋๋ค.
|
| 116 |
-
# plt.close(fig) # ์ผ๋จ ์ฃผ์ ์ฒ๋ฆฌํ์ฌ Gradio๊ฐ Figure๋ฅผ ๋ฐ๋๋ก ํจ
|
| 117 |
return fig
|
| 118 |
|
| 119 |
if not MODELS_LOADED_SUCCESSFULLY:
|
|
@@ -204,23 +198,20 @@ def analyze_sentence_for_gradio(sentence_text, top_k_value):
|
|
| 204 |
prediction_summary_text = f"ํด๋์ค: {predicted_class_label_str}\nํ๋ฅ : {pred_prob_val:.3f}"
|
| 205 |
prediction_details_for_label = {"์์ธก ํด๋์ค": predicted_class_label_str, "ํ๋ฅ ": f"{pred_prob_val:.3f}"}
|
| 206 |
|
| 207 |
-
pca_fig = create_empty_plot("PCA Plot N/A\n(Not enough non-special tokens for 3D)")
|
| 208 |
non_special_token_indices = [idx for idx, token_id in enumerate(input_ids[0,:len(actual_tokens)].tolist())
|
| 209 |
if token_id not in [cls_token_id, sep_token_id]]
|
| 210 |
|
| 211 |
if len(non_special_token_indices) >= 3 :
|
| 212 |
pca_tokens = [actual_tokens[i] for i in non_special_token_indices]
|
| 213 |
-
|
| 214 |
-
if len(pca_tokens) > 0: # pca_tokens๊ฐ ๋น์ด์์ง ์์์ง ํ์ธ
|
| 215 |
pca_embeddings = actual_input_embeds[non_special_token_indices, :]
|
| 216 |
pca_scores = actual_scores_np[non_special_token_indices]
|
| 217 |
|
| 218 |
pca = PCA(n_components=3, random_state=SEED)
|
| 219 |
token_embeddings_3d = pca.fit_transform(pca_embeddings)
|
| 220 |
-
#
|
| 221 |
-
plt.close(pca_fig)
|
| 222 |
pca_fig = plot_token_pca_3d(token_embeddings_3d, pca_tokens, pca_scores)
|
| 223 |
-
# else: pca_fig๋ ์ด๋ฏธ ์์์ ๋น ํ๋กฏ์ผ๋ก ์ด๊ธฐํ๋จ
|
| 224 |
|
| 225 |
return (html_output_str, highlighted_text_data,
|
| 226 |
prediction_summary_text, prediction_details_for_label,
|
|
@@ -236,7 +227,6 @@ def analyze_sentence_for_gradio(sentence_text, top_k_value):
|
|
| 236 |
empty_fig_placeholder = create_empty_plot("Error during plot generation")
|
| 237 |
return error_html, [], "๋ถ์ ์คํจ", {"์ค๋ฅ": str(e)}, [], empty_df, empty_fig_placeholder
|
| 238 |
|
| 239 |
-
|
| 240 |
# โโโโโโโโโโ Gradio ์ธํฐํ์ด์ค ์ ์ โโโโโโโโโโ
|
| 241 |
theme = gr.themes.Glass(primary_hue="blue", secondary_hue="cyan", neutral_hue="sky").set(
|
| 242 |
body_background_fill="linear-gradient(to right, #c9d6ff, #e2e2e2)",
|
|
@@ -279,7 +269,7 @@ with gr.Blocks(title="AI ๋ฌธ์ฅ ๋ถ์๊ธฐ XAI ๐", theme=theme, css=".gradio-c
|
|
| 279 |
label="Top-K ํ ํฐ ์ค์๋",
|
| 280 |
x="token",
|
| 281 |
y="score",
|
| 282 |
-
tooltip=['token', 'score'], #
|
| 283 |
min_width=300
|
| 284 |
)
|
| 285 |
with gr.TabItem("๐ ํ ํฐ ์๋ฒ ๋ฉ 3D PCA", id=3):
|
|
@@ -302,7 +292,8 @@ with gr.Blocks(title="AI ๋ฌธ์ฅ ๋ถ์๊ธฐ XAI ๐", theme=theme, css=".gradio-c
|
|
| 302 |
fn=analyze_sentence_for_gradio,
|
| 303 |
cache_examples=False
|
| 304 |
)
|
| 305 |
-
gr.Markdown
|
|
|
|
| 306 |
|
| 307 |
submit_button.click(
|
| 308 |
fn=analyze_sentence_for_gradio,
|
|
|
|
| 9 |
import matplotlib
|
| 10 |
matplotlib.use('Agg') # Matplotlib ๋ฐฑ์๋ ์ค์ (Gradio์ ํจ๊ป ์ฌ์ฉ ์ ์ค์)
|
| 11 |
import matplotlib.pyplot as plt
|
|
|
|
| 12 |
from sklearn.decomposition import PCA
|
| 13 |
|
| 14 |
# --- ๊ธฐ์กด ์ค์ ๋ฐ ์ ์ญ ๋ชจ๋ธ ๋ก๋ ๋ถ๋ถ ---
|
|
|
|
| 54 |
|
| 55 |
TOKENIZER_GLOBAL = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
|
| 56 |
MODEL_GLOBAL = AutoModel.from_pretrained(
|
| 57 |
+
MODEL_NAME, output_hidden_states=True, output_attentions=False
|
| 58 |
).to(DEVICE).eval()
|
| 59 |
|
| 60 |
if hasattr(lda, 'classes_'): CLASS_NAMES_GLOBAL = lda.classes_
|
|
|
|
| 74 |
ax = fig.add_subplot(111, projection='3d')
|
| 75 |
|
| 76 |
num_annotations = min(len(tokens), 15)
|
| 77 |
+
if len(scores) > 0 and len(tokens) > 0: # scores์ tokens๊ฐ ๋น์ด์์ง ์์์ง ํ์ธ
|
| 78 |
+
# scores๊ฐ NumPy ๋ฐฐ์ด์ด ์๋ ์ ์์ผ๋ฏ๋ก, ๋ฆฌ์คํธ์ธ ๊ฒฝ์ฐ np.array๋ก ๋ณํ
|
| 79 |
+
scores_np_array = np.array(scores)
|
| 80 |
+
indices_to_annotate = np.argsort(scores_np_array)[-num_annotations:]
|
| 81 |
+
else:
|
| 82 |
indices_to_annotate = np.array([])
|
| 83 |
|
|
|
|
| 84 |
scatter = ax.scatter(token_embeddings_3d[:, 0], token_embeddings_3d[:, 1], token_embeddings_3d[:, 2],
|
| 85 |
c=scores, cmap="coolwarm_r", s=50, alpha=0.8, depthshade=True)
|
| 86 |
|
|
|
|
| 108 |
ax = fig.add_subplot(111)
|
| 109 |
ax.text(0.5, 0.5, message, ha='center', va='center', fontsize=10)
|
| 110 |
ax.axis('off')
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
return fig
|
| 112 |
|
| 113 |
if not MODELS_LOADED_SUCCESSFULLY:
|
|
|
|
| 198 |
prediction_summary_text = f"ํด๋์ค: {predicted_class_label_str}\nํ๋ฅ : {pred_prob_val:.3f}"
|
| 199 |
prediction_details_for_label = {"์์ธก ํด๋์ค": predicted_class_label_str, "ํ๋ฅ ": f"{pred_prob_val:.3f}"}
|
| 200 |
|
| 201 |
+
pca_fig = create_empty_plot("PCA Plot N/A\n(Not enough non-special tokens for 3D)")
|
| 202 |
non_special_token_indices = [idx for idx, token_id in enumerate(input_ids[0,:len(actual_tokens)].tolist())
|
| 203 |
if token_id not in [cls_token_id, sep_token_id]]
|
| 204 |
|
| 205 |
if len(non_special_token_indices) >= 3 :
|
| 206 |
pca_tokens = [actual_tokens[i] for i in non_special_token_indices]
|
| 207 |
+
if len(pca_tokens) > 0:
|
|
|
|
| 208 |
pca_embeddings = actual_input_embeds[non_special_token_indices, :]
|
| 209 |
pca_scores = actual_scores_np[non_special_token_indices]
|
| 210 |
|
| 211 |
pca = PCA(n_components=3, random_state=SEED)
|
| 212 |
token_embeddings_3d = pca.fit_transform(pca_embeddings)
|
| 213 |
+
# plt.close(pca_fig) # ์ด์ ๋น ๊ทธ๋ฆผ ๋ซ๊ธฐ
|
|
|
|
| 214 |
pca_fig = plot_token_pca_3d(token_embeddings_3d, pca_tokens, pca_scores)
|
|
|
|
| 215 |
|
| 216 |
return (html_output_str, highlighted_text_data,
|
| 217 |
prediction_summary_text, prediction_details_for_label,
|
|
|
|
| 227 |
empty_fig_placeholder = create_empty_plot("Error during plot generation")
|
| 228 |
return error_html, [], "๋ถ์ ์คํจ", {"์ค๋ฅ": str(e)}, [], empty_df, empty_fig_placeholder
|
| 229 |
|
|
|
|
| 230 |
# โโโโโโโโโโ Gradio ์ธํฐํ์ด์ค ์ ์ โโโโโโโโโโ
|
| 231 |
theme = gr.themes.Glass(primary_hue="blue", secondary_hue="cyan", neutral_hue="sky").set(
|
| 232 |
body_background_fill="linear-gradient(to right, #c9d6ff, #e2e2e2)",
|
|
|
|
| 269 |
label="Top-K ํ ํฐ ์ค์๋",
|
| 270 |
x="token",
|
| 271 |
y="score",
|
| 272 |
+
tooltip=['token', 'score'], # SyntaxError ์์ ๋จ
|
| 273 |
min_width=300
|
| 274 |
)
|
| 275 |
with gr.TabItem("๐ ํ ํฐ ์๋ฒ ๋ฉ 3D PCA", id=3):
|
|
|
|
| 292 |
fn=analyze_sentence_for_gradio,
|
| 293 |
cache_examples=False
|
| 294 |
)
|
| 295 |
+
# gr.Markdown์ gr.HTML๋ก ๋ณ๊ฒฝํ์ฌ HTML ํ๊ทธ ์ง์ ์ฌ์ฉ
|
| 296 |
+
gr.HTML("<p style='text-align: center; color: #666;'>Explainable AI Demo with Gradio & Transformers</p>")
|
| 297 |
|
| 298 |
submit_button.click(
|
| 299 |
fn=analyze_sentence_for_gradio,
|