CosyVoice commited on
Commit
1d881df
·
1 Parent(s): f1e374a

fix vocoder speech overlap

Browse files
cosyvoice/cli/model.py CHANGED
@@ -31,18 +31,25 @@ class CosyVoiceModel:
31
  self.flow = flow
32
  self.hift = hift
33
  self.token_min_hop_len = 100
34
- self.token_max_hop_len = 400
35
  self.token_overlap_len = 20
36
- self.speech_overlap_len = 34 * 256
37
- self.window = np.hamming(2 * self.speech_overlap_len)
 
 
 
 
 
38
  self.stream_scale_factor = 1
39
  assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
40
  self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
41
  self.flow_hift_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
42
  self.lock = threading.Lock()
43
  # dict used to store session related variable
44
- self.tts_speech_token = {}
45
- self.llm_end = {}
 
 
46
 
47
  def load(self, llm_model, flow_model, hift_model):
48
  self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
@@ -64,102 +71,108 @@ class CosyVoiceModel:
64
  self.flow.decoder.estimator = xxx
65
  self.flow.decoder.session = xxx
66
 
67
- def llm_job(self, text, text_len, prompt_text, prompt_text_len, llm_prompt_speech_token, llm_prompt_speech_token_len, llm_embedding, this_uuid):
68
  with self.llm_context:
69
  for i in self.llm.inference(text=text.to(self.device),
70
- text_len=text_len.to(self.device),
71
  prompt_text=prompt_text.to(self.device),
72
- prompt_text_len=prompt_text_len.to(self.device),
73
  prompt_speech_token=llm_prompt_speech_token.to(self.device),
74
- prompt_speech_token_len=llm_prompt_speech_token_len.to(self.device),
75
  embedding=llm_embedding.to(self.device).half(),
76
  sampling=25,
77
  max_token_text_ratio=30,
78
  min_token_text_ratio=3):
79
- self.tts_speech_token[this_uuid].append(i)
80
- self.llm_end[this_uuid] = True
81
 
82
- def token2wav(self, token, prompt_token, prompt_token_len, prompt_feat, prompt_feat_len, embedding):
83
  with self.flow_hift_context:
84
  tts_mel = self.flow.inference(token=token.to(self.device),
85
- token_len=torch.tensor([token.size(1)], dtype=torch.int32).to(self.device),
86
  prompt_token=prompt_token.to(self.device),
87
- prompt_token_len=prompt_token_len.to(self.device),
88
  prompt_feat=prompt_feat.to(self.device),
89
- prompt_feat_len=prompt_feat_len.to(self.device),
90
  embedding=embedding.to(self.device))
91
- tts_speech = self.hift.inference(mel=tts_mel).cpu()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  return tts_speech
93
 
94
- def inference(self, text, text_len, flow_embedding, llm_embedding=torch.zeros(0, 192),
95
- prompt_text=torch.zeros(1, 0, dtype=torch.int32), prompt_text_len=torch.zeros(1, dtype=torch.int32),
96
- llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), llm_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
97
- flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), flow_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
98
- prompt_speech_feat=torch.zeros(1, 0, 80), prompt_speech_feat_len=torch.zeros(1, dtype=torch.int32), stream=False):
99
  # this_uuid is used to track variables related to this inference thread
100
  this_uuid = str(uuid.uuid1())
101
  with self.lock:
102
- self.tts_speech_token[this_uuid], self.llm_end[this_uuid] = [], False
103
- p = threading.Thread(target=self.llm_job, args=(text.to(self.device), text_len.to(self.device), prompt_text.to(self.device), prompt_text_len.to(self.device),
104
- llm_prompt_speech_token.to(self.device), llm_prompt_speech_token_len.to(self.device), llm_embedding.to(self.device), this_uuid))
105
  p.start()
 
106
  if stream is True:
107
- cache_speech, cache_token, token_hop_len = None, None, self.token_min_hop_len
108
  while True:
109
  time.sleep(0.1)
110
- if len(self.tts_speech_token[this_uuid]) >= token_hop_len + self.token_overlap_len:
111
- this_tts_speech_token = torch.concat(self.tts_speech_token[this_uuid][:token_hop_len + self.token_overlap_len], dim=1)
112
  with self.flow_hift_context:
