Spaces:
Build error
Build error
ddd
commited on
Commit
·
a3411b4
1
Parent(s):
853fd97
fix hparam
Browse files- modules/diffsinger_midi/fs2.py +0 -109
modules/diffsinger_midi/fs2.py
CHANGED
|
@@ -117,112 +117,3 @@ class FastSpeech2MIDI(FastSpeech2):
|
|
| 117 |
|
| 118 |
return ret
|
| 119 |
|
| 120 |
-
def add_pitch(self, decoder_inp, f0, uv, mel2ph, ret, encoder_out=None):
|
| 121 |
-
decoder_inp = decoder_inp.detach() + hparams['predictor_grad'] * (decoder_inp - decoder_inp.detach())
|
| 122 |
-
pitch_padding = mel2ph == 0
|
| 123 |
-
if hparams['pitch_ar']:
|
| 124 |
-
ret['pitch_pred'] = pitch_pred = self.pitch_predictor(decoder_inp, f0 if self.training else None)
|
| 125 |
-
if f0 is None:
|
| 126 |
-
f0 = pitch_pred[:, :, 0]
|
| 127 |
-
else:
|
| 128 |
-
ret['pitch_pred'] = pitch_pred = self.pitch_predictor(decoder_inp)
|
| 129 |
-
if f0 is None:
|
| 130 |
-
f0 = pitch_pred[:, :, 0]
|
| 131 |
-
if hparams['use_uv'] and uv is None:
|
| 132 |
-
uv = pitch_pred[:, :, 1] > 0
|
| 133 |
-
|
| 134 |
-
# here f0_denorm for pitch prediction
|
| 135 |
-
ret['f0_denorm'] = denorm_f0(f0, uv, hparams, pitch_padding=pitch_padding)
|
| 136 |
-
|
| 137 |
-
# here f0_denorm for mel prediction
|
| 138 |
-
if self.training:
|
| 139 |
-
mask = torch.full(uv.shape, hparams.get('mask_uv_prob', 0.)).to(f0.device)
|
| 140 |
-
masked_uv = torch.bernoulli(mask).bool().to(f0.device) # prob 的概率吐出一个随机uv.
|
| 141 |
-
uv_masked = uv.bool() | masked_uv
|
| 142 |
-
# print((uv.float()-uv_masked.float()).mean(dim=1))
|
| 143 |
-
f0_denorm = denorm_f0(f0, uv_masked, hparams, pitch_padding=pitch_padding)
|
| 144 |
-
else:
|
| 145 |
-
f0_denorm = ret['f0_denorm']
|
| 146 |
-
|
| 147 |
-
if pitch_padding is not None:
|
| 148 |
-
f0[pitch_padding] = 0
|
| 149 |
-
|
| 150 |
-
pitch = f0_to_coarse(f0_denorm) # start from 0
|
| 151 |
-
pitch_embed = self.pitch_embed(pitch)
|
| 152 |
-
return pitch_embed
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
class FastSpeech2MIDIMasked(FastSpeech2MIDI):
|
| 156 |
-
def forward(self, txt_tokens, mel2ph=None, spk_embed=None,
|
| 157 |
-
ref_mels=None, f0=None, uv=None, energy=None, skip_decoder=False,
|
| 158 |
-
spk_embed_dur_id=None, spk_embed_f0_id=None, infer=False, **kwargs):
|
| 159 |
-
ret = {}
|
| 160 |
-
|
| 161 |
-
midi_dur_embedding, slur_embedding = 0, 0
|
| 162 |
-
if kwargs.get('midi_dur') is not None:
|
| 163 |
-
midi_dur_embedding = self.midi_dur_layer(kwargs['midi_dur'][:, :, None]) # [B, T, 1] -> [B, T, H]
|
| 164 |
-
if kwargs.get('is_slur') is not None:
|
| 165 |
-
slur_embedding = self.is_slur_embed(kwargs['is_slur'])
|
| 166 |
-
encoder_out = self.encoder(txt_tokens, 0, midi_dur_embedding, slur_embedding) # [B, T, C]
|
| 167 |
-
src_nonpadding = (txt_tokens > 0).float()[:, :, None]
|
| 168 |
-
|
| 169 |
-
# add ref style embed
|
| 170 |
-
# Not implemented
|
| 171 |
-
# variance encoder
|
| 172 |
-
var_embed = 0
|
| 173 |
-
|
| 174 |
-
# encoder_out_dur denotes encoder outputs for duration predictor
|
| 175 |
-
# in speech adaptation, duration predictor use old speaker embedding
|
| 176 |
-
if hparams['use_spk_embed']:
|
| 177 |
-
spk_embed_dur = spk_embed_f0 = spk_embed = self.spk_embed_proj(spk_embed)[:, None, :]
|
| 178 |
-
elif hparams['use_spk_id']:
|
| 179 |
-
spk_embed_id = spk_embed
|
| 180 |
-
if spk_embed_dur_id is None:
|
| 181 |
-
spk_embed_dur_id = spk_embed_id
|
| 182 |
-
if spk_embed_f0_id is None:
|
| 183 |
-
spk_embed_f0_id = spk_embed_id
|
| 184 |
-
spk_embed = self.spk_embed_proj(spk_embed_id)[:, None, :]
|
| 185 |
-
spk_embed_dur = spk_embed_f0 = spk_embed
|
| 186 |
-
if hparams['use_split_spk_id']:
|
| 187 |
-
spk_embed_dur = self.spk_embed_dur(spk_embed_dur_id)[:, None, :]
|
| 188 |
-
spk_embed_f0 = self.spk_embed_f0(spk_embed_f0_id)[:, None, :]
|
| 189 |
-
else:
|
| 190 |
-
spk_embed_dur = spk_embed_f0 = spk_embed = 0
|
| 191 |
-
|
| 192 |
-
# add dur
|
| 193 |
-
dur_inp = (encoder_out + var_embed + spk_embed_dur) * src_nonpadding
|
| 194 |
-
|
| 195 |
-
mel2ph = self.add_dur(dur_inp, mel2ph, txt_tokens, ret)
|
| 196 |
-
|
| 197 |
-
decoder_inp = F.pad(encoder_out, [0, 0, 1, 0])
|
| 198 |
-
|
| 199 |
-
mel2ph_ = mel2ph[..., None].repeat([1, 1, encoder_out.shape[-1]])
|
| 200 |
-
decoder_inp = torch.gather(decoder_inp, 1, mel2ph_) # [B, T, H]
|
| 201 |
-
|
| 202 |
-
# expanded midi
|
| 203 |
-
midi_embedding = self.midi_embed(kwargs['pitch_midi'])
|
| 204 |
-
midi_embedding = F.pad(midi_embedding, [0, 0, 1, 0])
|
| 205 |
-
midi_embedding = torch.gather(midi_embedding, 1, mel2ph_)
|
| 206 |
-
print(midi_embedding.shape, decoder_inp.shape)
|
| 207 |
-
midi_mask = torch.full(midi_embedding.shape, hparams.get('mask_uv_prob', 0.)).to(midi_embedding.device)
|
| 208 |
-
midi_mask = 1 - torch.bernoulli(midi_mask).bool().to(midi_embedding.device) # prob 的概率吐出一个随机uv.
|
| 209 |
-
|
| 210 |
-
tgt_nonpadding = (mel2ph > 0).float()[:, :, None]
|
| 211 |
-
|
| 212 |
-
decoder_inp += midi_embedding
|
| 213 |
-
decoder_inp_origin = decoder_inp
|
| 214 |
-
# add pitch and energy embed
|
| 215 |
-
pitch_inp = (decoder_inp_origin + var_embed + spk_embed_f0) * tgt_nonpadding
|
| 216 |
-
if hparams['use_pitch_embed']:
|
| 217 |
-
pitch_inp_ph = (encoder_out + var_embed + spk_embed_f0) * src_nonpadding
|
| 218 |
-
decoder_inp = decoder_inp + self.add_pitch(pitch_inp, f0, uv, mel2ph, ret, encoder_out=pitch_inp_ph)
|
| 219 |
-
if hparams['use_energy_embed']:
|
| 220 |
-
decoder_inp = decoder_inp + self.add_energy(pitch_inp, energy, ret)
|
| 221 |
-
|
| 222 |
-
ret['decoder_inp'] = decoder_inp = (decoder_inp + spk_embed) * tgt_nonpadding
|
| 223 |
-
|
| 224 |
-
if skip_decoder:
|
| 225 |
-
return ret
|
| 226 |
-
ret['mel_out'] = self.run_decoder(decoder_inp, tgt_nonpadding, ret, infer=infer, **kwargs)
|
| 227 |
-
|
| 228 |
-
return ret
|
|
|
|
| 117 |
|
| 118 |
return ret
|
| 119 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|