qhuang20 ryankeivanfar commited on
Commit
279b7c3
·
verified ·
1 Parent(s): f602cd0

Create scripts/train.py (#4)

Browse files

- Create scripts/train.py (9321f687806c60589ef9d0181b30b8a7ccae6e60)


Co-authored-by: Ryan Keivanfar <[email protected]>

Files changed (1) hide show
  1. scripts/train.py +437 -0
scripts/train.py ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # train.py
2
+ """
3
+ Main training script for TahoeFormer.
4
+
5
+ This script handles:
6
+ - Loading configuration from a YAML file.
7
+ - Setting up logging (Weights & Biases).
8
+ - Initializing the model (LitEnformerSMILES, using Morgan Fingerprints).
9
+ - Initializing dataloaders (TahoeSMILESDataset).
10
+ - Setting up PyTorch Lightning Callbacks (ModelCheckpoint, EarlyStopping, MetricLogger).
11
+ - Running the training and testing loops using PyTorch Lightning Trainer.
12
+ """
13
+
14
+ import argparse
15
+ import yaml
16
+ import os
17
+ import torch
18
+ import pandas as pd
19
+ import lightning.pytorch as pl
20
+ from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint, LearningRateMonitor
21
+ from lightning.pytorch import Trainer
22
+ from lightning.pytorch.loggers import WandbLogger
23
+ from torch.utils.data import DataLoader, random_split
24
+ import wandb
25
+
26
+ from pl_models import LitEnformerSMILES, MetricLogger
27
+ from datasets import TahoeSMILESDataset, ENFORMER_INPUT_SEQ_LENGTH
28
+
29
+ import warnings
30
+ warnings.filterwarnings('ignore', '.*does not have many workers.*')
31
+ warnings.filterwarnings('ignore', '.*Detecting val_dataloader.*')
32
+
33
+ # --- Default Configs --- (can be overridden by config YAML)
34
+ DEFAULT_CONFIG = {
35
+ 'data': {
36
+ 'regions_csv_path': 'data/Enformer_genomic_regions_TSSCenteredGenes_FixedOverlapRemoval.csv',
37
+ 'pbulk_parquet_path': 'data/pseudoBulk_celllineXdrug_top3k_for_testing.parquet',
38
+ 'drug_meta_csv_path': 'data/drug_metadata.csv',
39
+ 'fasta_file_path': 'data/hg38.fa',
40
+ 'enformer_input_seq_length': 196_608,
41
+ 'morgan_fp_radius': 2, # For TahoeSMILESDataset
42
+ 'morgan_fp_nbits': 2048, # For TahoeSMILESDataset
43
+ 'filter_drugs_by_ids': None,
44
+ # Column name defaults for TahoeSMILESDataset
45
+ 'regions_gene_col': 'gene_id',
46
+ 'regions_chr_col': 'seqnames',
47
+ 'regions_start_col': 'start',
48
+ 'regions_end_col': 'end',
49
+ 'regions_strand_col': None,
50
+ 'regions_set_col': 'set', # column name for train/val/test split in regions_csv
51
+ 'pbulk_gene_col': 'gene_id',
52
+ 'pbulk_dose_col': 'drug_dose',
53
+ 'pbulk_expr_col': 'expression',
54
+ 'pbulk_cell_line_col': 'cell_line_id',
55
+ 'drug_meta_id_col': 'drug_id',
56
+ 'drug_meta_smiles_col': 'canonical_smiles'
57
+ },
58
+ 'model': {
59
+ 'enformer_model_name': 'EleutherAI/enformer-official-rough',
60
+ 'enformer_target_length': -1,
61
+ 'morgan_fingerprint_dim': 2048,
62
+ 'dose_input_dim': 1,
63
+ 'fusion_hidden_dim': 256,
64
+ 'final_output_tracks': 1,
65
+ 'learning_rate': 5e-6,
66
+ 'loss_alpha': 1.0,
67
+ 'weight_decay': 0.01,
68
+ 'eval_gene_sets': None
69
+ },
70
+ 'training': {
71
+ 'batch_size': 2,
72
+ 'num_workers': 0,
73
+ 'pin_memory': False,
74
+ 'max_epochs': 50,
75
+ 'gpus': -1, # -1 for all available GPUs, or specify count e.g., 1, 2
76
+ 'accelerator': 'auto',
77
+ 'strategy': 'ddp_find_unused_parameters_true',
78
+ 'precision': '16-mixed', # '32' or '16-mixed' or 'bf16-mixed'
79
+ 'val_check_interval': 1.0,
80
+ 'limit_train_batches': 1.0,
81
+ 'limit_val_batches': 1.0,
82
+ 'limit_test_batches': 1.0,
83
+ 'deterministic': True,
84
+ 'seed': 42
85
+ },
86
+ 'logging': {
87
+ 'wandb_project': 'TahoeformerDebug',
88
+ 'wandb_entity': None, # W&B info (username or team)
89
+ 'save_dir': 'outputs/model_checkpoints',
90
+ 'checkpoint_monitor_metric': 'validation_pearson_epoch',
91
+ 'checkpoint_monitor_mode': 'max',
92
+ 'save_top_k': 1,
93
+ 'early_stopping_metric': 'validation_pearson_epoch',
94
+ 'early_stopping_mode': 'max',
95
+ 'early_stopping_patience': 10
96
+ }
97
+ }
98
+
99
+ def delete_checkpoint_at_end(trainer):
100
+ """ Delete checkpoint after training and testing if desired """
101
+ checkpoint_callbacks = [cb for cb in trainer.callbacks if isinstance(cb, ModelCheckpoint)]
102
+ if checkpoint_callbacks:
103
+ checkpoint_callback = checkpoint_callbacks[0]
104
+ if hasattr(checkpoint_callback, 'best_model_path') and checkpoint_callback.best_model_path and os.path.exists(checkpoint_callback.best_model_path):
105
+ print(f"Deleting best checkpoint: {checkpoint_callback.best_model_path}")
106
+ os.remove(checkpoint_callback.best_model_path)
107
+ else:
108
+ print("No best model path found to delete or path does not exist.")
109
+ else:
110
+ print("No ModelCheckpoint callback found.")
111
+
112
+ def parse_optional_gene_list(filepath):
113
+ """ Parses a file containing one gene name per row, returns a list. Returns empty list if path is None or invalid. """
114
+ if filepath is None or not os.path.exists(filepath):
115
+ return []
116
+ gene_list = []
117
+ with open(filepath, 'r') as file:
118
+ for gene in file:
119
+ gene_list.append(gene.strip())
120
+ return gene_list
121
+
122
+ def load_config(config_path=None):
123
+ """Loads configuration from YAML, merging with defaults. Ensures deep copy of defaults."""
124
+ # Manual deep copy for 2 levels, as DEFAULT_CONFIG is structured
125
+ config = {}
126
+ for k, v in DEFAULT_CONFIG.items():
127
+ if isinstance(v, dict):
128
+ config[k] = v.copy() # Copies the inner dictionary
129
+ else:
130
+ config[k] = v
131
+
132
+ if config_path:
133
+ with open(config_path, 'r') as f:
134
+ user_config = yaml.safe_load(f)
135
+ if user_config: # Ensure user_config is not None (e.g. if YAML is empty)
136
+ for key, value in user_config.items():
137
+ if isinstance(value, dict) and key in config and isinstance(config[key], dict):
138
+ config[key].update(value) # Merge level 2 dicts
139
+ else:
140
+ config[key] = value # Overwrite or add new keys/values
141
+ return config
142
+
143
+ def build_model(config):
144
+ """
145
+ Builds the LitEnformerSMILES model using Morgan Fingerprints.
146
+ Model parameters are sourced from the 'model' section of the config.
147
+ """
148
+ model_params = config['model']
149
+ # Ensure morgan_fingerprint_dim from data config (for dataset) matches model config
150
+ # Model will use its own `morgan_fingerprint_dim` parameter.
151
+ # The dataset's `morgan_fp_nbits` should align with this.
152
+ print(f"Building LitEnformerSMILES model with morgan_fingerprint_dim: {model_params.get('morgan_fingerprint_dim')}")
153
+
154
+ return LitEnformerSMILES(
155
+ enformer_model_name=model_params.get('enformer_model_name'),
156
+ enformer_target_length=model_params.get('enformer_target_length'),
157
+ num_output_tracks_enformer_head=model_params.get('num_output_tracks_enformer_head'),
158
+ morgan_fingerprint_dim=model_params.get('morgan_fingerprint_dim', 2048), # Default from model if not in config
159
+ dose_input_dim=model_params.get('dose_input_dim'),
160
+ fusion_hidden_dim=model_params.get('fusion_hidden_dim'),
161
+ final_output_tracks=model_params.get('final_output_tracks'),
162
+ learning_rate=model_params.get('learning_rate'),
163
+ loss_alpha=model_params.get('loss_alpha'),
164
+ weight_decay=model_params.get('weight_decay'),
165
+ eval_gene_sets=model_params.get('eval_gene_sets')
166
+ )
167
+
168
+ def load_tahoe_smiles_dataloaders(config):
169
+ """
170
+ Initializes TahoeSMILESDataset (now using Morgan Fingerprints) and creates DataLoaders.
171
+ Dataset parameters are sourced from the 'data' section of the config.
172
+ Training parameters (batch_size, num_workers) from 'training' section.
173
+ """
174
+ data_config = config['data']
175
+ train_config = config['training']
176
+
177
+ # Pass Morgan fingerprint params to TahoeSMILESDataset
178
+ dataset_args = {
179
+ 'regions_csv_path': data_config['regions_csv_path'],
180
+ 'pbulk_parquet_path': data_config['pbulk_parquet_path'],
181
+ 'drug_meta_csv_path': data_config['drug_meta_csv_path'],
182
+ 'fasta_file_path': data_config['fasta_file_path'],
183
+ 'enformer_input_seq_length': data_config.get('enformer_input_seq_length'),
184
+ 'morgan_fp_radius': data_config.get('morgan_fp_radius', 2),
185
+ 'morgan_fp_nbits': data_config.get('morgan_fp_nbits', 2048),
186
+ 'filter_drugs_by_ids': data_config.get('filter_drugs_by_ids'),
187
+ # Pass column name configurations
188
+ 'regions_gene_col': data_config.get('regions_gene_col', 'gene_name'),
189
+ 'regions_chr_col': data_config.get('regions_chr_col', 'seqnames'),
190
+ 'regions_start_col': data_config.get('regions_start_col', 'starts'),
191
+ 'regions_end_col': data_config.get('regions_end_col', 'ends'),
192
+ 'regions_strand_col': data_config.get('regions_strand_col', None),
193
+ 'regions_set_col': data_config.get('regions_set_col', 'set'), # Added for set-based splitting
194
+ 'pbulk_gene_col': data_config.get('pbulk_gene_col', 'gene_id'),
195
+ 'pbulk_dose_col': data_config.get('pbulk_dose_col', 'dose_nM'),
196
+ 'pbulk_expr_col': data_config.get('pbulk_expr_col', 'value'),
197
+ 'pbulk_cell_line_col': data_config.get('pbulk_cell_line_col', 'cell_line_id'),
198
+ 'drug_meta_id_col': data_config.get('drug_meta_id_col', 'drug_id'),
199
+ 'drug_meta_smiles_col': data_config.get('drug_meta_smiles_col', 'canonical_smiles')
200
+ }
201
+
202
+ # print(f"Initializing TahoeSMILESDataset with morgan_fp_nbits: {dataset_args['morgan_fp_nbits']}")
203
+
204
+ # Instantiate dataset for each split using the 'target_set' parameter
205
+ print("Initializing train dataset...")
206
+ train_dataset = TahoeSMILESDataset(**dataset_args, target_set='train')
207
+ print("Initializing validation dataset...")
208
+ val_dataset = TahoeSMILESDataset(**dataset_args, target_set='valid')
209
+ # In the original Enformer_genomic_regions_TSSCenteredGenes_FixedOverlapRemoval.csv,
210
+ # the validation set is often named 'valid'. If it's 'validation' in your file, adjust accordingly.
211
+ print("Initializing test dataset...")
212
+ test_dataset = TahoeSMILESDataset(**dataset_args, target_set='test')
213
+
214
+
215
+ if len(train_dataset) == 0:
216
+ print("WARNING: Train dataset is empty. This could be due to filtering by set='train' or other data issues. Training might fail or be skipped.")
217
+ if len(val_dataset) == 0:
218
+ print("WARNING: Validation dataset is empty (set='valid'). Validation loop will likely be skipped.")
219
+ if len(test_dataset) == 0:
220
+ print("WARNING: Test dataset is empty (set='test'). Testing loop will likely be skipped.")
221
+
222
+ train_loader = DataLoader(
223
+ train_dataset,
224
+ batch_size=train_config.get('batch_size', 2),
225
+ shuffle=True,
226
+ num_workers=train_config.get('num_workers', 0),
227
+ pin_memory=train_config.get('pin_memory', False),
228
+ drop_last=True # Important for DDP and BatchNorm if batches can be size 1 per GPU
229
+ )
230
+ val_loader = DataLoader(
231
+ val_dataset,
232
+ batch_size=train_config.get('batch_size', 2) * 2, # Often use larger batch for val
233
+ shuffle=False,
234
+ num_workers=train_config.get('num_workers', 0),
235
+ pin_memory=train_config.get('pin_memory', False)
236
+ )
237
+ test_loader = DataLoader(
238
+ test_dataset,
239
+ batch_size=train_config.get('batch_size', 2) * 2,
240
+ shuffle=False,
241
+ num_workers=train_config.get('num_workers', 0),
242
+ pin_memory=train_config.get('pin_memory', False)
243
+ )
244
+ return train_loader, val_loader, test_loader
245
+
246
+ def load_trainer_and_callbacks(config, experiment_name_for_wandb, run_name_for_wandb):
247
+ """ Loads PyTorch Lightning Trainer and associated callbacks. """
248
+
249
+
250
+ metric_logger = MetricLogger(save_dir_prefix=os.path.join(config['logging']['save_dir'], "metrics"))
251
+
252
+
253
+ checkpoint_dir = os.path.join(config['logging']['save_dir'], 'checkpoints')
254
+ os.makedirs(checkpoint_dir, exist_ok=True)
255
+
256
+ monitor_metric = config['logging']['checkpoint_monitor_metric'] # MetricLogger logs with _epoch suffix
257
+ monitor_mode = config['logging']['checkpoint_monitor_mode']
258
+ save_top_k = config['logging'].get('save_top_k', 1)
259
+
260
+ checkpoint_callback = ModelCheckpoint(
261
+ dirpath=checkpoint_dir,
262
+ filename=f"{{epoch}}-{{{monitor_metric}:.4f}}",
263
+ save_top_k=save_top_k,
264
+ monitor=monitor_metric,
265
+ mode=monitor_mode
266
+ )
267
+
268
+
269
+ early_stop_monitor_metric = config['logging']['early_stopping_metric']
270
+ early_stop_mode = config['logging']['early_stopping_mode']
271
+ min_delta = 0.001
272
+ patience = config['logging']['early_stopping_patience']
273
+
274
+ early_stopping_callback = EarlyStopping(
275
+ monitor=early_stop_monitor_metric,
276
+ min_delta=min_delta,
277
+ patience=patience,
278
+ verbose=True,
279
+ mode=early_stop_mode
280
+ )
281
+
282
+ lr_monitor = LearningRateMonitor(logging_interval='step')
283
+
284
+ callbacks = [checkpoint_callback, metric_logger, early_stopping_callback, lr_monitor]
285
+
286
+
287
+ wandb_logger = None
288
+ if config['logging'].get('wandb_project'):
289
+ wandb_logger = WandbLogger(
290
+ name=run_name_for_wandb,
291
+ project=config['logging']['wandb_project'],
292
+ group=experiment_name_for_wandb,
293
+ config=config, # Log the entire config dictionary
294
+ save_dir=config['logging']['save_dir'], # Optional: ensure logger saves to the same base dir
295
+ id=run_name_for_wandb # Use the unique run_name as the W&B run ID
296
+ )
297
+
298
+ trainer = Trainer(
299
+ max_epochs=config['training']['max_epochs'],
300
+ precision=config['training']['precision'],
301
+ accumulate_grad_batches=config['training'].get('accumulate_grad_batches', 1),
302
+ gradient_clip_val=config['training'].get('gradient_clip_val', 0.5),
303
+ callbacks=callbacks,
304
+ logger=wandb_logger, # Use the configured logger
305
+ num_sanity_val_steps=config['training'].get('num_sanity_val_steps', 0), # Often 0 if val metrics are complex
306
+ log_every_n_steps=config['training'].get('log_every_n_steps', 50),
307
+ check_val_every_n_epoch=config['training'].get('check_val_every_n_epoch', 1),
308
+ deterministic=config['training']['deterministic'], # For reproducibility
309
+ strategy=config['training']['strategy'],
310
+ accelerator=config['training']['accelerator'],
311
+ devices=config['training'].get('gpus', 'auto')
312
+ )
313
+
314
+
315
+ if config['training'].get('accumulate_grad_batches', 1) > 1:
316
+ effective_batch_size = config['training']['batch_size'] * config['training'].get('accumulate_grad_batches', 1)
317
+ print(f"Gradient Accumulation: Effective batch size will be {effective_batch_size}")
318
+ # Log to wandb config if logger is active and wandb.run exists
319
+ if wandb_logger and wandb.run:
320
+ wandb.config.update({'effective_train_batch_size': effective_batch_size}, allow_val_change=True)
321
+
322
+ return trainer
323
+
324
+ def run_experiment(config: wandb.config):
325
+ """ Main training and evaluation loop. """
326
+ print("Starting experiment with configuration:")
327
+ for key, value in config.items():
328
+ print(f" {key}: {value}")
329
+
330
+
331
+ train_loader, val_loader, test_loader = load_tahoe_smiles_dataloaders(config)
332
+
333
+
334
+ eval_gene_sets = {
335
+ 'train_eval_set': parse_optional_gene_list(config.get('eval_train_gene_path')),
336
+ 'valid_eval_set': parse_optional_gene_list(config.get('eval_valid_gene_path')),
337
+ 'test_eval_set': parse_optional_gene_list(config.get('eval_test_gene_path'))
338
+ }
339
+ eval_gene_sets = {k: v for k, v in eval_gene_sets.items() if v} # Keep only non-empty lists
340
+
341
+
342
+ model = build_model(config)
343
+
344
+
345
+ experiment_name = config.get('experiment_name', 'DefaultExperiment')
346
+ run_name = config.get('run_name', f"{experiment_name}_default_run_id") # Fallback run_name
347
+
348
+ trainer = load_trainer_and_callbacks(config, experiment_name, run_name)
349
+
350
+
351
+ if config.get('validate_before_train', False) and val_loader.dataset:
352
+ print("Running pre-training validation loop...")
353
+ trainer.validate(model, dataloaders=val_loader)
354
+
355
+
356
+ print("Starting training...")
357
+ if train_loader.dataset:
358
+ trainer.fit(
359
+ model=model,
360
+ train_dataloaders=train_loader,
361
+ val_dataloaders=val_loader if val_loader.dataset else None
362
+ )
363
+ else:
364
+ print("Skipping training as train_loader is empty.")
365
+
366
+
367
+ print("Starting testing...")
368
+ if test_loader.dataset:
369
+
370
+ best_model_path = trainer.checkpoint_callback.best_model_path if hasattr(trainer.checkpoint_callback, 'best_model_path') else None
371
+ if best_model_path and os.path.exists(best_model_path):
372
+ print(f"Loading best model for testing from: {best_model_path}")
373
+ trainer.test(model, dataloaders=test_loader, ckpt_path=best_model_path)
374
+ elif not best_model_path:
375
+ print("No best_model_path found from checkpoint callback. Testing with current model state (if any training happened).")
376
+ trainer.test(model, dataloaders=test_loader) # Test with current model if no checkpoint or if training was skipped
377
+ else: # path exists but is false for some reason or doesnt exist
378
+ print(f"Best model path {best_model_path} not found. Testing with current model state.")
379
+ trainer.test(model, dataloaders=test_loader)
380
+ else:
381
+ print("Skipping testing as test_loader is empty.")
382
+
383
+ if config.get('delete_checkpoint_after_run', False):
384
+ delete_checkpoint_at_end(trainer)
385
+
386
+
387
+ def main():
388
+ parser = argparse.ArgumentParser(description='Run PyTorch Lightning Enformer-SMILES experiment.')
389
+ parser.add_argument('--config_path', type=str, required=True, help='Path to the YAML configuration file.')
390
+ # allow seed override from command line, though config is primary source
391
+ parser.add_argument("--seed", type=int, help="Override seed from config file.")
392
+ args = parser.parse_args()
393
+
394
+
395
+ effective_config = load_config(args.config_path)
396
+
397
+
398
+ if args.seed is not None:
399
+ # Ensure 'training' key exists if seed is to be put there
400
+ if 'training' not in effective_config:
401
+ effective_config['training'] = {}
402
+ effective_config['training']['seed'] = args.seed
403
+
404
+
405
+ seed = effective_config.get('training', {}).get('seed', 42)
406
+ pl.seed_everything(seed, workers=True)
407
+
408
+
409
+ if effective_config.get('training', {}).get('deterministic', False):
410
+ torch.use_deterministic_algorithms(True, warn_only=True) # Ensure deterministic ops if requested
411
+
412
+ current_script_dir = os.path.dirname(os.path.abspath(__file__))
413
+
414
+ experiment_name = effective_config.get('experiment_name', 'EnformerSMILESExperiment')
415
+
416
+
417
+ run_name = f"{experiment_name}_Seed-{seed}_{pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')}"
418
+ effective_config['run_name'] = run_name
419
+ effective_config['experiment_name'] = experiment_name # Ensure experiment_name is also in config
420
+
421
+
422
+ default_save_dir = os.path.join(current_script_dir, f"../results/{experiment_name}/{run_name}")
423
+
424
+ if 'logging' not in effective_config:
425
+ effective_config['logging'] = {}
426
+ effective_config['logging']['save_dir'] = effective_config.get('logging', {}).get('save_dir', default_save_dir)
427
+ os.makedirs(effective_config['logging']['save_dir'], exist_ok=True)
428
+ print(f"Results and checkpoints will be saved in: {effective_config['logging']['save_dir']}")
429
+
430
+
431
+ run_experiment(effective_config)
432
+
433
+ if effective_config.get('logging', {}).get('wandb_project') and wandb.run is not None:
434
+ wandb.finish()
435
+
436
+ if __name__ == '__main__':
437
+ main()