kikikara commited on
Commit
4c2f748
ยท
verified ยท
1 Parent(s): 1f56dfe

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +284 -0
app.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """app.ipynb
3
+
4
+ Automatically generated by Colab.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1y3yISz14Lpsr131OIJCKA77lwbFmEJzB
8
+ """
9
+
10
+ import streamlit as st
11
+ import os
12
+ import joblib
13
+ import torch
14
+ import numpy as np
15
+ import html
16
+ from transformers import AutoTokenizer, AutoModel, logging as hf_logging
17
+
18
+ # Hugging Face Transformers ๋กœ๊น… ๋ ˆ๋ฒจ ์„ค์ • (์˜ค๋ฅ˜๋งŒ ํ‘œ์‹œ)
19
+ hf_logging.set_verbosity_error()
20
+
21
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ ์„ค์ • (Hugging Face Spaces ํ™˜๊ฒฝ์— ๋งž๊ฒŒ ์กฐ์ •) โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
22
+ MODEL_NAME = "bert-base-uncased"
23
+ DEVICE = "cpu" # Hugging Face Spaces ๋ฌด๋ฃŒ ํ‹ฐ์–ด๋Š” CPU ์‚ฌ์šฉ
24
+ SAVE_DIR = "์ €์žฅ์ €์žฅ1" # ์—…๋กœ๋“œํ•  ํด๋”๋ช…๊ณผ ์ผ์น˜ํ•ด์•ผ ํ•จ
25
+ LAYER_ID = 4 # ์›๋ณธ ์ฝ”๋“œ์˜ SeparationScore ์ตœ๊ณ  ๋ ˆ์ด์–ด
26
+ SEED = 0 # ์›๋ณธ ์ฝ”๋“œ์˜ SEED ๊ฐ’
27
+ CLF_NAME = "linear" # ์›๋ณธ ์ฝ”๋“œ์˜ CLF_NAME
28
+
29
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ ๋ชจ๋ธ ๋กœ๋“œ (Streamlit ์บ์‹œ ์‚ฌ์šฉ์œผ๋กœ ์•ฑ ์ „์ฒด์—์„œ ํ•œ ๋ฒˆ๋งŒ ์‹คํ–‰) โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
30
+ @st.cache_resource
31
+ def load_all_models_and_data():
32
+ """
33
+ LDA, ๋ถ„๋ฅ˜๊ธฐ, ํ† ํฌ๋‚˜์ด์ €, BERT ๋ชจ๋ธ ๋ฐ ๊ด€๋ จ ํ–‰๋ ฌ๋“ค์„ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค.
34
+ Hugging Face Spaces์— ๋ฐฐํฌ ์‹œ ํŒŒ์ผ ๊ฒฝ๋กœ๊ฐ€ ์ •ํ™•ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
35
+ """
36
+ lda_file_path = os.path.join(SAVE_DIR, f"lda_layer{LAYER_ID}_seed{SEED}.pkl")
37
+ clf_file_path = os.path.join(SAVE_DIR, f"{CLF_NAME}_layer{LAYER_ID}_projlda_seed{SEED}.pkl")
38
+
39
+ # ํŒŒ์ผ ์กด์žฌ ์—ฌ๋ถ€ ํ™•์ธ (๋ฐฐํฌ ํ™˜๊ฒฝ ๋””๋ฒ„๊น…์šฉ)
40
+ if not os.path.isdir(SAVE_DIR):
41
+ st.error(f"์˜ค๋ฅ˜: ๋ชจ๋ธ ์ €์žฅ ๋””๋ ‰ํ† ๋ฆฌ '{SAVE_DIR}'๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค. Spaces์— ํด๋”๊ฐ€ ์˜ฌ๋ฐ”๋ฅด๊ฒŒ ์—…๋กœ๋“œ๋˜์—ˆ๋Š”์ง€, ์ด๋ฆ„์ด ์ผ์น˜ํ•˜๋Š”์ง€ ํ™•์ธํ•˜์„ธ์š”.")
42
+ return None
43
+ if not os.path.exists(lda_file_path):
44
+ st.error(f"์˜ค๋ฅ˜: LDA ๋ชจ๋ธ ํŒŒ์ผ '{lda_file_path}'๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค. ํŒŒ์ผ ์ด๋ฆ„๊ณผ ๊ฒฝ๋กœ๋ฅผ ํ™•์ธํ•˜์„ธ์š”.")
45
+ return None
46
+ if not os.path.exists(clf_file_path):
47
+ st.error(f"์˜ค๋ฅ˜: ๋ถ„๋ฅ˜๊ธฐ ๋ชจ๋ธ ํŒŒ์ผ '{clf_file_path}'๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค. ํŒŒ์ผ ์ด๋ฆ„๊ณผ ๊ฒฝ๋กœ๋ฅผ ํ™•์ธํ•˜์„ธ์š”.")
48
+ return None
49
+
50
+ try:
51
+ lda = joblib.load(lda_file_path)
52
+ clf = joblib.load(clf_file_path)
53
+ except Exception as e:
54
+ st.error(f"๋ชจ๋ธ ํŒŒ์ผ ๋กœ๋“œ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
55
+ st.error("ํŒŒ์ผ์ด ์†์ƒ๋˜์—ˆ๊ฑฐ๋‚˜, joblib ๋ฒ„์ „ ํ˜ธํ™˜์„ฑ ๋ฌธ์ œ๊ฐ€ ์žˆ์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.")
56
+ return None
57
+
58
+ if hasattr(clf, "base_estimator"): # Calibrated Ridge ๊ฒฝ์šฐ
59
+ clf = clf.base_estimator
60
+
61
+ # LDA ํ–‰๋ ฌยทํ‰๊ท , ๋ถ„๋ฅ˜๊ธฐ ๊ฐ€์ค‘์น˜๋ฅผ PyTorch Tensor๋กœ ๋ณ€ํ™˜
62
+ W_tensor = torch.tensor(lda.scalings_, dtype=torch.float32, device=DEVICE)
63
+ mu_vector = torch.tensor(lda.xbar_, dtype=torch.float32, device=DEVICE)
64
+ w_p_tensor = torch.tensor(clf.coef_, dtype=torch.float32, device=DEVICE)
65
+ b_p_vector = torch.tensor(clf.intercept_, dtype=torch.float32, device=DEVICE)
66
+
67
+ # Hugging Face ํ† ํฌ๋‚˜์ด์ € ๋ฐ BERT ๋ชจ๋ธ ๋กœ๋“œ
68
+ try:
69
+ tokenizer_obj = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
70
+ model_obj = AutoModel.from_pretrained(
71
+ MODEL_NAME, output_hidden_states=True
72
+ ).to(DEVICE).eval()
73
+ except Exception as e:
74
+ st.error(f"Hugging Face ๋ชจ๋ธ ({MODEL_NAME}) ๋กœ๋“œ ์ค‘ ์˜ค๋ฅ˜: {e}")
75
+ st.error("์ธํ„ฐ๋„ท ์—ฐ๊ฒฐ ๋˜๋Š” ๋ชจ๋ธ ์ด๋ฆ„์ด ์˜ฌ๋ฐ”๋ฅธ์ง€ ํ™•์ธํ•˜์„ธ์š”.")
76
+ return None
77
+
78
+ # ํด๋ž˜์Šค ์ด๋ฆ„ ๊ฐ€์ ธ์˜ค๊ธฐ ์‹œ๋„
79
+ class_names = None
80
+ if hasattr(lda, 'classes_'): # scikit-learn LDA์˜ ๊ฒฝ์šฐ
81
+ class_names = lda.classes_
82
+ elif hasattr(clf, 'classes_'): # scikit-learn ๋ถ„๋ฅ˜๊ธฐ์˜ ๊ฒฝ์šฐ
83
+ class_names = clf.classes_
84
+
85
+ return tokenizer_obj, model_obj, W_tensor, mu_vector, w_p_tensor, b_p_vector, class_names
86
+
87
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ ํ•ต์‹ฌ ๋ถ„์„ ํ•จ์ˆ˜ (์›๋ณธ ์ฝ”๋“œ ๊ธฐ๋ฐ˜) โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
88
+ def explain_sentence_streamlit(
89
+ text: str,
90
+ tokenizer, model, W, mu, w_p, b_p, # ๋กœ๋“œ๋œ ๊ฐ์ฒด๋“ค
91
+ layer_id_to_use: int, device_to_use: str, # ์„ค์ •๊ฐ’
92
+ top_k_tokens: int = 5
93
+ ) -> tuple[str, int, float, list] | None: # ๊ฒฐ๊ณผ ํƒ€์ž… ๋ช…์‹œ (์‹คํŒจ ์‹œ None)
94
+ """
95
+ ์ž…๋ ฅ ๋ฌธ์žฅ์„ ์˜ˆ์ธกํ•˜๊ณ  ํ† ํฐ ์ค‘์š”๋„๋ฅผ ๊ณ„์‚ฐํ•˜์—ฌ ๊ฒฐ๊ณผ๋ฅผ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
96
+ """
97
+ try:
98
+ # 1) ํ† ํฐํ™” (์ตœ๋Œ€ ๊ธธ์ด ๋ฐ ์ž˜๋ฆผ ์ฒ˜๋ฆฌ ์ถ”๊ฐ€)
99
+ enc = tokenizer(text, return_tensors="pt", truncation=True, max_length=510, padding=True) # BERT ์ตœ๋Œ€ ๊ธธ์ด 512 ๊ณ ๋ ค, CLS/SEP ๊ณต๊ฐ„ ํ™•๋ณด
100
+ input_ids = enc["input_ids"].to(device_to_use)
101
+ attn_mask = enc["attention_mask"].to(device_to_use)
102
+
103
+ if input_ids.shape[1] == 0: # ์ž…๋ ฅ์ด ๋„ˆ๋ฌด ์งง๊ฑฐ๋‚˜ ๋ชจ๋‘ ํ•„ํ„ฐ๋ง ๋œ ๊ฒฝ์šฐ
104
+ # Streamlit ์•ฑ์—์„œ๋Š” ์‚ฌ์šฉ์ž์—๊ฒŒ ๊ฒฝ๊ณ ๋ฅผ ํ‘œ์‹œํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
105
+ # st.warning("ํ† ํฐํ™” ๊ฒฐ๊ณผ ์œ ํšจํ•œ ํ† ํฐ์ด ์—†์Šต๋‹ˆ๋‹ค. ๋‹ค๋ฅธ ๋ฌธ์žฅ์„ ์‹œ๋„ํ•ด๏ฟฝ๏ฟฝ๏ฟฝ์„ธ์š”.")
106
+ return None
107
+
108
+
109
+ # 2) ์ž„๋ฒ ๋”ฉ์— gradient ์ถ”์ 
110
+ input_embeds = model.embeddings.word_embeddings(input_ids).clone().detach()
111
+ input_embeds.requires_grad_(True)
112
+
113
+ # 3) Forward pass โ†’ CLS ๋ฒกํ„ฐ ์ถ”์ถœ
114
+ outputs = model(inputs_embeds=input_embeds,
115
+ attention_mask=attn_mask, # Attention mask ์ „๋‹ฌ
116
+ output_hidden_states=True)
117
+ cls_vec = outputs.hidden_states[layer_id_to_use][:, 0, :] # (1, 768)
118
+
119
+ # 4) LDA ํˆฌ์˜ โ†’ ๋ถ„๋ฅ˜ logit ๊ณ„์‚ฐ
120
+ z_projected = (cls_vec - mu) @ W # (1, d)
121
+ logit_output = z_projected @ w_p.T + b_p # (1, C)
122
+
123
+ probs = torch.softmax(logit_output, dim=1)
124
+ pred_idx = torch.argmax(probs, dim=1).item()
125
+ pred_prob = probs[0, pred_idx].item()
126
+
127
+ # 5) Gradient ๊ณ„์‚ฐ
128
+ if input_embeds.grad is not None:
129
+ input_embeds.grad.zero_() # ์ด์ „ ๊ทธ๋ž˜๋””์–ธํŠธ ์ดˆ๊ธฐํ™”
130
+ logit_output[0, pred_idx].backward() # ์„ ํƒ๋œ ์˜ˆ์ธก ํด๋ž˜์Šค์— ๋Œ€ํ•œ ๊ทธ๋ž˜๋””์–ธํŠธ ๊ณ„์‚ฐ
131
+
132
+ if input_embeds.grad is None: # backward ํ›„์—๋„ grad๊ฐ€ ์—†๋Š” ์˜ˆ์™ธ์  ์ƒํ™ฉ ๋ฐฉ์ง€
133
+ # st.error("๊ทธ๋ž˜๋””์–ธํŠธ๋ฅผ ๊ณ„์‚ฐํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.") # Streamlit ์•ฑ ๋‚ด์—์„œ ์˜ค๋ฅ˜ ํ‘œ์‹œ
134
+ return None
135
+
136
+ grads = input_embeds.grad.clone().detach()
137
+
138
+ # 6) Grad ร— Input โ†’ ์ค‘์š”๋„ ์ ์ˆ˜ ๊ณ„์‚ฐ
139
+ scores = (grads * input_embeds.detach()).norm(dim=2).squeeze(0)
140
+ scores_np = scores.cpu().numpy()
141
+
142
+ # ์œ ํšจํ•œ ์ ์ˆ˜๋งŒ์œผ๋กœ ์ •๊ทœํ™” (NaN/Inf ๋ฐฉ์ง€)
143
+ valid_scores = scores_np[np.isfinite(scores_np)]
144
+ if len(valid_scores) > 0 and valid_scores.max() > 0:
145
+ scores_np = scores_np / (valid_scores.max() + 1e-9) # 0~1 ์ •๊ทœํ™”
146
+ else: # ๋ชจ๋“  ์ ์ˆ˜๊ฐ€ 0์ด๊ฑฐ๋‚˜ ์œ ํšจํ•˜์ง€ ์•Š์€ ๊ฒฝ์šฐ
147
+ scores_np = np.zeros_like(scores_np)
148
+
149
+
150
+ # 7) HTML ํ•˜์ด๋ผ์ดํŠธ ์ƒ์„ฑ
151
+ tokens = tokenizer.convert_ids_to_tokens(input_ids[0], skip_special_tokens=False) # ์ŠคํŽ˜์…œ ํ† ํฐ ํฌํ•จ
152
+ html_tokens_list = []
153
+
154
+ # CLS, SEP, PAD ํ† ํฐ ID ํ™•์ธ
155
+ cls_token_id = tokenizer.cls_token_id
156
+ sep_token_id = tokenizer.sep_token_id
157
+ pad_token_id = tokenizer.pad_token_id
158
+
159
+ for i, tok_str in enumerate(tokens):
160
+ if input_ids[0, i] == pad_token_id: # PAD ํ† ํฐ์€ ๊ฑด๋„ˆ๋›ฐ๊ธฐ
161
+ continue
162
+
163
+ clean_tok_str = tok_str.replace("##", "") if "##" not in tok_str else tok_str[2:]
164
+
165
+ # ์ŠคํŽ˜์…œ ํ† ํฐ์€ ๋‹ค๋ฅธ ์Šคํƒ€์ผ ์ ์šฉ ๋˜๋Š” ์ค‘์š”๋„ ๊ณ„์‚ฐ์—์„œ ์ œ์™ธ ๊ฐ€๋Šฅ
166
+ if input_ids[0, i] == cls_token_id or input_ids[0, i] == sep_token_id:
167
+ html_tokens_list.append(f"<span style='font-weight:bold;'>{html.escape(clean_tok_str)}</span>")
168
+ else:
169
+ score_val = scores_np[i] if i < len(scores_np) else 0 # ์ ์ˆ˜ ๋ฐฐ์—ด ๋ฒ”์œ„ ํ™•์ธ
170
+ color = f"rgba(255, 0, 0, {max(0, min(1, score_val)):.2f})" # ์ ์ˆ˜ ๋ฒ”์œ„ 0~1๋กœ ํด๋ฆฌํ•‘
171
+ html_tokens_list.append(
172
+ f"<span style='background-color:{color}; padding: 1px 2px; margin: 1px; border-radius: 3px; display:inline-block;'>{html.escape(clean_tok_str)}</span>"
173
+ )
174
+
175
+ html_output_str = " ".join(html_tokens_list)
176
+ # ๋ถˆํ•„์š”ํ•œ ๊ณต๋ฐฑ ์ •๋ฆฌ (์˜ˆ: subword ์‚ฌ์ด ๊ณต๋ฐฑ)
177
+ html_output_str = html_output_str.replace(" ##", "")
178
+
179
+ # Top-K ์ค‘์š” ํ† ํฐ ์ •๋ณด (์ŠคํŽ˜์…œ ํ† ํฐ ๋ฐ PAD ํ† ํฐ ์ œ์™ธ)
180
+ top_tokens_info_list = []
181
+ valid_indices_for_top_k = [
182
+ idx for idx, token_id in enumerate(input_ids[0].tolist())
183
+ if token_id not in [cls_token_id, sep_token_id, pad_token_id] and idx < len(scores_np)
184
+ ]
185
+
186
+ # ์ ์ˆ˜๊ฐ€ ๋†’์€ ์ˆœ์œผ๋กœ ์ •๋ ฌ
187
+ sorted_valid_indices = sorted(valid_indices_for_top_k, key=lambda idx: -scores_np[idx])
188
+
189
+ for token_idx in sorted_valid_indices[:top_k_tokens]:
190
+ top_tokens_info_list.append({
191
+ "token": tokens[token_idx],
192
+ "score": f"{scores_np[token_idx]:.3f}"
193
+ })
194
+
195
+ return html_output_str, pred_idx, pred_prob, top_tokens_info_list
196
+
197
+ except Exception as e:
198
+ # Streamlit ์•ฑ ๋‚ด์—์„œ ์˜ค๋ฅ˜๋ฅผ ๋” ์ž˜ ํ‘œ์‹œํ•˜๋„๋ก ์ˆ˜์ •
199
+ # st.error(f"๋ฌธ์žฅ ๋ถ„์„ ์ค‘ ์˜ˆ๊ธฐ์น˜ ์•Š์€ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
200
+ # import traceback
201
+ # st.text_area("์˜ค๋ฅ˜ ์ƒ์„ธ ์ •๋ณด (๋””๋ฒ„๊น…์šฉ):", traceback.format_exc(), height=200)
202
+ # print(f"๋ฌธ์žฅ ๋ถ„์„ ์ค‘ ์˜ˆ๊ธฐ์น˜ ์•Š์€ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}") # ์ฝ˜์†” ๋กœ๊น… (Spaces ๋กœ๊ทธ์—์„œ ํ™•์ธ ๊ฐ€๋Šฅ)
203
+ # import traceback
204
+ # print(traceback.format_exc()) # ์ฝ˜์†” ๋กœ๊น…
205
+ raise # ์˜ค๋ฅ˜๋ฅผ ๋‹ค์‹œ ๋ฐœ์ƒ์‹œ์ผœ Streamlit์ด ์ฒ˜๋ฆฌํ•˜๋„๋ก ํ•˜๊ฑฐ๋‚˜, ์•„๋ž˜์—์„œ None์„ ๋ฐ˜ํ™˜
206
+ # return None
207
+
208
+
209
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ Streamlit UI ๊ตฌ์„ฑ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
210
+ st.set_page_config(page_title="๋ฌธ์žฅ ํ† ํฐ ์ค‘์š”๋„ ๋ถ„์„๊ธฐ", layout="wide")
211
+ st.title("๐Ÿ“ ๋ฌธ์žฅ ํ† ํฐ ์ค‘์š”๋„ ๋ถ„์„๊ธฐ")
212
+ st.markdown("BERT์™€ LDA๋ฅผ ํ™œ์šฉํ•˜์—ฌ ๋ฌธ์žฅ ๋‚ด ๊ฐ ํ† ํฐ์˜ ์ค‘์š”๋„๋ฅผ ์‹œ๊ฐํ™”ํ•ฉ๋‹ˆ๋‹ค.")
213
+
214
+ # ๋ชจ๋ธ ๋กœ๋“œ ์‹œ๋„
215
+ loaded_data_tuple = load_all_models_and_data()
216
+
217
+ if loaded_data_tuple:
218
+ tokenizer, model, W, mu, w_p, b_p, class_names = loaded_data_tuple
219
+
220
+ # ์‚ฌ์ด๋“œ๋ฐ”์— ๋ชจ๋ธ ์ •๋ณด ํ‘œ์‹œ
221
+ st.sidebar.header("โš™๏ธ ๋ชจ๋ธ ๋ฐ ์„ค์ • ์ •๋ณด")
222
+ st.sidebar.info(f"**BERT ๋ชจ๋ธ:** `{MODEL_NAME}`\n\n"
223
+ f"**์‚ฌ์šฉ๋œ ๋ ˆ์ด์–ด ID:** `{LAYER_ID}`\n\n"
224
+ f"**๋ถ„๋ฅ˜๊ธฐ ์ข…๋ฅ˜:** `{CLF_NAME}` (LDA ํˆฌ์˜ ๊ธฐ๋ฐ˜)\n\n"
225
+ f"**์‹คํ–‰ ์žฅ์น˜:** `{DEVICE.upper()}`")
226
+ if class_names is not None:
227
+ st.sidebar.markdown(f"**์˜ˆ์ธก ๊ฐ€๋Šฅ ํด๋ž˜์Šค:** `{', '.join(map(str, class_names))}`")
228
+
229
+
230
+ # ์‚ฌ์šฉ์ž ์ž…๋ ฅ
231
+ st.subheader("๐Ÿ‘‡ ๋ถ„์„ํ•  ์˜์–ด ๋ฌธ์žฅ์„ ์ž…๋ ฅํ•˜์„ธ์š”:")
232
+ user_sentence = st.text_area("๋ฌธ์žฅ ์ž…๋ ฅ:", "This movie is exceptionally good and I highly recommend it.", height=100)
233
+
234
+ top_k_slider = st.slider("ํ‘œ์‹œํ•  Top-K ์ค‘์š” ํ† ํฐ ์ˆ˜:", min_value=1, max_value=10, value=5, step=1)
235
+
236
+ if st.button("๋ถ„์„ ์‹คํ–‰ํ•˜๊ธฐ ๐Ÿš€", type="primary"):
237
+ if user_sentence:
238
+ with st.spinner("๋ฌธ์žฅ์„ ๋ถ„์„ํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค... ์กฐ๊ธˆ๋งŒ ๊ธฐ๋‹ค๋ ค์ฃผ์„ธ์š”...โณ"):
239
+ analysis_results = None
240
+ try:
241
+ analysis_results = explain_sentence_streamlit(
242
+ user_sentence, tokenizer, model, W, mu, w_p, b_p,
243
+ LAYER_ID, DEVICE, top_k_tokens=top_k_slider
244
+ )
245
+ except Exception as e: # explain_sentence_streamlit ๋‚ด๋ถ€์—์„œ raise๋œ ์˜ค๋ฅ˜ ์ฒ˜๋ฆฌ
246
+ st.error(f"๋ถ„์„ ์ฒ˜๋ฆฌ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
247
+ st.info("์ž…๋ ฅ ๋ฌธ์žฅ์ด๋‚˜ ๋ชจ๋ธ ํ˜ธํ™˜์„ฑ ๋ฌธ์ œ๋ฅผ ํ™•์ธํ•ด๋ณด์„ธ์š”. ๋ฌธ์ œ๊ฐ€ ์ง€์†๋˜๋ฉด ๊ด€๋ฆฌ์ž์—๊ฒŒ ๋ฌธ์˜ํ•˜์„ธ์š”.")
248
+ # ๋” ์ž์„ธํ•œ ์˜ค๋ฅ˜๋Š” Spaces์˜ ๋กœ๊ทธ์—์„œ ํ™•์ธ ๊ฐ€๋Šฅ (print๋ฌธ ์‚ฌ์šฉ ์‹œ)
249
+
250
+
251
+ if analysis_results: # ์„ฑ๊ณต์ ์œผ๋กœ ๊ฒฐ๊ณผ ๋ฐ˜ํ™˜ ์‹œ
252
+ html_viz, predicted_idx, probability, top_k_list = analysis_results
253
+
254
+ st.markdown("---")
255
+ st.subheader("๐Ÿ“Š ๋ถ„์„ ๊ฒฐ๊ณผ")
256
+
257
+ predicted_class_label = str(predicted_idx) # ๊ธฐ๋ณธ๊ฐ’: ์ธ๋ฑ์Šค
258
+ if class_names is not None and 0 <= predicted_idx < len(class_names):
259
+ predicted_class_label = str(class_names[predicted_idx]) # ํด๋ž˜์Šค ์ด๋ฆ„ ์‚ฌ์šฉ
260
+
261
+ st.success(f"**์˜ˆ์ธก๋œ ํด๋ž˜์Šค:** **`{predicted_class_label}`** (์‹ ๋ขฐ๋„: **{probability:.2f}**)")
262
+
263
+ st.subheader("๐ŸŽจ ํ† ํฐ๋ณ„ ์ค‘์š”๋„ ์‹œ๊ฐํ™”")
264
+ st.markdown(html_viz, unsafe_allow_html=True)
265
+
266
+ st.subheader(f"โญ Top-{top_k_slider} ์ค‘์š” ํ† ํฐ")
267
+ if top_k_list:
268
+ cols = st.columns(len(top_k_list) if len(top_k_list) <=5 else 5 ) # ํ•œ ์ค„์— ์ตœ๋Œ€ 5๊ฐœ
269
+ for i, item in enumerate(top_k_list):
270
+ with cols[i % len(cols)]:
271
+ st.metric(label=item['token'], value=item['score'])
272
+ else:
273
+ st.info("์ค‘์š”๋„ ๋†’์€ ํ† ํฐ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค (์ŠคํŽ˜์…œ ํ† ํฐ ๋“ฑ ์ œ์™ธ).")
274
+ # 'analysis_results is None' ์ด๊ณ  ์˜ˆ์™ธ์ฒ˜๋ฆฌ๋กœ st.error๊ฐ€ ์ด๋ฏธ ํ‘œ์‹œ๋œ ๊ฒฝ์šฐ๋Š” ์ถ”๊ฐ€ ๋ฉ”์‹œ์ง€ ๋ถˆํ•„์š”
275
+ elif analysis_results is None and not user_sentence: # ๋ฌธ์žฅ ์ž…๋ ฅ ์—†์ด ๋ฒ„ํŠผ ๋ˆ„๋ฅธ ๊ฒฝ์šฐ (์‚ฌ์‹ค์ƒ ์œ„์—์„œ ์ฒ˜๋ฆฌ)
276
+ pass # ์ด๋ฏธ st.warning์œผ๋กœ ์ฒ˜๋ฆฌ๋จ
277
+
278
+ else: # ๋ฌธ์žฅ ์ž…๋ ฅ ์—†์ด ๋ฒ„ํŠผ ๋ˆ„๋ฅธ ๊ฒฝ์šฐ
279
+ st.warning("๋ถ„์„ํ•  ๋ฌธ์žฅ์„ ์ž…๋ ฅํ•ด์ฃผ์„ธ์š”.")
280
+ else:
281
+ st.error("๋ชจ๋ธ ๋กœ๋”ฉ์— ์‹คํŒจํ•˜์—ฌ ์• ํ”Œ๋ฆฌ์ผ€์ด์…˜์„ ์‹œ์ž‘ํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค. ์—…๋กœ๋“œ๋œ ํŒŒ์ผ๊ณผ ๊ฒฝ๋กœ ์„ค์ •์„ ํ™•์ธํ•ด์ฃผ์„ธ์š”. Hugging Face Spaces์˜ 'Logs' ํƒญ์—์„œ ์ƒ์„ธ ์˜ค๋ฅ˜๋ฅผ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.")
282
+
283
+ st.markdown("---")
284
+ st.markdown("<p style='text-align: center; color: grey;'>BERT ๊ธฐ๋ฐ˜ ๋ฌธ์žฅ ๋ถ„์„ ๋ฐ๋ชจ</p>", unsafe_allow_html=True)