jung-ming's picture
Update app.py
c7eac71 verified
import os
os.environ["HF_HOME"] = "/tmp/.cache/huggingface"
os.environ["MPLCONFIGDIR"] = "/tmp/.config/matplotlib"
import streamlit as st
import joblib
import pandas as pd
import shap
import matplotlib.pyplot as plt
import platform
from huggingface_hub import hf_hub_download
import matplotlib.font_manager as fm
# 自訂函數:嘗試找系統中可用的中文字型路徑
def find_chinese_font():
for font_path in fm.findSystemFonts(fontext='ttf'):
# 這裡可依你系統字型名關鍵字調整
if ("NotoSans" in font_path and ("TC" in font_path or "TraditionalChinese" in font_path)) \
or "STHeiti" in font_path or "Heiti" in font_path or "LiHei" in font_path:
return font_path
return None
chinese_font_path = find_chinese_font()
if chinese_font_path:
chinese_font_prop = fm.FontProperties(fname=chinese_font_path)
print(f"使用中文字型檔: {chinese_font_path}")
else:
chinese_font_prop = None
print("找不到適合的中文字型檔,請確認系統已安裝中文字型")
# matplotlib 全局字型設定(Fallback用)
if platform.system() == 'Windows':
plt.rcParams['font.family'] = 'Microsoft JhengHei'
elif platform.system() == 'Darwin': # macOS
plt.rcParams['font.family'] = 'AppleGothic'
else:
# Linux 預設用找到的字型名稱,沒找到就用 DejaVu Sans
if chinese_font_prop:
plt.rcParams['font.family'] = chinese_font_prop.get_name()
else:
plt.rcParams['font.family'] = 'DejaVu Sans'
plt.rcParams['axes.unicode_minus'] = False # 負號用 ASCII 減號
@st.cache_resource(show_spinner=True)
def load_model_and_explainer():
model_path = hf_hub_download(
repo_id="jung-ming/Ocean-Meets-Forest",
filename="rf_model_with_encoder.pkl",
repo_type="model"
)
bundle = joblib.load(model_path)
model = bundle["model"]
le = bundle["label_encoder"]
explainer = shap.TreeExplainer(model, feature_perturbation="interventional")
return model, le, explainer
model, le, explainer = load_model_and_explainer()
ship_type_to_code = dict(zip(le.classes_, le.transform(le.classes_)))
st.title("🚢 台中港艘次預測系統")
st.markdown("請輸入以下資訊,模型將預測該月艘次數")
port_count = st.selectbox("航線組合數", list(range(1, 100)))
year = st.selectbox("年", [2020, 2021, 2022, 2023, 2024, 2025])
month = st.selectbox("月", list(range(1, 13)))
ship_type = st.selectbox("船舶種類", list(ship_type_to_code.keys()))
if st.button("🔮 開始預測"):
ship_type_encoded = ship_type_to_code[ship_type]
# input_df 用中文欄位名,符合模型訓練時格式
input_df = pd.DataFrame({
"航線組合數": [port_count],
"年": [year],
"月": [month],
"船舶種類_編碼": [ship_type_encoded]
})
pred = model.predict(input_df)[0]
st.success(f"預測結果:🚢 約為 {pred:.2f} 艘次")
st.subheader("🧠 模型決策解釋圖(SHAP Waterfall plot)")
shap_values = explainer(input_df)
# 把SHAP特徵名稱轉成英文,但不改input_df欄位名稱
feature_name_map = {
"航線組合數": "Route Count",
"年": "Year",
"月": "Month",
"船舶種類_編碼": "Ship Type Code"
}
shap_values.feature_names = [feature_name_map.get(f, f) for f in shap_values.feature_names]
ax = shap.plots.waterfall(shap_values[0], show=False)
# 負號替換(避免顯示問題)
for text in ax.texts:
if text.get_text().startswith('\u2212'):
text.set_text(text.get_text().replace('\u2212', '-'))
st.pyplot(ax.figure)
plt.close(ax.figure)