Update vtoonify/train_vtoonify_d.py
Browse files- vtoonify/train_vtoonify_d.py +84 -1
vtoonify/train_vtoonify_d.py
CHANGED
|
@@ -391,6 +391,7 @@ def train(args, generator, discriminator, g_optim, d_optim, g_ema, percept, pars
|
|
| 391 |
|
| 392 |
|
| 393 |
|
|
|
|
| 394 |
if __name__ == "__main__":
|
| 395 |
|
| 396 |
device = "cuda"
|
|
@@ -430,4 +431,86 @@ if __name__ == "__main__":
|
|
| 430 |
if not args.pretrain:
|
| 431 |
generator.encoder.load_state_dict(torch.load(args.encoder_path, map_location=lambda storage, loc: storage)["g_ema"])
|
| 432 |
# we initialize the fusion modules to map f_G \otimes f_E to f_G.
|
| 433 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 391 |
|
| 392 |
|
| 393 |
|
| 394 |
+
|
| 395 |
if __name__ == "__main__":
|
| 396 |
|
| 397 |
device = "cuda"
|
|
|
|
| 431 |
if not args.pretrain:
|
| 432 |
generator.encoder.load_state_dict(torch.load(args.encoder_path, map_location=lambda storage, loc: storage)["g_ema"])
|
| 433 |
# we initialize the fusion modules to map f_G \otimes f_E to f_G.
|
| 434 |
+
for k in generator.fusion_out:
|
| 435 |
+
k.conv.weight.data *= 0.01
|
| 436 |
+
k.conv.weight[:,0:k.conv.weight.shape[0],1,1].data += torch.eye(k.conv.weight.shape[0]).cuda()
|
| 437 |
+
for k in generator.fusion_skip:
|
| 438 |
+
k.weight.data *= 0.01
|
| 439 |
+
k.weight[:,0:k.weight.shape[0],1,1].data += torch.eye(k.weight.shape[0]).cuda()
|
| 440 |
+
|
| 441 |
+
accumulate(g_ema.encoder, generator.encoder, 0)
|
| 442 |
+
accumulate(g_ema.fusion_out, generator.fusion_out, 0)
|
| 443 |
+
accumulate(g_ema.fusion_skip, generator.fusion_skip, 0)
|
| 444 |
+
|
| 445 |
+
g_parameters = list(generator.encoder.parameters())
|
| 446 |
+
if not args.pretrain:
|
| 447 |
+
g_parameters = g_parameters + list(generator.fusion_out.parameters()) + list(generator.fusion_skip.parameters())
|
| 448 |
+
|
| 449 |
+
g_optim = optim.Adam(
|
| 450 |
+
g_parameters,
|
| 451 |
+
lr=args.lr,
|
| 452 |
+
betas=(0.9, 0.99),
|
| 453 |
+
)
|
| 454 |
+
|
| 455 |
+
if args.distributed:
|
| 456 |
+
generator = nn.parallel.DistributedDataParallel(
|
| 457 |
+
generator,
|
| 458 |
+
device_ids=[args.local_rank],
|
| 459 |
+
output_device=args.local_rank,
|
| 460 |
+
broadcast_buffers=False,
|
| 461 |
+
find_unused_parameters=True,
|
| 462 |
+
)
|
| 463 |
+
|
| 464 |
+
parsingpredictor = BiSeNet(n_classes=19)
|
| 465 |
+
parsingpredictor.load_state_dict(torch.load(args.faceparsing_path, map_location=lambda storage, loc: storage))
|
| 466 |
+
parsingpredictor.to(device).eval()
|
| 467 |
+
requires_grad(parsingpredictor, False)
|
| 468 |
+
|
| 469 |
+
# we apply gaussian blur to the images to avoid flickers caused during downsampling
|
| 470 |
+
down = Downsample(kernel=[1, 3, 3, 1], factor=2).to(device)
|
| 471 |
+
requires_grad(down, False)
|
| 472 |
+
|
| 473 |
+
directions = torch.tensor(np.load(args.direction_path)).to(device)
|
| 474 |
+
|
| 475 |
+
# load style codes of DualStyleGAN
|
| 476 |
+
exstyles = np.load(args.exstyle_path, allow_pickle='TRUE').item()
|
| 477 |
+
if args.local_rank == 0 and not os.path.exists('checkpoint/%s/exstyle_code.npy'%(args.name)):
|
| 478 |
+
np.save('checkpoint/%s/exstyle_code.npy'%(args.name), exstyles, allow_pickle=True)
|
| 479 |
+
styles = []
|
| 480 |
+
with torch.no_grad():
|
| 481 |
+
for stylename in exstyles.keys():
|
| 482 |
+
exstyle = torch.tensor(exstyles[stylename]).to(device)
|
| 483 |
+
exstyle = g_ema.zplus2wplus(exstyle)
|
| 484 |
+
styles += [exstyle]
|
| 485 |
+
styles = torch.cat(styles, dim=0)
|
| 486 |
+
|
| 487 |
+
if not args.pretrain:
|
| 488 |
+
discriminator = ConditionalDiscriminator(256, use_condition=True, style_num = styles.size(0)).to(device)
|
| 489 |
+
|
| 490 |
+
d_optim = optim.Adam(
|
| 491 |
+
discriminator.parameters(),
|
| 492 |
+
lr=args.lr,
|
| 493 |
+
betas=(0.9, 0.99),
|
| 494 |
+
)
|
| 495 |
+
|
| 496 |
+
if args.distributed:
|
| 497 |
+
discriminator = nn.parallel.DistributedDataParallel(
|
| 498 |
+
discriminator,
|
| 499 |
+
device_ids=[args.local_rank],
|
| 500 |
+
output_device=args.local_rank,
|
| 501 |
+
broadcast_buffers=False,
|
| 502 |
+
find_unused_parameters=True,
|
| 503 |
+
)
|
| 504 |
+
|
| 505 |
+
percept = lpips.PerceptualLoss(model="net-lin", net="vgg", use_gpu=device.startswith("cuda"), gpu_ids=[args.local_rank])
|
| 506 |
+
requires_grad(percept.model.net, False)
|
| 507 |
+
|
| 508 |
+
pspencoder = load_psp_standalone(args.style_encoder_path, device)
|
| 509 |
+
|
| 510 |
+
if args.local_rank == 0:
|
| 511 |
+
print('Load models and data successfully loaded!')
|
| 512 |
+
|
| 513 |
+
if args.pretrain:
|
| 514 |
+
pretrain(args, generator, g_optim, g_ema, parsingpredictor, down, directions, styles, device)
|
| 515 |
+
else:
|
| 516 |
+
train(args, generator, discriminator, g_optim, d_optim, g_ema, percept, parsingpredictor, down, pspencoder, directions, styles, device)
|