File size: 1,493 Bytes
9b9e0ee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
import torch
import torch.nn as nn
import numpy as np
from timm.models.layers import to_2tuple
import models_vit
from audiovisual_dataset import AudioVisualDataset, collate_fn
from torch.utils.data import DataLoader
from util.stat import calculate_stats
from tqdm import tqdm
from AudioMAE import AudioMAE
if __name__ == "__main__":
device = "cuda"
dataset = AudioVisualDataset(
datafiles=[
"/mnt/bn/data-xubo/dataset/audioset_videos/datafiles/audioset_eval.json"
],
# disable SpecAug during evaluation
freqm=0,
timem=0,
return_label=True,
)
model = AudioMAE().to(device)
model.eval()
outputs = []
targets = []
dataloader = DataLoader(
dataset, batch_size=64, num_workers=8, shuffle=False, collate_fn=collate_fn
)
print("Start evaluation on AudioSet ...")
with torch.no_grad():
for data in tqdm(dataloader):
fbank = data["fbank"] # [B, 1, T, F]
fbank = fbank.to(device)
output = model(fbank, mask_t_prob=0.0, mask_f_prob=0.0)
target = data["labels"]
outputs.append(output)
targets.append(target)
outputs = torch.cat(outputs).cpu().numpy()
targets = torch.cat(targets).cpu().numpy()
stats = calculate_stats(outputs, targets)
AP = [stat["AP"] for stat in stats]
mAP = np.mean([stat["AP"] for stat in stats])
print("Done ... mAP: {:.6f}".format(mAP))
# mAP: 0.463003
|