VincentCroft commited on
Commit
43102e8
·
verified ·
1 Parent(s): 64866ab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -93
app.py CHANGED
@@ -1,85 +1,80 @@
1
- # lstm_cnn_app.py (modified)
2
- """
3
- Robust Gradio app for CNN-LSTM fault classification.
4
-
5
- Features added:
6
- - Prefer local model file; optionally download from Hugging Face Hub if HUB_REPO/HUB_FILENAME set.
7
- - If no model found, app still starts but prediction functions return friendly message.
8
- - Port selection:
9
- * If GRADIO_SERVER_PORT or PORT env var is set, try that.
10
- * Otherwise find a free ephemeral port and use it.
11
- * If binding fails, fall back to demo.launch() with no explicit port (Gradio picks).
12
- - Reduces TF logging noise via TF_CPP_MIN_LOG_LEVEL (optional).
13
- """
14
  import os
15
- import socket
 
16
  import numpy as np
17
  import pandas as pd
18
  import gradio as gr
19
- from tensorflow.keras.models import load_model
20
  from huggingface_hub import hf_hub_download
21
 
22
- # Reduce TensorFlow log noise (keeps warnings but hides info/debug)
23
- os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "2")
 
 
 
 
 
 
24
 
25
- # CONFIG: change these if your model filename/repo are different
26
- LOCAL_MODEL_FILE = "lstm_cnn_model.h5"
27
- HUB_REPO = "" # e.g., "username/lstm-cnn-model" (leave empty to disable)
28
- HUB_FILENAME = "" # e.g., "lstm_cnn_model.h5"
29
 
30
  def download_from_hub(repo: str, filename: str):
31
  try:
32
- print(f"Downloading {filename} from {repo} ...")
33
  path = hf_hub_download(repo_id=repo, filename=filename)
34
- print("Downloaded to:", path)
35
  return path
36
  except Exception as e:
37
- print("Failed to download from hub:", e)
38
  return None
39
 
40
- def get_model_path():
41
- # Prefer local file
42
- if os.path.exists(LOCAL_MODEL_FILE):
43
- return LOCAL_MODEL_FILE
44
- # Try env override for local path (handy in Spaces)
45
- alt = os.environ.get("MODEL_FILE_PATH")
46
- if alt and os.path.exists(alt):
47
- return alt
48
- # Try hub
49
- if HUB_REPO and HUB_FILENAME:
50
- return download_from_hub(HUB_REPO, HUB_FILENAME)
51
- return None
52
-
53
- def try_load_model(path):
54
  try:
55
- m = load_model(path)
56
- print("Loaded model:", path)
57
- return m
58
- except Exception as e:
59
- print("Failed to load model:", e)
60
- return None
61
-
62
- MODEL_PATH = get_model_path()
63
- MODEL = try_load_model(MODEL_PATH) if MODEL_PATH else None
64
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  def prepare_input_array(arr, n_timesteps=1, n_features=None):
66
  arr = np.array(arr)
67
  if arr.ndim == 1:
68
  if n_features is None:
69
- # assume arr is flattened timesteps*features
70
  return arr.reshape(1, n_timesteps, -1)
71
  return arr.reshape(1, n_timesteps, int(n_features))
72
  elif arr.ndim == 2:
73
- # treat as (timesteps, features) -> add batch dim
74
- if arr.shape[0] == 1:
75
- return arr.reshape(1, arr.shape[1], -1)
76
  return arr
77
  else:
78
  return arr
79
 
80
  def predict_text(text, n_timesteps=1, n_features=None):
81
- if MODEL is None:
82
- return "模型未加载。请上传 'lstm_cnn_model.h5' 到 Space 根目录,或设置 HUB_REPO/HUB_FILENAME。"
 
 
83
  try:
84
  arr = np.fromstring(text, sep=',')
85
  x = prepare_input_array(arr, n_timesteps=int(n_timesteps), n_features=(int(n_features) if n_features else None))
@@ -90,8 +85,10 @@ def predict_text(text, n_timesteps=1, n_features=None):
90
  return f"预测失败: {e}"
91
 
92
  def predict_csv(file, n_timesteps=1, n_features=None):
93
- if MODEL is None:
94
- return {"error": "模型未加载。请上传 'lstm_cnn_model.h5' 到 Space 根目录,或设置 HUB_REPO/HUB_FILENAME。"}
 
 
95
  try:
