fix device map
Browse files- scripts/finetune.py +5 -4
scripts/finetune.py
CHANGED
|
@@ -47,10 +47,11 @@ def choose_device(cfg):
|
|
| 47 |
return "cpu"
|
| 48 |
|
| 49 |
cfg.device = get_device()
|
| 50 |
-
if cfg.
|
| 51 |
-
cfg.
|
| 52 |
-
|
| 53 |
-
|
|
|
|
| 54 |
|
| 55 |
|
| 56 |
def get_multi_line_input() -> Optional[str]:
|
|
|
|
| 47 |
return "cpu"
|
| 48 |
|
| 49 |
cfg.device = get_device()
|
| 50 |
+
if cfg.device_map != "auto":
|
| 51 |
+
if cfg.device.startswith("cuda"):
|
| 52 |
+
cfg.device_map = {"": cfg.local_rank}
|
| 53 |
+
else:
|
| 54 |
+
cfg.device_map = {"": cfg.device}
|
| 55 |
|
| 56 |
|
| 57 |
def get_multi_line_input() -> Optional[str]:
|