support Apple mps
Browse files
README.md
CHANGED
@@ -46,8 +46,8 @@ from hit_sng_arch import HiT_SNG
|
|
46 |
from hit_srf_arch import HiT_SRF
|
47 |
import cv2
|
48 |
|
49 |
-
#
|
50 |
-
|
51 |
|
52 |
# initialize model (change model and upscale according to your setting)
|
53 |
model = HiT_SRF(upscale=4)
|
@@ -55,11 +55,10 @@ model = HiT_SRF(upscale=4)
|
|
55 |
# load model (change repo_name according to your setting)
|
56 |
repo_name = "XiangZ/hit-srf-4x"
|
57 |
model = model.from_pretrained(repo_name)
|
58 |
-
|
59 |
-
model.cuda()
|
60 |
|
61 |
# test and save results
|
62 |
-
sr_results = model.infer_image("path-to-input-image",
|
63 |
cv2.imwrite("path-to-output-location", sr_results)
|
64 |
```
|
65 |
|
|
|
46 |
from hit_srf_arch import HiT_SRF
|
47 |
import cv2
|
48 |
|
49 |
+
# detect device
|
50 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
|
51 |
|
52 |
# initialize model (change model and upscale according to your setting)
|
53 |
model = HiT_SRF(upscale=4)
|
|
|
55 |
# load model (change repo_name according to your setting)
|
56 |
repo_name = "XiangZ/hit-srf-4x"
|
57 |
model = model.from_pretrained(repo_name)
|
58 |
+
model.to(device)
|
|
|
59 |
|
60 |
# test and save results
|
61 |
+
sr_results = model.infer_image("path-to-input-image", device=device)
|
62 |
cv2.imwrite("path-to-output-location", sr_results)
|
63 |
```
|
64 |
|