Мясников Филипп Сергеевич commited on
Commit
9dc5640
1 Parent(s): bb1243f
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -37,7 +37,7 @@ device= 'cpu'
37
  ffhq_model_path = hf_hub_download(repo_id="bankholdup/stylegan_petbreeder", filename="e4e_ffhq512.pt")
38
 
39
  ffhq_ckpt = torch.load(ffhq_model_path, map_location='cpu')
40
- ffhq_latent_avg = ffhq_ckpt['latent_avg'].to('cuda:0')
41
  ffhq_opts = ffhq_ckpt['opts']
42
  ffhq_opts['checkpoint_path'] = ffhq_model_path
43
  ffhq_opts= Namespace(**ffhq_opts)
@@ -58,7 +58,7 @@ clear_output()
58
  dog_model_path = hf_hub_download(repo_id="bankholdup/stylegan_petbreeder", filename="e4e_ffhq512_dog.pt")
59
 
60
  dog_ckpt = torch.load(dog_model_path, map_location='cpu')
61
- dog_latent_avg = dog_ckpt['latent_avg'].to('cuda:0')
62
  dog_opts = dog_ckpt['opts']
63
  dog_opts['checkpoint_path'] = dog_model_path
64
  dog_opts= Namespace(**dog_opts)
@@ -79,7 +79,7 @@ clear_output()
79
  cat_model_path = hf_hub_download(repo_id="bankholdup/stylegan_petbreeder", filename="e4e_ffhq512_cat.pt")
80
 
81
  cat_ckpt = torch.load(cat_model_path, map_location='cpu')
82
- cat_latent_avg = cat_ckpt['latent_avg'].to('cuda:0')
83
  cat_opts = cat_ckpt['opts']
84
  cat_opts['checkpoint_path'] = cat_model_path
85
  cat_opts= Namespace(**cat_opts)
@@ -114,13 +114,13 @@ def inference(img):
114
  img.save('out.jpg')
115
  aligned_face = align_face('out.jpg')
116
 
117
- ffhq_codes = ffhq_encoder(aligned_face.unsqueeze(0).to("cuda").float())
118
  ffhq_codes = ffhq_codes + ffhq_latent_avg.repeat(ffhq_codes.shape[0], 1, 1)
119
 
120
- cat_codes = cat_encoder(aligned_face.unsqueeze(0).to("cuda").float())
121
  cat_codes = cat_codes + ffhq_latent_avg.repeat(cat_codes.shape[0], 1, 1)
122
 
123
- dog_codes = dog_encoder(aligned_face.unsqueeze(0).to("cuda").float())
124
  dog_codes = dog_codes + ffhq_latent_avg.repeat(dog_codes.shape[0], 1, 1)
125
 
126
  animal = "cat"
 
37
  ffhq_model_path = hf_hub_download(repo_id="bankholdup/stylegan_petbreeder", filename="e4e_ffhq512.pt")
38
 
39
  ffhq_ckpt = torch.load(ffhq_model_path, map_location='cpu')
40
+ ffhq_latent_avg = ffhq_ckpt['latent_avg'].to(device)
41
  ffhq_opts = ffhq_ckpt['opts']
42
  ffhq_opts['checkpoint_path'] = ffhq_model_path
43
  ffhq_opts= Namespace(**ffhq_opts)
 
58
  dog_model_path = hf_hub_download(repo_id="bankholdup/stylegan_petbreeder", filename="e4e_ffhq512_dog.pt")
59
 
60
  dog_ckpt = torch.load(dog_model_path, map_location='cpu')
61
+ dog_latent_avg = dog_ckpt['latent_avg'].to(device)
62
  dog_opts = dog_ckpt['opts']
63
  dog_opts['checkpoint_path'] = dog_model_path
64
  dog_opts= Namespace(**dog_opts)
 
79
  cat_model_path = hf_hub_download(repo_id="bankholdup/stylegan_petbreeder", filename="e4e_ffhq512_cat.pt")
80
 
81
  cat_ckpt = torch.load(cat_model_path, map_location='cpu')
82
+ cat_latent_avg = cat_ckpt['latent_avg'].to(device)
83
  cat_opts = cat_ckpt['opts']
84
  cat_opts['checkpoint_path'] = cat_model_path
85
  cat_opts= Namespace(**cat_opts)
 
114
  img.save('out.jpg')
115
  aligned_face = align_face('out.jpg')
116
 
117
+ ffhq_codes = ffhq_encoder(aligned_face.unsqueeze(0).to(device).float())
118
  ffhq_codes = ffhq_codes + ffhq_latent_avg.repeat(ffhq_codes.shape[0], 1, 1)
119
 
120
+ cat_codes = cat_encoder(aligned_face.unsqueeze(0).to(device).float())
121
  cat_codes = cat_codes + ffhq_latent_avg.repeat(cat_codes.shape[0], 1, 1)
122
 
123
+ dog_codes = dog_encoder(aligned_face.unsqueeze(0).to(device).float())
124
  dog_codes = dog_codes + ffhq_latent_avg.repeat(dog_codes.shape[0], 1, 1)
125
 
126
  animal = "cat"