Spaces:
Running
Running
feat: simplify loss function
Browse files- seq2seq/run_seq2seq_flax.py +8 -22
seq2seq/run_seq2seq_flax.py
CHANGED
|
@@ -639,33 +639,19 @@ def main():
|
|
| 639 |
state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, dropout_rng=dropout_rng)
|
| 640 |
|
| 641 |
# label smoothed cross entropy
|
| 642 |
-
def loss_fn(logits, labels
|
| 643 |
-
|
| 644 |
-
The label smoothing implementation is adapted from Flax's official example:
|
| 645 |
-
https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104
|
| 646 |
-
"""
|
| 647 |
-
vocab_size = logits.shape[-1]
|
| 648 |
-
confidence = 1.0 - label_smoothing_factor
|
| 649 |
-
low_confidence = (1.0 - confidence) / (vocab_size - 1)
|
| 650 |
-
normalizing_constant = -(
|
| 651 |
-
confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
|
| 652 |
-
)
|
| 653 |
-
soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence)
|
| 654 |
-
|
| 655 |
-
loss = optax.softmax_cross_entropy(logits, soft_labels)
|
| 656 |
-
loss = loss - normalizing_constant
|
| 657 |
-
|
| 658 |
loss = loss.mean()
|
| 659 |
return loss
|
| 660 |
|
| 661 |
# Define gradient update step fn
|
| 662 |
-
def train_step(state, batch
|
| 663 |
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
|
| 664 |
|
| 665 |
def compute_loss(params):
|
| 666 |
labels = batch.pop("labels")
|
| 667 |
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
|
| 668 |
-
loss = loss_fn(logits, labels
|
| 669 |
return loss
|
| 670 |
|
| 671 |
grad_fn = jax.value_and_grad(compute_loss)
|
|
@@ -680,10 +666,10 @@ def main():
|
|
| 680 |
return new_state, metrics
|
| 681 |
|
| 682 |
# Define eval fn
|
| 683 |
-
def eval_step(params, batch
|
| 684 |
labels = batch.pop("labels")
|
| 685 |
logits = model(**batch, params=params, train=False)[0]
|
| 686 |
-
loss = loss_fn(logits, labels
|
| 687 |
|
| 688 |
# summarize metrics
|
| 689 |
metrics = {"loss": loss}
|
|
@@ -704,9 +690,9 @@ def main():
|
|
| 704 |
|
| 705 |
# Create parallel version of the train and eval step
|
| 706 |
p_train_step = jax.pmap(
|
| 707 |
-
|
| 708 |
)
|
| 709 |
-
p_eval_step = jax.pmap(
|
| 710 |
p_generate_step = jax.pmap(generate_step, "batch")
|
| 711 |
|
| 712 |
# Replicate the train state on each device
|
|
|
|
| 639 |
state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, dropout_rng=dropout_rng)
|
| 640 |
|
| 641 |
# label smoothed cross entropy
|
| 642 |
+
def loss_fn(logits, labels):
|
| 643 |
+
loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1]))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 644 |
loss = loss.mean()
|
| 645 |
return loss
|
| 646 |
|
| 647 |
# Define gradient update step fn
|
| 648 |
+
def train_step(state, batch):
|
| 649 |
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
|
| 650 |
|
| 651 |
def compute_loss(params):
|
| 652 |
labels = batch.pop("labels")
|
| 653 |
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
|
| 654 |
+
loss = loss_fn(logits, labels)
|
| 655 |
return loss
|
| 656 |
|
| 657 |
grad_fn = jax.value_and_grad(compute_loss)
|
|
|
|
| 666 |
return new_state, metrics
|
| 667 |
|
| 668 |
# Define eval fn
|
| 669 |
+
def eval_step(params, batch):
|
| 670 |
labels = batch.pop("labels")
|
| 671 |
logits = model(**batch, params=params, train=False)[0]
|
| 672 |
+
loss = loss_fn(logits, labels)
|
| 673 |
|
| 674 |
# summarize metrics
|
| 675 |
metrics = {"loss": loss}
|
|
|
|
| 690 |
|
| 691 |
# Create parallel version of the train and eval step
|
| 692 |
p_train_step = jax.pmap(
|
| 693 |
+
train_step, "batch", donate_argnums=(0,)
|
| 694 |
)
|
| 695 |
+
p_eval_step = jax.pmap(eval_step, "batch")
|
| 696 |
p_generate_step = jax.pmap(generate_step, "batch")
|
| 697 |
|
| 698 |
# Replicate the train state on each device
|