Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -10,15 +10,15 @@ from model import DCCRN # requires model.py and utils/ dependencies
|
|
| 10 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 11 |
SR = int(os.getenv("SAMPLE_RATE", "16000"))
|
| 12 |
|
| 13 |
-
#
|
| 14 |
-
REPO_ID = os.getenv("MODEL_REPO_ID", "
|
| 15 |
-
FILENAME = os.getenv("MODEL_FILENAME", "
|
| 16 |
TOKEN = os.getenv("HF_TOKEN") # only required if the model repo is private
|
| 17 |
|
| 18 |
# ===== Download & load weights =====
|
| 19 |
ckpt_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME, token=TOKEN)
|
| 20 |
|
| 21 |
-
net = DCCRN() # instantiate with
|
| 22 |
ckpt = torch.load(ckpt_path, map_location=DEVICE)
|
| 23 |
state = ckpt.get("state_dict", ckpt)
|
| 24 |
state = {k.replace("model.", "").replace("module.", ""): v for k, v in state.items()}
|
|
@@ -31,11 +31,14 @@ def enhance(audio_path: str):
|
|
| 31 |
x = torch.from_numpy(wav).float().to(DEVICE)
|
| 32 |
if x.ndim == 1:
|
| 33 |
x = x.unsqueeze(0) # [1, T]
|
|
|
|
| 34 |
with torch.no_grad():
|
|
|
|
| 35 |
try:
|
| 36 |
-
y = net(x.unsqueeze(1)) #
|
| 37 |
except Exception:
|
| 38 |
-
y = net(x) #
|
|
|
|
| 39 |
y = y.squeeze().detach().cpu().numpy()
|
| 40 |
return (SR, y)
|
| 41 |
|
|
@@ -44,28 +47,26 @@ with gr.Blocks() as demo:
|
|
| 44 |
gr.Markdown(
|
| 45 |
"""
|
| 46 |
# 🎧 DCCRN Speech Enhancement (Demo)
|
|
|
|
| 47 |
|
| 48 |
-
**How to use:** drag & drop a noisy audio clip (or upload / record) → click **Enhance** → listen & download the result.
|
| 49 |
**Sample audio:** click a sample below to auto-fill the input, then click **Enhance**.
|
| 50 |
"""
|
| 51 |
)
|
| 52 |
|
| 53 |
with gr.Row():
|
| 54 |
inp = gr.Audio(
|
| 55 |
-
sources=["upload", "microphone"],
|
| 56 |
type="filepath",
|
| 57 |
-
label="Input: noisy speech"
|
| 58 |
-
placeholder="Drag & drop or click to upload / record",
|
| 59 |
-
show_label=True,
|
| 60 |
)
|
| 61 |
out = gr.Audio(
|
| 62 |
label="Output: enhanced speech (downloadable)",
|
| 63 |
-
show_download_button=True
|
| 64 |
)
|
| 65 |
|
| 66 |
enhance_btn = gr.Button("Enhance")
|
| 67 |
|
| 68 |
-
# On-page sample clips (
|
| 69 |
gr.Examples(
|
| 70 |
examples=[
|
| 71 |
["examples/noisy_1.wav"],
|
|
@@ -80,6 +81,6 @@ with gr.Blocks() as demo:
|
|
| 80 |
# Gradio ≥4.44: set concurrency on the event listener
|
| 81 |
enhance_btn.click(enhance, inputs=inp, outputs=out, concurrency_limit=1)
|
| 82 |
|
| 83 |
-
#
|
| 84 |
demo.queue(max_size=16)
|
| 85 |
demo.launch()
|
|
|
|
| 10 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 11 |
SR = int(os.getenv("SAMPLE_RATE", "16000"))
|
| 12 |
|
| 13 |
+
# Read model repo and filename from environment variables
|
| 14 |
+
REPO_ID = os.getenv("MODEL_REPO_ID", "Ada312/DCCRN") # change default if needed
|
| 15 |
+
FILENAME = os.getenv("MODEL_FILENAME", "dccrn.ckpt")
|
| 16 |
TOKEN = os.getenv("HF_TOKEN") # only required if the model repo is private
|
| 17 |
|
| 18 |
# ===== Download & load weights =====
|
| 19 |
ckpt_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME, token=TOKEN)
|
| 20 |
|
| 21 |
+
net = DCCRN() # if you trained with custom args, instantiate with the same args here
|
| 22 |
ckpt = torch.load(ckpt_path, map_location=DEVICE)
|
| 23 |
state = ckpt.get("state_dict", ckpt)
|
| 24 |
state = {k.replace("model.", "").replace("module.", ""): v for k, v in state.items()}
|
|
|
|
| 31 |
x = torch.from_numpy(wav).float().to(DEVICE)
|
| 32 |
if x.ndim == 1:
|
| 33 |
x = x.unsqueeze(0) # [1, T]
|
| 34 |
+
|
| 35 |
with torch.no_grad():
|
| 36 |
+
# Many DCCRNs expect [B,1,T]; try that first, fallback to [B,T]
|
| 37 |
try:
|
| 38 |
+
y = net(x.unsqueeze(1)) # [1, 1, T]
|
| 39 |
except Exception:
|
| 40 |
+
y = net(x) # [1, T]
|
| 41 |
+
|
| 42 |
y = y.squeeze().detach().cpu().numpy()
|
| 43 |
return (SR, y)
|
| 44 |
|
|
|
|
| 47 |
gr.Markdown(
|
| 48 |
"""
|
| 49 |
# 🎧 DCCRN Speech Enhancement (Demo)
|
| 50 |
+
**How to use:** drag & drop a noisy audio clip (or upload / record) → click **Enhance** → listen & download the result.
|
| 51 |
|
|
|
|
| 52 |
**Sample audio:** click a sample below to auto-fill the input, then click **Enhance**.
|
| 53 |
"""
|
| 54 |
)
|
| 55 |
|
| 56 |
with gr.Row():
|
| 57 |
inp = gr.Audio(
|
| 58 |
+
sources=["upload", "microphone"], # drag & drop supported by default
|
| 59 |
type="filepath",
|
| 60 |
+
label="Input: noisy speech (drag & drop or upload / record)"
|
|
|
|
|
|
|
| 61 |
)
|
| 62 |
out = gr.Audio(
|
| 63 |
label="Output: enhanced speech (downloadable)",
|
| 64 |
+
show_download_button=True
|
| 65 |
)
|
| 66 |
|
| 67 |
enhance_btn = gr.Button("Enhance")
|
| 68 |
|
| 69 |
+
# On-page sample clips (make sure these files exist in the repo)
|
| 70 |
gr.Examples(
|
| 71 |
examples=[
|
| 72 |
["examples/noisy_1.wav"],
|
|
|
|
| 81 |
# Gradio ≥4.44: set concurrency on the event listener
|
| 82 |
enhance_btn.click(enhance, inputs=inp, outputs=out, concurrency_limit=1)
|
| 83 |
|
| 84 |
+
# Queue: keep a small queue to avoid OOM
|
| 85 |
demo.queue(max_size=16)
|
| 86 |
demo.launch()
|