on1onmangoes commited on
Commit
0c88c1c
·
verified ·
1 Parent(s): 28b9cdd

Upload utils/device.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. utils/device.py +40 -0
utils/device.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import subprocess
4
+
5
+ def get_device(force_cpu=False):
6
+ if force_cpu:
7
+ return "cpu"
8
+ if torch.cuda.is_available():
9
+ return "cuda"
10
+ elif torch.backends.mps.is_available():
11
+ torch.mps.empty_cache()
12
+ return "mps"
13
+ else:
14
+ return "cpu"
15
+
16
+ def get_torch_and_np_dtypes(device, use_bfloat16=False):
17
+ if device == "cuda":
18
+ torch_dtype = torch.bfloat16 if use_bfloat16 else torch.float16
19
+ np_dtype = np.float16
20
+ elif device == "mps":
21
+ torch_dtype = torch.bfloat16 if use_bfloat16 else torch.float16
22
+ np_dtype = np.float16
23
+ else:
24
+ torch_dtype = torch.float32
25
+ np_dtype = np.float32
26
+ return torch_dtype, np_dtype
27
+
28
+ def cuda_version_check():
29
+ if torch.cuda.is_available():
30
+ try:
31
+ cuda_runtime = subprocess.check_output(["nvcc", "--version"]).decode()
32
+ cuda_version = cuda_runtime.split()[-2]
33
+ except Exception:
34
+ # Fallback to PyTorch's built-in version if nvcc isn't available
35
+ cuda_version = torch.version.cuda
36
+
37
+ device_name = torch.cuda.get_device_name(0)
38
+ return cuda_version, device_name
39
+ else:
40
+ return None, None