113
  this_tts_speech = self.token2wav(token=this_tts_speech_token,
114
- prompt_token=flow_prompt_speech_token.to(self.device),
115
- prompt_token_len=flow_prompt_speech_token_len.to(self.device),
116
- prompt_feat=prompt_speech_feat.to(self.device),
117
- prompt_feat_len=prompt_speech_feat_len.to(self.device),
118
- embedding=flow_embedding.to(self.device))
119
- # fade in/out if necessary
120
- if cache_speech is not None:
121
- this_tts_speech = fade_in_out(this_tts_speech, cache_speech, self.window)
122
- yield {'tts_speech': this_tts_speech[:, :-self.speech_overlap_len]}
123
- cache_speech = this_tts_speech[:, -self.speech_overlap_len:]
124
- cache_token = self.tts_speech_token[this_uuid][:token_hop_len]
125
  with self.lock:
126
- self.tts_speech_token[this_uuid] = self.tts_speech_token[this_uuid][token_hop_len:]
127
  # increase token_hop_len for better speech quality
128
  token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
129
- if self.llm_end[this_uuid] is True and len(self.tts_speech_token[this_uuid]) < token_hop_len + self.token_overlap_len:
130
  break
131
- p.join()
132
  # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
133
- this_tts_speech_token = torch.concat(self.tts_speech_token[this_uuid], dim=1)
134
- if this_tts_speech_token.shape[1] < self.token_min_hop_len + self.token_overlap_len and cache_token is not None:
135
- cache_token_len = self.token_min_hop_len + self.token_overlap_len - this_tts_speech_token.shape[1]
136
- this_tts_speech_token = torch.concat([torch.concat(cache_token[-cache_token_len:], dim=1), this_tts_speech_token], dim=1)
137
- else:
138
- cache_token_len = 0
139
  with self.flow_hift_context:
140
  this_tts_speech = self.token2wav(token=this_tts_speech_token,
141
- prompt_token=flow_prompt_speech_token.to(self.device),
142
- prompt_token_len=flow_prompt_speech_token_len.to(self.device),
143
- prompt_feat=prompt_speech_feat.to(self.device),
144
- prompt_feat_len=prompt_speech_feat_len.to(self.device),
145
- embedding=flow_embedding.to(self.device))
146
- this_tts_speech = this_tts_speech[:, int(cache_token_len / this_tts_speech_token.shape[1] * this_tts_speech.shape[1]):]
147
- if cache_speech is not None:
148
- this_tts_speech = fade_in_out(this_tts_speech, cache_speech, self.window)
149
- yield {'tts_speech': this_tts_speech}
150
  else:
151
  # deal with all tokens
152
- p.join()
153
- this_tts_speech_token = torch.concat(self.tts_speech_token[this_uuid], dim=1)
154
  with self.flow_hift_context:
155
  this_tts_speech = self.token2wav(token=this_tts_speech_token,
156
- prompt_token=flow_prompt_speech_token.to(self.device),
157
- prompt_token_len=flow_prompt_speech_token_len.to(self.device),
158
- prompt_feat=prompt_speech_feat.to(self.device),
159
- prompt_feat_len=prompt_speech_feat_len.to(self.device),
160
- embedding=flow_embedding.to(self.device))
161
- yield {'tts_speech': this_tts_speech}
162
  with self.lock:
163
- self.tts_speech_token.pop(this_uuid)
164
- self.llm_end.pop(this_uuid)
 
 
165
  torch.cuda.synchronize()
 
31
  self.flow = flow
32
  self.hift = hift
33
  self.token_min_hop_len = 100
34
+ self.token_max_hop_len = 200
35
  self.token_overlap_len = 20
36
+ # mel fade in out
37
+ self.mel_overlap_len = 34
38
+ self.mel_window = np.hamming(2 * self.mel_overlap_len)
39
+ # hift cache
40
+ self.mel_cache_len = 20
41
+ self.source_cache_len = int(self.mel_cache_len * 256)
42
+ # rtf and decoding related
43
  self.stream_scale_factor = 1
