jung-ming's picture
Upload 2 files
1f1a361 verified
raw
history blame
2.62 kB
# 建立 .streamlit/config.toml 避免 Hugging Face 權限錯誤
import os
os.makedirs(".streamlit", exist_ok=True)
with open(".streamlit/config.toml", "w") as f:
f.write("""
[server]
headless = true
port = 7860
enableCORS = true
""")
import streamlit as st
import joblib
import pandas as pd
import shap
import matplotlib.pyplot as plt
import platform
from huggingface_hub import hf_hub_download
# 跨平台字型設定
if platform.system() == 'Windows':
plt.rcParams['font.family'] = 'Microsoft JhengHei'
elif platform.system() == 'Darwin': # macOS
plt.rcParams['font.family'] = 'AppleGothic'
else:
plt.rcParams['font.family'] = 'Noto Sans CJK TC' # Linux
plt.rcParams['axes.unicode_minus'] = False # 負號使用 ASCII 減號
@st.cache_resource(show_spinner=True)
def load_model_and_explainer():
# 下載模型與 LabelEncoder
model_path = hf_hub_download(
repo_id="jung-ming/Ocean-Meets-Forest",
filename="rf_model_with_encoder.pkl",
repo_type="model"
)
bundle = joblib.load(model_path)
model = bundle["model"]
le = bundle["label_encoder"]
# 建立 explainer(避免用 pickle 載入 Numba 編譯物件)
explainer = shap.TreeExplainer(model, feature_perturbation="interventional")
return model, le, explainer
model, le, explainer = load_model_and_explainer()
# 建立映射
ship_type_to_code = dict(zip(le.classes_, le.transform(le.classes_)))
st.title("🚢 台中港艘次預測系統")
st.markdown("請輸入以下資訊,模型將預測該月艘次數")
port_count = st.selectbox("航線組合數", list(range(1, 100)))
year = st.selectbox("年", [2020, 2021, 2022, 2023, 2024, 2025])
month = st.selectbox("月", list(range(1, 13)))
ship_type = st.selectbox("船舶種類", list(ship_type_to_code.keys()))
if st.button("🔮 開始預測"):
ship_type_encoded = ship_type_to_code[ship_type]
input_df = pd.DataFrame({
"航線組合數": [port_count],
"年": [year],
"月": [month],
"船舶種類_編碼": [ship_type_encoded]
})
pred = model.predict(input_df)[0]
st.success(f"預測結果:🚢 約為 {pred:.2f} 艘次")
st.subheader("🧠 模型決策解釋圖(SHAP Waterfall)")
shap_values = explainer(input_df)
fig, ax = plt.subplots(figsize=(8, 4))
shap.plots.waterfall(shap_values[0], show=False)
# 修正負號顯示問題
for text in ax.texts:
if text.get_text().startswith('\u2212'):
new_text = text.get_text().replace('\u2212', '-')
text.set_text(new_text)
st.pyplot(fig)
plt.close(fig)