96
  df = pd.read_csv(file.name)
97
  X = df.values
@@ -103,13 +100,16 @@ def predict_csv(file, n_timesteps=1, n_features=None):
103
  except Exception as e:
104
  return {"error": f"预测失败: {e}"}
105
 
106
- # Gradio UI
107
  with gr.Blocks() as demo:
108
- gr.Markdown("# CNN-LSTM Fault Classification")
109
- if MODEL is None:
110
- gr.Markdown("**注意**:未检测到模型 (lstm_cnn_model.h5)。请上传模型或在代码中设置 HUB_REPO/HUB_FILENAME。应用仍会启动,但预测不可用。")
111
  else:
112
- gr.Markdown("模型已加载,可以上传 CSV 或粘贴逗号分隔的一行特征进行预测。")
 
 
 
113
  with gr.Row():
114
  file_in = gr.File(label="上传 CSV(每行 = 一个样本)")
115
  text_in = gr.Textbox(lines=2, placeholder="粘贴逗号分隔的一行特征,例如: 0.1,0.2,0.3,...")
@@ -128,36 +128,15 @@ with gr.Blocks() as demo:
128
 
129
  btn.click(run_predict, inputs=[file_in, text_in, n_ts, n_feat], outputs=[out_text, out_json])
130
 
131
- # Robust port selection & launch helper
132
- def find_free_port():
133
- s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
134
- s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
135
- s.bind(('', 0))
136
- addr, port = s.getsockname()
137
- s.close()
138
- return port
139
-
140
- def get_desired_port():
141
- # priority: GRADIO_SERVER_PORT -> PORT -> auto find
142
- p = os.environ.get("GRADIO_SERVER_PORT") or os.environ.get("PORT")
143
- if p:
144
- try:
145
- return int(p)
146
- except:
147
- pass
148
- # fallback to ephemeral free port
149
- return find_free_port()
150
-
151
- if __name__ == '__main__':
152
- port = None
153
  try:
154
- port = get_desired_port()
155
- print(f"Launching server on port {port} (server_name=0.0.0.0)")
156
- demo.launch(server_name='0.0.0.0', server_port=port)
157
- except OSError as e:
158
- print("Failed to bind requested port:", e)
159
- print("Falling back to default demo.launch() (no explicit port).")
160
- # last fallback: let Gradio choose/handle
161
- # Colab / local debug: 生成外网临时链接(share=True
162
- demo.launch(share=True)
163
-
 
1
+ # app.py -- Spaces-ready robust version for lstm_cnn model
 
 
 
 
 
 
 
 
 
 
 
 
2
  import os
3
+ import threading
4
+ import traceback
5
  import numpy as np
6
  import pandas as pd
7
  import gradio as gr
 
8
  from huggingface_hub import hf_hub_download
9
 
10
+ # Use Keras model loader (change if you have PyTorch)
11
+ from tensorflow.keras.models import load_model
12
+
13
+ # ---------------- Config ----------------
14
+ LOCAL_MODEL_FILE = os.environ.get("LOCAL_MODEL_FILE", "lstm_cnn_model.h5")
15
+ HUB_REPO = os.environ.get("HUB_REPO", "") # optional: "username/repo"
16
+ HUB_FILENAME = os.environ.get("HUB_FILENAME", "") # optional: "lstm_cnn_model.h5"
17
+ # ----------------------------------------
18
 
19
+ MODEL = None
20
+ MODEL_READY = False
21
+ MODEL_LOAD_ERROR = None
 
22
 
23
  def download_from_hub(repo: str, filename: str):
24
  try:
25
+ print(f"[model] Downloading {filename} from {repo} ...", flush=True)
26
  path = hf_hub_download(repo_id=repo, filename=filename)
27
+ print("[model] Downloaded to:", path, flush=True)
28
  return path
29
  except Exception as e:
30
+ print("[model] Hub download failed:", e, flush=True)
31
  return None
32
 
33
+ def load_model_background():
34
+ global MODEL, MODEL_READY, MODEL_LOAD_ERROR
 
 
 
 
 
 
 
 
 
 
 
 
35
  try:
