Upload run_summarization_flax.py
Browse files- run_summarization_flax.py +19 -17
run_summarization_flax.py
CHANGED
|
@@ -431,23 +431,25 @@ def main():
|
|
| 431 |
return
|
| 432 |
|
| 433 |
# Get the column names for input/target.
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
|
|
|
|
|
|
| 451 |
|
| 452 |
# Temporarily set max_target_length for training.
|
| 453 |
max_target_length = data_args.max_target_length
|
|
|
|
| 431 |
return
|
| 432 |
|
| 433 |
# Get the column names for input/target.
|
| 434 |
+
if not data_args.pretokenized:
|
| 435 |
+
|
| 436 |
+
dataset_columns = summarization_name_mapping.get(data_args.dataset_name, None)
|
| 437 |
+
if data_args.text_column is None:
|
| 438 |
+
text_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
|
| 439 |
+
else:
|
| 440 |
+
text_column = data_args.text_column
|
| 441 |
+
if text_column not in column_names:
|
| 442 |
+
raise ValueError(
|
| 443 |
+
f"--text_column' value '{data_args.text_column}' needs to be one of: {', '.join(column_names)}"
|
| 444 |
+
)
|
| 445 |
+
if data_args.summary_column is None:
|
| 446 |
+
summary_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
|
| 447 |
+
else:
|
| 448 |
+
summary_column = data_args.summary_column
|
| 449 |
+
if summary_column not in column_names:
|
| 450 |
+
raise ValueError(
|
| 451 |
+
f"--summary_column' value '{data_args.summary_column}' needs to be one of: {', '.join(column_names)}"
|
| 452 |
+
)
|
| 453 |
|
| 454 |
# Temporarily set max_target_length for training.
|
| 455 |
max_target_length = data_args.max_target_length
|