44
  assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
45
  self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
46
  self.flow_hift_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
47
  self.lock = threading.Lock()
48
  # dict used to store session related variable
49
+ self.tts_speech_token_dict = {}
50
+ self.llm_end_dict = {}
51
+ self.mel_overlap_dict = {}
52
+ self.hift_cache_dict = {}
53
 
54
  def load(self, llm_model, flow_model, hift_model):
55
  self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
 
71
  self.flow.decoder.estimator = xxx
72
  self.flow.decoder.session = xxx
73
 
74
+ def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
75
  with self.llm_context:
76
  for i in self.llm.inference(text=text.to(self.device),
77
+ text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
78
  prompt_text=prompt_text.to(self.device),
79
+ prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
80
  prompt_speech_token=llm_prompt_speech_token.to(self.device),
81
+ prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
82
  embedding=llm_embedding.to(self.device).half(),
83
  sampling=25,
84
  max_token_text_ratio=30,
85
  min_token_text_ratio=3):
86
+ self.tts_speech_token_dict[uuid].append(i)
87
+ self.llm_end_dict[uuid] = True
88
 
89
+ def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False):
90
  with self.flow_hift_context:
91
  tts_mel = self.flow.inference(token=token.to(self.device),
92
+ token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
93
  prompt_token=prompt_token.to(self.device),
94
+ prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
95
  prompt_feat=prompt_feat.to(self.device),
96
+ prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
97
  embedding=embedding.to(self.device))
98
+ # mel overlap fade in out
99
+ if self.mel_overlap_dict[uuid] is not None:
100
+ tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
101
+ # append hift cache
102
+ if self.hift_cache_dict[uuid] is not None:
103
+ hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
104
+ tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
105
+ else:
106
+ hift_cache_source = torch.zeros(1, 1, 0)
107
+ # keep overlap mel and hift cache
108
+ if finalize is False:
109
+ self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
110
+ tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
111
+ tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
112
+ self.hift_cache_dict[uuid] = {'source': tts_source[:, :, -self.source_cache_len:], 'mel': tts_mel[:, :, -self.mel_cache_len:]}
113
+ tts_speech = tts_speech[:, :-self.source_cache_len]
114
+ else:
115
+ tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
116
  return tts_speech
117
 
118
+ def inference(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
119
+ prompt_text=torch.zeros(1, 0, dtype=torch.int32),
120
+ llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
121
+ flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
122
+ prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, **kwargs):
123
  # this_uuid is used to track variables related to this inference thread
124
  this_uuid = str(uuid.uuid1())
125
  with self.lock:
126
+ self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid], self.mel_overlap_dict[this_uuid], self.hift_cache_dict[this_uuid] = [], False, None, None
127
+ p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
 
128
  p.start()
129
+ p.join()
130
  if stream is True:
131
+ token_hop_len = self.token_min_hop_len
132
  while True:
133
  time.sleep(0.1)
134
+ if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
135
+ this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len], dim=1)
136
  with self.flow_hift_context:
137
  this_tts_speech = self.token2wav(token=this_tts_speech_token,
138
+ prompt_token=flow_prompt_speech_token,
139
+ prompt_feat=prompt_speech_feat,
140
+ embedding=flow_embedding,
141
+ uuid=this_uuid,
142
+ finalize=False)
143
+ yield {'tts_speech': this_tts_speech.cpu()}
 
 
 
 
 
144
  with self.lock:
145
+ self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
146
  # increase token_hop_len for better speech quality
147
  token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
148
+ if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
149
  break
150
+ # p.join()
151
  # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
152
+ this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
 
 
 
 
 
153
  with self.flow_hift_context:
154
  this_tts_speech = self.token2wav(token=this_tts_speech_token,
155
+ prompt_token=flow_prompt_speech_token,
156
+ prompt_feat=prompt_speech_feat,
157
+ embedding=flow_embedding,
158
+ uuid=this_uuid,
159
+ finalize=True)
160
+ yield {'tts_speech': this_tts_speech.cpu()}
 
 
 
161
  else:
