Spaces:
Running
Running
Pedro Cuenca
commited on
Commit
·
df3c7bd
1
Parent(s):
a841a4c
Preprocessing: return "labels", "decoder_input_ids" and
Browse files"decoder_attention_mask".
All fields are required later on to compute the loss.
Note that labels and decoder_input_ids are the same in our case. I'm not
sure that's correct, but shifting right the decoder_inputs would lose
the last token.
- seq2seq/run_seq2seq_flax.py +16 -5
seq2seq/run_seq2seq_flax.py
CHANGED
@@ -458,19 +458,30 @@ def main():
|
|
458 |
)
|
459 |
|
460 |
# set up targets
|
461 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
462 |
|
463 |
# TODO: if data processing prevents correct compilation, we will:
|
464 |
# - have data saved in JSONL (to avoid `eval` which is needed here to convert string "[2]" to list[int])
|
465 |
# - use below `shift_tokens_right_fn`
|
466 |
-
|
467 |
-
|
468 |
-
|
|
|
469 |
|
470 |
-
model_inputs["decoder_input_ids"] =
|
471 |
|
472 |
# We need decoder_attention_mask so we can ignore pad tokens from loss
|
473 |
# TODO: I don't believe we need "decoder_attention_mask" in this case because all labels have same length
|
|
|
|
|
474 |
#model_inputs["decoder_attention_mask"] = labels["attention_mask"]
|
475 |
|
476 |
return model_inputs
|
|
|
458 |
)
|
459 |
|
460 |
# set up targets
|
461 |
+
# Note: we prepend the bos token instead of doing `shift_tokens_right` because the latter
|
462 |
+
# removes the last token, and we know we don't need padding. In our case, labels
|
463 |
+
# has a length of exactly 1 + 256, while shifting would produce 256 tokens.
|
464 |
+
labels = [[config.decoder_start_token_id] + eval(indices) for indices in examples['encoding']]
|
465 |
+
labels = np.asarray(labels)
|
466 |
+
|
467 |
+
# We need the labels, in addition to the decoder_input_ids, for the compute_loss function
|
468 |
+
# In our case, they are the same as decoder_input_ids. Is that correct?
|
469 |
+
model_inputs["labels"] = labels
|
470 |
|
471 |
# TODO: if data processing prevents correct compilation, we will:
|
472 |
# - have data saved in JSONL (to avoid `eval` which is needed here to convert string "[2]" to list[int])
|
473 |
# - use below `shift_tokens_right_fn`
|
474 |
+
# In our case, this prepends the bos token and removes the last one
|
475 |
+
# decoder_input_ids = shift_tokens_right_fn(
|
476 |
+
# jnp.array(labels), config.pad_token_id, config.decoder_start_token_id
|
477 |
+
# )
|
478 |
|
479 |
+
model_inputs["decoder_input_ids"] = labels
|
480 |
|
481 |
# We need decoder_attention_mask so we can ignore pad tokens from loss
|
482 |
# TODO: I don't believe we need "decoder_attention_mask" in this case because all labels have same length
|
483 |
+
# However, we need to provide a mask or modify the compute_loss function, which relies on having one
|
484 |
+
model_inputs["decoder_attention_mask"] = np.ones(labels.shape)
|
485 |
#model_inputs["decoder_attention_mask"] = labels["attention_mask"]
|
486 |
|
487 |
return model_inputs
|