Create scripts/train.py (#4)
Browse files- Create scripts/train.py (9321f687806c60589ef9d0181b30b8a7ccae6e60)
Co-authored-by: Ryan Keivanfar <[email protected]>
- scripts/train.py +437 -0
scripts/train.py
ADDED
@@ -0,0 +1,437 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# train.py
|
2 |
+
"""
|
3 |
+
Main training script for TahoeFormer.
|
4 |
+
|
5 |
+
This script handles:
|
6 |
+
- Loading configuration from a YAML file.
|
7 |
+
- Setting up logging (Weights & Biases).
|
8 |
+
- Initializing the model (LitEnformerSMILES, using Morgan Fingerprints).
|
9 |
+
- Initializing dataloaders (TahoeSMILESDataset).
|
10 |
+
- Setting up PyTorch Lightning Callbacks (ModelCheckpoint, EarlyStopping, MetricLogger).
|
11 |
+
- Running the training and testing loops using PyTorch Lightning Trainer.
|
12 |
+
"""
|
13 |
+
|
14 |
+
import argparse
|
15 |
+
import yaml
|
16 |
+
import os
|
17 |
+
import torch
|
18 |
+
import pandas as pd
|
19 |
+
import lightning.pytorch as pl
|
20 |
+
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint, LearningRateMonitor
|
21 |
+
from lightning.pytorch import Trainer
|
22 |
+
from lightning.pytorch.loggers import WandbLogger
|
23 |
+
from torch.utils.data import DataLoader, random_split
|
24 |
+
import wandb
|
25 |
+
|
26 |
+
from pl_models import LitEnformerSMILES, MetricLogger
|
27 |
+
from datasets import TahoeSMILESDataset, ENFORMER_INPUT_SEQ_LENGTH
|
28 |
+
|
29 |
+
import warnings
|
30 |
+
warnings.filterwarnings('ignore', '.*does not have many workers.*')
|
31 |
+
warnings.filterwarnings('ignore', '.*Detecting val_dataloader.*')
|
32 |
+
|
33 |
+
# --- Default Configs --- (can be overridden by config YAML)
|
34 |
+
DEFAULT_CONFIG = {
|
35 |
+
'data': {
|
36 |
+
'regions_csv_path': 'data/Enformer_genomic_regions_TSSCenteredGenes_FixedOverlapRemoval.csv',
|
37 |
+
'pbulk_parquet_path': 'data/pseudoBulk_celllineXdrug_top3k_for_testing.parquet',
|
38 |
+
'drug_meta_csv_path': 'data/drug_metadata.csv',
|
39 |
+
'fasta_file_path': 'data/hg38.fa',
|
40 |
+
'enformer_input_seq_length': 196_608,
|
41 |
+
'morgan_fp_radius': 2, # For TahoeSMILESDataset
|
42 |
+
'morgan_fp_nbits': 2048, # For TahoeSMILESDataset
|
43 |
+
'filter_drugs_by_ids': None,
|
44 |
+
# Column name defaults for TahoeSMILESDataset
|
45 |
+
'regions_gene_col': 'gene_id',
|
46 |
+
'regions_chr_col': 'seqnames',
|
47 |
+
'regions_start_col': 'start',
|
48 |
+
'regions_end_col': 'end',
|
49 |
+
'regions_strand_col': None,
|
50 |
+
'regions_set_col': 'set', # column name for train/val/test split in regions_csv
|
51 |
+
'pbulk_gene_col': 'gene_id',
|
52 |
+
'pbulk_dose_col': 'drug_dose',
|
53 |
+
'pbulk_expr_col': 'expression',
|
54 |
+
'pbulk_cell_line_col': 'cell_line_id',
|
55 |
+
'drug_meta_id_col': 'drug_id',
|
56 |
+
'drug_meta_smiles_col': 'canonical_smiles'
|
57 |
+
},
|
58 |
+
'model': {
|
59 |
+
'enformer_model_name': 'EleutherAI/enformer-official-rough',
|
60 |
+
'enformer_target_length': -1,
|
61 |
+
'morgan_fingerprint_dim': 2048,
|
62 |
+
'dose_input_dim': 1,
|
63 |
+
'fusion_hidden_dim': 256,
|
64 |
+
'final_output_tracks': 1,
|
65 |
+
'learning_rate': 5e-6,
|
66 |
+
'loss_alpha': 1.0,
|
67 |
+
'weight_decay': 0.01,
|
68 |
+
'eval_gene_sets': None
|
69 |
+
},
|
70 |
+
'training': {
|
71 |
+
'batch_size': 2,
|
72 |
+
'num_workers': 0,
|
73 |
+
'pin_memory': False,
|
74 |
+
'max_epochs': 50,
|
75 |
+
'gpus': -1, # -1 for all available GPUs, or specify count e.g., 1, 2
|
76 |
+
'accelerator': 'auto',
|
77 |
+
'strategy': 'ddp_find_unused_parameters_true',
|
78 |
+
'precision': '16-mixed', # '32' or '16-mixed' or 'bf16-mixed'
|
79 |
+
'val_check_interval': 1.0,
|
80 |
+
'limit_train_batches': 1.0,
|
81 |
+
'limit_val_batches': 1.0,
|
82 |
+
'limit_test_batches': 1.0,
|
83 |
+
'deterministic': True,
|
84 |
+
'seed': 42
|
85 |
+
},
|
86 |
+
'logging': {
|
87 |
+
'wandb_project': 'TahoeformerDebug',
|
88 |
+
'wandb_entity': None, # W&B info (username or team)
|
89 |
+
'save_dir': 'outputs/model_checkpoints',
|
90 |
+
'checkpoint_monitor_metric': 'validation_pearson_epoch',
|
91 |
+
'checkpoint_monitor_mode': 'max',
|
92 |
+
'save_top_k': 1,
|
93 |
+
'early_stopping_metric': 'validation_pearson_epoch',
|
94 |
+
'early_stopping_mode': 'max',
|
95 |
+
'early_stopping_patience': 10
|
96 |
+
}
|
97 |
+
}
|
98 |
+
|
99 |
+
def delete_checkpoint_at_end(trainer):
|
100 |
+
""" Delete checkpoint after training and testing if desired """
|
101 |
+
checkpoint_callbacks = [cb for cb in trainer.callbacks if isinstance(cb, ModelCheckpoint)]
|
102 |
+
if checkpoint_callbacks:
|
103 |
+
checkpoint_callback = checkpoint_callbacks[0]
|
104 |
+
if hasattr(checkpoint_callback, 'best_model_path') and checkpoint_callback.best_model_path and os.path.exists(checkpoint_callback.best_model_path):
|
105 |
+
print(f"Deleting best checkpoint: {checkpoint_callback.best_model_path}")
|
106 |
+
os.remove(checkpoint_callback.best_model_path)
|
107 |
+
else:
|
108 |
+
print("No best model path found to delete or path does not exist.")
|
109 |
+
else:
|
110 |
+
print("No ModelCheckpoint callback found.")
|
111 |
+
|
112 |
+
def parse_optional_gene_list(filepath):
|
113 |
+
""" Parses a file containing one gene name per row, returns a list. Returns empty list if path is None or invalid. """
|
114 |
+
if filepath is None or not os.path.exists(filepath):
|
115 |
+
return []
|
116 |
+
gene_list = []
|
117 |
+
with open(filepath, 'r') as file:
|
118 |
+
for gene in file:
|
119 |
+
gene_list.append(gene.strip())
|
120 |
+
return gene_list
|
121 |
+
|
122 |
+
def load_config(config_path=None):
|
123 |
+
"""Loads configuration from YAML, merging with defaults. Ensures deep copy of defaults."""
|
124 |
+
# Manual deep copy for 2 levels, as DEFAULT_CONFIG is structured
|
125 |
+
config = {}
|
126 |
+
for k, v in DEFAULT_CONFIG.items():
|
127 |
+
if isinstance(v, dict):
|
128 |
+
config[k] = v.copy() # Copies the inner dictionary
|
129 |
+
else:
|
130 |
+
config[k] = v
|
131 |
+
|
132 |
+
if config_path:
|
133 |
+
with open(config_path, 'r') as f:
|
134 |
+
user_config = yaml.safe_load(f)
|
135 |
+
if user_config: # Ensure user_config is not None (e.g. if YAML is empty)
|
136 |
+
for key, value in user_config.items():
|
137 |
+
if isinstance(value, dict) and key in config and isinstance(config[key], dict):
|
138 |
+
config[key].update(value) # Merge level 2 dicts
|
139 |
+
else:
|
140 |
+
config[key] = value # Overwrite or add new keys/values
|
141 |
+
return config
|
142 |
+
|
143 |
+
def build_model(config):
|
144 |
+
"""
|
145 |
+
Builds the LitEnformerSMILES model using Morgan Fingerprints.
|
146 |
+
Model parameters are sourced from the 'model' section of the config.
|
147 |
+
"""
|
148 |
+
model_params = config['model']
|
149 |
+
# Ensure morgan_fingerprint_dim from data config (for dataset) matches model config
|
150 |
+
# Model will use its own `morgan_fingerprint_dim` parameter.
|
151 |
+
# The dataset's `morgan_fp_nbits` should align with this.
|
152 |
+
print(f"Building LitEnformerSMILES model with morgan_fingerprint_dim: {model_params.get('morgan_fingerprint_dim')}")
|
153 |
+
|
154 |
+
return LitEnformerSMILES(
|
155 |
+
enformer_model_name=model_params.get('enformer_model_name'),
|
156 |
+
enformer_target_length=model_params.get('enformer_target_length'),
|
157 |
+
num_output_tracks_enformer_head=model_params.get('num_output_tracks_enformer_head'),
|
158 |
+
morgan_fingerprint_dim=model_params.get('morgan_fingerprint_dim', 2048), # Default from model if not in config
|
159 |
+
dose_input_dim=model_params.get('dose_input_dim'),
|
160 |
+
fusion_hidden_dim=model_params.get('fusion_hidden_dim'),
|
161 |
+
final_output_tracks=model_params.get('final_output_tracks'),
|
162 |
+
learning_rate=model_params.get('learning_rate'),
|
163 |
+
loss_alpha=model_params.get('loss_alpha'),
|
164 |
+
weight_decay=model_params.get('weight_decay'),
|
165 |
+
eval_gene_sets=model_params.get('eval_gene_sets')
|
166 |
+
)
|
167 |
+
|
168 |
+
def load_tahoe_smiles_dataloaders(config):
|
169 |
+
"""
|
170 |
+
Initializes TahoeSMILESDataset (now using Morgan Fingerprints) and creates DataLoaders.
|
171 |
+
Dataset parameters are sourced from the 'data' section of the config.
|
172 |
+
Training parameters (batch_size, num_workers) from 'training' section.
|
173 |
+
"""
|
174 |
+
data_config = config['data']
|
175 |
+
train_config = config['training']
|
176 |
+
|
177 |
+
# Pass Morgan fingerprint params to TahoeSMILESDataset
|
178 |
+
dataset_args = {
|
179 |
+
'regions_csv_path': data_config['regions_csv_path'],
|
180 |
+
'pbulk_parquet_path': data_config['pbulk_parquet_path'],
|
181 |
+
'drug_meta_csv_path': data_config['drug_meta_csv_path'],
|
182 |
+
'fasta_file_path': data_config['fasta_file_path'],
|
183 |
+
'enformer_input_seq_length': data_config.get('enformer_input_seq_length'),
|
184 |
+
'morgan_fp_radius': data_config.get('morgan_fp_radius', 2),
|
185 |
+
'morgan_fp_nbits': data_config.get('morgan_fp_nbits', 2048),
|
186 |
+
'filter_drugs_by_ids': data_config.get('filter_drugs_by_ids'),
|
187 |
+
# Pass column name configurations
|
188 |
+
'regions_gene_col': data_config.get('regions_gene_col', 'gene_name'),
|
189 |
+
'regions_chr_col': data_config.get('regions_chr_col', 'seqnames'),
|
190 |
+
'regions_start_col': data_config.get('regions_start_col', 'starts'),
|
191 |
+
'regions_end_col': data_config.get('regions_end_col', 'ends'),
|
192 |
+
'regions_strand_col': data_config.get('regions_strand_col', None),
|
193 |
+
'regions_set_col': data_config.get('regions_set_col', 'set'), # Added for set-based splitting
|
194 |
+
'pbulk_gene_col': data_config.get('pbulk_gene_col', 'gene_id'),
|
195 |
+
'pbulk_dose_col': data_config.get('pbulk_dose_col', 'dose_nM'),
|
196 |
+
'pbulk_expr_col': data_config.get('pbulk_expr_col', 'value'),
|
197 |
+
'pbulk_cell_line_col': data_config.get('pbulk_cell_line_col', 'cell_line_id'),
|
198 |
+
'drug_meta_id_col': data_config.get('drug_meta_id_col', 'drug_id'),
|
199 |
+
'drug_meta_smiles_col': data_config.get('drug_meta_smiles_col', 'canonical_smiles')
|
200 |
+
}
|
201 |
+
|
202 |
+
# print(f"Initializing TahoeSMILESDataset with morgan_fp_nbits: {dataset_args['morgan_fp_nbits']}")
|
203 |
+
|
204 |
+
# Instantiate dataset for each split using the 'target_set' parameter
|
205 |
+
print("Initializing train dataset...")
|
206 |
+
train_dataset = TahoeSMILESDataset(**dataset_args, target_set='train')
|
207 |
+
print("Initializing validation dataset...")
|
208 |
+
val_dataset = TahoeSMILESDataset(**dataset_args, target_set='valid')
|
209 |
+
# In the original Enformer_genomic_regions_TSSCenteredGenes_FixedOverlapRemoval.csv,
|
210 |
+
# the validation set is often named 'valid'. If it's 'validation' in your file, adjust accordingly.
|
211 |
+
print("Initializing test dataset...")
|
212 |
+
test_dataset = TahoeSMILESDataset(**dataset_args, target_set='test')
|
213 |
+
|
214 |
+
|
215 |
+
if len(train_dataset) == 0:
|
216 |
+
print("WARNING: Train dataset is empty. This could be due to filtering by set='train' or other data issues. Training might fail or be skipped.")
|
217 |
+
if len(val_dataset) == 0:
|
218 |
+
print("WARNING: Validation dataset is empty (set='valid'). Validation loop will likely be skipped.")
|
219 |
+
if len(test_dataset) == 0:
|
220 |
+
print("WARNING: Test dataset is empty (set='test'). Testing loop will likely be skipped.")
|
221 |
+
|
222 |
+
train_loader = DataLoader(
|
223 |
+
train_dataset,
|
224 |
+
batch_size=train_config.get('batch_size', 2),
|
225 |
+
shuffle=True,
|
226 |
+
num_workers=train_config.get('num_workers', 0),
|
227 |
+
pin_memory=train_config.get('pin_memory', False),
|
228 |
+
drop_last=True # Important for DDP and BatchNorm if batches can be size 1 per GPU
|
229 |
+
)
|
230 |
+
val_loader = DataLoader(
|
231 |
+
val_dataset,
|
232 |
+
batch_size=train_config.get('batch_size', 2) * 2, # Often use larger batch for val
|
233 |
+
shuffle=False,
|
234 |
+
num_workers=train_config.get('num_workers', 0),
|
235 |
+
pin_memory=train_config.get('pin_memory', False)
|
236 |
+
)
|
237 |
+
test_loader = DataLoader(
|
238 |
+
test_dataset,
|
239 |
+
batch_size=train_config.get('batch_size', 2) * 2,
|
240 |
+
shuffle=False,
|
241 |
+
num_workers=train_config.get('num_workers', 0),
|
242 |
+
pin_memory=train_config.get('pin_memory', False)
|
243 |
+
)
|
244 |
+
return train_loader, val_loader, test_loader
|
245 |
+
|
246 |
+
def load_trainer_and_callbacks(config, experiment_name_for_wandb, run_name_for_wandb):
|
247 |
+
""" Loads PyTorch Lightning Trainer and associated callbacks. """
|
248 |
+
|
249 |
+
|
250 |
+
metric_logger = MetricLogger(save_dir_prefix=os.path.join(config['logging']['save_dir'], "metrics"))
|
251 |
+
|
252 |
+
|
253 |
+
checkpoint_dir = os.path.join(config['logging']['save_dir'], 'checkpoints')
|
254 |
+
os.makedirs(checkpoint_dir, exist_ok=True)
|
255 |
+
|
256 |
+
monitor_metric = config['logging']['checkpoint_monitor_metric'] # MetricLogger logs with _epoch suffix
|
257 |
+
monitor_mode = config['logging']['checkpoint_monitor_mode']
|
258 |
+
save_top_k = config['logging'].get('save_top_k', 1)
|
259 |
+
|
260 |
+
checkpoint_callback = ModelCheckpoint(
|
261 |
+
dirpath=checkpoint_dir,
|
262 |
+
filename=f"{{epoch}}-{{{monitor_metric}:.4f}}",
|
263 |
+
save_top_k=save_top_k,
|
264 |
+
monitor=monitor_metric,
|
265 |
+
mode=monitor_mode
|
266 |
+
)
|
267 |
+
|
268 |
+
|
269 |
+
early_stop_monitor_metric = config['logging']['early_stopping_metric']
|
270 |
+
early_stop_mode = config['logging']['early_stopping_mode']
|
271 |
+
min_delta = 0.001
|
272 |
+
patience = config['logging']['early_stopping_patience']
|
273 |
+
|
274 |
+
early_stopping_callback = EarlyStopping(
|
275 |
+
monitor=early_stop_monitor_metric,
|
276 |
+
min_delta=min_delta,
|
277 |
+
patience=patience,
|
278 |
+
verbose=True,
|
279 |
+
mode=early_stop_mode
|
280 |
+
)
|
281 |
+
|
282 |
+
lr_monitor = LearningRateMonitor(logging_interval='step')
|
283 |
+
|
284 |
+
callbacks = [checkpoint_callback, metric_logger, early_stopping_callback, lr_monitor]
|
285 |
+
|
286 |
+
|
287 |
+
wandb_logger = None
|
288 |
+
if config['logging'].get('wandb_project'):
|
289 |
+
wandb_logger = WandbLogger(
|
290 |
+
name=run_name_for_wandb,
|
291 |
+
project=config['logging']['wandb_project'],
|
292 |
+
group=experiment_name_for_wandb,
|
293 |
+
config=config, # Log the entire config dictionary
|
294 |
+
save_dir=config['logging']['save_dir'], # Optional: ensure logger saves to the same base dir
|
295 |
+
id=run_name_for_wandb # Use the unique run_name as the W&B run ID
|
296 |
+
)
|
297 |
+
|
298 |
+
trainer = Trainer(
|
299 |
+
max_epochs=config['training']['max_epochs'],
|
300 |
+
precision=config['training']['precision'],
|
301 |
+
accumulate_grad_batches=config['training'].get('accumulate_grad_batches', 1),
|
302 |
+
gradient_clip_val=config['training'].get('gradient_clip_val', 0.5),
|
303 |
+
callbacks=callbacks,
|
304 |
+
logger=wandb_logger, # Use the configured logger
|
305 |
+
num_sanity_val_steps=config['training'].get('num_sanity_val_steps', 0), # Often 0 if val metrics are complex
|
306 |
+
log_every_n_steps=config['training'].get('log_every_n_steps', 50),
|
307 |
+
check_val_every_n_epoch=config['training'].get('check_val_every_n_epoch', 1),
|
308 |
+
deterministic=config['training']['deterministic'], # For reproducibility
|
309 |
+
strategy=config['training']['strategy'],
|
310 |
+
accelerator=config['training']['accelerator'],
|
311 |
+
devices=config['training'].get('gpus', 'auto')
|
312 |
+
)
|
313 |
+
|
314 |
+
|
315 |
+
if config['training'].get('accumulate_grad_batches', 1) > 1:
|
316 |
+
effective_batch_size = config['training']['batch_size'] * config['training'].get('accumulate_grad_batches', 1)
|
317 |
+
print(f"Gradient Accumulation: Effective batch size will be {effective_batch_size}")
|
318 |
+
# Log to wandb config if logger is active and wandb.run exists
|
319 |
+
if wandb_logger and wandb.run:
|
320 |
+
wandb.config.update({'effective_train_batch_size': effective_batch_size}, allow_val_change=True)
|
321 |
+
|
322 |
+
return trainer
|
323 |
+
|
324 |
+
def run_experiment(config: wandb.config):
|
325 |
+
""" Main training and evaluation loop. """
|
326 |
+
print("Starting experiment with configuration:")
|
327 |
+
for key, value in config.items():
|
328 |
+
print(f" {key}: {value}")
|
329 |
+
|
330 |
+
|
331 |
+
train_loader, val_loader, test_loader = load_tahoe_smiles_dataloaders(config)
|
332 |
+
|
333 |
+
|
334 |
+
eval_gene_sets = {
|
335 |
+
'train_eval_set': parse_optional_gene_list(config.get('eval_train_gene_path')),
|
336 |
+
'valid_eval_set': parse_optional_gene_list(config.get('eval_valid_gene_path')),
|
337 |
+
'test_eval_set': parse_optional_gene_list(config.get('eval_test_gene_path'))
|
338 |
+
}
|
339 |
+
eval_gene_sets = {k: v for k, v in eval_gene_sets.items() if v} # Keep only non-empty lists
|
340 |
+
|
341 |
+
|
342 |
+
model = build_model(config)
|
343 |
+
|
344 |
+
|
345 |
+
experiment_name = config.get('experiment_name', 'DefaultExperiment')
|
346 |
+
run_name = config.get('run_name', f"{experiment_name}_default_run_id") # Fallback run_name
|
347 |
+
|
348 |
+
trainer = load_trainer_and_callbacks(config, experiment_name, run_name)
|
349 |
+
|
350 |
+
|
351 |
+
if config.get('validate_before_train', False) and val_loader.dataset:
|
352 |
+
print("Running pre-training validation loop...")
|
353 |
+
trainer.validate(model, dataloaders=val_loader)
|
354 |
+
|
355 |
+
|
356 |
+
print("Starting training...")
|
357 |
+
if train_loader.dataset:
|
358 |
+
trainer.fit(
|
359 |
+
model=model,
|
360 |
+
train_dataloaders=train_loader,
|
361 |
+
val_dataloaders=val_loader if val_loader.dataset else None
|
362 |
+
)
|
363 |
+
else:
|
364 |
+
print("Skipping training as train_loader is empty.")
|
365 |
+
|
366 |
+
|
367 |
+
print("Starting testing...")
|
368 |
+
if test_loader.dataset:
|
369 |
+
|
370 |
+
best_model_path = trainer.checkpoint_callback.best_model_path if hasattr(trainer.checkpoint_callback, 'best_model_path') else None
|
371 |
+
if best_model_path and os.path.exists(best_model_path):
|
372 |
+
print(f"Loading best model for testing from: {best_model_path}")
|
373 |
+
trainer.test(model, dataloaders=test_loader, ckpt_path=best_model_path)
|
374 |
+
elif not best_model_path:
|
375 |
+
print("No best_model_path found from checkpoint callback. Testing with current model state (if any training happened).")
|
376 |
+
trainer.test(model, dataloaders=test_loader) # Test with current model if no checkpoint or if training was skipped
|
377 |
+
else: # path exists but is false for some reason or doesnt exist
|
378 |
+
print(f"Best model path {best_model_path} not found. Testing with current model state.")
|
379 |
+
trainer.test(model, dataloaders=test_loader)
|
380 |
+
else:
|
381 |
+
print("Skipping testing as test_loader is empty.")
|
382 |
+
|
383 |
+
if config.get('delete_checkpoint_after_run', False):
|
384 |
+
delete_checkpoint_at_end(trainer)
|
385 |
+
|
386 |
+
|
387 |
+
def main():
|
388 |
+
parser = argparse.ArgumentParser(description='Run PyTorch Lightning Enformer-SMILES experiment.')
|
389 |
+
parser.add_argument('--config_path', type=str, required=True, help='Path to the YAML configuration file.')
|
390 |
+
# allow seed override from command line, though config is primary source
|
391 |
+
parser.add_argument("--seed", type=int, help="Override seed from config file.")
|
392 |
+
args = parser.parse_args()
|
393 |
+
|
394 |
+
|
395 |
+
effective_config = load_config(args.config_path)
|
396 |
+
|
397 |
+
|
398 |
+
if args.seed is not None:
|
399 |
+
# Ensure 'training' key exists if seed is to be put there
|
400 |
+
if 'training' not in effective_config:
|
401 |
+
effective_config['training'] = {}
|
402 |
+
effective_config['training']['seed'] = args.seed
|
403 |
+
|
404 |
+
|
405 |
+
seed = effective_config.get('training', {}).get('seed', 42)
|
406 |
+
pl.seed_everything(seed, workers=True)
|
407 |
+
|
408 |
+
|
409 |
+
if effective_config.get('training', {}).get('deterministic', False):
|
410 |
+
torch.use_deterministic_algorithms(True, warn_only=True) # Ensure deterministic ops if requested
|
411 |
+
|
412 |
+
current_script_dir = os.path.dirname(os.path.abspath(__file__))
|
413 |
+
|
414 |
+
experiment_name = effective_config.get('experiment_name', 'EnformerSMILESExperiment')
|
415 |
+
|
416 |
+
|
417 |
+
run_name = f"{experiment_name}_Seed-{seed}_{pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')}"
|
418 |
+
effective_config['run_name'] = run_name
|
419 |
+
effective_config['experiment_name'] = experiment_name # Ensure experiment_name is also in config
|
420 |
+
|
421 |
+
|
422 |
+
default_save_dir = os.path.join(current_script_dir, f"../results/{experiment_name}/{run_name}")
|
423 |
+
|
424 |
+
if 'logging' not in effective_config:
|
425 |
+
effective_config['logging'] = {}
|
426 |
+
effective_config['logging']['save_dir'] = effective_config.get('logging', {}).get('save_dir', default_save_dir)
|
427 |
+
os.makedirs(effective_config['logging']['save_dir'], exist_ok=True)
|
428 |
+
print(f"Results and checkpoints will be saved in: {effective_config['logging']['save_dir']}")
|
429 |
+
|
430 |
+
|
431 |
+
run_experiment(effective_config)
|
432 |
+
|
433 |
+
if effective_config.get('logging', {}).get('wandb_project') and wandb.run is not None:
|
434 |
+
wandb.finish()
|
435 |
+
|
436 |
+
if __name__ == '__main__':
|
437 |
+
main()
|