Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| import joblib | |
| import torch | |
| import numpy as np | |
| import html | |
| from transformers import AutoTokenizer, AutoModel, logging as hf_logging | |
| import pandas as pd | |
| import matplotlib | |
| matplotlib.use('Agg') # Matplotlib ๋ฐฑ์๋ ์ค์ (Gradio์ ํจ๊ป ์ฌ์ฉ ์ ์ค์) | |
| import matplotlib.pyplot as plt | |
| from sklearn.decomposition import PCA | |
| # --- ๊ธฐ์กด ์ค์ ๋ฐ ์ ์ญ ๋ชจ๋ธ ๋ก๋ ๋ถ๋ถ --- | |
| # Hugging Face Transformers ๋ก๊น ๋ ๋ฒจ ์ค์ | |
| hf_logging.set_verbosity_error() | |
| # ์ค์ | |
| MODEL_NAME = "bert-base-uncased" | |
| DEVICE = "cpu" | |
| SAVE_DIR = "์ ์ฅ์ ์ฅ1" | |
| LAYER_ID = 4 | |
| SEED = 0 | |
| CLF_NAME = "linear" | |
| # ์ ์ญ ๋ชจ๋ธ ๋ก๋ | |
| TOKENIZER_GLOBAL, MODEL_GLOBAL = None, None | |
| W_GLOBAL, MU_GLOBAL, W_P_GLOBAL, B_P_GLOBAL = None, None, None, None | |
| CLASS_NAMES_GLOBAL = None | |
| MODELS_LOADED_SUCCESSFULLY = False | |
| MODEL_LOADING_ERROR_MESSAGE = "" | |
| try: | |
| print("Gradio App: ๋ชจ๋ธ ๋ก๋ฉ์ ์์ํฉ๋๋ค...") | |
| lda_file_path = os.path.join(SAVE_DIR, f"lda_layer{LAYER_ID}_seed{SEED}.pkl") | |
| clf_file_path = os.path.join(SAVE_DIR, f"{CLF_NAME}_layer{LAYER_ID}_projlda_seed{SEED}.pkl") | |
| if not os.path.isdir(SAVE_DIR): | |
| raise FileNotFoundError(f"์ค๋ฅ: ๋ชจ๋ธ ์ ์ฅ ๋๋ ํ ๋ฆฌ '{SAVE_DIR}'๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค.") | |
| if not os.path.exists(lda_file_path): | |
| raise FileNotFoundError(f"์ค๋ฅ: LDA ๋ชจ๋ธ ํ์ผ '{lda_file_path}'๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค.") | |
| if not os.path.exists(clf_file_path): | |
| raise FileNotFoundError(f"์ค๋ฅ: ๋ถ๋ฅ๊ธฐ ๋ชจ๋ธ ํ์ผ '{clf_file_path}'๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค.") | |
| lda = joblib.load(lda_file_path) | |
| clf = joblib.load(clf_file_path) | |
| if hasattr(clf, "base_estimator"): clf = clf.base_estimator | |
| W_GLOBAL = torch.tensor(lda.scalings_, dtype=torch.float32, device=DEVICE) | |
| MU_GLOBAL = torch.tensor(lda.xbar_, dtype=torch.float32, device=DEVICE) | |
| W_P_GLOBAL = torch.tensor(clf.coef_, dtype=torch.float32, device=DEVICE) | |
| B_P_GLOBAL = torch.tensor(clf.intercept_, dtype=torch.float32, device=DEVICE) | |
| TOKENIZER_GLOBAL = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True) | |
| MODEL_GLOBAL = AutoModel.from_pretrained( | |
| MODEL_NAME, output_hidden_states=True, output_attentions=False | |
| ).to(DEVICE).eval() | |
| if hasattr(lda, 'classes_'): CLASS_NAMES_GLOBAL = lda.classes_ | |
| elif hasattr(clf, 'classes_'): CLASS_NAMES_GLOBAL = clf.classes_ | |
| MODELS_LOADED_SUCCESSFULLY = True | |
| print("Gradio App: ๋ชจ๋ ๋ชจ๋ธ ๋ฐ ๋ฐ์ดํฐ ๋ก๋ ์ฑ๊ณต!") | |
| except Exception as e: | |
| MODELS_LOADED_SUCCESSFULLY = False | |
| MODEL_LOADING_ERROR_MESSAGE = f"๋ชจ๋ธ ๋ก๋ฉ ์ค ์ฌ๊ฐํ ์ค๋ฅ ๋ฐ์: {str(e)}\n'์ ์ฅ์ ์ฅ1' ํด๋์ ๋ด์ฉ๋ฌผ์ ํ์ธํด์ฃผ์ธ์." | |
| print(MODEL_LOADING_ERROR_MESSAGE) | |
| # ํฌํผ ํจ์: PCA ์๊ฐํ (3D) | |
| def plot_token_pca_3d(token_embeddings_3d, tokens, scores, title="Token Embeddings 3D PCA (Colored by Importance)"): | |
| fig = plt.figure(figsize=(10, 8)) | |
| ax = fig.add_subplot(111, projection='3d') | |
| num_annotations = min(len(tokens), 15) | |
| if len(scores) > 0 and len(tokens) > 0: # scores์ tokens๊ฐ ๋น์ด์์ง ์์์ง ํ์ธ | |
| # scores๊ฐ NumPy ๋ฐฐ์ด์ด ์๋ ์ ์์ผ๋ฏ๋ก, ๋ฆฌ์คํธ์ธ ๊ฒฝ์ฐ np.array๋ก ๋ณํ | |
| scores_np_array = np.array(scores) | |
| indices_to_annotate = np.argsort(scores_np_array)[-num_annotations:] | |
| else: | |
| indices_to_annotate = np.array([]) | |
| scatter = ax.scatter(token_embeddings_3d[:, 0], token_embeddings_3d[:, 1], token_embeddings_3d[:, 2], | |
| c=scores, cmap="coolwarm_r", s=50, alpha=0.8, depthshade=True) | |
| for i in range(len(tokens)): | |
| if i in indices_to_annotate: | |
| ax.text(token_embeddings_3d[i, 0], token_embeddings_3d[i, 1], token_embeddings_3d[i, 2], | |
| f' {tokens[i]}', size=8, zorder=1, color='k') | |
| ax.set_title(title, fontsize=14) | |
| ax.set_xlabel("PCA Component 1", fontsize=10) | |
| ax.set_ylabel("PCA Component 2", fontsize=10) | |
| ax.set_zlabel("PCA Component 3", fontsize=10) | |
| cbar = plt.colorbar(scatter, label="Importance Score", shrink=0.7) | |
| cbar.ax.tick_params(labelsize=8) | |
| ax.tick_params(axis='both', which='major', labelsize=8) | |
| plt.tight_layout() | |
| return fig | |
| # โโโโโโโโโโ ํต์ฌ ๋ถ์ ํจ์ (๋ฐํ ๊ฐ 7๊ฐ) โโโโโโโโโโ | |
| def analyze_sentence_for_gradio(sentence_text, top_k_value): | |
| def create_empty_plot(message="N/A"): | |
| fig = plt.figure(figsize=(2,2)); | |
| ax = fig.add_subplot(111) | |
| ax.text(0.5, 0.5, message, ha='center', va='center', fontsize=10) | |
| ax.axis('off') | |
| return fig | |
| if not MODELS_LOADED_SUCCESSFULLY: | |
| error_html = f"<p style='color:red;'>์ด๊ธฐํ ์ค๋ฅ: {html.escape(MODEL_LOADING_ERROR_MESSAGE)}</p>" | |
| empty_df = pd.DataFrame(columns=['token', 'score']) | |
| empty_fig_placeholder = create_empty_plot() | |
| return error_html, [], "๋ชจ๋ธ ๋ก๋ฉ ์คํจ", {"์ค๋ฅ":"๋ชจ๋ธ ๋ก๋ฉ ์คํจ"}, [], empty_df, empty_fig_placeholder | |
| try: | |
| tokenizer, model = TOKENIZER_GLOBAL, MODEL_GLOBAL | |
| W, mu, w_p, b_p = W_GLOBAL, MU_GLOBAL, W_P_GLOBAL, B_P_GLOBAL | |
| class_names = CLASS_NAMES_GLOBAL | |
| enc = tokenizer(sentence_text, return_tensors="pt", truncation=True, max_length=510, padding=True) | |
| input_ids, attn_mask = enc["input_ids"].to(DEVICE), enc["attention_mask"].to(DEVICE) | |
| if input_ids.shape[1] == 0: | |
| empty_df = pd.DataFrame(columns=['token', 'score']) | |
| empty_fig_placeholder = create_empty_plot() | |
| return "<p style='color:orange;'>์ ๋ ฅ ์ค๋ฅ: ์ ํจํ ํ ํฐ์ด ์์ต๋๋ค.</p>", [], "์ ๋ ฅ ์ค๋ฅ", {"์ค๋ฅ":"์ ๋ ฅ ์ค๋ฅ"}, [], empty_df, empty_fig_placeholder | |
| input_embeds_detached = model.embeddings.word_embeddings(input_ids).clone().detach() | |
| input_embeds_for_grad = input_embeds_detached.clone().requires_grad_(True) | |
| outputs = model(inputs_embeds=input_embeds_for_grad, attention_mask=attn_mask, | |
| output_hidden_states=True, output_attentions=False) | |
| cls_vec = outputs.hidden_states[LAYER_ID][:, 0, :] | |
| z_projected = (cls_vec - mu) @ W | |
| logit_output = z_projected @ w_p.T + b_p | |
| probs = torch.softmax(logit_output, dim=1) | |
| pred_idx, pred_prob_val = torch.argmax(probs, dim=1).item(), probs[0, torch.argmax(probs, dim=1).item()].item() | |
| if input_embeds_for_grad.grad is not None: input_embeds_for_grad.grad.zero_() | |
| logit_output[0, pred_idx].backward() | |
| if input_embeds_for_grad.grad is None: | |
| empty_df = pd.DataFrame(columns=['token', 'score']) | |
| empty_fig_placeholder = create_empty_plot() | |
| return "<p style='color:red;'>๋ถ์ ์ค๋ฅ: ๊ทธ๋๋์ธํธ ๊ณ์ฐ ์คํจ.</p>", [],"๋ถ์ ์ค๋ฅ", {"์ค๋ฅ":"๋ถ์ ์ค๋ฅ"}, [], empty_df, empty_fig_placeholder | |
| grads = input_embeds_for_grad.grad.clone().detach() | |
| scores = (grads * input_embeds_detached).norm(dim=2).squeeze(0) | |
| scores_np = scores.cpu().numpy() | |
| valid_scores = scores_np[np.isfinite(scores_np)] | |
| scores_np = scores_np / (valid_scores.max() + 1e-9) if len(valid_scores) > 0 and valid_scores.max() > 0 else np.zeros_like(scores_np) | |
| tokens_raw = tokenizer.convert_ids_to_tokens(input_ids[0], skip_special_tokens=False) | |
| actual_tokens = [tok for i, tok in enumerate(tokens_raw) if input_ids[0,i] != tokenizer.pad_token_id] | |
| actual_scores_np = scores_np[:len(actual_tokens)] | |
| actual_input_embeds = input_embeds_detached[0, :len(actual_tokens), :].cpu().numpy() | |
| html_tokens_list, highlighted_text_data = [], [] | |
| cls_token_id, sep_token_id = tokenizer.cls_token_id, tokenizer.sep_token_id | |
| for i, tok_str in enumerate(actual_tokens): | |
| clean_tok_str = tok_str.replace("##", "") if "##" not in tok_str else tok_str[2:] | |
| current_score = actual_scores_np[i] | |
| current_score_clipped = max(0, min(1, current_score)) | |
| current_token_id = input_ids[0, i].item() | |
| if current_token_id == cls_token_id or current_token_id == sep_token_id: | |
| html_tokens_list.append(f"<span style='font-weight:bold;'>{html.escape(clean_tok_str)}</span>") | |
| highlighted_text_data.append((clean_tok_str + " ", None)) | |
| else: | |
| color = f"rgba(255, 0, 0, {current_score_clipped:.2f})" | |
| html_tokens_list.append(f"<span style='background-color:{color}; padding: 1px 2px; margin: 1px; border-radius: 3px; display:inline-block;'>{html.escape(clean_tok_str)}</span>") | |
| highlighted_text_data.append((clean_tok_str + " ", round(current_score_clipped, 3))) | |
| html_output_str = " ".join(html_tokens_list).replace(" ##", "") | |
| top_tokens_for_df, top_tokens_for_barplot_list = [], [] | |
| valid_indices = [idx for idx, token_id in enumerate(input_ids[0,:len(actual_tokens)].tolist()) | |
| if token_id not in [cls_token_id, sep_token_id]] | |
| sorted_valid_indices = sorted(valid_indices, key=lambda idx: -actual_scores_np[idx]) | |
| for token_idx in sorted_valid_indices[:top_k_value]: | |
| token_str = actual_tokens[token_idx] | |
| score_val_str = f"{actual_scores_np[token_idx]:.3f}" | |
| top_tokens_for_df.append([token_str, score_val_str]) | |
| top_tokens_for_barplot_list.append({"token": token_str, "score": actual_scores_np[token_idx]}) | |
| barplot_df = pd.DataFrame(top_tokens_for_barplot_list) if top_tokens_for_barplot_list else pd.DataFrame(columns=['token', 'score']) | |
| predicted_class_label_str = str(pred_idx) | |
| if class_names is not None and 0 <= pred_idx < len(class_names): | |
| predicted_class_label_str = str(class_names[pred_idx]) | |
| prediction_summary_text = f"ํด๋์ค: {predicted_class_label_str}\nํ๋ฅ : {pred_prob_val:.3f}" | |
| prediction_details_for_label = {"์์ธก ํด๋์ค": predicted_class_label_str, "ํ๋ฅ ": f"{pred_prob_val:.3f}"} | |
| pca_fig = create_empty_plot("PCA Plot N/A\n(Not enough non-special tokens for 3D)") | |
| non_special_token_indices = [idx for idx, token_id in enumerate(input_ids[0,:len(actual_tokens)].tolist()) | |
| if token_id not in [cls_token_id, sep_token_id]] | |
| if len(non_special_token_indices) >= 3 : | |
| pca_tokens = [actual_tokens[i] for i in non_special_token_indices] | |
| if len(pca_tokens) > 0: | |
| pca_embeddings = actual_input_embeds[non_special_token_indices, :] | |
| pca_scores = actual_scores_np[non_special_token_indices] | |
| pca = PCA(n_components=3, random_state=SEED) | |
| token_embeddings_3d = pca.fit_transform(pca_embeddings) | |
| # plt.close(pca_fig) # ์ด์ ๋น ๊ทธ๋ฆผ ๋ซ๊ธฐ | |
| pca_fig = plot_token_pca_3d(token_embeddings_3d, pca_tokens, pca_scores) | |
| return (html_output_str, highlighted_text_data, | |
| prediction_summary_text, prediction_details_for_label, | |
| top_tokens_for_df, barplot_df, | |
| pca_fig) | |
| except Exception as e: | |
| import traceback | |
| tb_str = traceback.format_exc() | |
| error_html = f"<p style='color:red;'>๋ถ์ ์ค ์ค๋ฅ ๋ฐ์: {html.escape(str(e))}</p><pre>{html.escape(tb_str)}</pre>" | |
| print(f"Analyze_sentence_for_gradio error: {e}\n{tb_str}") | |
| empty_df = pd.DataFrame(columns=['token', 'score']) | |
| empty_fig_placeholder = create_empty_plot("Error during plot generation") | |
| return error_html, [], "๋ถ์ ์คํจ", {"์ค๋ฅ": str(e)}, [], empty_df, empty_fig_placeholder | |
| # โโโโโโโโโโ Gradio ์ธํฐํ์ด์ค ์ ์ โโโโโโโโโโ | |
| theme = gr.themes.Glass(primary_hue="blue", secondary_hue="cyan", neutral_hue="sky").set( | |
| body_background_fill="linear-gradient(to right, #c9d6ff, #e2e2e2)", | |
| block_background_fill="rgba(255,255,255,0.8)", | |
| block_border_width="1px", | |
| block_shadow="*shadow_drop_lg" | |
| ) | |
| with gr.Blocks(title="AI ๋ฌธ์ฅ ๋ถ์๊ธฐ XAI ๐", theme=theme, css=".gradio-container {max-width: 98% !important;}") as demo: | |
| gr.Markdown("# ๐ AI ๋ฌธ์ฅ ๋ถ์๊ธฐ XAI: ๋ชจ๋ธ ํด์ ํํ") | |
| gr.Markdown("BERT ๋ชจ๋ธ ์์ธก์ ๊ทผ๊ฑฐ๋ฅผ ๋ค์ํ ์๊ฐํ ๊ธฐ๋ฒ์ผ๋ก ํ์ํฉ๋๋ค. ํ ํฐ์ ์ค์๋์ ์๋ฒ ๋ฉ ๊ณต๊ฐ์์์ ๋ถํฌ๋ฅผ ํ์ธํด๋ณด์ธ์.") | |
| with gr.Row(equal_height=False): | |
| with gr.Column(scale=1, min_width=300): | |
| with gr.Group(): | |
| gr.Markdown("### โ๏ธ ๋ฌธ์ฅ ์ ๋ ฅ & ์ค์ ") | |
| input_sentence = gr.Textbox(lines=5, label="๋ถ์ํ ์์ด ๋ฌธ์ฅ", placeholder="์ฌ๊ธฐ์ ๋ถ์ํ๊ณ ์ถ์ ์์ด ๋ฌธ์ฅ์ ์ ๋ ฅํ์ธ์...") | |
| input_top_k = gr.Slider(minimum=1, maximum=10, value=5, step=1, label="Top-K ํ ํฐ ์") | |
| submit_button = gr.Button("๋ถ์ ์์ ๐ซ", variant="primary", scale=1) | |
| with gr.Column(scale=2): | |
| with gr.Accordion("๐ฏ ์์ธก ๊ฒฐ๊ณผ", open=True): | |
| output_prediction_summary = gr.Textbox(label="๊ฐ๋จ ์์ฝ", lines=2, interactive=False) | |
| output_prediction_details = gr.Label(label="์์ธ ์ ๋ณด") | |
| with gr.Accordion("โญ Top-K ์ค์ ํ ํฐ (ํ)", open=True): | |
| output_top_tokens_df = gr.DataFrame(headers=["Token", "Score"], label="์ค์๋ ๋์ ํ ํฐ", | |
| row_count=(1,"dynamic"), col_count=(2,"fixed"), interactive=False, wrap=True) | |
| with gr.Tabs() as tabs: | |
| with gr.TabItem("๐จ HTML ํ์ด๋ผ์ดํธ", id=0): | |
| output_html_visualization = gr.HTML(label="ํ ํฐ๋ณ ์ค์๋ (Gradient x Input)") | |
| with gr.TabItem("๐๏ธ ํ ์คํธ ํ์ด๋ผ์ดํธ", id=1): | |
| output_highlighted_text = gr.HighlightedText( | |
| label="์ค์๋ ๊ธฐ๋ฐ ํ ์คํธ ํ์ด๋ผ์ดํธ (์ ์: 0~1)", | |
| show_legend=True, | |
| combine_adjacent=False | |
| ) | |
| with gr.TabItem("๐ Top-K ๋ง๋ ๊ทธ๋ํ", id=2): | |
| output_top_tokens_barplot = gr.BarPlot( | |
| label="Top-K ํ ํฐ ์ค์๋", | |
| x="token", | |
| y="score", | |
| tooltip=['token', 'score'], # SyntaxError ์์ ๋จ | |
| min_width=300 | |
| ) | |
| with gr.TabItem("๐ ํ ํฐ ์๋ฒ ๋ฉ 3D PCA", id=3): | |
| output_pca_plot = gr.Plot(label="ํ ํฐ ์๋ฒ ๋ฉ 3D PCA (์ค์๋ ์์)") | |
| gr.Markdown("---") | |
| gr.Examples( | |
| examples=[ | |
| ["This movie is an absolute masterpiece, captivating from start to finish.", 5], | |
| ["Despite some flaws, the film offers a compelling narrative.", 3], | |
| ["I was thoroughly disappointed with the lackluster performance and predictable plot.", 4] | |
| ], | |
| inputs=[input_sentence, input_top_k], | |
| outputs=[ | |
| output_html_visualization, output_highlighted_text, | |
| output_prediction_summary, output_prediction_details, | |
| output_top_tokens_df, output_top_tokens_barplot, | |
| output_pca_plot | |
| ], | |
| fn=analyze_sentence_for_gradio, | |
| cache_examples=False | |
| ) | |
| # gr.Markdown์ gr.HTML๋ก ๋ณ๊ฒฝํ์ฌ HTML ํ๊ทธ ์ง์ ์ฌ์ฉ | |
| gr.HTML("<p style='text-align: center; color: #666;'>Explainable AI Demo with Gradio & Transformers</p>") | |
| submit_button.click( | |
| fn=analyze_sentence_for_gradio, | |
| inputs=[input_sentence, input_top_k], | |
| outputs=[ | |
| output_html_visualization, output_highlighted_text, | |
| output_prediction_summary, output_prediction_details, | |
| output_top_tokens_df, output_top_tokens_barplot, | |
| output_pca_plot | |
| ], | |
| api_name="explain_sentence_xai" | |
| ) | |
| if __name__ == "__main__": | |
| if not MODELS_LOADED_SUCCESSFULLY: | |
| print("*"*80) | |
| print(f"๊ฒฝ๊ณ : ๋ชจ๋ธ ๋ก๋ฉ ์คํจ! {MODEL_LOADING_ERROR_MESSAGE}") | |
| print("Gradio UI๋ ํ์๋์ง๋ง ๋ถ์ ๊ธฐ๋ฅ์ด ์ ๋๋ก ์๋ํ์ง ์์ต๋๋ค.") | |
| print("*"*80) | |
| demo.launch() |