Jezia commited on
Commit
c58bbe8
·
1 Parent(s): e43bbf3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -2
app.py CHANGED
@@ -3,5 +3,78 @@ import pickle as pickle
3
  import os
4
  import sys
5
 
6
- os.system("git clone https://github.com/NVlabs/stylegan3")
7
- sys.path.append('./stylegan3')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import os
4
  import sys
5
 
6
+ #os.system("git clone https://github.com/NVlabs/stylegan3")
7
+ #sys.path.append('./stylegan3')
8
+
9
+ model_names = {
10
+ 'AFHQv2-512-R': 'stylegan3-r-afhqv2-512x512.pkl',
11
+ 'FFHQ-1024-R': 'stylegan3-r-ffhq-1024x1024.pkl',
12
+ 'FFHQ-U-256-R': 'stylegan3-r-ffhqu-256x256.pkl',
13
+ 'FFHQ-U-1024-R': 'stylegan3-r-ffhqu-1024x1024.pkl',
14
+ 'MetFaces-1024-R': 'stylegan3-r-metfaces-1024x1024.pkl',
15
+ 'MetFaces-U-1024-R': 'stylegan3-r-metfacesu-1024x1024.pkl',
16
+ 'AFHQv2-512-T': 'stylegan3-t-afhqv2-512x512.pkl',
17
+ 'FFHQ-1024-T': 'stylegan3-t-ffhq-1024x1024.pkl',
18
+ 'FFHQ-U-256-T': 'stylegan3-t-ffhqu-256x256.pkl',
19
+ 'FFHQ-U-1024-T': 'stylegan3-t-ffhqu-1024x1024.pkl',
20
+ 'MetFaces-1024-T': 'stylegan3-t-metfaces-1024x1024.pkl',
21
+ 'MetFaces-U-1024-T': 'stylegan3-t-metfacesu-1024x1024.pkl',
22
+ }
23
+ model_dict = {
24
+ name: file_name
25
+ for name, file_name in model_names.items()
26
+ }
27
+
28
+ def fetch_model(url_or_path):
29
+ basename = os.path.basename(url_or_path)
30
+ if os.path.exists(basename):
31
+ return basename
32
+ else:
33
+ !wget -c '{url_or_path}'
34
+ return basename
35
+
36
+ def load_model(file_name: str, device: torch.device):
37
+ #path = torch.hub.download_url_to_file('https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/'+f'{file_name}',
38
+ # f'{file_name}')
39
+ base_url = "https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/"
40
+ network_url = base_url + f'{file_name}'
41
+
42
+ local_path = '/content/'f'{file_name}'
43
+ print(local_path)
44
+ with open(fetch_model(network_url), 'rb') as f:
45
+ model = pickle.load(f)['G_ema']
46
+ model.eval()
47
+ model.to(device)
48
+ with torch.inference_mode():
49
+ z = torch.zeros((1, model.z_dim)).to(device)
50
+ label = torch.zeros([1, model.c_dim], device=device)
51
+ model(z, label)
52
+ return model
53
+
54
+ def generate_image(model_name: str, seed: int, truncation_psi: float):
55
+ device = 'cuda'
56
+ model = model_dict[model_name]
57
+ model = load_model(model, device)
58
+ seed = int(np.clip(seed, 0, np.iinfo(np.uint32).max))
59
+ z = torch.from_numpy(np.random.RandomState(seed).randn(1, model.z_dim)).to(device)
60
+ label = torch.zeros([1, model.c_dim], device=device)
61
+
62
+ out = model(z, label, truncation_psi=truncation_psi)
63
+ out = (out.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
64
+ return out[0].cpu().numpy()
65
+
66
+ import gradio as gr
67
+ gr.Interface(
68
+ generate_image,
69
+ [
70
+ gr.inputs.Radio(list(model_names.keys()),
71
+ type='value',
72
+ default='FFHQ-1024-R',
73
+ label='Model'),
74
+ gr.inputs.Number(default=0, label='Seed'),
75
+ gr.inputs.Slider(
76
+ 0, 2, step=0.05, default=0.7, label='Truncation psi')
77
+ ],
78
+ gr.outputs.Image(type='numpy', label='Output')
79
+ ).launch(debug=True
80
+ )