File size: 2,495 Bytes
4399fdc
 
 
1f1a361
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import os
os.environ["HF_HOME"] = "/tmp/.cache/huggingface"
os.environ["MPLCONFIGDIR"] = "/tmp/.config/matplotlib"
import streamlit as st
import joblib
import pandas as pd
import shap
import matplotlib.pyplot as plt
import platform
from huggingface_hub import hf_hub_download

# 跨平台字型設定
if platform.system() == 'Windows':
    plt.rcParams['font.family'] = 'Microsoft JhengHei'
elif platform.system() == 'Darwin':  # macOS
    plt.rcParams['font.family'] = 'AppleGothic'
else:
    plt.rcParams['font.family'] = 'Noto Sans CJK TC'  # Linux

plt.rcParams['axes.unicode_minus'] = False  # 負號使用 ASCII 減號

@st.cache_resource(show_spinner=True)
def load_model_and_explainer():
    # 下載模型與 LabelEncoder
    model_path = hf_hub_download(
        repo_id="jung-ming/Ocean-Meets-Forest",
        filename="rf_model_with_encoder.pkl",
        repo_type="model"
    )
    bundle = joblib.load(model_path)
    model = bundle["model"]
    le = bundle["label_encoder"]

    # 建立 explainer(避免用 pickle 載入 Numba 編譯物件)
    explainer = shap.TreeExplainer(model, feature_perturbation="interventional")

    return model, le, explainer

model, le, explainer = load_model_and_explainer()

# 建立映射
ship_type_to_code = dict(zip(le.classes_, le.transform(le.classes_)))

st.title("🚢 台中港艘次預測系統")
st.markdown("請輸入以下資訊,模型將預測該月艘次數")

port_count = st.selectbox("航線組合數", list(range(1, 100)))
year = st.selectbox("年", [2020, 2021, 2022, 2023, 2024, 2025])
month = st.selectbox("月", list(range(1, 13)))
ship_type = st.selectbox("船舶種類", list(ship_type_to_code.keys()))

if st.button("🔮 開始預測"):
    ship_type_encoded = ship_type_to_code[ship_type]
    input_df = pd.DataFrame({
        "航線組合數": [port_count],
        "年": [year],
        "月": [month],
        "船舶種類_編碼": [ship_type_encoded]
    })
    pred = model.predict(input_df)[0]
    st.success(f"預測結果:🚢 約為 {pred:.2f} 艘次")

    st.subheader("🧠 模型決策解釋圖(SHAP Waterfall)")

    shap_values = explainer(input_df)
    fig, ax = plt.subplots(figsize=(8, 4))
    shap.plots.waterfall(shap_values[0], show=False)

    # 修正負號顯示問題
    for text in ax.texts:
        if text.get_text().startswith('\u2212'):
            new_text = text.get_text().replace('\u2212', '-')
            text.set_text(new_text)

    st.pyplot(fig)
    plt.close(fig)