Maxime
commited on
fix distributed devices (#612)
Browse files* fix distributed devices
* Update distributed.py
* Update distributed.py
src/axolotl/utils/distributed.py
CHANGED
|
@@ -77,7 +77,9 @@ def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-n
|
|
| 77 |
value_scalar = fn()
|
| 78 |
if not is_distributed():
|
| 79 |
return [value_scalar]
|
| 80 |
-
value_tensor = torch.tensor(
|
|
|
|
|
|
|
| 81 |
|
| 82 |
if not is_main_process():
|
| 83 |
dist.gather(value_tensor, dst=0)
|
|
@@ -137,9 +139,13 @@ def compute_and_broadcast(fn): # pylint: disable=invalid-name
|
|
| 137 |
"""
|
| 138 |
if is_main_process():
|
| 139 |
value_scalar = fn()
|
| 140 |
-
value_tensor = torch.tensor(
|
|
|
|
|
|
|
| 141 |
else:
|
| 142 |
-
value_tensor = torch.tensor(
|
|
|
|
|
|
|
| 143 |
|
| 144 |
# Broadcast the tensor to all processes.
|
| 145 |
barrier()
|
|
@@ -164,7 +170,9 @@ def gather_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-name
|
|
| 164 |
- A list of computed values from all ranks if on the gathering rank, otherwise None.
|
| 165 |
"""
|
| 166 |
value_scalar = fn()
|
| 167 |
-
value_tensor = torch.tensor(
|
|
|
|
|
|
|
| 168 |
|
| 169 |
# Placeholder tensor for gathering results
|
| 170 |
if is_main_process():
|
|
|
|
| 77 |
value_scalar = fn()
|
| 78 |
if not is_distributed():
|
| 79 |
return [value_scalar]
|
| 80 |
+
value_tensor = torch.tensor(
|
| 81 |
+
value_scalar, device=torch.cuda.current_device()
|
| 82 |
+
).float()
|
| 83 |
|
| 84 |
if not is_main_process():
|
| 85 |
dist.gather(value_tensor, dst=0)
|
|
|
|
| 139 |
"""
|
| 140 |
if is_main_process():
|
| 141 |
value_scalar = fn()
|
| 142 |
+
value_tensor = torch.tensor(
|
| 143 |
+
value_scalar, device=torch.cuda.current_device()
|
| 144 |
+
).float()
|
| 145 |
else:
|
| 146 |
+
value_tensor = torch.tensor(
|
| 147 |
+
0.0, device=torch.cuda.current_device()
|
| 148 |
+
) # Placeholder tensor
|
| 149 |
|
| 150 |
# Broadcast the tensor to all processes.
|
| 151 |
barrier()
|
|
|
|
| 170 |
- A list of computed values from all ranks if on the gathering rank, otherwise None.
|
| 171 |
"""
|
| 172 |
value_scalar = fn()
|
| 173 |
+
value_tensor = torch.tensor(
|
| 174 |
+
value_scalar, device=torch.cuda.current_device()
|
| 175 |
+
).float()
|
| 176 |
|
| 177 |
# Placeholder tensor for gathering results
|
| 178 |
if is_main_process():
|