Spaces:
Running
on
Zero
Running
on
Zero
fix vocoder speech overlap
Browse files- cosyvoice/cli/model.py +79 -66
- cosyvoice/hifigan/generator.py +8 -4
- cosyvoice/utils/common.py +6 -4
cosyvoice/cli/model.py
CHANGED
@@ -31,18 +31,25 @@ class CosyVoiceModel:
|
|
31 |
self.flow = flow
|
32 |
self.hift = hift
|
33 |
self.token_min_hop_len = 100
|
34 |
-
self.token_max_hop_len =
|
35 |
self.token_overlap_len = 20
|
36 |
-
|
37 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
38 |
self.stream_scale_factor = 1
|
39 |
assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
|
40 |
self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
|
41 |
self.flow_hift_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
|
42 |
self.lock = threading.Lock()
|
43 |
# dict used to store session related variable
|
44 |
-
self.
|
45 |
-
self.
|
|
|
|
|
46 |
|
47 |
def load(self, llm_model, flow_model, hift_model):
|
48 |
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
|
@@ -64,102 +71,108 @@ class CosyVoiceModel:
|
|
64 |
self.flow.decoder.estimator = xxx
|
65 |
self.flow.decoder.session = xxx
|
66 |
|
67 |
-
def llm_job(self, text,
|
68 |
with self.llm_context:
|
69 |
for i in self.llm.inference(text=text.to(self.device),
|
70 |
-
text_len=
|
71 |
prompt_text=prompt_text.to(self.device),
|
72 |
-
prompt_text_len=
|
73 |
prompt_speech_token=llm_prompt_speech_token.to(self.device),
|
74 |
-
prompt_speech_token_len=
|
75 |
embedding=llm_embedding.to(self.device).half(),
|
76 |
sampling=25,
|
77 |
max_token_text_ratio=30,
|
78 |
min_token_text_ratio=3):
|
79 |
-
self.
|
80 |
-
self.
|
81 |
|
82 |
-
def token2wav(self, token, prompt_token,
|
83 |
with self.flow_hift_context:
|
84 |
tts_mel = self.flow.inference(token=token.to(self.device),
|
85 |
-
token_len=torch.tensor([token.
|
86 |
prompt_token=prompt_token.to(self.device),
|
87 |
-
prompt_token_len=
|
88 |
prompt_feat=prompt_feat.to(self.device),
|
89 |
-
prompt_feat_len=
|
90 |
embedding=embedding.to(self.device))
|
91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
return tts_speech
|
93 |
|
94 |
-
def inference(self, text,
|
95 |
-
prompt_text=torch.zeros(1, 0, dtype=torch.int32),
|
96 |
-
llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
|
97 |
-
flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
|
98 |
-
prompt_speech_feat=torch.zeros(1, 0, 80),
|
99 |
# this_uuid is used to track variables related to this inference thread
|
100 |
this_uuid = str(uuid.uuid1())
|
101 |
with self.lock:
|
102 |
-
self.
|
103 |
-
p = threading.Thread(target=self.llm_job, args=(text
|
104 |
-
llm_prompt_speech_token.to(self.device), llm_prompt_speech_token_len.to(self.device), llm_embedding.to(self.device), this_uuid))
|
105 |
p.start()
|
|
|
106 |
if stream is True:
|
107 |
-
|
108 |
while True:
|
109 |
time.sleep(0.1)
|
110 |
-
if len(self.
|
111 |
-
this_tts_speech_token = torch.concat(self.
|
112 |
with self.flow_hift_context:
|
113 |
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
114 |
-
prompt_token=flow_prompt_speech_token
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
if cache_speech is not None:
|
121 |
-
this_tts_speech = fade_in_out(this_tts_speech, cache_speech, self.window)
|
122 |
-
yield {'tts_speech': this_tts_speech[:, :-self.speech_overlap_len]}
|
123 |
-
cache_speech = this_tts_speech[:, -self.speech_overlap_len:]
|
124 |
-
cache_token = self.tts_speech_token[this_uuid][:token_hop_len]
|
125 |
with self.lock:
|
126 |
-
self.
|
127 |
# increase token_hop_len for better speech quality
|
128 |
token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
|
129 |
-
if self.
|
130 |
break
|
131 |
-
p.join()
|
132 |
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
|
133 |
-
this_tts_speech_token = torch.concat(self.
|
134 |
-
if this_tts_speech_token.shape[1] < self.token_min_hop_len + self.token_overlap_len and cache_token is not None:
|
135 |
-
cache_token_len = self.token_min_hop_len + self.token_overlap_len - this_tts_speech_token.shape[1]
|
136 |
-
this_tts_speech_token = torch.concat([torch.concat(cache_token[-cache_token_len:], dim=1), this_tts_speech_token], dim=1)
|
137 |
-
else:
|
138 |
-
cache_token_len = 0
|
139 |
with self.flow_hift_context:
|
140 |
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
141 |
-
prompt_token=flow_prompt_speech_token
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
if cache_speech is not None:
|
148 |
-
this_tts_speech = fade_in_out(this_tts_speech, cache_speech, self.window)
|
149 |
-
yield {'tts_speech': this_tts_speech}
|
150 |
else:
|
151 |
# deal with all tokens
|
152 |
-
p.join()
|
153 |
-
this_tts_speech_token = torch.concat(self.
|
154 |
with self.flow_hift_context:
|
155 |
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
156 |
-
prompt_token=flow_prompt_speech_token
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
yield {'tts_speech': this_tts_speech}
|
162 |
with self.lock:
|
163 |
-
self.
|
164 |
-
self.
|
|
|
|
|
165 |
torch.cuda.synchronize()
|
|
|
31 |
self.flow = flow
|
32 |
self.hift = hift
|
33 |
self.token_min_hop_len = 100
|
34 |
+
self.token_max_hop_len = 200
|
35 |
self.token_overlap_len = 20
|
36 |
+
# mel fade in out
|
37 |
+
self.mel_overlap_len = 34
|
38 |
+
self.mel_window = np.hamming(2 * self.mel_overlap_len)
|
39 |
+
# hift cache
|
40 |
+
self.mel_cache_len = 20
|
41 |
+
self.source_cache_len = int(self.mel_cache_len * 256)
|
42 |
+
# rtf and decoding related
|
43 |
self.stream_scale_factor = 1
|
44 |
assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
|
45 |
self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
|
46 |
self.flow_hift_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
|
47 |
self.lock = threading.Lock()
|
48 |
# dict used to store session related variable
|
49 |
+
self.tts_speech_token_dict = {}
|
50 |
+
self.llm_end_dict = {}
|
51 |
+
self.mel_overlap_dict = {}
|
52 |
+
self.hift_cache_dict = {}
|
53 |
|
54 |
def load(self, llm_model, flow_model, hift_model):
|
55 |
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
|
|
|
71 |
self.flow.decoder.estimator = xxx
|
72 |
self.flow.decoder.session = xxx
|
73 |
|
74 |
+
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
|
75 |
with self.llm_context:
|
76 |
for i in self.llm.inference(text=text.to(self.device),
|
77 |
+
text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
|
78 |
prompt_text=prompt_text.to(self.device),
|
79 |
+
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
|
80 |
prompt_speech_token=llm_prompt_speech_token.to(self.device),
|
81 |
+
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
|
82 |
embedding=llm_embedding.to(self.device).half(),
|
83 |
sampling=25,
|
84 |
max_token_text_ratio=30,
|
85 |
min_token_text_ratio=3):
|
86 |
+
self.tts_speech_token_dict[uuid].append(i)
|
87 |
+
self.llm_end_dict[uuid] = True
|
88 |
|
89 |
+
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False):
|
90 |
with self.flow_hift_context:
|
91 |
tts_mel = self.flow.inference(token=token.to(self.device),
|
92 |
+
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
|
93 |
prompt_token=prompt_token.to(self.device),
|
94 |
+
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
|
95 |
prompt_feat=prompt_feat.to(self.device),
|
96 |
+
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
|
97 |
embedding=embedding.to(self.device))
|
98 |
+
# mel overlap fade in out
|
99 |
+
if self.mel_overlap_dict[uuid] is not None:
|
100 |
+
tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
|
101 |
+
# append hift cache
|
102 |
+
if self.hift_cache_dict[uuid] is not None:
|
103 |
+
hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
|
104 |
+
tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
|
105 |
+
else:
|
106 |
+
hift_cache_source = torch.zeros(1, 1, 0)
|
107 |
+
# keep overlap mel and hift cache
|
108 |
+
if finalize is False:
|
109 |
+
self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
|
110 |
+
tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
|
111 |
+
tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
|
112 |
+
self.hift_cache_dict[uuid] = {'source': tts_source[:, :, -self.source_cache_len:], 'mel': tts_mel[:, :, -self.mel_cache_len:]}
|
113 |
+
tts_speech = tts_speech[:, :-self.source_cache_len]
|
114 |
+
else:
|
115 |
+
tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
|
116 |
return tts_speech
|
117 |
|
118 |
+
def inference(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
|
119 |
+
prompt_text=torch.zeros(1, 0, dtype=torch.int32),
|
120 |
+
llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
|
121 |
+
flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
|
122 |
+
prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, **kwargs):
|
123 |
# this_uuid is used to track variables related to this inference thread
|
124 |
this_uuid = str(uuid.uuid1())
|
125 |
with self.lock:
|
126 |
+
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid], self.mel_overlap_dict[this_uuid], self.hift_cache_dict[this_uuid] = [], False, None, None
|
127 |
+
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
|
|
|
128 |
p.start()
|
129 |
+
p.join()
|
130 |
if stream is True:
|
131 |
+
token_hop_len = self.token_min_hop_len
|
132 |
while True:
|
133 |
time.sleep(0.1)
|
134 |
+
if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
|
135 |
+
this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len], dim=1)
|
136 |
with self.flow_hift_context:
|
137 |
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
138 |
+
prompt_token=flow_prompt_speech_token,
|
139 |
+
prompt_feat=prompt_speech_feat,
|
140 |
+
embedding=flow_embedding,
|
141 |
+
uuid=this_uuid,
|
142 |
+
finalize=False)
|
143 |
+
yield {'tts_speech': this_tts_speech.cpu()}
|
|
|
|
|
|
|
|
|
|
|
144 |
with self.lock:
|
145 |
+
self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
|
146 |
# increase token_hop_len for better speech quality
|
147 |
token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
|
148 |
+
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
|
149 |
break
|
150 |
+
# p.join()
|
151 |
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
|
152 |
+
this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
|
|
|
|
|
|
|
|
|
|
|
153 |
with self.flow_hift_context:
|
154 |
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
155 |
+
prompt_token=flow_prompt_speech_token,
|
156 |
+
prompt_feat=prompt_speech_feat,
|
157 |
+
embedding=flow_embedding,
|
158 |
+
uuid=this_uuid,
|
159 |
+
finalize=True)
|
160 |
+
yield {'tts_speech': this_tts_speech.cpu()}
|
|
|
|
|
|
|
161 |
else:
|
162 |
# deal with all tokens
|
163 |
+
# p.join()
|
164 |
+
this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
|
165 |
with self.flow_hift_context:
|
166 |
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
167 |
+
prompt_token=flow_prompt_speech_token,
|
168 |
+
prompt_feat=prompt_speech_feat,
|
169 |
+
embedding=flow_embedding,
|
170 |
+
uuid=this_uuid,
|
171 |
+
finalize=True)
|
172 |
+
yield {'tts_speech': this_tts_speech.cpu()}
|
173 |
with self.lock:
|
174 |
+
self.tts_speech_token_dict.pop(this_uuid)
|
175 |
+
self.llm_end_dict.pop(this_uuid)
|
176 |
+
self.mel_overlap_dict.pop(this_uuid)
|
177 |
+
self.hift_cache_dict.pop(this_uuid)
|
178 |
torch.cuda.synchronize()
|
cosyvoice/hifigan/generator.py
CHANGED
@@ -335,10 +335,14 @@ class HiFTGenerator(nn.Module):
|
|
335 |
inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
|
336 |
return inverse_transform
|
337 |
|
338 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
339 |
f0 = self.f0_predictor(x)
|
340 |
s = self._f02source(f0)
|
341 |
|
|
|
|
|
|
|
|
|
342 |
s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
|
343 |
s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
|
344 |
|
@@ -370,7 +374,7 @@ class HiFTGenerator(nn.Module):
|
|
370 |
|
371 |
x = self._istft(magnitude, phase)
|
372 |
x = torch.clamp(x, -self.audio_limit, self.audio_limit)
|
373 |
-
return x
|
374 |
|
375 |
def remove_weight_norm(self):
|
376 |
print('Removing weight norm...')
|
@@ -387,5 +391,5 @@ class HiFTGenerator(nn.Module):
|
|
387 |
l.remove_weight_norm()
|
388 |
|
389 |
@torch.inference_mode()
|
390 |
-
def inference(self, mel: torch.Tensor) -> torch.Tensor:
|
391 |
-
return self.forward(x=mel)
|
|
|
335 |
inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
|
336 |
return inverse_transform
|
337 |
|
338 |
+
def forward(self, x: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
|
339 |
f0 = self.f0_predictor(x)
|
340 |
s = self._f02source(f0)
|
341 |
|
342 |
+
# use cache_source to avoid glitch
|
343 |
+
if cache_source.shape[2] == 0:
|
344 |
+
s[:, :, :cache_source.shape[2]] = cache_source
|
345 |
+
|
346 |
s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
|
347 |
s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
|
348 |
|
|
|
374 |
|
375 |
x = self._istft(magnitude, phase)
|
376 |
x = torch.clamp(x, -self.audio_limit, self.audio_limit)
|
377 |
+
return x, s
|
378 |
|
379 |
def remove_weight_norm(self):
|
380 |
print('Removing weight norm...')
|
|
|
391 |
l.remove_weight_norm()
|
392 |
|
393 |
@torch.inference_mode()
|
394 |
+
def inference(self, mel: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
|
395 |
+
return self.forward(x=mel, cache_source=cache_source)
|
cosyvoice/utils/common.py
CHANGED
@@ -131,7 +131,9 @@ def random_sampling(weighted_scores, decoded_tokens, sampling):
|
|
131 |
top_ids = weighted_scores.softmax(dim=0).multinomial(1, replacement=True)
|
132 |
return top_ids
|
133 |
|
134 |
-
def fade_in_out(
|
135 |
-
|
136 |
-
|
137 |
-
|
|
|
|
|
|
131 |
top_ids = weighted_scores.softmax(dim=0).multinomial(1, replacement=True)
|
132 |
return top_ids
|
133 |
|
134 |
+
def fade_in_out(fade_in_mel, fade_out_mel, window):
|
135 |
+
device = fade_in_mel.device
|
136 |
+
fade_in_mel, fade_out_mel = fade_in_mel.cpu(), fade_out_mel.cpu()
|
137 |
+
mel_overlap_len = int(window.shape[0] / 2)
|
138 |
+
fade_in_mel[:, :, :mel_overlap_len] = fade_in_mel[:, :, :mel_overlap_len] * window[:mel_overlap_len] + fade_out_mel[:, :, -mel_overlap_len:] * window[mel_overlap_len:]
|
139 |
+
return fade_in_mel.to(device)
|