File size: 4,849 Bytes
9b9e0ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
# Author: Haohe Liu
# Email: [email protected]
# Date: 11 Feb 2023

import sys

sys.path.append("src")

import os
import wandb

import argparse
import yaml
import torch
from pytorch_lightning.strategies.ddp import DDPStrategy
from qa_mdt.audioldm_train.utilities.data.dataset import AudioDataset
from torch.utils.data import DataLoader
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning import Trainer
from qa_mdt.audioldm_train.modules.latent_encoder.autoencoder import AutoencoderKL
from pytorch_lightning.callbacks import ModelCheckpoint
from qa_mdt.audioldm_train.utilities.tools import get_restore_step


def listdir_nohidden(path):
    for f in os.listdir(path):
        if not f.startswith("."):
            yield f


def main(configs, exp_group_name, exp_name):
    if "precision" in configs.keys():
        torch.set_float32_matmul_precision(configs["precision"])
    batch_size = config_yaml["model"]["params"]["batchsize"]
    log_path = config_yaml["log_directory"]

    if "dataloader_add_ons" in configs["data"].keys():
        dataloader_add_ons = configs["data"]["dataloader_add_ons"]
    else:
        dataloader_add_ons = []

    dataset = AudioDataset(config_yaml, split="train", add_ons=dataloader_add_ons)

    loader = DataLoader(
        dataset, batch_size=batch_size, num_workers=8, pin_memory=True, shuffle=True
    )

    print(
        "The length of the dataset is %s, the length of the dataloader is %s, the batchsize is %s"
        % (len(dataset), len(loader), batch_size)
    )

    val_dataset = AudioDataset(config_yaml, split="val", add_ons=dataloader_add_ons)

    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        num_workers=8,
        shuffle=True,
    )

    model = AutoencoderKL(
        ddconfig=config_yaml["model"]["params"]["ddconfig"],
        lossconfig=config_yaml["model"]["params"]["lossconfig"],
        embed_dim=config_yaml["model"]["params"]["embed_dim"],
        image_key=config_yaml["model"]["params"]["image_key"],
        base_learning_rate=config_yaml["model"]["base_learning_rate"],
        subband=config_yaml["model"]["params"]["subband"],
        sampling_rate=config_yaml["preprocessing"]["audio"]["sampling_rate"],
    )

    try:
        config_reload_from_ckpt = configs["reload_from_ckpt"]
    except:
        config_reload_from_ckpt = None

    checkpoint_path = os.path.join(log_path, exp_group_name, exp_name, "checkpoints")

    checkpoint_callback = ModelCheckpoint(
        dirpath=checkpoint_path,
        monitor="global_step",
        mode="max",
        filename="checkpoint-{global_step:.0f}",
        every_n_train_steps=5000,
        save_top_k=config_yaml["step"]["save_top_k"],
        auto_insert_metric_name=False,
        save_last=True,
    )

    wandb_path = os.path.join(log_path, exp_group_name, exp_name)

    model.set_log_dir(log_path, exp_group_name, exp_name)

    os.makedirs(checkpoint_path, exist_ok=True)

    if len(os.listdir(checkpoint_path)) > 0:
        print("Load checkpoint from path: %s" % checkpoint_path)
        restore_step, n_step = get_restore_step(checkpoint_path)
        resume_from_checkpoint = os.path.join(checkpoint_path, restore_step)
        print("Resume from checkpoint", resume_from_checkpoint)
    elif config_reload_from_ckpt is not None:
        resume_from_checkpoint = config_reload_from_ckpt
        print("Reload ckpt specified in the config file %s" % resume_from_checkpoint)
    else:
        print("Train from scratch")
        resume_from_checkpoint = None

    devices = torch.cuda.device_count()

    wandb_logger = WandbLogger(
        save_dir=wandb_path,
        project=config_yaml["project"],
        config=config_yaml,
        name="%s/%s" % (exp_group_name, exp_name),
    )

    trainer = Trainer(
        accelerator="gpu",
        devices=devices,
        logger=wandb_logger,
        limit_val_batches=100,
        callbacks=[checkpoint_callback],
        strategy=DDPStrategy(find_unused_parameters=True),
        val_check_interval=2000,
    )

    # TRAINING
    trainer.fit(model, loader, val_loader, ckpt_path=resume_from_checkpoint)

    # EVALUTION
    # trainer.test(model, test_loader, ckpt_path=resume_from_checkpoint)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-c",
        "--autoencoder_config",
        type=str,
        required=True,
        help="path to autoencoder config .yam",
    )

    args = parser.parse_args()

    config_yaml = args.autoencoder_config
    exp_name = os.path.basename(config_yaml.split(".")[0])
    exp_group_name = os.path.basename(os.path.dirname(config_yaml))

    config_yaml = os.path.join(config_yaml)

    config_yaml = yaml.load(open(config_yaml, "r"), Loader=yaml.FullLoader)

    main(config_yaml, exp_group_name, exp_name)