kltn20133118's picture
Upload 337 files
dbaa71b verified
import torch
def is_gpu_available() -> bool:
return torch.cuda.is_available()
def get_device_id(device: str) -> int:
if device == "cpu":
return -1
elif device == "auto":
return 0 if is_gpu_available() else -1
elif device.startswith("cuda:"):
device_no = device.replace("cuda:", "")
if device_no.isnumeric():
return int(device_no)
raise Exception(f"Invalid device: '{device}'")