|
|
|
""" |
|
簡單的中文情感分析模型創建腳本 |
|
基於 bert-base-chinese 創建一個可推理的模型 |
|
""" |
|
|
|
from transformers import ( |
|
BertTokenizer, |
|
BertForSequenceClassification, |
|
pipeline |
|
) |
|
import torch |
|
|
|
def create_model(): |
|
"""創建基於 BERT 的中文情感分析模型""" |
|
|
|
print("正在載入 bert-base-chinese...") |
|
|
|
|
|
model_name = "bert-base-chinese" |
|
|
|
|
|
tokenizer = BertTokenizer.from_pretrained(model_name) |
|
|
|
|
|
model = BertForSequenceClassification.from_pretrained( |
|
model_name, |
|
num_labels=2, |
|
id2label={0: "NEGATIVE", 1: "POSITIVE"}, |
|
label2id={"NEGATIVE": 0, "POSITIVE": 1} |
|
) |
|
|
|
print("✅ 模型載入完成!") |
|
return model, tokenizer |
|
|
|
def save_model(model, tokenizer, save_path="./"): |
|
"""保存模型到指定路徑""" |
|
|
|
print(f"正在保存模型到 {save_path}...") |
|
|
|
|
|
model.save_pretrained(save_path) |
|
tokenizer.save_pretrained(save_path) |
|
|
|
print("✅ 模型保存完成!") |
|
|
|
|
|
import os |
|
print("\\n生成的檔案:") |
|
for file in sorted(os.listdir(save_path)): |
|
if not file.startswith('.'): |
|
print(f" 📄 {file}") |
|
|
|
def test_model(model_path="./"): |
|
"""測試模型推理功能""" |
|
|
|
print("\\n=== 測試模型推理 ===") |
|
|
|
try: |
|
|
|
classifier = pipeline( |
|
"text-classification", |
|
model=model_path, |
|
tokenizer=model_path |
|
) |
|
|
|
|
|
test_texts = [ |
|
"這個產品真的很棒!我很喜歡。", |
|
"質量太差了,完全不值得購買。", |
|
"還不錯,可以考慮。", |
|
"非常滿意這次的服務體驗。" |
|
] |
|
|
|
print("\\n推理結果:") |
|
for i, text in enumerate(test_texts, 1): |
|
result = classifier(text) |
|
label = result[0]['label'] |
|
score = result[0]['score'] |
|
|
|
print(f"{i}. 文本: {text}") |
|
print(f" 預測: {label} (信心度: {score:.4f})") |
|
print() |
|
|
|
print("✅ 推理測試完成!") |
|
|
|
except Exception as e: |
|
print(f"❌ 推理測試失敗: {e}") |
|
|
|
if __name__ == "__main__": |
|
print("🚀 開始創建中文情感分析模型...") |
|
|
|
try: |
|
|
|
model, tokenizer = create_model() |
|
|
|
|
|
save_model(model, tokenizer) |
|
|
|
|
|
test_model() |
|
|
|
print("\\n" + "="*50) |
|
print("🎉 模型創建成功!") |
|
print("\\n📋 下一步:") |
|
print("1. git add . && git commit -m 'Add trained model'") |
|
print("2. git push origin main") |
|
print("3. 其他人可以使用:") |
|
print(" from transformers import pipeline") |
|
print(" classifier = pipeline('text-classification', model='sk413025/my-awesome-model')") |
|
|
|
except Exception as e: |
|
print(f"❌ 錯誤: {e}") |
|
print("請確保網路連接正常,能夠下載 bert-base-chinese 模型") |
|
|