|
local env = import "../env.jsonnet"; |
|
local base = import "ace.jsonnet"; |
|
|
|
local pretrained_path = env.str("PRETRAINED_PATH", "cache/ace/best"); |
|
local lr = env.json("FT_LR", 5e-5); |
|
|
|
# training |
|
local cuda_devices = base.cuda_devices; |
|
|
|
{ |
|
dataset_reader: base.dataset_reader, |
|
train_data_path: base.train_data_path, |
|
validation_data_path: base.validation_data_path, |
|
test_data_path: base.test_data_path, |
|
datasets_for_vocab_creation: ["train"], |
|
data_loader: base.data_loader, |
|
validation_data_loader: base.validation_data_loader, |
|
|
|
model: { |
|
type: "from_archive", |
|
archive_file: pretrained_path |
|
}, |
|
vocabulary: { |
|
type: "from_files", |
|
directory: pretrained_path + "/vocabulary" |
|
}, |
|
|
|
trainer: { |
|
num_epochs: base.trainer.num_epochs, |
|
patience: base.trainer.patience, |
|
[if std.length(cuda_devices) == 1 then "cuda_device"]: cuda_devices[0], |
|
validation_metric: "+arg-c_f", |
|
num_gradient_accumulation_steps: base.trainer.num_gradient_accumulation_steps, |
|
optimizer: { |
|
type: "transformer", |
|
base: { |
|
type: "adam", |
|
lr: lr, |
|
}, |
|
embeddings_lr: 0.0, |
|
encoder_lr: 1e-5, |
|
pooler_lr: 1e-5, |
|
layer_fix: base.trainer.optimizer.layer_fix, |
|
} |
|
}, |
|
|
|
[if std.length(cuda_devices) > 1 then "distributed"]: { |
|
"cuda_devices": cuda_devices |
|
}, |
|
[if std.length(cuda_devices) == 1 then "evaluate_on_test"]: true |
|
} |
|
|