fix model device
Browse files
app.py
CHANGED
|
@@ -84,11 +84,18 @@ def predict(model_name, pairs_file, sequence_file, progress = gr.Progress()):
|
|
| 84 |
try:
|
| 85 |
run_id = uuid4()
|
| 86 |
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
|
|
|
|
| 87 |
|
| 88 |
# gr.Info("Loading model...")
|
| 89 |
_ = lm_embed("M", use_cuda = (device.type == "cuda"))
|
| 90 |
|
| 91 |
-
model = DSCRIPTModel.from_pretrained(model_map[model_name], use_cuda=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
# gr.Info("Loading files...")
|
| 94 |
try:
|
|
|
|
| 84 |
try:
|
| 85 |
run_id = uuid4()
|
| 86 |
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
|
| 87 |
+
use_cuda = torch.cuda.is_available()
|
| 88 |
|
| 89 |
# gr.Info("Loading model...")
|
| 90 |
_ = lm_embed("M", use_cuda = (device.type == "cuda"))
|
| 91 |
|
| 92 |
+
model = DSCRIPTModel.from_pretrained(model_map[model_name], use_cuda=use_cuda)
|
| 93 |
+
if use_cuda:
|
| 94 |
+
model = model.to(device)
|
| 95 |
+
model.use_cuda = True
|
| 96 |
+
else:
|
| 97 |
+
model = model.to("cpu")
|
| 98 |
+
model.use_cuda = False
|
| 99 |
|
| 100 |
# gr.Info("Loading files...")
|
| 101 |
try:
|