File size: 2,150 Bytes
88b57c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# Standard Library
import os

# Third-Party
import streamlit as st

# Local
from src.models.MDX_net.kimvocal import KimVocal
from src.loader import Loader
from src.models.MDX_net.mdx_net import Conv_TDF_net_trimm

# Constants
from src.constants import ONNX_MODEL_PATH

INPUT_FOLDER = "./datasets/input"
OUTPUT_FOLDER = "./datasets/output"


def main():
    # Set page configuration and theming
    st.set_page_config(
        page_title="Jarod's Vocal Remover",
        page_icon="🎵",
    )
    st.title("Vocal Remover")

    # Upload WAV file
    uploaded_file = st.file_uploader(
        "Upload an Audio File (WAV, MP3, OGG, FLAC)",
        type=["wav", "mp3", "ogg", "flac"],
        key="file_uploader",
    )

    if uploaded_file is not None:
        uploaded_file_ext = uploaded_file.name.lower().split(".")[-1]
        # Process the uploaded audio
        st.subheader("Audio Processing")
        st.write("Processing the uploaded audio file...")

        # Display a progress bar while processing
        progress_bar = st.progress(0)
        progress_text = st.empty()

        loader = Loader(INPUT_FOLDER, OUTPUT_FOLDER)
        music_tensor, samplerate = loader.prepare_uploaded_file(
            uploaded_file=uploaded_file
        )

        model_raw_python = Conv_TDF_net_trimm(
            model_path=ONNX_MODEL_PATH,
            use_onnx=True,
            target_name="vocals",
            L=11,
            l=3,
            g=48,
            bn=8,
            bias=False,
            dim_f=11,
            dim_t=8,
        )

        kimvocal = KimVocal()
        vocals_tensor = kimvocal.demix_vocals(
            music_tensor=music_tensor,
            sample_rate=samplerate,
            model=model_raw_python,
            streamlit_progressbar=progress_bar,
        )
        vocals_array = vocals_tensor.numpy()

        # Update progress
        progress_bar.progress(100)
        progress_text.text("Audio processing complete!")

        # Display processed audio
        st.subheader("Processed Audio")
        st.audio(vocals_array, format="audio/wav", sample_rate=samplerate)


if __name__ == "__main__":
    main()