Spaces:
Running
Running
# app.py | |
import os | |
from typing import Dict, List | |
import gradio as gr | |
import pandas as pd | |
import torch | |
import torch.nn as nn | |
from transformers import AutoTokenizer, BertPreTrainedModel, BertModel | |
class BertForCLSClassification(BertPreTrainedModel): | |
def __init__(self, config): | |
super().__init__(config) | |
self.bert = BertModel(config) | |
self.classifier = nn.Linear(config.hidden_size, config.num_labels) | |
self.post_init() | |
def forward(self, input_ids=None, token_type_ids=None, attention_mask=None, labels=None): | |
outputs = self.bert(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) | |
cls_output = outputs.last_hidden_state[:, 0, :] # 跟你原本一樣 | |
logits = self.classifier(cls_output) | |
return logits | |
# ========================= | |
# 設定:模型位置 & HF 私有 Token(可選) | |
# ========================= | |
# 若你把權重直接放在 Space 目錄(例如 config.json、model.safetensors、tokenizer 檔), | |
# 可把 MODEL_ID 改成 "." 以載入本地檔案。 | |
MODEL_ID = "TAIDE-EDU/task4-level-judgement" | |
HF_TOKEN = os.getenv("HF_TOKEN", None) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# ========================= | |
# 載入模型與 tokenizer | |
# ========================= | |
def load_model_and_tokenizer(): | |
kwargs = {} | |
if HF_TOKEN and MODEL_ID != ".": | |
kwargs["token"] = HF_TOKEN | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, **kwargs) | |
model = BertForCLSClassification.from_pretrained(MODEL_ID, **kwargs).to(device) | |
model.eval() | |
# 從 config 取 id2label;若缺少則提供預設 | |
id2label = getattr(model.config, "id2label", None) or {0: "入門基礎", 1: "進階高階", 2: "流利精通"} | |
# 依 id 排序成 list 方便顯示順序穩定 | |
ordered_labels = [id2label[i] if i in id2label else str(i) for i in range(len(id2label))] | |
return model, tokenizer, ordered_labels | |
model, tokenizer, ordered_labels = load_model_and_tokenizer() | |
MAX_LEN = 512 | |
try: | |
if model is not None: | |
MAX_LEN = min(getattr(model.config, "max_position_embeddings", 512) or 512, 512) | |
except Exception: | |
pass | |
# ========================= | |
# 推論邏輯(任務0:程度等級分布) | |
# ========================= | |
def class_judgement(text: str) -> dict: | |
"""回傳任務0的機率分布(字串,美觀顯示於 Label)。""" | |
if not text or not text.strip(): | |
return "(請輸入或從下方表格點選範例)" | |
batch = tokenizer( | |
[text], | |
max_length=MAX_LEN, | |
padding=True, | |
truncation=True, | |
return_tensors="pt", | |
) | |
batch = {k: v.to(device) for k, v in batch.items()} | |
with torch.no_grad(): | |
output = model(**batch) | |
print(output.shape) | |
probs = torch.softmax(output, dim=-1)[0].tolist() | |
# 轉成「label: prob」並格式化為多行字串 | |
predictions = {} | |
for i, lab in enumerate(ordered_labels): | |
p = probs[i] if i < len(probs) else 0.0 | |
predictions[lab] = p | |
return predictions | |
# ========================= | |
# 範例資料(只放兩個) | |
# - 閱讀測驗:id=1 | |
# - 克漏字: id=2 | |
# ========================= | |
headers = ['id', '範例'] | |
reading_article = ( | |
"文章:\n" | |
"我家旁邊有一個公園,公園裡有高高的樹和很多花。早上常常有小鳥在唱歌," | |
"小孩也喜歡去那裡玩。下雨以後,空氣會變得很新鮮。我和媽媽有時會一起去公園走路," | |
"看天上的雲,覺得很快樂。\n\n" | |
"題目:\n" | |
"1.這篇文章主要在介紹什麼地方?\n" | |
"(A)我家旁邊的公園\n" | |
"(B)我學校附近的商店\n" | |
"(C)我媽媽的工作地點\n" | |
"(D)小孩最想去的玩具店\n" | |
) | |
cloze_text = ( | |
"陳老師在大學教書已經十年了,他一直認為,學生除了學會課本上的知識," | |
"更重要的是學會__1__和別人合作。每學期開始前,他都會__2__一個小組討論的主題," | |
"讓學生們分組進行研究。每個組員要分工合作,有的人負責收集資料,有的人負責__3__和報告," | |
"大家__4__幫忙,才能完成老師的要求。有時候小組成員之間會出現意見不同的情況," | |
"陳老師總是__5__他們耐心溝通,學會傾聽別人的想法。他相信,這樣的經歷__6__能提升學生的能力," | |
"__6__會對他們未來的工作和生活有很大幫助。\n" | |
"1.\n(A)既然\n(B)如何\n(C)常常\n(D)但是\n" | |
"2.\n(A)放棄\n(B)安排\n(C)要求\n(D)修理\n" | |
"3.\n(A)整理\n(B)傳達\n(C)指導\n(D)販賣\n" | |
"4.\n(A)故意\n(B)互相\n(C)分別\n(D)偶爾\n" | |
"5.\n(A)鼓勵\n(B)懷疑\n(C)責罵\n(D)批評\n" | |
"6.\n(A)不僅⋯⋯還⋯⋯\n(B)與其⋯⋯不如⋯⋯\n(C)要麼⋯⋯要麼⋯⋯\n(D)只有⋯⋯才⋯⋯\n" | |
) | |
# 兩個表只各放一列(id=1 與 id=2),其他欄位填空字串即可 | |
reading_samples = [ | |
[1, reading_article], | |
] | |
filling_samples = [ | |
[2, cloze_text], | |
] | |
reading_test_df = pd.DataFrame(reading_samples, columns=headers) | |
filling_test_df = pd.DataFrame(filling_samples, columns=headers) | |
# 點選表格時要填入的實際「完整 Prompt」內容 | |
reading_test_id2data_str: Dict[int, str] = {1: reading_article} | |
filling_test_id2data_str: Dict[int, str] = {2: cloze_text} | |
# ========================= | |
# 表格點選事件 | |
# ========================= | |
def on_row_select(evt: gr.SelectData, df: pd.DataFrame) -> str: | |
"""回傳應填入 Textbox 的完整文字。""" | |
# evt.index 可能是 int 或 [int];統一成 int | |
row_idx = evt.index[0] if isinstance(evt.index, (list, tuple)) else evt.index | |
row = df.iloc[int(row_idx)] | |
row_id = int(row["id"]) | |
if row_id in reading_test_id2data_str.keys(): | |
# 閱讀測驗 | |
return reading_test_id2data_str[row_id] | |
else: | |
# 克漏字 | |
return filling_test_id2data_str[row_id] | |
# ========================= | |
# Gradio 介面 | |
# ========================= | |
with gr.Blocks(title="class_judgement") as demo: | |
gr.Markdown("## 華策會等級分類器(任務0 Demo)\n從下方**範例表**點選一列,系統會自動帶入並推論。") | |
with gr.Row(): | |
inp = gr.Textbox(label="輸入文章", lines=12, placeholder="也可手動貼文後按下『送出』") | |
btn = gr.Button("送出", variant="primary") | |
with gr.Row(): | |
out0 = gr.Label(label="任務0:程度等級分布") | |
# 閱讀測驗範例表(只 1 列) | |
table_reading = gr.Dataframe( | |
value=reading_test_df, | |
headers=headers, | |
row_count=(len(reading_test_df), "fixed"), | |
col_count=len(headers), | |
interactive=False, | |
wrap=True, | |
label="閱讀測驗範例表(點選帶入)" | |
) | |
table_reading.select( | |
on_row_select, | |
inputs=table_reading, | |
outputs=[inp] | |
) | |
# 克漏字範例表(只 1 列) | |
table_filling = gr.Dataframe( | |
value=filling_test_df, | |
headers=headers, | |
row_count=(len(filling_test_df), "fixed"), | |
col_count=len(headers), | |
interactive=False, | |
wrap=True, | |
label="克漏字填空範例表(點選帶入)" | |
) | |
table_filling.select( | |
on_row_select, | |
inputs=table_filling, | |
outputs=[inp] | |
) | |
# 手動輸入或點範例帶入後,按「送出」只輸出任務0 | |
btn.click( | |
class_judgement, | |
inputs=inp, | |
outputs=[out0] | |
) | |
# 建立完 demo 後 | |
demo.queue(max_size=32) | |
if __name__ == "__main__": | |
demo.launch() |