VincentCroft commited on
Commit
85e194e
·
verified ·
1 Parent(s): 43102e8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +220 -53
app.py CHANGED
@@ -11,15 +11,40 @@ from huggingface_hub import hf_hub_download
11
  from tensorflow.keras.models import load_model
12
 
13
  # ---------------- Config ----------------
 
14
  LOCAL_MODEL_FILE = os.environ.get("LOCAL_MODEL_FILE", "lstm_cnn_model.h5")
15
- HUB_REPO = os.environ.get("HUB_REPO", "") # optional: "username/repo"
16
- HUB_FILENAME = os.environ.get("HUB_FILENAME", "") # optional: "lstm_cnn_model.h5"
 
 
 
 
 
 
 
 
17
  # ----------------------------------------
18
 
19
  MODEL = None
20
  MODEL_READY = False
21
  MODEL_LOAD_ERROR = None
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  def download_from_hub(repo: str, filename: str):
24
  try:
25
  print(f"[model] Downloading {filename} from {repo} ...", flush=True)
@@ -34,25 +59,44 @@ def load_model_background():
34
  global MODEL, MODEL_READY, MODEL_LOAD_ERROR
35
  try:
36
  model_path = None
 
 
37
  if os.path.exists(LOCAL_MODEL_FILE):
38
  model_path = LOCAL_MODEL_FILE
39
  print(f"[model] Found local model: {model_path}", flush=True)
 
40
  elif HUB_REPO and HUB_FILENAME:
41
  model_path = download_from_hub(HUB_REPO, HUB_FILENAME)
42
  else:
43
  print("[model] No local model file and no HUB_REPO/HUB_FILENAME configured.", flush=True)
 
44
 
45
- if model_path is None:
46
- raise FileNotFoundError("Model file not found locally or on Hugging Face Hub.")
47
-
48
- print(f"[model] Loading model from {model_path} ...", flush=True)
49
- MODEL = load_model(model_path)
 
 
 
 
 
50
  MODEL_READY = True
51
- print("[model] Model loaded OK.", flush=True)
52
- except Exception:
53
  MODEL_LOAD_ERROR = traceback.format_exc()
54
  MODEL_READY = False
55
- print("[model] Error loading model:\\n", MODEL_LOAD_ERROR, flush=True)
 
 
 
 
 
 
 
 
 
 
56
 
57
  # Start model loader in background so Gradio can bind to PORT immediately
58
  loader = threading.Thread(target=load_model_background, daemon=True)
@@ -60,76 +104,193 @@ loader.start()
60
 
61
  # ---------------- Helper functions ----------------
62
  def prepare_input_array(arr, n_timesteps=1, n_features=None):
 
63
  arr = np.array(arr)
 
 
64
  if arr.ndim == 1:
65
  if n_features is None:
66
- return arr.reshape(1, n_timesteps, -1)
67
- return arr.reshape(1, n_timesteps, int(n_features))
 
 
 
68
  elif arr.ndim == 2:
69
- return arr
 
 
 
 
70
  else:
71
  return arr
72
 
73
  def predict_text(text, n_timesteps=1, n_features=None):
 
74
  if not MODEL_READY:
75
  if MODEL_LOAD_ERROR:
76
- return f"模型加载失败:\\n{MODEL_LOAD_ERROR}"
77
- return "模型尚未加载完成,请稍候(后台正在加载)。"
 
78
  try:
79
- arr = np.fromstring(text, sep=',')
80
- x = prepare_input_array(arr, n_timesteps=int(n_timesteps), n_features=(int(n_features) if n_features else None))
81
- probs = MODEL.predict(x)
 
 
 
 
 
 
 
 
 
82
  label = int(np.argmax(probs, axis=1)[0])
83
- return f"预测类别: {label} (概率: {float(np.max(probs)):.4f})"
 
 
 
84
  except Exception as e:
85
- return f"预测失败: {e}"
86
 
