File size: 3,784 Bytes
ea59ebe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
"""
lstm_cnn_app.py
Gradio app to serve the CNN-LSTM fault classification model.

Usage:
- Place a local model file named by LOCAL_MODEL_FILE in the same repo, or
- Set HUB_REPO and HUB_FILENAME to a public Hugging Face model repo + filename,
  and the app will download it at startup using hf_hub_download.
"""
import os
import numpy as np
import pandas as pd
import gradio as gr
from tensorflow.keras.models import load_model
from huggingface_hub import hf_hub_download

# CONFIG: change these if your model filename/repo are different
LOCAL_MODEL_FILE = "lstm_cnn_model.h5"
HUB_REPO = ""         # e.g., "username/lstm-cnn-model"
HUB_FILENAME = ""     # e.g., "lstm_cnn_model.h5"

def get_model_path():
    if os.path.exists(LOCAL_MODEL_FILE):
        return LOCAL_MODEL_FILE
    if HUB_REPO and HUB_FILENAME:
        try:
            print(f"Downloading {HUB_FILENAME} from {HUB_REPO} ...")
            return hf_hub_download(repo_id=HUB_REPO, filename=HUB_FILENAME)
        except Exception as e:
            print("Failed to download from hub:", e)
    return None

MODEL_PATH = get_model_path()
MODEL = None
if MODEL_PATH:
    try:
        MODEL = load_model(MODEL_PATH)
        print("Loaded model:", MODEL_PATH)
    except Exception as e:
        print("Failed to load model:", e)
        MODEL = None
else:
    print("No model found. Please upload a model named", LOCAL_MODEL_FILE, "or set HUB_REPO/HUB_FILENAME.")

def prepare_input_array(arr, n_timesteps=1, n_features=None):
    arr = np.array(arr)
    if arr.ndim == 1:
        if n_features is None:
            return arr.reshape(1, n_timesteps, -1)
        return arr.reshape(1, n_timesteps, n_features)
    elif arr.ndim == 2:
        return arr
    else:
        return arr

def predict_text(text, n_timesteps=1, n_features=None):
    if MODEL is None:
        return "模型未加载,请上传或配置模型。"
    arr = np.fromstring(text, sep=',')
    x = prepare_input_array(arr, n_timesteps=int(n_timesteps), n_features=(int(n_features) if n_features else None))
    probs = MODEL.predict(x)
    label = int(np.argmax(probs, axis=1)[0])
    return f"预测类别: {label} (概率: {float(np.max(probs)):.4f})"

def predict_csv(file, n_timesteps=1, n_features=None):
    if MODEL is None:
        return {"error": "模型未加载,请上传或配置模型。"}
    df = pd.read_csv(file.name)
    X = df.values
    if n_features:
        X = X.reshape(X.shape[0], int(n_timesteps), int(n_features))
    preds = MODEL.predict(X)
    labels = preds.argmax(axis=1).tolist()
    return {"labels": labels, "probs": preds.tolist()}

with gr.Blocks() as demo:
    gr.Markdown("# CNN-LSTM Fault Classification")
    gr.Markdown("上传 CSV(每行一个样本)或粘贴逗号分隔的一行特征进行预测。")
    with gr.Row():
        file_in = gr.File(label="上传 CSV(每行 = 一个样本)")
        text_in = gr.Textbox(lines=2, placeholder="粘贴逗号分隔的一行特征,例如: 0.1,0.2,0.3,...")
    n_ts = gr.Number(value=1, label="timesteps (整型)")
    n_feat = gr.Number(value=None, label="features (可选,留空尝试自动推断)")
    btn = gr.Button("预测")
    out_text = gr.Textbox(label="单样本预测输出")
    out_json = gr.JSON(label="批量预测结果 (labels & probs)")

    def run_predict(file, text, n_timesteps, n_features):
        if file is not None:
            return "CSV 预测完成", predict_csv(file, n_timesteps, n_features)
        if text:
            return predict_text(text, n_timesteps, n_features), {}
        return "请提供 CSV 或特征文本", {}

    btn.click(run_predict, inputs=[file_in, text_in, n_ts, n_feat], outputs=[out_text, out_json])

if __name__ == '__main__':
    demo.launch(server_name='0.0.0.0', server_port=7861)