Spaces:
Sleeping
Sleeping
# 建立 .streamlit/config.toml 避免 Hugging Face 權限錯誤 | |
import os | |
os.makedirs(".streamlit", exist_ok=True) | |
with open(".streamlit/config.toml", "w") as f: | |
f.write(""" | |
[server] | |
headless = true | |
port = 7860 | |
enableCORS = true | |
""") | |
import streamlit as st | |
import joblib | |
import pandas as pd | |
import shap | |
import matplotlib.pyplot as plt | |
import platform | |
from huggingface_hub import hf_hub_download | |
# 跨平台字型設定 | |
if platform.system() == 'Windows': | |
plt.rcParams['font.family'] = 'Microsoft JhengHei' | |
elif platform.system() == 'Darwin': # macOS | |
plt.rcParams['font.family'] = 'AppleGothic' | |
else: | |
plt.rcParams['font.family'] = 'Noto Sans CJK TC' # Linux | |
plt.rcParams['axes.unicode_minus'] = False # 負號使用 ASCII 減號 | |
def load_model_and_explainer(): | |
# 下載模型與 LabelEncoder | |
model_path = hf_hub_download( | |
repo_id="jung-ming/Ocean-Meets-Forest", | |
filename="rf_model_with_encoder.pkl", | |
repo_type="model" | |
) | |
bundle = joblib.load(model_path) | |
model = bundle["model"] | |
le = bundle["label_encoder"] | |
# 建立 explainer(避免用 pickle 載入 Numba 編譯物件) | |
explainer = shap.TreeExplainer(model, feature_perturbation="interventional") | |
return model, le, explainer | |
model, le, explainer = load_model_and_explainer() | |
# 建立映射 | |
ship_type_to_code = dict(zip(le.classes_, le.transform(le.classes_))) | |
st.title("🚢 台中港艘次預測系統") | |
st.markdown("請輸入以下資訊,模型將預測該月艘次數") | |
port_count = st.selectbox("航線組合數", list(range(1, 100))) | |
year = st.selectbox("年", [2020, 2021, 2022, 2023, 2024, 2025]) | |
month = st.selectbox("月", list(range(1, 13))) | |
ship_type = st.selectbox("船舶種類", list(ship_type_to_code.keys())) | |
if st.button("🔮 開始預測"): | |
ship_type_encoded = ship_type_to_code[ship_type] | |
input_df = pd.DataFrame({ | |
"航線組合數": [port_count], | |
"年": [year], | |
"月": [month], | |
"船舶種類_編碼": [ship_type_encoded] | |
}) | |
pred = model.predict(input_df)[0] | |
st.success(f"預測結果:🚢 約為 {pred:.2f} 艘次") | |
st.subheader("🧠 模型決策解釋圖(SHAP Waterfall)") | |
shap_values = explainer(input_df) | |
fig, ax = plt.subplots(figsize=(8, 4)) | |
shap.plots.waterfall(shap_values[0], show=False) | |
# 修正負號顯示問題 | |
for text in ax.texts: | |
if text.get_text().startswith('\u2212'): | |
new_text = text.get_text().replace('\u2212', '-') | |
text.set_text(new_text) | |
st.pyplot(fig) | |
plt.close(fig) | |