Spaces:
Running
Running
| import streamlit as st | |
| import numpy as np | |
| import librosa | |
| import soundfile as sf | |
| import os | |
| import tempfile | |
| from pathlib import Path | |
| import torch | |
| from tqdm import tqdm | |
| import base64 | |
| import io | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| # Page configuration | |
| st.set_page_config( | |
| page_title="Music Stem Splitter", | |
| page_icon="🎵", | |
| layout="wide", | |
| initial_sidebar_state="expanded" | |
| ) | |
| # Set maximum audio duration (in seconds) and file size (in MB) | |
| MAX_AUDIO_DURATION = 300 # 5 minutes | |
| MAX_FILE_SIZE_MB = 100 | |
| # Load pretrained separator model | |
| def load_separator_model(): | |
| try: | |
| # Import here to avoid loading until needed | |
| from demucs.pretrained import get_model | |
| model = get_model('htdemucs') | |
| model.eval() | |
| if torch.cuda.is_available(): | |
| model.cuda() | |
| return model | |
| except ImportError: | |
| st.error("Required package 'demucs' not found. Please install it with 'pip install demucs'.") | |
| return None | |
| # Function to check audio length | |
| def check_audio_length(audio_path): | |
| try: | |
| duration = librosa.get_duration(path=audio_path) | |
| return duration | |
| except Exception as e: | |
| st.error(f"Could not determine audio length: {str(e)}") | |
| return MAX_AUDIO_DURATION + 1 # Return a value that will fail the check | |
| # Function to separate stems from an audio file | |
| def separate_stems(audio_path, model, sample_rate=44100): | |
| from demucs.apply import apply_model | |
| import torchaudio | |
| # Load audio with potential resampling to save memory | |
| waveform, original_sample_rate = torchaudio.load(audio_path) | |
| # Resample if needed to optimize memory usage | |
| if original_sample_rate > sample_rate: | |
| resampler = torchaudio.transforms.Resample(orig_freq=original_sample_rate, new_freq=sample_rate) | |
| waveform = resampler(waveform) | |
| st.info(f"Audio resampled from {original_sample_rate}Hz to {sample_rate}Hz to optimize performance.") | |
| else: | |
| sample_rate = original_sample_rate | |
| # Create a mono version just for visualization | |
| if waveform.shape[0] > 1: | |
| waveform_mono = torch.mean(waveform, dim=0, keepdim=True) | |
| else: | |
| waveform_mono = waveform | |
| # Get the audio length in seconds for progress tracking | |
| audio_length = waveform.shape[1] / sample_rate | |
| # Create a progress bar | |
| progress_bar = st.progress(0) | |
| status_text = st.empty() | |
| # Prepare the model input | |
| if torch.cuda.is_available(): | |
| waveform = waveform.cuda() | |
| # For Demucs, we need the audio as (batch, channels, time) | |
| if waveform.dim() == 2: # (channels, time) | |
| waveform = waveform.unsqueeze(0) | |
| # Create a temp directory for saving stems | |
| temp_dir = tempfile.mkdtemp() | |
| stems = {} | |
| # Process and separate stems | |
| status_text.text("Separating stems... This may take a while depending on the audio length.") | |
| # Optimize memory usage by processing in chunks if needed | |
| with torch.no_grad(): | |
| # Use smaller chunks for CPU, larger for GPU | |
| chunk_size = 10 * sample_rate if torch.cuda.is_available() else 5 * sample_rate | |
| if waveform.shape[-1] > chunk_size and waveform.shape[-1] > 30 * sample_rate: | |
| # Process in chunks for very long audio | |
| st.info("Processing long audio in chunks to optimize memory usage...") | |
| sources = [] | |
| # Calculate number of chunks | |
| num_chunks = int(np.ceil(waveform.shape[-1] / chunk_size)) | |
| for i in range(num_chunks): | |
| # Update progress | |
| progress = i / num_chunks * 0.7 # 70% of progress for separation | |
| progress_bar.progress(progress) | |
| status_text.text(f"Processing chunk {i+1}/{num_chunks}...") | |
| # Extract chunk | |
| start = i * chunk_size | |
| end = min(start + chunk_size, waveform.shape[-1]) | |
| chunk = waveform[:, :, start:end] | |
| # Process chunk | |
| chunk_sources = apply_model(model, chunk, device="cuda" if torch.cuda.is_available() else "cpu") | |
| # Append to sources | |
| if i == 0: | |
| sources = chunk_sources | |
| else: | |
| # Concatenate along time dimension | |
| sources = torch.cat([sources, chunk_sources], dim=-1) | |
| # Clear GPU memory if needed | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| else: | |
| # Process entire audio at once for shorter clips | |
| sources = apply_model(model, waveform, device="cuda" if torch.cuda.is_available() else "cpu") | |
| # sources is (batch, source, channels, time) | |
| sources = sources[0] # Remove batch dimension | |
| # Save each source | |
| source_names = ["drums", "bass", "other", "vocals"] | |
| for i, source_name in enumerate(source_names): | |
| stems[source_name] = sources[i].cpu().numpy() | |
| # Update progress | |
| progress = 0.7 + (i + 1) / len(source_names) * 0.2 # 20% of progress for stem saving | |
| progress_bar.progress(progress) | |
| status_text.text(f"Processed {source_name} stem ({i+1}/{len(source_names)})") | |
| # Create visualizations (at reduced resolution to save memory) | |
| visualizations = {} | |
| for stem_name, audio_data in stems.items(): | |
| # Create spectrogram visualization | |
| plt.figure(figsize=(10, 4)) | |
| # Use a smaller portion of audio for visualization if it's too long | |
| max_samples = min(sample_rate * 30, audio_data.shape[1]) # 30 seconds max | |
| visualization_data = audio_data[0, :max_samples] if audio_data.shape[1] > max_samples else audio_data[0] | |
| # Create spectrogram with reduced resolution | |
| D = librosa.amplitude_to_db(np.abs(librosa.stft(visualization_data, n_fft=1024, hop_length=512)), ref=np.max) | |
| plt.subplot(1, 1, 1) | |
| librosa.display.specshow(D, y_axis='log', x_axis='time', sr=sample_rate) | |
| plt.title(f'{stem_name.capitalize()} Spectrogram') | |
| plt.colorbar(format='%+2.0f dB') | |
| plt.tight_layout() | |
| # Save figure to bytes | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', dpi=100) # Lower DPI to save memory | |
| buf.seek(0) | |
| visualizations[stem_name] = buf | |
| plt.close() | |
| # Clear GPU memory | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| # Update progress to complete | |
| progress_bar.progress(1.0) | |
| status_text.text("Stem separation complete!") | |
| return stems, sample_rate, visualizations | |
| # Function to create a download link for audio files | |
| def get_binary_file_downloader_html(bin_data, file_label, file_extension): | |
| b64data = base64.b64encode(bin_data).decode() | |
| href = f'<a href="data:audio/{file_extension};base64,{b64data}" download="{file_label}.{file_extension}">Download {file_label}</a>' | |
| return href | |
| # Title and description | |
| st.title("🎵 Music Stem Splitter") | |
| st.markdown(""" | |
| This application separates music tracks into individual stems: | |
| - **Vocals**: Lead and background vocals | |
| - **Drums**: Drum kit and percussion | |
| - **Bass**: Bass guitar, synth bass, etc. | |
| - **Other**: All other instruments and sounds | |
| Upload an audio file (MP3, WAV, or FLAC) to get started. | |
| """) | |
| # Add warning about HF Spaces limitations | |
| st.warning(f""" | |
| ⚠️ **Hugging Face Spaces Limitations**: | |
| - Maximum file size: {MAX_FILE_SIZE_MB}MB | |
| - Maximum audio duration: {MAX_AUDIO_DURATION} seconds ({MAX_AUDIO_DURATION//60} minutes) | |
| - Processing may take several minutes depending on server load | |
| """) | |
| # Initialize session state for storing results | |
| if 'stems' not in st.session_state: | |
| st.session_state.stems = None | |
| if 'sample_rate' not in st.session_state: | |
| st.session_state.sample_rate = None | |
| if 'visualizations' not in st.session_state: | |
| st.session_state.visualizations = None | |
| # File uploader | |
| st.subheader("Upload Audio File") | |
| uploaded_file = st.file_uploader("Choose an audio file", type=["mp3", "wav", "flac", "ogg"]) | |
| # Model loading (only when needed) | |
| model_load_state = st.empty() | |
| # Process the uploaded file | |
| if uploaded_file is not None: | |
| # Check file size | |
| file_size_mb = uploaded_file.size / 1e6 | |
| if file_size_mb > MAX_FILE_SIZE_MB: | |
| st.error(f"File too large: {file_size_mb:.1f}MB. Maximum allowed size is {MAX_FILE_SIZE_MB}MB.") | |
| else: | |
| # Display file info | |
| file_details = {"Filename": uploaded_file.name, "FileSize": f"{file_size_mb:.2f} MB"} | |
| st.write(file_details) | |
| # Create a temporary file | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(uploaded_file.name)[1]) as tmp_file: | |
| tmp_file.write(uploaded_file.getvalue()) | |
| tmp_path = tmp_file.name | |
| # Check audio duration | |
| audio_duration = check_audio_length(tmp_path) | |
| if audio_duration > MAX_AUDIO_DURATION: | |
| st.error(f"Audio duration too long: {audio_duration:.1f} seconds. Maximum allowed duration is {MAX_AUDIO_DURATION} seconds ({MAX_AUDIO_DURATION//60} minutes).") | |
| # Clean up temporary file | |
| os.unlink(tmp_path) | |
| else: | |
| st.info(f"Audio duration: {audio_duration:.1f} seconds") | |
| # Load model (with caching for efficiency) | |
| with model_load_state: | |
| st.info("Loading AI model... This may take a moment the first time.") | |
| model = load_separator_model() | |
| if model is not None: | |
| # Process button | |
| if st.button("Split into Stems"): | |
| try: | |
| # Select processing sample rate based on file duration | |
| # Shorter files can use higher quality, longer files use lower to save memory | |
| if audio_duration < 60: # Less than 1 minute | |
| processing_sample_rate = 44100 | |
| elif audio_duration < 180: # 1-3 minutes | |
| processing_sample_rate = 32000 | |
| else: # 3-5 minutes | |
| processing_sample_rate = 22050 | |
| # Perform stem separation | |
| st.session_state.stems, st.session_state.sample_rate, st.session_state.visualizations = separate_stems( | |
| tmp_path, model, sample_rate=processing_sample_rate | |
| ) | |
| st.success("Stem separation completed! Scroll down to preview and download individual stems.") | |
| except Exception as e: | |
| st.error(f"An error occurred during processing: {str(e)}") | |
| st.info("Try with a shorter audio clip or a different file format.") | |
| else: | |
| st.warning("Required packages not available. To run locally, install with 'pip install demucs librosa soundfile'") | |
| # Clean up temporary file | |
| os.unlink(tmp_path) | |
| # Display results if available | |
| if st.session_state.stems is not None: | |
| st.header("Separated Stems") | |
| # Create tabs for each stem | |
| stem_tabs = st.tabs(["Vocals", "Drums", "Bass", "Other"]) | |
| # Get stem names in correct order | |
| stem_names = ["vocals", "drums", "bass", "other"] | |
| # Process each stem | |
| for i, (stem_tab, stem_name) in enumerate(zip(stem_tabs, stem_names)): | |
| with stem_tab: | |
| # Create columns for audio player and visualization | |
| col1, col2 = st.columns([1, 1]) | |
| with col1: | |
| st.subheader(f"{stem_name.capitalize()} Stem") | |
| # Convert numpy array to audio file for playback | |
| audio_data = st.session_state.stems[stem_name] | |
| # Create a temporary buffer for the audio data | |
| buf = io.BytesIO() | |
| sf.write(buf, audio_data.T, st.session_state.sample_rate, format='WAV') | |
| buf.seek(0) | |
| # Display audio player | |
| st.audio(buf, format='audio/wav') | |
| # Download button | |
| st.markdown(get_binary_file_downloader_html(buf.getvalue(), f"{stem_name}", "wav"), unsafe_allow_html=True) | |
| # Additional information | |
| if stem_name == "vocals": | |
| st.info("Contains lead vocals and backing vocals.") | |
| elif stem_name == "drums": | |
| st.info("Contains drums and percussion elements.") | |
| elif stem_name == "bass": | |
| st.info("Contains bass guitar and low-frequency elements.") | |
| else: # other | |
| st.info("Contains all other instruments (guitars, keys, synths, etc).") | |
| with col2: | |
| # Display visualization | |
| if st.session_state.visualizations and stem_name in st.session_state.visualizations: | |
| st.image(st.session_state.visualizations[stem_name], caption=f"{stem_name.capitalize()} Spectrogram") | |
| # Show instructions for downloading all stems | |
| st.header("Usage Tips") | |
| st.markdown(""" | |
| ### What can you do with these stems? | |
| - Create remixes or mashups | |
| - Practice playing along with isolated instrument tracks | |
| - Create karaoke versions by removing vocals | |
| - Analyze individual instrument parts for educational purposes | |
| ### Next steps: | |
| 1. Download each stem you want to use | |
| 2. Import them into your DAW (Digital Audio Workstation) | |
| 3. Mix, process, and create! | |
| """) | |
| # Add instructions for local deployment | |
| st.sidebar.header("About This App") | |
| st.sidebar.markdown(""" | |
| This application uses the Demucs model to separate audio tracks into individual stems. The model was developed by Facebook AI Research. | |
| ### How it works | |
| The separation process uses a deep neural network to identify and isolate: | |
| - Vocals | |
| - Drums | |
| - Bass | |
| - Other instruments | |
| ### Source code | |
| [GitHub Repository](https://github.com/huggingface/music-stem-splitter) | |
| (Link to your repo once created) | |
| """) | |
| # Add a note about processing time | |
| st.sidebar.markdown(""" | |
| ### Processing Time | |
| The processing time depends on: | |
| - Length of the audio file | |
| - Available computational resources | |
| - File quality | |
| For best results, use high-quality audio files without excessive background noise. | |
| """) | |
| # Show model information | |
| st.sidebar.markdown(""" | |
| ### Model Information | |
| This app uses the HTDemucs model, which is trained to separate music into four stems. | |
| Audio processing is optimized based on file length: | |
| - Short files (< 1 min): 44.1kHz processing | |
| - Medium files (1-3 min): 32kHz processing | |
| - Longer files (3-5 min): 22kHz processing | |
| """) |