YaohuiW commited on
Commit
f01681c
·
1 Parent(s): 6b0ef0f
app.py CHANGED
@@ -17,7 +17,7 @@ ckpt_path = hf_hub_download(repo_id="YaohuiW/LIA-X", filename="lia-x.pt")
17
  gen.load_state_dict(torch.load(ckpt_path, weights_only=True))
18
  gen.eval()
19
 
20
- chunk_size=16
21
 
22
  def load_file(path):
23
 
 
17
  gen.load_state_dict(torch.load(ckpt_path, weights_only=True))
18
  gen.eval()
19
 
20
+ chunk_size=30
21
 
22
  def load_file(path):
23
 
assets/instruction.md CHANGED
@@ -3,18 +3,18 @@
3
  * **Image Animation**
4
 
5
  - Upload `Source Image` and `Driving Video`
6
- - Using sliders in the `Control Panel` to edit image
7
  - Use `Animate` button to obtain `Animated Video`
8
 
9
  * **Image Editing**
10
 
11
  - Upload `Source Image`
12
- - Using sliders in the `Control Panel` to edit image
13
 
14
  * **Video Editing**
15
 
16
  - Upload `Video`
17
- - Using sliders in the `Control Panel` to edit image
18
  - Use `Generate` button to obtain `Edited Video`
19
 
