Spaces:
Running
Running
fix: OOM with checkpoints
Browse files
dev/seq2seq/run_seq2seq_flax.py
CHANGED
|
@@ -262,15 +262,15 @@ class TrainState(train_state.TrainState):
|
|
| 262 |
def restore_state(self, artifact_dir):
|
| 263 |
# restore optimizer state
|
| 264 |
with (Path(artifact_dir) / "opt_state.msgpack").open("rb") as f:
|
| 265 |
-
|
| 266 |
|
| 267 |
# restore steps
|
| 268 |
with (Path(artifact_dir) / "training_state.json").open("r") as f:
|
| 269 |
training_state = json.load(f)
|
| 270 |
-
|
| 271 |
|
| 272 |
# replace state
|
| 273 |
-
return self.replace(step=
|
| 274 |
|
| 275 |
|
| 276 |
class CustomFlaxBartModule(FlaxBartModule):
|
|
@@ -802,6 +802,7 @@ def main():
|
|
| 802 |
|
| 803 |
# Replicate the train state on each device
|
| 804 |
state = state.replicate()
|
|
|
|
| 805 |
|
| 806 |
logger.info("***** Running training *****")
|
| 807 |
logger.info(f" Num examples = {len_train_dataset}")
|
|
|
|
| 262 |
def restore_state(self, artifact_dir):
|
| 263 |
# restore optimizer state
|
| 264 |
with (Path(artifact_dir) / "opt_state.msgpack").open("rb") as f:
|
| 265 |
+
new_opt_state = from_bytes(self.opt_state, f.read())
|
| 266 |
|
| 267 |
# restore steps
|
| 268 |
with (Path(artifact_dir) / "training_state.json").open("r") as f:
|
| 269 |
training_state = json.load(f)
|
| 270 |
+
new_step = training_state["step"]
|
| 271 |
|
| 272 |
# replace state
|
| 273 |
+
return self.replace(step=new_step, opt_state=new_opt_state)
|
| 274 |
|
| 275 |
|
| 276 |
class CustomFlaxBartModule(FlaxBartModule):
|
|
|
|
| 802 |
|
| 803 |
# Replicate the train state on each device
|
| 804 |
state = state.replicate()
|
| 805 |
+
del model._params
|
| 806 |
|
| 807 |
logger.info("***** Running training *****")
|
| 808 |
logger.info(f" Num examples = {len_train_dataset}")
|