Set matmul tf32
Browse files- scripts/finetune.py +3 -0
scripts/finetune.py
CHANGED
|
@@ -183,6 +183,9 @@ def train(
|
|
| 183 |
cfg.fp16 = True
|
| 184 |
cfg.bf16 = False
|
| 185 |
|
|
|
|
|
|
|
|
|
|
| 186 |
# load the tokenizer first
|
| 187 |
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
|
| 188 |
logging.info(f"loading tokenizer... {tokenizer_config}")
|
|
|
|
| 183 |
cfg.fp16 = True
|
| 184 |
cfg.bf16 = False
|
| 185 |
|
| 186 |
+
if cfg.tf32:
|
| 187 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 188 |
+
|
| 189 |
# load the tokenizer first
|
| 190 |
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
|
| 191 |
logging.info(f"loading tokenizer... {tokenizer_config}")
|