Spaces:
Sleeping
Sleeping
import os | |
os.environ["HF_HOME"] = "/tmp/.cache/huggingface" | |
os.environ["MPLCONFIGDIR"] = "/tmp/.config/matplotlib" | |
import streamlit as st | |
import joblib | |
import pandas as pd | |
import shap | |
import matplotlib.pyplot as plt | |
import platform | |
from huggingface_hub import hf_hub_download | |
import matplotlib.font_manager as fm | |
# 自訂函數:嘗試找系統中可用的中文字型路徑 | |
def find_chinese_font(): | |
for font_path in fm.findSystemFonts(fontext='ttf'): | |
# 這裡可依你系統字型名關鍵字調整 | |
if ("NotoSans" in font_path and ("TC" in font_path or "TraditionalChinese" in font_path)) \ | |
or "STHeiti" in font_path or "Heiti" in font_path or "LiHei" in font_path: | |
return font_path | |
return None | |
chinese_font_path = find_chinese_font() | |
if chinese_font_path: | |
chinese_font_prop = fm.FontProperties(fname=chinese_font_path) | |
print(f"使用中文字型檔: {chinese_font_path}") | |
else: | |
chinese_font_prop = None | |
print("找不到適合的中文字型檔,請確認系統已安裝中文字型") | |
# matplotlib 全局字型設定(Fallback用) | |
if platform.system() == 'Windows': | |
plt.rcParams['font.family'] = 'Microsoft JhengHei' | |
elif platform.system() == 'Darwin': # macOS | |
plt.rcParams['font.family'] = 'AppleGothic' | |
else: | |
# Linux 預設用找到的字型名稱,沒找到就用 DejaVu Sans | |
if chinese_font_prop: | |
plt.rcParams['font.family'] = chinese_font_prop.get_name() | |
else: | |
plt.rcParams['font.family'] = 'DejaVu Sans' | |
plt.rcParams['axes.unicode_minus'] = False # 負號用 ASCII 減號 | |
def load_model_and_explainer(): | |
model_path = hf_hub_download( | |
repo_id="jung-ming/Ocean-Meets-Forest", | |
filename="rf_model_with_encoder.pkl", | |
repo_type="model" | |
) | |
bundle = joblib.load(model_path) | |
model = bundle["model"] | |
le = bundle["label_encoder"] | |
explainer = shap.TreeExplainer(model, feature_perturbation="interventional") | |
return model, le, explainer | |
model, le, explainer = load_model_and_explainer() | |
ship_type_to_code = dict(zip(le.classes_, le.transform(le.classes_))) | |
st.title("🚢 台中港艘次預測系統") | |
st.markdown("請輸入以下資訊,模型將預測該月艘次數") | |
port_count = st.selectbox("航線組合數", list(range(1, 100))) | |
year = st.selectbox("年", [2020, 2021, 2022, 2023, 2024, 2025]) | |
month = st.selectbox("月", list(range(1, 13))) | |
ship_type = st.selectbox("船舶種類", list(ship_type_to_code.keys())) | |
if st.button("🔮 開始預測"): | |
ship_type_encoded = ship_type_to_code[ship_type] | |
# input_df 用中文欄位名,符合模型訓練時格式 | |
input_df = pd.DataFrame({ | |
"航線組合數": [port_count], | |
"年": [year], | |
"月": [month], | |
"船舶種類_編碼": [ship_type_encoded] | |
}) | |
pred = model.predict(input_df)[0] | |
st.success(f"預測結果:🚢 約為 {pred:.2f} 艘次") | |
st.subheader("🧠 模型決策解釋圖(SHAP Waterfall plot)") | |
shap_values = explainer(input_df) | |
# 把SHAP特徵名稱轉成英文,但不改input_df欄位名稱 | |
feature_name_map = { | |
"航線組合數": "Route Count", | |
"年": "Year", | |
"月": "Month", | |
"船舶種類_編碼": "Ship Type Code" | |
} | |
shap_values.feature_names = [feature_name_map.get(f, f) for f in shap_values.feature_names] | |
ax = shap.plots.waterfall(shap_values[0], show=False) | |
# 負號替換(避免顯示問題) | |
for text in ax.texts: | |
if text.get_text().startswith('\u2212'): | |
text.set_text(text.get_text().replace('\u2212', '-')) | |
st.pyplot(ax.figure) | |
plt.close(ax.figure) |