Ada312 commited on
Commit
9298b00
·
verified ·
1 Parent(s): dab4401

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -0
app.py CHANGED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, numpy as np, torch, gradio as gr, librosa
2
+ from huggingface_hub import hf_hub_download
3
+ from model import DCCRN # 确保你上传了 model.py 和 utils 依赖
4
+
5
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
6
+ SR = 16000 # 你的模型训练采样率
7
+
8
+ # 从环境变量里读取模型仓库名和权重文件名
9
+ REPO_ID = os.getenv("MODEL_REPO_ID", "Ada312/DCCRN") # 你的模型仓库
10
+ FILENAME = os.getenv("MODEL_FILENAME", "dccrn.ckpt") # 权重文件
11
+ TOKEN = os.getenv("HF_TOKEN") # 如果模型仓库是私有,就需要这个
12
+
13
+ # 下载权重到本地缓存
14
+ ckpt_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME, token=TOKEN)
15
+
16
+ # 初始化模型并加载权重
17
+ net = DCCRN()
18
+ ckpt = torch.load(ckpt_path, map_location=DEVICE)
19
+ state = ckpt.get("state_dict", ckpt)
20
+ state = {k.replace("model.","").replace("module.",""): v for k,v in state.items()}
21
+ net.load_state_dict(state, strict=False)
22
+ net.to(DEVICE).eval()
23
+
24
+ # 推理函数:输入 noisy audio → 输出 enhanced audio
25
+ def enhance(audio_path: str):
26
+ wav, _ = librosa.load(audio_path, sr=SR, mono=True)
27
+ x = torch.from_numpy(wav).float().to(DEVICE)[None, None, :]
28
+ with torch.no_grad():
29
+ y = net(x).squeeze().cpu().numpy()
30
+ return (SR, y)
31
+
32
+ # Gradio 界面
33
+ with gr.Blocks() as demo:
34
+ gr.Markdown("## 🎧 DCCRN Speech Enhancement\n上传或录音,点击“去噪”。")
35
+ with gr.Row():
36
+ inp = gr.Audio(sources=["upload","microphone"], type="filepath", label="Noisy speech")
37
+ out = gr.Audio(label="Enhanced speech")
38
+ gr.Button("去噪").click(enhance, inputs=inp, outputs=out)
39
+
40
+ demo.queue(concurrency_count=1, max_size=8)
41
+ demo.launch()