kikikara's picture
Update app.py
f38d876 verified
raw
history blame
11.8 kB
import gradio as gr
import os
import joblib
import torch
import numpy as np
import html
from transformers import AutoTokenizer, AutoModel, logging as hf_logging
# Hugging Face Transformers λ‘œκΉ… 레벨 μ„€μ •
hf_logging.set_verbosity_error()
# ────────── μ„€μ • ──────────
MODEL_NAME = "bert-base-uncased"
DEVICE = "cpu"
SAVE_DIR = "μ €μž₯μ €μž₯1" # 이 폴더가 app.py와 같은 μœ„μΉ˜μ— μžˆμ–΄μ•Ό ν•©λ‹ˆλ‹€.
LAYER_ID = 4
SEED = 0
CLF_NAME = "linear"
# ────────── μ „μ—­ λͺ¨λΈ λ‘œλ“œ (Gradio μ•± μ‹œμž‘ μ‹œ ν•œ 번 μ‹€ν–‰) ──────────
# Streamlit의 @st.cache_resource λŒ€μ‹ , μ•± μ‹œμž‘ μ‹œ λ‘œλ“œλ˜λ„λ‘ μ „μ—­ λ³€μˆ˜λ‘œ 관리
TOKENIZER_GLOBAL = None
MODEL_GLOBAL = None
W_GLOBAL, MU_GLOBAL, W_P_GLOBAL, B_P_GLOBAL = None, None, None, None
CLASS_NAMES_GLOBAL = None
MODELS_LOADED_SUCCESSFULLY = False
MODEL_LOADING_ERROR_MESSAGE = ""
try:
print("Gradio App: λͺ¨λΈ λ‘œλ”©μ„ μ‹œμž‘ν•©λ‹ˆλ‹€...")
lda_file_path = os.path.join(SAVE_DIR, f"lda_layer{LAYER_ID}_seed{SEED}.pkl")
clf_file_path = os.path.join(SAVE_DIR, f"{CLF_NAME}_layer{LAYER_ID}_projlda_seed{SEED}.pkl")
if not os.path.isdir(SAVE_DIR):
raise FileNotFoundError(f"였λ₯˜: λͺ¨λΈ μ €μž₯ 디렉토리 '{SAVE_DIR}'λ₯Ό 찾을 수 μ—†μŠ΅λ‹ˆλ‹€. 'μ €μž₯μ €μž₯1' 폴더λ₯Ό ν™•μΈν•˜μ„Έμš”.")
if not os.path.exists(lda_file_path):
raise FileNotFoundError(f"였λ₯˜: LDA λͺ¨λΈ 파일 '{lda_file_path}'λ₯Ό 찾을 수 μ—†μŠ΅λ‹ˆλ‹€.")
if not os.path.exists(clf_file_path):
raise FileNotFoundError(f"였λ₯˜: λΆ„λ₯˜κΈ° λͺ¨λΈ 파일 '{clf_file_path}'λ₯Ό 찾을 수 μ—†μŠ΅λ‹ˆλ‹€.")
lda = joblib.load(lda_file_path)
clf = joblib.load(clf_file_path)
if hasattr(clf, "base_estimator"):
clf = clf.base_estimator
W_GLOBAL = torch.tensor(lda.scalings_, dtype=torch.float32, device=DEVICE)
MU_GLOBAL = torch.tensor(lda.xbar_, dtype=torch.float32, device=DEVICE)
W_P_GLOBAL = torch.tensor(clf.coef_, dtype=torch.float32, device=DEVICE)
B_P_GLOBAL = torch.tensor(clf.intercept_, dtype=torch.float32, device=DEVICE)
TOKENIZER_GLOBAL = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
MODEL_GLOBAL = AutoModel.from_pretrained(
MODEL_NAME, output_hidden_states=True
).to(DEVICE).eval()
if hasattr(lda, 'classes_'):
CLASS_NAMES_GLOBAL = lda.classes_
elif hasattr(clf, 'classes_'):
CLASS_NAMES_GLOBAL = clf.classes_
MODELS_LOADED_SUCCESSFULLY = True
print("Gradio App: λͺ¨λ“  λͺ¨λΈ 및 데이터 λ‘œλ“œ 성곡!")
except Exception as e:
MODEL_LOADING_ERROR_MESSAGE = f"λͺ¨λΈ λ‘œλ”© 쀑 μ‹¬κ°ν•œ 였λ₯˜ λ°œμƒ: {str(e)}\n'μ €μž₯μ €μž₯1' 폴더와 λ‚΄μš©λ¬Όμ„ ν™•μΈν•΄μ£Όμ„Έμš”."
print(MODEL_LOADING_ERROR_MESSAGE)
# 이 였λ₯˜λŠ” Gradio UIλ₯Ό 톡해 μ‚¬μš©μžμ—κ²Œ 전달될 수 μžˆλ„λ‘ μ²˜λ¦¬ν•  수 μžˆμŠ΅λ‹ˆλ‹€.
# ────────── 핡심 뢄석 ν•¨μˆ˜ (Gradio μΈν„°νŽ˜μ΄μŠ€κ°€ 호좜) ──────────
def analyze_sentence_for_gradio(sentence_text, top_k_value):
if not MODELS_LOADED_SUCCESSFULLY:
# λͺ¨λΈ λ‘œλ”© μ‹€νŒ¨ μ‹œ Gradio 좜λ ₯ ν˜•μ‹μ— 맞좰 였λ₯˜ λ©”μ‹œμ§€ λ°˜ν™˜
error_html = f"<p style='color:red;'>μ΄ˆκΈ°ν™” 였λ₯˜: {html.escape(MODEL_LOADING_ERROR_MESSAGE)}</p>"
# Gradio InterfaceλŠ” μ •μ˜λœ λͺ¨λ“  좜λ ₯에 λŒ€ν•΄ 값을 λ°›μ•„μ•Ό ν•©λ‹ˆλ‹€.
return error_html, "λͺ¨λΈ λ‘œλ”© μ‹€νŒ¨", "N/A", [] # HTML, μ˜ˆμΈ‘κ²°κ³Όν…μŠ€νŠΈ, 상세결과(Label), TopK(DataFrame)
try:
# μ „μ—­μ—μ„œ λ‘œλ“œλœ λͺ¨λΈ μ‚¬μš©
tokenizer = TOKENIZER_GLOBAL
model = MODEL_GLOBAL
W, mu, w_p, b_p = W_GLOBAL, MU_GLOBAL, W_P_GLOBAL, B_P_GLOBAL
class_names = CLASS_NAMES_GLOBAL
# 1) 토큰화
enc = tokenizer(sentence_text, return_tensors="pt", truncation=True, max_length=510, padding=True)
input_ids = enc["input_ids"].to(DEVICE)
attn_mask = enc["attention_mask"].to(DEVICE)
if input_ids.shape[1] == 0:
return "<p style='color:orange;'>μž…λ ₯ 였λ₯˜: μœ νš¨ν•œ 토큰이 μ—†μŠ΅λ‹ˆλ‹€.</p>", "μž…λ ₯ 였λ₯˜", "N/A", []
# 2) μž„λ² λ”© 및 κ·Έλž˜λ””μ–ΈνŠΈ μ„€μ •
input_embeds = model.embeddings.word_embeddings(input_ids).clone().detach()
input_embeds.requires_grad_(True)
# 3) Forward pass
outputs = model(inputs_embeds=input_embeds, attention_mask=attn_mask, output_hidden_states=True)
cls_vec = outputs.hidden_states[LAYER_ID][:, 0, :]
# 4) LDA 투영 및 λΆ„λ₯˜
z_projected = (cls_vec - mu) @ W
logit_output = z_projected @ w_p.T + b_p
probs = torch.softmax(logit_output, dim=1)
pred_idx = torch.argmax(probs, dim=1).item()
pred_prob_val = probs[0, pred_idx].item()
# 5) Gradient 계산
if input_embeds.grad is not None:
input_embeds.grad.zero_()
logit_output[0, pred_idx].backward()
if input_embeds.grad is None:
return "<p style='color:red;'>뢄석 였λ₯˜: κ·Έλž˜λ””μ–ΈνŠΈ 계산 μ‹€νŒ¨.</p>", "뢄석 였λ₯˜", "N/A", []
grads = input_embeds.grad.clone().detach()
# 6) μ€‘μš”λ„ 점수 계산
scores = (grads * input_embeds.detach()).norm(dim=2).squeeze(0)
scores_np = scores.cpu().numpy()
valid_scores = scores_np[np.isfinite(scores_np)]
if len(valid_scores) > 0 and valid_scores.max() > 0:
scores_np = scores_np / (valid_scores.max() + 1e-9)
else:
scores_np = np.zeros_like(scores_np)
# 7) HTML 생성
tokens = tokenizer.convert_ids_to_tokens(input_ids[0], skip_special_tokens=False)
html_tokens_list = []
cls_token_id, sep_token_id, pad_token_id = tokenizer.cls_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id
for i, tok_str in enumerate(tokens):
if input_ids[0, i] == pad_token_id: continue
clean_tok_str = tok_str.replace("##", "") if "##" not in tok_str else tok_str[2:]
if input_ids[0, i] == cls_token_id or input_ids[0, i] == sep_token_id:
html_tokens_list.append(f"<span style='font-weight:bold;'>{html.escape(clean_tok_str)}</span>")
else:
score_val = scores_np[i] if i < len(scores_np) else 0
color = f"rgba(255, 0, 0, {max(0, min(1, score_val)):.2f})"
html_tokens_list.append(f"<span style='background-color:{color}; padding: 1px 2px; margin: 1px; border-radius: 3px; display:inline-block;'>{html.escape(clean_tok_str)}</span>")
html_output_str = " ".join(html_tokens_list).replace(" ##", "")
# Top-K 토큰 (DataFrame용 리슀트의 리슀트)
top_tokens_for_df = []
valid_indices = [idx for idx, token_id in enumerate(input_ids[0].tolist())
if token_id not in [cls_token_id, sep_token_id, pad_token_id] and idx < len(scores_np)]
sorted_valid_indices = sorted(valid_indices, key=lambda idx: -scores_np[idx])
for token_idx in sorted_valid_indices[:top_k_value]:
top_tokens_for_df.append([tokens[token_idx], f"{scores_np[token_idx]:.3f}"])
# 예츑 클래슀 λ ˆμ΄λΈ”
predicted_class_label_str = str(pred_idx)
if class_names is not None and 0 <= pred_idx < len(class_names):
predicted_class_label_str = str(class_names[pred_idx])
prediction_summary_text = f"클래슀: {predicted_class_label_str}\nν™•λ₯ : {pred_prob_val:.3f}"
prediction_details_for_label = {"예츑 클래슀": predicted_class_label_str, "ν™•λ₯ ": f"{pred_prob_val:.3f}"}
return html_output_str, prediction_summary_text, prediction_details_for_label, top_tokens_for_df
except Exception as e:
import traceback
tb_str = traceback.format_exc()
error_html = f"<p style='color:red;'>뢄석 쀑 였λ₯˜ λ°œμƒ: {html.escape(str(e))}</p><pre>{html.escape(tb_str)}</pre>"
print(f"Analyze_sentence_for_gradio error: {e}\n{tb_str}")
return error_html, "뢄석 μ‹€νŒ¨", {"였λ₯˜": str(e)}, []
# ────────── Gradio μΈν„°νŽ˜μ΄μŠ€ μ •μ˜ ──────────
# μž…λ ₯ μ»΄ν¬λ„ŒνŠΈ
input_sentence = gr.Textbox(lines=3, label="뢄석할 μ˜μ–΄ λ¬Έμž₯", placeholder="여기에 μ˜μ–΄ λ¬Έμž₯을 μž…λ ₯ν•˜μ„Έμš”...")
input_top_k = gr.Slider(minimum=1, maximum=10, value=5, step=1, label="ν‘œμ‹œν•  Top-K μ€‘μš” 토큰 수")
# 좜λ ₯ μ»΄ν¬λ„ŒνŠΈ
output_html_visualization = gr.HTML(label="토큰 μ€‘μš”λ„ μ‹œκ°ν™”")
output_prediction_summary = gr.Textbox(label="예츑 μš”μ•½", lines=2) # κ°„λ‹¨ν•œ ν…μŠ€νŠΈ μš”μ•½μš©
output_prediction_details = gr.Label(label="예츑 상세") # Label은 λ”•μ…”λ„ˆλ¦¬λ₯Ό 잘 λ³΄μ—¬μ€Œ
output_top_tokens_df = gr.DataFrame(headers=["Token", "Score"], label="Top-K μ€‘μš” 토큰", row_count=(1,"dynamic"), col_count=(2,"fixed"))
# Gradio Blocksλ₯Ό μ‚¬μš©ν•˜μ—¬ λ ˆμ΄μ•„μ›ƒ ꡬ성 (선택 사항, Interface보닀 μœ μ—°ν•¨)
with gr.Blocks(title="λ¬Έμž₯ 토큰 μ€‘μš”λ„ 뢄석기 (Gradio)", theme=gr.themes.Soft()) as demo:
gr.Markdown("# πŸ“ λ¬Έμž₯ 토큰 μ€‘μš”λ„ 뢄석기 (Gradio)")
gr.Markdown("BERT와 LDAλ₯Ό ν™œμš©ν•˜μ—¬ λ¬Έμž₯ λ‚΄ 각 ν† ν°μ˜ μ€‘μš”λ„λ₯Ό μ‹œκ°ν™”ν•©λ‹ˆλ‹€.")
with gr.Row():
with gr.Column(scale=2):
input_sentence.render()
input_top_k.render()
submit_button = gr.Button("뢄석 μ‹€ν–‰ν•˜κΈ° πŸš€", variant="primary")
with gr.Column(scale=3):
output_prediction_summary.render()
output_prediction_details.render()
output_html_visualization.render()
output_top_tokens_df.render()
gr.Markdown("---")
gr.Markdown("<p style='text-align: center; color: grey;'>BERT 기반 λ¬Έμž₯ 뢄석 데λͺ¨ (Gradio)</p>")
# λ²„νŠΌ 클릭 μ‹œ ν•¨μˆ˜ μ—°κ²°
submit_button.click(
fn=analyze_sentence_for_gradio,
inputs=[input_sentence, input_top_k],
outputs=[output_html_visualization, output_prediction_summary, output_prediction_details, output_top_tokens_df]
)
# 예제 μΆ”κ°€
gr.Examples(
examples=[
["This is a great movie and I really enjoyed it!", 5],
["The weather is quite gloomy today.", 3],
["I am not sure if this is the right way to do it, but let's try.", 4]
],
inputs=[input_sentence, input_top_k],
outputs=[output_html_visualization, output_prediction_summary, output_prediction_details, output_top_tokens_df], # 예제 μ‹€ν–‰ μ‹œμ—λ„ λͺ¨λ“  좜λ ₯ μ»΄ν¬λ„ŒνŠΈ ν•„μš”
fn=analyze_sentence_for_gradio, # 예제 μ‹€ν–‰ μ‹œμ—λ„ 동일 ν•¨μˆ˜ μ‚¬μš©
cache_examples=False # λͺ¨λΈμ΄ μžˆλŠ” 경우 True둜 ν•˜λ©΄ 예제 λ‘œλ”©μ΄ 빨라질 수 μžˆμœΌλ‚˜, 디버깅 μ€‘μ—λŠ” False ꢌμž₯
)
# Gradio μ•± μ‹€ν–‰ (Hugging Face Spacesμ—μ„œλŠ” 이 뢀뢄이 μžλ™μœΌλ‘œ 처리됨)
# λ‘œμ»¬μ—μ„œ ν…ŒμŠ€νŠΈ μ‹œ: demo.launch()
if __name__ == "__main__":
if not MODELS_LOADED_SUCCESSFULLY:
print("*"*80)
print("κ²½κ³ : λͺ¨λΈ λ‘œλ”©μ— μ‹€νŒ¨ν•˜μ—¬ Gradio 앱이 μ •μƒμ μœΌλ‘œ μž‘λ™ν•˜μ§€ μ•Šμ„ 수 μžˆμŠ΅λ‹ˆλ‹€.")
print(f"였λ₯˜ λ‚΄μš©: {MODEL_LOADING_ERROR_MESSAGE}")
print("Gradio UIλŠ” ν‘œμ‹œλ˜μ§€λ§Œ, '뢄석 μ‹€ν–‰ν•˜κΈ°' λ²„νŠΌμ„ λˆŒλ €μ„ λ•Œ 였λ₯˜κ°€ λ°œμƒν•©λ‹ˆλ‹€.")
print("`μ €μž₯μ €μž₯1` 폴더 및 λ‚΄λΆ€ νŒŒμΌλ“€μ΄ `app.py`와 λ™μΌν•œ 디렉토리에 μžˆλŠ”μ§€ ν™•μΈν•˜μ„Έμš”.")
print("*"*80)
# Hugging Face SpacesλŠ” app.pyλ₯Ό μ‹€ν–‰ν•˜κ³  demo.launch()λ₯Ό μ°Ύκ±°λ‚˜
# demoλΌλŠ” μ΄λ¦„μ˜ launchable Blocks/Interface 객체λ₯Ό μ°ΎμŠ΅λ‹ˆλ‹€.
demo.launch()