|
import json |
|
import logging |
|
import math |
|
import os |
|
import time |
|
from contextlib import suppress |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
|
|
try: |
|
import wandb |
|
except ImportError: |
|
wandb = None |
|
|
|
from open_clip import LPLoss, LPMetrics, lp_gather_features |
|
from open_clip.utils import do_mixup, get_mix_lambda |
|
from .distributed import is_master |
|
from .zero_shot import zero_shot_eval |
|
|
|
|
|
class AverageMeter(object): |
|
"""Computes and stores the average and current value""" |
|
|
|
def __init__(self): |
|
self.reset() |
|
|
|
def reset(self): |
|
self.val = 0 |
|
self.avg = 0 |
|
self.sum = 0 |
|
self.count = 0 |
|
|
|
def update(self, val, n=1): |
|
self.val = val |
|
self.sum += val * n |
|
self.count += n |
|
self.avg = self.sum / self.count |
|
|
|
|
|
def unwrap_model(model): |
|
if hasattr(model, "module"): |
|
return model.module |
|
else: |
|
return model |
|
|
|
|
|
def train_one_epoch( |
|
model, |
|
data, |
|
epoch, |
|
optimizer, |
|
scaler, |
|
scheduler, |
|
args, |
|
tb_writer=None, |
|
extra_suffix="", |
|
): |
|
device = torch.device(args.device) |
|
autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress |
|
model.train() |
|
loss = LPLoss(args.lp_loss) |
|
|
|
dataloader, sampler = data["train"].dataloader, data["train"].sampler |
|
if args.distributed and sampler is not None: |
|
sampler.set_epoch(epoch) |
|
num_batches_per_epoch = dataloader.num_batches |
|
sample_digits = math.ceil(math.log(dataloader.num_samples + 1, 10)) |
|
|
|
|
|
if args.dataset_type == "toy": |
|
dataloader.dataset.generate_queue() |
|
|
|
loss_m = AverageMeter() |
|
batch_time_m = AverageMeter() |
|
data_time_m = AverageMeter() |
|
end = time.time() |
|
|
|
for i, batch in enumerate(dataloader): |
|
step = num_batches_per_epoch * epoch + i |
|
|
|
if isinstance(scheduler, dict): |
|
for s in scheduler.values(): |
|
s(step) |
|
else: |
|
scheduler(step) |
|
|
|
audio = batch |
|
class_label = batch["class_label"] |
|
|
|
class_label = class_label.to(device=device, non_blocking=True) |
|
|
|
if args.mixup: |
|
|
|
mix_lambda = torch.from_numpy( |
|
get_mix_lambda(0.5, len(audio["waveform"])) |
|
).to(device) |
|
class_label = do_mixup(class_label, mix_lambda) |
|
else: |
|
mix_lambda = None |
|
|
|
data_time_m.update(time.time() - end) |
|
if isinstance(optimizer, dict): |
|
for o_ in optimizer.values(): |
|
o_.zero_grad() |
|
else: |
|
optimizer.zero_grad() |
|
|
|
with autocast(): |
|
pred = model(audio, mix_lambda=mix_lambda, device=device) |
|
total_loss = loss(pred, class_label) |
|
|
|
if isinstance(optimizer, dict): |
|
if scaler is not None: |
|
scaler.scale(total_loss).backward() |
|
for o_ in optimizer.values(): |
|
if args.horovod: |
|
o_.synchronize() |
|
scaler.unscale_(o_) |
|
with o_.skip_synchronize(): |
|
scaler.step(o_) |
|
else: |
|
scaler.step(o_) |
|
scaler.update() |
|
else: |
|
total_loss.backward() |
|
for o_ in optimizer.values(): |
|
o_.step() |
|
else: |
|
if scaler is not None: |
|
scaler.scale(total_loss).backward() |
|
if args.horovod: |
|
optimizer.synchronize() |
|
scaler.unscale_(optimizer) |
|
with optimizer.skip_synchronize(): |
|
scaler.step(optimizer) |
|
else: |
|
scaler.step(optimizer) |
|
scaler.update() |
|
else: |
|
total_loss.backward() |
|
optimizer.step() |
|
|
|
|
|
with torch.no_grad(): |
|
unwrap_model(model).clap_model.logit_scale_a.clamp_(0, math.log(100)) |
|
unwrap_model(model).clap_model.logit_scale_t.clamp_(0, math.log(100)) |
|
|
|
batch_time_m.update(time.time() - end) |
|
end = time.time() |
|
batch_count = i + 1 |
|
|
|
if is_master(args) and (i % 100 == 0 or batch_count == num_batches_per_epoch): |
|
if isinstance(audio, dict): |
|
batch_size = len(audio["waveform"]) |
|
else: |
|
batch_size = len(audio) |
|
num_samples = batch_count * batch_size * args.world_size |
|
samples_per_epoch = dataloader.num_samples |
|
percent_complete = 100.0 * batch_count / num_batches_per_epoch |
|
|
|
|
|
loss_m.update(total_loss.item(), batch_size) |
|
if isinstance(optimizer, dict): |
|
logging.info( |
|
f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " |
|
f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) " |
|
f"Data (t): {data_time_m.avg:.3f} " |
|
f"Batch (t): {batch_time_m.avg:.3f} " |
|
f"LR: {[o_.param_groups[0]['lr'] for o_ in optimizer.values()]}" |
|
) |
|
log_data = { |
|
"loss": loss_m.val, |
|
"data_time": data_time_m.val, |
|
"batch_time": batch_time_m.val, |
|
"lr": [o_.param_groups[0]["lr"] for o_ in optimizer.values()], |
|
} |
|
else: |
|
logging.info( |
|
f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " |
|
f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) " |
|
f"Data (t): {data_time_m.avg:.3f} " |
|
f"Batch (t): {batch_time_m.avg:.3f} " |
|
f"LR: {optimizer.param_groups[0]['lr']:5f} " |
|
) |
|
|
|
|
|
log_data = { |
|
"loss": loss_m.val, |
|
"data_time": data_time_m.val, |
|
"batch_time": batch_time_m.val, |
|
"lr": optimizer.param_groups[0]["lr"], |
|
} |
|
for name, val in log_data.items(): |
|
name = f"train{extra_suffix}/{name}" |
|
if tb_writer is not None: |
|
tb_writer.add_scalar(name, val, step) |
|
if args.wandb: |
|
assert wandb is not None, "Please install wandb." |
|
wandb.log({name: val, "step": step}) |
|
|
|
|
|
batch_time_m.reset() |
|
data_time_m.reset() |
|
|
|
|
|
|
|
def evaluate(model, data, epoch, args, tb_writer=None, extra_suffix=""): |
|
metrics = {} |
|
if not args.parallel_eval: |
|
if not is_master(args): |
|
return metrics |
|
device = torch.device(args.device) |
|
model.eval() |
|
|
|
|
|
|
|
|
|
if is_master(args): |
|
print("Evaluating...") |
|
metric_names = args.lp_metrics.split(",") |
|
eval_tool = LPMetrics(metric_names=metric_names) |
|
|
|
autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress |
|
if "val" in data and ( |
|
args.val_frequency |
|
and ((epoch % args.val_frequency) == 0 or epoch == args.epochs) |
|
): |
|
if args.parallel_eval: |
|
dataloader, sampler = data["val"].dataloader, data["val"].sampler |
|
if args.distributed and sampler is not None: |
|
sampler.set_epoch(epoch) |
|
samples_per_val = dataloader.num_samples |
|
else: |
|
dataloader = data["val"].dataloader |
|
num_samples = 0 |
|
samples_per_val = dataloader.num_samples |
|
|
|
eval_info = {"pred": [], "target": []} |
|
with torch.no_grad(): |
|
for i, batch in enumerate(dataloader): |
|
audio = batch |
|
class_label = batch["class_label"] |
|
|
|
|
|
class_label = class_label.to(device=device, non_blocking=True) |
|
|
|
with autocast(): |
|
pred = model(audio, device=device) |
|
if args.parallel_eval: |
|
pred, class_label = lp_gather_features( |
|
pred, class_label, args.world_size, args.horovod |
|
) |
|
eval_info["pred"].append(pred) |
|
eval_info["target"].append(class_label) |
|
|
|
num_samples += class_label.shape[0] |
|
|
|
if (i % 100) == 0: |
|
logging.info( |
|
f"Eval Epoch: {epoch} [{num_samples} / {samples_per_val}]" |
|
) |
|
|
|
if is_master(args): |
|
eval_info["pred"] = torch.cat(eval_info["pred"], 0).cpu() |
|
eval_info["target"] = torch.cat(eval_info["target"], 0).cpu() |
|
metric_dict = eval_tool.evaluate_mertics( |
|
eval_info["pred"], eval_info["target"] |
|
) |
|
metrics.update(metric_dict) |
|
if "epoch" not in metrics.keys(): |
|
metrics.update({"epoch": epoch}) |
|
|
|
if is_master(args): |
|
if not metrics: |
|
return metrics |
|
|
|
logging.info( |
|
f"Eval Epoch: {epoch} " |
|
+ "\n".join( |
|
["\t".join([f"{m}: {round(metrics[m], 4):.4f}"]) for m in metrics] |
|
) |
|
) |
|
if args.save_logs: |
|
for name, val in metrics.items(): |
|
if tb_writer is not None: |
|
tb_writer.add_scalar(f"val{extra_suffix}/{name}", val, epoch) |
|
|
|
with open(os.path.join(args.checkpoint_path, "results.jsonl"), "a+") as f: |
|
f.write(json.dumps(metrics)) |
|
f.write("\n") |
|
|
|
if args.wandb: |
|
assert wandb is not None, "Please install wandb." |
|
for name, val in metrics.items(): |
|
wandb.log({f"val{extra_suffix}/{name}": val, "epoch": epoch}) |
|
|
|
return metrics |
|
else: |
|
return metrics |
|
|