Spaces:
Runtime error
Runtime error
Mehdi Cherti
commited on
Commit
·
9e9d0ce
1
Parent(s):
27911d6
small fix
Browse files- test_ddgan.py +16 -16
test_ddgan.py
CHANGED
@@ -433,7 +433,7 @@ def sample_and_test(args):
|
|
433 |
break
|
434 |
print("PATH", path)
|
435 |
suffix = '_' + args.eval_name if args.eval_name else ""
|
436 |
-
dest = './saved_info/dd_gan/{}/{}/eval_{}{}.json'.format(cfg['dataset'],
|
437 |
if (args.compute_fid or args.compute_clip_score or args.compute_image_reward) and os.path.exists(dest):
|
438 |
continue
|
439 |
print("Load epoch", args.epoch_id, "checkpoint")
|
@@ -496,7 +496,7 @@ def sample_and_test(args):
|
|
496 |
if args.guidance_scale:
|
497 |
fake_sample = sample_from_model_classifier_free_guidance(pos_coeff, netG, args.num_timesteps, x_t_1,T, args, text_encoder, cond=cond, guidance_scale=args.guidance_scale)
|
498 |
else:
|
499 |
-
fake_sample = sample(generator=
|
500 |
fake_sample = to_range_0_1(fake_sample)
|
501 |
|
502 |
if args.compute_fid:
|
@@ -600,25 +600,25 @@ def sample_and_test(args):
|
|
600 |
if __name__ == '__main__':
|
601 |
parser = argparse.ArgumentParser('ddgan parameters')
|
602 |
parser.add_argument('--name', type=str, default="", help="model config name")
|
603 |
-
parser.add_argument('--
|
604 |
parser.add_argument('--seed', type=int, default=1024, help='seed used for initialization')
|
605 |
-
parser.add_argument('--
|
606 |
help='whether or not compute FID')
|
607 |
-
parser.add_argument('--
|
608 |
help='whether or not compute CLIP score')
|
609 |
-
parser.add_argument('--
|
610 |
help='whether or not compute CLIP score')
|
611 |
|
612 |
-
parser.add_argument('--
|
613 |
-
parser.add_argument('--
|
614 |
-
parser.add_argument('--
|
615 |
-
parser.add_argument('--
|
616 |
-
parser.add_argument('--
|
617 |
-
parser.add_argument('--
|
618 |
-
parser.add_argument('--
|
619 |
-
parser.add_argument('--
|
620 |
-
parser.add_argument('--
|
621 |
-
parser.add_argument('--
|
622 |
args = parser.parse_args()
|
623 |
sample_and_test(args)
|
624 |
|
|
|
433 |
break
|
434 |
print("PATH", path)
|
435 |
suffix = '_' + args.eval_name if args.eval_name else ""
|
436 |
+
dest = './saved_info/dd_gan/{}/{}/eval_{}{}.json'.format(cfg['dataset'], args.name, args.epoch_id, suffix)
|
437 |
if (args.compute_fid or args.compute_clip_score or args.compute_image_reward) and os.path.exists(dest):
|
438 |
continue
|
439 |
print("Load epoch", args.epoch_id, "checkpoint")
|
|
|
496 |
if args.guidance_scale:
|
497 |
fake_sample = sample_from_model_classifier_free_guidance(pos_coeff, netG, args.num_timesteps, x_t_1,T, args, text_encoder, cond=cond, guidance_scale=args.guidance_scale)
|
498 |
else:
|
499 |
+
fake_sample = sample(generator=netG, x_init=x_init, cond=cond)
|
500 |
fake_sample = to_range_0_1(fake_sample)
|
501 |
|
502 |
if args.compute_fid:
|
|
|
600 |
if __name__ == '__main__':
|
601 |
parser = argparse.ArgumentParser('ddgan parameters')
|
602 |
parser.add_argument('--name', type=str, default="", help="model config name")
|
603 |
+
parser.add_argument('--batch-size', type=int, default=16)
|
604 |
parser.add_argument('--seed', type=int, default=1024, help='seed used for initialization')
|
605 |
+
parser.add_argument('--compute-fid', action='store_true', default=False,
|
606 |
help='whether or not compute FID')
|
607 |
+
parser.add_argument('--compute-clip-score', action='store_true', default=False,
|
608 |
help='whether or not compute CLIP score')
|
609 |
+
parser.add_argument('--compute-image-reward', action='store_true', default=False,
|
610 |
help='whether or not compute CLIP score')
|
611 |
|
612 |
+
parser.add_argument('--clip-model', type=str,default="ViT-L/14")
|
613 |
+
parser.add_argument('--eval-name', type=str,default="")
|
614 |
+
parser.add_argument('--epoch-id', type=int,default=-1)
|
615 |
+
parser.add_argument('--guidance-scale', type=float,default=0)
|
616 |
+
parser.add_argument('--dynamic-thresholding-quantile', type=float,default=0)
|
617 |
+
parser.add_argument('--cond-text', type=str,default="a chair in the form of an avocado")
|
618 |
+
parser.add_argument('--scale-factor-h', type=int,default=1)
|
619 |
+
parser.add_argument('--scale-factor-w', type=int,default=1)
|
620 |
+
parser.add_argument('--scale-method', type=str,default="convolutional")
|
621 |
+
parser.add_argument('--nb-images-for-fid', type=int, default=0)
|
622 |
args = parser.parse_args()
|
623 |
sample_and_test(args)
|
624 |
|