Spaces:
Sleeping
Sleeping
File size: 7,588 Bytes
9fe4b01 ad87d43 9fe4b01 ad87d43 877a009 9fe4b01 877a009 9fe4b01 d4d0fcb 9fe4b01 ad87d43 9fe4b01 877a009 9fe4b01 b1be1d1 9fe4b01 f778b06 552f453 877a009 552f453 f778b06 552f453 9fe4b01 6b5c543 9fe4b01 b1be1d1 9fe4b01 f778b06 ad87d43 9fe4b01 f778b06 9fe4b01 0d027c2 9fe4b01 ad87d43 9fe4b01 ad87d43 9fe4b01 ad87d43 9fe4b01 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 |
# app.py
import os
from typing import Dict, List
import gradio as gr
import pandas as pd
import torch
import torch.nn as nn
from transformers import AutoTokenizer, BertPreTrainedModel, BertModel
class BertForCLSClassification(BertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.bert = BertModel(config)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.post_init()
def forward(self, input_ids=None, token_type_ids=None, attention_mask=None, labels=None):
outputs = self.bert(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
cls_output = outputs.last_hidden_state[:, 0, :] # 跟你原本一樣
logits = self.classifier(cls_output)
return logits
# =========================
# 設定:模型位置 & HF 私有 Token(可選)
# =========================
# 若你把權重直接放在 Space 目錄(例如 config.json、model.safetensors、tokenizer 檔),
# 可把 MODEL_ID 改成 "." 以載入本地檔案。
MODEL_ID = "TAIDE-EDU/task4-level-judgement"
HF_TOKEN = os.getenv("HF_TOKEN", None)
device = "cuda" if torch.cuda.is_available() else "cpu"
# =========================
# 載入模型與 tokenizer
# =========================
def load_model_and_tokenizer():
kwargs = {}
if HF_TOKEN and MODEL_ID != ".":
kwargs["token"] = HF_TOKEN
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, **kwargs)
model = BertForCLSClassification.from_pretrained(MODEL_ID, **kwargs).to(device)
model.eval()
# 從 config 取 id2label;若缺少則提供預設
id2label = getattr(model.config, "id2label", None) or {0: "入門基礎", 1: "進階高階", 2: "流利精通"}
# 依 id 排序成 list 方便顯示順序穩定
ordered_labels = [id2label[i] if i in id2label else str(i) for i in range(len(id2label))]
return model, tokenizer, ordered_labels
model, tokenizer, ordered_labels = load_model_and_tokenizer()
MAX_LEN = 512
try:
if model is not None:
MAX_LEN = min(getattr(model.config, "max_position_embeddings", 512) or 512, 512)
except Exception:
pass
# =========================
# 推論邏輯(任務0:程度等級分布)
# =========================
def class_judgement(text: str) -> dict:
"""回傳任務0的機率分布(字串,美觀顯示於 Label)。"""
if not text or not text.strip():
return "(請輸入或從下方表格點選範例)"
batch = tokenizer(
[text],
max_length=MAX_LEN,
padding=True,
truncation=True,
return_tensors="pt",
)
batch = {k: v.to(device) for k, v in batch.items()}
with torch.no_grad():
output = model(**batch)
print(output.shape)
probs = torch.softmax(output, dim=-1)[0].tolist()
# 轉成「label: prob」並格式化為多行字串
predictions = {}
for i, lab in enumerate(ordered_labels):
p = probs[i] if i < len(probs) else 0.0
predictions[lab] = p
return predictions
# =========================
# 範例資料(只放兩個)
# - 閱讀測驗:id=1
# - 克漏字: id=2
# =========================
headers = ['id', '範例']
reading_article = (
"文章:\n"
"我家旁邊有一個公園,公園裡有高高的樹和很多花。早上常常有小鳥在唱歌,"
"小孩也喜歡去那裡玩。下雨以後,空氣會變得很新鮮。我和媽媽有時會一起去公園走路,"
"看天上的雲,覺得很快樂。\n\n"
"題目:\n"
"1.這篇文章主要在介紹什麼地方?\n"
"(A)我家旁邊的公園\n"
"(B)我學校附近的商店\n"
"(C)我媽媽的工作地點\n"
"(D)小孩最想去的玩具店\n"
)
cloze_text = (
"陳老師在大學教書已經十年了,他一直認為,學生除了學會課本上的知識,"
"更重要的是學會__1__和別人合作。每學期開始前,他都會__2__一個小組討論的主題,"
"讓學生們分組進行研究。每個組員要分工合作,有的人負責收集資料,有的人負責__3__和報告,"
"大家__4__幫忙,才能完成老師的要求。有時候小組成員之間會出現意見不同的情況,"
"陳老師總是__5__他們耐心溝通,學會傾聽別人的想法。他相信,這樣的經歷__6__能提升學生的能力,"
"__6__會對他們未來的工作和生活有很大幫助。\n"
"1.\n(A)既然\n(B)如何\n(C)常常\n(D)但是\n"
"2.\n(A)放棄\n(B)安排\n(C)要求\n(D)修理\n"
"3.\n(A)整理\n(B)傳達\n(C)指導\n(D)販賣\n"
"4.\n(A)故意\n(B)互相\n(C)分別\n(D)偶爾\n"
"5.\n(A)鼓勵\n(B)懷疑\n(C)責罵\n(D)批評\n"
"6.\n(A)不僅⋯⋯還⋯⋯\n(B)與其⋯⋯不如⋯⋯\n(C)要麼⋯⋯要麼⋯⋯\n(D)只有⋯⋯才⋯⋯\n"
)
# 兩個表只各放一列(id=1 與 id=2),其他欄位填空字串即可
reading_samples = [
[1, reading_article],
]
filling_samples = [
[2, cloze_text],
]
reading_test_df = pd.DataFrame(reading_samples, columns=headers)
filling_test_df = pd.DataFrame(filling_samples, columns=headers)
# 點選表格時要填入的實際「完整 Prompt」內容
reading_test_id2data_str: Dict[int, str] = {1: reading_article}
filling_test_id2data_str: Dict[int, str] = {2: cloze_text}
# =========================
# 表格點選事件
# =========================
def on_row_select(evt: gr.SelectData, df: pd.DataFrame) -> str:
"""回傳應填入 Textbox 的完整文字。"""
# evt.index 可能是 int 或 [int];統一成 int
row_idx = evt.index[0] if isinstance(evt.index, (list, tuple)) else evt.index
row = df.iloc[int(row_idx)]
row_id = int(row["id"])
if row_id in reading_test_id2data_str.keys():
# 閱讀測驗
return reading_test_id2data_str[row_id]
else:
# 克漏字
return filling_test_id2data_str[row_id]
# =========================
# Gradio 介面
# =========================
with gr.Blocks(title="class_judgement") as demo:
gr.Markdown("## 華策會等級分類器(任務0 Demo)\n從下方**範例表**點選一列,系統會自動帶入並推論。")
with gr.Row():
inp = gr.Textbox(label="輸入文章", lines=12, placeholder="也可手動貼文後按下『送出』")
btn = gr.Button("送出", variant="primary")
with gr.Row():
out0 = gr.Label(label="任務0:程度等級分布")
# 閱讀測驗範例表(只 1 列)
table_reading = gr.Dataframe(
value=reading_test_df,
headers=headers,
row_count=(len(reading_test_df), "fixed"),
col_count=len(headers),
interactive=False,
wrap=True,
label="閱讀測驗範例表(點選帶入)"
)
table_reading.select(
on_row_select,
inputs=table_reading,
outputs=[inp]
)
# 克漏字範例表(只 1 列)
table_filling = gr.Dataframe(
value=filling_test_df,
headers=headers,
row_count=(len(filling_test_df), "fixed"),
col_count=len(headers),
interactive=False,
wrap=True,
label="克漏字填空範例表(點選帶入)"
)
table_filling.select(
on_row_select,
inputs=table_filling,
outputs=[inp]
)
# 手動輸入或點範例帶入後,按「送出」只輸出任務0
btn.click(
class_judgement,
inputs=inp,
outputs=[out0]
)
# 建立完 demo 後
demo.queue(max_size=32)
if __name__ == "__main__":
demo.launch() |