VincentCroft's picture
Update app.py
85e194e verified
raw
history blame
11.3 kB
# app.py -- Spaces-ready robust version for lstm_cnn model
import os
import threading
import traceback
import numpy as np
import pandas as pd
import gradio as gr
from huggingface_hub import hf_hub_download
# Use Keras model loader (change if you have PyTorch)
from tensorflow.keras.models import load_model
# ---------------- Config ----------------
# 设置默认的模型文件名
LOCAL_MODEL_FILE = os.environ.get("LOCAL_MODEL_FILE", "lstm_cnn_model.h5")
# 如果您的模型在 Hugging Face Hub 上,请设置这些环境变量
# 例如: HUB_REPO="your-username/your-repo-name"
# HUB_FILENAME="lstm_cnn_model.h5"
HUB_REPO = os.environ.get("HUB_REPO", "")
HUB_FILENAME = os.environ.get("HUB_FILENAME", "")
# 也可以直接在这里硬编码(如果不想用环境变量)
# HUB_REPO = "your-username/your-repo-name" # 替换为您的实际仓库
# HUB_FILENAME = "lstm_cnn_model.h5" # 替换为您的实际文件名
# ----------------------------------------
MODEL = None
MODEL_READY = False
MODEL_LOAD_ERROR = None
def create_dummy_model():
"""创建一个虚拟模型用于演示(当真实模型不可用时)"""
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, LSTM, Conv1D, Flatten, Input
print("[model] Creating dummy model for demonstration...", flush=True)
model = Sequential([
Input(shape=(1, 10)), # 假设输入形状
Conv1D(32, 3, activation='relu', padding='same'),
LSTM(64, return_sequences=False),
Dense(32, activation='relu'),
Dense(4, activation='softmax') # 假设4个类别
])
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
return model
def download_from_hub(repo: str, filename: str):
try:
print(f"[model] Downloading {filename} from {repo} ...", flush=True)
path = hf_hub_download(repo_id=repo, filename=filename)
print("[model] Downloaded to:", path, flush=True)
return path
except Exception as e:
print("[model] Hub download failed:", e, flush=True)
return None
def load_model_background():
global MODEL, MODEL_READY, MODEL_LOAD_ERROR
try:
model_path = None
# 优先检查本地文件
if os.path.exists(LOCAL_MODEL_FILE):
model_path = LOCAL_MODEL_FILE
print(f"[model] Found local model: {model_path}", flush=True)
# 如果本地没有,尝试从 Hub 下载
elif HUB_REPO and HUB_FILENAME:
model_path = download_from_hub(HUB_REPO, HUB_FILENAME)
else:
print("[model] No local model file and no HUB_REPO/HUB_FILENAME configured.", flush=True)
print("[model] Using dummy model for demonstration purposes.", flush=True)
if model_path and os.path.exists(model_path):
print(f"[model] Loading model from {model_path} ...", flush=True)
MODEL = load_model(model_path)
print("[model] Model loaded successfully!", flush=True)
else:
# 如果没有真实模型,创建虚拟模型以便演示
print("[model] Creating dummy model since no real model is available.", flush=True)
MODEL = create_dummy_model()
print("[model] Dummy model created. Note: This is for demo only!", flush=True)
MODEL_READY = True
except Exception as e:
MODEL_LOAD_ERROR = traceback.format_exc()
MODEL_READY = False
print("[model] Error loading model:\n", MODEL_LOAD_ERROR, flush=True)
# 即使出错也尝试创建虚拟模型
try:
print("[model] Attempting to create dummy model as fallback...", flush=True)
MODEL = create_dummy_model()
MODEL_READY = True
MODEL_LOAD_ERROR = None
print("[model] Dummy model created as fallback.", flush=True)
except Exception as e2:
print(f"[model] Failed to create dummy model: {e2}", flush=True)
# Start model loader in background so Gradio can bind to PORT immediately
loader = threading.Thread(target=load_model_background, daemon=True)
loader.start()
# ---------------- Helper functions ----------------
def prepare_input_array(arr, n_timesteps=1, n_features=None):
"""准备输入数组为模型所需的形状"""
arr = np.array(arr)
# 如果是一维数组,重塑为 (1, timesteps, features)
if arr.ndim == 1:
if n_features is None:
# 自动推断 features
n_features = len(arr) // n_timesteps
return arr.reshape(1, n_timesteps, n_features)
# 如果是二维数组,假设第一维是 batch
elif arr.ndim == 2:
if n_features is None:
n_features = arr.shape[1] // n_timesteps
return arr.reshape(arr.shape[0], n_timesteps, n_features)
# 三维直接返回
else:
return arr
def predict_text(text, n_timesteps=1, n_features=None):
"""对单个文本输入进行预测"""
if not MODEL_READY:
if MODEL_LOAD_ERROR:
return f"⚠️ 模型加载失败:\n{MODEL_LOAD_ERROR}"
return "⏳ 模型正在加载中,请稍候..."
try:
# 解析逗号分隔的数字
arr = np.fromstring(text.strip(), sep=',')
if len(arr) == 0:
return "❌ 输入无效:请提供逗号分隔的数字"
# 准备输入
x = prepare_input_array(arr, n_timesteps=int(n_timesteps),
n_features=(int(n_features) if n_features else None))
# 预测
probs = MODEL.predict(x, verbose=0)
label = int(np.argmax(probs, axis=1)[0])
confidence = float(np.max(probs))
return f"✅ 预测类别: {label}\n置信度: {confidence:.2%}"
except Exception as e:
return f"❌ 预测失败: {str(e)}"
def predict_csv(file, n_timesteps=1, n_features=None):
"""对CSV文件进行批量预测"""
if not MODEL_READY:
if MODEL_LOAD_ERROR:
return None, f"⚠️ 模型加载失败:\n{MODEL_LOAD_ERROR}"
return None, "⏳ 模型正在加载中,请稍候..."
try:
# 读取CSV
df = pd.read_csv(file.name)
X = df.values
# 准备输入
if n_features:
n_samples = X.shape[0]
X = X.reshape(n_samples, int(n_timesteps), int(n_features))
else:
# 自动推断
n_samples = X.shape[0]
n_features_total = X.shape[1]
n_features = n_features_total // int(n_timesteps)
X = X.reshape(n_samples, int(n_timesteps), n_features)
# 批量预测
preds = MODEL.predict(X, verbose=0)
labels = preds.argmax(axis=1).tolist()
# 创建结果DataFrame
result_df = pd.DataFrame({
'Sample': range(len(labels)),
'Predicted_Label': labels,
'Confidence': [float(np.max(p)) for p in preds]
})
# 添加各类别的概率
for i in range(preds.shape[1]):
result_df[f'Prob_Class_{i}'] = preds[:, i]
return result_df, f"✅ 成功预测 {len(labels)} 个样本"
except Exception as e:
return None, f"❌ 预测失败: {str(e)}"
# ---------------- Gradio UI ----------------
with gr.Blocks(title="CNN-LSTM 故障分类") as demo:
gr.Markdown("# 🤖 CNN-LSTM 故障分类系统")
# 显示模型状态
with gr.Row():
with gr.Column():
status_text = gr.Markdown("⏳ 模型状态检查中...")
gr.Markdown("---")
# 输入选项卡
with gr.Tabs():
with gr.Tab("📝 单样本预测"):
text_in = gr.Textbox(
lines=3,
placeholder="输入逗号分隔的特征值,例如:\n0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0",
label="特征输入"
)
with gr.Row():
n_ts_text = gr.Number(value=1, label="时间步数 (timesteps)", precision=0)
n_feat_text = gr.Number(value=10, label="特征数 (features)", precision=0)
predict_btn = gr.Button("🔍 预测", variant="primary")
out_text = gr.Textbox(label="预测结果", lines=3)
with gr.Tab("📊 批量预测"):
file_in = gr.File(
label="上传 CSV 文件",
file_types=[".csv"],
file_count="single"
)
gr.Markdown("**CSV 格式说明:** 每行代表一个样本,列为特征值")
with gr.Row():
n_ts_csv = gr.Number(value=1, label="时间步数 (timesteps)", precision=0)
n_feat_csv = gr.Number(value=10, label="特征数 (features,可选)", precision=0)
predict_csv_btn = gr.Button("🔍 批量预测", variant="primary")
out_csv_msg = gr.Textbox(label="处理消息", lines=2)
out_csv_df = gr.Dataframe(label="预测结果表格", interactive=False)
gr.Markdown("---")
# 使用说明
with gr.Accordion("📖 使用说明", open=False):
gr.Markdown("""
### 输入格式
- **单样本**:输入逗号分隔的数值特征
- **批量预测**:上传CSV文件,每行一个样本
### 参数说明
- **时间步数 (timesteps)**:序列的时间步长度
- **特征数 (features)**:每个时间步的特征维度
- 输入总长度应该等于 `timesteps × features`
### 示例
如果您有10个特征,1个时间步:
- 输入:`0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0`
- timesteps = 1, features = 10
""")
# 定期更新状态
def update_status():
if MODEL_READY:
if MODEL_LOAD_ERROR is None:
return "✅ 模型已就绪"
else:
return "⚠️ 使用演示模型(真实模型加载失败)"
elif MODEL_LOAD_ERROR:
return f"❌ 模型加载失败"
else:
return "⏳ 模型加载中..."
# 事件绑定
predict_btn.click(
fn=predict_text,
inputs=[text_in, n_ts_text, n_feat_text],
outputs=out_text
)
predict_csv_btn.click(
fn=predict_csv,
inputs=[file_in, n_ts_csv, n_feat_csv],
outputs=[out_csv_df, out_csv_msg]
)
# 页面加载时更新状态
demo.load(fn=update_status, outputs=status_text)
# ---------------- Launch (Spaces-friendly) ----------------
def get_port():
"""获取端口号,优先使用环境变量"""
try:
return int(os.environ.get("PORT", 7860))
except:
return 7860
if __name__ == "__main__":
port = get_port()
print(f"[app] Starting Gradio on 0.0.0.0:{port}", flush=True)
# 新版 Gradio 不再需要 enable_queue 参数
# queue 功能现在是默认启用的
demo.launch(
server_name="0.0.0.0",
server_port=port,
show_error=True
)