AnimeRecBERT / template_args.py
mramazan's picture
Upload 60 files
426ffb5 verified
raw
history blame
5.38 kB
def set_template(args):
if args.template is None:
return
elif args.template.startswith('train_bert'):
args.mode = 'train'
args.dataset_code = 'AnimeRatings54M'
args.min_rating = 7
args.min_uc = 10
args.min_sc = 10
args.split = 'leave_one_out'
args.dataloader_code = 'bert'
batch = 128
args.train_batch_size = batch
args.val_batch_size = batch
args.test_batch_size = batch
args.train_negative_sampler_code = 'random'
args.train_negative_sample_size = 100
args.train_negative_sampling_seed = 0
args.test_negative_sampler_code = 'random'
args.test_negative_sample_size = 100
args.test_negative_sampling_seed = 98765
args.trainer_code = 'bert'
args.device = 'cuda'
args.num_gpu = 1
args.device_idx = '0'
args.optimizer = 'Adam'
args.lr = 0.001
args.enable_lr_schedule = True
args.decay_step = 25
args.gamma = 1.0
args.num_epochs = 5
args.metric_ks = [1, 5, 10, 20, 50, 100]
args.best_metric = 'NDCG@10'
args.model_code = 'bert'
args.model_init_seed = 0
args.bert_dropout = 0.2
args.weight_decay = 1e-4
args.bert_hidden_units = 256
args.bert_mask_prob = 0.15
args.bert_max_len = 128
args.bert_num_blocks = 2
args.bert_num_heads = 4
elif args.template.startswith('train_dae'):
args.mode = 'train'
args.dataset_code = 'ml-' + input('Input 1 for ml-1m, 20 for ml-20m: ') + 'm'
args.min_rating = 7
args.min_uc = 20
args.min_sc = 20
args.split = 'holdout'
args.dataset_split_seed = 98765
args.eval_set_size = 500 if args.dataset_code == 'ml-1m' else 20000
args.dataloader_code = 'ae'
batch = 128
args.train_batch_size = batch
args.val_batch_size = batch
args.test_batch_size = batch
args.trainer_code = 'dae'
args.device = 'cuda'
args.num_gpu = 1
args.device_idx = '0'
args.optimizer = 'Adam'
args.lr = 1e-3
args.enable_lr_schedule = False
args.weight_decay = 1e-4
args.num_epochs = 100 if args.dataset_code == 'ml-1m' else 200
args.metric_ks = [1, 5, 10, 20, 50, 100]
args.best_metric = 'NDCG@10'
args.model_code = 'dae'
args.model_init_seed = 0
args.dae_num_hidden = 2
args.dae_hidden_dim = 600
args.dae_latent_dim = 200
args.dae_dropout = 0.5
elif args.template.startswith('train_vae_search_beta'):
args.mode = 'train'
args.dataset_code = 'ml-' + input('Input 1 for ml-1m, 20 for ml-20m: ') + 'm'
args.min_rating = 0 if args.dataset_code == 'ml-1m' else 4
args.min_uc = 5
args.min_sc = 0
args.split = 'holdout'
args.dataset_split_seed = 98765
args.eval_set_size = 500 if args.dataset_code == 'ml-1m' else 10000
args.dataloader_code = 'ae'
batch = 128 if args.dataset_code == 'ml-1m' else 512
args.train_batch_size = batch
args.val_batch_size = batch
args.test_batch_size = batch
args.trainer_code = 'vae'
args.device = 'cuda'
args.num_gpu = 1
args.device_idx = '0'
args.optimizer = 'Adam'
args.lr = 1e-3
args.enable_lr_schedule = False
args.weight_decay = 0.01
args.num_epochs = 100 if args.dataset_code == 'ml-1m' else 200
args.metric_ks = [1, 5, 10, 20, 50, 100]
args.best_metric = 'NDCG@10'
args.total_anneal_steps = 3000 if args.dataset_code == 'ml-1m' else 20000
args.find_best_beta = True
args.model_code = 'vae'
args.model_init_seed = 0
args.vae_num_hidden = 2
args.vae_hidden_dim = 600
args.vae_latent_dim = 200
args.vae_dropout = 0.5
elif args.template.startswith('train_vae_give_beta'):
args.mode = 'train'
args.dataset_code = 'ml-' + input('Input 1 for ml-1m, 20 for ml-20m: ') + 'm'
args.min_rating = 0 if args.dataset_code == 'ml-1m' else 4
args.min_uc = 5
args.min_sc = 0
args.split = 'holdout'
args.dataset_split_seed = 98765
args.eval_set_size = 500 if args.dataset_code == 'ml-1m' else 10000
args.dataloader_code = 'ae'
batch = 128 if args.dataset_code == 'ml-1m' else 512
args.train_batch_size = batch
args.val_batch_size = batch
args.test_batch_size = batch
args.trainer_code = 'vae'
args.device = 'cuda'
args.num_gpu = 1
args.device_idx = '0'
args.optimizer = 'Adam'
args.lr = 1e-3
args.enable_lr_schedule = False
args.weight_decay = 0.01
args.num_epochs = 100 if args.dataset_code == 'ml-1m' else 200
args.metric_ks = [1, 5, 10, 20, 50, 100]
args.best_metric = 'NDCG@100'
args.find_best_beta = False
args.anneal_cap = 0.342
args.total_anneal_steps = 3000 if args.dataset_code == 'ml-1m' else 20000
args.model_code = 'vae'
args.model_init_seed = 0
args.vae_num_hidden = 2
args.vae_hidden_dim = 600
args.vae_latent_dim = 200
args.vae_dropout = 0.5