Spaces:
Runtime error
Runtime error
File size: 1,997 Bytes
cb80c28 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
import sys
import logging
import copy
import torch
from PIL import Image
import torchvision.transforms as transforms
from utils import factory
from utils.data_manager import DataManager
from utils.toolkit import count_parameters
import os
import numpy as np
import json
import argparse
def _set_device(args):
device_type = args["device"]
gpus = []
for device in device_type:
if device == -1:
device = torch.device("cpu")
else:
device = torch.device("cuda:{}".format(device))
gpus.append(device)
args["device"] = gpus
def get_methods(object, spacing=20):
methodList = []
for method_name in dir(object):
try:
if callable(getattr(object, method_name)):
methodList.append(str(method_name))
except Exception:
methodList.append(str(method_name))
processFunc = (lambda s: ' '.join(s.split())) or (lambda s: s)
for method in methodList:
try:
print(str(method.ljust(spacing)) + ' ' +
processFunc(str(getattr(object, method).__doc__)[0:90]))
except Exception:
print(method.ljust(spacing) + ' ' + ' getattr() failed')
def load_model(args):
_set_device(args)
model = factory.get_model(args["model_name"], args)
model.load_checkpoint(args["checkpoint"])
return model
def main():
args = setup_parser().parse_args()
param = load_json(args.config)
args = vars(args) # Converting argparse Namespace to a dict.
args.update(param) # Add parameters from json
load_model(args)
def load_json(settings_path):
with open(settings_path) as data_file:
param = json.load(data_file)
return param
def setup_parser():
parser = argparse.ArgumentParser(description='Reproduce of multiple continual learning algorthms.')
parser.add_argument('--config', type=str, default='./exps/finetune.json',
help='Json file of settings.')
return parser
if __name__ == '__main__':
main()
|