Update gigaam_transformers.py
Browse files- gigaam_transformers.py +1 -1
gigaam_transformers.py
CHANGED
@@ -432,7 +432,7 @@ class GigaAMRNNTHF(PreTrainedModel):
|
|
432 |
last_label = torch.tensor([[hyp[-1]]]).to(x.device)
|
433 |
new_symbols += 1
|
434 |
|
435 |
-
return torch.tensor(
|
436 |
|
437 |
@torch.inference_mode()
|
438 |
def generate(self, input_features: Tensor, input_lengths: Tensor, **kwargs) -> torch.Tensor:
|
|
|
432 |
last_label = torch.tensor([[hyp[-1]]]).to(x.device)
|
433 |
new_symbols += 1
|
434 |
|
435 |
+
return torch.tensor(hyp, dtype=torch.int32)
|
436 |
|
437 |
@torch.inference_mode()
|
438 |
def generate(self, input_features: Tensor, input_lengths: Tensor, **kwargs) -> torch.Tensor:
|