import gradio as gr
import os
import joblib
import torch
import numpy as np
import html
from transformers import AutoTokenizer, AutoModel, logging as hf_logging
import pandas as pd
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
import plotly.graph_objects as go
# --- Global Settings and Model Loading ---
hf_logging.set_verbosity_error()
MODEL_NAME = "bert-base-uncased"
DEVICE = "cpu"
SAVE_DIR = "저장저장1"
LAYER_ID = 4
SEED = 0
CLF_NAME = "linear"
CLASS_LABEL_MAP = {
0: "World",
1: "Sports",
2: "Business",
3: "Sci/Tech"
}
TOKENIZER_GLOBAL, MODEL_GLOBAL = None, None
W_GLOBAL, MU_GLOBAL, W_P_GLOBAL, B_P_GLOBAL = None, None, None, None
MODELS_LOADED_SUCCESSFULLY = False
MODEL_LOADING_ERROR_MESSAGE = ""
try:
print("Gradio App: Initializing model loading...")
lda_file_path = os.path.join(SAVE_DIR, f"lda_layer{LAYER_ID}_seed{SEED}.pkl")
clf_file_path = os.path.join(SAVE_DIR, f"{CLF_NAME}_layer{LAYER_ID}_projlda_seed{SEED}.pkl")
if not os.path.isdir(SAVE_DIR):
raise FileNotFoundError(f"Error: Model storage directory '{SAVE_DIR}' not found.")
if not os.path.exists(lda_file_path):
raise FileNotFoundError(f"Error: LDA model file '{lda_file_path}' not found.")
if not os.path.exists(clf_file_path):
raise FileNotFoundError(f"Error: Classifier model file '{clf_file_path}' not found.")
lda = joblib.load(lda_file_path)
clf = joblib.load(clf_file_path)
if hasattr(clf, "base_estimator"): clf = clf.base_estimator
W_GLOBAL = torch.tensor(lda.scalings_, dtype=torch.float32, device=DEVICE)
MU_GLOBAL = torch.tensor(lda.xbar_, dtype=torch.float32, device=DEVICE)
W_P_GLOBAL = torch.tensor(clf.coef_, dtype=torch.float32, device=DEVICE)
B_P_GLOBAL = torch.tensor(clf.intercept_, dtype=torch.float32, device=DEVICE)
TOKENIZER_GLOBAL = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
MODEL_GLOBAL = AutoModel.from_pretrained(
MODEL_NAME, output_hidden_states=True, output_attentions=False
).to(DEVICE).eval()
MODELS_LOADED_SUCCESSFULLY = True
print("Gradio App: All models and data loaded successfully!")
except Exception as e:
MODELS_LOADED_SUCCESSFULLY = False
MODEL_LOADING_ERROR_MESSAGE = f"Critical error during model loading: {str(e)}\nPlease ensure the '{SAVE_DIR}' folder and its contents are correct."
print(MODEL_LOADING_ERROR_MESSAGE)
# Helper function: 3D PCA Visualization using Plotly
def plot_token_pca_3d_plotly(token_embeddings_3d, tokens, scores, title="Token Embeddings 3D PCA (Colored by Importance)"):
num_annotations = min(len(tokens), 20)
scores_array = np.array(scores).flatten()
text_annotations = [''] * len(tokens)
if len(scores_array) > 0 and len(tokens) > 0:
indices_to_annotate = np.argsort(scores_array)[-num_annotations:]
for i in indices_to_annotate:
if i < len(tokens):
text_annotations[i] = tokens[i]
fig = go.Figure(data=[go.Scatter3d(
x=token_embeddings_3d[:, 0],
y=token_embeddings_3d[:, 1],
z=token_embeddings_3d[:, 2],
mode='markers+text',
text=text_annotations,
textfont=dict(size=9, color='#333333'),
textposition='top center',
marker=dict(
size=6,
color=scores_array,
colorscale='RdBu',
reversescale=True,
opacity=0.8,
colorbar=dict(title='Importance', tickfont=dict(size=9), len=0.75, yanchor='middle')
),
hoverinfo='text',
hovertext=[f"Token: {t}
Score: {s:.3f}" for t, s in zip(tokens, scores_array)]
)])
fig.update_layout(
title=dict(text=title, x=0.5, font=dict(size=16)),
scene=dict(
xaxis=dict(title=dict(text='PCA Comp 1', font=dict(size=10)), tickfont=dict(size=9), backgroundcolor="rgba(230, 230, 230, 0.8)"),
yaxis=dict(title=dict(text='PCA Comp 2', font=dict(size=10)), tickfont=dict(size=9), backgroundcolor="rgba(230, 230, 230, 0.8)"),
zaxis=dict(title=dict(text='PCA Comp 3', font=dict(size=10)), tickfont=dict(size=9), backgroundcolor="rgba(230, 230, 230, 0.8)"),
bgcolor="rgba(255, 255, 255, 0.95)",
camera_eye=dict(x=1.5, y=1.5, z=0.5)
),
margin=dict(l=5, r=5, b=5, t=45),
paper_bgcolor='rgba(0,0,0,0)'
)
return fig
# Helper function: Create an empty Plotly figure for placeholders
def create_empty_plotly_figure(message="N/A"):
fig = go.Figure()
fig.add_annotation(text=message, xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False, font=dict(size=12, color="grey"))
fig.update_layout(
xaxis={'visible': False},
yaxis={'visible': False},
height=300,
paper_bgcolor='rgba(0,0,0,0)',
plot_bgcolor='rgba(0,0,0,0)'
)
return fig
# --- Core Analysis Function (returns 7 items for Gradio UI) ---
def analyze_sentence_for_gradio(sentence_text, top_k_value):
if not MODELS_LOADED_SUCCESSFULLY:
error_html = f"
Initialization Error: {html.escape(MODEL_LOADING_ERROR_MESSAGE)}
" empty_df = pd.DataFrame(columns=['token', 'score']) empty_fig = create_empty_plotly_figure("Model Loading Failed") # gr.Label에 대한 오류 반환값 수정 (단순 딕셔너리 또는 문자열) error_label_output = {"Status": "Error", "Message": "Model Loading Failed. Check logs."} return error_html, [], "Model Loading Failed", error_label_output, [], empty_df, empty_fig try: tokenizer, model = TOKENIZER_GLOBAL, MODEL_GLOBAL W, mu, w_p, b_p = W_GLOBAL, MU_GLOBAL, W_P_GLOBAL, B_P_GLOBAL enc = tokenizer(sentence_text, return_tensors="pt", truncation=True, max_length=510, padding=True) input_ids, attn_mask = enc["input_ids"].to(DEVICE), enc["attention_mask"].to(DEVICE) if input_ids.shape[1] == 0: empty_df = pd.DataFrame(columns=['token', 'score']) empty_fig = create_empty_plotly_figure("Invalid Input") error_label_output = {"Status": "Error", "Message": "Invalid input, no valid tokens."} return "Input Error: No valid tokens found.
", [], "Input Error", error_label_output, [], empty_df, empty_fig input_embeds_detached = model.embeddings.word_embeddings(input_ids).clone().detach() input_embeds_for_grad = input_embeds_detached.clone().requires_grad_(True) outputs = model(inputs_embeds=input_embeds_for_grad, attention_mask=attn_mask, output_hidden_states=True, output_attentions=False) cls_vec = outputs.hidden_states[LAYER_ID][:, 0, :] z_projected = (cls_vec - mu) @ W logit_output = z_projected @ w_p.T + b_p probs = torch.softmax(logit_output, dim=1) pred_idx, pred_prob_val = torch.argmax(probs, dim=1).item(), probs[0, torch.argmax(probs, dim=1).item()].item() if input_embeds_for_grad.grad is not None: input_embeds_for_grad.grad.zero_() logit_output[0, pred_idx].backward() if input_embeds_for_grad.grad is None: empty_df = pd.DataFrame(columns=['token', 'score']) empty_fig = create_empty_plotly_figure("Gradient Error") error_label_output = {"Status": "Error", "Message": "Gradient calculation failed."} return "Analysis Error: Gradient calculation failed.
", [],"Analysis Error", error_label_output, [], empty_df, empty_fig grads = input_embeds_for_grad.grad.clone().detach() scores = (grads * input_embeds_detached).norm(dim=2).squeeze(0) scores_np = scores.cpu().numpy() valid_scores_for_norm = scores_np[np.isfinite(scores_np)] scores_np = scores_np / (valid_scores_for_norm.max() + 1e-9) if len(valid_scores_for_norm) > 0 and valid_scores_for_norm.max() > 0 else np.zeros_like(scores_np) tokens_raw = tokenizer.convert_ids_to_tokens(input_ids[0], skip_special_tokens=False) actual_tokens = [tok for i, tok in enumerate(tokens_raw) if input_ids[0,i] != tokenizer.pad_token_id] actual_scores_np = scores_np[:len(actual_tokens)] actual_input_embeds = input_embeds_detached[0, :len(actual_tokens), :].cpu().numpy() html_tokens_list, highlighted_text_data = [], [] cls_token_id, sep_token_id = tokenizer.cls_token_id, tokenizer.sep_token_id for i, tok_str in enumerate(actual_tokens): clean_tok_str = tok_str.replace("##", "") if "##" not in tok_str else tok_str[2:] current_score = actual_scores_np[i] current_score_clipped = max(0, min(1, current_score)) current_token_id = input_ids[0, i].item() if current_token_id == cls_token_id or current_token_id == sep_token_id: html_tokens_list.append(f"{html.escape(clean_tok_str)}") highlighted_text_data.append((clean_tok_str + " ", None)) else: color = f"rgba(220, 50, 50, {current_score_clipped:.2f})" html_tokens_list.append(f"{html.escape(clean_tok_str)}") highlighted_text_data.append((clean_tok_str + " ", round(current_score_clipped, 3))) html_output_str = " ".join(html_tokens_list).replace(" ##", "") top_tokens_for_df, top_tokens_for_barplot_list = [], [] valid_indices = [idx for idx, token_id in enumerate(input_ids[0,:len(actual_tokens)].tolist()) if token_id not in [cls_token_id, sep_token_id]] sorted_valid_indices = sorted(valid_indices, key=lambda idx: -actual_scores_np[idx]) for token_idx in sorted_valid_indices[:top_k_value]: token_str = actual_tokens[token_idx] score_val_str = f"{actual_scores_np[token_idx]:.3f}" top_tokens_for_df.append([token_str, score_val_str]) top_tokens_for_barplot_list.append({"token": token_str, "score": actual_scores_np[token_idx]}) barplot_df = pd.DataFrame(top_tokens_for_barplot_list) if top_tokens_for_barplot_list else pd.DataFrame(columns=['token', 'score']) predicted_class_label_str = CLASS_LABEL_MAP.get(pred_idx, f"Unknown Index ({pred_idx})") prediction_summary_text = f"Predicted Class: {predicted_class_label_str}\nProbability: {pred_prob_val:.3f}" # 수정된 부분: gr.Label에 적합한 딕셔너리 형태 (클래스명: 확률값) prediction_details_for_label = {predicted_class_label_str: float(f"{pred_prob_val:.3f}")} # 확률값을 float으로 전달 pca_fig = create_empty_plotly_figure("PCA Plot N/A\n(Not enough non-special tokens for 3D)") non_special_token_indices = [idx for idx, token_id in enumerate(input_ids[0,:len(actual_tokens)].tolist()) if token_id not in [cls_token_id, sep_token_id]] if len(non_special_token_indices) >= 3 : pca_tokens = [actual_tokens[i] for i in non_special_token_indices] if len(pca_tokens) > 0: pca_embeddings = actual_input_embeds[non_special_token_indices, :] pca_scores_for_plot = actual_scores_np[non_special_token_indices] pca = PCA(n_components=3, random_state=SEED) token_embeddings_3d = pca.fit_transform(pca_embeddings) pca_fig = plot_token_pca_3d_plotly(token_embeddings_3d, pca_tokens, pca_scores_for_plot) return (html_output_str, highlighted_text_data, prediction_summary_text, prediction_details_for_label, top_tokens_for_df, barplot_df, pca_fig) except Exception as e: import traceback tb_str = traceback.format_exc() error_html = f"Analysis Error: {html.escape(str(e))}
{html.escape(tb_str)}"
print(f"analyze_sentence_for_gradio error: {e}\n{tb_str}")
empty_df = pd.DataFrame(columns=['token', 'score'])
empty_fig = create_empty_plotly_figure("Analysis Error")
# gr.Label에 대한 오류 반환값 수정
error_label_output = {"Status": "Error", "Message": f"Analysis failed: {str(e)}"}
return error_html, [], "Analysis Failed", error_label_output, [], empty_df, empty_fig
# --- Gradio UI Definition (Translated and Enhanced) ---
theme = gr.themes.Monochrome(
primary_hue=gr.themes.colors.blue,
secondary_hue=gr.themes.colors.sky,
neutral_hue=gr.themes.colors.slate
).set(
body_background_fill="#f0f2f6",
block_shadow="*shadow_drop_lg",
button_primary_background_fill="*primary_500",
button_primary_text_color="white",
)
with gr.Blocks(title="AI Sentence Analyzer XAI 🚀", theme=theme, css=".gradio-container {max-width: 98% !important;}") as demo:
gr.Markdown("# 🚀 AI Sentence Analyzer XAI: Exploring Model Explanations")
gr.Markdown("Analyze English sentences to understand BERT model predictions through various XAI visualization techniques. "
"Explore token importance and their distribution in the embedding space.")
with gr.Row(equal_height=False):
with gr.Column(scale=1, min_width=350):
with gr.Group():
gr.Markdown("### ✏️ Input Sentence & Settings")
input_sentence = gr.Textbox(lines=5, label="English Sentence to Analyze", placeholder="Enter the English sentence you want to analyze here...")
input_top_k = gr.Slider(minimum=1, maximum=10, value=5, step=1, label="Number of Top-K Tokens")
submit_button = gr.Button("Analyze Sentence 💫", variant="primary")
with gr.Column(scale=2):
with gr.Accordion("🎯 Prediction Outcome", open=True):
output_prediction_summary = gr.Textbox(label="Prediction Summary", lines=2, interactive=False)
output_prediction_details = gr.Label(label="Prediction Details & Confidence") # 레이블 이름 변경
with gr.Accordion("⭐ Top-K Important Tokens (Table)", open=True):
output_top_tokens_df = gr.DataFrame(headers=["Token", "Score"], label="Most Important Tokens",
row_count=(1,"dynamic"), col_count=(2,"fixed"), interactive=False, wrap=True)
with gr.Tabs() as tabs:
with gr.TabItem("🎨 HTML Highlight (Custom)", id=0):
output_html_visualization = gr.HTML(label="Token Importance (Gradient x Input based)")
with gr.TabItem("🖍️ Highlighted Text (Gradio)", id=1):
output_highlighted_text = gr.HighlightedText(
label="Token Importance (Score: 0-1)",
show_legend=True,
combine_adjacent=False
)
with gr.TabItem("📊 Top-K Bar Plot", id=2):
output_top_tokens_barplot = gr.BarPlot(
label="Top-K Token Importance Scores",
x="token",
y="score",
tooltip=['token', 'score'],
min_width=300
)
with gr.TabItem("🌐 Token Embeddings 3D PCA (Interactive)", id=3):
output_pca_plot = gr.Plot(label="3D PCA of Token Embeddings (Colored by Importance Score)")
gr.Markdown("---")
gr.Examples(
examples=[
["This movie is an absolute masterpiece, captivating from start to finish.", 5],
["Despite some flaws, the film offers a compelling narrative.", 3],
["I was thoroughly disappointed with the lackluster performance and predictable plot.", 4]
],
inputs=[input_sentence, input_top_k],
outputs=[
output_html_visualization, output_highlighted_text,
output_prediction_summary, output_prediction_details,
output_top_tokens_df, output_top_tokens_barplot,
output_pca_plot
],
fn=analyze_sentence_for_gradio,
cache_examples=False
)
gr.HTML("Explainable AI Demo powered by Gradio & Hugging Face Transformers
") submit_button.click( fn=analyze_sentence_for_gradio, inputs=[input_sentence, input_top_k], outputs=[ output_html_visualization, output_highlighted_text, output_prediction_summary, output_prediction_details, output_top_tokens_df, output_top_tokens_barplot, output_pca_plot ], api_name="explain_sentence_xai" ) if __name__ == "__main__": if not MODELS_LOADED_SUCCESSFULLY: print("*"*80) print(f"WARNING: Models failed to load! {MODEL_LOADING_ERROR_MESSAGE}") print("The Gradio UI will be displayed, but analysis will fail.") print("*"*80) demo.launch()