Jiatao Gu commited on
Commit
df44b7d
1 Parent(s): 77c753d

fix bug for cpu running

Browse files
Files changed (3) hide show
  1. app.py +4 -14
  2. gradio_queue.db +0 -0
  3. training/networks.py +3 -3
app.py CHANGED
@@ -9,29 +9,18 @@ import time
9
  import legacy
10
  import torch
11
  import glob
12
-
13
  import cv2
14
- import signal
15
  from torch_utils import misc
16
  from renderer import Renderer
17
  from training.networks import Generator
18
  from huggingface_hub import hf_hub_download
19
 
20
 
21
- device = torch.device('cuda')
22
  port = int(sys.argv[1]) if len(sys.argv) > 1 else 21111
23
 
24
 
25
-
26
- def handler(signum, frame):
27
- res = input("Ctrl-c was pressed. Do you really want to exit? y/n ")
28
- if res == 'y':
29
- gr.close_all()
30
- exit(1)
31
-
32
- signal.signal(signal.SIGINT, handler)
33
-
34
-
35
  def set_random_seed(seed):
36
  torch.manual_seed(seed)
37
  np.random.seed(seed)
@@ -202,11 +191,12 @@ yaw = gr.inputs.Slider(minimum=-1, maximum=1, default=0, label="yaw")
202
  pitch = gr.inputs.Slider(minimum=-1, maximum=1, default=0, label="pitch")
203
  roll = gr.inputs.Slider(minimum=-1, maximum=1, default=0, label="roll (optional, not suggested)")
204
  fov = gr.inputs.Slider(minimum=9, maximum=15, default=12, label="fov")
205
- css = ".output_image {height: 40rem !important; width: 100% !important;}"
206
 
207
  gr.Interface(fn=f_synthesis,
208
  inputs=[model_name, model_find, render_option, early, trunc, seed1, seed2, mix1, mix2, yaw, pitch, roll, fov, "state"],
209
  title="Interctive Web Demo for StyleNeRF (ICLR 2022)",
 
210
  outputs=["image", "state"],
211
  layout='unaligned',
212
  css=css, theme='dark-huggingface',
 
9
  import legacy
10
  import torch
11
  import glob
 
12
  import cv2
13
+
14
  from torch_utils import misc
15
  from renderer import Renderer
16
  from training.networks import Generator
17
  from huggingface_hub import hf_hub_download
18
 
19
 
20
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
21
  port = int(sys.argv[1]) if len(sys.argv) > 1 else 21111
22
 
23
 
 
 
 
 
 
 
 
 
 
 
24
  def set_random_seed(seed):
25
  torch.manual_seed(seed)
26
  np.random.seed(seed)
 
191
  pitch = gr.inputs.Slider(minimum=-1, maximum=1, default=0, label="pitch")
192
  roll = gr.inputs.Slider(minimum=-1, maximum=1, default=0, label="roll (optional, not suggested)")
193
  fov = gr.inputs.Slider(minimum=9, maximum=15, default=12, label="fov")
194
+ css = ".output-image, .input-image, .image-preview {height: 600px !important} "
195
 
196
  gr.Interface(fn=f_synthesis,
197
  inputs=[model_name, model_find, render_option, early, trunc, seed1, seed2, mix1, mix2, yaw, pitch, roll, fov, "state"],
198
  title="Interctive Web Demo for StyleNeRF (ICLR 2022)",
199
+ description="Demo for ICLR 2022 Papaer: A Style-based 3D-Aware Generator for High-resolution Image Synthesis. Currently the demo runs on CPU only."
200
  outputs=["image", "state"],
201
  layout='unaligned',
202
  css=css, theme='dark-huggingface',
gradio_queue.db ADDED
Binary file (856 kB). View file
 
training/networks.py CHANGED
@@ -794,7 +794,7 @@ class SynthesisBlock(torch.nn.Module):
794
  def forward(self, x, img, ws, force_fp32=False, fused_modconv=None, add_on=None, block_noise=None, disable_rgb=False, **layer_kwargs):
795
  misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim])
796
  w_iter = iter(ws.unbind(dim=1))
797
- dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
798
  memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format
799
  if fused_modconv is None:
800
  with misc.suppress_tracer_warnings(): # this value will be treated as a constant
@@ -937,7 +937,7 @@ class SynthesisBlock3(torch.nn.Module):
937
 
938
  def forward(self, x, img, ws, force_fp32=False, add_on=None, disable_rgb=False, **layer_kwargs):
939
  w_iter = iter(ws.unbind(dim=1))
940
- dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
941
  memory_format = torch.contiguous_format
942
 
943
  # Main layers.
@@ -1141,7 +1141,7 @@ class DiscriminatorBlock(torch.nn.Module):
1141
  trainable=next(trainable_iter), resample_filter=resample_filter, channels_last=self.channels_last)
1142
 
1143
  def forward(self, x, img, force_fp32=False, downsampler=None):
1144
- dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
1145
  memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format
1146
 
1147
  # Input.
 
794
  def forward(self, x, img, ws, force_fp32=False, fused_modconv=None, add_on=None, block_noise=None, disable_rgb=False, **layer_kwargs):
795
  misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim])
796
  w_iter = iter(ws.unbind(dim=1))
797
+ dtype = torch.float16 if (self.use_fp16 and x.device.type == 'cuda') and not force_fp32 else torch.float32
798
  memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format
799
  if fused_modconv is None:
800
  with misc.suppress_tracer_warnings(): # this value will be treated as a constant
 
937
 
938
  def forward(self, x, img, ws, force_fp32=False, add_on=None, disable_rgb=False, **layer_kwargs):
939
  w_iter = iter(ws.unbind(dim=1))
940
+ dtype = torch.float16 if (self.use_fp16 and x.device.type == 'cuda') and not force_fp32 else torch.float32
941
  memory_format = torch.contiguous_format
942
 
943
  # Main layers.
 
1141
  trainable=next(trainable_iter), resample_filter=resample_filter, channels_last=self.channels_last)
1142
 
1143
  def forward(self, x, img, force_fp32=False, downsampler=None):
1144
+ dtype = torch.float16 if (self.use_fp16 and x.device.type == 'cuda') and not force_fp32 else torch.float32
1145
  memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format
1146
 
1147
  # Input.