import os import numpy as np import torch import gradio as gr import librosa from huggingface_hub import hf_hub_download from model import DCCRN # 确保已有 model.py 与 utils/ 依赖 # ===== 基本配置 ===== DEVICE = "cuda" if torch.cuda.is_available() else "cpu" SR = int(os.getenv("SAMPLE_RATE", "16000")) # 从环境变量读取模型仓库与权重文件 REPO_ID = os.getenv("MODEL_REPO_ID", "Ada312/DCCRN") FILENAME = os.getenv("MODEL_FILENAME", "dccrn.ckpt") TOKEN = os.getenv("HF_TOKEN") # 私有模型仓库才需要 # ===== 下载并加载权重 ===== ckpt_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME, token=TOKEN) net = DCCRN() # 如果训练时用了自定义参数,请按实际填入 ckpt = torch.load(ckpt_path, map_location=DEVICE) state = ckpt.get("state_dict", ckpt) state = {k.replace("model.", "").replace("module.", ""): v for k, v in state.items()} net.load_state_dict(state, strict=False) net.to(DEVICE).eval() # ===== 推理函数 ===== def enhance(audio_path: str): wav, _ = librosa.load(audio_path, sr=SR, mono=True) x = torch.from_numpy(wav).float().to(DEVICE) if x.ndim == 1: x = x.unsqueeze(0) # [1, T] with torch.no_grad(): # 许多 DCCRN 期望 [B, 1, T],先尝试该形状;不行再退回 [B, T] try: y = net(x.unsqueeze(1)) # [1, 1, T] except Exception: y = net(x) # [1, T] y = y.squeeze().detach().cpu().numpy() return (SR, y) # ===== Gradio 界面 ===== with gr.Blocks() as demo: gr.Markdown("## 🎧 DCCRN Speech Enhancement\n上传或录音,点击“去噪”。") with gr.Row(): inp = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Noisy speech") out = gr.Audio(label="Enhanced speech") btn = gr.Button("去噪") # 新写法:把并发限制写在事件监听器上 btn.click(enhance, inputs=inp, outputs=out, concurrency_limit=1) # 队列:保留排队上限即可(不再使用已废弃的 concurrency_count) demo.queue(max_size=8) demo.launch()