JasonLiao's picture
Update app.py
422bc91 verified
raw
history blame
7.59 kB
# app.py
import os
from typing import Dict, List
import gradio as gr
import pandas as pd
import torch
import torch.nn as nn
from transformers import AutoTokenizer, BertPreTrainedModel, BertModel
class BertForCLSClassification(BertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.bert = BertModel(config)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.post_init()
def forward(self, input_ids=None, token_type_ids=None, attention_mask=None, labels=None):
outputs = self.bert(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
cls_output = outputs.last_hidden_state[:, 0, :] # 跟你原本一樣
logits = self.classifier(cls_output)
return logits
# =========================
# 設定:模型位置 & HF 私有 Token(可選)
# =========================
# 若你把權重直接放在 Space 目錄(例如 config.json、model.safetensors、tokenizer 檔),
# 可把 MODEL_ID 改成 "." 以載入本地檔案。
MODEL_ID = "TAIDE-EDU/task4-level-judgement"
HF_TOKEN = os.getenv("HF_TOKEN", None)
device = "cuda" if torch.cuda.is_available() else "cpu"
# =========================
# 載入模型與 tokenizer
# =========================
def load_model_and_tokenizer():
kwargs = {}
if HF_TOKEN and MODEL_ID != ".":
kwargs["token"] = HF_TOKEN
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, **kwargs)
model = BertForCLSClassification.from_pretrained(MODEL_ID, **kwargs).to(device)
model.eval()
# 從 config 取 id2label;若缺少則提供預設
id2label = getattr(model.config, "id2label", None) or {0: "入門基礎", 1: "進階高階", 2: "流利精通"}
# 依 id 排序成 list 方便顯示順序穩定
ordered_labels = [id2label[i] if i in id2label else str(i) for i in range(len(id2label))]
return model, tokenizer, ordered_labels
model, tokenizer, ordered_labels = load_model_and_tokenizer()
MAX_LEN = 512
try:
if model is not None:
MAX_LEN = min(getattr(model.config, "max_position_embeddings", 512) or 512, 512)
except Exception:
pass
# =========================
# 推論邏輯(任務0:程度等級分布)
# =========================
def class_judgement(text: str) -> dict:
"""回傳任務0的機率分布(字串,美觀顯示於 Label)。"""
if not text or not text.strip():
return "(請輸入或從下方表格點選範例)"
batch = tokenizer(
[text],
max_length=MAX_LEN,
padding=True,
truncation=True,
return_tensors="pt",
)
batch = {k: v.to(device) for k, v in batch.items()}
with torch.no_grad():
output = model(**batch)
print(output.shape)
probs = torch.softmax(output, dim=-1)[0].tolist()
# 轉成「label: prob」並格式化為多行字串
predictions = {}
for i, lab in enumerate(ordered_labels):
p = probs[i] if i < len(probs) else 0.0
predictions[lab] = p
return predictions
# =========================
# 範例資料(只放兩個)
# - 閱讀測驗:id=1
# - 克漏字: id=2
# =========================
headers = ['id', '範例']
reading_article = (
"文章:\n"
"我家旁邊有一個公園,公園裡有高高的樹和很多花。早上常常有小鳥在唱歌,"
"小孩也喜歡去那裡玩。下雨以後,空氣會變得很新鮮。我和媽媽有時會一起去公園走路,"
"看天上的雲,覺得很快樂。\n\n"
"題目:\n"
"1.這篇文章主要在介紹什麼地方?\n"
"(A)我家旁邊的公園\n"
"(B)我學校附近的商店\n"
"(C)我媽媽的工作地點\n"
"(D)小孩最想去的玩具店\n"
)
cloze_text = (
"陳老師在大學教書已經十年了,他一直認為,學生除了學會課本上的知識,"
"更重要的是學會__1__和別人合作。每學期開始前,他都會__2__一個小組討論的主題,"
"讓學生們分組進行研究。每個組員要分工合作,有的人負責收集資料,有的人負責__3__和報告,"
"大家__4__幫忙,才能完成老師的要求。有時候小組成員之間會出現意見不同的情況,"
"陳老師總是__5__他們耐心溝通,學會傾聽別人的想法。他相信,這樣的經歷__6__能提升學生的能力,"
"__6__會對他們未來的工作和生活有很大幫助。\n"
"1.\n(A)既然\n(B)如何\n(C)常常\n(D)但是\n"
"2.\n(A)放棄\n(B)安排\n(C)要求\n(D)修理\n"
"3.\n(A)整理\n(B)傳達\n(C)指導\n(D)販賣\n"
"4.\n(A)故意\n(B)互相\n(C)分別\n(D)偶爾\n"
"5.\n(A)鼓勵\n(B)懷疑\n(C)責罵\n(D)批評\n"
"6.\n(A)不僅⋯⋯還⋯⋯\n(B)與其⋯⋯不如⋯⋯\n(C)要麼⋯⋯要麼⋯⋯\n(D)只有⋯⋯才⋯⋯\n"
)
# 兩個表只各放一列(id=1 與 id=2),其他欄位填空字串即可
reading_samples = [
[1, reading_article],
]
filling_samples = [
[2, cloze_text],
]
reading_test_df = pd.DataFrame(reading_samples, columns=headers)
filling_test_df = pd.DataFrame(filling_samples, columns=headers)
# 點選表格時要填入的實際「完整 Prompt」內容
reading_test_id2data_str: Dict[int, str] = {1: reading_article}
filling_test_id2data_str: Dict[int, str] = {2: cloze_text}
# =========================
# 表格點選事件
# =========================
def on_row_select(evt: gr.SelectData, df: pd.DataFrame) -> str:
"""回傳應填入 Textbox 的完整文字。"""
# evt.index 可能是 int 或 [int];統一成 int
row_idx = evt.index[0] if isinstance(evt.index, (list, tuple)) else evt.index
row = df.iloc[int(row_idx)]
row_id = int(row["id"])
if row_id in reading_test_id2data_str.keys():
# 閱讀測驗
return reading_test_id2data_str[row_id]
else:
# 克漏字
return filling_test_id2data_str[row_id]
# =========================
# Gradio 介面
# =========================
with gr.Blocks(title="class_judgement") as demo:
gr.Markdown("## 華策會等級分類器(任務0 Demo)\n從下方**範例表**點選一列,系統會自動帶入並推論。")
with gr.Row():
inp = gr.Textbox(label="輸入文章", lines=12, placeholder="也可手動貼文後按下『送出』")
btn = gr.Button("送出", variant="primary")
with gr.Row():
out0 = gr.Label(label="任務0:程度等級分布")
# 閱讀測驗範例表(只 1 列)
table_reading = gr.Dataframe(
value=reading_test_df,
headers=headers,
row_count=(len(reading_test_df), "fixed"),
col_count=len(headers),
interactive=False,
wrap=True,
label="閱讀測驗範例表(點選帶入)"
)
table_reading.select(
on_row_select,
inputs=table_reading,
outputs=[inp]
)
# 克漏字範例表(只 1 列)
table_filling = gr.Dataframe(
value=filling_test_df,
headers=headers,
row_count=(len(filling_test_df), "fixed"),
col_count=len(headers),
interactive=False,
wrap=True,
label="克漏字填空範例表(點選帶入)"
)
table_filling.select(
on_row_select,
inputs=table_filling,
outputs=[inp]
)
# 手動輸入或點範例帶入後,按「送出」只輸出任務0
btn.click(
class_judgement,
inputs=inp,
outputs=[out0]
)
# 建立完 demo 後
demo.queue(max_size=32)
if __name__ == "__main__":
demo.launch()