Update gigaam_transformers.py
Browse files- gigaam_transformers.py +2 -1
gigaam_transformers.py
CHANGED
@@ -321,7 +321,7 @@ class GigaAMProcessor(Wav2Vec2Processor):
|
|
321 |
|
322 |
|
323 |
class GigaAMConfig(PretrainedConfig):
|
324 |
-
model_type =
|
325 |
|
326 |
def __init__(self, **kwargs):
|
327 |
super().__init__(**kwargs)
|
@@ -394,6 +394,7 @@ class GigaAMRNNTHF(PreTrainedModel):
|
|
394 |
hidden_states = torch.zeros((self.config.head["decoder"]["pred_rnn_layers"], batch_size, self.model.head.decoder.pred_hidden), device=encoder_out.device)
|
395 |
hidden_c = torch.zeros((self.config.head["decoder"]["pred_rnn_layers"], batch_size, self.model.head.decoder.pred_hidden), device=encoder_out.device)
|
396 |
plus_one_dim = self.config.blank_id * torch.ones((batch_size, 1), dtype=torch.int32, device=encoder_out.device)
|
|
|
397 |
|
398 |
decoder_out, h, c = self.model.head.decoder(torch.cat((plus_one_dim, labels), dim=1), hidden_states, hidden_c)
|
399 |
|
|
|
321 |
|
322 |
|
323 |
class GigaAMConfig(PretrainedConfig):
|
324 |
+
model_type = None
|
325 |
|
326 |
def __init__(self, **kwargs):
|
327 |
super().__init__(**kwargs)
|
|
|
394 |
hidden_states = torch.zeros((self.config.head["decoder"]["pred_rnn_layers"], batch_size, self.model.head.decoder.pred_hidden), device=encoder_out.device)
|
395 |
hidden_c = torch.zeros((self.config.head["decoder"]["pred_rnn_layers"], batch_size, self.model.head.decoder.pred_hidden), device=encoder_out.device)
|
396 |
plus_one_dim = self.config.blank_id * torch.ones((batch_size, 1), dtype=torch.int32, device=encoder_out.device)
|
397 |
+
labels[labels < 0] = self.config.blank_id
|
398 |
|
399 |
decoder_out, h, c = self.model.head.decoder(torch.cat((plus_one_dim, labels), dim=1), hidden_states, hidden_c)
|
400 |
|