20
  **NOTE: we recommend to crop both input images and videos using provided [tools](https://github.com/wyhsirius/LIA-X/tree/main) for better results**
 
3
  * **Image Animation**
4
 
5
  - Upload `Source Image` and `Driving Video`
6
+ - Using `sliders` in the `Control Panel` to edit image
7
  - Use `Animate` button to obtain `Animated Video`
8
 
9
  * **Image Editing**
10
 
11
  - Upload `Source Image`
12
+ - Using `sliders` in the `Control Panel` to edit image
13
 
14
  * **Video Editing**
15
 
16
  - Upload `Video`
17
+ - Using `sliders` in the `Control Panel` to edit image
18
  - Use `Generate` button to obtain `Edited Video`
19
 
20
  **NOTE: we recommend to crop both input images and videos using provided [tools](https://github.com/wyhsirius/LIA-X/tree/main) for better results**
gradio_tabs/animation.py CHANGED
@@ -90,10 +90,6 @@ def vid_preprocessing(vid_path, size):
90
  vid = vid_dict[0].permute(0, 3, 1, 2) # tchw
91
  fps = vid_dict[2]['video_fps']
92
  vid_norm = (vid / 255.0 - 0.5) * 2.0 # [-1, 1]
93
-
94
- #vid_norm = torch.cat([
95
- # resize(vid_norm[i:i+1, :, :, :], size).unsqueeze(1) for i in range(vid.size(0))
96
- #], dim=1)
97
  vid_norm = resize(vid_norm, size) # tchw
98
 
99
  return vid_norm, fps
@@ -135,9 +131,7 @@ def vid_postprocessing(video, w, h, fps):
135
 
136
  t,c,_,_ = video.size()
137
  vid = resize_back(video, w, h)
138
-
139
- vid = vid.clamp(-1, 1)
140
- vid = (vid - vid.min()) / (vid.max() - vid.min())
141
 
142
  vid = rearrange(vid, "t c h w -> t h w c") # T H W C
143
  vid_np = (vid.cpu().numpy() * 255).astype('uint8')
@@ -215,30 +209,27 @@ def animation(gen, chunk_size, device):
215
  vid_target_tensor, fps = vid_preprocessing(video, 512)
216
  image_tensor = image_tensor.to(device)
217
  video_target_tensor = vid_target_tensor.to(device) #tchw
218
-
219
- #animated_video = gen.animate_batch(image_tensor, video_target_tensor, labels_v, selected_s, chunk_size)
220
- #edited_image = animated_video[:,:,0,:,:]
221
 
222
  img_start = video_target_tensor[0:1,:,:,:]
223
- #vid_target_tensor_batch = rearrange(video_target_tensor, 'b t c h w -> (b t) c h w')
224
 
225
  res = []
226
- t = video_target_tensor.size(1)
 
227
  chunks = t // chunk_size
 
 
 
 
 
 
228
  z_s2r, alpha_r2s, feat_rgb = compiled_enc_img(image_tensor, selected_s)
229
- #z_s2r, alpha_r2s, feat_rgb = gen.enc_img(image_tensor, labels_v, selected_s)
230
  for i in range(chunks+1):
231
- if i == chunks:
232
- img_target = vid_target_tensor[i*chunk_size:, :, :, :]
233
- img_animated = compiled_dec_vid(z_s2r, alpha_r2s, feat_rgb, img_start, img_target)
234
- #img_animated_batch = gen.dec_vid(z_s2r, alpha_r2s, feat_rgb, img_start, img_target_batch)
235
- else:
236
- img_target = vid_target_tensor[i*chunk_size:(i+1)*chunk_size, :, :, :]
237
- img_animated = compiled_dec_vid(z_s2r, alpha_r2s, feat_rgb, img_start, img_target)
238
- #img_animated_batch = gen.dec_vid(z_s2r, alpha_r2s, feat_rgb, img_start, img_target_batch)
239
-
240
- res.append(img_animated)
241
- animated_video = torch.cat(res, dim=0) # TCHW
242
  edited_image = animated_video[0:1,:,:,:]
243
 
244
  # postprocessing
@@ -308,7 +299,7 @@ def animation(gen, chunk_size, device):
308
  #video_output.render()
309
  video_output = gr.Video(label="Output Video", elem_id="output_vid", width=512)#.render()
310
 
311
- with gr.Accordion("Control Panel (Using Sliders to Edit Image)", open=True):
312
  with gr.Tab("Head"):
313
  with gr.Row():
314
  for k in labels_k[:3]:
@@ -344,23 +335,12 @@ def animation(gen, chunk_size, device):
344
  fn=edit_media,
345
  inputs=[image_input] + inputs_s,
346
  outputs=[image_output],
347
-
348
  show_progress='hidden',
349
-
350
  trigger_mode='always_last',
351
-
352
  # currently we have a latency around 450ms
353
  stream_every=0.5
354
  )
355
 
356
-
357
- #edit_btn.click(
358
- # fn=edit_media,
359
- # inputs=[image_input] + inputs_s,
360
- # outputs=[image_output],
361
- # show_progress=True
362
- #)
363
-
364
  animate_btn.click(
365
  fn=animate_media,
366
  inputs=[image_input, video_input] + inputs_s,
 
90
  vid = vid_dict[0].permute(0, 3, 1, 2) # tchw
91
  fps = vid_dict[2]['video_fps']
92
  vid_norm = (vid / 255.0 - 0.5) * 2.0 # [-1, 1]
 
 
 
 
93
  vid_norm = resize(vid_norm, size) # tchw
94
 
95
  return vid_norm, fps
 
131
 
132
  t,c,_,_ = video.size()
133
  vid = resize_back(video, w, h)
134
+ vid = vid_denorm(vid)
 
 
135
 
136
  vid = rearrange(vid, "t c h w -> t h w c") # T H W C
137
  vid_np = (vid.cpu().numpy() * 255).astype('uint8')
 
209
  vid_target_tensor, fps = vid_preprocessing(video, 512)
210
  image_tensor = image_tensor.to(device)
211
  video_target_tensor = vid_target_tensor.to(device) #tchw
 
 
 
212
 
213
  img_start = video_target_tensor[0:1,:,:,:]
 
214
 
215
  res = []
216
+ t, c, h, w = video_target_tensor.size()
217
+
218
  chunks = t // chunk_size
219
+ if t%chunk_size == 0:
220
+ vid_target_tensor_batch = torch.zeros(chunk_size * chunks, c, h, w).to(device)
221
+ else:
222
+ vid_target_tensor_batch = torch.zeros(chunk_size * (chunks + 1), c, h, w).to(device)
223
+ vid_target_tensor_batch[:t] = video_target_tensor
224
+
225
  z_s2r, alpha_r2s, feat_rgb = compiled_enc_img(image_tensor, selected_s)
 
226
  for i in range(chunks+1):
227
+
228
+ img_target_batch = vid_target_tensor_batch[i * chunk_size:(i + 1) * chunk_size, :, :, :]
229
+ img_animated_batch = compiled_dec_vid(z_s2r, alpha_r2s, feat_rgb, img_start, img_target_batch)
230
+
231
+ res.append(img_animated_batch)
232
+ animated_video = torch.cat(res, dim=0)[:t] # TCHW
 
 
 
 
 
233
  edited_image = animated_video[0:1,:,:,:]
234
 
235
  # postprocessing
 
299
  #video_output.render()
300
  video_output = gr.Video(label="Output Video", elem_id="output_vid", width=512)#.render()
301
 
302
+ with gr.Accordion("Control Panel - Using Sliders to Edit Image", open=True):
303
  with gr.Tab("Head"):
304
  with gr.Row():
305
  for k in labels_k[:3]:
 
335
  fn=edit_media,
336
  inputs=[image_input] + inputs_s,
337
  outputs=[image_output],
 
338
  show_progress='hidden',
 
339
  trigger_mode='always_last',
 
340
  # currently we have a latency around 450ms
341
  stream_every=0.5
342
  )
343
 
 
 
 
 
 
 
 
 
344
  animate_btn.click(
345
  fn=animate_media,
346
  inputs=[image_input, video_input] + inputs_s,
gradio_tabs/img_edit.py CHANGED
@@ -95,14 +95,10 @@ def img_denorm(img):
95
  def img_postprocessing(img, w, h):
96
 
97
  img = resize_back(img, w, h)
98
- #image = image.permute(0, 2, 3, 1)
99
  img = img_denorm(img)
100
  img = img.squeeze(0).permute(1, 2, 0).contiguous() # contiguous() for fast transfer
101
  img_output = (img.cpu().numpy() * 255).astype(np.uint8)
102
 
103
- #with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
104
- # imageio.imwrite(temp_file.name, img_output, quality=8)
105
- # return temp_file.name
106
  return img_output
107
 
108
 
@@ -196,7 +192,7 @@ def img_edit(gen, device):
196
  image_output = gr.Image(label="Output Image", type='numpy', interactive=False, width=512)
197
 
198
 
199
- with gr.Accordion("Control Panel (Using Sliders to Edit Image)", open=True):
200
  with gr.Tab("Head"):
201
  with gr.Row():
202
  for k in labels_k[:3]:
@@ -239,15 +235,7 @@ def img_edit(gen, device):
239
 
240
  # currently we have a latency around 450ms
241
  stream_every=0.5
242
- )
243
-
244
-
245
- #edit_btn.click(
246
- # fn=edit_img,
247
- # inputs=[image_input] + inputs_s,
248
- # outputs=[image_output],
249
- # show_progress=True
250
- #)
251
 
252
  clear_btn.click(
253
  fn=clear_media,
 
95
  def img_postprocessing(img, w, h):
96
 
97
  img = resize_back(img, w, h)
 
98
  img = img_denorm(img)
99
  img = img.squeeze(0).permute(1, 2, 0).contiguous() # contiguous() for fast transfer
100
  img_output = (img.cpu().numpy() * 255).astype(np.uint8)
101
 
 
 
 
102
  return img_output
103
 
104
 
 
192
  image_output = gr.Image(label="Output Image", type='numpy', interactive=False, width=512)
193
 
194
 
195
+ with gr.Accordion("Control Panel - Using Sliders to Edit Image", open=True):
196
  with gr.Tab("Head"):
197
  with gr.Row():
198
  for k in labels_k[:3]:
 
235
 
236
  # currently we have a latency around 450ms
237
  stream_every=0.5
238
+ )
 
 
 
 
 
 
 
 
239
 
240
  clear_btn.click(
241
  fn=clear_media,
gradio_tabs/vid_edit.py CHANGED
@@ -231,21 +231,23 @@ def vid_edit(gen, chunk_size, device):
231
  res = []
232
  t = video_target_tensor.size(1)
233
  chunks = t // chunk_size
 
 
234
  z_s2r, alpha_r2s, feat_rgb = compiled_enc_img(img_start, selected_s)
235
  for i in range(chunks + 1):
236
  if i == chunks:
237
- img_target_batch = vid_target_tensor_batch[i * chunk_size:, :, :, :]
238
- img_animated_batch = compiled_dec_vid(z_s2r, alpha_r2s, feat_rgb, img_start, img_target)
239
  else:
240
- img_target_batch = vid_target_tensor_batch[i * chunk_size:(i + 1) * chunk_size, :, :, :]
241
- img_animated_batch = compiled_dec_vid(z_s2r, alpha_r2s, feat_rgb, img_start, img_target)
242
 
243
  res.append(img_animated_batch)
244
  edited_video_tensor = torch.cat(res, dim=0) # TCHW
245
  edited_image_tensor = edited_video_tensor[0:1,:,:,:]
246
 
247
  # de-norm
248
- animated_video, animated_all_video = vid_all_save(vid_target_tensor_batch, edited_video_tensor, w, h, fps)
249
  edited_image = img_postprocessing(edited_image_tensor, w, h)
250
 
251
  return edited_image, animated_video, animated_all_video
@@ -293,7 +295,7 @@ def vid_edit(gen, chunk_size, device):
293
  video_all_output = gr.Video(label="Videos", elem_id="output_vid_all")
294
 
295
  with gr.Column(scale=1):
296
- with gr.Accordion("Control Panel (Using Sliders to Edit Image)", open=True):
297
  with gr.Tab("Head"):
298
  with gr.Row():
299
  for k in labels_k[:3]:
@@ -342,13 +344,6 @@ def vid_edit(gen, chunk_size, device):
342
  stream_every=0.5
343
  )
344
 
345
- #edit_btn.click(
346
- # fn=edit_img,
347
- # inputs=[video_input] + inputs_s,
348
- # outputs=[image_output],
349
- # show_progress=True
350
- #)
351
-
352
  animate_btn.click(
353
  fn=edit_vid,
354
  inputs=[video_input] + inputs_s, # [image_input, video_input] + inputs_s,
 
231
  res = []
232
  t = video_target_tensor.size(1)
233
  chunks = t // chunk_size
234
+
235
+
236
  z_s2r, alpha_r2s, feat_rgb = compiled_enc_img(img_start, selected_s)
237
  for i in range(chunks + 1):
238
  if i == chunks:
239
+ img_target_batch = video_target_tensor[i * chunk_size:, :, :, :]
240
+ img_animated_batch = compiled_dec_vid(z_s2r, alpha_r2s, feat_rgb, img_start, img_target_batch)
241
  else:
242
+ img_target_batch = video_target_tensor[i * chunk_size:(i + 1) * chunk_size, :, :, :]
243
+ img_animated_batch = compiled_dec_vid(z_s2r, alpha_r2s, feat_rgb, img_start, img_target_batch)
244
 
245
  res.append(img_animated_batch)
246
  edited_video_tensor = torch.cat(res, dim=0) # TCHW
247
  edited_image_tensor = edited_video_tensor[0:1,:,:,:]
248
 
249
  # de-norm
250
+ animated_video, animated_all_video = vid_all_save(video_target_tensor, edited_video_tensor, w, h, fps)
251
  edited_image = img_postprocessing(edited_image_tensor, w, h)
252
 
253
  return edited_image, animated_video, animated_all_video
 
295
  video_all_output = gr.Video(label="Videos", elem_id="output_vid_all")
296
 
297
  with gr.Column(scale=1):
298
+ with gr.Accordion("Control Panel - Using Sliders to Edit Image", open=True):
299
  with gr.Tab("Head"):
300
  with gr.Row():
301
  for k in labels_k[:3]:
 
344
  stream_every=0.5
345
  )
346
 
 
 
 
 
 
 
 
347
  animate_btn.click(
348
  fn=edit_vid,
349
  inputs=[video_input] + inputs_s, # [image_input, video_input] + inputs_s,