waveletdeboshir commited on
Commit
e4cbc9d
·
verified ·
1 Parent(s): 5a4f890

Update gigaam_transformers.py

Browse files
Files changed (1) hide show
  1. gigaam_transformers.py +2 -1
gigaam_transformers.py CHANGED
@@ -321,7 +321,7 @@ class GigaAMProcessor(Wav2Vec2Processor):
321
 
322
 
323
  class GigaAMConfig(PretrainedConfig):
324
- model_type = "gigaam-ctc"
325
 
326
  def __init__(self, **kwargs):
327
  super().__init__(**kwargs)
@@ -394,6 +394,7 @@ class GigaAMRNNTHF(PreTrainedModel):
394
  hidden_states = torch.zeros((self.config.head["decoder"]["pred_rnn_layers"], batch_size, self.model.head.decoder.pred_hidden), device=encoder_out.device)
395
  hidden_c = torch.zeros((self.config.head["decoder"]["pred_rnn_layers"], batch_size, self.model.head.decoder.pred_hidden), device=encoder_out.device)
396
  plus_one_dim = self.config.blank_id * torch.ones((batch_size, 1), dtype=torch.int32, device=encoder_out.device)
 
397
 
398
  decoder_out, h, c = self.model.head.decoder(torch.cat((plus_one_dim, labels), dim=1), hidden_states, hidden_c)
399
 
 
321
 
322
 
323
  class GigaAMConfig(PretrainedConfig):
324
+ model_type = None
325
 
326
  def __init__(self, **kwargs):
327
  super().__init__(**kwargs)
 
394
  hidden_states = torch.zeros((self.config.head["decoder"]["pred_rnn_layers"], batch_size, self.model.head.decoder.pred_hidden), device=encoder_out.device)
395
  hidden_c = torch.zeros((self.config.head["decoder"]["pred_rnn_layers"], batch_size, self.model.head.decoder.pred_hidden), device=encoder_out.device)
396
  plus_one_dim = self.config.blank_id * torch.ones((batch_size, 1), dtype=torch.int32, device=encoder_out.device)
397
+ labels[labels < 0] = self.config.blank_id
398
 
399
  decoder_out, h, c = self.model.head.decoder(torch.cat((plus_one_dim, labels), dim=1), hidden_states, hidden_c)
400