File size: 1,493 Bytes
9b9e0ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch
import torch.nn as nn
import numpy as np
from timm.models.layers import to_2tuple
import models_vit
from audiovisual_dataset import AudioVisualDataset, collate_fn
from torch.utils.data import DataLoader
from util.stat import calculate_stats
from tqdm import tqdm
from AudioMAE import AudioMAE

if __name__ == "__main__":
    device = "cuda"
    dataset = AudioVisualDataset(
        datafiles=[
            "/mnt/bn/data-xubo/dataset/audioset_videos/datafiles/audioset_eval.json"
        ],
        # disable SpecAug during evaluation
        freqm=0,
        timem=0,
        return_label=True,
    )

    model = AudioMAE().to(device)
    model.eval()

    outputs = []
    targets = []

    dataloader = DataLoader(
        dataset, batch_size=64, num_workers=8, shuffle=False, collate_fn=collate_fn
    )

    print("Start evaluation on AudioSet ...")
    with torch.no_grad():
        for data in tqdm(dataloader):
            fbank = data["fbank"]  # [B, 1, T, F]
            fbank = fbank.to(device)
            output = model(fbank, mask_t_prob=0.0, mask_f_prob=0.0)
            target = data["labels"]
            outputs.append(output)
            targets.append(target)

    outputs = torch.cat(outputs).cpu().numpy()
    targets = torch.cat(targets).cpu().numpy()
    stats = calculate_stats(outputs, targets)

    AP = [stat["AP"] for stat in stats]
    mAP = np.mean([stat["AP"] for stat in stats])
    print("Done ... mAP: {:.6f}".format(mAP))

    # mAP: 0.463003