lif31up's picture
big update
c4c047d
raw
history blame
1.9 kB
import argparse
import torchvision as tv
from src.train import train
from src.evaluate import evaluate
def main():
# eval(default)
parser = argparse.ArgumentParser(description="Few-shot learning using Prototypical Network")
parser.add_argument("--model", type=str, help="path of your model")
parser.add_argument("--dataset", type=str, help="path of your dataset")
# train
subparser = parser.add_subparsers(title="subcommands", dest="subcommand")
parser_train = subparser.add_parser("train", help="train your model")
parser_train.add_argument("--dataset", type=str, help="path to your dataset")
parser_train.add_argument("--save_to", type=str, help="path to save your model")
parser_train.add_argument("--n_way", type=int, help="number of classes per episode")
parser_train.add_argument("--k_shot", type=int, help="number of support samples per class")
parser_train.add_argument("--n_query", type=int, help="number of query samples per class")
parser_train.add_argument("--iters", type=int, help="how much iteration your model does for an episode")
parser_train.add_argument("--epochs", type=int, help="how much epochs your model does for training")
parser_train.set_defaults(func=lambda kwargs: train(
DATASET=kwargs.dataset,
SAVE_TO=kwargs.save_to,
N_WAY=kwargs.n_way,
K_SHOT=kwargs.k_shot,
N_QUERY=kwargs.n_query)
) # parser_train.set_defaults()
# download dataset
parser_download = subparser.add_parser("download", help="download dataset")
parser_download.add_argument("--path", type=str, help="path to download dataset")
parser_download.set_defaults(func=lambda kwargs: tv.datasets.Omniglot(root=kwargs.path, background=True, download=True))
# parse logic
args = parser.parse_args()
if hasattr(args, 'func'): args.func(args)
else: evaluate(MODEL=args.model, DATASET=args.dataset)
# main():
if __name__ == "__main__": main()