YaohuiW commited on
Commit
92a3897
·
verified ·
1 Parent(s): 8b1952d

Update gradio_tabs/vid_edit.py

Browse files
Files changed (1) hide show
  1. gradio_tabs/vid_edit.py +26 -16
gradio_tabs/vid_edit.py CHANGED
@@ -21,7 +21,7 @@ labels_k = [
21
  'pout',
22
  'open->close',
23
  '"O" mouth',
24
- 'apple cheek',
25
 
26
  'close->open',
27
  'eyebrows',
@@ -58,16 +58,24 @@ def img_preprocessing(img_path, size):
58
 
59
  def resize(img, size):
60
  transform = torchvision.transforms.Compose([
61
- torchvision.transforms.Resize(size, antialias=True),
62
- torchvision.transforms.CenterCrop(size)
63
  ])
64
 
65
  return transform(img)
66
 
67
 
 
 
 
 
 
 
 
 
68
  def vid_preprocessing(vid_path, size):
69
  vid_dict = torchvision.io.read_video(vid_path, pts_unit='sec')
70
  vid = vid_dict[0].permute(0, 3, 1, 2).unsqueeze(0) # btchw
 
71
  fps = vid_dict[2]['video_fps']
72
  vid_norm = (vid / 255.0 - 0.5) * 2.0 # [-1, 1]
73
 
@@ -75,7 +83,7 @@ def vid_preprocessing(vid_path, size):
75
  resize(vid_norm[:, i, :, :, :], size).unsqueeze(1) for i in range(vid.size(1))
76
  ], dim=1)
77
 
78
- return vid_norm, fps
79
 
80
 
81
  def img_denorm(img):
@@ -92,7 +100,8 @@ def vid_denorm(vid):
92
  return vid
93
 
94
 
95
- def img_postprocessing(image):
 
96
  image = image.permute(0, 2, 3, 1)
97
  edited_image = img_denorm(image)
98
  img_output = (edited_image[0].numpy() * 255).astype(np.uint8)
@@ -102,10 +111,14 @@ def img_postprocessing(image):
102
  return temp_file.name
103
 
104
 
105
- def vid_all_save(vid_d, vid_a, fps):
106
 
107
- vid_d = rearrange(vid_d, 'b t c h w -> b t h w c')
108
- vid_a = rearrange(vid_a, 'b c t h w -> b t h w c')
 
 
 
 
109
  vid_all = torch.cat([vid_d, vid_a], dim=3)
110
 
111
  vid_a_np = (vid_denorm(vid_a[0]).numpy() * 255).astype('uint8')
@@ -126,14 +139,14 @@ def vid_edit(gen, chunk_size, device):
126
  @torch.no_grad()
127
  def edit_img(video, *selected_s):
128
 
129
- vid_target_tensor, fps = vid_preprocessing(video, 512)
130
  video_target_tensor = vid_target_tensor.to(device)
131
  image_tensor = video_target_tensor[:,0,:,:,:]
132
 
133
  edited_image_tensor = gen.edit_img(image_tensor, labels_v, selected_s)
134
 
135
  # de-norm
136
- edited_image = img_postprocessing(edited_image_tensor)
137
 
138
  return edited_image
139
 
@@ -141,15 +154,15 @@ def vid_edit(gen, chunk_size, device):
141
  @torch.no_grad()
142
  def edit_vid(video, *selected_s):
143
 
144
- video_target_tensor, fps = vid_preprocessing(video, 512)
145
  video_target_tensor = video_target_tensor.to(device)
146
 
147
  edited_video_tensor = gen.edit_vid_batch(video_target_tensor, labels_v, selected_s, chunk_size)
148
  edited_image_tensor = edited_video_tensor[:,:,0,:,:]
149
 
150
  # de-norm
151
- animated_video, animated_all_video = vid_all_save(video_target_tensor, edited_video_tensor, fps)
152
- edited_image = img_postprocessing(edited_image_tensor)
153
 
154
  return edited_image, animated_video, animated_all_video
155
 
@@ -187,16 +200,13 @@ def vid_edit(gen, chunk_size, device):
187
 
188
  with gr.Row():
189
  with gr.Accordion(open=True, label="Edited First Frame"):
190
- #image_output.render()
191
  image_output = gr.Image(label="Image", elem_id="output_img", type='numpy', interactive=False, width=512)
192
 
193
  with gr.Accordion(open=True, label="Edited Video"):
194
- #video_output.render()
195
  video_output = gr.Video(label="Video", elem_id="output_vid", width=512)
196
 
197
  with gr.Row():
198
  with gr.Accordion(open=True, label="Original & Edited Videos"):