87
  def predict_csv(file, n_timesteps=1, n_features=None):
 
88
  if not MODEL_READY:
89
  if MODEL_LOAD_ERROR:
90
- return {"error": f"模型加载失败:\\n{MODEL_LOAD_ERROR}"}
91
- return {"error": "模型尚未加载完成,请稍候(后台正在加载)。"}
 
92
  try:
 
93
  df = pd.read_csv(file.name)
94
  X = df.values
 
 
95
  if n_features:
96
- X = X.reshape(X.shape[0], int(n_timesteps), int(n_features))
97
- preds = MODEL.predict(X)
 
 
 
 
 
 
 
 
 
98
  labels = preds.argmax(axis=1).tolist()
99
- return {"labels": labels, "probs": preds.tolist()}
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  except Exception as e:
101
- return {"error": f"预测失败: {e}"}
102
 
103
  # ---------------- Gradio UI ----------------
104
- with gr.Blocks() as demo:
105
- gr.Markdown("# CNN-LSTM Fault Classification (Spaces)")
106
- if MODEL_READY:
107
- gr.Markdown("模型已加载 ✅")
108
- else:
109
- if MODEL_LOAD_ERROR:
110
- gr.Markdown("**模型加载失败**,请查看运行日志(下方可能有堆栈)。")
111
- else:
112
- gr.Markdown("模型正在后台加载(不会阻塞应用启动),请稍候。")
113
  with gr.Row():
114
- file_in = gr.File(label="上传 CSV(每行 = 一个样本)")
115
- text_in = gr.Textbox(lines=2, placeholder="粘贴逗号分隔的一行特征,例如: 0.1,0.2,0.3,...")
116
- n_ts = gr.Number(value=1, label="timesteps (整型)")
117
- n_feat = gr.Number(value=None, label="features (可选,留空尝试自动推断)")
118
- btn = gr.Button("预测")
119
- out_text = gr.Textbox(label="单样本预测输出")
120
- out_json = gr.JSON(label="批量预测结果 (labels & probs)")
121
-
122
- def run_predict(file, text, n_timesteps, n_features):
123
- if file is not None:
124
- return "CSV 预测完成", predict_csv(file, n_timesteps, n_features)
125
- if text:
126
- return predict_text(text, n_timesteps, n_features), {}
127
- return "请提供 CSV 或特征文本", {}
128
-
129
- btn.click(run_predict, inputs=[file_in, text_in, n_ts, n_feat], outputs=[out_text, out_json])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
  # ---------------- Launch (Spaces-friendly) ----------------
132
  def get_port():
 
133
  try:
134
  return int(os.environ.get("PORT", 7860))
135
  except:
@@ -138,5 +299,11 @@ def get_port():
138
  if __name__ == "__main__":
139
  port = get_port()
140
  print(f"[app] Starting Gradio on 0.0.0.0:{port}", flush=True)
141
- # Do NOT use share=True on Spaces
142
- demo.launch(server_name="0.0.0.0", server_port=port, show_error=True, enable_queue=True)
 
 
 
 
 
 
 
11
  from tensorflow.keras.models import load_model
12
 
13
  # ---------------- Config ----------------
14
+ # 设置默认的模型文件名
15
  LOCAL_MODEL_FILE = os.environ.get("LOCAL_MODEL_FILE", "lstm_cnn_model.h5")
16
+
17
+ # 如果您的模型在 Hugging Face Hub 上,请设置这些环境变量
18
+ # 例如: HUB_REPO="your-username/your-repo-name"
19
+ # HUB_FILENAME="lstm_cnn_model.h5"
20
+ HUB_REPO = os.environ.get("HUB_REPO", "")
21
+ HUB_FILENAME = os.environ.get("HUB_FILENAME", "")
22
+
23
+ # 也可以直接在这里硬编码(如果不想用环境变量)
24
+ # HUB_REPO = "your-username/your-repo-name" # 替换为您的实际仓库
25
+ # HUB_FILENAME = "lstm_cnn_model.h5" # 替换为您的实际文件名
26
  # ----------------------------------------
