Spaces:
Running
Running
Fix bug
Browse files- translate_troch2trt.py +1 -1
translate_troch2trt.py
CHANGED
|
@@ -73,7 +73,7 @@ def main(
|
|
| 73 |
) as output_file:
|
| 74 |
with torch.no_grad():
|
| 75 |
for batch in data_loader:
|
| 76 |
-
batch = batch.to(device, dtype=dtype)
|
| 77 |
generated_tokens = model.generate(
|
| 78 |
**batch, forced_bos_token_id=lang_code_to_idx
|
| 79 |
)
|
|
|
|
| 73 |
) as output_file:
|
| 74 |
with torch.no_grad():
|
| 75 |
for batch in data_loader:
|
| 76 |
+
batch["input_ids"] = batch["input_ids"].to(device, dtype=dtype)
|
| 77 |
generated_tokens = model.generate(
|
| 78 |
**batch, forced_bos_token_id=lang_code_to_idx
|
| 79 |
)
|