# app.py -- Spaces-ready robust version for lstm_cnn model import os import threading import traceback import numpy as np import pandas as pd import gradio as gr from huggingface_hub import hf_hub_download # Use Keras model loader (change if you have PyTorch) from tensorflow.keras.models import load_model # ---------------- Config ---------------- # 设置默认的模型文件名 LOCAL_MODEL_FILE = os.environ.get("LOCAL_MODEL_FILE", "lstm_cnn_model.h5") # 如果您的模型在 Hugging Face Hub 上,请设置这些环境变量 # 例如: HUB_REPO="your-username/your-repo-name" # HUB_FILENAME="lstm_cnn_model.h5" HUB_REPO = os.environ.get("HUB_REPO", "") HUB_FILENAME = os.environ.get("HUB_FILENAME", "") # 也可以直接在这里硬编码(如果不想用环境变量) # HUB_REPO = "your-username/your-repo-name" # 替换为您的实际仓库 # HUB_FILENAME = "lstm_cnn_model.h5" # 替换为您的实际文件名 # ---------------------------------------- MODEL = None MODEL_READY = False MODEL_LOAD_ERROR = None def create_dummy_model(): """创建一个虚拟模型用于演示(当真实模型不可用时)""" from tensorflow.keras.models import Sequential from tensorflow.keras.layers import Dense, LSTM, Conv1D, Flatten, Input print("[model] Creating dummy model for demonstration...", flush=True) model = Sequential([ Input(shape=(1, 10)), # 假设输入形状 Conv1D(32, 3, activation='relu', padding='same'), LSTM(64, return_sequences=False), Dense(32, activation='relu'), Dense(4, activation='softmax') # 假设4个类别 ]) model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy']) return model def download_from_hub(repo: str, filename: str): try: print(f"[model] Downloading {filename} from {repo} ...", flush=True) path = hf_hub_download(repo_id=repo, filename=filename) print("[model] Downloaded to:", path, flush=True) return path except Exception as e: print("[model] Hub download failed:", e, flush=True) return None def load_model_background(): global MODEL, MODEL_READY, MODEL_LOAD_ERROR try: model_path = None # 优先检查本地文件 if os.path.exists(LOCAL_MODEL_FILE): model_path = LOCAL_MODEL_FILE print(f"[model] Found local model: {model_path}", flush=True) # 如果本地没有,尝试从 Hub 下载 elif HUB_REPO and HUB_FILENAME: model_path = download_from_hub(HUB_REPO, HUB_FILENAME) else: print("[model] No local model file and no HUB_REPO/HUB_FILENAME configured.", flush=True) print("[model] Using dummy model for demonstration purposes.", flush=True) if model_path and os.path.exists(model_path): print(f"[model] Loading model from {model_path} ...", flush=True) MODEL = load_model(model_path) print("[model] Model loaded successfully!", flush=True) else: # 如果没有真实模型,创建虚拟模型以便演示 print("[model] Creating dummy model since no real model is available.", flush=True) MODEL = create_dummy_model() print("[model] Dummy model created. Note: This is for demo only!", flush=True) MODEL_READY = True except Exception as e: MODEL_LOAD_ERROR = traceback.format_exc() MODEL_READY = False print("[model] Error loading model:\n", MODEL_LOAD_ERROR, flush=True) # 即使出错也尝试创建虚拟模型 try: print("[model] Attempting to create dummy model as fallback...", flush=True) MODEL = create_dummy_model() MODEL_READY = True MODEL_LOAD_ERROR = None print("[model] Dummy model created as fallback.", flush=True) except Exception as e2: print(f"[model] Failed to create dummy model: {e2}", flush=True) # Start model loader in background so Gradio can bind to PORT immediately loader = threading.Thread(target=load_model_background, daemon=True) loader.start() # ---------------- Helper functions ---------------- def prepare_input_array(arr, n_timesteps=1, n_features=None): """准备输入数组为模型所需的形状""" arr = np.array(arr) # 如果是一维数组,重塑为 (1, timesteps, features) if arr.ndim == 1: if n_features is None: # 自动推断 features n_features = len(arr) // n_timesteps return arr.reshape(1, n_timesteps, n_features) # 如果是二维数组,假设第一维是 batch elif arr.ndim == 2: if n_features is None: n_features = arr.shape[1] // n_timesteps return arr.reshape(arr.shape[0], n_timesteps, n_features) # 三维直接返回 else: return arr def predict_text(text, n_timesteps=1, n_features=None): """对单个文本输入进行预测""" if not MODEL_READY: if MODEL_LOAD_ERROR: return f"⚠️ 模型加载失败:\n{MODEL_LOAD_ERROR}" return "⏳ 模型正在加载中,请稍候..." try: # 解析逗号分隔的数字 arr = np.fromstring(text.strip(), sep=',') if len(arr) == 0: return "❌ 输入无效:请提供逗号分隔的数字" # 准备输入 x = prepare_input_array(arr, n_timesteps=int(n_timesteps), n_features=(int(n_features) if n_features else None)) # 预测 probs = MODEL.predict(x, verbose=0) label = int(np.argmax(probs, axis=1)[0]) confidence = float(np.max(probs)) return f"✅ 预测类别: {label}\n置信度: {confidence:.2%}" except Exception as e: return f"❌ 预测失败: {str(e)}" def predict_csv(file, n_timesteps=1, n_features=None): """对CSV文件进行批量预测""" if not MODEL_READY: if MODEL_LOAD_ERROR: return None, f"⚠️ 模型加载失败:\n{MODEL_LOAD_ERROR}" return None, "⏳ 模型正在加载中,请稍候..." try: # 读取CSV df = pd.read_csv(file.name) X = df.values # 准备输入 if n_features: n_samples = X.shape[0] X = X.reshape(n_samples, int(n_timesteps), int(n_features)) else: # 自动推断 n_samples = X.shape[0] n_features_total = X.shape[1] n_features = n_features_total // int(n_timesteps) X = X.reshape(n_samples, int(n_timesteps), n_features) # 批量预测 preds = MODEL.predict(X, verbose=0) labels = preds.argmax(axis=1).tolist() # 创建结果DataFrame result_df = pd.DataFrame({ 'Sample': range(len(labels)), 'Predicted_Label': labels, 'Confidence': [float(np.max(p)) for p in preds] }) # 添加各类别的概率 for i in range(preds.shape[1]): result_df[f'Prob_Class_{i}'] = preds[:, i] return result_df, f"✅ 成功预测 {len(labels)} 个样本" except Exception as e: return None, f"❌ 预测失败: {str(e)}" # ---------------- Gradio UI ---------------- with gr.Blocks(title="CNN-LSTM 故障分类") as demo: gr.Markdown("# 🤖 CNN-LSTM 故障分类系统") # 显示模型状态 with gr.Row(): with gr.Column(): status_text = gr.Markdown("⏳ 模型状态检查中...") gr.Markdown("---") # 输入选项卡 with gr.Tabs(): with gr.Tab("📝 单样本预测"): text_in = gr.Textbox( lines=3, placeholder="输入逗号分隔的特征值,例如:\n0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0", label="特征输入" ) with gr.Row(): n_ts_text = gr.Number(value=1, label="时间步数 (timesteps)", precision=0) n_feat_text = gr.Number(value=10, label="特征数 (features)", precision=0) predict_btn = gr.Button("🔍 预测", variant="primary") out_text = gr.Textbox(label="预测结果", lines=3) with gr.Tab("📊 批量预测"): file_in = gr.File( label="上传 CSV 文件", file_types=[".csv"], file_count="single" ) gr.Markdown("**CSV 格式说明:** 每行代表一个样本,列为特征值") with gr.Row(): n_ts_csv = gr.Number(value=1, label="时间步数 (timesteps)", precision=0) n_feat_csv = gr.Number(value=10, label="特征数 (features,可选)", precision=0) predict_csv_btn = gr.Button("🔍 批量预测", variant="primary") out_csv_msg = gr.Textbox(label="处理消息", lines=2) out_csv_df = gr.Dataframe(label="预测结果表格", interactive=False) gr.Markdown("---") # 使用说明 with gr.Accordion("📖 使用说明", open=False): gr.Markdown(""" ### 输入格式 - **单样本**:输入逗号分隔的数值特征 - **批量预测**:上传CSV文件,每行一个样本 ### 参数说明 - **时间步数 (timesteps)**:序列的时间步长度 - **特征数 (features)**:每个时间步的特征维度 - 输入总长度应该等于 `timesteps × features` ### 示例 如果您有10个特征,1个时间步: - 输入:`0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0` - timesteps = 1, features = 10 """) # 定期更新状态 def update_status(): if MODEL_READY: if MODEL_LOAD_ERROR is None: return "✅ 模型已就绪" else: return "⚠️ 使用演示模型(真实模型加载失败)" elif MODEL_LOAD_ERROR: return f"❌ 模型加载失败" else: return "⏳ 模型加载中..." # 事件绑定 predict_btn.click( fn=predict_text, inputs=[text_in, n_ts_text, n_feat_text], outputs=out_text ) predict_csv_btn.click( fn=predict_csv, inputs=[file_in, n_ts_csv, n_feat_csv], outputs=[out_csv_df, out_csv_msg] ) # 页面加载时更新状态 demo.load(fn=update_status, outputs=status_text) # ---------------- Launch (Spaces-friendly) ---------------- def get_port(): """获取端口号,优先使用环境变量""" try: return int(os.environ.get("PORT", 7860)) except: return 7860 if __name__ == "__main__": port = get_port() print(f"[app] Starting Gradio on 0.0.0.0:{port}", flush=True) # 新版 Gradio 不再需要 enable_queue 参数 # queue 功能现在是默认启用的 demo.launch( server_name="0.0.0.0", server_port=port, show_error=True )