Мясников Филипп Сергеевич commited on
Commit
025efa1
1 Parent(s): d0396da
Files changed (1) hide show
  1. app.py +22 -4
app.py CHANGED
@@ -28,11 +28,18 @@ from PIL import Image
28
  import torch
29
  import torchvision.transforms as transforms
30
  from argparse import Namespace
 
31
  from e4e.models.psp import pSp
32
  from e4e.models.encoders import psp_encoders
33
  from util import *
34
  from huggingface_hub import hf_hub_download
35
 
 
 
 
 
 
 
36
  device= 'cpu'
37
  ffhq_model_path = hf_hub_download(repo_id="bankholdup/stylegan_petbreeder", filename="e4e_ffhq512.pt")
38
 
@@ -95,6 +102,15 @@ cat_decoder.eval()
95
  cat_decoder.to(device)
96
 
97
 
 
 
 
 
 
 
 
 
 
98
  def gen_im(model_type='ffhq'):
99
  if model_type=='ffhq':
100
  imgs, _ = ffhq_decoder([ffhq_codes], input_is_latent=True, randomize_noise=False, return_latents=True)
@@ -109,15 +125,17 @@ def gen_im(model_type='ffhq'):
109
 
110
  def inference(img):
111
  img.save('out.jpg')
112
- aligned_face = align_face('out.jpg')
 
 
113
 
114
- ffhq_codes = ffhq_encoder(aligned_face.unsqueeze(0).to(device).float())
115
  ffhq_codes = ffhq_codes + ffhq_latent_avg.repeat(ffhq_codes.shape[0], 1, 1)
116
 
117
- cat_codes = cat_encoder(aligned_face.unsqueeze(0).to(device).float())
118
  cat_codes = cat_codes + ffhq_latent_avg.repeat(cat_codes.shape[0], 1, 1)
119
 
120
- dog_codes = dog_encoder(aligned_face.unsqueeze(0).to(device).float())
121
  dog_codes = dog_codes + ffhq_latent_avg.repeat(dog_codes.shape[0], 1, 1)
122
 
123
  animal = "cat"
 
28
  import torch
29
  import torchvision.transforms as transforms
30
  from argparse import Namespace
31
+ from e4e.utils.common import tensor2im
32
  from e4e.models.psp import pSp
33
  from e4e.models.encoders import psp_encoders
34
  from util import *
35
  from huggingface_hub import hf_hub_download
36
 
37
+ transform = transforms.Compose([
38
+ transforms.Resize((256, 256)),
39
+ transforms.ToTensor(),
40
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
41
+ resize_dims = (256, 256)
42
+
43
  device= 'cpu'
44
  ffhq_model_path = hf_hub_download(repo_id="bankholdup/stylegan_petbreeder", filename="e4e_ffhq512.pt")
45
 
 
102
  cat_decoder.to(device)
103
 
104
 
105
+ def run_alignment(image_path):
106
+ import dlib
107
+ from e4e.utils.alignment import align_face
108
+ predictor = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat")
109
+ aligned_image = align_face(filepath=image_path, predictor=predictor)
110
+ print("Aligned image has shape: {}".format(aligned_image.size))
111
+ return aligned_image
112
+
113
+
114
  def gen_im(model_type='ffhq'):
115
  if model_type=='ffhq':
116
  imgs, _ = ffhq_decoder([ffhq_codes], input_is_latent=True, randomize_noise=False, return_latents=True)
 
125
 
126
  def inference(img):
127
  img.save('out.jpg')
128
+ #aligned_face = align_face('out.jpg')
129
+ input_image = run_alignment(image_path)
130
+ transformed_image = transform(input_image)
131
 
132
+ ffhq_codes = ffhq_encoder(transformed_image.unsqueeze(0).to(device).float())
133
  ffhq_codes = ffhq_codes + ffhq_latent_avg.repeat(ffhq_codes.shape[0], 1, 1)
134
 
135
+ cat_codes = cat_encoder(transformed_image.unsqueeze(0).to(device).float())
136
  cat_codes = cat_codes + ffhq_latent_avg.repeat(cat_codes.shape[0], 1, 1)
137
 
138
+ dog_codes = dog_encoder(transformed_image.unsqueeze(0).to(device).float())
139
  dog_codes = dog_codes + ffhq_latent_avg.repeat(dog_codes.shape[0], 1, 1)
140
 
141
  animal = "cat"