rendchevi commited on
Commit
b546670
·
1 Parent(s): 7a08c0b

initial commit

Browse files
.gitignore ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python cache
2
+ __pycache__/
3
+ *.pyc
4
+ .ipynb_checkpoints
5
+
6
+ notebooks/
7
+
8
+ secrets.toml
9
+
10
+ cache_sound/*.wav
11
+ assets/*/*.onnx
.streamlit/config.toml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ [theme]
2
+
3
+ base="dark"
4
+ primaryColor="#4551b2"
5
+ backgroundColor="#292a57"
6
+ textColor="#faf9fc"
README.md DELETED
@@ -1,13 +0,0 @@
1
- ---
2
- title: Nix Tts
3
- emoji: 📚
4
- colorFrom: purple
5
- colorTo: red
6
- sdk: streamlit
7
- sdk_version: 1.2.0
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Utils
2
+ import os
3
+ import soundfile as sf
4
+
5
+ # Streamlit
6
+ import streamlit as st
7
+
8
+ # Custom elements
9
+ from elements.component import (
10
+ centered_text,
11
+ )
12
+ from elements.session_states import (
13
+ init_session_state,
14
+ update_session_state,
15
+ update_model,
16
+ )
17
+ from elements.tts import (
18
+ generate_voice,
19
+ )
20
+
21
+ st.set_page_config(
22
+ page_title = "Nix-TTS Interactive Demo",
23
+ page_icon = "🐤",
24
+ )
25
+
26
+ # Initiate stuffs
27
+ init_session_state()
28
+
29
+ # ---------------------------------------------------------------------------------
30
+
31
+ # Description
32
+ centered_text("🐤 Nix-TTS Interactive Demo")
33
+ centered_text("An incredibly lightweight end-to-end text-to-speech model via knowledge distillation", "h5")
34
+ st.write(" ")
35
+ st.caption("🗒️ This is a demo from our latest paper, **Nix-TTS**. <br>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;You can access the paper and the released models [here](https://github.com/rendchevi/nix-tts). <br>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;Authors: Rendi Chevi, Radityo Eko Prasojo, Alham Fikri Aji.<br>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;**Corresponding Author**: Rendi Chevi | rendi.chevi@{kata.ai, gmail.com}.", True)
36
+
37
+ # Model demo
38
+ st.write(" ")
39
+ st.write(" ")
40
+ col1, col2 = st.columns(2)
41
+ with col1:
42
+ input_text = st.text_input(
43
+ "Input Text",
44
+ value = "Born to multiply, born to gaze into night skies.",
45
+ )
46
+ with col2:
47
+ model_variant = st.selectbox("Choose Model Variant", options = ["Deterministic", "Stochastic"], index = 1)
48
+ if model_variant != st.session_state.model_variant:
49
+ # Update variant choice
50
+ update_session_state("model_variant", model_variant)
51
+ # Re-load model
52
+ update_model()
53
+
54
+ button_gen = st.button("Generate Voice")
55
+ if button_gen == True:
56
+ generate_voice(input_text)
assets/nix-ljspeech-sdp-v0.1/foo.txt ADDED
File without changes
assets/nix-ljspeech-sdp-v0.1/tokenizer_state.pkl ADDED
Binary file (2.7 kB). View file
 
assets/nix-ljspeech-v0.1/foo.txt ADDED
File without changes
assets/nix-ljspeech-v0.1/tokenizer_state.pkl ADDED
Binary file (2.7 kB). View file
 
