FaultDetectionDeepLearning / lstm_cnn_app.py
VincentCroft's picture
Upload 3 files
ea59ebe verified
raw
history blame
3.78 kB
"""
lstm_cnn_app.py
Gradio app to serve the CNN-LSTM fault classification model.
Usage:
- Place a local model file named by LOCAL_MODEL_FILE in the same repo, or
- Set HUB_REPO and HUB_FILENAME to a public Hugging Face model repo + filename,
and the app will download it at startup using hf_hub_download.
"""
import os
import numpy as np
import pandas as pd
import gradio as gr
from tensorflow.keras.models import load_model
from huggingface_hub import hf_hub_download
# CONFIG: change these if your model filename/repo are different
LOCAL_MODEL_FILE = "lstm_cnn_model.h5"
HUB_REPO = "" # e.g., "username/lstm-cnn-model"
HUB_FILENAME = "" # e.g., "lstm_cnn_model.h5"
def get_model_path():
if os.path.exists(LOCAL_MODEL_FILE):
return LOCAL_MODEL_FILE
if HUB_REPO and HUB_FILENAME:
try:
print(f"Downloading {HUB_FILENAME} from {HUB_REPO} ...")
return hf_hub_download(repo_id=HUB_REPO, filename=HUB_FILENAME)
except Exception as e:
print("Failed to download from hub:", e)
return None
MODEL_PATH = get_model_path()
MODEL = None
if MODEL_PATH:
try:
MODEL = load_model(MODEL_PATH)
print("Loaded model:", MODEL_PATH)
except Exception as e:
print("Failed to load model:", e)
MODEL = None
else:
print("No model found. Please upload a model named", LOCAL_MODEL_FILE, "or set HUB_REPO/HUB_FILENAME.")
def prepare_input_array(arr, n_timesteps=1, n_features=None):
arr = np.array(arr)
if arr.ndim == 1:
if n_features is None:
return arr.reshape(1, n_timesteps, -1)
return arr.reshape(1, n_timesteps, n_features)
elif arr.ndim == 2:
return arr
else:
return arr
def predict_text(text, n_timesteps=1, n_features=None):
if MODEL is None:
return "模型未加载,请上传或配置模型。"
arr = np.fromstring(text, sep=',')
x = prepare_input_array(arr, n_timesteps=int(n_timesteps), n_features=(int(n_features) if n_features else None))
probs = MODEL.predict(x)
label = int(np.argmax(probs, axis=1)[0])
return f"预测类别: {label} (概率: {float(np.max(probs)):.4f})"
def predict_csv(file, n_timesteps=1, n_features=None):
if MODEL is None:
return {"error": "模型未加载,请上传或配置模型。"}
df = pd.read_csv(file.name)
X = df.values
if n_features:
X = X.reshape(X.shape[0], int(n_timesteps), int(n_features))
preds = MODEL.predict(X)
labels = preds.argmax(axis=1).tolist()
return {"labels": labels, "probs": preds.tolist()}
with gr.Blocks() as demo:
gr.Markdown("# CNN-LSTM Fault Classification")
gr.Markdown("上传 CSV(每行一个样本)或粘贴逗号分隔的一行特征进行预测。")
with gr.Row():
file_in = gr.File(label="上传 CSV(每行 = 一个样本)")
text_in = gr.Textbox(lines=2, placeholder="粘贴逗号分隔的一行特征,例如: 0.1,0.2,0.3,...")
n_ts = gr.Number(value=1, label="timesteps (整型)")
n_feat = gr.Number(value=None, label="features (可选,留空尝试自动推断)")
btn = gr.Button("预测")
out_text = gr.Textbox(label="单样本预测输出")
out_json = gr.JSON(label="批量预测结果 (labels & probs)")
def run_predict(file, text, n_timesteps, n_features):
if file is not None:
return "CSV 预测完成", predict_csv(file, n_timesteps, n_features)
if text:
return predict_text(text, n_timesteps, n_features), {}
return "请提供 CSV 或特征文本", {}
btn.click(run_predict, inputs=[file_in, text_in, n_ts, n_feat], outputs=[out_text, out_json])
if __name__ == '__main__':
demo.launch(server_name='0.0.0.0', server_port=7861)