dccrn-demo / app.py
Ada312's picture
Update app.py
9298b00 verified
raw
history blame
1.67 kB
import os, numpy as np, torch, gradio as gr, librosa
from huggingface_hub import hf_hub_download
from model import DCCRN # 确保你上传了 model.py 和 utils 依赖
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SR = 16000 # 你的模型训练采样率
# 从环境变量里读取模型仓库名和权重文件名
REPO_ID = os.getenv("MODEL_REPO_ID", "Ada312/DCCRN") # 你的模型仓库
FILENAME = os.getenv("MODEL_FILENAME", "dccrn.ckpt") # 权重文件
TOKEN = os.getenv("HF_TOKEN") # 如果模型仓库是私有,就需要这个
# 下载权重到本地缓存
ckpt_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME, token=TOKEN)
# 初始化模型并加载权重
net = DCCRN()
ckpt = torch.load(ckpt_path, map_location=DEVICE)
state = ckpt.get("state_dict", ckpt)
state = {k.replace("model.","").replace("module.",""): v for k,v in state.items()}
net.load_state_dict(state, strict=False)
net.to(DEVICE).eval()
# 推理函数:输入 noisy audio → 输出 enhanced audio
def enhance(audio_path: str):
wav, _ = librosa.load(audio_path, sr=SR, mono=True)
x = torch.from_numpy(wav).float().to(DEVICE)[None, None, :]
with torch.no_grad():
y = net(x).squeeze().cpu().numpy()
return (SR, y)
# Gradio 界面
with gr.Blocks() as demo:
gr.Markdown("## 🎧 DCCRN Speech Enhancement\n上传或录音,点击“去噪”。")
with gr.Row():
inp = gr.Audio(sources=["upload","microphone"], type="filepath", label="Noisy speech")
out = gr.Audio(label="Enhanced speech")
gr.Button("去噪").click(enhance, inputs=inp, outputs=out)
demo.queue(concurrency_count=1, max_size=8)
demo.launch()