ML_predict_LNM / app.py
Ashleygxr's picture
Upload app.py
14a6b31 verified
raw
history blame
2.45 kB
import numpy as np
import pandas as pd
import xgboost as xgb
import shap
import gradio as gr
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use("Agg") # 防止服务器无图形界面时报错
# 指定输入特征顺序
feature_names = ["CT value(HU)", "Tumor size(cm)", "ctDNA", "CEA", "Location", "CYFRA21-1", "CA125", "LDH"]
# 加载模型(确保是用 sklearn API 训练并保存的)
model = xgb.XGBClassifier()
model.load_model("./xgb_model.json")
model.get_booster().feature_names = feature_names
# 初始化 SHAP 解释器
explainer = shap.Explainer(model)
# 预测函数
def predict_probability(CT_value, Tumor_size, ctDNA, CEA, Location, CYFRA21_1, CA125, LDH):
input_data = pd.DataFrame(
[[CT_value, Tumor_size, ctDNA, CEA, Location, CYFRA21_1, CA125, LDH]], columns=feature_names
)
# 将 Location 和 ctDNA 转换为数值型
input_data["Location"] = input_data["Location"].map({"Central": 1, "Peripheral": 0})
input_data["ctDNA"] = input_data["ctDNA"].map({"Positive": 1, "Negative": 0})
# 预测
try:
prob = model.predict_proba(input_data)[0][1]
except Exception as e:
return f"预测出错: {e}", None
# 计算 SHAP 值
try:
shap_values = explainer(input_data)
# 绘图
shap.plots.waterfall(shap_values[0], show=False)
plt.title("SHAP Waterfall Plot")
plt.savefig("shap_plot.png", bbox_inches="tight", dpi=300)
plt.close()
except Exception as e:
return f"SHAP 图生成失败: {e}", None
return f"阳性概率: {prob:.2%}", "shap_plot.png"
demo = gr.Interface(
fn=predict_probability,
inputs=[
gr.Number(label="CT value(HU)"),
gr.Number(label="Tumor size(cm)"),
gr.Dropdown(choices=["Positive", "Negative"], label="ctDNA"),
gr.Number(label="CEA (ng/mL) Normal range: 0-5"),
gr.Dropdown(choices=["Central", "Peripheral"], label="Location"), # 修改为 Dropdown 类型
gr.Number(label="CYFRA21-1 (ng/mL) Normal range: 0-5"),
gr.Number(label="CA125 (U/mL) Normal range: 0-35"),
gr.Number(label="LDH (U/L) Normal range: 120-250"),
],
outputs=[
gr.Textbox(label="预测结果"),
gr.Image(type="filepath", label="SHAP Waterfall Plot"),
],
title="淋巴结转移预测",
description="输入变量,获取预测阳性概率及SHAP解释图",
)
demo.launch(share=True)