File size: 668 Bytes
b84549f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch
from torch import nn 


def to(x, device):
    if isinstance(x, dict):
        for k, v in x.items():
            if isinstance(v, torch.Tensor):
                x[k] = v.to(device)
    else:
        x = x.to(device)
    return x


def get_cur_acc(testset, hyps, model, shuffle, iter_index):
    from data import split_dataset, build_dataloader
    cur_test_batch_dataset = split_dataset(testset, hyps['val_batch_size'], iter_index)[0]
    cur_test_batch_dataloader = build_dataloader(cur_test_batch_dataset, hyps['train_batch_size'], hyps['num_workers'], False, shuffle)
    cur_acc = model.get_accuracy(cur_test_batch_dataloader)
    return cur_acc