27
 
28
  MODEL = None
29
  MODEL_READY = False
30
  MODEL_LOAD_ERROR = None
31
 
32
+ def create_dummy_model():
33
+ """创建一个虚拟模型用于演示(当真实模型不可用时)"""
34
+ from tensorflow.keras.models import Sequential
35
+ from tensorflow.keras.layers import Dense, LSTM, Conv1D, Flatten, Input
36
+
37
+ print("[model] Creating dummy model for demonstration...", flush=True)
38
+ model = Sequential([
39
+ Input(shape=(1, 10)), # 假设输入形状
40
+ Conv1D(32, 3, activation='relu', padding='same'),
41
+ LSTM(64, return_sequences=False),
42
+ Dense(32, activation='relu'),
43
+ Dense(4, activation='softmax') # 假设4个类别
44
+ ])
45
+ model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
46
+ return model
47
+
48
  def download_from_hub(repo: str, filename: str):
49
  try:
50
  print(f"[model] Downloading {filename} from {repo} ...", flush=True)
 
59
  global MODEL, MODEL_READY, MODEL_LOAD_ERROR
60
  try:
61
  model_path = None
62
+
63
+ # 优先检查本地文件
64
  if os.path.exists(LOCAL_MODEL_FILE):
65
  model_path = LOCAL_MODEL_FILE
66
  print(f"[model] Found local model: {model_path}", flush=True)
67
+ # 如果本地没有,尝试从 Hub 下载
68
  elif HUB_REPO and HUB_FILENAME:
69
  model_path = download_from_hub(HUB_REPO, HUB_FILENAME)
70
  else:
71
  print("[model] No local model file and no HUB_REPO/HUB_FILENAME configured.", flush=True)
72
+ print("[model] Using dummy model for demonstration purposes.", flush=True)
73
 
74
+ if model_path and os.path.exists(model_path):
75
+ print(f"[model] Loading model from {model_path} ...", flush=True)
76
+ MODEL = load_model(model_path)
77
+ print("[model] Model loaded successfully!", flush=True)
78
+ else:
79
+ # 如果没有真实模型,创建虚拟模型以便演示
80
+ print("[model] Creating dummy model since no real model is available.", flush=True)
81
+ MODEL = create_dummy_model()
82
+ print("[model] Dummy model created. Note: This is for demo only!", flush=True)
83
+
84
  MODEL_READY = True
85
+
86
+ except Exception as e:
87
  MODEL_LOAD_ERROR = traceback.format_exc()
88
  MODEL_READY = False
89
+ print("[model] Error loading model:\n", MODEL_LOAD_ERROR, flush=True)
90
+
91
+ # 即使出错也尝试创建虚拟模型
92
+ try:
93
+ print("[model] Attempting to create dummy model as fallback...", flush=True)
94
+ MODEL = create_dummy_model()
95
+ MODEL_READY = True
96
+ MODEL_LOAD_ERROR = None
97
+ print("[model] Dummy model created as fallback.", flush=True)
98
+ except Exception as e2:
99
+ print(f"[model] Failed to create dummy model: {e2}", flush=True)
100
 
101
  # Start model loader in background so Gradio can bind to PORT immediately
102
  loader = threading.Thread(target=load_model_background, daemon=True)
 
104
 
105
  # ---------------- Helper functions ----------------
106
  def prepare_input_array(arr, n_timesteps=1, n_features=None):
107
+ """准备输入数组为模型所需的形状"""
108
  arr = np.array(arr)
109
+
110
+ # 如果是一维数组,重塑为 (1, timesteps, features)
111
  if arr.ndim == 1:
112
  if n_features is None:
113
+ # 自动推断 features
114
+ n_features = len(arr) // n_timesteps
115
+ return arr.reshape(1, n_timesteps, n_features)
116
+
117
+ # 如果是二维数组,假设第一维是 batch
118
  elif arr.ndim == 2:
