Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -6,11 +6,11 @@ import numpy as np
|
|
6 |
import html
|
7 |
from transformers import AutoTokenizer, AutoModel, logging as hf_logging
|
8 |
import pandas as pd
|
9 |
-
import matplotlib
|
10 |
-
matplotlib.use('Agg')
|
11 |
import matplotlib.pyplot as plt
|
12 |
from sklearn.decomposition import PCA
|
13 |
-
import plotly.graph_objects as go
|
14 |
|
15 |
# --- Global Settings and Model Loading ---
|
16 |
hf_logging.set_verbosity_error()
|
@@ -103,7 +103,6 @@ def plot_token_pca_3d_plotly(token_embeddings_3d, tokens, scores, title="Token E
|
|
103 |
fig.update_layout(
|
104 |
title=dict(text=title, x=0.5, font=dict(size=16)),
|
105 |
scene=dict(
|
106 |
-
# ์์ ๋ ๋ถ๋ถ: title ์์ฑ ๋ด์ text์ font๋ฅผ ํฌํจ
|
107 |
xaxis=dict(title=dict(text='PCA Comp 1', font=dict(size=10)), tickfont=dict(size=9), backgroundcolor="rgba(230, 230, 230, 0.8)"),
|
108 |
yaxis=dict(title=dict(text='PCA Comp 2', font=dict(size=10)), tickfont=dict(size=9), backgroundcolor="rgba(230, 230, 230, 0.8)"),
|
109 |
zaxis=dict(title=dict(text='PCA Comp 3', font=dict(size=10)), tickfont=dict(size=9), backgroundcolor="rgba(230, 230, 230, 0.8)"),
|
@@ -134,8 +133,9 @@ def analyze_sentence_for_gradio(sentence_text, top_k_value):
|
|
134 |
error_html = f"<p style='color:red;'>Initialization Error: {html.escape(MODEL_LOADING_ERROR_MESSAGE)}</p>"
|
135 |
empty_df = pd.DataFrame(columns=['token', 'score'])
|
136 |
empty_fig = create_empty_plotly_figure("Model Loading Failed")
|
137 |
-
# gr.Label์ ๋ํ ์ค๋ฅ ๋ฐํ๊ฐ ์์
|
138 |
-
|
|
|
139 |
|
140 |
try:
|
141 |
tokenizer, model = TOKENIZER_GLOBAL, MODEL_GLOBAL
|
@@ -147,7 +147,8 @@ def analyze_sentence_for_gradio(sentence_text, top_k_value):
|
|
147 |
if input_ids.shape[1] == 0:
|
148 |
empty_df = pd.DataFrame(columns=['token', 'score'])
|
149 |
empty_fig = create_empty_plotly_figure("Invalid Input")
|
150 |
-
|
|
|
151 |
|
152 |
input_embeds_detached = model.embeddings.word_embeddings(input_ids).clone().detach()
|
153 |
input_embeds_for_grad = input_embeds_detached.clone().requires_grad_(True)
|
@@ -166,7 +167,8 @@ def analyze_sentence_for_gradio(sentence_text, top_k_value):
|
|
166 |
if input_embeds_for_grad.grad is None:
|
167 |
empty_df = pd.DataFrame(columns=['token', 'score'])
|
168 |
empty_fig = create_empty_plotly_figure("Gradient Error")
|
169 |
-
|
|
|
170 |
|
171 |
grads = input_embeds_for_grad.grad.clone().detach()
|
172 |
scores = (grads * input_embeds_detached).norm(dim=2).squeeze(0)
|
@@ -210,11 +212,12 @@ def analyze_sentence_for_gradio(sentence_text, top_k_value):
|
|
210 |
|
211 |
barplot_df = pd.DataFrame(top_tokens_for_barplot_list) if top_tokens_for_barplot_list else pd.DataFrame(columns=['token', 'score'])
|
212 |
|
213 |
-
predicted_class_label_str = CLASS_LABEL_MAP.get(pred_idx, f"Unknown Index
|
214 |
|
215 |
prediction_summary_text = f"Predicted Class: {predicted_class_label_str}\nProbability: {pred_prob_val:.3f}"
|
216 |
-
|
217 |
-
|
|
|
218 |
pca_fig = create_empty_plotly_figure("PCA Plot N/A\n(Not enough non-special tokens for 3D)")
|
219 |
non_special_token_indices = [idx for idx, token_id in enumerate(input_ids[0,:len(actual_tokens)].tolist())
|
220 |
if token_id not in [cls_token_id, sep_token_id]]
|
@@ -242,7 +245,8 @@ def analyze_sentence_for_gradio(sentence_text, top_k_value):
|
|
242 |
empty_df = pd.DataFrame(columns=['token', 'score'])
|
243 |
empty_fig = create_empty_plotly_figure("Analysis Error")
|
244 |
# gr.Label์ ๋ํ ์ค๋ฅ ๋ฐํ๊ฐ ์์
|
245 |
-
|
|
|
246 |
|
247 |
# --- Gradio UI Definition (Translated and Enhanced) ---
|
248 |
theme = gr.themes.Monochrome(
|
@@ -272,7 +276,7 @@ with gr.Blocks(title="AI Sentence Analyzer XAI ๐", theme=theme, css=".gradio-
|
|
272 |
with gr.Column(scale=2):
|
273 |
with gr.Accordion("๐ฏ Prediction Outcome", open=True):
|
274 |
output_prediction_summary = gr.Textbox(label="Prediction Summary", lines=2, interactive=False)
|
275 |
-
output_prediction_details = gr.Label(label="
|
276 |
with gr.Accordion("โญ Top-K Important Tokens (Table)", open=True):
|
277 |
output_top_tokens_df = gr.DataFrame(headers=["Token", "Score"], label="Most Important Tokens",
|
278 |
row_count=(1,"dynamic"), col_count=(2,"fixed"), interactive=False, wrap=True)
|
|
|
6 |
import html
|
7 |
from transformers import AutoTokenizer, AutoModel, logging as hf_logging
|
8 |
import pandas as pd
|
9 |
+
import matplotlib
|
10 |
+
matplotlib.use('Agg')
|
11 |
import matplotlib.pyplot as plt
|
12 |
from sklearn.decomposition import PCA
|
13 |
+
import plotly.graph_objects as go
|
14 |
|
15 |
# --- Global Settings and Model Loading ---
|
16 |
hf_logging.set_verbosity_error()
|
|
|
103 |
fig.update_layout(
|
104 |
title=dict(text=title, x=0.5, font=dict(size=16)),
|
105 |
scene=dict(
|
|
|
106 |
xaxis=dict(title=dict(text='PCA Comp 1', font=dict(size=10)), tickfont=dict(size=9), backgroundcolor="rgba(230, 230, 230, 0.8)"),
|
107 |
yaxis=dict(title=dict(text='PCA Comp 2', font=dict(size=10)), tickfont=dict(size=9), backgroundcolor="rgba(230, 230, 230, 0.8)"),
|
108 |
zaxis=dict(title=dict(text='PCA Comp 3', font=dict(size=10)), tickfont=dict(size=9), backgroundcolor="rgba(230, 230, 230, 0.8)"),
|
|
|
133 |
error_html = f"<p style='color:red;'>Initialization Error: {html.escape(MODEL_LOADING_ERROR_MESSAGE)}</p>"
|
134 |
empty_df = pd.DataFrame(columns=['token', 'score'])
|
135 |
empty_fig = create_empty_plotly_figure("Model Loading Failed")
|
136 |
+
# gr.Label์ ๋ํ ์ค๋ฅ ๋ฐํ๊ฐ ์์ (๋จ์ ๋์
๋๋ฆฌ ๋๋ ๋ฌธ์์ด)
|
137 |
+
error_label_output = {"Status": "Error", "Message": "Model Loading Failed. Check logs."}
|
138 |
+
return error_html, [], "Model Loading Failed", error_label_output, [], empty_df, empty_fig
|
139 |
|
140 |
try:
|
141 |
tokenizer, model = TOKENIZER_GLOBAL, MODEL_GLOBAL
|
|
|
147 |
if input_ids.shape[1] == 0:
|
148 |
empty_df = pd.DataFrame(columns=['token', 'score'])
|
149 |
empty_fig = create_empty_plotly_figure("Invalid Input")
|
150 |
+
error_label_output = {"Status": "Error", "Message": "Invalid input, no valid tokens."}
|
151 |
+
return "<p style='color:orange;'>Input Error: No valid tokens found.</p>", [], "Input Error", error_label_output, [], empty_df, empty_fig
|
152 |
|
153 |
input_embeds_detached = model.embeddings.word_embeddings(input_ids).clone().detach()
|
154 |
input_embeds_for_grad = input_embeds_detached.clone().requires_grad_(True)
|
|
|
167 |
if input_embeds_for_grad.grad is None:
|
168 |
empty_df = pd.DataFrame(columns=['token', 'score'])
|
169 |
empty_fig = create_empty_plotly_figure("Gradient Error")
|
170 |
+
error_label_output = {"Status": "Error", "Message": "Gradient calculation failed."}
|
171 |
+
return "<p style='color:red;'>Analysis Error: Gradient calculation failed.</p>", [],"Analysis Error", error_label_output, [], empty_df, empty_fig
|
172 |
|
173 |
grads = input_embeds_for_grad.grad.clone().detach()
|
174 |
scores = (grads * input_embeds_detached).norm(dim=2).squeeze(0)
|
|
|
212 |
|
213 |
barplot_df = pd.DataFrame(top_tokens_for_barplot_list) if top_tokens_for_barplot_list else pd.DataFrame(columns=['token', 'score'])
|
214 |
|
215 |
+
predicted_class_label_str = CLASS_LABEL_MAP.get(pred_idx, f"Unknown Index ({pred_idx})")
|
216 |
|
217 |
prediction_summary_text = f"Predicted Class: {predicted_class_label_str}\nProbability: {pred_prob_val:.3f}"
|
218 |
+
# ์์ ๋ ๋ถ๋ถ: gr.Label์ ์ ํฉํ ๋์
๋๋ฆฌ ํํ (ํด๋์ค๋ช
: ํ๋ฅ ๊ฐ)
|
219 |
+
prediction_details_for_label = {predicted_class_label_str: float(f"{pred_prob_val:.3f}")} # ํ๋ฅ ๊ฐ์ float์ผ๋ก ์ ๋ฌ
|
220 |
+
|
221 |
pca_fig = create_empty_plotly_figure("PCA Plot N/A\n(Not enough non-special tokens for 3D)")
|
222 |
non_special_token_indices = [idx for idx, token_id in enumerate(input_ids[0,:len(actual_tokens)].tolist())
|
223 |
if token_id not in [cls_token_id, sep_token_id]]
|
|
|
245 |
empty_df = pd.DataFrame(columns=['token', 'score'])
|
246 |
empty_fig = create_empty_plotly_figure("Analysis Error")
|
247 |
# gr.Label์ ๋ํ ์ค๋ฅ ๋ฐํ๊ฐ ์์
|
248 |
+
error_label_output = {"Status": "Error", "Message": f"Analysis failed: {str(e)}"}
|
249 |
+
return error_html, [], "Analysis Failed", error_label_output, [], empty_df, empty_fig
|
250 |
|
251 |
# --- Gradio UI Definition (Translated and Enhanced) ---
|
252 |
theme = gr.themes.Monochrome(
|
|
|
276 |
with gr.Column(scale=2):
|
277 |
with gr.Accordion("๐ฏ Prediction Outcome", open=True):
|
278 |
output_prediction_summary = gr.Textbox(label="Prediction Summary", lines=2, interactive=False)
|
279 |
+
output_prediction_details = gr.Label(label="Prediction Details & Confidence") # ๋ ์ด๋ธ ์ด๋ฆ ๋ณ๊ฒฝ
|
280 |
with gr.Accordion("โญ Top-K Important Tokens (Table)", open=True):
|
281 |
output_top_tokens_df = gr.DataFrame(headers=["Token", "Score"], label="Most Important Tokens",
|
282 |
row_count=(1,"dynamic"), col_count=(2,"fixed"), interactive=False, wrap=True)
|