Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -17,12 +17,11 @@ hf_logging.set_verbosity_error()
|
|
| 17 |
|
| 18 |
MODEL_NAME = "bert-base-uncased"
|
| 19 |
DEVICE = "cpu"
|
| 20 |
-
SAVE_DIR = "์ ์ฅ์ ์ฅ1"
|
| 21 |
LAYER_ID = 4
|
| 22 |
SEED = 0
|
| 23 |
CLF_NAME = "linear"
|
| 24 |
|
| 25 |
-
# Class label mapping provided by user
|
| 26 |
CLASS_LABEL_MAP = {
|
| 27 |
0: "World",
|
| 28 |
1: "Sports",
|
|
@@ -32,7 +31,6 @@ CLASS_LABEL_MAP = {
|
|
| 32 |
|
| 33 |
TOKENIZER_GLOBAL, MODEL_GLOBAL = None, None
|
| 34 |
W_GLOBAL, MU_GLOBAL, W_P_GLOBAL, B_P_GLOBAL = None, None, None, None
|
| 35 |
-
# CLASS_NAMES_GLOBAL = None # We'll use CLASS_LABEL_MAP instead for clarity
|
| 36 |
MODELS_LOADED_SUCCESSFULLY = False
|
| 37 |
MODEL_LOADING_ERROR_MESSAGE = ""
|
| 38 |
|
|
@@ -73,24 +71,20 @@ except Exception as e:
|
|
| 73 |
|
| 74 |
# Helper function: 3D PCA Visualization using Plotly
|
| 75 |
def plot_token_pca_3d_plotly(token_embeddings_3d, tokens, scores, title="Token Embeddings 3D PCA (Colored by Importance)"):
|
| 76 |
-
num_annotations = min(len(tokens), 20)
|
| 77 |
-
|
| 78 |
-
# Ensure scores is a 1D numpy array for Plotly marker color processing
|
| 79 |
scores_array = np.array(scores).flatten()
|
| 80 |
-
|
| 81 |
-
# Prepare text annotations (only for most important tokens to avoid clutter)
|
| 82 |
text_annotations = [''] * len(tokens)
|
| 83 |
if len(scores_array) > 0 and len(tokens) > 0:
|
| 84 |
indices_to_annotate = np.argsort(scores_array)[-num_annotations:]
|
| 85 |
for i in indices_to_annotate:
|
| 86 |
-
if i < len(tokens):
|
| 87 |
text_annotations[i] = tokens[i]
|
| 88 |
|
| 89 |
fig = go.Figure(data=[go.Scatter3d(
|
| 90 |
x=token_embeddings_3d[:, 0],
|
| 91 |
y=token_embeddings_3d[:, 1],
|
| 92 |
z=token_embeddings_3d[:, 2],
|
| 93 |
-
mode='markers+text',
|
| 94 |
text=text_annotations,
|
| 95 |
textfont=dict(size=9, color='#333333'),
|
| 96 |
textposition='top center',
|
|
@@ -98,25 +92,26 @@ def plot_token_pca_3d_plotly(token_embeddings_3d, tokens, scores, title="Token E
|
|
| 98 |
size=6,
|
| 99 |
color=scores_array,
|
| 100 |
colorscale='RdBu',
|
| 101 |
-
reversescale=True,
|
| 102 |
opacity=0.8,
|
| 103 |
colorbar=dict(title='Importance', tickfont=dict(size=9), len=0.75, yanchor='middle')
|
| 104 |
),
|
| 105 |
-
hoverinfo='text',
|
| 106 |
-
hovertext=[f"Token: {t}<br>Score: {s:.3f}" for t, s in zip(tokens, scores_array)]
|
| 107 |
)])
|
| 108 |
|
| 109 |
fig.update_layout(
|
| 110 |
title=dict(text=title, x=0.5, font=dict(size=16)),
|
| 111 |
scene=dict(
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
|
|
|
| 115 |
bgcolor="rgba(255, 255, 255, 0.95)",
|
| 116 |
-
camera_eye=dict(x=1.5, y=1.5, z=0.5)
|
| 117 |
),
|
| 118 |
margin=dict(l=5, r=5, b=5, t=45),
|
| 119 |
-
paper_bgcolor='rgba(0,0,0,0)'
|
| 120 |
)
|
| 121 |
return fig
|
| 122 |
|
|
@@ -127,7 +122,7 @@ def create_empty_plotly_figure(message="N/A"):
|
|
| 127 |
fig.update_layout(
|
| 128 |
xaxis={'visible': False},
|
| 129 |
yaxis={'visible': False},
|
| 130 |
-
height=300,
|
| 131 |
paper_bgcolor='rgba(0,0,0,0)',
|
| 132 |
plot_bgcolor='rgba(0,0,0,0)'
|
| 133 |
)
|
|
@@ -139,7 +134,8 @@ def analyze_sentence_for_gradio(sentence_text, top_k_value):
|
|
| 139 |
error_html = f"<p style='color:red;'>Initialization Error: {html.escape(MODEL_LOADING_ERROR_MESSAGE)}</p>"
|
| 140 |
empty_df = pd.DataFrame(columns=['token', 'score'])
|
| 141 |
empty_fig = create_empty_plotly_figure("Model Loading Failed")
|
| 142 |
-
|
|
|
|
| 143 |
|
| 144 |
try:
|
| 145 |
tokenizer, model = TOKENIZER_GLOBAL, MODEL_GLOBAL
|
|
@@ -151,7 +147,7 @@ def analyze_sentence_for_gradio(sentence_text, top_k_value):
|
|
| 151 |
if input_ids.shape[1] == 0:
|
| 152 |
empty_df = pd.DataFrame(columns=['token', 'score'])
|
| 153 |
empty_fig = create_empty_plotly_figure("Invalid Input")
|
| 154 |
-
return "<p style='color:orange;'>Input Error: No valid tokens found.</p>", [], "Input Error", {"Error":"Input
|
| 155 |
|
| 156 |
input_embeds_detached = model.embeddings.word_embeddings(input_ids).clone().detach()
|
| 157 |
input_embeds_for_grad = input_embeds_detached.clone().requires_grad_(True)
|
|
@@ -170,12 +166,12 @@ def analyze_sentence_for_gradio(sentence_text, top_k_value):
|
|
| 170 |
if input_embeds_for_grad.grad is None:
|
| 171 |
empty_df = pd.DataFrame(columns=['token', 'score'])
|
| 172 |
empty_fig = create_empty_plotly_figure("Gradient Error")
|
| 173 |
-
return "<p style='color:red;'>Analysis Error: Gradient calculation failed.</p>", [],"Analysis Error", {"Error":"
|
| 174 |
|
| 175 |
grads = input_embeds_for_grad.grad.clone().detach()
|
| 176 |
scores = (grads * input_embeds_detached).norm(dim=2).squeeze(0)
|
| 177 |
scores_np = scores.cpu().numpy()
|
| 178 |
-
valid_scores_for_norm = scores_np[np.isfinite(scores_np)]
|
| 179 |
scores_np = scores_np / (valid_scores_for_norm.max() + 1e-9) if len(valid_scores_for_norm) > 0 and valid_scores_for_norm.max() > 0 else np.zeros_like(scores_np)
|
| 180 |
|
| 181 |
tokens_raw = tokenizer.convert_ids_to_tokens(input_ids[0], skip_special_tokens=False)
|
|
@@ -196,7 +192,7 @@ def analyze_sentence_for_gradio(sentence_text, top_k_value):
|
|
| 196 |
html_tokens_list.append(f"<span style='font-weight:bold;'>{html.escape(clean_tok_str)}</span>")
|
| 197 |
highlighted_text_data.append((clean_tok_str + " ", None))
|
| 198 |
else:
|
| 199 |
-
color = f"rgba(220, 50, 50, {current_score_clipped:.2f})"
|
| 200 |
html_tokens_list.append(f"<span style='background-color:{color}; color:white; padding: 1px 3px; margin: 1px; border-radius: 4px; display:inline-block;'>{html.escape(clean_tok_str)}</span>")
|
| 201 |
highlighted_text_data.append((clean_tok_str + " ", round(current_score_clipped, 3)))
|
| 202 |
|
|
@@ -227,7 +223,7 @@ def analyze_sentence_for_gradio(sentence_text, top_k_value):
|
|
| 227 |
pca_tokens = [actual_tokens[i] for i in non_special_token_indices]
|
| 228 |
if len(pca_tokens) > 0:
|
| 229 |
pca_embeddings = actual_input_embeds[non_special_token_indices, :]
|
| 230 |
-
pca_scores_for_plot = actual_scores_np[non_special_token_indices]
|
| 231 |
|
| 232 |
pca = PCA(n_components=3, random_state=SEED)
|
| 233 |
token_embeddings_3d = pca.fit_transform(pca_embeddings)
|
|
@@ -245,10 +241,10 @@ def analyze_sentence_for_gradio(sentence_text, top_k_value):
|
|
| 245 |
print(f"analyze_sentence_for_gradio error: {e}\n{tb_str}")
|
| 246 |
empty_df = pd.DataFrame(columns=['token', 'score'])
|
| 247 |
empty_fig = create_empty_plotly_figure("Analysis Error")
|
| 248 |
-
|
|
|
|
| 249 |
|
| 250 |
# --- Gradio UI Definition (Translated and Enhanced) ---
|
| 251 |
-
# Using a built-in theme and some CSS for aesthetics
|
| 252 |
theme = gr.themes.Monochrome(
|
| 253 |
primary_hue=gr.themes.colors.blue,
|
| 254 |
secondary_hue=gr.themes.colors.sky,
|
|
@@ -260,14 +256,13 @@ theme = gr.themes.Monochrome(
|
|
| 260 |
button_primary_text_color="white",
|
| 261 |
)
|
| 262 |
|
| 263 |
-
|
| 264 |
with gr.Blocks(title="AI Sentence Analyzer XAI ๐", theme=theme, css=".gradio-container {max-width: 98% !important;}") as demo:
|
| 265 |
gr.Markdown("# ๐ AI Sentence Analyzer XAI: Exploring Model Explanations")
|
| 266 |
gr.Markdown("Analyze English sentences to understand BERT model predictions through various XAI visualization techniques. "
|
| 267 |
"Explore token importance and their distribution in the embedding space.")
|
| 268 |
|
| 269 |
with gr.Row(equal_height=False):
|
| 270 |
-
with gr.Column(scale=1, min_width=350):
|
| 271 |
with gr.Group():
|
| 272 |
gr.Markdown("### โ๏ธ Input Sentence & Settings")
|
| 273 |
input_sentence = gr.Textbox(lines=5, label="English Sentence to Analyze", placeholder="Enter the English sentence you want to analyze here...")
|
|
@@ -289,10 +284,6 @@ with gr.Blocks(title="AI Sentence Analyzer XAI ๐", theme=theme, css=".gradio-
|
|
| 289 |
output_highlighted_text = gr.HighlightedText(
|
| 290 |
label="Token Importance (Score: 0-1)",
|
| 291 |
show_legend=True,
|
| 292 |
-
# Color map can be more sophisticated if scores are categorical
|
| 293 |
-
# For numerical scores (0-1), Gradio tries to infer intensity.
|
| 294 |
-
# Example color map (if scores were categories like "LOW", "MEDIUM", "HIGH"):
|
| 295 |
-
# color_map={"LOW": "lightblue", "MEDIUM": "lightgreen", "HIGH": "pink"},
|
| 296 |
combine_adjacent=False
|
| 297 |
)
|
| 298 |
with gr.TabItem("๐ Top-K Bar Plot", id=2):
|
|
@@ -301,8 +292,7 @@ with gr.Blocks(title="AI Sentence Analyzer XAI ๐", theme=theme, css=".gradio-
|
|
| 301 |
x="token",
|
| 302 |
y="score",
|
| 303 |
tooltip=['token', 'score'],
|
| 304 |
-
min_width=300
|
| 305 |
-
# title="Top-K Most Important Tokens" # BarPlot may not have a direct title prop
|
| 306 |
)
|
| 307 |
with gr.TabItem("๐ Token Embeddings 3D PCA (Interactive)", id=3):
|
| 308 |
output_pca_plot = gr.Plot(label="3D PCA of Token Embeddings (Colored by Importance Score)")
|
|
@@ -322,7 +312,7 @@ with gr.Blocks(title="AI Sentence Analyzer XAI ๐", theme=theme, css=".gradio-
|
|
| 322 |
output_pca_plot
|
| 323 |
],
|
| 324 |
fn=analyze_sentence_for_gradio,
|
| 325 |
-
cache_examples=False
|
| 326 |
)
|
| 327 |
gr.HTML("<p style='text-align: center; color: #4a5568;'>Explainable AI Demo powered by Gradio & Hugging Face Transformers</p>")
|
| 328 |
|
|
|
|
| 17 |
|
| 18 |
MODEL_NAME = "bert-base-uncased"
|
| 19 |
DEVICE = "cpu"
|
| 20 |
+
SAVE_DIR = "์ ์ฅ์ ์ฅ1"
|
| 21 |
LAYER_ID = 4
|
| 22 |
SEED = 0
|
| 23 |
CLF_NAME = "linear"
|
| 24 |
|
|
|
|
| 25 |
CLASS_LABEL_MAP = {
|
| 26 |
0: "World",
|
| 27 |
1: "Sports",
|
|
|
|
| 31 |
|
| 32 |
TOKENIZER_GLOBAL, MODEL_GLOBAL = None, None
|
| 33 |
W_GLOBAL, MU_GLOBAL, W_P_GLOBAL, B_P_GLOBAL = None, None, None, None
|
|
|
|
| 34 |
MODELS_LOADED_SUCCESSFULLY = False
|
| 35 |
MODEL_LOADING_ERROR_MESSAGE = ""
|
| 36 |
|
|
|
|
| 71 |
|
| 72 |
# Helper function: 3D PCA Visualization using Plotly
|
| 73 |
def plot_token_pca_3d_plotly(token_embeddings_3d, tokens, scores, title="Token Embeddings 3D PCA (Colored by Importance)"):
|
| 74 |
+
num_annotations = min(len(tokens), 20)
|
|
|
|
|
|
|
| 75 |
scores_array = np.array(scores).flatten()
|
|
|
|
|
|
|
| 76 |
text_annotations = [''] * len(tokens)
|
| 77 |
if len(scores_array) > 0 and len(tokens) > 0:
|
| 78 |
indices_to_annotate = np.argsort(scores_array)[-num_annotations:]
|
| 79 |
for i in indices_to_annotate:
|
| 80 |
+
if i < len(tokens):
|
| 81 |
text_annotations[i] = tokens[i]
|
| 82 |
|
| 83 |
fig = go.Figure(data=[go.Scatter3d(
|
| 84 |
x=token_embeddings_3d[:, 0],
|
| 85 |
y=token_embeddings_3d[:, 1],
|
| 86 |
z=token_embeddings_3d[:, 2],
|
| 87 |
+
mode='markers+text',
|
| 88 |
text=text_annotations,
|
| 89 |
textfont=dict(size=9, color='#333333'),
|
| 90 |
textposition='top center',
|
|
|
|
| 92 |
size=6,
|
| 93 |
color=scores_array,
|
| 94 |
colorscale='RdBu',
|
| 95 |
+
reversescale=True,
|
| 96 |
opacity=0.8,
|
| 97 |
colorbar=dict(title='Importance', tickfont=dict(size=9), len=0.75, yanchor='middle')
|
| 98 |
),
|
| 99 |
+
hoverinfo='text',
|
| 100 |
+
hovertext=[f"Token: {t}<br>Score: {s:.3f}" for t, s in zip(tokens, scores_array)]
|
| 101 |
)])
|
| 102 |
|
| 103 |
fig.update_layout(
|
| 104 |
title=dict(text=title, x=0.5, font=dict(size=16)),
|
| 105 |
scene=dict(
|
| 106 |
+
# ์์ ๋ ๋ถ๋ถ: title ์์ฑ ๋ด์ text์ font๋ฅผ ํฌํจ
|
| 107 |
+
xaxis=dict(title=dict(text='PCA Comp 1', font=dict(size=10)), tickfont=dict(size=9), backgroundcolor="rgba(230, 230, 230, 0.8)"),
|
| 108 |
+
yaxis=dict(title=dict(text='PCA Comp 2', font=dict(size=10)), tickfont=dict(size=9), backgroundcolor="rgba(230, 230, 230, 0.8)"),
|
| 109 |
+
zaxis=dict(title=dict(text='PCA Comp 3', font=dict(size=10)), tickfont=dict(size=9), backgroundcolor="rgba(230, 230, 230, 0.8)"),
|
| 110 |
bgcolor="rgba(255, 255, 255, 0.95)",
|
| 111 |
+
camera_eye=dict(x=1.5, y=1.5, z=0.5)
|
| 112 |
),
|
| 113 |
margin=dict(l=5, r=5, b=5, t=45),
|
| 114 |
+
paper_bgcolor='rgba(0,0,0,0)'
|
| 115 |
)
|
| 116 |
return fig
|
| 117 |
|
|
|
|
| 122 |
fig.update_layout(
|
| 123 |
xaxis={'visible': False},
|
| 124 |
yaxis={'visible': False},
|
| 125 |
+
height=300,
|
| 126 |
paper_bgcolor='rgba(0,0,0,0)',
|
| 127 |
plot_bgcolor='rgba(0,0,0,0)'
|
| 128 |
)
|
|
|
|
| 134 |
error_html = f"<p style='color:red;'>Initialization Error: {html.escape(MODEL_LOADING_ERROR_MESSAGE)}</p>"
|
| 135 |
empty_df = pd.DataFrame(columns=['token', 'score'])
|
| 136 |
empty_fig = create_empty_plotly_figure("Model Loading Failed")
|
| 137 |
+
# gr.Label์ ๋ํ ์ค๋ฅ ๋ฐํ๊ฐ ์์
|
| 138 |
+
return error_html, [], "Model Loading Failed", {"Status":"Error", "Message":"Model Loading Failed"}, [], empty_df, empty_fig
|
| 139 |
|
| 140 |
try:
|
| 141 |
tokenizer, model = TOKENIZER_GLOBAL, MODEL_GLOBAL
|
|
|
|
| 147 |
if input_ids.shape[1] == 0:
|
| 148 |
empty_df = pd.DataFrame(columns=['token', 'score'])
|
| 149 |
empty_fig = create_empty_plotly_figure("Invalid Input")
|
| 150 |
+
return "<p style='color:orange;'>Input Error: No valid tokens found.</p>", [], "Input Error", {"Status":"Error", "Message":"Invalid Input"}, [], empty_df, empty_fig
|
| 151 |
|
| 152 |
input_embeds_detached = model.embeddings.word_embeddings(input_ids).clone().detach()
|
| 153 |
input_embeds_for_grad = input_embeds_detached.clone().requires_grad_(True)
|
|
|
|
| 166 |
if input_embeds_for_grad.grad is None:
|
| 167 |
empty_df = pd.DataFrame(columns=['token', 'score'])
|
| 168 |
empty_fig = create_empty_plotly_figure("Gradient Error")
|
| 169 |
+
return "<p style='color:red;'>Analysis Error: Gradient calculation failed.</p>", [],"Analysis Error", {"Status":"Error", "Message":"Gradient Error"}, [], empty_df, empty_fig
|
| 170 |
|
| 171 |
grads = input_embeds_for_grad.grad.clone().detach()
|
| 172 |
scores = (grads * input_embeds_detached).norm(dim=2).squeeze(0)
|
| 173 |
scores_np = scores.cpu().numpy()
|
| 174 |
+
valid_scores_for_norm = scores_np[np.isfinite(scores_np)]
|
| 175 |
scores_np = scores_np / (valid_scores_for_norm.max() + 1e-9) if len(valid_scores_for_norm) > 0 and valid_scores_for_norm.max() > 0 else np.zeros_like(scores_np)
|
| 176 |
|
| 177 |
tokens_raw = tokenizer.convert_ids_to_tokens(input_ids[0], skip_special_tokens=False)
|
|
|
|
| 192 |
html_tokens_list.append(f"<span style='font-weight:bold;'>{html.escape(clean_tok_str)}</span>")
|
| 193 |
highlighted_text_data.append((clean_tok_str + " ", None))
|
| 194 |
else:
|
| 195 |
+
color = f"rgba(220, 50, 50, {current_score_clipped:.2f})"
|
| 196 |
html_tokens_list.append(f"<span style='background-color:{color}; color:white; padding: 1px 3px; margin: 1px; border-radius: 4px; display:inline-block;'>{html.escape(clean_tok_str)}</span>")
|
| 197 |
highlighted_text_data.append((clean_tok_str + " ", round(current_score_clipped, 3)))
|
| 198 |
|
|
|
|
| 223 |
pca_tokens = [actual_tokens[i] for i in non_special_token_indices]
|
| 224 |
if len(pca_tokens) > 0:
|
| 225 |
pca_embeddings = actual_input_embeds[non_special_token_indices, :]
|
| 226 |
+
pca_scores_for_plot = actual_scores_np[non_special_token_indices]
|
| 227 |
|
| 228 |
pca = PCA(n_components=3, random_state=SEED)
|
| 229 |
token_embeddings_3d = pca.fit_transform(pca_embeddings)
|
|
|
|
| 241 |
print(f"analyze_sentence_for_gradio error: {e}\n{tb_str}")
|
| 242 |
empty_df = pd.DataFrame(columns=['token', 'score'])
|
| 243 |
empty_fig = create_empty_plotly_figure("Analysis Error")
|
| 244 |
+
# gr.Label์ ๋ํ ์ค๋ฅ ๋ฐํ๊ฐ ์์
|
| 245 |
+
return error_html, [], "Analysis Failed", {"Status":"Error", "Message": str(e)}, [], empty_df, empty_fig
|
| 246 |
|
| 247 |
# --- Gradio UI Definition (Translated and Enhanced) ---
|
|
|
|
| 248 |
theme = gr.themes.Monochrome(
|
| 249 |
primary_hue=gr.themes.colors.blue,
|
| 250 |
secondary_hue=gr.themes.colors.sky,
|
|
|
|
| 256 |
button_primary_text_color="white",
|
| 257 |
)
|
| 258 |
|
|
|
|
| 259 |
with gr.Blocks(title="AI Sentence Analyzer XAI ๐", theme=theme, css=".gradio-container {max-width: 98% !important;}") as demo:
|
| 260 |
gr.Markdown("# ๐ AI Sentence Analyzer XAI: Exploring Model Explanations")
|
| 261 |
gr.Markdown("Analyze English sentences to understand BERT model predictions through various XAI visualization techniques. "
|
| 262 |
"Explore token importance and their distribution in the embedding space.")
|
| 263 |
|
| 264 |
with gr.Row(equal_height=False):
|
| 265 |
+
with gr.Column(scale=1, min_width=350):
|
| 266 |
with gr.Group():
|
| 267 |
gr.Markdown("### โ๏ธ Input Sentence & Settings")
|
| 268 |
input_sentence = gr.Textbox(lines=5, label="English Sentence to Analyze", placeholder="Enter the English sentence you want to analyze here...")
|
|
|
|
| 284 |
output_highlighted_text = gr.HighlightedText(
|
| 285 |
label="Token Importance (Score: 0-1)",
|
| 286 |
show_legend=True,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 287 |
combine_adjacent=False
|
| 288 |
)
|
| 289 |
with gr.TabItem("๐ Top-K Bar Plot", id=2):
|
|
|
|
| 292 |
x="token",
|
| 293 |
y="score",
|
| 294 |
tooltip=['token', 'score'],
|
| 295 |
+
min_width=300
|
|
|
|
| 296 |
)
|
| 297 |
with gr.TabItem("๐ Token Embeddings 3D PCA (Interactive)", id=3):
|
| 298 |
output_pca_plot = gr.Plot(label="3D PCA of Token Embeddings (Colored by Importance Score)")
|
|
|
|
| 312 |
output_pca_plot
|
| 313 |
],
|
| 314 |
fn=analyze_sentence_for_gradio,
|
| 315 |
+
cache_examples=False
|
| 316 |
)
|
| 317 |
gr.HTML("<p style='text-align: center; color: #4a5568;'>Explainable AI Demo powered by Gradio & Hugging Face Transformers</p>")
|
| 318 |
|