119
+ if n_features is None:
120
+ n_features = arr.shape[1] // n_timesteps
121
+ return arr.reshape(arr.shape[0], n_timesteps, n_features)
122
+
123
+ # 三维直接返回
124
  else:
125
  return arr
126
 
127
  def predict_text(text, n_timesteps=1, n_features=None):
128
+ """对单个文本输入进行预测"""
129
  if not MODEL_READY:
130
  if MODEL_LOAD_ERROR:
131
+ return f"⚠️ 模型加载失败:\n{MODEL_LOAD_ERROR}"
132
+ return "⏳ 模型正在加载中,请稍候..."
133
+
134
  try:
135
+ # 解析逗号分隔的数字
136
+ arr = np.fromstring(text.strip(), sep=',')
137
+
138
+ if len(arr) == 0:
139
+ return "❌ 输入无效:请提供逗号分隔的数字"
140
+
141
+ # 准备输入
142
+ x = prepare_input_array(arr, n_timesteps=int(n_timesteps),
143
+ n_features=(int(n_features) if n_features else None))
144
+
145
+ # 预测
146
+ probs = MODEL.predict(x, verbose=0)
147
  label = int(np.argmax(probs, axis=1)[0])
148
+ confidence = float(np.max(probs))
149
+
150
+ return f"✅ 预测类别: {label}\n置信度: {confidence:.2%}"
151
+
152
  except Exception as e:
153
+ return f"预测失败: {str(e)}"
154
 
155
  def predict_csv(file, n_timesteps=1, n_features=None):
156
+ """对CSV文件进行批量预测"""
157
  if not MODEL_READY:
158
  if MODEL_LOAD_ERROR:
159
+ return None, f"⚠️ 模型加载失败:\n{MODEL_LOAD_ERROR}"
160
+ return None, " 模型正在加载中,请稍候..."
161
+
162
  try:
163
+ # 读取CSV
164
  df = pd.read_csv(file.name)
165
  X = df.values
166
+
167
+ # 准备输入
168
  if n_features:
169
+ n_samples = X.shape[0]
170
+ X = X.reshape(n_samples, int(n_timesteps), int(n_features))
171
+ else:
172
+ # 自动推断
173
+ n_samples = X.shape[0]
174
+ n_features_total = X.shape[1]
175
+ n_features = n_features_total // int(n_timesteps)
176
+ X = X.reshape(n_samples, int(n_timesteps), n_features)
177
+
178
+ # 批量预测
179
+ preds = MODEL.predict(X, verbose=0)
180
  labels = preds.argmax(axis=1).tolist()
181
+
182
+ # 创建结果DataFrame
183
+ result_df = pd.DataFrame({
184
+ 'Sample': range(len(labels)),
185
+ 'Predicted_Label': labels,
186
+ 'Confidence': [float(np.max(p)) for p in preds]
187
+ })
188
+
189
+ # 添加各类别的概率
190
+ for i in range(preds.shape[1]):
191
+ result_df[f'Prob_Class_{i}'] = preds[:, i]
192
+
193
+ return result_df, f"✅ 成功预测 {len(labels)} 个样本"
194
+
195
  except Exception as e:
196
+ return None, f"预测失败: {str(e)}"
197
 
198
  # ---------------- Gradio UI ----------------
199
+ with gr.Blocks(title="CNN-LSTM 故障分类") as demo:
200
+ gr.Markdown("# 🤖 CNN-LSTM 故障分类系统")
201
+
202
+ # 显示模型状态
 
 
 
 
 
203
  with gr.Row():
