Spaces:
Running
Running
feat: padding mask not required
Browse files- seq2seq/run_seq2seq_flax.py +4 -15
seq2seq/run_seq2seq_flax.py
CHANGED
|
@@ -487,10 +487,6 @@ def main():
|
|
| 487 |
|
| 488 |
model_inputs["decoder_input_ids"] = labels
|
| 489 |
|
| 490 |
-
# We need decoder_attention_mask so we can ignore pad tokens from loss
|
| 491 |
-
# TODO: I don't believe we need "decoder_attention_mask" in this case because all labels have same length
|
| 492 |
-
#model_inputs["decoder_attention_mask"] = labels["attention_mask"]
|
| 493 |
-
|
| 494 |
return model_inputs
|
| 495 |
|
| 496 |
if training_args.do_train:
|
|
@@ -643,7 +639,7 @@ def main():
|
|
| 643 |
state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, dropout_rng=dropout_rng)
|
| 644 |
|
| 645 |
# label smoothed cross entropy
|
| 646 |
-
def loss_fn(logits, labels,
|
| 647 |
"""
|
| 648 |
The label smoothing implementation is adapted from Flax's official example:
|
| 649 |
https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104
|
|
@@ -659,12 +655,7 @@ def main():
|
|
| 659 |
loss = optax.softmax_cross_entropy(logits, soft_labels)
|
| 660 |
loss = loss - normalizing_constant
|
| 661 |
|
| 662 |
-
|
| 663 |
-
padding_mask = np.ones(loss.shape)
|
| 664 |
-
|
| 665 |
-
# ignore padded tokens from loss
|
| 666 |
-
loss = loss * padding_mask
|
| 667 |
-
loss = loss.sum() / padding_mask.sum()
|
| 668 |
return loss
|
| 669 |
|
| 670 |
# Define gradient update step fn
|
|
@@ -674,8 +665,7 @@ def main():
|
|
| 674 |
def compute_loss(params):
|
| 675 |
labels = batch.pop("labels")
|
| 676 |
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
|
| 677 |
-
|
| 678 |
-
loss = loss_fn(logits, labels, padding_mask, label_smoothing_factor)
|
| 679 |
return loss
|
| 680 |
|
| 681 |
grad_fn = jax.value_and_grad(compute_loss)
|
|
@@ -693,8 +683,7 @@ def main():
|
|
| 693 |
def eval_step(params, batch, label_smoothing_factor=0.0):
|
| 694 |
labels = batch.pop("labels")
|
| 695 |
logits = model(**batch, params=params, train=False)[0]
|
| 696 |
-
|
| 697 |
-
loss = loss_fn(logits, labels, padding_mask, label_smoothing_factor)
|
| 698 |
|
| 699 |
# summarize metrics
|
| 700 |
metrics = {"loss": loss}
|
|
|
|
| 487 |
|
| 488 |
model_inputs["decoder_input_ids"] = labels
|
| 489 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 490 |
return model_inputs
|
| 491 |
|
| 492 |
if training_args.do_train:
|
|
|
|
| 639 |
state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, dropout_rng=dropout_rng)
|
| 640 |
|
| 641 |
# label smoothed cross entropy
|
| 642 |
+
def loss_fn(logits, labels, label_smoothing_factor=0.0):
|
| 643 |
"""
|
| 644 |
The label smoothing implementation is adapted from Flax's official example:
|
| 645 |
https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104
|
|
|
|
| 655 |
loss = optax.softmax_cross_entropy(logits, soft_labels)
|
| 656 |
loss = loss - normalizing_constant
|
| 657 |
|
| 658 |
+
loss = loss.mean()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 659 |
return loss
|
| 660 |
|
| 661 |
# Define gradient update step fn
|
|
|
|
| 665 |
def compute_loss(params):
|
| 666 |
labels = batch.pop("labels")
|
| 667 |
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
|
| 668 |
+
loss = loss_fn(logits, labels, label_smoothing_factor)
|
|
|
|
| 669 |
return loss
|
| 670 |
|
| 671 |
grad_fn = jax.value_and_grad(compute_loss)
|
|
|
|
| 683 |
def eval_step(params, batch, label_smoothing_factor=0.0):
|
| 684 |
labels = batch.pop("labels")
|
| 685 |
logits = model(**batch, params=params, train=False)[0]
|
| 686 |
+
loss = loss_fn(logits, labels, label_smoothing_factor)
|
|
|
|
| 687 |
|
| 688 |
# summarize metrics
|
| 689 |
metrics = {"loss": loss}
|