162
  # deal with all tokens
163
+ # p.join()
164
+ this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
165
  with self.flow_hift_context:
166
  this_tts_speech = self.token2wav(token=this_tts_speech_token,
167
+ prompt_token=flow_prompt_speech_token,
168
+ prompt_feat=prompt_speech_feat,
169
+ embedding=flow_embedding,
170
+ uuid=this_uuid,
171
+ finalize=True)
172
+ yield {'tts_speech': this_tts_speech.cpu()}
173
  with self.lock:
174
+ self.tts_speech_token_dict.pop(this_uuid)
175
+ self.llm_end_dict.pop(this_uuid)
176
+ self.mel_overlap_dict.pop(this_uuid)
177
+ self.hift_cache_dict.pop(this_uuid)
178
  torch.cuda.synchronize()
cosyvoice/hifigan/generator.py CHANGED
@@ -335,10 +335,14 @@ class HiFTGenerator(nn.Module):
335
  inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
336
  return inverse_transform
337
 
338
- def forward(self, x: torch.Tensor) -> torch.Tensor:
339
  f0 = self.f0_predictor(x)
340
  s = self._f02source(f0)
341
 
 
 
 
 
342
  s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
343
  s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
344
 
@@ -370,7 +374,7 @@ class HiFTGenerator(nn.Module):
370
 
371
  x = self._istft(magnitude, phase)
372
  x = torch.clamp(x, -self.audio_limit, self.audio_limit)
373
- return x
374
 
375
  def remove_weight_norm(self):
376
  print('Removing weight norm...')
@@ -387,5 +391,5 @@ class HiFTGenerator(nn.Module):
387
  l.remove_weight_norm()
388
 
389
  @torch.inference_mode()
390
- def inference(self, mel: torch.Tensor) -> torch.Tensor:
391
- return self.forward(x=mel)
 
335
  inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
336
  return inverse_transform
337
 
338
+ def forward(self, x: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
339
  f0 = self.f0_predictor(x)
340
  s = self._f02source(f0)
341
 
342
+ # use cache_source to avoid glitch
343
+ if cache_source.shape[2] == 0:
344
+ s[:, :, :cache_source.shape[2]] = cache_source
345
+
346
  s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
347
  s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
348
 
 
374
 
375
  x = self._istft(magnitude, phase)
376
  x = torch.clamp(x, -self.audio_limit, self.audio_limit)
377
+ return x, s
378
 
379
  def remove_weight_norm(self):
380
  print('Removing weight norm...')
 
391
  l.remove_weight_norm()
392
 
393
  @torch.inference_mode()
394
+ def inference(self, mel: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
395
+ return self.forward(x=mel, cache_source=cache_source)
cosyvoice/utils/common.py CHANGED
@@ -131,7 +131,9 @@ def random_sampling(weighted_scores, decoded_tokens, sampling):
131
  top_ids = weighted_scores.softmax(dim=0).multinomial(1, replacement=True)
132
  return top_ids
133
 
134
- def fade_in_out(fade_in_speech, fade_out_speech, window):
135
- speech_overlap_len = int(window.shape[0] / 2)
136
- fade_in_speech[:, :speech_overlap_len] = fade_in_speech[:, :speech_overlap_len] * window[:speech_overlap_len] + fade_out_speech[:, -speech_overlap_len:] * window[speech_overlap_len:]
137
- return fade_in_speech
 
 
 
131
  top_ids = weighted_scores.softmax(dim=0).multinomial(1, replacement=True)
132
  return top_ids
133
 
134
+ def fade_in_out(fade_in_mel, fade_out_mel, window):
135
+ device = fade_in_mel.device
136
+ fade_in_mel, fade_out_mel = fade_in_mel.cpu(), fade_out_mel.cpu()
137
+ mel_overlap_len = int(window.shape[0] / 2)
138
+ fade_in_mel[:, :, :mel_overlap_len] = fade_in_mel[:, :, :mel_overlap_len] * window[:mel_overlap_len] + fade_out_mel[:, :, -mel_overlap_len:] * window[mel_overlap_len:]
139
+ return fade_in_mel.to(device)