Spaces:
Running
Running
| # This code uses the decoder loss directly. | |
| # | |
| # | |
| # Deep learning | |
| import torch | |
| from torch_optimizer.lamb import Lamb | |
| from trainer import TrainerDirectDecoder | |
| # Parallel | |
| from torch.utils.data.distributed import DistributedSampler | |
| from torch.distributed import init_process_group, destroy_process_group | |
| # Data | |
| from utils import MoleculeModule, get_optim_groups | |
| from torch.utils.data import DataLoader | |
| # Standard library | |
| import os | |
| import args | |
| def ddp_setup(): | |
| init_process_group(backend="nccl") | |
| torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) | |
| def load_train_objs(config): | |
| # load data | |
| train_loader = MoleculeModule( | |
| config.max_len, | |
| config.train_load, | |
| config.data_root | |
| ) | |
| train_loader.setup() | |
| loader = DataLoader( | |
| train_loader.pubchem, | |
| batch_size=config.n_batch, | |
| pin_memory=True, | |
| shuffle=False, | |
| collate_fn=train_loader.text_encoder.process, | |
| sampler=DistributedSampler(train_loader.pubchem), | |
| num_workers=config.n_workers | |
| ) | |
| # load model | |
| if config.smi_ted_version == 'v1': | |
| from smi_ted_light.load import Smi_ted | |
| elif config.smi_ted_version == 'v2': | |
| from smi_ted_large.load import Smi_ted | |
| model = Smi_ted(config, train_loader.get_vocab()).to('cuda') | |
| model.apply(model._init_weights) | |
| # load optimizer | |
| optim_groups = get_optim_groups(model) | |
| optimizer = torch.optim.AdamW(optim_groups, lr=config.lr_decoder, betas=(0.9, 0.99), fused=True) | |
| return loader, model, optimizer | |
| def main( | |
| config, | |
| save_every: int, | |
| total_epochs: int, | |
| save_checkpoint_path: str, | |
| load_checkpoint_path: str | |
| ): | |
| ddp_setup() | |
| # training objects | |
| train_data, model, optimizer = load_train_objs(config) | |
| # init trainer | |
| trainer = TrainerDirectDecoder( | |
| model, | |
| train_data, | |
| optimizer, | |
| save_every, | |
| save_checkpoint_path, | |
| load_checkpoint_path, | |
| config | |
| ) | |
| trainer.train(total_epochs) | |
| destroy_process_group() | |
| if __name__ == '__main__': | |
| parser = args.get_parser() | |
| args = parser.parse_args() | |
| main( | |
| args, | |
| args.checkpoint_every, | |
| args.max_epochs, | |
| save_checkpoint_path=args.save_checkpoint_path, | |
| load_checkpoint_path=args.load_checkpoint_path, | |
| ) | |