Spaces:
Running
Running
Merge pull request #32 from borisdayma/feat-model
Browse filesfeat: save and restore checkpoints
Former-commit-id: 6254697762481523764fcb4c8856e63203d2f117
- seq2seq/run_seq2seq_flax.py +55 -12
seq2seq/run_seq2seq_flax.py
CHANGED
|
@@ -271,6 +271,10 @@ class TrainState(train_state.TrainState):
|
|
| 271 |
|
| 272 |
class CustomFlaxBartModule(FlaxBartModule):
|
| 273 |
def setup(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
# we keep shared to easily load pre-trained weights
|
| 275 |
self.shared = nn.Embed(
|
| 276 |
self.config.vocab_size,
|
|
@@ -280,7 +284,7 @@ class CustomFlaxBartModule(FlaxBartModule):
|
|
| 280 |
)
|
| 281 |
# a separate embedding is used for the decoder
|
| 282 |
self.decoder_embed = nn.Embed(
|
| 283 |
-
|
| 284 |
self.config.d_model,
|
| 285 |
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
| 286 |
dtype=self.dtype,
|
|
@@ -289,20 +293,23 @@ class CustomFlaxBartModule(FlaxBartModule):
|
|
| 289 |
|
| 290 |
# the decoder has a different config
|
| 291 |
decoder_config = BartConfig(self.config.to_dict())
|
| 292 |
-
decoder_config.max_position_embeddings =
|
| 293 |
-
decoder_config.vocab_size =
|
| 294 |
self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)
|
| 295 |
|
| 296 |
class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):
|
| 297 |
def setup(self):
|
|
|
|
|
|
|
|
|
|
| 298 |
self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
|
| 299 |
self.lm_head = nn.Dense(
|
| 300 |
-
|
| 301 |
use_bias=False,
|
| 302 |
dtype=self.dtype,
|
| 303 |
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
| 304 |
)
|
| 305 |
-
self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1,
|
| 306 |
|
| 307 |
class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):
|
| 308 |
module_class = CustomFlaxBartForConditionalGenerationModule
|
|
@@ -429,11 +436,24 @@ def main():
|
|
| 429 |
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
| 430 |
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
| 431 |
|
| 432 |
-
#
|
| 433 |
-
tokenizer =
|
| 434 |
-
|
| 435 |
-
)
|
| 436 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 437 |
if model_args.from_checkpoint is not None:
|
| 438 |
artifact = wandb.run.use_artifact(model_args.from_checkpoint)
|
| 439 |
artifact_dir = artifact.download()
|
|
@@ -448,6 +468,12 @@ def main():
|
|
| 448 |
# used in the preprocessing function
|
| 449 |
config = model.config
|
| 450 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 451 |
else:
|
| 452 |
base_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
|
| 453 |
model_args.model_name_or_path, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
|
@@ -473,6 +499,12 @@ def main():
|
|
| 473 |
model.params['model']['shared'] = base_model.params['model']['shared']
|
| 474 |
del base_model
|
| 475 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 476 |
print(f"TPUs: {jax.device_count()}")
|
| 477 |
assert jax.device_count() == 8, "TPUs in use, please check running processes"
|
| 478 |
|
|
@@ -669,6 +701,9 @@ def main():
|
|
| 669 |
grad_accum=jax.tree_map(jnp.zeros_like, model.params),
|
| 670 |
optimizer_step=0,
|
| 671 |
)
|
|
|
|
|
|
|
|
|
|
| 672 |
|
| 673 |
# label smoothed cross entropy
|
| 674 |
def loss_fn(logits, labels):
|
|
@@ -811,13 +846,16 @@ def main():
|
|
| 811 |
params=params,
|
| 812 |
)
|
| 813 |
|
|
|
|
|
|
|
|
|
|
| 814 |
# save state
|
| 815 |
state = unreplicate(state)
|
| 816 |
with (Path(training_args.output_dir) / 'opt_state.msgpack').open('wb') as f:
|
| 817 |
f.write(to_bytes(state.opt_state))
|
| 818 |
with (Path(training_args.output_dir) / 'training_state.json').open('w') as f:
|
| 819 |
json.dump({'step': state.step.item()}, f)
|
| 820 |
-
|
| 821 |
# save to W&B
|
| 822 |
if data_args.log_model:
|
| 823 |
metadata = {'step': step, 'epoch': epoch}
|
|
@@ -826,8 +864,13 @@ def main():
|
|
| 826 |
artifact = wandb.Artifact(
|
| 827 |
name=f"model-{wandb.run.id}", type="bart_model", metadata=metadata
|
| 828 |
)
|
| 829 |
-
artifact.add_file(str(Path(training_args.output_dir) / 'flax_model.msgpack'))
|
| 830 |
-
artifact.add_file(str(Path(training_args.output_dir) / 'config.json'))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 831 |
artifact.add_file(str(Path(training_args.output_dir) / 'opt_state.msgpack'))
|
| 832 |
artifact.add_file(str(Path(training_args.output_dir) / 'training_state.json'))
|
| 833 |
wandb.run.log_artifact(artifact)
|
|
|
|
| 271 |
|
| 272 |
class CustomFlaxBartModule(FlaxBartModule):
|
| 273 |
def setup(self):
|
| 274 |
+
# check config is valid, otherwise set default values
|
| 275 |
+
self.config.vocab_size_output = getattr(self.config, 'vocab_size_output', OUTPUT_VOCAB_SIZE)
|
| 276 |
+
self.config.max_position_embeddings_decoder = getattr(self.config, 'max_position_embeddings_decoder', OUTPUT_LENGTH)
|
| 277 |
+
|
| 278 |
# we keep shared to easily load pre-trained weights
|
| 279 |
self.shared = nn.Embed(
|
| 280 |
self.config.vocab_size,
|
|
|
|
| 284 |
)
|
| 285 |
# a separate embedding is used for the decoder
|
| 286 |
self.decoder_embed = nn.Embed(
|
| 287 |
+
self.config.vocab_size_output,
|
| 288 |
self.config.d_model,
|
| 289 |
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
| 290 |
dtype=self.dtype,
|
|
|
|
| 293 |
|
| 294 |
# the decoder has a different config
|
| 295 |
decoder_config = BartConfig(self.config.to_dict())
|
| 296 |
+
decoder_config.max_position_embeddings = self.config.max_position_embeddings_decoder
|
| 297 |
+
decoder_config.vocab_size = self.config.vocab_size_output
|
| 298 |
self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)
|
| 299 |
|
| 300 |
class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):
|
| 301 |
def setup(self):
|
| 302 |
+
# check config is valid, otherwise set default values
|
| 303 |
+
self.config.vocab_size_output = getattr(self.config, 'vocab_size_output', OUTPUT_VOCAB_SIZE)
|
| 304 |
+
|
| 305 |
self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
|
| 306 |
self.lm_head = nn.Dense(
|
| 307 |
+
self.config.vocab_size_output,
|
| 308 |
use_bias=False,
|
| 309 |
dtype=self.dtype,
|
| 310 |
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
| 311 |
)
|
| 312 |
+
self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.config.vocab_size_output))
|
| 313 |
|
| 314 |
class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):
|
| 315 |
module_class = CustomFlaxBartForConditionalGenerationModule
|
|
|
|
| 436 |
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
| 437 |
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
| 438 |
|
| 439 |
+
# Set up items to load or create
|
| 440 |
+
tokenizer = None
|
| 441 |
+
artifact_dir = None
|
|
|
|
| 442 |
|
| 443 |
+
def restore_state(state, artifact_dir):
|
| 444 |
+
# restore optimizer state
|
| 445 |
+
if (Path(artifact_dir) / 'opt_state.msgpack').exists():
|
| 446 |
+
with (Path(artifact_dir) / 'opt_state.msgpack').open('rb') as f:
|
| 447 |
+
opt_state = from_bytes(state.opt_state, f.read())
|
| 448 |
+
|
| 449 |
+
# restore steps
|
| 450 |
+
if (Path(artifact_dir) / 'training_state.json').exists():
|
| 451 |
+
with (Path(artifact_dir) / 'training_state.json').open('r') as f:
|
| 452 |
+
training_state = json.load(f)
|
| 453 |
+
step = training_state['step']
|
| 454 |
+
optimizer_step = step // training_args.gradient_accumulation_steps
|
| 455 |
+
state.replace(step=step, optimizer_step=optimizer_step)
|
| 456 |
+
|
| 457 |
if model_args.from_checkpoint is not None:
|
| 458 |
artifact = wandb.run.use_artifact(model_args.from_checkpoint)
|
| 459 |
artifact_dir = artifact.download()
|
|
|
|
| 468 |
# used in the preprocessing function
|
| 469 |
config = model.config
|
| 470 |
|
| 471 |
+
# load tokenizer if present
|
| 472 |
+
if (Path(artifact_dir) / 'tokenizer_config.json').exists():
|
| 473 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 474 |
+
model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
| 475 |
+
)
|
| 476 |
+
|
| 477 |
else:
|
| 478 |
base_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
|
| 479 |
model_args.model_name_or_path, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
|
|
|
| 499 |
model.params['model']['shared'] = base_model.params['model']['shared']
|
| 500 |
del base_model
|
| 501 |
|
| 502 |
+
# Load tokenizer if it has not been set
|
| 503 |
+
if tokenizer is None:
|
| 504 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 505 |
+
model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
| 506 |
+
)
|
| 507 |
+
|
| 508 |
print(f"TPUs: {jax.device_count()}")
|
| 509 |
assert jax.device_count() == 8, "TPUs in use, please check running processes"
|
| 510 |
|
|
|
|
| 701 |
grad_accum=jax.tree_map(jnp.zeros_like, model.params),
|
| 702 |
optimizer_step=0,
|
| 703 |
)
|
| 704 |
+
if model_args.from_checkpoint is not None:
|
| 705 |
+
# restore optimizer state, step and optimizer_step
|
| 706 |
+
restore_state(state, artifact_dir)
|
| 707 |
|
| 708 |
# label smoothed cross entropy
|
| 709 |
def loss_fn(logits, labels):
|
|
|
|
| 846 |
params=params,
|
| 847 |
)
|
| 848 |
|
| 849 |
+
# save tokenizer
|
| 850 |
+
tokenizer.save_pretrained(training_args.output_dir)
|
| 851 |
+
|
| 852 |
# save state
|
| 853 |
state = unreplicate(state)
|
| 854 |
with (Path(training_args.output_dir) / 'opt_state.msgpack').open('wb') as f:
|
| 855 |
f.write(to_bytes(state.opt_state))
|
| 856 |
with (Path(training_args.output_dir) / 'training_state.json').open('w') as f:
|
| 857 |
json.dump({'step': state.step.item()}, f)
|
| 858 |
+
|
| 859 |
# save to W&B
|
| 860 |
if data_args.log_model:
|
| 861 |
metadata = {'step': step, 'epoch': epoch}
|
|
|
|
| 864 |
artifact = wandb.Artifact(
|
| 865 |
name=f"model-{wandb.run.id}", type="bart_model", metadata=metadata
|
| 866 |
)
|
| 867 |
+
artifact.add_file(str(Path(training_args.output_dir) / 'flax_model.msgpack'))
|
| 868 |
+
artifact.add_file(str(Path(training_args.output_dir) / 'config.json'))
|
| 869 |
+
artifact.add_file(str(Path(training_args.output_dir) / 'tokenizer.json'))
|
| 870 |
+
artifact.add_file(str(Path(training_args.output_dir) / 'tokenizer_config.json'))
|
| 871 |
+
artifact.add_file(str(Path(training_args.output_dir) / 'vocab.json'))
|
| 872 |
+
artifact.add_file(str(Path(training_args.output_dir) / 'merges.txt'))
|
| 873 |
+
artifact.add_file(str(Path(training_args.output_dir) / 'special_tokens_map.json'))
|
| 874 |
artifact.add_file(str(Path(training_args.output_dir) / 'opt_state.msgpack'))
|
| 875 |
artifact.add_file(str(Path(training_args.output_dir) / 'training_state.json'))
|
| 876 |
wandb.run.log_artifact(artifact)
|