File size: 7,588 Bytes
9fe4b01
 
 
ad87d43
9fe4b01
 
ad87d43
877a009
 
9fe4b01
877a009
 
 
 
 
 
 
 
 
 
 
 
 
9fe4b01
 
 
 
 
d4d0fcb
9fe4b01
ad87d43
 
9fe4b01
 
 
 
 
 
 
 
 
877a009
9fe4b01
 
 
 
 
 
 
b1be1d1
9fe4b01
 
 
 
 
 
 
 
 
 
 
f778b06
 
 
 
 
 
 
 
 
 
 
 
 
552f453
877a009
 
 
552f453
f778b06
 
 
 
 
 
552f453
9fe4b01
 
 
 
 
6b5c543
9fe4b01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b1be1d1
9fe4b01
 
f778b06
ad87d43
9fe4b01
 
 
 
 
f778b06
9fe4b01
 
 
 
 
 
 
 
 
 
0d027c2
9fe4b01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad87d43
9fe4b01
 
 
 
ad87d43
 
9fe4b01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad87d43
9fe4b01
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
# app.py
import os
from typing import Dict, List

import gradio as gr
import pandas as pd
import torch
import torch.nn as nn
from transformers import AutoTokenizer, BertPreTrainedModel, BertModel

class BertForCLSClassification(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.bert = BertModel(config)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
        self.post_init()

    def forward(self, input_ids=None, token_type_ids=None, attention_mask=None, labels=None):
        outputs = self.bert(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
        cls_output = outputs.last_hidden_state[:, 0, :]  # 跟你原本一樣
        logits = self.classifier(cls_output)
        return logits
    
# =========================
# 設定:模型位置 & HF 私有 Token(可選)
# =========================
# 若你把權重直接放在 Space 目錄(例如 config.json、model.safetensors、tokenizer 檔),
# 可把 MODEL_ID 改成 "." 以載入本地檔案。
MODEL_ID = "TAIDE-EDU/task4-level-judgement"
HF_TOKEN = os.getenv("HF_TOKEN", None)

device = "cuda" if torch.cuda.is_available() else "cpu"

# =========================
# 載入模型與 tokenizer
# =========================
def load_model_and_tokenizer():
    kwargs = {}
    if HF_TOKEN and MODEL_ID != ".":
        kwargs["token"] = HF_TOKEN
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, **kwargs)
    model = BertForCLSClassification.from_pretrained(MODEL_ID, **kwargs).to(device)
    model.eval()
    # 從 config 取 id2label;若缺少則提供預設
    id2label = getattr(model.config, "id2label", None) or {0: "入門基礎", 1: "進階高階", 2: "流利精通"}
    # 依 id 排序成 list 方便顯示順序穩定
    ordered_labels = [id2label[i] if i in id2label else str(i) for i in range(len(id2label))]
    return model, tokenizer, ordered_labels

model, tokenizer, ordered_labels = load_model_and_tokenizer()

MAX_LEN = 512
try:
    if model is not None:
        MAX_LEN = min(getattr(model.config, "max_position_embeddings", 512) or 512, 512)
except Exception:
    pass

# =========================
# 推論邏輯(任務0:程度等級分布)
# =========================
def class_judgement(text: str) -> dict:
    """回傳任務0的機率分布(字串,美觀顯示於 Label)。"""
    if not text or not text.strip():
        return "(請輸入或從下方表格點選範例)"

    batch = tokenizer(
        [text],
        max_length=MAX_LEN,
        padding=True,
        truncation=True,
        return_tensors="pt",
    )
    batch = {k: v.to(device) for k, v in batch.items()}
    with torch.no_grad():
        output = model(**batch)
        print(output.shape)
        probs = torch.softmax(output, dim=-1)[0].tolist()

    # 轉成「label: prob」並格式化為多行字串
    predictions = {}
    for i, lab in enumerate(ordered_labels):
        p = probs[i] if i < len(probs) else 0.0
        predictions[lab] = p
    return predictions

# =========================
# 範例資料(只放兩個)
# - 閱讀測驗:id=1
# - 克漏字:  id=2
# =========================
headers = ['id', '範例']

reading_article = (
    "文章:\n"
    "我家旁邊有一個公園,公園裡有高高的樹和很多花。早上常常有小鳥在唱歌,"
    "小孩也喜歡去那裡玩。下雨以後,空氣會變得很新鮮。我和媽媽有時會一起去公園走路,"
    "看天上的雲,覺得很快樂。\n\n"
    "題目:\n"
    "1.這篇文章主要在介紹什麼地方?\n"
    "(A)我家旁邊的公園\n"
    "(B)我學校附近的商店\n"
    "(C)我媽媽的工作地點\n"
    "(D)小孩最想去的玩具店\n"
)

cloze_text = (
    "陳老師在大學教書已經十年了,他一直認為,學生除了學會課本上的知識,"
    "更重要的是學會__1__和別人合作。每學期開始前,他都會__2__一個小組討論的主題,"
    "讓學生們分組進行研究。每個組員要分工合作,有的人負責收集資料,有的人負責__3__和報告,"
    "大家__4__幫忙,才能完成老師的要求。有時候小組成員之間會出現意見不同的情況,"
    "陳老師總是__5__他們耐心溝通,學會傾聽別人的想法。他相信,這樣的經歷__6__能提升學生的能力,"
    "__6__會對他們未來的工作和生活有很大幫助。\n"
    "1.\n(A)既然\n(B)如何\n(C)常常\n(D)但是\n"
    "2.\n(A)放棄\n(B)安排\n(C)要求\n(D)修理\n"
    "3.\n(A)整理\n(B)傳達\n(C)指導\n(D)販賣\n"
    "4.\n(A)故意\n(B)互相\n(C)分別\n(D)偶爾\n"
    "5.\n(A)鼓勵\n(B)懷疑\n(C)責罵\n(D)批評\n"
    "6.\n(A)不僅⋯⋯還⋯⋯\n(B)與其⋯⋯不如⋯⋯\n(C)要麼⋯⋯要麼⋯⋯\n(D)只有⋯⋯才⋯⋯\n"
)

# 兩個表只各放一列(id=1 與 id=2),其他欄位填空字串即可
reading_samples = [
    [1, reading_article],
]
filling_samples = [
    [2, cloze_text],
]
reading_test_df = pd.DataFrame(reading_samples, columns=headers)
filling_test_df = pd.DataFrame(filling_samples, columns=headers)

# 點選表格時要填入的實際「完整 Prompt」內容
reading_test_id2data_str: Dict[int, str] = {1: reading_article}
filling_test_id2data_str: Dict[int, str] = {2: cloze_text}

# =========================
# 表格點選事件
# =========================
def on_row_select(evt: gr.SelectData, df: pd.DataFrame) -> str:
    """回傳應填入 Textbox 的完整文字。"""
    # evt.index 可能是 int 或 [int];統一成 int
    row_idx = evt.index[0] if isinstance(evt.index, (list, tuple)) else evt.index
    row = df.iloc[int(row_idx)]
    row_id = int(row["id"])
    if row_id in reading_test_id2data_str.keys():
        # 閱讀測驗
        return reading_test_id2data_str[row_id]
    else:
        # 克漏字
        return filling_test_id2data_str[row_id]

# =========================
# Gradio 介面
# =========================
with gr.Blocks(title="class_judgement") as demo:
    gr.Markdown("## 華策會等級分類器(任務0 Demo)\n從下方**範例表**點選一列,系統會自動帶入並推論。")

    with gr.Row():
        inp = gr.Textbox(label="輸入文章", lines=12, placeholder="也可手動貼文後按下『送出』")
        btn = gr.Button("送出", variant="primary")

    with gr.Row():
        out0 = gr.Label(label="任務0:程度等級分布")

    # 閱讀測驗範例表(只 1 列)
    table_reading = gr.Dataframe(
        value=reading_test_df,
        headers=headers,
        row_count=(len(reading_test_df), "fixed"),
        col_count=len(headers),
        interactive=False,
        wrap=True,
        label="閱讀測驗範例表(點選帶入)"
    )

    table_reading.select(
        on_row_select,
        inputs=table_reading,
        outputs=[inp]
    )

    # 克漏字範例表(只 1 列)
    table_filling = gr.Dataframe(
        value=filling_test_df,
        headers=headers,
        row_count=(len(filling_test_df), "fixed"),
        col_count=len(headers),
        interactive=False,
        wrap=True,
        label="克漏字填空範例表(點選帶入)"
    )

    table_filling.select(
        on_row_select,
        inputs=table_filling,
        outputs=[inp]
    )

    # 手動輸入或點範例帶入後,按「送出」只輸出任務0
    btn.click(
        class_judgement,
        inputs=inp,
        outputs=[out0]
    )

# 建立完 demo 後
demo.queue(max_size=32)

if __name__ == "__main__":
    demo.launch()