ollieollie commited on
Commit
7438313
·
1 Parent(s): fae012e
chatterbox/src/chatterbox/__pycache__/tts.cpython-311.pyc CHANGED
Binary files a/chatterbox/src/chatterbox/__pycache__/tts.cpython-311.pyc and b/chatterbox/src/chatterbox/__pycache__/tts.cpython-311.pyc differ
 
chatterbox/src/chatterbox/models/t3/inference/__pycache__/t3_hf_backend.cpython-311.pyc CHANGED
Binary files a/chatterbox/src/chatterbox/models/t3/inference/__pycache__/t3_hf_backend.cpython-311.pyc and b/chatterbox/src/chatterbox/models/t3/inference/__pycache__/t3_hf_backend.cpython-311.pyc differ
 
chatterbox/src/chatterbox/tts.py CHANGED
@@ -6,7 +6,6 @@ import torch
6
  import perth
7
  import torch.nn.functional as F
8
  from huggingface_hub import hf_hub_download
9
- from silero_vad import load_silero_vad, get_speech_timestamps
10
 
11
  from .models.t3 import T3
12
  from .models.s3tokenizer import S3_SR, drop_invalid_tokens
@@ -14,23 +13,11 @@ from .models.s3gen import S3GEN_SR, S3Gen
14
  from .models.tokenizers import EnTokenizer
15
  from .models.voice_encoder import VoiceEncoder
16
  from .models.t3.modules.cond_enc import T3Cond
17
- from .utils import trim_silence
18
 
19
 
20
  REPO_ID = "ResembleAI/chatterbox"
21
 
22
 
23
- def change_pace(speech_tokens: torch.Tensor, pace: float):
24
- """
25
- :param speech_tokens: Tensor of shape (L,)
26
- :param pace: float, pace (default: 1)
27
- """
28
- L = len(speech_tokens)
29
- speech_tokens = F.interpolate(speech_tokens.view(1, 1, -1).float(), size=int(L / pace), mode="nearest")
30
- speech_tokens = speech_tokens.view(-1).long()
31
- return speech_tokens
32
-
33
-
34
  def punc_norm(text: str) -> str:
35
  """
36
  Quick cleanup func for punctuation from LLMs or
@@ -134,7 +121,6 @@ class ChatterboxTTS:
134
  self.device = device
135
  self.conds = conds
136
  self.watermarker = perth.PerthImplicitWatermarker()
137
- self.vad_model = load_silero_vad()
138
 
139
  @classmethod
140
  def from_local(cls, ckpt_dir, device) -> 'ChatterboxTTS':
@@ -182,19 +168,6 @@ class ChatterboxTTS:
182
 
183
  ref_16k_wav = librosa.resample(s3gen_ref_wav, orig_sr=S3GEN_SR, target_sr=S3_SR)
184
 
185
- vad_wav = ref_16k_wav
186
- if S3_SR != 16000:
187
- vad_wav = librosa.resample(ref_16k_wav, orig_sr=S3_SR, target_sr=16000)
188
-
189
- speech_timestamps = get_speech_timestamps(
190
- vad_wav,
191
- self.vad_model,
192
- return_seconds=True,
193
- )
194
-
195
- # s3gen_ref_wav = trim_silence(s3gen_ref_wav, speech_timestamps, S3GEN_SR)
196
- # ref_16k_wav = trim_silence(ref_16k_wav, speech_timestamps, S3_SR)
197
-
198
  s3gen_ref_wav = s3gen_ref_wav[:self.DEC_COND_LEN]
199
  s3gen_ref_dict = self.s3gen.embed_ref(s3gen_ref_wav, S3GEN_SR, device=self.device)
200
 
@@ -220,8 +193,7 @@ class ChatterboxTTS:
220
  text,
221
  audio_prompt_path=None,
222
  exaggeration=0.5,
223
- cfg_weight=0,
224
- pace=1,
225
  temperature=0.8,
226
  ):
227
  if audio_prompt_path:
@@ -263,8 +235,6 @@ class ChatterboxTTS:
263
  speech_tokens = drop_invalid_tokens(speech_tokens)
264
  speech_tokens = speech_tokens.to(self.device)
265
 
266
- speech_tokens = change_pace(speech_tokens, pace=pace)
267
-
268
  wav, _ = self.s3gen.inference(
269
  speech_tokens=speech_tokens,
270
  ref_dict=self.conds.gen,
 
6
  import perth
7
  import torch.nn.functional as F
8
  from huggingface_hub import hf_hub_download
 
9
 
10
  from .models.t3 import T3
11
  from .models.s3tokenizer import S3_SR, drop_invalid_tokens
 
13
  from .models.tokenizers import EnTokenizer
14
  from .models.voice_encoder import VoiceEncoder
15
  from .models.t3.modules.cond_enc import T3Cond
 
16
 
17
 
18
  REPO_ID = "ResembleAI/chatterbox"
19
 
20
 
 
 
 
 
 
 
 
 
 
 
 
21
  def punc_norm(text: str) -> str:
22
  """
23
  Quick cleanup func for punctuation from LLMs or
 
121
  self.device = device
122
  self.conds = conds
123
  self.watermarker = perth.PerthImplicitWatermarker()
 
124
 
125
  @classmethod
126
  def from_local(cls, ckpt_dir, device) -> 'ChatterboxTTS':
 
168
 
169
  ref_16k_wav = librosa.resample(s3gen_ref_wav, orig_sr=S3GEN_SR, target_sr=S3_SR)
170
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  s3gen_ref_wav = s3gen_ref_wav[:self.DEC_COND_LEN]
172
  s3gen_ref_dict = self.s3gen.embed_ref(s3gen_ref_wav, S3GEN_SR, device=self.device)
173
 
 
193
  text,
194
  audio_prompt_path=None,
195
  exaggeration=0.5,
196
+ cfg_weight=0.5,
 
197
  temperature=0.8,
198
  ):
199
  if audio_prompt_path:
 
235
  speech_tokens = drop_invalid_tokens(speech_tokens)
236
  speech_tokens = speech_tokens.to(self.device)
237
 
 
 
238
  wav, _ = self.s3gen.inference(
239
  speech_tokens=speech_tokens,
240
  ref_dict=self.conds.gen,
chatterbox/src/chatterbox/utils.py CHANGED
@@ -1,15 +1,15 @@
1
- import numpy as np
2
-
3
-
4
- def trim_silence(wav, speech_timestamps, sr):
5
- """TODO: fading"""
6
- if len(speech_timestamps) == 0:
7
- return wav # WARNING: no speech detected, returning original wav
8
- segs = []
9
- for segment in speech_timestamps:
10
- start_s, end_s = segment['start'], segment['end']
11
- start = int(start_s * sr)
12
- end = int(end_s * sr)
13
- seg = wav[start: end]
14
- segs.append(seg)
15
- return np.concatenate(segs)
 
1
+ # import numpy as np
2
+ #
3
+ #
4
+ # def trim_silence(wav, speech_timestamps, sr):
5
+ # """TODO: fading"""
6
+ # if len(speech_timestamps) == 0:
7
+ # return wav # WARNING: no speech detected, returning original wav
8
+ # segs = []
9
+ # for segment in speech_timestamps:
10
+ # start_s, end_s = segment['start'], segment['end']
11
+ # start = int(start_s * sr)
12
+ # end = int(end_s * sr)
13
+ # seg = wav[start: end]
14
+ # segs.append(seg)
15
+ # return np.concatenate(segs)