Ashleygxr commited on
Commit
14a6b31
·
verified ·
1 Parent(s): 2727f38

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -0
app.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ import xgboost as xgb
4
+ import shap
5
+ import gradio as gr
6
+ import matplotlib.pyplot as plt
7
+ import matplotlib
8
+
9
+ matplotlib.use("Agg") # 防止服务器无图形界面时报错
10
+
11
+ # 指定输入特征顺序
12
+ feature_names = ["CT value(HU)", "Tumor size(cm)", "ctDNA", "CEA", "Location", "CYFRA21-1", "CA125", "LDH"]
13
+
14
+ # 加载模型(确保是用 sklearn API 训练并保存的)
15
+ model = xgb.XGBClassifier()
16
+ model.load_model("./xgb_model.json")
17
+ model.get_booster().feature_names = feature_names
18
+
19
+ # 初始化 SHAP 解释器
20
+ explainer = shap.Explainer(model)
21
+
22
+
23
+ # 预测函数
24
+ def predict_probability(CT_value, Tumor_size, ctDNA, CEA, Location, CYFRA21_1, CA125, LDH):
25
+ input_data = pd.DataFrame(
26
+ [[CT_value, Tumor_size, ctDNA, CEA, Location, CYFRA21_1, CA125, LDH]], columns=feature_names
27
+ )
28
+
29
+ # 将 Location 和 ctDNA 转换为数值型
30
+ input_data["Location"] = input_data["Location"].map({"Central": 1, "Peripheral": 0})
31
+ input_data["ctDNA"] = input_data["ctDNA"].map({"Positive": 1, "Negative": 0})
32
+ # 预测
33
+ try:
34
+ prob = model.predict_proba(input_data)[0][1]
35
+ except Exception as e:
36
+ return f"预测出错: {e}", None
37
+
38
+ # 计算 SHAP 值
39
+ try:
40
+ shap_values = explainer(input_data)
41
+
42
+ # 绘图
43
+ shap.plots.waterfall(shap_values[0], show=False)
44
+ plt.title("SHAP Waterfall Plot")
45
+ plt.savefig("shap_plot.png", bbox_inches="tight", dpi=300)
46
+ plt.close()
47
+ except Exception as e:
48
+ return f"SHAP 图生成失败: {e}", None
49
+
50
+ return f"阳性概率: {prob:.2%}", "shap_plot.png"
51
+
52
+
53
+ demo = gr.Interface(
54
+ fn=predict_probability,
55
+ inputs=[
56
+ gr.Number(label="CT value(HU)"),
57
+ gr.Number(label="Tumor size(cm)"),
58
+ gr.Dropdown(choices=["Positive", "Negative"], label="ctDNA"),
59
+ gr.Number(label="CEA (ng/mL) Normal range: 0-5"),
60
+ gr.Dropdown(choices=["Central", "Peripheral"], label="Location"), # 修改为 Dropdown 类型
61
+ gr.Number(label="CYFRA21-1 (ng/mL) Normal range: 0-5"),
62
+ gr.Number(label="CA125 (U/mL) Normal range: 0-35"),
63
+ gr.Number(label="LDH (U/L) Normal range: 120-250"),
64
+ ],
65
+ outputs=[
66
+ gr.Textbox(label="预测结果"),
67
+ gr.Image(type="filepath", label="SHAP Waterfall Plot"),
68
+ ],
69
+ title="淋巴结转移预测",
70
+ description="输入变量,获取预测阳性概率及SHAP解释图",
71
+ )
72
+
73
+ demo.launch(share=True)