Update app.py
Browse files
app.py
CHANGED
@@ -11,15 +11,40 @@ from huggingface_hub import hf_hub_download
|
|
11 |
from tensorflow.keras.models import load_model
|
12 |
|
13 |
# ---------------- Config ----------------
|
|
|
14 |
LOCAL_MODEL_FILE = os.environ.get("LOCAL_MODEL_FILE", "lstm_cnn_model.h5")
|
15 |
-
|
16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
# ----------------------------------------
|
18 |
|
19 |
MODEL = None
|
20 |
MODEL_READY = False
|
21 |
MODEL_LOAD_ERROR = None
|
22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
def download_from_hub(repo: str, filename: str):
|
24 |
try:
|
25 |
print(f"[model] Downloading {filename} from {repo} ...", flush=True)
|
@@ -34,25 +59,44 @@ def load_model_background():
|
|
34 |
global MODEL, MODEL_READY, MODEL_LOAD_ERROR
|
35 |
try:
|
36 |
model_path = None
|
|
|
|
|
37 |
if os.path.exists(LOCAL_MODEL_FILE):
|
38 |
model_path = LOCAL_MODEL_FILE
|
39 |
print(f"[model] Found local model: {model_path}", flush=True)
|
|
|
40 |
elif HUB_REPO and HUB_FILENAME:
|
41 |
model_path = download_from_hub(HUB_REPO, HUB_FILENAME)
|
42 |
else:
|
43 |
print("[model] No local model file and no HUB_REPO/HUB_FILENAME configured.", flush=True)
|
|
|
44 |
|
45 |
-
if model_path
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
50 |
MODEL_READY = True
|
51 |
-
|
52 |
-
except Exception:
|
53 |
MODEL_LOAD_ERROR = traceback.format_exc()
|
54 |
MODEL_READY = False
|
55 |
-
print("[model] Error loading model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
|
57 |
# Start model loader in background so Gradio can bind to PORT immediately
|
58 |
loader = threading.Thread(target=load_model_background, daemon=True)
|
@@ -60,76 +104,193 @@ loader.start()
|
|
60 |
|
61 |
# ---------------- Helper functions ----------------
|
62 |
def prepare_input_array(arr, n_timesteps=1, n_features=None):
|
|
|
63 |
arr = np.array(arr)
|
|
|
|
|
64 |
if arr.ndim == 1:
|
65 |
if n_features is None:
|
66 |
-
|
67 |
-
|
|
|
|
|
|
|
68 |
elif arr.ndim == 2:
|
69 |
-
|
|
|
|
|
|
|
|
|
70 |
else:
|
71 |
return arr
|
72 |
|
73 |
def predict_text(text, n_timesteps=1, n_features=None):
|
|
|
74 |
if not MODEL_READY:
|
75 |
if MODEL_LOAD_ERROR:
|
76 |
-
return f"
|
77 |
-
return "
|
|
|
78 |
try:
|
79 |
-
|
80 |
-
|
81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
label = int(np.argmax(probs, axis=1)[0])
|
83 |
-
|
|
|
|
|
|
|
84 |
except Exception as e:
|
85 |
-
return f"预测失败: {e}"
|
86 |
|
87 |
def predict_csv(file, n_timesteps=1, n_features=None):
|
|
|
88 |
if not MODEL_READY:
|
89 |
if MODEL_LOAD_ERROR:
|
90 |
-
return
|
91 |
-
return
|
|
|
92 |
try:
|
|
|
93 |
df = pd.read_csv(file.name)
|
94 |
X = df.values
|
|
|
|
|
95 |
if n_features:
|
96 |
-
|
97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
labels = preds.argmax(axis=1).tolist()
|
99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
except Exception as e:
|
101 |
-
return
|
102 |
|
103 |
# ---------------- Gradio UI ----------------
|
104 |
-
with gr.Blocks() as demo:
|
105 |
-
gr.Markdown("# CNN-LSTM
|
106 |
-
|
107 |
-
|
108 |
-
else:
|
109 |
-
if MODEL_LOAD_ERROR:
|
110 |
-
gr.Markdown("**模型加载失败**,请查看运行日志(下方可能有堆栈)。")
|
111 |
-
else:
|
112 |
-
gr.Markdown("模型正在后台加载(不会阻塞应用启动),请稍候。")
|
113 |
with gr.Row():
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
|
131 |
# ---------------- Launch (Spaces-friendly) ----------------
|
132 |
def get_port():
|
|
|
133 |
try:
|
134 |
return int(os.environ.get("PORT", 7860))
|
135 |
except:
|
@@ -138,5 +299,11 @@ def get_port():
|
|
138 |
if __name__ == "__main__":
|
139 |
port = get_port()
|
140 |
print(f"[app] Starting Gradio on 0.0.0.0:{port}", flush=True)
|
141 |
-
|
142 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
from tensorflow.keras.models import load_model
|
12 |
|
13 |
# ---------------- Config ----------------
|
14 |
+
# 设置默认的模型文件名
|
15 |
LOCAL_MODEL_FILE = os.environ.get("LOCAL_MODEL_FILE", "lstm_cnn_model.h5")
|
16 |
+
|
17 |
+
# 如果您的模型在 Hugging Face Hub 上,请设置这些环境变量
|
18 |
+
# 例如: HUB_REPO="your-username/your-repo-name"
|
19 |
+
# HUB_FILENAME="lstm_cnn_model.h5"
|
20 |
+
HUB_REPO = os.environ.get("HUB_REPO", "")
|
21 |
+
HUB_FILENAME = os.environ.get("HUB_FILENAME", "")
|
22 |
+
|
23 |
+
# 也可以直接在这里硬编码(如果不想用环境变量)
|
24 |
+
# HUB_REPO = "your-username/your-repo-name" # 替换为您的实际仓库
|
25 |
+
# HUB_FILENAME = "lstm_cnn_model.h5" # 替换为您的实际文件名
|
26 |
# ----------------------------------------
|
27 |
|
28 |
MODEL = None
|
29 |
MODEL_READY = False
|
30 |
MODEL_LOAD_ERROR = None
|
31 |
|
32 |
+
def create_dummy_model():
|
33 |
+
"""创建一个虚拟模型用于演示(当真实模型不可用时)"""
|
34 |
+
from tensorflow.keras.models import Sequential
|
35 |
+
from tensorflow.keras.layers import Dense, LSTM, Conv1D, Flatten, Input
|
36 |
+
|
37 |
+
print("[model] Creating dummy model for demonstration...", flush=True)
|
38 |
+
model = Sequential([
|
39 |
+
Input(shape=(1, 10)), # 假设输入形状
|
40 |
+
Conv1D(32, 3, activation='relu', padding='same'),
|
41 |
+
LSTM(64, return_sequences=False),
|
42 |
+
Dense(32, activation='relu'),
|
43 |
+
Dense(4, activation='softmax') # 假设4个类别
|
44 |
+
])
|
45 |
+
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
|
46 |
+
return model
|
47 |
+
|
48 |
def download_from_hub(repo: str, filename: str):
|
49 |
try:
|
50 |
print(f"[model] Downloading {filename} from {repo} ...", flush=True)
|
|
|
59 |
global MODEL, MODEL_READY, MODEL_LOAD_ERROR
|
60 |
try:
|
61 |
model_path = None
|
62 |
+
|
63 |
+
# 优先检查本地文件
|
64 |
if os.path.exists(LOCAL_MODEL_FILE):
|
65 |
model_path = LOCAL_MODEL_FILE
|
66 |
print(f"[model] Found local model: {model_path}", flush=True)
|
67 |
+
# 如果本地没有,尝试从 Hub 下载
|
68 |
elif HUB_REPO and HUB_FILENAME:
|
69 |
model_path = download_from_hub(HUB_REPO, HUB_FILENAME)
|
70 |
else:
|
71 |
print("[model] No local model file and no HUB_REPO/HUB_FILENAME configured.", flush=True)
|
72 |
+
print("[model] Using dummy model for demonstration purposes.", flush=True)
|
73 |
|
74 |
+
if model_path and os.path.exists(model_path):
|
75 |
+
print(f"[model] Loading model from {model_path} ...", flush=True)
|
76 |
+
MODEL = load_model(model_path)
|
77 |
+
print("[model] Model loaded successfully!", flush=True)
|
78 |
+
else:
|
79 |
+
# 如果没有真实模型,创建虚拟模型以便演示
|
80 |
+
print("[model] Creating dummy model since no real model is available.", flush=True)
|
81 |
+
MODEL = create_dummy_model()
|
82 |
+
print("[model] Dummy model created. Note: This is for demo only!", flush=True)
|
83 |
+
|
84 |
MODEL_READY = True
|
85 |
+
|
86 |
+
except Exception as e:
|
87 |
MODEL_LOAD_ERROR = traceback.format_exc()
|
88 |
MODEL_READY = False
|
89 |
+
print("[model] Error loading model:\n", MODEL_LOAD_ERROR, flush=True)
|
90 |
+
|
91 |
+
# 即使出错也尝试创建虚拟模型
|
92 |
+
try:
|
93 |
+
print("[model] Attempting to create dummy model as fallback...", flush=True)
|
94 |
+
MODEL = create_dummy_model()
|
95 |
+
MODEL_READY = True
|
96 |
+
MODEL_LOAD_ERROR = None
|
97 |
+
print("[model] Dummy model created as fallback.", flush=True)
|
98 |
+
except Exception as e2:
|
99 |
+
print(f"[model] Failed to create dummy model: {e2}", flush=True)
|
100 |
|
101 |
# Start model loader in background so Gradio can bind to PORT immediately
|
102 |
loader = threading.Thread(target=load_model_background, daemon=True)
|
|
|
104 |
|
105 |
# ---------------- Helper functions ----------------
|
106 |
def prepare_input_array(arr, n_timesteps=1, n_features=None):
|
107 |
+
"""准备输入数组为模型所需的形状"""
|
108 |
arr = np.array(arr)
|
109 |
+
|
110 |
+
# 如果是一维数组,重塑为 (1, timesteps, features)
|
111 |
if arr.ndim == 1:
|
112 |
if n_features is None:
|
113 |
+
# 自动推断 features
|
114 |
+
n_features = len(arr) // n_timesteps
|
115 |
+
return arr.reshape(1, n_timesteps, n_features)
|
116 |
+
|
117 |
+
# 如果是二维数组,假设第一维是 batch
|
118 |
elif arr.ndim == 2:
|
119 |
+
if n_features is None:
|
120 |
+
n_features = arr.shape[1] // n_timesteps
|
121 |
+
return arr.reshape(arr.shape[0], n_timesteps, n_features)
|
122 |
+
|
123 |
+
# 三维直接返回
|
124 |
else:
|
125 |
return arr
|
126 |
|
127 |
def predict_text(text, n_timesteps=1, n_features=None):
|
128 |
+
"""对单个文本输入进行预测"""
|
129 |
if not MODEL_READY:
|
130 |
if MODEL_LOAD_ERROR:
|
131 |
+
return f"⚠️ 模型加载失败:\n{MODEL_LOAD_ERROR}"
|
132 |
+
return "⏳ 模型正在加载中,请稍候..."
|
133 |
+
|
134 |
try:
|
135 |
+
# 解析逗号分隔的数字
|
136 |
+
arr = np.fromstring(text.strip(), sep=',')
|
137 |
+
|
138 |
+
if len(arr) == 0:
|
139 |
+
return "❌ 输入无效:请提供逗号分隔的数字"
|
140 |
+
|
141 |
+
# 准备输入
|
142 |
+
x = prepare_input_array(arr, n_timesteps=int(n_timesteps),
|
143 |
+
n_features=(int(n_features) if n_features else None))
|
144 |
+
|
145 |
+
# 预测
|
146 |
+
probs = MODEL.predict(x, verbose=0)
|
147 |
label = int(np.argmax(probs, axis=1)[0])
|
148 |
+
confidence = float(np.max(probs))
|
149 |
+
|
150 |
+
return f"✅ 预测类别: {label}\n置信度: {confidence:.2%}"
|
151 |
+
|
152 |
except Exception as e:
|
153 |
+
return f"❌ 预测失败: {str(e)}"
|
154 |
|
155 |
def predict_csv(file, n_timesteps=1, n_features=None):
|
156 |
+
"""对CSV文件进行批量预测"""
|
157 |
if not MODEL_READY:
|
158 |
if MODEL_LOAD_ERROR:
|
159 |
+
return None, f"⚠️ 模型加载失败:\n{MODEL_LOAD_ERROR}"
|
160 |
+
return None, "⏳ 模型正在加载中,请稍候..."
|
161 |
+
|
162 |
try:
|
163 |
+
# 读取CSV
|
164 |
df = pd.read_csv(file.name)
|
165 |
X = df.values
|
166 |
+
|
167 |
+
# 准备输入
|
168 |
if n_features:
|
169 |
+
n_samples = X.shape[0]
|
170 |
+
X = X.reshape(n_samples, int(n_timesteps), int(n_features))
|
171 |
+
else:
|
172 |
+
# 自动推断
|
173 |
+
n_samples = X.shape[0]
|
174 |
+
n_features_total = X.shape[1]
|
175 |
+
n_features = n_features_total // int(n_timesteps)
|
176 |
+
X = X.reshape(n_samples, int(n_timesteps), n_features)
|
177 |
+
|
178 |
+
# 批量预测
|
179 |
+
preds = MODEL.predict(X, verbose=0)
|
180 |
labels = preds.argmax(axis=1).tolist()
|
181 |
+
|
182 |
+
# 创建结果DataFrame
|
183 |
+
result_df = pd.DataFrame({
|
184 |
+
'Sample': range(len(labels)),
|
185 |
+
'Predicted_Label': labels,
|
186 |
+
'Confidence': [float(np.max(p)) for p in preds]
|
187 |
+
})
|
188 |
+
|
189 |
+
# 添加各类别的概率
|
190 |
+
for i in range(preds.shape[1]):
|
191 |
+
result_df[f'Prob_Class_{i}'] = preds[:, i]
|
192 |
+
|
193 |
+
return result_df, f"✅ 成功预测 {len(labels)} 个样本"
|
194 |
+
|
195 |
except Exception as e:
|
196 |
+
return None, f"❌ 预测失败: {str(e)}"
|
197 |
|
198 |
# ---------------- Gradio UI ----------------
|
199 |
+
with gr.Blocks(title="CNN-LSTM 故障分类") as demo:
|
200 |
+
gr.Markdown("# 🤖 CNN-LSTM 故障分类系统")
|
201 |
+
|
202 |
+
# 显示模型状态
|
|
|
|
|
|
|
|
|
|
|
203 |
with gr.Row():
|
204 |
+
with gr.Column():
|
205 |
+
status_text = gr.Markdown("⏳ 模型状态检查中...")
|
206 |
+
|
207 |
+
gr.Markdown("---")
|
208 |
+
|
209 |
+
# 输入选项卡
|
210 |
+
with gr.Tabs():
|
211 |
+
with gr.Tab("📝 单样本预测"):
|
212 |
+
text_in = gr.Textbox(
|
213 |
+
lines=3,
|
214 |
+
placeholder="输入逗号分隔的特征值,例如:\n0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0",
|
215 |
+
label="特征输入"
|
216 |
+
)
|
217 |
+
|
218 |
+
with gr.Row():
|
219 |
+
n_ts_text = gr.Number(value=1, label="时间步数 (timesteps)", precision=0)
|
220 |
+
n_feat_text = gr.Number(value=10, label="特征数 (features)", precision=0)
|
221 |
+
|
222 |
+
predict_btn = gr.Button("🔍 预测", variant="primary")
|
223 |
+
out_text = gr.Textbox(label="预测结果", lines=3)
|
224 |
+
|
225 |
+
with gr.Tab("📊 批量预测"):
|
226 |
+
file_in = gr.File(
|
227 |
+
label="上传 CSV 文件",
|
228 |
+
file_types=[".csv"],
|
229 |
+
file_count="single"
|
230 |
+
)
|
231 |
+
|
232 |
+
gr.Markdown("**CSV 格式说明:** 每行代表一个样本,列为特征值")
|
233 |
+
|
234 |
+
with gr.Row():
|
235 |
+
n_ts_csv = gr.Number(value=1, label="时间步数 (timesteps)", precision=0)
|
236 |
+
n_feat_csv = gr.Number(value=10, label="特征数 (features,可选)", precision=0)
|
237 |
+
|
238 |
+
predict_csv_btn = gr.Button("🔍 批量预测", variant="primary")
|
239 |
+
|
240 |
+
out_csv_msg = gr.Textbox(label="处理消息", lines=2)
|
241 |
+
out_csv_df = gr.Dataframe(label="预测结果表格", interactive=False)
|
242 |
+
|
243 |
+
gr.Markdown("---")
|
244 |
+
|
245 |
+
# 使用说明
|
246 |
+
with gr.Accordion("📖 使用说明", open=False):
|
247 |
+
gr.Markdown("""
|
248 |
+
### 输入格式
|
249 |
+
- **单样本**:输入逗号分隔的数值特征
|
250 |
+
- **批量预测**:上传CSV文件,每行一个样本
|
251 |
+
|
252 |
+
### 参数说明
|
253 |
+
- **时间步数 (timesteps)**:序列的时间步长度
|
254 |
+
- **特征数 (features)**:每个时间步的特征维度
|
255 |
+
- 输入总长度应该等于 `timesteps × features`
|
256 |
+
|
257 |
+
### 示例
|
258 |
+
如果您有10个特征,1个时间步:
|
259 |
+
- 输入:`0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0`
|
260 |
+
- timesteps = 1, features = 10
|
261 |
+
""")
|
262 |
+
|
263 |
+
# 定期更新状态
|
264 |
+
def update_status():
|
265 |
+
if MODEL_READY:
|
266 |
+
if MODEL_LOAD_ERROR is None:
|
267 |
+
return "✅ 模型已就绪"
|
268 |
+
else:
|
269 |
+
return "⚠️ 使用演示模型(真实模型加载失败)"
|
270 |
+
elif MODEL_LOAD_ERROR:
|
271 |
+
return f"❌ 模型加载失败"
|
272 |
+
else:
|
273 |
+
return "⏳ 模型加载中..."
|
274 |
+
|
275 |
+
# 事件绑定
|
276 |
+
predict_btn.click(
|
277 |
+
fn=predict_text,
|
278 |
+
inputs=[text_in, n_ts_text, n_feat_text],
|
279 |
+
outputs=out_text
|
280 |
+
)
|
281 |
+
|
282 |
+
predict_csv_btn.click(
|
283 |
+
fn=predict_csv,
|
284 |
+
inputs=[file_in, n_ts_csv, n_feat_csv],
|
285 |
+
outputs=[out_csv_df, out_csv_msg]
|
286 |
+
)
|
287 |
+
|
288 |
+
# 页面加载时更新状态
|
289 |
+
demo.load(fn=update_status, outputs=status_text)
|
290 |
|
291 |
# ---------------- Launch (Spaces-friendly) ----------------
|
292 |
def get_port():
|
293 |
+
"""获取端口号,优先使用环境变量"""
|
294 |
try:
|
295 |
return int(os.environ.get("PORT", 7860))
|
296 |
except:
|
|
|
299 |
if __name__ == "__main__":
|
300 |
port = get_port()
|
301 |
print(f"[app] Starting Gradio on 0.0.0.0:{port}", flush=True)
|
302 |
+
|
303 |
+
# 新版 Gradio 不再需要 enable_queue 参数
|
304 |
+
# queue 功能现在是默认启用的
|
305 |
+
demo.launch(
|
306 |
+
server_name="0.0.0.0",
|
307 |
+
server_port=port,
|
308 |
+
show_error=True
|
309 |
+
)
|