Spaces:
Running
Running
Fix bug
Browse files- translate_troch2trt.py +3 -1
translate_troch2trt.py
CHANGED
|
@@ -55,8 +55,10 @@ def main(
|
|
| 55 |
device = "cuda"
|
| 56 |
from torch2trt import torch2trt
|
| 57 |
|
|
|
|
|
|
|
| 58 |
model = torch2trt(
|
| 59 |
-
model
|
| 60 |
[torch.randn((batch_size, max_length)).to(device, dtype=torch.long)],
|
| 61 |
)
|
| 62 |
|
|
|
|
| 55 |
device = "cuda"
|
| 56 |
from torch2trt import torch2trt
|
| 57 |
|
| 58 |
+
model.to(device, dtype=dtype)
|
| 59 |
+
|
| 60 |
model = torch2trt(
|
| 61 |
+
model,
|
| 62 |
[torch.randn((batch_size, max_length)).to(device, dtype=torch.long)],
|
| 63 |
)
|
| 64 |
|