add bf16 check (#587)
Browse files
src/axolotl/utils/config.py
CHANGED
|
@@ -4,6 +4,7 @@ import logging
|
|
| 4 |
import os
|
| 5 |
|
| 6 |
import torch
|
|
|
|
| 7 |
|
| 8 |
from axolotl.utils.bench import log_gpu_memory_usage
|
| 9 |
from axolotl.utils.models import load_model_config
|
|
@@ -89,6 +90,14 @@ def normalize_config(cfg):
|
|
| 89 |
|
| 90 |
|
| 91 |
def validate_config(cfg):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
if cfg.max_packed_sequence_len and cfg.sample_packing:
|
| 93 |
raise ValueError(
|
| 94 |
"please set only one of max_packed_sequence_len (deprecated soon) or sample_packing"
|
|
|
|
| 4 |
import os
|
| 5 |
|
| 6 |
import torch
|
| 7 |
+
from transformers.utils import is_torch_bf16_gpu_available
|
| 8 |
|
| 9 |
from axolotl.utils.bench import log_gpu_memory_usage
|
| 10 |
from axolotl.utils.models import load_model_config
|
|
|
|
| 90 |
|
| 91 |
|
| 92 |
def validate_config(cfg):
|
| 93 |
+
if is_torch_bf16_gpu_available():
|
| 94 |
+
if not cfg.bf16 and not cfg.bfloat16:
|
| 95 |
+
LOG.info("bf16 support detected, but not enabled for this configuration.")
|
| 96 |
+
else:
|
| 97 |
+
if cfg.bf16 or cfg.bfloat16:
|
| 98 |
+
raise ValueError(
|
| 99 |
+
"bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above."
|
| 100 |
+
)
|
| 101 |
if cfg.max_packed_sequence_len and cfg.sample_packing:
|
| 102 |
raise ValueError(
|
| 103 |
"please set only one of max_packed_sequence_len (deprecated soon) or sample_packing"
|