204
+ with gr.Column():
205
+ status_text = gr.Markdown(" 模型状态检查中...")
206
+
207
+ gr.Markdown("---")
208
+
209
+ # 输入选项卡
210
+ with gr.Tabs():
211
+ with gr.Tab("📝 单样本预测"):
212
+ text_in = gr.Textbox(
213
+ lines=3,
214
+ placeholder="输入逗号分隔的特征值,例如:\n0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0",
215
+ label="特征输入"
216
+ )
217
+
218
+ with gr.Row():
219
+ n_ts_text = gr.Number(value=1, label="时间步数 (timesteps)", precision=0)
220
+ n_feat_text = gr.Number(value=10, label="特征数 (features)", precision=0)
221
+
222
+ predict_btn = gr.Button("🔍 预测", variant="primary")
223
+ out_text = gr.Textbox(label="预测结果", lines=3)
224
+
225
+ with gr.Tab("📊 批量预测"):
226
+ file_in = gr.File(
227
+ label="上传 CSV 文件",
228
+ file_types=[".csv"],
229
+ file_count="single"
230
+ )
231
+
232
+ gr.Markdown("**CSV 格式说明:** 每行代表一个样本,列为特征值")
233
+
234
+ with gr.Row():
235
+ n_ts_csv = gr.Number(value=1, label="时间步数 (timesteps)", precision=0)
236
+ n_feat_csv = gr.Number(value=10, label="特征数 (features,可选)", precision=0)
237
+
238
+ predict_csv_btn = gr.Button("🔍 批量预测", variant="primary")
239
+
240
+ out_csv_msg = gr.Textbox(label="处理消息", lines=2)
241
+ out_csv_df = gr.Dataframe(label="预测结果表格", interactive=False)
242
+
243
+ gr.Markdown("---")
244
+
245
+ # 使用说明
246
+ with gr.Accordion("📖 使用说明", open=False):
247
+ gr.Markdown("""
248
+ ### 输入格式
249
+ - **单样本**:输入逗号分隔的数值特征
250
+ - **批量预测**:上传CSV文件,每行一个样本
251
+
252
+ ### 参数说明
253
+ - **时间步数 (timesteps)**:序列的时间步长度
254
+ - **特征数 (features)**:每个时间步的特征维度
255
+ - 输入总长度应该等于 `timesteps × features`
256
+
257
+ ### 示例
258
+ 如果您有10个特征,1个时间步:
259
+ - 输入:`0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0`
260
+ - timesteps = 1, features = 10
261
+ """)
262
+
263
+ # 定期更新状态
264
+ def update_status():
265
+ if MODEL_READY:
266
+ if MODEL_LOAD_ERROR is None:
267
+ return "✅ 模型已就绪"
268
+ else:
269
+ return "⚠️ 使用演示模型(真实模型加载失败)"
270
+ elif MODEL_LOAD_ERROR:
271
+ return f"❌ 模型加载失败"
272
+ else:
273
+ return "⏳ 模型加载中..."
274
+
275
+ # 事件绑定
276
+ predict_btn.click(
277
+ fn=predict_text,
278
+ inputs=[text_in, n_ts_text, n_feat_text],
279
+ outputs=out_text
280
+ )
281
+
282
+ predict_csv_btn.click(
283
+ fn=predict_csv,
284
+ inputs=[file_in, n_ts_csv, n_feat_csv],
285
+ outputs=[out_csv_df, out_csv_msg]
286
+ )
287
+
288
+ # 页面加载时更新状态
289
+ demo.load(fn=update_status, outputs=status_text)
290
 
291
  # ---------------- Launch (Spaces-friendly) ----------------
292
  def get_port():
293
+ """获取端口号,优先使用环境变量"""
294
  try:
295
  return int(os.environ.get("PORT", 7860))
296
  except:
 
299
  if __name__ == "__main__":
300
  port = get_port()
301
  print(f"[app] Starting Gradio on 0.0.0.0:{port}", flush=True)
302
+
303
+ # 新版 Gradio 不再需要 enable_queue 参数
304
+ # queue 功能现在是默认启用的
305
+ demo.launch(
306
+ server_name="0.0.0.0",
307
+ server_port=port,
308
+ show_error=True
309
+ )