Spaces:
Running
Running
feat: simplify parameters
Browse files- dev/seq2seq/run_seq2seq_flax.py +20 -47
dev/seq2seq/run_seq2seq_flax.py
CHANGED
|
@@ -151,7 +151,7 @@ class DataTrainingArguments:
|
|
| 151 |
"than this will be truncated, sequences shorter will be padded."
|
| 152 |
},
|
| 153 |
)
|
| 154 |
-
|
| 155 |
default=False,
|
| 156 |
metadata={"help": "Whether to use decay in the learning rate scheduler."},
|
| 157 |
)
|
|
@@ -170,18 +170,16 @@ class DataTrainingArguments:
|
|
| 170 |
},
|
| 171 |
)
|
| 172 |
preprocessing_num_workers: Optional[int] = field(
|
| 173 |
-
default=80, # ensure we have the same datasets cached data and avoid using too much space
|
| 174 |
-
metadata={"help": "The number of processes to use for the preprocessing."},
|
| 175 |
-
)
|
| 176 |
-
source_prefix: Optional[str] = field(
|
| 177 |
default=None,
|
| 178 |
metadata={
|
| 179 |
-
"help": "
|
| 180 |
},
|
| 181 |
)
|
| 182 |
overwrite_cache: bool = field(
|
| 183 |
default=False,
|
| 184 |
-
metadata={
|
|
|
|
|
|
|
| 185 |
)
|
| 186 |
log_interval: Optional[int] = field(
|
| 187 |
default=40,
|
|
@@ -189,41 +187,16 @@ class DataTrainingArguments:
|
|
| 189 |
)
|
| 190 |
log_model: bool = field(
|
| 191 |
default=False,
|
| 192 |
-
metadata={"help": "
|
| 193 |
)
|
| 194 |
save_model_steps: Optional[int] = field(
|
| 195 |
-
default=5000,
|
| 196 |
-
metadata={
|
| 197 |
-
"help": "For logging the model more frequently. Used only when `log_model` is set."
|
| 198 |
-
},
|
| 199 |
)
|
| 200 |
|
| 201 |
def __post_init__(self):
|
| 202 |
if self.dataset_repo_or_path is None:
|
| 203 |
raise ValueError("Need a dataset repository or path.")
|
| 204 |
-
if self.train_file is None or self.validation_file is None:
|
| 205 |
-
raise ValueError("Need training/validation file.")
|
| 206 |
-
else:
|
| 207 |
-
if self.train_file is not None:
|
| 208 |
-
extension = self.train_file.split(".")[-1]
|
| 209 |
-
assert extension in [
|
| 210 |
-
"tsv",
|
| 211 |
-
"csv",
|
| 212 |
-
"json",
|
| 213 |
-
"jsonl",
|
| 214 |
-
], "`train_file` should be a tsv, csv or json file."
|
| 215 |
-
if self.validation_file is not None:
|
| 216 |
-
extension = self.validation_file.split(".")[-1]
|
| 217 |
-
assert extension in [
|
| 218 |
-
"tsv",
|
| 219 |
-
"csv",
|
| 220 |
-
"json",
|
| 221 |
-
"jsonl",
|
| 222 |
-
], "`validation_file` should be a tsv, csv or json file."
|
| 223 |
-
if self.streaming and (self.len_train is None or self.len_eval is None):
|
| 224 |
-
raise ValueError(
|
| 225 |
-
"Streaming requires providing length of training and validation datasets"
|
| 226 |
-
)
|
| 227 |
|
| 228 |
|
| 229 |
class TrainState(train_state.TrainState):
|
|
@@ -291,7 +264,7 @@ def create_learning_rate_fn(
|
|
| 291 |
num_train_epochs: int,
|
| 292 |
num_warmup_steps: int,
|
| 293 |
learning_rate: float,
|
| 294 |
-
|
| 295 |
) -> Callable[[int], jnp.array]:
|
| 296 |
"""Returns a linear warmup, linear_decay learning rate function."""
|
| 297 |
steps_per_epoch = train_ds_size // train_batch_size
|
|
@@ -299,7 +272,7 @@ def create_learning_rate_fn(
|
|
| 299 |
warmup_fn = optax.linear_schedule(
|
| 300 |
init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps
|
| 301 |
)
|
| 302 |
-
if
|
| 303 |
return warmup_fn
|
| 304 |
decay_fn = optax.linear_schedule(
|
| 305 |
init_value=learning_rate,
|
|
@@ -372,10 +345,13 @@ def main():
|
|
| 372 |
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
|
| 373 |
# (the dataset will be downloaded automatically from the datasets Hub).
|
| 374 |
#
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
|
|
|
|
|
|
|
|
|
| 379 |
dataset = load_dataset(
|
| 380 |
data_args.dataset_repo_or_path,
|
| 381 |
data_files=data_files,
|
|
@@ -449,8 +425,6 @@ def main():
|
|
| 449 |
print(f"TPUs: {jax.device_count()}")
|
| 450 |
assert jax.device_count() == 8, "TPUs in use, please check running processes"
|
| 451 |
|
| 452 |
-
prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
|
| 453 |
-
|
| 454 |
# Preprocessing the datasets.
|
| 455 |
# We need to tokenize inputs and targets.
|
| 456 |
|
|
@@ -475,7 +449,6 @@ def main():
|
|
| 475 |
|
| 476 |
def preprocess_function(examples):
|
| 477 |
inputs = examples[text_column]
|
| 478 |
-
inputs = [prefix + inp for inp in inputs] if prefix else inputs
|
| 479 |
# Setting padding="max_length" as we need fixed length inputs for jitted functions
|
| 480 |
model_inputs = tokenizer(
|
| 481 |
inputs,
|
|
@@ -617,7 +590,7 @@ def main():
|
|
| 617 |
training_args.num_train_epochs,
|
| 618 |
training_args.warmup_steps,
|
| 619 |
training_args.learning_rate,
|
| 620 |
-
data_args.
|
| 621 |
)
|
| 622 |
|
| 623 |
# We use Optax's "masking" functionality to not apply weight decay
|
|
@@ -625,8 +598,6 @@ def main():
|
|
| 625 |
# mask boolean with the same structure as the parameters.
|
| 626 |
# The mask is True for parameters that should be decayed.
|
| 627 |
# Note that this mask is specifically adapted for FlaxBart.
|
| 628 |
-
# For FlaxT5, one should correct the layer norm parameter naming
|
| 629 |
-
# accordingly - see `run_t5_mlm_flax.py` e.g.
|
| 630 |
def decay_mask_fn(params):
|
| 631 |
flat_params = traverse_util.flatten_dict(params)
|
| 632 |
layer_norm_params = [
|
|
@@ -649,6 +620,8 @@ def main():
|
|
| 649 |
# For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
|
| 650 |
optimizer = optax.adafactor(
|
| 651 |
learning_rate=learning_rate_fn,
|
|
|
|
|
|
|
| 652 |
)
|
| 653 |
else:
|
| 654 |
optimizer = optax.adamw(
|
|
|
|
| 151 |
"than this will be truncated, sequences shorter will be padded."
|
| 152 |
},
|
| 153 |
)
|
| 154 |
+
use_decay: bool = field(
|
| 155 |
default=False,
|
| 156 |
metadata={"help": "Whether to use decay in the learning rate scheduler."},
|
| 157 |
)
|
|
|
|
| 170 |
},
|
| 171 |
)
|
| 172 |
preprocessing_num_workers: Optional[int] = field(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
default=None,
|
| 174 |
metadata={
|
| 175 |
+
"help": "The number of processes to use for the preprocessing. Not used in streaming mode."
|
| 176 |
},
|
| 177 |
)
|
| 178 |
overwrite_cache: bool = field(
|
| 179 |
default=False,
|
| 180 |
+
metadata={
|
| 181 |
+
"help": "Overwrite the cached training and evaluation sets. Not used in streaming mode."
|
| 182 |
+
},
|
| 183 |
)
|
| 184 |
log_interval: Optional[int] = field(
|
| 185 |
default=40,
|
|
|
|
| 187 |
)
|
| 188 |
log_model: bool = field(
|
| 189 |
default=False,
|
| 190 |
+
metadata={"help": "Log frequency for model"},
|
| 191 |
)
|
| 192 |
save_model_steps: Optional[int] = field(
|
| 193 |
+
default=5000,
|
| 194 |
+
metadata={"help": "For saving/logging the model more frequently"},
|
|
|
|
|
|
|
| 195 |
)
|
| 196 |
|
| 197 |
def __post_init__(self):
|
| 198 |
if self.dataset_repo_or_path is None:
|
| 199 |
raise ValueError("Need a dataset repository or path.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
|
| 201 |
|
| 202 |
class TrainState(train_state.TrainState):
|
|
|
|
| 264 |
num_train_epochs: int,
|
| 265 |
num_warmup_steps: int,
|
| 266 |
learning_rate: float,
|
| 267 |
+
use_decay: bool,
|
| 268 |
) -> Callable[[int], jnp.array]:
|
| 269 |
"""Returns a linear warmup, linear_decay learning rate function."""
|
| 270 |
steps_per_epoch = train_ds_size // train_batch_size
|
|
|
|
| 272 |
warmup_fn = optax.linear_schedule(
|
| 273 |
init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps
|
| 274 |
)
|
| 275 |
+
if not use_decay:
|
| 276 |
return warmup_fn
|
| 277 |
decay_fn = optax.linear_schedule(
|
| 278 |
init_value=learning_rate,
|
|
|
|
| 345 |
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
|
| 346 |
# (the dataset will be downloaded automatically from the datasets Hub).
|
| 347 |
#
|
| 348 |
+
if data_args.train_file is not None or data_args.validation_file is not None:
|
| 349 |
+
data_files = {
|
| 350 |
+
"train": data_args.train_file,
|
| 351 |
+
"validation": data_args.validation_file,
|
| 352 |
+
}
|
| 353 |
+
else:
|
| 354 |
+
data_files = None
|
| 355 |
dataset = load_dataset(
|
| 356 |
data_args.dataset_repo_or_path,
|
| 357 |
data_files=data_files,
|
|
|
|
| 425 |
print(f"TPUs: {jax.device_count()}")
|
| 426 |
assert jax.device_count() == 8, "TPUs in use, please check running processes"
|
| 427 |
|
|
|
|
|
|
|
| 428 |
# Preprocessing the datasets.
|
| 429 |
# We need to tokenize inputs and targets.
|
| 430 |
|
|
|
|
| 449 |
|
| 450 |
def preprocess_function(examples):
|
| 451 |
inputs = examples[text_column]
|
|
|
|
| 452 |
# Setting padding="max_length" as we need fixed length inputs for jitted functions
|
| 453 |
model_inputs = tokenizer(
|
| 454 |
inputs,
|
|
|
|
| 590 |
training_args.num_train_epochs,
|
| 591 |
training_args.warmup_steps,
|
| 592 |
training_args.learning_rate,
|
| 593 |
+
data_args.use_decay,
|
| 594 |
)
|
| 595 |
|
| 596 |
# We use Optax's "masking" functionality to not apply weight decay
|
|
|
|
| 598 |
# mask boolean with the same structure as the parameters.
|
| 599 |
# The mask is True for parameters that should be decayed.
|
| 600 |
# Note that this mask is specifically adapted for FlaxBart.
|
|
|
|
|
|
|
| 601 |
def decay_mask_fn(params):
|
| 602 |
flat_params = traverse_util.flatten_dict(params)
|
| 603 |
layer_norm_params = [
|
|
|
|
| 620 |
# For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
|
| 621 |
optimizer = optax.adafactor(
|
| 622 |
learning_rate=learning_rate_fn,
|
| 623 |
+
weight_decay_rate=training_args.weight_decay,
|
| 624 |
+
weight_decay_mask=decay_mask_fn
|
| 625 |
)
|
| 626 |
else:
|
| 627 |
optimizer = optax.adamw(
|