199
- #video_all_output.render()
200
  video_all_output = gr.Video(label="Videos", elem_id="output_vid_all")
201
 
202
  with gr.Column(scale=1):
 
21
  'pout',
22
  'open->close',
23
  '"O" mouth',
24
+ 'smile',
25
 
26
  'close->open',
27
  'eyebrows',
 
58
 
59
  def resize(img, size):
60
  transform = torchvision.transforms.Compose([
61
+ torchvision.transforms.Resize((size, size), antialias=True),
 
62
  ])
63
 
64
  return transform(img)
65
 
66
 
67
+ def resize_back(img, w, h):
68
+ transform = torchvision.transforms.Compose([
69
+ torchvision.transforms.Resize((h, w), antialias=True),
70
+ ])
71
+
72
+ return transform(img)
73
+
74
+
75
  def vid_preprocessing(vid_path, size):
76
  vid_dict = torchvision.io.read_video(vid_path, pts_unit='sec')
77
  vid = vid_dict[0].permute(0, 3, 1, 2).unsqueeze(0) # btchw
78
+ _,_,_,h,w = vid.size()
79
  fps = vid_dict[2]['video_fps']
80
  vid_norm = (vid / 255.0 - 0.5) * 2.0 # [-1, 1]
81
 
 
83
  resize(vid_norm[:, i, :, :, :], size).unsqueeze(1) for i in range(vid.size(1))
84
  ], dim=1)
85
 
86
+ return vid_norm, fps, w, h
87
 
88
 
89
  def img_denorm(img):
 
100
  return vid
101
 
102
 
103
+ def img_postprocessing(image, w, h):
104
+ image = resize_back(image, w, h)
105
  image = image.permute(0, 2, 3, 1)
106
  edited_image = img_denorm(image)
107
  img_output = (edited_image[0].numpy() * 255).astype(np.uint8)
 
111
  return temp_file.name
112
 
113
 
114
+ def vid_all_save(vid_d, vid_a, w, h, fps):
115
 
116
+ b,t,c,_,_ = vid_d.size()
117
+ vid_d_batch = resize_back(rearrange(vid_d, "b t c h w -> (b t) c h w"), w, h)
118
+ vid_a_batch = resize_back(rearrange(vid_a, "b c t h w -> (b t) c h w"), w, h)
119
+
120
+ vid_d = rearrange(vid_d_batch, "(b t) c h w -> b t h w c", b=b) # B T H W C
121
+ vid_a = rearrange(vid_a_batch, "(b t) c h w -> b t h w c", b=b) # B T H W C
122
  vid_all = torch.cat([vid_d, vid_a], dim=3)
123
 
124
  vid_a_np = (vid_denorm(vid_a[0]).numpy() * 255).astype('uint8')
 
139
  @torch.no_grad()
140
  def edit_img(video, *selected_s):
141
 
142
+ vid_target_tensor, fps, w, h = vid_preprocessing(video, 512)
143
  video_target_tensor = vid_target_tensor.to(device)
144
  image_tensor = video_target_tensor[:,0,:,:,:]
145
 
146
  edited_image_tensor = gen.edit_img(image_tensor, labels_v, selected_s)
147
 
148
  # de-norm
149
+ edited_image = img_postprocessing(edited_image_tensor, w, h)
150
 
151
  return edited_image
152
 
 
154
  @torch.no_grad()
155
  def edit_vid(video, *selected_s):
156
 
157
+ video_target_tensor, fps, w, h = vid_preprocessing(video, 512)
158
  video_target_tensor = video_target_tensor.to(device)
159
 
160
  edited_video_tensor = gen.edit_vid_batch(video_target_tensor, labels_v, selected_s, chunk_size)
161
  edited_image_tensor = edited_video_tensor[:,:,0,:,:]
162
 
163
  # de-norm
164
+ animated_video, animated_all_video = vid_all_save(video_target_tensor, edited_video_tensor, w, h, fps)
165
+ edited_image = img_postprocessing(edited_image_tensor, w, h)
166
 
167
  return edited_image, animated_video, animated_all_video
168
 
 
200
 
201
  with gr.Row():
202
  with gr.Accordion(open=True, label="Edited First Frame"):
 
203
  image_output = gr.Image(label="Image", elem_id="output_img", type='numpy', interactive=False, width=512)
204
 
205
  with gr.Accordion(open=True, label="Edited Video"):
 
206
  video_output = gr.Video(label="Video", elem_id="output_vid", width=512)
207
 
208
  with gr.Row():
209
  with gr.Accordion(open=True, label="Original & Edited Videos"):
 
210
  video_all_output = gr.Video(label="Videos", elem_id="output_vid_all")
211
 
212
  with gr.Column(scale=1):