Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| import joblib | |
| import torch | |
| import numpy as np | |
| import html | |
| from transformers import AutoTokenizer, AutoModel, logging as hf_logging | |
| # Hugging Face Transformers λ‘κΉ λ 벨 μ€μ | |
| hf_logging.set_verbosity_error() | |
| # ββββββββββ μ€μ ββββββββββ | |
| MODEL_NAME = "bert-base-uncased" | |
| DEVICE = "cpu" | |
| SAVE_DIR = "μ μ₯μ μ₯1" # μ΄ ν΄λκ° app.pyμ κ°μ μμΉμ μμ΄μΌ ν©λλ€. | |
| LAYER_ID = 4 | |
| SEED = 0 | |
| CLF_NAME = "linear" | |
| # ββββββββββ μ μ λͺ¨λΈ λ‘λ (Gradio μ± μμ μ ν λ² μ€ν) ββββββββββ | |
| # Streamlitμ @st.cache_resource λμ , μ± μμ μ λ‘λλλλ‘ μ μ λ³μλ‘ κ΄λ¦¬ | |
| TOKENIZER_GLOBAL = None | |
| MODEL_GLOBAL = None | |
| W_GLOBAL, MU_GLOBAL, W_P_GLOBAL, B_P_GLOBAL = None, None, None, None | |
| CLASS_NAMES_GLOBAL = None | |
| MODELS_LOADED_SUCCESSFULLY = False | |
| MODEL_LOADING_ERROR_MESSAGE = "" | |
| try: | |
| print("Gradio App: λͺ¨λΈ λ‘λ©μ μμν©λλ€...") | |
| lda_file_path = os.path.join(SAVE_DIR, f"lda_layer{LAYER_ID}_seed{SEED}.pkl") | |
| clf_file_path = os.path.join(SAVE_DIR, f"{CLF_NAME}_layer{LAYER_ID}_projlda_seed{SEED}.pkl") | |
| if not os.path.isdir(SAVE_DIR): | |
| raise FileNotFoundError(f"μ€λ₯: λͺ¨λΈ μ μ₯ λλ ν 리 '{SAVE_DIR}'λ₯Ό μ°Ύμ μ μμ΅λλ€. 'μ μ₯μ μ₯1' ν΄λλ₯Ό νμΈνμΈμ.") | |
| if not os.path.exists(lda_file_path): | |
| raise FileNotFoundError(f"μ€λ₯: LDA λͺ¨λΈ νμΌ '{lda_file_path}'λ₯Ό μ°Ύμ μ μμ΅λλ€.") | |
| if not os.path.exists(clf_file_path): | |
| raise FileNotFoundError(f"μ€λ₯: λΆλ₯κΈ° λͺ¨λΈ νμΌ '{clf_file_path}'λ₯Ό μ°Ύμ μ μμ΅λλ€.") | |
| lda = joblib.load(lda_file_path) | |
| clf = joblib.load(clf_file_path) | |
| if hasattr(clf, "base_estimator"): | |
| clf = clf.base_estimator | |
| W_GLOBAL = torch.tensor(lda.scalings_, dtype=torch.float32, device=DEVICE) | |
| MU_GLOBAL = torch.tensor(lda.xbar_, dtype=torch.float32, device=DEVICE) | |
| W_P_GLOBAL = torch.tensor(clf.coef_, dtype=torch.float32, device=DEVICE) | |
| B_P_GLOBAL = torch.tensor(clf.intercept_, dtype=torch.float32, device=DEVICE) | |
| TOKENIZER_GLOBAL = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True) | |
| MODEL_GLOBAL = AutoModel.from_pretrained( | |
| MODEL_NAME, output_hidden_states=True | |
| ).to(DEVICE).eval() | |
| if hasattr(lda, 'classes_'): | |
| CLASS_NAMES_GLOBAL = lda.classes_ | |
| elif hasattr(clf, 'classes_'): | |
| CLASS_NAMES_GLOBAL = clf.classes_ | |
| MODELS_LOADED_SUCCESSFULLY = True | |
| print("Gradio App: λͺ¨λ λͺ¨λΈ λ° λ°μ΄ν° λ‘λ μ±κ³΅!") | |
| except Exception as e: | |
| MODEL_LOADING_ERROR_MESSAGE = f"λͺ¨λΈ λ‘λ© μ€ μ¬κ°ν μ€λ₯ λ°μ: {str(e)}\n'μ μ₯μ μ₯1' ν΄λμ λ΄μ©λ¬Όμ νμΈν΄μ£ΌμΈμ." | |
| print(MODEL_LOADING_ERROR_MESSAGE) | |
| # μ΄ μ€λ₯λ Gradio UIλ₯Ό ν΅ν΄ μ¬μ©μμκ² μ λ¬λ μ μλλ‘ μ²λ¦¬ν μ μμ΅λλ€. | |
| # ββββββββββ ν΅μ¬ λΆμ ν¨μ (Gradio μΈν°νμ΄μ€κ° νΈμΆ) ββββββββββ | |
| def analyze_sentence_for_gradio(sentence_text, top_k_value): | |
| if not MODELS_LOADED_SUCCESSFULLY: | |
| # λͺ¨λΈ λ‘λ© μ€ν¨ μ Gradio μΆλ ₯ νμμ λ§μΆ° μ€λ₯ λ©μμ§ λ°ν | |
| error_html = f"<p style='color:red;'>μ΄κΈ°ν μ€λ₯: {html.escape(MODEL_LOADING_ERROR_MESSAGE)}</p>" | |
| # Gradio Interfaceλ μ μλ λͺ¨λ μΆλ ₯μ λν΄ κ°μ λ°μμΌ ν©λλ€. | |
| return error_html, "λͺ¨λΈ λ‘λ© μ€ν¨", "N/A", [] # HTML, μμΈ‘κ²°κ³Όν μ€νΈ, μμΈκ²°κ³Ό(Label), TopK(DataFrame) | |
| try: | |
| # μ μμμ λ‘λλ λͺ¨λΈ μ¬μ© | |
| tokenizer = TOKENIZER_GLOBAL | |
| model = MODEL_GLOBAL | |
| W, mu, w_p, b_p = W_GLOBAL, MU_GLOBAL, W_P_GLOBAL, B_P_GLOBAL | |
| class_names = CLASS_NAMES_GLOBAL | |
| # 1) ν ν°ν | |
| enc = tokenizer(sentence_text, return_tensors="pt", truncation=True, max_length=510, padding=True) | |
| input_ids = enc["input_ids"].to(DEVICE) | |
| attn_mask = enc["attention_mask"].to(DEVICE) | |
| if input_ids.shape[1] == 0: | |
| return "<p style='color:orange;'>μ λ ₯ μ€λ₯: μ ν¨ν ν ν°μ΄ μμ΅λλ€.</p>", "μ λ ₯ μ€λ₯", "N/A", [] | |
| # 2) μλ² λ© λ° κ·ΈλλμΈνΈ μ€μ | |
| input_embeds = model.embeddings.word_embeddings(input_ids).clone().detach() | |
| input_embeds.requires_grad_(True) | |
| # 3) Forward pass | |
| outputs = model(inputs_embeds=input_embeds, attention_mask=attn_mask, output_hidden_states=True) | |
| cls_vec = outputs.hidden_states[LAYER_ID][:, 0, :] | |
| # 4) LDA ν¬μ λ° λΆλ₯ | |
| z_projected = (cls_vec - mu) @ W | |
| logit_output = z_projected @ w_p.T + b_p | |
| probs = torch.softmax(logit_output, dim=1) | |
| pred_idx = torch.argmax(probs, dim=1).item() | |
| pred_prob_val = probs[0, pred_idx].item() | |
| # 5) Gradient κ³μ° | |
| if input_embeds.grad is not None: | |
| input_embeds.grad.zero_() | |
| logit_output[0, pred_idx].backward() | |
| if input_embeds.grad is None: | |
| return "<p style='color:red;'>λΆμ μ€λ₯: κ·ΈλλμΈνΈ κ³μ° μ€ν¨.</p>", "λΆμ μ€λ₯", "N/A", [] | |
| grads = input_embeds.grad.clone().detach() | |
| # 6) μ€μλ μ μ κ³μ° | |
| scores = (grads * input_embeds.detach()).norm(dim=2).squeeze(0) | |
| scores_np = scores.cpu().numpy() | |
| valid_scores = scores_np[np.isfinite(scores_np)] | |
| if len(valid_scores) > 0 and valid_scores.max() > 0: | |
| scores_np = scores_np / (valid_scores.max() + 1e-9) | |
| else: | |
| scores_np = np.zeros_like(scores_np) | |
| # 7) HTML μμ± | |
| tokens = tokenizer.convert_ids_to_tokens(input_ids[0], skip_special_tokens=False) | |
| html_tokens_list = [] | |
| cls_token_id, sep_token_id, pad_token_id = tokenizer.cls_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id | |
| for i, tok_str in enumerate(tokens): | |
| if input_ids[0, i] == pad_token_id: continue | |
| clean_tok_str = tok_str.replace("##", "") if "##" not in tok_str else tok_str[2:] | |
| if input_ids[0, i] == cls_token_id or input_ids[0, i] == sep_token_id: | |
| html_tokens_list.append(f"<span style='font-weight:bold;'>{html.escape(clean_tok_str)}</span>") | |
| else: | |
| score_val = scores_np[i] if i < len(scores_np) else 0 | |
| color = f"rgba(255, 0, 0, {max(0, min(1, score_val)):.2f})" | |
| html_tokens_list.append(f"<span style='background-color:{color}; padding: 1px 2px; margin: 1px; border-radius: 3px; display:inline-block;'>{html.escape(clean_tok_str)}</span>") | |
| html_output_str = " ".join(html_tokens_list).replace(" ##", "") | |
| # Top-K ν ν° (DataFrameμ© λ¦¬μ€νΈμ 리μ€νΈ) | |
| top_tokens_for_df = [] | |
| valid_indices = [idx for idx, token_id in enumerate(input_ids[0].tolist()) | |
| if token_id not in [cls_token_id, sep_token_id, pad_token_id] and idx < len(scores_np)] | |
| sorted_valid_indices = sorted(valid_indices, key=lambda idx: -scores_np[idx]) | |
| for token_idx in sorted_valid_indices[:top_k_value]: | |
| top_tokens_for_df.append([tokens[token_idx], f"{scores_np[token_idx]:.3f}"]) | |
| # μμΈ‘ ν΄λμ€ λ μ΄λΈ | |
| predicted_class_label_str = str(pred_idx) | |
| if class_names is not None and 0 <= pred_idx < len(class_names): | |
| predicted_class_label_str = str(class_names[pred_idx]) | |
| prediction_summary_text = f"ν΄λμ€: {predicted_class_label_str}\nνλ₯ : {pred_prob_val:.3f}" | |
| prediction_details_for_label = {"μμΈ‘ ν΄λμ€": predicted_class_label_str, "νλ₯ ": f"{pred_prob_val:.3f}"} | |
| return html_output_str, prediction_summary_text, prediction_details_for_label, top_tokens_for_df | |
| except Exception as e: | |
| import traceback | |
| tb_str = traceback.format_exc() | |
| error_html = f"<p style='color:red;'>λΆμ μ€ μ€λ₯ λ°μ: {html.escape(str(e))}</p><pre>{html.escape(tb_str)}</pre>" | |
| print(f"Analyze_sentence_for_gradio error: {e}\n{tb_str}") | |
| return error_html, "λΆμ μ€ν¨", {"μ€λ₯": str(e)}, [] | |
| # ββββββββββ Gradio μΈν°νμ΄μ€ μ μ ββββββββββ | |
| # μ λ ₯ μ»΄ν¬λνΈ | |
| input_sentence = gr.Textbox(lines=3, label="λΆμν μμ΄ λ¬Έμ₯", placeholder="μ¬κΈ°μ μμ΄ λ¬Έμ₯μ μ λ ₯νμΈμ...") | |
| input_top_k = gr.Slider(minimum=1, maximum=10, value=5, step=1, label="νμν Top-K μ€μ ν ν° μ") | |
| # μΆλ ₯ μ»΄ν¬λνΈ | |
| output_html_visualization = gr.HTML(label="ν ν° μ€μλ μκ°ν") | |
| output_prediction_summary = gr.Textbox(label="μμΈ‘ μμ½", lines=2) # κ°λ¨ν ν μ€νΈ μμ½μ© | |
| output_prediction_details = gr.Label(label="μμΈ‘ μμΈ") # Labelμ λμ λ리λ₯Ό μ 보μ¬μ€ | |
| output_top_tokens_df = gr.DataFrame(headers=["Token", "Score"], label="Top-K μ€μ ν ν°", row_count=(1,"dynamic"), col_count=(2,"fixed")) | |
| # Gradio Blocksλ₯Ό μ¬μ©νμ¬ λ μ΄μμ κ΅¬μ± (μ ν μ¬ν, Interfaceλ³΄λ€ μ μ°ν¨) | |
| with gr.Blocks(title="λ¬Έμ₯ ν ν° μ€μλ λΆμκΈ° (Gradio)", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# π λ¬Έμ₯ ν ν° μ€μλ λΆμκΈ° (Gradio)") | |
| gr.Markdown("BERTμ LDAλ₯Ό νμ©νμ¬ λ¬Έμ₯ λ΄ κ° ν ν°μ μ€μλλ₯Ό μκ°νν©λλ€.") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| input_sentence.render() | |
| input_top_k.render() | |
| submit_button = gr.Button("λΆμ μ€ννκΈ° π", variant="primary") | |
| with gr.Column(scale=3): | |
| output_prediction_summary.render() | |
| output_prediction_details.render() | |
| output_html_visualization.render() | |
| output_top_tokens_df.render() | |
| gr.Markdown("---") | |
| gr.Markdown("<p style='text-align: center; color: grey;'>BERT κΈ°λ° λ¬Έμ₯ λΆμ λ°λͺ¨ (Gradio)</p>") | |
| # λ²νΌ ν΄λ¦ μ ν¨μ μ°κ²° | |
| submit_button.click( | |
| fn=analyze_sentence_for_gradio, | |
| inputs=[input_sentence, input_top_k], | |
| outputs=[output_html_visualization, output_prediction_summary, output_prediction_details, output_top_tokens_df] | |
| ) | |
| # μμ μΆκ° | |
| gr.Examples( | |
| examples=[ | |
| ["This is a great movie and I really enjoyed it!", 5], | |
| ["The weather is quite gloomy today.", 3], | |
| ["I am not sure if this is the right way to do it, but let's try.", 4] | |
| ], | |
| inputs=[input_sentence, input_top_k], | |
| outputs=[output_html_visualization, output_prediction_summary, output_prediction_details, output_top_tokens_df], # μμ μ€ν μμλ λͺ¨λ μΆλ ₯ μ»΄ν¬λνΈ νμ | |
| fn=analyze_sentence_for_gradio, # μμ μ€ν μμλ λμΌ ν¨μ μ¬μ© | |
| cache_examples=False # λͺ¨λΈμ΄ μλ κ²½μ° Trueλ‘ νλ©΄ μμ λ‘λ©μ΄ λΉ¨λΌμ§ μ μμΌλ, λλ²κΉ μ€μλ False κΆμ₯ | |
| ) | |
| # Gradio μ± μ€ν (Hugging Face Spacesμμλ μ΄ λΆλΆμ΄ μλμΌλ‘ μ²λ¦¬λ¨) | |
| # λ‘컬μμ ν μ€νΈ μ: demo.launch() | |
| if __name__ == "__main__": | |
| if not MODELS_LOADED_SUCCESSFULLY: | |
| print("*"*80) | |
| print("κ²½κ³ : λͺ¨λΈ λ‘λ©μ μ€ν¨νμ¬ Gradio μ±μ΄ μ μμ μΌλ‘ μλνμ§ μμ μ μμ΅λλ€.") | |
| print(f"μ€λ₯ λ΄μ©: {MODEL_LOADING_ERROR_MESSAGE}") | |
| print("Gradio UIλ νμλμ§λ§, 'λΆμ μ€ννκΈ°' λ²νΌμ λλ μ λ μ€λ₯κ° λ°μν©λλ€.") | |
| print("`μ μ₯μ μ₯1` ν΄λ λ° λ΄λΆ νμΌλ€μ΄ `app.py`μ λμΌν λλ ν 리μ μλμ§ νμΈνμΈμ.") | |
| print("*"*80) | |
| # Hugging Face Spacesλ app.pyλ₯Ό μ€ννκ³ demo.launch()λ₯Ό μ°Ύκ±°λ | |
| # demoλΌλ μ΄λ¦μ launchable Blocks/Interface κ°μ²΄λ₯Ό μ°Ύμ΅λλ€. | |
| demo.launch() |