36
+ model_path = None
37
+ if os.path.exists(LOCAL_MODEL_FILE):
38
+ model_path = LOCAL_MODEL_FILE
39
+ print(f"[model] Found local model: {model_path}", flush=True)
40
+ elif HUB_REPO and HUB_FILENAME:
41
+ model_path = download_from_hub(HUB_REPO, HUB_FILENAME)
42
+ else:
43
+ print("[model] No local model file and no HUB_REPO/HUB_FILENAME configured.", flush=True)
44
+
45
+ if model_path is None:
46
+ raise FileNotFoundError("Model file not found locally or on Hugging Face Hub.")
47
+
48
+ print(f"[model] Loading model from {model_path} ...", flush=True)
49
+ MODEL = load_model(model_path)
50
+ MODEL_READY = True
51
+ print("[model] Model loaded OK.", flush=True)
52
+ except Exception:
53
+ MODEL_LOAD_ERROR = traceback.format_exc()
54
+ MODEL_READY = False
55
+ print("[model] Error loading model:\\n", MODEL_LOAD_ERROR, flush=True)
56
+
57
+ # Start model loader in background so Gradio can bind to PORT immediately
58
+ loader = threading.Thread(target=load_model_background, daemon=True)
59
+ loader.start()
60
+
61
+ # ---------------- Helper functions ----------------
62
  def prepare_input_array(arr, n_timesteps=1, n_features=None):
63
  arr = np.array(arr)
64
  if arr.ndim == 1:
65
  if n_features is None:
 
66
  return arr.reshape(1, n_timesteps, -1)
67
  return arr.reshape(1, n_timesteps, int(n_features))
68
  elif arr.ndim == 2:
 
 
 
69
  return arr
70
  else:
71
  return arr
72
 
73
  def predict_text(text, n_timesteps=1, n_features=None):
74
+ if not MODEL_READY:
75
+ if MODEL_LOAD_ERROR:
76
+ return f"模型加载失败:\\n{MODEL_LOAD_ERROR}"
77
+ return "模型尚���加载完成,请稍候(后台正在加载)。"
78
  try:
79
  arr = np.fromstring(text, sep=',')
80
  x = prepare_input_array(arr, n_timesteps=int(n_timesteps), n_features=(int(n_features) if n_features else None))
 
85
  return f"预测失败: {e}"
86
 
87
  def predict_csv(file, n_timesteps=1, n_features=None):
88
+ if not MODEL_READY:
89
+ if MODEL_LOAD_ERROR:
90
+ return {"error": f"模型加载失败:\\n{MODEL_LOAD_ERROR}"}
91
+ return {"error": "模型尚未加载完成,请稍候(后台正在加载)。"}
92
  try:
93
  df = pd.read_csv(file.name)
94
  X = df.values
 
100
  except Exception as e:
101
  return {"error": f"预测失败: {e}"}
102
 
103
+ # ---------------- Gradio UI ----------------
104
  with gr.Blocks() as demo:
105
+ gr.Markdown("# CNN-LSTM Fault Classification (Spaces)")
106
+ if MODEL_READY:
107
+ gr.Markdown("模型已加载 ")
108
  else:
109
+ if MODEL_LOAD_ERROR:
110
+ gr.Markdown("**模型加载失败**,请查看运行日志(下方可能有堆栈)。")
111
+ else:
112
+ gr.Markdown("模型正在后台加载(不会阻塞应用启动),请稍候。")
113
  with gr.Row():
114
  file_in = gr.File(label="上传 CSV(每行 = 一个样本)")
115
  text_in = gr.Textbox(lines=2, placeholder="粘贴逗号分隔的一行特征,例如: 0.1,0.2,0.3,...")
 
128
 
129
  btn.click(run_predict, inputs=[file_in, text_in, n_ts, n_feat], outputs=[out_text, out_json])
130
 
131
+ # ---------------- Launch (Spaces-friendly) ----------------
132
+ def get_port():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  try:
134
+ return int(os.environ.get("PORT", 7860))
135
+ except:
136
+ return 7860
137
+
138
+ if __name__ == "__main__":
139
+ port = get_port()
140
+ print(f"[app] Starting Gradio on 0.0.0.0:{port}", flush=True)
141
+ # Do NOT use share=True on Spaces
142
+ demo.launch(server_name="0.0.0.0", server_port=port, show_error=True, enable_queue=True)