|
|
|
|
|
import gradio as gr |
|
import torch |
|
import torchaudio |
|
from demucs.pretrained import get_model |
|
from demucs.apply import apply_model |
|
import os |
|
import tempfile |
|
import numpy as np |
|
import warnings |
|
import soundfile as sf |
|
import librosa |
|
import time |
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
print("Setting up models...") |
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
print(f"Using device: {device}") |
|
|
|
|
|
print("Loading HT-Demucs model...") |
|
htdemucs_model = get_model(name="htdemucs") |
|
htdemucs_model = htdemucs_model.to(device) |
|
htdemucs_model.eval() |
|
print("HT-Demucs model loaded successfully.") |
|
|
|
|
|
print("Setting up Spleeter...") |
|
spleeter_separator = None |
|
spleeter_audio_adapter = None |
|
spleeter_available = False |
|
|
|
def patch_spleeter_redirects(): |
|
"""Patch Spleeter to handle GitHub redirects properly""" |
|
try: |
|
import httpx |
|
from spleeter.model.provider.github import GithubModelProvider |
|
|
|
|
|
original_download = GithubModelProvider.download |
|
|
|
def patched_download(self, name, model_directory): |
|
"""Patched download method that handles redirects""" |
|
import os |
|
import tarfile |
|
import tempfile |
|
from urllib.parse import urlparse |
|
|
|
print(f"Downloading {name} model with redirect handling...") |
|
|
|
|
|
model_urls = { |
|
'5stems': 'https://github.com/deezer/spleeter/releases/download/v1.4.0/5stems.tar.gz' |
|
} |
|
|
|
if name not in model_urls: |
|
return original_download(self, name, model_directory) |
|
|
|
url = model_urls[name] |
|
|
|
try: |
|
|
|
with httpx.Client(follow_redirects=True, timeout=300) as client: |
|
print(f"Downloading from: {url}") |
|
response = client.get(url) |
|
response.raise_for_status() |
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix='.tar.gz') as tmp_file: |
|
tmp_file.write(response.content) |
|
tmp_file_path = tmp_file.name |
|
|
|
print(f"Downloaded {len(response.content)} bytes") |
|
|
|
|
|
os.makedirs(model_directory, exist_ok=True) |
|
with tarfile.open(tmp_file_path, 'r:gz') as tar: |
|
tar.extractall(model_directory) |
|
|
|
|
|
os.unlink(tmp_file_path) |
|
print(f"β
Successfully downloaded and extracted {name} model") |
|
|
|
except Exception as e: |
|
print(f"β Failed to download {name} model: {e}") |
|
|
|
return original_download(self, name, model_directory) |
|
|
|
|
|
GithubModelProvider.download = patched_download |
|
print("β
Patched Spleeter to handle GitHub redirects") |
|
return True |
|
|
|
except Exception as e: |
|
print(f"β οΈ Could not patch Spleeter redirects: {e}") |
|
return False |
|
|
|
def setup_spleeter_with_retry(): |
|
"""Setup Spleeter 5stems model only""" |
|
global spleeter_separator, spleeter_audio_adapter, spleeter_available |
|
|
|
try: |
|
from spleeter.separator import Separator |
|
from spleeter.audio.adapter import AudioAdapter |
|
import os |
|
|
|
|
|
patch_spleeter_redirects() |
|
|
|
|
|
os.environ['SPLEETER_MODEL_PATH'] = '/tmp/spleeter_models' |
|
|
|
|
|
print("Creating Spleeter 5stems separator...") |
|
spleeter_separator = Separator('spleeter:5stems') |
|
spleeter_audio_adapter = AudioAdapter.default() |
|
spleeter_available = True |
|
print("β
Spleeter 5stems model loaded successfully!") |
|
return True |
|
|
|
except Exception as e: |
|
print(f"β Failed to load Spleeter 5stems: {e}") |
|
spleeter_separator = None |
|
spleeter_audio_adapter = None |
|
spleeter_available = False |
|
return False |
|
|
|
|
|
setup_spleeter_with_retry() |
|
|
|
|
|
def separate_with_htdemucs(audio_path): |
|
""" |
|
Separates an audio file using HT-Demucs into drums, bass, other, and vocals. |
|
Returns FILE PATHS. |
|
""" |
|
if audio_path is None: |
|
return None, None, None, None, "Please upload an audio file." |
|
|
|
try: |
|
print(f"HT-Demucs: Loading audio from: {audio_path}") |
|
|
|
|
|
wav, sr = torchaudio.load(audio_path) |
|
|
|
if wav.shape[0] == 1: |
|
print("Audio is mono, converting to stereo.") |
|
wav = wav.repeat(2, 1) |
|
|
|
wav = wav.to(device) |
|
|
|
print("HT-Demucs: Applying the separation model...") |
|
with torch.no_grad(): |
|
sources = apply_model(htdemucs_model, wav[None], device=device, progress=True)[0] |
|
print("HT-Demucs: Separation complete.") |
|
|
|
|
|
timestamp = int(time.time() * 1000) |
|
output_dir = f"htdemucs_stems_{timestamp}" |
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
stem_names = ["drums", "bass", "other", "vocals"] |
|
|
|
output_paths = [] |
|
for i, name in enumerate(stem_names): |
|
out_path = os.path.join(output_dir, f"{name}_{timestamp}.wav") |
|
torchaudio.save(out_path, sources[i].cpu(), sr) |
|
output_paths.append(out_path) |
|
print(f"β
HT-Demucs saved {name} to {out_path}") |
|
|
|
return output_paths[0], output_paths[1], output_paths[2], output_paths[3], "β
HT-Demucs separation successful!" |
|
|
|
except Exception as e: |
|
print(f"HT-Demucs Error: {e}") |
|
return None, None, None, None, f"β HT-Demucs Error: {str(e)}" |
|
|
|
|
|
def separate_with_spleeter(audio_path): |
|
""" |
|
Separates an audio file using Spleeter into vocals, drums, bass, other, and piano. |
|
Uses Python API approach from stem_separation_spleeter.py |
|
Returns FILE PATHS. |
|
""" |
|
if audio_path is None: |
|
return None, None, None, None, None, "Please upload an audio file." |
|
|
|
if not spleeter_available or spleeter_separator is None or spleeter_audio_adapter is None: |
|
return None, None, None, None, None, "β Spleeter not available. Please install Spleeter." |
|
|
|
try: |
|
print(f"Spleeter: Processing audio from: {audio_path}") |
|
|
|
|
|
timestamp = int(time.time() * 1000) |
|
output_dir = f"spleeter_stems_{timestamp}" |
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
print("Spleeter: Loading audio...") |
|
waveform, sample_rate = spleeter_audio_adapter.load(audio_path, sample_rate=44100) |
|
print(f"Spleeter: Loaded audio - shape: {waveform.shape}, sr: {sample_rate}") |
|
|
|
|
|
print("Spleeter: Separating audio sources...") |
|
prediction = spleeter_separator.separate(waveform) |
|
print("Spleeter: Separation complete.") |
|
print(f"Spleeter: Prediction keys: {list(prediction.keys())}") |
|
|
|
|
|
output_paths = [] |
|
stem_names = ["vocals", "drums", "bass", "other", "piano"] |
|
|
|
for stem_name in stem_names: |
|
if stem_name in prediction: |
|
out_path = os.path.join(output_dir, f"{stem_name}_{timestamp}.wav") |
|
stem_audio = prediction[stem_name] |
|
|
|
print(f"Spleeter: {stem_name} audio shape: {stem_audio.shape}, dtype: {stem_audio.dtype}") |
|
|
|
|
|
sf.write(out_path, stem_audio, sample_rate) |
|
output_paths.append(out_path) |
|
print(f"β
Spleeter saved {stem_name} to {out_path}") |
|
else: |
|
print(f"β οΈ Warning: {stem_name} not found in prediction") |
|
output_paths.append(None) |
|
|
|
|
|
while len(output_paths) < 5: |
|
output_paths.append(None) |
|
|
|
return output_paths[0], output_paths[1], output_paths[2], output_paths[3], output_paths[4], "β
Spleeter separation successful!" |
|
|
|
except Exception as e: |
|
print(f"Spleeter Error: {e}") |
|
import traceback |
|
traceback.print_exc() |
|
return None, None, None, None, None, f"β Spleeter Error: {str(e)}" |
|
|
|
|
|
def separate_selected_models(audio_path, run_htdemucs, run_spleeter): |
|
""" |
|
Separates an audio file using selected models (HT-Demucs, Spleeter, or both). |
|
Returns stems from selected models. |
|
""" |
|
if audio_path is None: |
|
return [None] * 11, "Please upload an audio file." |
|
|
|
if not run_htdemucs and not run_spleeter: |
|
return [None] * 11, "β Please select at least one model to run." |
|
|
|
try: |
|
htdemucs_results = [None] * 5 |
|
spleeter_results = [None] * 6 |
|
status_messages = [] |
|
|
|
|
|
if run_htdemucs: |
|
print("Running HT-Demucs...") |
|
htdemucs_results = separate_with_htdemucs(audio_path) |
|
status_messages.append(htdemucs_results[-1]) |
|
|
|
|
|
if run_spleeter: |
|
print("Running Spleeter...") |
|
spleeter_results = separate_with_spleeter(audio_path) |
|
status_messages.append(spleeter_results[-1]) |
|
|
|
|
|
all_results = list(htdemucs_results[:-1]) + list(spleeter_results[:-1]) |
|
|
|
|
|
models_used = [] |
|
if run_htdemucs: |
|
models_used.append("HT-Demucs") |
|
if run_spleeter: |
|
models_used.append("Spleeter") |
|
|
|
combined_status = f"π΅ {' + '.join(models_used)} completed!\n\n" + "\n".join(status_messages) |
|
|
|
return all_results + [combined_status] |
|
|
|
except Exception as e: |
|
print(f"Combined Error: {e}") |
|
import traceback |
|
traceback.print_exc() |
|
return [None] * 11, f"β Error: {str(e)}" |
|
|
|
|
|
print("Creating Gradio interface...") |
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
gr.Markdown(""" |
|
# π΅ Spleeter & Demucs - Now Both Work! |
|
|
|
**Follow me on:** [ Hugging Face @ahk-d](https://huggingface.co/ahk-d) | [ GitHub @ahk-d](https://github.com/ahk-d) |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
audio_input = gr.Audio(type="filepath", label="π΅ Upload Your Song") |
|
|
|
|
|
gr.Markdown("### ποΈ Select Models to Run") |
|
with gr.Row(): |
|
htdemucs_toggle = gr.Checkbox(label="π― HT-Demucs", value=True, info="Drums, Bass, Other, Vocals") |
|
spleeter_label = "π΅ Spleeter 2025 (5stems)" if spleeter_available else "π΅ Spleeter 2025" |
|
spleeter_info = "Vocals, Drums, Bass, Other, Piano" if spleeter_available else "5stems model not available" |
|
spleeter_toggle = gr.Checkbox( |
|
label=spleeter_label, |
|
value=spleeter_available, |
|
info=spleeter_info, |
|
interactive=spleeter_available |
|
) |
|
|
|
separate_button = gr.Button("π Separate Music", variant="primary", size="lg") |
|
status_output = gr.Textbox(label="π Status", interactive=False, lines=4) |
|
|
|
gr.Markdown("---") |
|
|
|
with gr.Row(): |
|
|
|
with gr.Column(): |
|
gr.Markdown("### π― HT-Demucs Results") |
|
with gr.Row(): |
|
htdemucs_drums = gr.Audio(label="π₯ Drums", type="filepath") |
|
htdemucs_bass = gr.Audio(label="πΈ Bass", type="filepath") |
|
with gr.Row(): |
|
htdemucs_other = gr.Audio(label="πΌ Other", type="filepath") |
|
htdemucs_vocals = gr.Audio(label="π€ Vocals", type="filepath") |
|
|
|
|
|
with gr.Column(): |
|
gr.Markdown("### π΅ Spleeter 2025 Results") |
|
with gr.Row(): |
|
spleeter_vocals = gr.Audio(label="π€ Vocals", type="filepath") |
|
spleeter_drums = gr.Audio(label="π₯ Drums", type="filepath") |
|
with gr.Row(): |
|
spleeter_bass = gr.Audio(label="πΈ Bass", type="filepath") |
|
spleeter_other = gr.Audio(label="πΌ Other", type="filepath") |
|
with gr.Row(): |
|
spleeter_piano = gr.Audio(label="πΉ Piano", type="filepath") |
|
|
|
if spleeter_available: |
|
gr.Markdown("*5stems model: Vocals, Drums, Bass, Other, Piano*") |
|
else: |
|
gr.Markdown("*Note: Spleeter 5stems model not available*") |
|
|
|
gr.Markdown("---") |
|
|
|
with gr.Row(): |
|
comparison_text = f""" |
|
### π Model Comparison |
|
|
|
| Feature | HT-Demucs | Spleeter 2025 (5stems) | |
|
|---------|-----------|----------| |
|
| **Vocals** | β
High Quality | {'β
Available' if spleeter_available else 'β N/A'} | |
|
| **Drums** | β
High Quality | {'β
Available' if spleeter_available else 'β N/A'} | |
|
| **Bass** | β
High Quality | {'β
Available' if spleeter_available else 'β N/A'} | |
|
| **Other** | β
High Quality | {'β
Available' if spleeter_available else 'β N/A'} | |
|
| **Piano** | β Not Available | {'β
**Available**' if spleeter_available else 'β N/A'} | |
|
| **Speed** | β‘ Fast | {'β‘ Fast' if spleeter_available else 'β N/A'} | |
|
| **Quality** | π Excellent | {'π Good' if spleeter_available else 'β N/A'} | |
|
|
|
**π‘ Tip:** Use Spleeter 2025 for piano separation, HT-Demucs for other instruments! |
|
""" |
|
gr.Markdown(comparison_text) |
|
|
|
|
|
separate_button.click( |
|
fn=separate_selected_models, |
|
inputs=[audio_input, htdemucs_toggle, spleeter_toggle], |
|
outputs=[ |
|
htdemucs_drums, htdemucs_bass, htdemucs_other, htdemucs_vocals, |
|
spleeter_vocals, spleeter_drums, spleeter_bass, spleeter_other, spleeter_piano, |
|
status_output |
|
] |
|
) |
|
|
|
gr.Markdown(""" |
|
--- |
|
<p style='text-align: center; font-size: small;'> |
|
π Powered by <strong>HT-Demucs</strong> & <strong>Spleeter 2025</strong> | |
|
π΅ Compare and choose your best stems! |
|
</p> |
|
""") |
|
|
|
if __name__ == "__main__": |
|
demo.launch(share=True) |