Pass weakref to model in the SIGINT handler to free up model post train function (#1581)
Browse files* Pass weakref to model in the SIGINT handler to free up model post train()
* Fix lint issues
* chore: lint
---------
Co-authored-by: Wing Lian <[email protected]>
- src/axolotl/train.py +12 -5
src/axolotl/train.py
CHANGED
|
@@ -3,6 +3,7 @@
|
|
| 3 |
import os
|
| 4 |
import signal
|
| 5 |
import sys
|
|
|
|
| 6 |
from dataclasses import dataclass
|
| 7 |
from pathlib import Path
|
| 8 |
from typing import Optional, Tuple, Union
|
|
@@ -127,14 +128,20 @@ def train(
|
|
| 127 |
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
|
| 128 |
if cfg.local_rank == 0:
|
| 129 |
|
| 130 |
-
def terminate_handler(_, __,
|
| 131 |
-
if
|
| 132 |
-
|
| 133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
sys.exit(0)
|
| 135 |
|
|
|
|
| 136 |
signal.signal(
|
| 137 |
-
signal.SIGINT,
|
|
|
|
| 138 |
)
|
| 139 |
|
| 140 |
badge_markdown = """[<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)"""
|
|
|
|
| 3 |
import os
|
| 4 |
import signal
|
| 5 |
import sys
|
| 6 |
+
import weakref
|
| 7 |
from dataclasses import dataclass
|
| 8 |
from pathlib import Path
|
| 9 |
from typing import Optional, Tuple, Union
|
|
|
|
| 128 |
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
|
| 129 |
if cfg.local_rank == 0:
|
| 130 |
|
| 131 |
+
def terminate_handler(_, __, model_weakref):
|
| 132 |
+
if model_weakref() is not None:
|
| 133 |
+
_model = model_weakref()
|
| 134 |
+
if cfg.flash_optimum and BetterTransformer:
|
| 135 |
+
_model = BetterTransformer.reverse(_model)
|
| 136 |
+
_model.save_pretrained(
|
| 137 |
+
cfg.output_dir, safe_serialization=safe_serialization
|
| 138 |
+
)
|
| 139 |
sys.exit(0)
|
| 140 |
|
| 141 |
+
_model_weakref = weakref.ref(model)
|
| 142 |
signal.signal(
|
| 143 |
+
signal.SIGINT,
|
| 144 |
+
lambda signum, frame: terminate_handler(signum, frame, _model_weakref),
|
| 145 |
)
|
| 146 |
|
| 147 |
badge_markdown = """[<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)"""
|