chenxie95 commited on
Commit
265d119
·
verified ·
1 Parent(s): 00aa78e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -14
app.py CHANGED
@@ -10,15 +10,15 @@ from model import DCCRN # requires model.py and utils/ dependencies
10
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
11
  SR = int(os.getenv("SAMPLE_RATE", "16000"))
12
 
13
- # Model repo & filename (override via Space Variables)
14
- REPO_ID = os.getenv("MODEL_REPO_ID", "chenxie95/DCCRN") # <- change default if needed
15
- FILENAME = os.getenv("MODEL_FILENAME", "epoch=44-step=113895.ckpt")
16
  TOKEN = os.getenv("HF_TOKEN") # only required if the model repo is private
17
 
18
  # ===== Download & load weights =====
19
  ckpt_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME, token=TOKEN)
20
 
21
- net = DCCRN() # instantiate with your training-time args if they differ
22
  ckpt = torch.load(ckpt_path, map_location=DEVICE)
23
  state = ckpt.get("state_dict", ckpt)
24
  state = {k.replace("model.", "").replace("module.", ""): v for k, v in state.items()}
@@ -31,11 +31,14 @@ def enhance(audio_path: str):
31
  x = torch.from_numpy(wav).float().to(DEVICE)
32
  if x.ndim == 1:
33
  x = x.unsqueeze(0) # [1, T]
 
34
  with torch.no_grad():
 
35
  try:
36
- y = net(x.unsqueeze(1)) # try [B,1,T]
37
  except Exception:
38
- y = net(x) # fallback [B,T]
 
39
  y = y.squeeze().detach().cpu().numpy()
40
  return (SR, y)
41
 
@@ -44,28 +47,26 @@ with gr.Blocks() as demo:
44
  gr.Markdown(
45
  """
46
  # 🎧 DCCRN Speech Enhancement (Demo)
 
47
 
48
- **How to use:** drag & drop a noisy audio clip (or upload / record) → click **Enhance** → listen & download the result.
49
  **Sample audio:** click a sample below to auto-fill the input, then click **Enhance**.
50
  """
51
  )
52
 
53
  with gr.Row():
54
  inp = gr.Audio(
55
- sources=["upload", "microphone"], # drag & drop supported
56
  type="filepath",
57
- label="Input: noisy speech",
58
- placeholder="Drag & drop or click to upload / record",
59
- show_label=True,
60
  )
61
  out = gr.Audio(
62
  label="Output: enhanced speech (downloadable)",
63
- show_download_button=True,
64
  )
65
 
66
  enhance_btn = gr.Button("Enhance")
67
 
68
- # On-page sample clips (ensure these files exist under examples/)
69
  gr.Examples(
70
  examples=[
71
  ["examples/noisy_1.wav"],
@@ -80,6 +81,6 @@ with gr.Blocks() as demo:
80
  # Gradio ≥4.44: set concurrency on the event listener
81
  enhance_btn.click(enhance, inputs=inp, outputs=out, concurrency_limit=1)
82
 
83
- # Keep a small queue to avoid OOM
84
  demo.queue(max_size=16)
85
  demo.launch()
 
10
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
11
  SR = int(os.getenv("SAMPLE_RATE", "16000"))
12
 
13
+ # Read model repo and filename from environment variables
14
+ REPO_ID = os.getenv("MODEL_REPO_ID", "Ada312/DCCRN") # change default if needed
15
+ FILENAME = os.getenv("MODEL_FILENAME", "dccrn.ckpt")
16
  TOKEN = os.getenv("HF_TOKEN") # only required if the model repo is private
17
 
18
  # ===== Download & load weights =====
19
  ckpt_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME, token=TOKEN)
20
 
21
+ net = DCCRN() # if you trained with custom args, instantiate with the same args here
22
  ckpt = torch.load(ckpt_path, map_location=DEVICE)
23
  state = ckpt.get("state_dict", ckpt)
24
  state = {k.replace("model.", "").replace("module.", ""): v for k, v in state.items()}
 
31
  x = torch.from_numpy(wav).float().to(DEVICE)
32
  if x.ndim == 1:
33
  x = x.unsqueeze(0) # [1, T]
34
+
35
  with torch.no_grad():
36
+ # Many DCCRNs expect [B,1,T]; try that first, fallback to [B,T]
37
  try:
38
+ y = net(x.unsqueeze(1)) # [1, 1, T]
39
  except Exception:
40
+ y = net(x) # [1, T]
41
+
42
  y = y.squeeze().detach().cpu().numpy()
43
  return (SR, y)
44
 
 
47
  gr.Markdown(
48
  """
49
  # 🎧 DCCRN Speech Enhancement (Demo)
50
+ **How to use:** drag & drop a noisy audio clip (or upload / record) → click **Enhance** → listen & download the result.
51
 
 
52
  **Sample audio:** click a sample below to auto-fill the input, then click **Enhance**.
53
  """
54
  )
55
 
56
  with gr.Row():
57
  inp = gr.Audio(
58
+ sources=["upload", "microphone"], # drag & drop supported by default
59
  type="filepath",
60
+ label="Input: noisy speech (drag & drop or upload / record)"
 
 
61
  )
62
  out = gr.Audio(
63
  label="Output: enhanced speech (downloadable)",
64
+ show_download_button=True
65
  )
66
 
67
  enhance_btn = gr.Button("Enhance")
68
 
69
+ # On-page sample clips (make sure these files exist in the repo)
70
  gr.Examples(
71
  examples=[
72
  ["examples/noisy_1.wav"],
 
81
  # Gradio ≥4.44: set concurrency on the event listener
82
  enhance_btn.click(enhance, inputs=inp, outputs=out, concurrency_limit=1)
83
 
84
+ # Queue: keep a small queue to avoid OOM
85
  demo.queue(max_size=16)
86
  demo.launch()