Мясников Филипп Сергеевич commited on
Commit
8c42239
1 Parent(s): 1816654

like in colab

Browse files
Files changed (1) hide show
  1. app.py +87 -29
app.py CHANGED
@@ -16,7 +16,6 @@ from tqdm import tqdm
16
  import lpips
17
  from model import *
18
 
19
-
20
  #from e4e_projection import projection as e4e_projection
21
 
22
  from copy import deepcopy
@@ -30,44 +29,103 @@ import torch
30
  import torchvision.transforms as transforms
31
  from argparse import Namespace
32
  from e4e.models.psp import pSp
 
33
  from util import *
34
  from huggingface_hub import hf_hub_download
35
 
36
  device= 'cpu'
37
- model_path_e = hf_hub_download(repo_id="bankholdup/stylegan_petbreeder", filename="e4e_ffhq512_cat.pt")
38
- ckpt = torch.load(model_path_e, map_location='cpu')
39
- opts = ckpt['opts']
40
- opts['checkpoint_path'] = model_path_e
41
- opts= Namespace(**opts)
42
- net = pSp(opts, device).eval().to(device)
43
-
44
- @ torch.no_grad()
45
- def projection(img, name, device='cuda'):
46
-
47
-
48
- transform = transforms.Compose(
49
- [
50
- transforms.Resize(256),
51
- transforms.CenterCrop(256),
52
- transforms.ToTensor(),
53
- transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
54
- ]
55
- )
56
- img = transform(img).unsqueeze(0).to(device)
57
- images, w_plus = net(img, randomize_noise=False, return_latents=True)
58
- result_file = {}
59
- result_file['latent'] = w_plus[0]
60
- torch.save(result_file, name)
61
- return w_plus[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
 
64
  def inference(img):
65
  img.save('out.jpg')
66
  aligned_face = align_face('out.jpg')
 
 
 
 
 
 
 
 
 
67
 
68
- my_w = projection(aligned_face, "test.pt", device).unsqueeze(0)
69
-
70
- npimage = my_w.numpy()
71
  imageio.imwrite('filename.jpeg', npimage)
72
  return 'filename.jpeg'
73
 
 
16
  import lpips
17
  from model import *
18
 
 
19
  #from e4e_projection import projection as e4e_projection
20
 
21
  from copy import deepcopy
 
29
  import torchvision.transforms as transforms
30
  from argparse import Namespace
31
  from e4e.models.psp import pSp
32
+ from models.encoders import psp_encoders
33
  from util import *
34
  from huggingface_hub import hf_hub_download
35
 
36
  device= 'cpu'
37
+ ffhq_model_path = hf_hub_download(repo_id="bankholdup/stylegan_petbreeder", filename="e4e_ffhq512.pt")
38
+
39
+ ffhq_ckpt = torch.load(ffhq_model_path, map_location='cpu')
40
+ ffhq_latent_avg = ffhq_ckpt['latent_avg'].to('cuda:0')
41
+ ffhq_opts = ffhq_ckpt['opts']
42
+ ffhq_opts['checkpoint_path'] = ffhq_model_path
43
+ ffhq_opts= Namespace(**ffhq_opts)
44
+
45
+ ffhq_encoder = psp_encoders.Encoder4Editing(50, 'ir_se', ffhq_opts)
46
+ ffhq_e_filt = {k[len('encoder') + 1:]: v for k, v in ffhq_ckpt['state_dict'].items() if k[:len('encoder')] == 'encoder'}
47
+ ffhq_encoder.load_state_dict(ffhq_e_filt, strict=True)
48
+ ffhq_encoder.eval()
49
+ ffhq_encoder.to(device)
50
+
51
+ ffhq_decoder = Generator(512, 512, 8, channel_multiplier=2)
52
+ ffhq_d_filt = {k[len('decoder') + 1:]: v for k, v in ffhq_ckpt['state_dict'].items() if k[:len('decoder')] == 'decoder'}
53
+ ffhq_decoder.load_state_dict(ffhq_d_filt, strict=True)
54
+ ffhq_decoder.eval()
55
+ ffhq_decoder.to(device)
56
+ clear_output()
57
+
58
+ dog_model_path = hf_hub_download(repo_id="bankholdup/stylegan_petbreeder", filename="e4e_ffhq512_dog.pt")
59
+
60
+ dog_ckpt = torch.load(dog_model_path, map_location='cpu')
61
+ dog_latent_avg = dog_ckpt['latent_avg'].to('cuda:0')
62
+ dog_opts = dog_ckpt['opts']
63
+ dog_opts['checkpoint_path'] = dog_model_path
64
+ dog_opts= Namespace(**dog_opts)
65
+
66
+ dog_encoder = psp_encoders.Encoder4Editing(50, 'ir_se', dog_opts)
67
+ dog_e_filt = {k[len('encoder') + 1:]: v for k, v in dog_ckpt['state_dict'].items() if k[:len('encoder')] == 'encoder'}
68
+ dog_encoder.load_state_dict(dog_e_filt, strict=True)
69
+ dog_encoder.eval()
70
+ dog_encoder.to(device)
71
+
72
+ dog_decoder = Generator(512, 512, 8, channel_multiplier=2)
73
+ dog_d_filt = {k[len('decoder') + 1:]: v for k, v in dog_ckpt['state_dict'].items() if k[:len('decoder')] == 'decoder'}
74
+ dog_decoder.load_state_dict(dog_d_filt, strict=True)
75
+ dog_decoder.eval()
76
+ dog_decoder.to(device)
77
+ clear_output()
78
+
79
+ cat_model_path = hf_hub_download(repo_id="bankholdup/stylegan_petbreeder", filename="e4e_ffhq512_cat.pt")
80
+
81
+ cat_ckpt = torch.load(cat_model_path, map_location='cpu')
82
+ cat_latent_avg = cat_ckpt['latent_avg'].to('cuda:0')
83
+ cat_opts = cat_ckpt['opts']
84
+ cat_opts['checkpoint_path'] = cat_model_path
85
+ cat_opts= Namespace(**cat_opts)
86
+
87
+ cat_encoder = psp_encoders.Encoder4Editing(50, 'ir_se', cat_opts)
88
+ cat_e_filt = {k[len('encoder') + 1:]: v for k, v in cat_ckpt['state_dict'].items() if k[:len('encoder')] == 'encoder'}
89
+ cat_encoder.load_state_dict(cat_e_filt, strict=True)
90
+ cat_encoder.eval()
91
+ cat_encoder.to(device)
92
+
93
+ cat_decoder = Generator(512, 512, 8, channel_multiplier=2)
94
+ cat_d_filt = {k[len('decoder') + 1:]: v for k, v in cat_ckpt['state_dict'].items() if k[:len('decoder')] == 'decoder'}
95
+ cat_decoder.load_state_dict(cat_d_filt, strict=True)
96
+ cat_decoder.eval()
97
+ cat_decoder.to(device)
98
+ clear_output()
99
+
100
+
101
+ def gen_im(model_type='ffhq'):
102
+ if model_type=='ffhq':
103
+ imgs, _ = ffhq_decoder([ffhq_codes], input_is_latent=True, randomize_noise=False, return_latents=True)
104
+ elif model_type=='dog':
105
+ imgs, _ = dog_decoder([dog_codes], input_is_latent=True, randomize_noise=False, return_latents=True)
106
+ elif model_type=='cat':
107
+ imgs, _ = cat_decoder([cat_codes], input_is_latent=True, randomize_noise=False, return_latents=True)
108
+ else:
109
+ imgs, _ = custom_decoder([custom_codes], input_is_latent=True, randomize_noise=False, return_latents=True)
110
+ return tensor2im(imgs[0])
111
 
112
 
113
  def inference(img):
114
  img.save('out.jpg')
115
  aligned_face = align_face('out.jpg')
116
+
117
+ ffhq_codes = ffhq_encoder(aligned_face.unsqueeze(0).to("cuda").float())
118
+ ffhq_codes = ffhq_codes + ffhq_latent_avg.repeat(ffhq_codes.shape[0], 1, 1)
119
+
120
+ cat_codes = cat_encoder(aligned_face.unsqueeze(0).to("cuda").float())
121
+ cat_codes = cat_codes + ffhq_latent_avg.repeat(cat_codes.shape[0], 1, 1)
122
+
123
+ dog_codes = dog_encoder(aligned_face.unsqueeze(0).to("cuda").float())
124
+ dog_codes = dog_codes + ffhq_latent_avg.repeat(dog_codes.shape[0], 1, 1)
125
 
126
+ animal = "cat"
127
+ npimage = gen_im(animal)
128
+
129
  imageio.imwrite('filename.jpeg', npimage)
130
  return 'filename.jpeg'
131