waveletdeboshir commited on
Commit
27c2b34
·
verified ·
1 Parent(s): 8d2bbab

Update gigaam_transformers.py

Browse files
Files changed (1) hide show
  1. gigaam_transformers.py +1 -1
gigaam_transformers.py CHANGED
@@ -432,7 +432,7 @@ class GigaAMRNNTHF(PreTrainedModel):
432
  last_label = torch.tensor([[hyp[-1]]]).to(x.device)
433
  new_symbols += 1
434
 
435
- return torch.tensor([hyp], dtype=torch.int32)
436
 
437
  @torch.inference_mode()
438
  def generate(self, input_features: Tensor, input_lengths: Tensor, **kwargs) -> torch.Tensor:
 
432
  last_label = torch.tensor([[hyp[-1]]]).to(x.device)
433
  new_symbols += 1
434
 
435
+ return torch.tensor(hyp, dtype=torch.int32)
436
 
437
  @torch.inference_mode()
438
  def generate(self, input_features: Tensor, input_lengths: Tensor, **kwargs) -> torch.Tensor: