import os import numpy as np import torch import gradio as gr import librosa from huggingface_hub import hf_hub_download from model import DCCRN # requires model.py and utils/ dependencies # ===== Basic config ===== DEVICE = "cuda" if torch.cuda.is_available() else "cpu" SR = int(os.getenv("SAMPLE_RATE", "16000")) # Read model repo and filename from environment variables REPO_ID = os.getenv("MODEL_REPO_ID", "Ada312/DCCRN") # change default if needed FILENAME = os.getenv("MODEL_FILENAME", "dccrn.ckpt") TOKEN = os.getenv("HF_TOKEN") # only required if the model repo is private # ===== Download & load weights ===== ckpt_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME, token=TOKEN) net = DCCRN() # if you trained with custom args, instantiate with the same args here ckpt = torch.load(ckpt_path, map_location=DEVICE) state = ckpt.get("state_dict", ckpt) state = {k.replace("model.", "").replace("module.", ""): v for k, v in state.items()} net.load_state_dict(state, strict=False) net.to(DEVICE).eval() # ===== Inference ===== def enhance(audio_path: str): wav, _ = librosa.load(audio_path, sr=SR, mono=True) x = torch.from_numpy(wav).float().to(DEVICE) if x.ndim == 1: x = x.unsqueeze(0) # [1, T] with torch.no_grad(): # Many DCCRNs expect [B,1,T]; try that first, fallback to [B,T] try: y = net(x.unsqueeze(1)) # [1, 1, T] except Exception: y = net(x) # [1, T] y = y.squeeze().detach().cpu().numpy() return (SR, y) # ===== Gradio UI ===== with gr.Blocks() as demo: gr.Markdown( """ # 🎧 DCCRN Speech Enhancement (Demo) **How to use:** drag & drop a noisy audio clip (or upload / record) → click **Enhance** → listen & download the result. **Sample audio:** click a sample below to auto-fill the input, then click **Enhance**. """ ) with gr.Row(): inp = gr.Audio( sources=["upload", "microphone"], # drag & drop supported by default type="filepath", label="Input: noisy speech (drag & drop or upload / record)" ) out = gr.Audio( label="Output: enhanced speech (downloadable)", show_download_button=True ) enhance_btn = gr.Button("Enhance") # On-page sample clips (make sure these files exist in the repo) gr.Examples( examples=[ ["examples/noisy_1.wav"], ["examples/noisy_2.wav"], ["examples/noisy_3.wav"], ], inputs=inp, label="Sample audio", examples_per_page=3, ) # Gradio ≥4.44: set concurrency on the event listener enhance_btn.click(enhance, inputs=inp, outputs=out, concurrency_limit=1) # Queue: keep a small queue to avoid OOM demo.queue(max_size=16) demo.launch()