File size: 1,670 Bytes
9298b00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import os, numpy as np, torch, gradio as gr, librosa
from huggingface_hub import hf_hub_download
from model import DCCRN  # 确保你上传了 model.py 和 utils 依赖

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SR = 16000  # 你的模型训练采样率

# 从环境变量里读取模型仓库名和权重文件名
REPO_ID  = os.getenv("MODEL_REPO_ID", "Ada312/DCCRN")   # 你的模型仓库
FILENAME = os.getenv("MODEL_FILENAME", "dccrn.ckpt")    # 权重文件
TOKEN    = os.getenv("HF_TOKEN")  # 如果模型仓库是私有,就需要这个

# 下载权重到本地缓存
ckpt_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME, token=TOKEN)

# 初始化模型并加载权重
net = DCCRN()
ckpt = torch.load(ckpt_path, map_location=DEVICE)
state = ckpt.get("state_dict", ckpt)
state = {k.replace("model.","").replace("module.",""): v for k,v in state.items()}
net.load_state_dict(state, strict=False)
net.to(DEVICE).eval()

# 推理函数:输入 noisy audio → 输出 enhanced audio
def enhance(audio_path: str):
    wav, _ = librosa.load(audio_path, sr=SR, mono=True)
    x = torch.from_numpy(wav).float().to(DEVICE)[None, None, :]
    with torch.no_grad():
        y = net(x).squeeze().cpu().numpy()
    return (SR, y)

# Gradio 界面
with gr.Blocks() as demo:
    gr.Markdown("## 🎧 DCCRN Speech Enhancement\n上传或录音,点击“去噪”。")
    with gr.Row():
        inp = gr.Audio(sources=["upload","microphone"], type="filepath", label="Noisy speech")
        out = gr.Audio(label="Enhanced speech")
    gr.Button("去噪").click(enhance, inputs=inp, outputs=out)

demo.queue(concurrency_count=1, max_size=8)
demo.launch()