kikikara's picture
Update app.py
2b56bba verified
raw
history blame
16.3 kB
import gradio as gr
import os
import joblib
import torch
import numpy as np
import html
from transformers import AutoTokenizer, AutoModel, logging as hf_logging
import pandas as pd
import matplotlib
matplotlib.use('Agg') # Matplotlib ๋ฐฑ์—”๋“œ ์„ค์ • (Gradio์™€ ํ•จ๊ป˜ ์‚ฌ์šฉ ์‹œ ์ค‘์š”)
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
# --- ๊ธฐ์กด ์„ค์ • ๋ฐ ์ „์—ญ ๋ชจ๋ธ ๋กœ๋“œ ๋ถ€๋ถ„ ---
# Hugging Face Transformers ๋กœ๊น… ๋ ˆ๋ฒจ ์„ค์ •
hf_logging.set_verbosity_error()
# ์„ค์ •
MODEL_NAME = "bert-base-uncased"
DEVICE = "cpu"
SAVE_DIR = "์ €์žฅ์ €์žฅ1"
LAYER_ID = 4
SEED = 0
CLF_NAME = "linear"
# ์ „์—ญ ๋ชจ๋ธ ๋กœ๋“œ
TOKENIZER_GLOBAL, MODEL_GLOBAL = None, None
W_GLOBAL, MU_GLOBAL, W_P_GLOBAL, B_P_GLOBAL = None, None, None, None
CLASS_NAMES_GLOBAL = None
MODELS_LOADED_SUCCESSFULLY = False
MODEL_LOADING_ERROR_MESSAGE = ""
try:
print("Gradio App: ๋ชจ๋ธ ๋กœ๋”ฉ์„ ์‹œ์ž‘ํ•ฉ๋‹ˆ๋‹ค...")
lda_file_path = os.path.join(SAVE_DIR, f"lda_layer{LAYER_ID}_seed{SEED}.pkl")
clf_file_path = os.path.join(SAVE_DIR, f"{CLF_NAME}_layer{LAYER_ID}_projlda_seed{SEED}.pkl")
if not os.path.isdir(SAVE_DIR):
raise FileNotFoundError(f"์˜ค๋ฅ˜: ๋ชจ๋ธ ์ €์žฅ ๋””๋ ‰ํ† ๋ฆฌ '{SAVE_DIR}'๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.")
if not os.path.exists(lda_file_path):
raise FileNotFoundError(f"์˜ค๋ฅ˜: LDA ๋ชจ๋ธ ํŒŒ์ผ '{lda_file_path}'๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.")
if not os.path.exists(clf_file_path):
raise FileNotFoundError(f"์˜ค๋ฅ˜: ๋ถ„๋ฅ˜๊ธฐ ๋ชจ๋ธ ํŒŒ์ผ '{clf_file_path}'๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.")
lda = joblib.load(lda_file_path)
clf = joblib.load(clf_file_path)
if hasattr(clf, "base_estimator"): clf = clf.base_estimator
W_GLOBAL = torch.tensor(lda.scalings_, dtype=torch.float32, device=DEVICE)
MU_GLOBAL = torch.tensor(lda.xbar_, dtype=torch.float32, device=DEVICE)
W_P_GLOBAL = torch.tensor(clf.coef_, dtype=torch.float32, device=DEVICE)
B_P_GLOBAL = torch.tensor(clf.intercept_, dtype=torch.float32, device=DEVICE)
TOKENIZER_GLOBAL = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
MODEL_GLOBAL = AutoModel.from_pretrained(
MODEL_NAME, output_hidden_states=True, output_attentions=False
).to(DEVICE).eval()
if hasattr(lda, 'classes_'): CLASS_NAMES_GLOBAL = lda.classes_
elif hasattr(clf, 'classes_'): CLASS_NAMES_GLOBAL = clf.classes_
MODELS_LOADED_SUCCESSFULLY = True
print("Gradio App: ๋ชจ๋“  ๋ชจ๋ธ ๋ฐ ๋ฐ์ดํ„ฐ ๋กœ๋“œ ์„ฑ๊ณต!")
except Exception as e:
MODELS_LOADED_SUCCESSFULLY = False
MODEL_LOADING_ERROR_MESSAGE = f"๋ชจ๋ธ ๋กœ๋”ฉ ์ค‘ ์‹ฌ๊ฐํ•œ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}\n'์ €์žฅ์ €์žฅ1' ํด๋”์™€ ๋‚ด์šฉ๋ฌผ์„ ํ™•์ธํ•ด์ฃผ์„ธ์š”."
print(MODEL_LOADING_ERROR_MESSAGE)
# ํ—ฌํผ ํ•จ์ˆ˜: PCA ์‹œ๊ฐํ™” (3D)
def plot_token_pca_3d(token_embeddings_3d, tokens, scores, title="Token Embeddings 3D PCA (Colored by Importance)"):
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')
num_annotations = min(len(tokens), 15)
if len(scores) > 0 and len(tokens) > 0: # scores์™€ tokens๊ฐ€ ๋น„์–ด์žˆ์ง€ ์•Š์€์ง€ ํ™•์ธ
# scores๊ฐ€ NumPy ๋ฐฐ์—ด์ด ์•„๋‹ ์ˆ˜ ์žˆ์œผ๋ฏ€๋กœ, ๋ฆฌ์ŠคํŠธ์ธ ๊ฒฝ์šฐ np.array๋กœ ๋ณ€ํ™˜
scores_np_array = np.array(scores)
indices_to_annotate = np.argsort(scores_np_array)[-num_annotations:]
else:
indices_to_annotate = np.array([])
scatter = ax.scatter(token_embeddings_3d[:, 0], token_embeddings_3d[:, 1], token_embeddings_3d[:, 2],
c=scores, cmap="coolwarm_r", s=50, alpha=0.8, depthshade=True)
for i in range(len(tokens)):
if i in indices_to_annotate:
ax.text(token_embeddings_3d[i, 0], token_embeddings_3d[i, 1], token_embeddings_3d[i, 2],
f' {tokens[i]}', size=8, zorder=1, color='k')
ax.set_title(title, fontsize=14)
ax.set_xlabel("PCA Component 1", fontsize=10)
ax.set_ylabel("PCA Component 2", fontsize=10)
ax.set_zlabel("PCA Component 3", fontsize=10)
cbar = plt.colorbar(scatter, label="Importance Score", shrink=0.7)
cbar.ax.tick_params(labelsize=8)
ax.tick_params(axis='both', which='major', labelsize=8)
plt.tight_layout()
return fig
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ ํ•ต์‹ฌ ๋ถ„์„ ํ•จ์ˆ˜ (๋ฐ˜ํ™˜ ๊ฐ’ 7๊ฐœ) โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def analyze_sentence_for_gradio(sentence_text, top_k_value):
def create_empty_plot(message="N/A"):
fig = plt.figure(figsize=(2,2));
ax = fig.add_subplot(111)
ax.text(0.5, 0.5, message, ha='center', va='center', fontsize=10)
ax.axis('off')
return fig
if not MODELS_LOADED_SUCCESSFULLY:
error_html = f"<p style='color:red;'>์ดˆ๊ธฐํ™” ์˜ค๋ฅ˜: {html.escape(MODEL_LOADING_ERROR_MESSAGE)}</p>"
empty_df = pd.DataFrame(columns=['token', 'score'])
empty_fig_placeholder = create_empty_plot()
return error_html, [], "๋ชจ๋ธ ๋กœ๋”ฉ ์‹คํŒจ", {"์˜ค๋ฅ˜":"๋ชจ๋ธ ๋กœ๋”ฉ ์‹คํŒจ"}, [], empty_df, empty_fig_placeholder
try:
tokenizer, model = TOKENIZER_GLOBAL, MODEL_GLOBAL
W, mu, w_p, b_p = W_GLOBAL, MU_GLOBAL, W_P_GLOBAL, B_P_GLOBAL
class_names = CLASS_NAMES_GLOBAL
enc = tokenizer(sentence_text, return_tensors="pt", truncation=True, max_length=510, padding=True)
input_ids, attn_mask = enc["input_ids"].to(DEVICE), enc["attention_mask"].to(DEVICE)
if input_ids.shape[1] == 0:
empty_df = pd.DataFrame(columns=['token', 'score'])
empty_fig_placeholder = create_empty_plot()
return "<p style='color:orange;'>์ž…๋ ฅ ์˜ค๋ฅ˜: ์œ ํšจํ•œ ํ† ํฐ์ด ์—†์Šต๋‹ˆ๋‹ค.</p>", [], "์ž…๋ ฅ ์˜ค๋ฅ˜", {"์˜ค๋ฅ˜":"์ž…๋ ฅ ์˜ค๋ฅ˜"}, [], empty_df, empty_fig_placeholder
input_embeds_detached = model.embeddings.word_embeddings(input_ids).clone().detach()
input_embeds_for_grad = input_embeds_detached.clone().requires_grad_(True)
outputs = model(inputs_embeds=input_embeds_for_grad, attention_mask=attn_mask,
output_hidden_states=True, output_attentions=False)
cls_vec = outputs.hidden_states[LAYER_ID][:, 0, :]
z_projected = (cls_vec - mu) @ W
logit_output = z_projected @ w_p.T + b_p
probs = torch.softmax(logit_output, dim=1)
pred_idx, pred_prob_val = torch.argmax(probs, dim=1).item(), probs[0, torch.argmax(probs, dim=1).item()].item()
if input_embeds_for_grad.grad is not None: input_embeds_for_grad.grad.zero_()
logit_output[0, pred_idx].backward()
if input_embeds_for_grad.grad is None:
empty_df = pd.DataFrame(columns=['token', 'score'])
empty_fig_placeholder = create_empty_plot()
return "<p style='color:red;'>๋ถ„์„ ์˜ค๋ฅ˜: ๊ทธ๋ž˜๋””์–ธํŠธ ๊ณ„์‚ฐ ์‹คํŒจ.</p>", [],"๋ถ„์„ ์˜ค๋ฅ˜", {"์˜ค๋ฅ˜":"๋ถ„์„ ์˜ค๋ฅ˜"}, [], empty_df, empty_fig_placeholder
grads = input_embeds_for_grad.grad.clone().detach()
scores = (grads * input_embeds_detached).norm(dim=2).squeeze(0)
scores_np = scores.cpu().numpy()
valid_scores = scores_np[np.isfinite(scores_np)]
scores_np = scores_np / (valid_scores.max() + 1e-9) if len(valid_scores) > 0 and valid_scores.max() > 0 else np.zeros_like(scores_np)
tokens_raw = tokenizer.convert_ids_to_tokens(input_ids[0], skip_special_tokens=False)
actual_tokens = [tok for i, tok in enumerate(tokens_raw) if input_ids[0,i] != tokenizer.pad_token_id]
actual_scores_np = scores_np[:len(actual_tokens)]
actual_input_embeds = input_embeds_detached[0, :len(actual_tokens), :].cpu().numpy()
html_tokens_list, highlighted_text_data = [], []
cls_token_id, sep_token_id = tokenizer.cls_token_id, tokenizer.sep_token_id
for i, tok_str in enumerate(actual_tokens):
clean_tok_str = tok_str.replace("##", "") if "##" not in tok_str else tok_str[2:]
current_score = actual_scores_np[i]
current_score_clipped = max(0, min(1, current_score))
current_token_id = input_ids[0, i].item()
if current_token_id == cls_token_id or current_token_id == sep_token_id:
html_tokens_list.append(f"<span style='font-weight:bold;'>{html.escape(clean_tok_str)}</span>")
highlighted_text_data.append((clean_tok_str + " ", None))
else:
color = f"rgba(255, 0, 0, {current_score_clipped:.2f})"
html_tokens_list.append(f"<span style='background-color:{color}; padding: 1px 2px; margin: 1px; border-radius: 3px; display:inline-block;'>{html.escape(clean_tok_str)}</span>")
highlighted_text_data.append((clean_tok_str + " ", round(current_score_clipped, 3)))
html_output_str = " ".join(html_tokens_list).replace(" ##", "")
top_tokens_for_df, top_tokens_for_barplot_list = [], []
valid_indices = [idx for idx, token_id in enumerate(input_ids[0,:len(actual_tokens)].tolist())
if token_id not in [cls_token_id, sep_token_id]]
sorted_valid_indices = sorted(valid_indices, key=lambda idx: -actual_scores_np[idx])
for token_idx in sorted_valid_indices[:top_k_value]:
token_str = actual_tokens[token_idx]
score_val_str = f"{actual_scores_np[token_idx]:.3f}"
top_tokens_for_df.append([token_str, score_val_str])
top_tokens_for_barplot_list.append({"token": token_str, "score": actual_scores_np[token_idx]})
barplot_df = pd.DataFrame(top_tokens_for_barplot_list) if top_tokens_for_barplot_list else pd.DataFrame(columns=['token', 'score'])
predicted_class_label_str = str(pred_idx)
if class_names is not None and 0 <= pred_idx < len(class_names):
predicted_class_label_str = str(class_names[pred_idx])
prediction_summary_text = f"ํด๋ž˜์Šค: {predicted_class_label_str}\nํ™•๋ฅ : {pred_prob_val:.3f}"
prediction_details_for_label = {"์˜ˆ์ธก ํด๋ž˜์Šค": predicted_class_label_str, "ํ™•๋ฅ ": f"{pred_prob_val:.3f}"}
pca_fig = create_empty_plot("PCA Plot N/A\n(Not enough non-special tokens for 3D)")
non_special_token_indices = [idx for idx, token_id in enumerate(input_ids[0,:len(actual_tokens)].tolist())
if token_id not in [cls_token_id, sep_token_id]]
if len(non_special_token_indices) >= 3 :
pca_tokens = [actual_tokens[i] for i in non_special_token_indices]
if len(pca_tokens) > 0:
pca_embeddings = actual_input_embeds[non_special_token_indices, :]
pca_scores = actual_scores_np[non_special_token_indices]
pca = PCA(n_components=3, random_state=SEED)
token_embeddings_3d = pca.fit_transform(pca_embeddings)
# plt.close(pca_fig) # ์ด์ „ ๋นˆ ๊ทธ๋ฆผ ๋‹ซ๊ธฐ
pca_fig = plot_token_pca_3d(token_embeddings_3d, pca_tokens, pca_scores)
return (html_output_str, highlighted_text_data,
prediction_summary_text, prediction_details_for_label,
top_tokens_for_df, barplot_df,
pca_fig)
except Exception as e:
import traceback
tb_str = traceback.format_exc()
error_html = f"<p style='color:red;'>๋ถ„์„ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {html.escape(str(e))}</p><pre>{html.escape(tb_str)}</pre>"
print(f"Analyze_sentence_for_gradio error: {e}\n{tb_str}")
empty_df = pd.DataFrame(columns=['token', 'score'])
empty_fig_placeholder = create_empty_plot("Error during plot generation")
return error_html, [], "๋ถ„์„ ์‹คํŒจ", {"์˜ค๋ฅ˜": str(e)}, [], empty_df, empty_fig_placeholder
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ Gradio ์ธํ„ฐํŽ˜์ด์Šค ์ •์˜ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
theme = gr.themes.Glass(primary_hue="blue", secondary_hue="cyan", neutral_hue="sky").set(
body_background_fill="linear-gradient(to right, #c9d6ff, #e2e2e2)",
block_background_fill="rgba(255,255,255,0.8)",
block_border_width="1px",
block_shadow="*shadow_drop_lg"
)
with gr.Blocks(title="AI ๋ฌธ์žฅ ๋ถ„์„๊ธฐ XAI ๐Ÿš€", theme=theme, css=".gradio-container {max-width: 98% !important;}") as demo:
gr.Markdown("# ๐Ÿš€ AI ๋ฌธ์žฅ ๋ถ„์„๊ธฐ XAI: ๋ชจ๋ธ ํ•ด์„ ํƒํ—˜")
gr.Markdown("BERT ๋ชจ๋ธ ์˜ˆ์ธก์˜ ๊ทผ๊ฑฐ๋ฅผ ๋‹ค์–‘ํ•œ ์‹œ๊ฐํ™” ๊ธฐ๋ฒ•์œผ๋กœ ํƒ์ƒ‰ํ•ฉ๋‹ˆ๋‹ค. ํ† ํฐ์˜ ์ค‘์š”๋„์™€ ์ž„๋ฒ ๋”ฉ ๊ณต๊ฐ„์—์„œ์˜ ๋ถ„ํฌ๋ฅผ ํ™•์ธํ•ด๋ณด์„ธ์š”.")
with gr.Row(equal_height=False):
with gr.Column(scale=1, min_width=300):
with gr.Group():
gr.Markdown("### โœ๏ธ ๋ฌธ์žฅ ์ž…๋ ฅ & ์„ค์ •")
input_sentence = gr.Textbox(lines=5, label="๋ถ„์„ํ•  ์˜์–ด ๋ฌธ์žฅ", placeholder="์—ฌ๊ธฐ์— ๋ถ„์„ํ•˜๊ณ  ์‹ถ์€ ์˜์–ด ๋ฌธ์žฅ์„ ์ž…๋ ฅํ•˜์„ธ์š”...")
input_top_k = gr.Slider(minimum=1, maximum=10, value=5, step=1, label="Top-K ํ† ํฐ ์ˆ˜")
submit_button = gr.Button("๋ถ„์„ ์‹œ์ž‘ ๐Ÿ’ซ", variant="primary", scale=1)
with gr.Column(scale=2):
with gr.Accordion("๐ŸŽฏ ์˜ˆ์ธก ๊ฒฐ๊ณผ", open=True):
output_prediction_summary = gr.Textbox(label="๊ฐ„๋‹จ ์š”์•ฝ", lines=2, interactive=False)
output_prediction_details = gr.Label(label="์ƒ์„ธ ์ •๋ณด")
with gr.Accordion("โญ Top-K ์ค‘์š” ํ† ํฐ (ํ‘œ)", open=True):
output_top_tokens_df = gr.DataFrame(headers=["Token", "Score"], label="์ค‘์š”๋„ ๋†’์€ ํ† ํฐ",
row_count=(1,"dynamic"), col_count=(2,"fixed"), interactive=False, wrap=True)
with gr.Tabs() as tabs:
with gr.TabItem("๐ŸŽจ HTML ํ•˜์ด๋ผ์ดํŠธ", id=0):
output_html_visualization = gr.HTML(label="ํ† ํฐ๋ณ„ ์ค‘์š”๋„ (Gradient x Input)")
with gr.TabItem("๐Ÿ–๏ธ ํ…์ŠคํŠธ ํ•˜์ด๋ผ์ดํŠธ", id=1):
output_highlighted_text = gr.HighlightedText(
label="์ค‘์š”๋„ ๊ธฐ๋ฐ˜ ํ…์ŠคํŠธ ํ•˜์ด๋ผ์ดํŠธ (์ ์ˆ˜: 0~1)",
show_legend=True,
combine_adjacent=False
)
with gr.TabItem("๐Ÿ“Š Top-K ๋ง‰๋Œ€ ๊ทธ๋ž˜ํ”„", id=2):
output_top_tokens_barplot = gr.BarPlot(
label="Top-K ํ† ํฐ ์ค‘์š”๋„",
x="token",
y="score",
tooltip=['token', 'score'], # SyntaxError ์ˆ˜์ •๋จ
min_width=300
)
with gr.TabItem("๐ŸŒ ํ† ํฐ ์ž„๋ฒ ๋”ฉ 3D PCA", id=3):
output_pca_plot = gr.Plot(label="ํ† ํฐ ์ž„๋ฒ ๋”ฉ 3D PCA (์ค‘์š”๋„ ์ƒ‰์ƒ)")
gr.Markdown("---")
gr.Examples(
examples=[
["This movie is an absolute masterpiece, captivating from start to finish.", 5],
["Despite some flaws, the film offers a compelling narrative.", 3],
["I was thoroughly disappointed with the lackluster performance and predictable plot.", 4]
],
inputs=[input_sentence, input_top_k],
outputs=[
output_html_visualization, output_highlighted_text,
output_prediction_summary, output_prediction_details,
output_top_tokens_df, output_top_tokens_barplot,
output_pca_plot
],
fn=analyze_sentence_for_gradio,
cache_examples=False
)
# gr.Markdown์„ gr.HTML๋กœ ๋ณ€๊ฒฝํ•˜์—ฌ HTML ํƒœ๊ทธ ์ง์ ‘ ์‚ฌ์šฉ
gr.HTML("<p style='text-align: center; color: #666;'>Explainable AI Demo with Gradio & Transformers</p>")
submit_button.click(
fn=analyze_sentence_for_gradio,
inputs=[input_sentence, input_top_k],
outputs=[
output_html_visualization, output_highlighted_text,
output_prediction_summary, output_prediction_details,
output_top_tokens_df, output_top_tokens_barplot,
output_pca_plot
],
api_name="explain_sentence_xai"
)
if __name__ == "__main__":
if not MODELS_LOADED_SUCCESSFULLY:
print("*"*80)
print(f"๊ฒฝ๊ณ : ๋ชจ๋ธ ๋กœ๋”ฉ ์‹คํŒจ! {MODEL_LOADING_ERROR_MESSAGE}")
print("Gradio UI๋Š” ํ‘œ์‹œ๋˜์ง€๋งŒ ๋ถ„์„ ๊ธฐ๋Šฅ์ด ์ œ๋Œ€๋กœ ์ž‘๋™ํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.")
print("*"*80)
demo.launch()