cache_sound/foo.txt ADDED
File without changes
elements/component.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Streamlit
2
+ import streamlit as st
3
+
4
+ def centered_text(
5
+ input_text,
6
+ mode = "h1",
7
+ ):
8
+ st.markdown(
9
+ f"<{mode} style='text-align: center;'>{input_text}</{mode}>",
10
+ unsafe_allow_html = True
11
+ )
elements/session_states.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Utils
2
+ import uuid
3
+
4
+ # Streamlit
5
+ import streamlit as st
6
+
7
+ # Nix
8
+ from nix.models.TTS import NixTTSInference
9
+
10
+ # --------------------- SESSION STATE MANAGEMENT -------------------------
11
+
12
+ def init_session_state():
13
+ # Model
14
+ if "init_model" not in st.session_state:
15
+ st.session_state.init_model = True
16
+ st.session_state.random_str = uuid.uuid1().hex
17
+ st.session_state.model_variant = "Stochastic"
18
+ st.session_state.TTS = NixTTSInference("assets/nix-ljspeech-sdp-v0.1")
19
+
20
+ def update_model():
21
+ if st.session_state.model_variant == "Deterministic":
22
+ st.session_state.TTS = NixTTSInference("assets/nix-ljspeech-v0.1")
23
+ elif st.session_state.model_variant == "Stochastic":
24
+ st.session_state.TTS = NixTTSInference("assets/nix-ljspeech-sdp-v0.1")
25
+
26
+ def update_session_state(
27
+ state_id,
28
+ state_value,
29
+ ):
30
+ st.session_state[f"{state_id}"] = state_value
elements/tts.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Utils
2
+ import timeit
3
+ import soundfile as sf
4
+
5
+ # Streamlit
6
+ import streamlit as st
7
+
8
+ # Custom elements
9
+ from elements.component import (
10
+ centered_text,
11
+ )
12
+
13
+ def generate_voice(
14
+ input_text,
15
+ ):
16
+ # TTS Inference
17
+ start_time = timeit.default_timer()
18
+ c, c_length, phoneme = st.session_state.TTS.tokenize(input_text)
19
+ tok_time = timeit.default_timer() - start_time
20
+
21
+ start_time = timeit.default_timer()
22
+ voice = st.session_state.TTS.vocalize(c, c_length)
23
+ tts_time = timeit.default_timer() - start_time
24
+
25
+ # Time stats
26
+ total_infer_time = tts_time + tok_time
27
+ audio_time = voice.shape[-1] / 22050
28
+ rtf = total_infer_time / audio_time
29
+ rt_ratio = 1 / rtf
30
+
31
+ # Save audio (bug in Streamlit, can't play numpy array directly)
32
+ sf.write(f"cache_sound/{st.session_state.random_str}.wav", voice[0,0], 22050)
33
+
34
+ # Play audio
35
+ st.audio(f"cache_sound/{st.session_state.random_str}.wav", format = "audio/wav")
36
+ st.caption("Generated Voice")
37
+
38
+ st.code(
39
+ f"💬 Output Audio: {str(audio_time)[:6]} sec.\n\n⏳ Elapsed time for:\n => Tokenization: {str(tok_time)[:6]} sec.\n => Model Inference: {str(tts_time)[:6]} sec.\n\n⏰ Real-time Factor (RTF): {str(rtf)[:6]}\n\n🏃 The model runs {str(rt_ratio)[:6]} x faster than real-time \
40
+ ",
41
+ language = "bash",
42
+ )
43
+ st.caption("Elapsed Time Stats")
nix/__init__.py ADDED
File without changes
nix/models/TTS.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ import timeit
4
+
5
+ import numpy as np
6
+ import onnxruntime as ort
7
+
8
+ from nix.tokenizers.tokenizer_en import NixTokenizerEN
9
+
10
+ class NixTTSInference:
11
+
12
+ def __init__(
13
+ self,
14
+ model_dir,
15
+ ):
16
+ # Load tokenizer
17
+ self.tokenizer = NixTokenizerEN(pickle.load(open(os.path.join(model_dir, "tokenizer_state.pkl"), "rb")))
18
+ # Load TTS model
19
+ self.encoder = ort.InferenceSession(os.path.join(model_dir, "encoder.onnx"))
20
+ self.decoder = ort.InferenceSession(os.path.join(model_dir, "decoder.onnx"))
21
+
22
+ def tokenize(
23
+ self,
24
+ text,
25
+ ):
26
+ # Tokenize input text
27
+ c, c_lengths, phonemes = self.tokenizer([text])
28
+
29
+ return np.array(c, dtype = np.int64), np.array(c_lengths, dtype = np.int64), phonemes
30
+
31
+ def vocalize(
32
+ self,
33
+ c,
34
+ c_lengths,
35
+ ):
36
+ """
37
+ Single-batch TTS inference
38
+ """
39
+ # Infer latent samples from encoder
40
+ z = self.encoder.run(
41
+ None,
42
+ {
43
+ "c": c,
44
+ "c_lengths": c_lengths,
45
+ }
46
+ )[2]
47
+ # Decode raw audio with decoder
48
+ xw = self.decoder.run(
49
+ None,
50
+ {
51
+ "z": z,
52
+ }
53
+ )[0]
54
+
55
+ return xw
nix/tokenizers/tokenizer_en.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Regex
2
+ import re
3
+
4
+ # Phonemizer
5
+ from phonemizer.backend import EspeakBackend
6
+ phonemizer_backend = EspeakBackend(
7
+ language = 'en-us',
8
+ preserve_punctuation = True,
9
+ with_stress = True
10
+ )
11
+
12
+ class NixTokenizerEN:
13
+
14
+ def __init__(
15
+ self,
16
+ tokenizer_state,
17
+ ):
18
+ # Vocab and abbreviations dictionary
19
+ self.vocab_dict = tokenizer_state["vocab_dict"]
20
+ self.abbreviations_dict = tokenizer_state["abbreviations_dict"]
21
+
22
+ # Regex recipe
23
+ self.whitespace_regex = tokenizer_state["whitespace_regex"]
24
+ self.abbreviations_regex = tokenizer_state["abbreviations_regex"]
25
+
26
+ def __call__(
27
+ self,
28
+ texts,
29
+ ):
30
+ # 1. Phonemize input texts
31
+ phonemes = [ self._collapse_whitespace(
32
+ phonemizer_backend.phonemize(
33
+ self._expand_abbreviations(text.lower()),
34
+ strip = True,
35
+ )
36
+ ) for text in texts ]
37
+
38
+ # 2. Tokenize phonemes
39
+ tokens = [ self._intersperse([self.vocab_dict[p] for p in phoneme], 0) for phoneme in phonemes ]
40
+
41
+ # 3. Pad tokens
42
+ tokens, tokens_lengths = self._pad_tokens(tokens)
43
+
44
+ return tokens, tokens_lengths, phonemes
45
+
46
+ def _expand_abbreviations(
47
+ self,
48
+ text
49
+ ):
50
+ for regex, replacement in self.abbreviations_regex:
51
+ text = re.sub(regex, replacement, text)
52
+
53
+ return text
54
+
55
+ def _collapse_whitespace(
56
+ self,
57
+ text
58
+ ):
59
+ return re.sub(self.whitespace_regex, ' ', text)
60
+
61
+ def _intersperse(
62
+ self,
63
+ lst,
64
+ item,
65
+ ):
66
+ result = [item] * (len(lst) * 2 + 1)
67
+ result[1::2] = lst
68
+ return result
69
+
70
+ def _pad_tokens(
71
+ self,
72
+ tokens,
73
+ ):
74
+ tokens_lengths = [len(token) for token in tokens]
75
+ max_len = max(tokens_lengths)
76
+ tokens = [token + [0 for _ in range(max_len - len(token))] for token in tokens]
77
+ return tokens, tokens_lengths
packages.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ libsndfile1-dev
2
+ espeak
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ streamlit
2
+ numpy
3
+ onnxruntime
4
+ phonemizer
5
+ SoundFile