Spaces:
Build error
Build error
| import torch | |
| import torch.nn.functional as F | |
| def build_word_mask(x2word, y2word): | |
| return (x2word[:, :, None] == y2word[:, None, :]).long() | |
| def mel2ph_to_mel2word(mel2ph, ph2word): | |
| mel2word = (ph2word - 1).gather(1, (mel2ph - 1).clamp(min=0)) + 1 | |
| mel2word = mel2word * (mel2ph > 0).long() | |
| return mel2word | |
| def clip_mel2token_to_multiple(mel2token, frames_multiple): | |
| max_frames = mel2token.shape[1] // frames_multiple * frames_multiple | |
| mel2token = mel2token[:, :max_frames] | |
| return mel2token | |
| def expand_states(h, mel2token): | |
| h = F.pad(h, [0, 0, 1, 0]) | |
| mel2token_ = mel2token[..., None].repeat([1, 1, h.shape[-1]]) | |
| h = torch.gather(h, 1, mel2token_) # [B, T, H] | |
| return h | |