waveletdeboshir commited on
Commit
8d2bbab
·
verified ·
1 Parent(s): a3fc5de

Update gigaam_transformers.py

Browse files
Files changed (1) hide show
  1. gigaam_transformers.py +2 -1
gigaam_transformers.py CHANGED
@@ -6,6 +6,7 @@ import torch.nn as nn
6
  import torchaudio
7
  from .encoder import ConformerEncoder
8
  from torch import Tensor
 
9
  from transformers import Wav2Vec2CTCTokenizer, Wav2Vec2Processor
10
  from transformers.configuration_utils import PretrainedConfig
11
  from transformers.feature_extraction_sequence_utils import \
@@ -445,4 +446,4 @@ class GigaAMRNNTHF(PreTrainedModel):
445
  for i in range(b):
446
  inseq = encoder_out[i, :, :].unsqueeze(1)
447
  preds.append(self._greedy_decode(inseq, encoded_lengths[i]))
448
- return torch.cat(preds, dim=1)
 
6
  import torchaudio
7
  from .encoder import ConformerEncoder
8
  from torch import Tensor
9
+ from torch.nn.utils.rnn import pad_sequence
10
  from transformers import Wav2Vec2CTCTokenizer, Wav2Vec2Processor
11
  from transformers.configuration_utils import PretrainedConfig
12
  from transformers.feature_extraction_sequence_utils import \
 
446
  for i in range(b):
447
  inseq = encoder_out[i, :, :].unsqueeze(1)
448
  preds.append(self._greedy_decode(inseq, encoded_lengths[i]))
449
+ return pad_sequence(preds, batch_first=True, padding_value=self.config.blank_id)