kikikara's picture
Update app.py
6e12229 verified
raw
history blame
17.2 kB
import gradio as gr
import os
import joblib
import torch
import numpy as np
import html
from transformers import AutoTokenizer, AutoModel, logging as hf_logging
import pandas as pd
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
import plotly.graph_objects as go
# --- Global Settings and Model Loading ---
hf_logging.set_verbosity_error()
MODEL_NAME = "bert-base-uncased"
DEVICE = "cpu"
SAVE_DIR = "μ €μž₯μ €μž₯1"
LAYER_ID = 4
SEED = 0
CLF_NAME = "linear"
CLASS_LABEL_MAP = {
0: "World",
1: "Sports",
2: "Business",
3: "Sci/Tech"
}
TOKENIZER_GLOBAL, MODEL_GLOBAL = None, None
W_GLOBAL, MU_GLOBAL, W_P_GLOBAL, B_P_GLOBAL = None, None, None, None
MODELS_LOADED_SUCCESSFULLY = False
MODEL_LOADING_ERROR_MESSAGE = ""
try:
print("Gradio App: Initializing model loading...")
lda_file_path = os.path.join(SAVE_DIR, f"lda_layer{LAYER_ID}_seed{SEED}.pkl")
clf_file_path = os.path.join(SAVE_DIR, f"{CLF_NAME}_layer{LAYER_ID}_projlda_seed{SEED}.pkl")
if not os.path.isdir(SAVE_DIR):
raise FileNotFoundError(f"Error: Model storage directory '{SAVE_DIR}' not found.")
if not os.path.exists(lda_file_path):
raise FileNotFoundError(f"Error: LDA model file '{lda_file_path}' not found.")
if not os.path.exists(clf_file_path):
raise FileNotFoundError(f"Error: Classifier model file '{clf_file_path}' not found.")
lda = joblib.load(lda_file_path)
clf = joblib.load(clf_file_path)
if hasattr(clf, "base_estimator"): clf = clf.base_estimator
W_GLOBAL = torch.tensor(lda.scalings_, dtype=torch.float32, device=DEVICE)
MU_GLOBAL = torch.tensor(lda.xbar_, dtype=torch.float32, device=DEVICE)
W_P_GLOBAL = torch.tensor(clf.coef_, dtype=torch.float32, device=DEVICE)
B_P_GLOBAL = torch.tensor(clf.intercept_, dtype=torch.float32, device=DEVICE)
TOKENIZER_GLOBAL = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
MODEL_GLOBAL = AutoModel.from_pretrained(
MODEL_NAME, output_hidden_states=True, output_attentions=False
).to(DEVICE).eval()
MODELS_LOADED_SUCCESSFULLY = True
print("Gradio App: All models and data loaded successfully!")
except Exception as e:
MODELS_LOADED_SUCCESSFULLY = False
MODEL_LOADING_ERROR_MESSAGE = f"Critical error during model loading: {str(e)}\nPlease ensure the '{SAVE_DIR}' folder and its contents are correct."
print(MODEL_LOADING_ERROR_MESSAGE)
# Helper function: 3D PCA Visualization using Plotly
def plot_token_pca_3d_plotly(token_embeddings_3d, tokens, scores, title="Token Embeddings 3D PCA (Colored by Importance)"):
num_annotations = min(len(tokens), 20)
scores_array = np.array(scores).flatten()
text_annotations = [''] * len(tokens)
if len(scores_array) > 0 and len(tokens) > 0:
indices_to_annotate = np.argsort(scores_array)[-num_annotations:]
for i in indices_to_annotate:
if i < len(tokens):
text_annotations[i] = tokens[i]
fig = go.Figure(data=[go.Scatter3d(
x=token_embeddings_3d[:, 0],
y=token_embeddings_3d[:, 1],
z=token_embeddings_3d[:, 2],
mode='markers+text',
text=text_annotations,
textfont=dict(size=9, color='#333333'),
textposition='top center',
marker=dict(
size=6,
color=scores_array,
colorscale='RdBu',
reversescale=True,
opacity=0.8,
colorbar=dict(title='Importance', tickfont=dict(size=9), len=0.75, yanchor='middle')
),
hoverinfo='text',
hovertext=[f"Token: {t}<br>Score: {s:.3f}" for t, s in zip(tokens, scores_array)]
)])
fig.update_layout(
title=dict(text=title, x=0.5, font=dict(size=16)),
scene=dict(
xaxis=dict(title=dict(text='PCA Comp 1', font=dict(size=10)), tickfont=dict(size=9), backgroundcolor="rgba(230, 230, 230, 0.8)"),
yaxis=dict(title=dict(text='PCA Comp 2', font=dict(size=10)), tickfont=dict(size=9), backgroundcolor="rgba(230, 230, 230, 0.8)"),
zaxis=dict(title=dict(text='PCA Comp 3', font=dict(size=10)), tickfont=dict(size=9), backgroundcolor="rgba(230, 230, 230, 0.8)"),
bgcolor="rgba(255, 255, 255, 0.95)",
camera_eye=dict(x=1.5, y=1.5, z=0.5)
),
margin=dict(l=5, r=5, b=5, t=45),
paper_bgcolor='rgba(0,0,0,0)'
)
return fig
# Helper function: Create an empty Plotly figure for placeholders
def create_empty_plotly_figure(message="N/A"):
fig = go.Figure()
fig.add_annotation(text=message, xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False, font=dict(size=12, color="grey"))
fig.update_layout(
xaxis={'visible': False},
yaxis={'visible': False},
height=300,
paper_bgcolor='rgba(0,0,0,0)',
plot_bgcolor='rgba(0,0,0,0)'
)
return fig
# --- Core Analysis Function (returns 7 items for Gradio UI) ---
def analyze_sentence_for_gradio(sentence_text, top_k_value):
if not MODELS_LOADED_SUCCESSFULLY:
error_html = f"<p style='color:red;'>Initialization Error: {html.escape(MODEL_LOADING_ERROR_MESSAGE)}</p>"
empty_df = pd.DataFrame(columns=['token', 'score'])
empty_fig = create_empty_plotly_figure("Model Loading Failed")
# gr.Label에 λŒ€ν•œ 였λ₯˜ λ°˜ν™˜κ°’ μˆ˜μ • (λ‹¨μˆœ λ”•μ…”λ„ˆλ¦¬ λ˜λŠ” λ¬Έμžμ—΄)
error_label_output = {"Status": "Error", "Message": "Model Loading Failed. Check logs."}
return error_html, [], "Model Loading Failed", error_label_output, [], empty_df, empty_fig
try:
tokenizer, model = TOKENIZER_GLOBAL, MODEL_GLOBAL
W, mu, w_p, b_p = W_GLOBAL, MU_GLOBAL, W_P_GLOBAL, B_P_GLOBAL
enc = tokenizer(sentence_text, return_tensors="pt", truncation=True, max_length=510, padding=True)
input_ids, attn_mask = enc["input_ids"].to(DEVICE), enc["attention_mask"].to(DEVICE)
if input_ids.shape[1] == 0:
empty_df = pd.DataFrame(columns=['token', 'score'])
empty_fig = create_empty_plotly_figure("Invalid Input")
error_label_output = {"Status": "Error", "Message": "Invalid input, no valid tokens."}
return "<p style='color:orange;'>Input Error: No valid tokens found.</p>", [], "Input Error", error_label_output, [], empty_df, empty_fig
input_embeds_detached = model.embeddings.word_embeddings(input_ids).clone().detach()
input_embeds_for_grad = input_embeds_detached.clone().requires_grad_(True)
outputs = model(inputs_embeds=input_embeds_for_grad, attention_mask=attn_mask,
output_hidden_states=True, output_attentions=False)
cls_vec = outputs.hidden_states[LAYER_ID][:, 0, :]
z_projected = (cls_vec - mu) @ W
logit_output = z_projected @ w_p.T + b_p
probs = torch.softmax(logit_output, dim=1)
pred_idx, pred_prob_val = torch.argmax(probs, dim=1).item(), probs[0, torch.argmax(probs, dim=1).item()].item()
if input_embeds_for_grad.grad is not None: input_embeds_for_grad.grad.zero_()
logit_output[0, pred_idx].backward()
if input_embeds_for_grad.grad is None:
empty_df = pd.DataFrame(columns=['token', 'score'])
empty_fig = create_empty_plotly_figure("Gradient Error")
error_label_output = {"Status": "Error", "Message": "Gradient calculation failed."}
return "<p style='color:red;'>Analysis Error: Gradient calculation failed.</p>", [],"Analysis Error", error_label_output, [], empty_df, empty_fig
grads = input_embeds_for_grad.grad.clone().detach()
scores = (grads * input_embeds_detached).norm(dim=2).squeeze(0)
scores_np = scores.cpu().numpy()
valid_scores_for_norm = scores_np[np.isfinite(scores_np)]
scores_np = scores_np / (valid_scores_for_norm.max() + 1e-9) if len(valid_scores_for_norm) > 0 and valid_scores_for_norm.max() > 0 else np.zeros_like(scores_np)
tokens_raw = tokenizer.convert_ids_to_tokens(input_ids[0], skip_special_tokens=False)
actual_tokens = [tok for i, tok in enumerate(tokens_raw) if input_ids[0,i] != tokenizer.pad_token_id]
actual_scores_np = scores_np[:len(actual_tokens)]
actual_input_embeds = input_embeds_detached[0, :len(actual_tokens), :].cpu().numpy()
html_tokens_list, highlighted_text_data = [], []
cls_token_id, sep_token_id = tokenizer.cls_token_id, tokenizer.sep_token_id
for i, tok_str in enumerate(actual_tokens):
clean_tok_str = tok_str.replace("##", "") if "##" not in tok_str else tok_str[2:]
current_score = actual_scores_np[i]
current_score_clipped = max(0, min(1, current_score))
current_token_id = input_ids[0, i].item()
if current_token_id == cls_token_id or current_token_id == sep_token_id:
html_tokens_list.append(f"<span style='font-weight:bold;'>{html.escape(clean_tok_str)}</span>")
highlighted_text_data.append((clean_tok_str + " ", None))
else:
color = f"rgba(220, 50, 50, {current_score_clipped:.2f})"
html_tokens_list.append(f"<span style='background-color:{color}; color:white; padding: 1px 3px; margin: 1px; border-radius: 4px; display:inline-block;'>{html.escape(clean_tok_str)}</span>")
highlighted_text_data.append((clean_tok_str + " ", round(current_score_clipped, 3)))
html_output_str = " ".join(html_tokens_list).replace(" ##", "")
top_tokens_for_df, top_tokens_for_barplot_list = [], []
valid_indices = [idx for idx, token_id in enumerate(input_ids[0,:len(actual_tokens)].tolist())
if token_id not in [cls_token_id, sep_token_id]]
sorted_valid_indices = sorted(valid_indices, key=lambda idx: -actual_scores_np[idx])
for token_idx in sorted_valid_indices[:top_k_value]:
token_str = actual_tokens[token_idx]
score_val_str = f"{actual_scores_np[token_idx]:.3f}"
top_tokens_for_df.append([token_str, score_val_str])
top_tokens_for_barplot_list.append({"token": token_str, "score": actual_scores_np[token_idx]})
barplot_df = pd.DataFrame(top_tokens_for_barplot_list) if top_tokens_for_barplot_list else pd.DataFrame(columns=['token', 'score'])
predicted_class_label_str = CLASS_LABEL_MAP.get(pred_idx, f"Unknown Index ({pred_idx})")
prediction_summary_text = f"Predicted Class: {predicted_class_label_str}\nProbability: {pred_prob_val:.3f}"
# μˆ˜μ •λœ λΆ€λΆ„: gr.Label에 μ ν•©ν•œ λ”•μ…”λ„ˆλ¦¬ ν˜•νƒœ (클래슀λͺ…: ν™•λ₯ κ°’)
prediction_details_for_label = {predicted_class_label_str: float(f"{pred_prob_val:.3f}")} # ν™•λ₯ κ°’을 float으둜 전달
pca_fig = create_empty_plotly_figure("PCA Plot N/A\n(Not enough non-special tokens for 3D)")
non_special_token_indices = [idx for idx, token_id in enumerate(input_ids[0,:len(actual_tokens)].tolist())
if token_id not in [cls_token_id, sep_token_id]]
if len(non_special_token_indices) >= 3 :
pca_tokens = [actual_tokens[i] for i in non_special_token_indices]
if len(pca_tokens) > 0:
pca_embeddings = actual_input_embeds[non_special_token_indices, :]
pca_scores_for_plot = actual_scores_np[non_special_token_indices]
pca = PCA(n_components=3, random_state=SEED)
token_embeddings_3d = pca.fit_transform(pca_embeddings)
pca_fig = plot_token_pca_3d_plotly(token_embeddings_3d, pca_tokens, pca_scores_for_plot)
return (html_output_str, highlighted_text_data,
prediction_summary_text, prediction_details_for_label,
top_tokens_for_df, barplot_df,
pca_fig)
except Exception as e:
import traceback
tb_str = traceback.format_exc()
error_html = f"<p style='color:red;'>Analysis Error: {html.escape(str(e))}</p><pre>{html.escape(tb_str)}</pre>"
print(f"analyze_sentence_for_gradio error: {e}\n{tb_str}")
empty_df = pd.DataFrame(columns=['token', 'score'])
empty_fig = create_empty_plotly_figure("Analysis Error")
# gr.Label에 λŒ€ν•œ 였λ₯˜ λ°˜ν™˜κ°’ μˆ˜μ •
error_label_output = {"Status": "Error", "Message": f"Analysis failed: {str(e)}"}
return error_html, [], "Analysis Failed", error_label_output, [], empty_df, empty_fig
# --- Gradio UI Definition (Translated and Enhanced) ---
theme = gr.themes.Monochrome(
primary_hue=gr.themes.colors.blue,
secondary_hue=gr.themes.colors.sky,
neutral_hue=gr.themes.colors.slate
).set(
body_background_fill="#f0f2f6",
block_shadow="*shadow_drop_lg",
button_primary_background_fill="*primary_500",
button_primary_text_color="white",
)
with gr.Blocks(title="AI Sentence Analyzer XAI πŸš€", theme=theme, css=".gradio-container {max-width: 98% !important;}") as demo:
gr.Markdown("# πŸš€ AI Sentence Analyzer XAI: Exploring Model Explanations")
gr.Markdown("Analyze English sentences to understand BERT model predictions through various XAI visualization techniques. "
"Explore token importance and their distribution in the embedding space.")
with gr.Row(equal_height=False):
with gr.Column(scale=1, min_width=350):
with gr.Group():
gr.Markdown("### ✏️ Input Sentence & Settings")
input_sentence = gr.Textbox(lines=5, label="English Sentence to Analyze", placeholder="Enter the English sentence you want to analyze here...")
input_top_k = gr.Slider(minimum=1, maximum=10, value=5, step=1, label="Number of Top-K Tokens")
submit_button = gr.Button("Analyze Sentence πŸ’«", variant="primary")
with gr.Column(scale=2):
with gr.Accordion("🎯 Prediction Outcome", open=True):
output_prediction_summary = gr.Textbox(label="Prediction Summary", lines=2, interactive=False)
output_prediction_details = gr.Label(label="Prediction Details & Confidence") # λ ˆμ΄λΈ” 이름 λ³€κ²½
with gr.Accordion("⭐ Top-K Important Tokens (Table)", open=True):
output_top_tokens_df = gr.DataFrame(headers=["Token", "Score"], label="Most Important Tokens",
row_count=(1,"dynamic"), col_count=(2,"fixed"), interactive=False, wrap=True)
with gr.Tabs() as tabs:
with gr.TabItem("🎨 HTML Highlight (Custom)", id=0):
output_html_visualization = gr.HTML(label="Token Importance (Gradient x Input based)")
with gr.TabItem("πŸ–οΈ Highlighted Text (Gradio)", id=1):
output_highlighted_text = gr.HighlightedText(
label="Token Importance (Score: 0-1)",
show_legend=True,
combine_adjacent=False
)
with gr.TabItem("πŸ“Š Top-K Bar Plot", id=2):
output_top_tokens_barplot = gr.BarPlot(
label="Top-K Token Importance Scores",
x="token",
y="score",
tooltip=['token', 'score'],
min_width=300
)
with gr.TabItem("🌐 Token Embeddings 3D PCA (Interactive)", id=3):
output_pca_plot = gr.Plot(label="3D PCA of Token Embeddings (Colored by Importance Score)")
gr.Markdown("---")
gr.Examples(
examples=[
["This movie is an absolute masterpiece, captivating from start to finish.", 5],
["Despite some flaws, the film offers a compelling narrative.", 3],
["I was thoroughly disappointed with the lackluster performance and predictable plot.", 4]
],
inputs=[input_sentence, input_top_k],
outputs=[
output_html_visualization, output_highlighted_text,
output_prediction_summary, output_prediction_details,
output_top_tokens_df, output_top_tokens_barplot,
output_pca_plot
],
fn=analyze_sentence_for_gradio,
cache_examples=False
)
gr.HTML("<p style='text-align: center; color: #4a5568;'>Explainable AI Demo powered by Gradio & Hugging Face Transformers</p>")
submit_button.click(
fn=analyze_sentence_for_gradio,
inputs=[input_sentence, input_top_k],
outputs=[
output_html_visualization, output_highlighted_text,
output_prediction_summary, output_prediction_details,
output_top_tokens_df, output_top_tokens_barplot,
output_pca_plot
],
api_name="explain_sentence_xai"
)
if __name__ == "__main__":
if not MODELS_LOADED_SUCCESSFULLY:
print("*"*80)
print(f"WARNING: Models failed to load! {MODEL_LOADING_ERROR_MESSAGE}")
print("The Gradio UI will be displayed, but analysis will fail.")
print("*"*80)
demo.launch()