# app.py import os from typing import Dict, List import gradio as gr import pandas as pd import torch import torch.nn as nn from transformers import AutoTokenizer, BertPreTrainedModel, BertModel class BertForCLSClassification(BertPreTrainedModel): def __init__(self, config): super().__init__(config) self.bert = BertModel(config) self.classifier = nn.Linear(config.hidden_size, config.num_labels) self.post_init() def forward(self, input_ids=None, token_type_ids=None, attention_mask=None, labels=None): outputs = self.bert(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) cls_output = outputs.last_hidden_state[:, 0, :] # 跟你原本一樣 logits = self.classifier(cls_output) return logits # ========================= # 設定:模型位置 & HF 私有 Token(可選) # ========================= # 若你把權重直接放在 Space 目錄(例如 config.json、model.safetensors、tokenizer 檔), # 可把 MODEL_ID 改成 "." 以載入本地檔案。 MODEL_ID = "TAIDE-EDU/task4-level-judgement" HF_TOKEN = os.getenv("HF_TOKEN", None) device = "cuda" if torch.cuda.is_available() else "cpu" # ========================= # 載入模型與 tokenizer # ========================= def load_model_and_tokenizer(): kwargs = {} if HF_TOKEN and MODEL_ID != ".": kwargs["token"] = HF_TOKEN tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, **kwargs) model = BertForCLSClassification.from_pretrained(MODEL_ID, **kwargs).to(device) model.eval() # 從 config 取 id2label;若缺少則提供預設 id2label = getattr(model.config, "id2label", None) or {0: "入門基礎", 1: "進階高階", 2: "流利精通"} # 依 id 排序成 list 方便顯示順序穩定 ordered_labels = [id2label[i] if i in id2label else str(i) for i in range(len(id2label))] return model, tokenizer, ordered_labels model, tokenizer, ordered_labels = load_model_and_tokenizer() MAX_LEN = 512 try: if model is not None: MAX_LEN = min(getattr(model.config, "max_position_embeddings", 512) or 512, 512) except Exception: pass # ========================= # 推論邏輯(任務0:程度等級分布) # ========================= def class_judgement(text: str) -> dict: """回傳任務0的機率分布(字串,美觀顯示於 Label)。""" if not text or not text.strip(): return "(請輸入或從下方表格點選範例)" batch = tokenizer( [text], max_length=MAX_LEN, padding=True, truncation=True, return_tensors="pt", ) batch = {k: v.to(device) for k, v in batch.items()} with torch.no_grad(): output = model(**batch) print(output.shape) probs = torch.softmax(output, dim=-1)[0].tolist() # 轉成「label: prob」並格式化為多行字串 predictions = {} for i, lab in enumerate(ordered_labels): p = probs[i] if i < len(probs) else 0.0 predictions[lab] = p return predictions # ========================= # 範例資料(只放兩個) # - 閱讀測驗:id=1 # - 克漏字: id=2 # ========================= headers = ['id', '範例'] reading_article = ( "文章:\n" "我家旁邊有一個公園,公園裡有高高的樹和很多花。早上常常有小鳥在唱歌," "小孩也喜歡去那裡玩。下雨以後,空氣會變得很新鮮。我和媽媽有時會一起去公園走路," "看天上的雲,覺得很快樂。\n\n" "題目:\n" "1.這篇文章主要在介紹什麼地方?\n" "(A)我家旁邊的公園\n" "(B)我學校附近的商店\n" "(C)我媽媽的工作地點\n" "(D)小孩最想去的玩具店\n" ) cloze_text = ( "陳老師在大學教書已經十年了,他一直認為,學生除了學會課本上的知識," "更重要的是學會__1__和別人合作。每學期開始前,他都會__2__一個小組討論的主題," "讓學生們分組進行研究。每個組員要分工合作,有的人負責收集資料,有的人負責__3__和報告," "大家__4__幫忙,才能完成老師的要求。有時候小組成員之間會出現意見不同的情況," "陳老師總是__5__他們耐心溝通,學會傾聽別人的想法。他相信,這樣的經歷__6__能提升學生的能力," "__6__會對他們未來的工作和生活有很大幫助。\n" "1.\n(A)既然\n(B)如何\n(C)常常\n(D)但是\n" "2.\n(A)放棄\n(B)安排\n(C)要求\n(D)修理\n" "3.\n(A)整理\n(B)傳達\n(C)指導\n(D)販賣\n" "4.\n(A)故意\n(B)互相\n(C)分別\n(D)偶爾\n" "5.\n(A)鼓勵\n(B)懷疑\n(C)責罵\n(D)批評\n" "6.\n(A)不僅⋯⋯還⋯⋯\n(B)與其⋯⋯不如⋯⋯\n(C)要麼⋯⋯要麼⋯⋯\n(D)只有⋯⋯才⋯⋯\n" ) # 兩個表只各放一列(id=1 與 id=2),其他欄位填空字串即可 reading_samples = [ [1, reading_article], ] filling_samples = [ [2, cloze_text], ] reading_test_df = pd.DataFrame(reading_samples, columns=headers) filling_test_df = pd.DataFrame(filling_samples, columns=headers) # 點選表格時要填入的實際「完整 Prompt」內容 reading_test_id2data_str: Dict[int, str] = {1: reading_article} filling_test_id2data_str: Dict[int, str] = {2: cloze_text} # ========================= # 表格點選事件 # ========================= def on_row_select(evt: gr.SelectData, df: pd.DataFrame) -> str: """回傳應填入 Textbox 的完整文字。""" # evt.index 可能是 int 或 [int];統一成 int row_idx = evt.index[0] if isinstance(evt.index, (list, tuple)) else evt.index row = df.iloc[int(row_idx)] row_id = int(row["id"]) if row_id in reading_test_id2data_str.keys(): # 閱讀測驗 return reading_test_id2data_str[row_id] else: # 克漏字 return filling_test_id2data_str[row_id] # ========================= # Gradio 介面 # ========================= with gr.Blocks(title="class_judgement") as demo: gr.Markdown("## 華策會等級分類器(任務0 Demo)\n從下方**範例表**點選一列,系統會自動帶入並推論。") with gr.Row(): inp = gr.Textbox(label="輸入文章", lines=12, placeholder="也可手動貼文後按下『送出』") btn = gr.Button("送出", variant="primary") with gr.Row(): out0 = gr.Label(label="任務0:程度等級分布") # 閱讀測驗範例表(只 1 列) table_reading = gr.Dataframe( value=reading_test_df, headers=headers, row_count=(len(reading_test_df), "fixed"), col_count=len(headers), interactive=False, wrap=True, label="閱讀測驗範例表(點選帶入)" ) table_reading.select( on_row_select, inputs=table_reading, outputs=[inp] ) # 克漏字範例表(只 1 列) table_filling = gr.Dataframe( value=filling_test_df, headers=headers, row_count=(len(filling_test_df), "fixed"), col_count=len(headers), interactive=False, wrap=True, label="克漏字填空範例表(點選帶入)" ) table_filling.select( on_row_select, inputs=table_filling, outputs=[inp] ) # 手動輸入或點範例帶入後,按「送出」只輸出任務0 btn.click( class_judgement, inputs=inp, outputs=[out0] ) # 建立完 demo 後 demo.queue(max_size=32) if __name__ == "__main__": demo.launch()