Spaces:
Runtime error
Runtime error
| import argparse | |
| import numpy as np | |
| import os | |
| import sys | |
| import time | |
| from tqdm import tqdm | |
| from collections import OrderedDict | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.utils.data import DataLoader | |
| from torch.utils.tensorboard import SummaryWriter | |
| from config import init_args | |
| import data | |
| import models | |
| from models import * | |
| from utils import utils, torch_utils | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def validation(args, net, criterion, data_loader, device='cuda'): | |
| # import pdb; pdb.set_trace() | |
| net.eval() | |
| pred_all = torch.tensor([]).to(device) | |
| target_all = torch.tensor([]).to(device) | |
| with torch.no_grad(): | |
| for step, batch in tqdm(enumerate(data_loader), total=len(data_loader), desc="Validation"): | |
| pred, target = predict(args, net, batch, device) | |
| pred_all = torch.cat([pred_all, pred], dim=0) | |
| target_all = torch.cat([target_all, target], dim=0) | |
| res = criterion.evaluate(pred_all, target_all) | |
| torch.cuda.empty_cache() | |
| net.train() | |
| return res | |
| def predict(args, net, batch, device): | |
| inputs = { | |
| 'frames': batch['frames'].to(device) | |
| } | |
| pred = net(inputs) | |
| target = batch['label'].to(device) | |
| return pred, target | |
| def train(args, device): | |
| # save dir | |
| gpus = torch.cuda.device_count() | |
| gpu_ids = list(range(gpus)) | |
| # ----- make dirs for checkpoints ----- # | |
| sys.stdout = utils.LoggerOutput(os.path.join('checkpoints', args.exp, 'log.txt')) | |
| os.makedirs('./checkpoints/' + args.exp, exist_ok=True) | |
| writer = SummaryWriter(os.path.join('./checkpoints', args.exp, 'visualization')) | |
| # ------------------------------------- # | |
| tqdm.write('{}'.format(args)) | |
| # ------------------------------------ # | |
| # ----- Dataset and Dataloader ----- # | |
| train_dataset = data.CountixAVDataset(args, split='train') | |
| # train_dataset.getitem_test(1) | |
| train_loader = DataLoader( | |
| train_dataset, | |
| batch_size=args.batch_size, | |
| shuffle=True, | |
| num_workers=args.num_workers, | |
| pin_memory=True, | |
| drop_last=False) | |
| val_dataset = data.CountixAVDataset(args, split='val') | |
| val_loader = DataLoader( | |
| val_dataset, | |
| batch_size=args.batch_size, | |
| shuffle=False, | |
| num_workers=args.num_workers, | |
| pin_memory=True, | |
| drop_last=False) | |
| # --------------------------------- # | |
| # ----- Network ----- # | |
| net = models.VideoOnsetNet(pretrained=False).to(device) | |
| criterion = models.BCLoss(args) | |
| optimizer = torch_utils.make_optimizer(net, args) | |
| # --------------------- # | |
| # -------- Loading checkpoints weights ------------- # | |
| if args.resume: | |
| resume = './checkpoints/' + args.resume | |
| net, args.start_epoch = torch_utils.load_model(resume, net, device=device, strict=True) | |
| if args.resume_optim: | |
| tqdm.write('loading optimizer...') | |
| optim_state = torch.load(resume)['optimizer'] | |
| optimizer.load_state_dict(optim_state) | |
| tqdm.write('loaded optimizer!') | |
| else: | |
| args.start_epoch = 0 | |
| # ------------------- | |
| net = nn.DataParallel(net, device_ids=gpu_ids) | |
| # --------- Random or resume validation ------------ # | |
| res = validation(args, net, criterion, val_loader, device) | |
| writer.add_scalars('VideoOnset' + '/validation', res, args.start_epoch) | |
| tqdm.write("Beginning, Validation results: {}".format(res)) | |
| tqdm.write('\n') | |
| # ----------------- Training ---------------- # | |
| # import pdb; pdb.set_trace() | |
| VALID_STEP = args.valid_step | |
| for epoch in range(args.start_epoch, args.epochs): | |
| running_loss = 0.0 | |
| torch_utils.adjust_learning_rate(optimizer, epoch, args) | |
| for step, batch in tqdm(enumerate(train_loader), total=len(train_loader), desc="Training"): | |
| pred, target = predict(args, net, batch, device) | |
| loss = criterion(pred, target) | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| if step % 1 == 0: | |
| tqdm.write("Epoch: {}/{}, step: {}/{}, loss: {}".format(epoch+1, args.epochs, step+1, len(train_loader), loss)) | |
| running_loss += loss.item() | |
| current_step = epoch * len(train_loader) + step + 1 | |
| BOARD_STEP = 3 | |
| if (step+1) % BOARD_STEP == 0: | |
| writer.add_scalar('VideoOnset' + '/training loss', running_loss / BOARD_STEP, current_step) | |
| running_loss = 0.0 | |
| # ----------- Validtion -------------- # | |
| if (epoch + 1) % VALID_STEP == 0: | |
| res = validation(args, net, criterion, val_loader, device) | |
| writer.add_scalars('VideoOnset' + '/validation', res, epoch + 1) | |
| tqdm.write("Epoch: {}/{}, Validation results: {}".format(epoch + 1, args.epochs, res)) | |
| # ---------- Save model ----------- # | |
| SAVE_STEP = args.save_step | |
| if (epoch + 1) % SAVE_STEP == 0: | |
| path = os.path.join('./checkpoints', args.exp, 'checkpoint_ep' + str(epoch + 1) + '.pth.tar') | |
| torch.save({'epoch': epoch + 1, | |
| 'step': current_step, | |
| 'state_dict': net.state_dict(), | |
| 'optimizer': optimizer.state_dict(), | |
| }, | |
| path) | |
| # --------------------------------- # | |
| torch.cuda.empty_cache() | |
| tqdm.write('Training Complete!') | |
| writer.close() | |
| def test(args, device): | |
| # save dir | |
| gpus = torch.cuda.device_count() | |
| gpu_ids = list(range(gpus)) | |
| # ----- make dirs for results ----- # | |
| sys.stdout = utils.LoggerOutput(os.path.join('results', args.exp, 'log.txt')) | |
| os.makedirs('./results/' + args.exp, exist_ok=True) | |
| # ------------------------------------- # | |
| tqdm.write('{}'.format(args)) | |
| # ------------------------------------ # | |
| # ----- Dataset and Dataloader ----- # | |
| test_dataset = data.CountixAVDataset(args, split='test') | |
| test_loader = DataLoader( | |
| test_dataset, | |
| batch_size=args.batch_size, | |
| shuffle=False, | |
| num_workers=args.num_workers, | |
| pin_memory=True, | |
| drop_last=False) | |
| # --------------------------------- # | |
| # ----- Network ----- # | |
| net = models.VideoOnsetNet(pretrained=False).to(device) | |
| criterion = models.BCLoss(args) | |
| # -------- Loading checkpoints weights ------------- # | |
| if args.resume: | |
| resume = './checkpoints/' + args.resume | |
| net, _ = torch_utils.load_model(resume, net, device=device, strict=True) | |
| # ------------------- # | |
| net = nn.DataParallel(net, device_ids=gpu_ids) | |
| # --------- Testing ------------ # | |
| res = validation(args, net, criterion, test_loader, device) | |
| tqdm.write("Testing results: {}".format(res)) | |
| # CUDA_VISIBLE_DEVICES=1 python main.py --exp='EXP1' --epochs=100 --batch_size=12 --num_workers=8 --save_step=10 --valid_step=1 --lr=0.0001 --optim='Adam' --repeat=1 --schedule='cos' | |
| if __name__ == '__main__': | |
| args = init_args() | |
| if args.test_mode: | |
| test(args, DEVICE) | |
| else: | |
| train(args, DEVICE) |