File size: 7,480 Bytes
c1a7f73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
import pytorch_lightning as pl
import os
import shutil
import fnmatch
import torch
from argparse import ArgumentParser
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.strategies import DDPStrategy
from pytorch_lightning.loggers import WandbLogger

from dev.utils.func import RankedLogger, load_config_act, CONSOLE
from dev.datasets.scalable_dataset import MultiDataModule
from dev.model.smart import SMART


def backup(source_dir, backup_dir):
    """
    Back up the source directory (code and configs) to a backup directory.
    """

    if os.path.exists(backup_dir):
        return
    os.makedirs(backup_dir, exist_ok=False)

    # Helper function to check if a path matches exclude patterns
    def should_exclude(path):
        for pattern in exclude_patterns:
            if fnmatch.fnmatch(os.path.basename(path), pattern):
                return True
        return False

    # Iterate through the files and directories in source_dir
    for root, dirs, files in os.walk(source_dir):
        # Skip excluded directories
        dirs[:] = [d for d in dirs if not should_exclude(d)]

        # Determine the relative path and destination path
        rel_path = os.path.relpath(root, source_dir)
        dest_dir = os.path.join(backup_dir, rel_path)
        os.makedirs(dest_dir, exist_ok=True)
        
        # Copy all relevant files
        for file in files:
            if any(fnmatch.fnmatch(file, pattern) for pattern in include_patterns):
                shutil.copy2(os.path.join(root, file), os.path.join(dest_dir, file))
    
    logger.info(f"Backup completed. Files saved to: {backup_dir}")


if __name__ == '__main__':
    pl.seed_everything(2024, workers=True)
    torch.set_printoptions(precision=3)

    parser = ArgumentParser()
    parser.add_argument('--config', type=str, default='configs/ours_long_term.yaml')
    parser.add_argument('--pretrain_ckpt', type=str, default=None,
                        help='Path to any pretrained model, will only load its parameters.'
    )
    parser.add_argument('--ckpt_path', type=str, default=None,
                        help='Path to any trained model, will load all the states.'
    )
    parser.add_argument('--save_ckpt_path', type=str, default='output/debug',
                        help='Path to save the checkpoints in training mode'
    )
    parser.add_argument('--save_path', type=str, default=None,
                        help='Path to save the inference results in validation and test mode.'
    )
    parser.add_argument('--wandb', action='store_true',
                        help='Whether to use wandb logger in training.'
    )
    parser.add_argument('--devices', type=int, default=1)
    parser.add_argument('--train', action='store_true')
    parser.add_argument('--validate', action='store_true')
    parser.add_argument('--test', action='store_true')
    parser.add_argument('--plot_rollouts', action='store_true')
    args = parser.parse_args()

    if not (args.train or args.validate or args.test or args.plot_rollouts):
        raise RuntimeError(f"Got invalid action, should be one of ['train', 'validate', 'test', 'plot_rollouts']")

    # ! setup logger
    logger = RankedLogger(__name__, rank_zero_only=True)

    # ! backup codes
    exclude_patterns = ['*output*', '*logs', 'wandb', 'data', '*debug*', '*backup*', 'interact_*', '*edge_map*', '__pycache__']
    include_patterns = ['*.py', '*.json', '*.yaml', '*.yml', '*.sh']
    backup(os.getcwd(), os.path.join(args.save_ckpt_path, 'backups'))

    config = load_config_act(args.config)

    wandb_logger = None
    if args.wandb and not int(os.getenv('DEBUG', 0)):
        # squeue -O username,state,nodelist,gres,minmemory,numcpus,name
        wandb_logger = WandbLogger(project='simagent')

    trainer_config = config.Trainer
    max_epochs = trainer_config.max_epochs

    # ! setup datamodule and model
    datamodule = MultiDataModule(**vars(config.Dataset), logger=logger)
    model = SMART(config.Model, save_path=args.save_ckpt_path, logger=logger, max_epochs=max_epochs)
    if args.pretrain_ckpt:
        model.load_state_from_file(filename=args.pretrain_ckpt)
    strategy = DDPStrategy(find_unused_parameters=True, gradient_as_bucket_view=True)
    logger.info(f'Build model: {model.__class__.__name__} datamodule: {datamodule.__class__.__name__}')

    # ! checkpoint configuration
    every_n_epochs = 1
    if int(os.getenv('OVERFIT', 0)):
        max_epochs = trainer_config.overfit_epochs
        every_n_epochs = 100

    if int(os.getenv('CHECK_INPUTS', 0)):
        max_epochs = 1

    check_val_every_n_epoch = 1  # save checkpoints for each epoch
    model_checkpoint = ModelCheckpoint(dirpath=args.save_ckpt_path,
                                       filename='{epoch:02d}',
                                       save_top_k=5,
                                       monitor='epoch',
                                       mode='max',
                                       save_last=True,
                                       every_n_train_steps=1000,
                                       save_on_train_epoch_end=True)

    # ! setup trainer
    lr_monitor = LearningRateMonitor(logging_interval='epoch')
    trainer = pl.Trainer(accelerator=trainer_config.accelerator, devices=args.devices if args.devices is not None else trainer_config.devices,
                         strategy=strategy, logger=wandb_logger,
                         accumulate_grad_batches=trainer_config.accumulate_grad_batches,
                         num_nodes=trainer_config.num_nodes,
                         callbacks=[model_checkpoint, lr_monitor],
                         max_epochs=max_epochs,
                         num_sanity_val_steps=0,
                         check_val_every_n_epoch=check_val_every_n_epoch,
                         log_every_n_steps=1,
                         gradient_clip_val=0.5)
    logger.info(f'Build trainer: {trainer.__class__.__name__}')

    # ! run
    if args.train:

        logger.info(f'Start training ...')
        trainer.fit(model, datamodule, ckpt_path=args.ckpt_path)

    # NOTE: here both validation and test process use validation split data
    # for validation, we enable the online metric calculation with results dumping
    # for test, we disable it and only dump the inference results.
    else:

        if args.save_path is not None:
            save_path = args.save_path
        else:
            assert args.ckpt_path is not None and os.path.exists(args.ckpt_path), \
                    f'Path {args.ckpt_path} not exists!'
            save_path = os.path.join(os.path.dirname(args.ckpt_path), 'validation')
        os.makedirs(save_path, exist_ok=True)
        CONSOLE.log(f'Results will be saved to [yellow]{save_path}[/]')

        model.save_path = save_path

        if not args.ckpt_path:
            CONSOLE.log(f'[yellow] Warning: no checkpoint will be loaded in validation! [/]')

        if args.validate:

            CONSOLE.log('[on blue] Start validating ... [/]')
            model.set(mode='validation')

        elif args.test:

            CONSOLE.log('[on blue] Sart testing ... [/]')
            model.set(mode='test')

        elif args.plot_rollouts:

            CONSOLE.log('[on blue] Sart generating ... [/]')
            model.set(mode='plot_rollouts')

        trainer.validate(model, datamodule, ckpt_path=args.ckpt_path)