Spaces:
Sleeping
Sleeping
import os, numpy as np, torch, gradio as gr, librosa | |
from huggingface_hub import hf_hub_download | |
from model import DCCRN # 确保你上传了 model.py 和 utils 依赖 | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
SR = 16000 # 你的模型训练采样率 | |
# 从环境变量里读取模型仓库名和权重文件名 | |
REPO_ID = os.getenv("MODEL_REPO_ID", "Ada312/DCCRN") # 你的模型仓库 | |
FILENAME = os.getenv("MODEL_FILENAME", "dccrn.ckpt") # 权重文件 | |
TOKEN = os.getenv("HF_TOKEN") # 如果模型仓库是私有,就需要这个 | |
# 下载权重到本地缓存 | |
ckpt_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME, token=TOKEN) | |
# 初始化模型并加载权重 | |
net = DCCRN() | |
ckpt = torch.load(ckpt_path, map_location=DEVICE) | |
state = ckpt.get("state_dict", ckpt) | |
state = {k.replace("model.","").replace("module.",""): v for k,v in state.items()} | |
net.load_state_dict(state, strict=False) | |
net.to(DEVICE).eval() | |
# 推理函数:输入 noisy audio → 输出 enhanced audio | |
def enhance(audio_path: str): | |
wav, _ = librosa.load(audio_path, sr=SR, mono=True) | |
x = torch.from_numpy(wav).float().to(DEVICE)[None, None, :] | |
with torch.no_grad(): | |
y = net(x).squeeze().cpu().numpy() | |
return (SR, y) | |
# Gradio 界面 | |
with gr.Blocks() as demo: | |
gr.Markdown("## 🎧 DCCRN Speech Enhancement\n上传或录音,点击“去噪”。") | |
with gr.Row(): | |
inp = gr.Audio(sources=["upload","microphone"], type="filepath", label="Noisy speech") | |
out = gr.Audio(label="Enhanced speech") | |
gr.Button("去噪").click(enhance, inputs=inp, outputs=out) | |
demo.queue(concurrency_count=1, max_size=8) | |
demo.launch() | |