Fix padding
Browse files- inference.py +1 -1
inference.py
CHANGED
@@ -26,7 +26,7 @@ def decode_tts(tokens, quantizer, n_codebooks, n_original_tokens, start_audio_to
|
|
26 |
if reminder:
|
27 |
# pad if last frame is incomplete
|
28 |
pad_tokens = torch.zeros(n_codebooks - reminder, device="cuda")
|
29 |
-
audio_tokens = torch.cat([audio_tokens, pad_tokens
|
30 |
|
31 |
transposed = audio_tokens.view(-1, n_codebooks).t()
|
32 |
codes = transposed.view(n_codebooks, 1, -1).to(device)
|
|
|
26 |
if reminder:
|
27 |
# pad if last frame is incomplete
|
28 |
pad_tokens = torch.zeros(n_codebooks - reminder, device="cuda")
|
29 |
+
audio_tokens = torch.cat([audio_tokens, pad_tokens], dim=0)
|
30 |
|
31 |
transposed = audio_tokens.view(-1, n_codebooks).t()
|
32 |
codes = transposed.view(n_codebooks, 1, -1).to(device)
|