|
|
|
import os |
|
import threading |
|
import traceback |
|
import numpy as np |
|
import pandas as pd |
|
import gradio as gr |
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
from tensorflow.keras.models import load_model |
|
|
|
|
|
|
|
LOCAL_MODEL_FILE = os.environ.get("LOCAL_MODEL_FILE", "lstm_cnn_model.h5") |
|
|
|
|
|
|
|
|
|
HUB_REPO = os.environ.get("HUB_REPO", "") |
|
HUB_FILENAME = os.environ.get("HUB_FILENAME", "") |
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL = None |
|
MODEL_READY = False |
|
MODEL_LOAD_ERROR = None |
|
|
|
def create_dummy_model(): |
|
"""创建一个虚拟模型用于演示(当真实模型不可用时)""" |
|
from tensorflow.keras.models import Sequential |
|
from tensorflow.keras.layers import Dense, LSTM, Conv1D, Flatten, Input |
|
|
|
print("[model] Creating dummy model for demonstration...", flush=True) |
|
model = Sequential([ |
|
Input(shape=(1, 10)), |
|
Conv1D(32, 3, activation='relu', padding='same'), |
|
LSTM(64, return_sequences=False), |
|
Dense(32, activation='relu'), |
|
Dense(4, activation='softmax') |
|
]) |
|
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy']) |
|
return model |
|
|
|
def download_from_hub(repo: str, filename: str): |
|
try: |
|
print(f"[model] Downloading {filename} from {repo} ...", flush=True) |
|
path = hf_hub_download(repo_id=repo, filename=filename) |
|
print("[model] Downloaded to:", path, flush=True) |
|
return path |
|
except Exception as e: |
|
print("[model] Hub download failed:", e, flush=True) |
|
return None |
|
|
|
def load_model_background(): |
|
global MODEL, MODEL_READY, MODEL_LOAD_ERROR |
|
try: |
|
model_path = None |
|
|
|
|
|
if os.path.exists(LOCAL_MODEL_FILE): |
|
model_path = LOCAL_MODEL_FILE |
|
print(f"[model] Found local model: {model_path}", flush=True) |
|
|
|
elif HUB_REPO and HUB_FILENAME: |
|
model_path = download_from_hub(HUB_REPO, HUB_FILENAME) |
|
else: |
|
print("[model] No local model file and no HUB_REPO/HUB_FILENAME configured.", flush=True) |
|
print("[model] Using dummy model for demonstration purposes.", flush=True) |
|
|
|
if model_path and os.path.exists(model_path): |
|
print(f"[model] Loading model from {model_path} ...", flush=True) |
|
MODEL = load_model(model_path) |
|
print("[model] Model loaded successfully!", flush=True) |
|
else: |
|
|
|
print("[model] Creating dummy model since no real model is available.", flush=True) |
|
MODEL = create_dummy_model() |
|
print("[model] Dummy model created. Note: This is for demo only!", flush=True) |
|
|
|
MODEL_READY = True |
|
|
|
except Exception as e: |
|
MODEL_LOAD_ERROR = traceback.format_exc() |
|
MODEL_READY = False |
|
print("[model] Error loading model:\n", MODEL_LOAD_ERROR, flush=True) |
|
|
|
|
|
try: |
|
print("[model] Attempting to create dummy model as fallback...", flush=True) |
|
MODEL = create_dummy_model() |
|
MODEL_READY = True |
|
MODEL_LOAD_ERROR = None |
|
print("[model] Dummy model created as fallback.", flush=True) |
|
except Exception as e2: |
|
print(f"[model] Failed to create dummy model: {e2}", flush=True) |
|
|
|
|
|
loader = threading.Thread(target=load_model_background, daemon=True) |
|
loader.start() |
|
|
|
|
|
def prepare_input_array(arr, n_timesteps=1, n_features=None): |
|
"""准备输入数组为模型所需的形状""" |
|
arr = np.array(arr) |
|
|
|
|
|
if arr.ndim == 1: |
|
if n_features is None: |
|
|
|
n_features = len(arr) // n_timesteps |
|
return arr.reshape(1, n_timesteps, n_features) |
|
|
|
|
|
elif arr.ndim == 2: |
|
if n_features is None: |
|
n_features = arr.shape[1] // n_timesteps |
|
return arr.reshape(arr.shape[0], n_timesteps, n_features) |
|
|
|
|
|
else: |
|
return arr |
|
|
|
def predict_text(text, n_timesteps=1, n_features=None): |
|
"""对单个文本输入进行预测""" |
|
if not MODEL_READY: |
|
if MODEL_LOAD_ERROR: |
|
return f"⚠️ 模型加载失败:\n{MODEL_LOAD_ERROR}" |
|
return "⏳ 模型正在加载中,请稍候..." |
|
|
|
try: |
|
|
|
arr = np.fromstring(text.strip(), sep=',') |
|
|
|
if len(arr) == 0: |
|
return "❌ 输入无效:请提供逗号分隔的数字" |
|
|
|
|
|
x = prepare_input_array(arr, n_timesteps=int(n_timesteps), |
|
n_features=(int(n_features) if n_features else None)) |
|
|
|
|
|
probs = MODEL.predict(x, verbose=0) |
|
label = int(np.argmax(probs, axis=1)[0]) |
|
confidence = float(np.max(probs)) |
|
|
|
return f"✅ 预测类别: {label}\n置信度: {confidence:.2%}" |
|
|
|
except Exception as e: |
|
return f"❌ 预测失败: {str(e)}" |
|
|
|
def predict_csv(file, n_timesteps=1, n_features=None): |
|
"""对CSV文件进行批量预测""" |
|
if not MODEL_READY: |
|
if MODEL_LOAD_ERROR: |
|
return None, f"⚠️ 模型加载失败:\n{MODEL_LOAD_ERROR}" |
|
return None, "⏳ 模型正在加载中,请稍候..." |
|
|
|
try: |
|
|
|
df = pd.read_csv(file.name) |
|
X = df.values |
|
|
|
|
|
if n_features: |
|
n_samples = X.shape[0] |
|
X = X.reshape(n_samples, int(n_timesteps), int(n_features)) |
|
else: |
|
|
|
n_samples = X.shape[0] |
|
n_features_total = X.shape[1] |
|
n_features = n_features_total // int(n_timesteps) |
|
X = X.reshape(n_samples, int(n_timesteps), n_features) |
|
|
|
|
|
preds = MODEL.predict(X, verbose=0) |
|
labels = preds.argmax(axis=1).tolist() |
|
|
|
|
|
result_df = pd.DataFrame({ |
|
'Sample': range(len(labels)), |
|
'Predicted_Label': labels, |
|
'Confidence': [float(np.max(p)) for p in preds] |
|
}) |
|
|
|
|
|
for i in range(preds.shape[1]): |
|
result_df[f'Prob_Class_{i}'] = preds[:, i] |
|
|
|
return result_df, f"✅ 成功预测 {len(labels)} 个样本" |
|
|
|
except Exception as e: |
|
return None, f"❌ 预测失败: {str(e)}" |
|
|
|
|
|
with gr.Blocks(title="CNN-LSTM 故障分类") as demo: |
|
gr.Markdown("# 🤖 CNN-LSTM 故障分类系统") |
|
|
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
status_text = gr.Markdown("⏳ 模型状态检查中...") |
|
|
|
gr.Markdown("---") |
|
|
|
|
|
with gr.Tabs(): |
|
with gr.Tab("📝 单样本预测"): |
|
text_in = gr.Textbox( |
|
lines=3, |
|
placeholder="输入逗号分隔的特征值,例如:\n0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0", |
|
label="特征输入" |
|
) |
|
|
|
with gr.Row(): |
|
n_ts_text = gr.Number(value=1, label="时间步数 (timesteps)", precision=0) |
|
n_feat_text = gr.Number(value=10, label="特征数 (features)", precision=0) |
|
|
|
predict_btn = gr.Button("🔍 预测", variant="primary") |
|
out_text = gr.Textbox(label="预测结果", lines=3) |
|
|
|
with gr.Tab("📊 批量预测"): |
|
file_in = gr.File( |
|
label="上传 CSV 文件", |
|
file_types=[".csv"], |
|
file_count="single" |
|
) |
|
|
|
gr.Markdown("**CSV 格式说明:** 每行代表一个样本,列为特征值") |
|
|
|
with gr.Row(): |
|
n_ts_csv = gr.Number(value=1, label="时间步数 (timesteps)", precision=0) |
|
n_feat_csv = gr.Number(value=10, label="特征数 (features,可选)", precision=0) |
|
|
|
predict_csv_btn = gr.Button("🔍 批量预测", variant="primary") |
|
|
|
out_csv_msg = gr.Textbox(label="处理消息", lines=2) |
|
out_csv_df = gr.Dataframe(label="预测结果表格", interactive=False) |
|
|
|
gr.Markdown("---") |
|
|
|
|
|
with gr.Accordion("📖 使用说明", open=False): |
|
gr.Markdown(""" |
|
### 输入格式 |
|
- **单样本**:输入逗号分隔的数值特征 |
|
- **批量预测**:上传CSV文件,每行一个样本 |
|
|
|
### 参数说明 |
|
- **时间步数 (timesteps)**:序列的时间步长度 |
|
- **特征数 (features)**:每个时间步的特征维度 |
|
- 输入总长度应该等于 `timesteps × features` |
|
|
|
### 示例 |
|
如果您有10个特征,1个时间步: |
|
- 输入:`0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0` |
|
- timesteps = 1, features = 10 |
|
""") |
|
|
|
|
|
def update_status(): |
|
if MODEL_READY: |
|
if MODEL_LOAD_ERROR is None: |
|
return "✅ 模型已就绪" |
|
else: |
|
return "⚠️ 使用演示模型(真实模型加载失败)" |
|
elif MODEL_LOAD_ERROR: |
|
return f"❌ 模型加载失败" |
|
else: |
|
return "⏳ 模型加载中..." |
|
|
|
|
|
predict_btn.click( |
|
fn=predict_text, |
|
inputs=[text_in, n_ts_text, n_feat_text], |
|
outputs=out_text |
|
) |
|
|
|
predict_csv_btn.click( |
|
fn=predict_csv, |
|
inputs=[file_in, n_ts_csv, n_feat_csv], |
|
outputs=[out_csv_df, out_csv_msg] |
|
) |
|
|
|
|
|
demo.load(fn=update_status, outputs=status_text) |
|
|
|
|
|
def get_port(): |
|
"""获取端口号,优先使用环境变量""" |
|
try: |
|
return int(os.environ.get("PORT", 7860)) |
|
except: |
|
return 7860 |
|
|
|
if __name__ == "__main__": |
|
port = get_port() |
|
print(f"[app] Starting Gradio on 0.0.0.0:{port}", flush=True) |
|
|
|
|
|
|
|
demo.launch( |
|
server_name="0.0.0.0", |
|
server_port=port, |
|
show_error=True |
|
) |