Update gigaam_transformers.py
Browse files- gigaam_transformers.py +2 -1
gigaam_transformers.py
CHANGED
@@ -6,6 +6,7 @@ import torch.nn as nn
|
|
6 |
import torchaudio
|
7 |
from .encoder import ConformerEncoder
|
8 |
from torch import Tensor
|
|
|
9 |
from transformers import Wav2Vec2CTCTokenizer, Wav2Vec2Processor
|
10 |
from transformers.configuration_utils import PretrainedConfig
|
11 |
from transformers.feature_extraction_sequence_utils import \
|
@@ -445,4 +446,4 @@ class GigaAMRNNTHF(PreTrainedModel):
|
|
445 |
for i in range(b):
|
446 |
inseq = encoder_out[i, :, :].unsqueeze(1)
|
447 |
preds.append(self._greedy_decode(inseq, encoded_lengths[i]))
|
448 |
-
return
|
|
|
6 |
import torchaudio
|
7 |
from .encoder import ConformerEncoder
|
8 |
from torch import Tensor
|
9 |
+
from torch.nn.utils.rnn import pad_sequence
|
10 |
from transformers import Wav2Vec2CTCTokenizer, Wav2Vec2Processor
|
11 |
from transformers.configuration_utils import PretrainedConfig
|
12 |
from transformers.feature_extraction_sequence_utils import \
|
|
|
446 |
for i in range(b):
|
447 |
inseq = encoder_out[i, :, :].unsqueeze(1)
|
448 |
preds.append(self._greedy_decode(inseq, encoded_lengths[i]))
|
449 |
+
return pad_sequence(preds, batch_first=True, padding_value=self.config.blank_id)
|