ShuxianZou's picture
Update config.yaml
3a28270 verified
# lightning.pytorch==2.4.0
seed_everything: 42
trainer:
accelerator: auto
strategy:
class_path: lightning.pytorch.strategies.DDPStrategy
init_args:
find_unused_parameters: true
devices: auto
num_nodes: 1
precision: 32
logger:
class_path: lightning.pytorch.loggers.WandbLogger
init_args:
name: isoform_expression
save_dir: logs
project: rna_tasks
callbacks:
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
dirpath: logs/
filename: best_val:{step}-{val_spearman:.3f}-{val_r2:.3f}
monitor: val_spearman
save_last: true
save_top_k: 1
mode: max
every_n_train_steps: 500
every_n_epochs: null
save_on_train_epoch_end: true
max_epochs: 3
max_steps: -1
val_check_interval: 499
check_val_every_n_epoch: null
log_every_n_steps: 1
accumulate_grad_batches: 1
gradient_clip_val: 0.1
gradient_clip_algorithm: null
default_root_dir: logs
model:
class_path: modelgenerator.tasks.MMSequenceRegression
init_args:
backbone:
class_path: modelgenerator.backbones.enformer
init_args:
max_length: 196_608
frozen: false
backbone1:
class_path: modelgenerator.backbones.aido_rna_1b600m_cds
init_args:
max_length: 1024
frozen: false
use_peft: true
save_peft_only: true
lora_r: 32
lora_alpha: 64
lora_dropout: 0.1
lora_target_modules:
- query
- value
config_overwrites:
hidden_dropout_prob: 0
attention_probs_dropout_prob: 0
model_init_args: null
backbone2:
class_path: modelgenerator.backbones.esm2_150m
init_args:
max_length: 1024
frozen: false
use_peft: true
save_peft_only: true
lora_r: 32
lora_alpha: 64
lora_dropout: 0.1
lora_target_modules:
- query
- value
config_overwrites:
hidden_dropout_prob: 0
attention_probs_dropout_prob: 0
model_init_args: null
backbone_order:
- dna_seq
- rna_seq
- protein_seq
adapter:
class_path: modelgenerator.adapters.fusion.MMFusionTokenAdapter
init_args:
fusion:
class_path: modelgenerator.adapters.fusion.ConcatFusion
init_args:
project_size: 1024
pooling: mean_pooling
adapter:
class_path: modelgenerator.adapters.MLPAdapter
init_args:
hidden_sizes:
- 1024
bias: true
dropout: 0.1
dropout_in_middle: false
num_outputs: 30
optimizer:
class_path: torch.optim.AdamW
init_args:
lr: 1e-4
betas:
- 0.9
- 0.999
eps: 1.0e-08
weight_decay: 0.01
amsgrad: false
maximize: false
foreach: null
capturable: false
differentiable: false
fused: null
lr_scheduler:
class_path: modelgenerator.lr_schedulers.CosineWithWarmup
init_args:
warmup_ratio: 0.01
num_warmup_steps: null
last_epoch: -1
verbose: deprecated
use_legacy_adapter: false
strict_loading: true
reset_optimizer_states: false
data:
class_path: modelgenerator.data.IsoformExpression
init_args:
path: genbio-ai/transcript_isoform_expression_prediction
config_name: null
valid_split_name: valid
train_split_files:
- train_*.tsv
valid_split_files:
- validation.tsv
test_split_files:
- test.tsv
x_col:
- rna_seq
- dna_seq
- protein_seq
normalize: true
random_seed: 42
batch_size: 2
shuffle: true
sampler: null
num_workers: 0
pin_memory: true
persistent_workers: false
cv_num_folds: 1
cv_test_fold_id: 0
cv_enable_val_fold: true
cv_fold_id_col: null
ckpt_path: null