XiangZ commited on
Commit
1233664
·
verified ·
1 Parent(s): 26f4236

support Apple mps

Browse files
Files changed (1) hide show
  1. README.md +4 -5
README.md CHANGED
@@ -46,8 +46,8 @@ from hit_sng_arch import HiT_SNG
46
  from hit_srf_arch import HiT_SRF
47
  import cv2
48
 
49
- # use GPU (True) or CPU (False)
50
- cuda_flag = True
51
 
52
  # initialize model (change model and upscale according to your setting)
53
  model = HiT_SRF(upscale=4)
@@ -55,11 +55,10 @@ model = HiT_SRF(upscale=4)
55
  # load model (change repo_name according to your setting)
56
  repo_name = "XiangZ/hit-srf-4x"
57
  model = model.from_pretrained(repo_name)
58
- if cuda_flag:
59
- model.cuda()
60
 
61
  # test and save results
62
- sr_results = model.infer_image("path-to-input-image", cuda=cuda_flag)
63
  cv2.imwrite("path-to-output-location", sr_results)
64
  ```
65
 
 
46
  from hit_srf_arch import HiT_SRF
47
  import cv2
48
 
49
+ # detect device
50
+ device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
51
 
52
  # initialize model (change model and upscale according to your setting)
53
  model = HiT_SRF(upscale=4)
 
55
  # load model (change repo_name according to your setting)
56
  repo_name = "XiangZ/hit-srf-4x"
57
  model = model.from_pretrained(repo_name)
58
+ model.to(device)
 
59
 
60
  # test and save results
61
+ sr_results = model.infer_image("path-to-input-image", device=device)
62
  cv2.imwrite("path-to-output-location", sr_results)
63
  ```
64