YaohuiW commited on
Commit
6b0ef0f
·
1 Parent(s): 01e0491
app.py CHANGED
@@ -7,6 +7,10 @@ from gradio_tabs.vid_edit import vid_edit
7
  from gradio_tabs.img_edit import img_edit
8
  from networks.generator import Generator
9
 
 
 
 
 
10
  device = torch.device("cuda")
11
  gen = Generator(size=512, motion_dim=40, scale=2).to(device)
12
  ckpt_path = hf_hub_download(repo_id="YaohuiW/LIA-X", filename="lia-x.pt")
 
7
  from gradio_tabs.img_edit import img_edit
8
  from networks.generator import Generator
9
 
10
+ # Optimize torch.compile performance
11
+ torch.set_float32_matmul_precision('high') # Enable TensorFloat32 for better performance
12
+ torch._dynamo.config.cache_size_limit = 64 # Increase cache size to reduce recompilations
13
+
14
  device = torch.device("cuda")
15
  gen = Generator(size=512, motion_dim=40, scale=2).to(device)
16
  ckpt_path = hf_hub_download(repo_id="YaohuiW/LIA-X", filename="lia-x.pt")
assets/instruction.md CHANGED
@@ -3,18 +3,18 @@
3
  * **Image Animation**
4
 
5
  - Upload `Source Image` and `Driving Video`
6
- - Use `Control Panel` to edit source image and `Edit` button to display the `Edited Image`
7
- - Use `Animate` button to obtained `Animated Video`
8
 
9
  * **Image Editing**
10
 
11
  - Upload `Source Image`
12
- - Use `Control Panel` to edit source image and `Edit` button to display the `Edited Image`
13
 
14
  * **Video Editing**
15
 
16
  - Upload `Video`
17
- - Use `Control Panel` to edit first frame of video and `Edit` button to display the `Edited Image`
18
  - Use `Generate` button to obtain `Edited Video`
19
 
20
  **NOTE: we recommend to crop both input images and videos using provided [tools](https://github.com/wyhsirius/LIA-X/tree/main) for better results**
 
3
  * **Image Animation**
4
 
5
  - Upload `Source Image` and `Driving Video`
6
+ - Using sliders in the `Control Panel` to edit image
7
+ - Use `Animate` button to obtain `Animated Video`
8
 
9
  * **Image Editing**
10
 
11
  - Upload `Source Image`
12
+ - Using sliders in the `Control Panel` to edit image
13
 
14
  * **Video Editing**
15
 
16
  - Upload `Video`
17
+ - Using sliders in the `Control Panel` to edit image
18
  - Use `Generate` button to obtain `Edited Video`
19
 
20
  **NOTE: we recommend to crop both input images and videos using provided [tools](https://github.com/wyhsirius/LIA-X/tree/main) for better results**
assets/title.md CHANGED
@@ -1,4 +1,5 @@
1
  <font size=7><center>LIA-X: Interpretable Latent Portrait Animator</center></font>
 
2
  <div style="display: flex;align-items: center;justify-content: center">
3
  [<a href="https://arxiv.org/abs/2508.09959">Technical Report</a>] | [<a href="https://wyhsirius.github.io/LIA-X-project/">Project Page</a>] | [<a href="https://github.com/wyhsirius/LIA-X">Code</a>]
4
  </div>
 
1
  <font size=7><center>LIA-X: Interpretable Latent Portrait Animator</center></font>
2
+ <font size=5><center>Toward Interactive Portrait Animation and Editing</center></font>
3
  <div style="display: flex;align-items: center;justify-content: center">
4
  [<a href="https://arxiv.org/abs/2508.09959">Technical Report</a>] | [<a href="https://wyhsirius.github.io/LIA-X-project/">Project Page</a>] | [<a href="https://github.com/wyhsirius/LIA-X">Code</a>]
5
  </div>
gradio_tabs/animation.py CHANGED
@@ -36,64 +36,78 @@ labels_v = [
36
  13, 24, 17, 26
37
  ]
38
 
39
-
40
  def load_image(img, size):
41
 
42
  img = Image.open(img).convert('RGB')
43
  w, h = img.size
44
  img = img.resize((size, size))
45
  img = np.asarray(img)
 
46
  img = np.transpose(img, (2, 0, 1)) # 3 x 256 x 256
47
 
48
  return img / 255.0, w, h
49
 
50
 
 
51
  def img_preprocessing(img_path, size):
52
- img, w, h = load_image(img_path, size) # [0, 1]
53
  img = torch.from_numpy(img).unsqueeze(0).float() # [0, 1]
54
  imgs_norm = (img - 0.5) * 2.0 # [-1, 1]
55
 
56
  return imgs_norm, w, h
57
 
58
 
59
- def resize(img, size):
60
- transform = torchvision.transforms.Compose([
61
- torchvision.transforms.Resize((size, size), antialias=True),
62
- ])
 
 
 
 
 
 
 
 
 
63
 
 
 
 
 
64
  return transform(img)
65
 
66
 
67
  def resize_back(img, w, h):
68
- transform = torchvision.transforms.Compose([
69
- torchvision.transforms.Resize((h, w), antialias=True),
70
- ])
71
-
72
  return transform(img)
73
-
74
 
75
  def vid_preprocessing(vid_path, size):
76
  vid_dict = torchvision.io.read_video(vid_path, pts_unit='sec')
77
- vid = vid_dict[0].permute(0, 3, 1, 2).unsqueeze(0) # btchw
78
  fps = vid_dict[2]['video_fps']
79
  vid_norm = (vid / 255.0 - 0.5) * 2.0 # [-1, 1]
80
 
81
- vid_norm = torch.cat([
82
- resize(vid_norm[:, i, :, :, :], size).unsqueeze(1) for i in range(vid.size(1))
83
- ], dim=1)
 
84
 
85
  return vid_norm, fps
86
 
87
 
88
  def img_denorm(img):
89
- img = img.clamp(-1, 1).cpu()
90
  img = (img - img.min()) / (img.max() - img.min())
91
 
92
  return img
93
 
94
 
95
  def vid_denorm(vid):
96
- vid = vid.clamp(-1, 1).cpu()
97
  vid = (vid - vid.min()) / (vid.max() - vid.min())
98
 
99
  return vid
@@ -101,24 +115,32 @@ def vid_denorm(vid):
101
 
102
  def img_postprocessing(image, w, h):
103
 
104
- image = resize_back(image, w, h)
105
- image = image.permute(0, 2, 3, 1)
106
- edited_image = img_denorm(image)
107
- img_output = (edited_image[0].numpy() * 255).astype(np.uint8)
108
 
109
- with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
110
- imageio.imwrite(temp_file.name, img_output, quality=8)
111
- return temp_file.name
 
 
 
 
 
 
 
112
 
113
 
114
 
115
  def vid_postprocessing(video, w, h, fps):
116
- # video: BCTHW
 
 
 
117
 
118
- b,c,t,_,_ = video.size()
119
- vid_batch = resize_back(rearrange(video, "b c t h w -> (b t) c h w"), w, h)
120
- vid = rearrange(vid_batch, "(b t) c h w -> b t h w c", b=b) # B T H W C
121
- vid_np = (vid_denorm(vid[0]).numpy() * 255).astype('uint8')
 
122
 
123
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_file:
124
  imageio.mimwrite(temp_file.name, vid_np, fps=fps, codec='libx264', quality=8)
@@ -126,15 +148,59 @@ def vid_postprocessing(video, w, h, fps):
126
 
127
 
128
  def animation(gen, chunk_size, device):
129
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  @spaces.GPU
131
- @torch.no_grad()
132
  def edit_media(image, *selected_s):
133
 
134
  image_tensor, w, h = img_preprocessing(image, 512)
135
  image_tensor = image_tensor.to(device)
136
 
137
- edited_image_tensor = gen.edit_img(image_tensor, labels_v, selected_s)
 
138
 
139
  # de-norm
140
  edited_image = img_postprocessing(edited_image_tensor, w, h)
@@ -142,16 +208,38 @@ def animation(gen, chunk_size, device):
142
  return edited_image
143
 
144
  @spaces.GPU
145
- @torch.no_grad()
146
  def animate_media(image, video, *selected_s):
147
 
148
  image_tensor, w, h = img_preprocessing(image, 512)
149
  vid_target_tensor, fps = vid_preprocessing(video, 512)
150
  image_tensor = image_tensor.to(device)
151
- video_target_tensor = vid_target_tensor.to(device)
152
-
153
- animated_video = gen.animate_batch(image_tensor, video_target_tensor, labels_v, selected_s, chunk_size)
154
- edited_image = animated_video[:,:,0,:,:]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
  # postprocessing
157
  animated_video = vid_postprocessing(animated_video, w, h, fps)
@@ -162,7 +250,7 @@ def animation(gen, chunk_size, device):
162
  def clear_media():
163
  return None, None, *([0] * len(labels_k))
164
 
165
-
166
  with gr.Tab("Image Animation"):
167
 
168
  inputs_s = []
@@ -202,11 +290,10 @@ def animation(gen, chunk_size, device):
202
  with gr.Row():
203
  with gr.Column(scale=1):
204
  with gr.Row(): # Buttons now within a single Row
205
- edit_btn = gr.Button("Edit", elem_id="button_edit",)
206
- clear_btn = gr.Button("Clear", elem_id="button_clear")
207
- with gr.Row():
208
  animate_btn = gr.Button("Animate", elem_id="button_animate")
209
-
 
210
 
211
 
212
  with gr.Column(scale=1):
@@ -221,7 +308,7 @@ def animation(gen, chunk_size, device):
221
  #video_output.render()
222
  video_output = gr.Video(label="Output Video", elem_id="output_vid", width=512)#.render()
223
 
224
- with gr.Accordion("Control Panel", open=True):
225
  with gr.Tab("Head"):
226
  with gr.Row():
227
  for k in labels_k[:3]:
@@ -251,20 +338,34 @@ def animation(gen, chunk_size, device):
251
  for k in labels_k[12:14]:
252
  slider = gr.Slider(minimum=-0.2, maximum=0.2, value=0, label=k, elem_id="slider_"+str(k))
253
  inputs_s.append(slider)
 
 
 
 
 
 
254
 
 
 
 
 
 
 
 
255
 
256
- edit_btn.click(
257
- fn=edit_media,
258
- inputs=[image_input] + inputs_s,
259
- outputs=[image_output],
260
- show_progress=True
261
- )
 
262
 
263
  animate_btn.click(
264
  fn=animate_media,
265
  inputs=[image_input, video_input] + inputs_s,
266
  outputs=[image_output, video_output],
267
- show_progress=True
268
  )
269
 
270
  clear_btn.click(
@@ -280,14 +381,14 @@ def animation(gen, chunk_size, device):
280
  ['./data/source/macron.png', './data/driving/driving1.mp4', 0.14,0,-0.26,-0.29,-0.11,0,-0.13,-0.18,0,0,0,0,-0.02,0.07],
281
  ['./data/source/portrait3.png', './data/driving/driving1.mp4', -0.03,0.21,-0.31,-0.12,-0.11,0,-0.05,-0.16,0,0,0,0,-0.02,0.07],
282
  ['./data/source/einstein.png','./data/driving/driving2.mp4',-0.31,0,0,0.16,0.08,0,-0.07,0,0.13,0,0,0,0,0],
283
- ['./data/source/portrait1.png', './data/driving/driving4.mp4', 0, 0, -0.17, -0.19, 0.25, 0, 0, -0.086,
284
  0.087, 0, 0, 0, 0, 0],
285
  ['./data/source/portrait2.png','./data/driving/driving8.mp4',0,0,-0.25,0,0,0,0,0,0,0.126,0,0,0,0],
286
 
287
  ],
288
- fn=animate_media,
289
  inputs=[image_input, video_input] + inputs_s,
290
- outputs=[image_output, video_output],
291
  )
292
 
293
 
 
36
  13, 24, 17, 26
37
  ]
38
 
39
+ @torch.compiler.allow_in_graph
40
  def load_image(img, size):
41
 
42
  img = Image.open(img).convert('RGB')
43
  w, h = img.size
44
  img = img.resize((size, size))
45
  img = np.asarray(img)
46
+ img = np.copy(img)
47
  img = np.transpose(img, (2, 0, 1)) # 3 x 256 x 256
48
 
49
  return img / 255.0, w, h
50
 
51
 
52
+ @torch.compiler.allow_in_graph
53
  def img_preprocessing(img_path, size):
54
+ img, w, h = load_image(img_path, size) # [0, 1]
55
  img = torch.from_numpy(img).unsqueeze(0).float() # [0, 1]
56
  imgs_norm = (img - 0.5) * 2.0 # [-1, 1]
57
 
58
  return imgs_norm, w, h
59
 
60
 
61
+ # Pre-compile resize transforms for better performance
62
+ resize_transform_cache = {}
63
+
64
+ def get_resize_transform(size):
65
+ """Get cached resize transform - creates once, reuses many times"""
66
+ if size not in resize_transform_cache:
67
+ # Only create the transform if it doesn't exist in cache
68
+ resize_transform_cache[size] = torchvision.transforms.Resize(
69
+ size,
70
+ interpolation=torchvision.transforms.InterpolationMode.BILINEAR,
71
+ antialias=True
72
+ )
73
+ return resize_transform_cache[size]
74
 
75
+
76
+ def resize(img, size):
77
+ """Use cached resize transform"""
78
+ transform = get_resize_transform((size, size))
79
  return transform(img)
80
 
81
 
82
  def resize_back(img, w, h):
83
+ """Use cached resize transform for back operation"""
84
+ transform = get_resize_transform((h, w))
 
 
85
  return transform(img)
86
+
87
 
88
  def vid_preprocessing(vid_path, size):
89
  vid_dict = torchvision.io.read_video(vid_path, pts_unit='sec')
90
+ vid = vid_dict[0].permute(0, 3, 1, 2) # tchw
91
  fps = vid_dict[2]['video_fps']
92
  vid_norm = (vid / 255.0 - 0.5) * 2.0 # [-1, 1]
93
 
94
+ #vid_norm = torch.cat([
95
+ # resize(vid_norm[i:i+1, :, :, :], size).unsqueeze(1) for i in range(vid.size(0))
96
+ #], dim=1)
97
+ vid_norm = resize(vid_norm, size) # tchw
98
 
99
  return vid_norm, fps
100
 
101
 
102
  def img_denorm(img):
103
+ img = img.clamp(-1, 1)
104
  img = (img - img.min()) / (img.max() - img.min())
105
 
106
  return img
107
 
108
 
109
  def vid_denorm(vid):
110
+ vid = vid.clamp(-1, 1)
111
  vid = (vid - vid.min()) / (vid.max() - vid.min())
112
 
113
  return vid
 
115
 
116
  def img_postprocessing(image, w, h):
117
 
118
+ img = resize_back(image, w, h)
 
 
 
119
 
120
+ # Denormalize ON GPU (avoid early CPU transfer)
121
+ img = img.clamp(-1, 1) # Still on GPU
122
+ img = (img - img.min()) / (img.max() - img.min()) # Still on GPU
123
+
124
+ # Single optimized CPU transfer
125
+ img = img.squeeze(0).permute(1, 2, 0).contiguous() # contiguous() for fast transfer
126
+ img_output = (img.cpu().numpy() * 255).astype(np.uint8) # Single CPU transfer
127
+
128
+ # return the Numpy array directly, since Gradio supports it
129
+ return img_output
130
 
131
 
132
 
133
  def vid_postprocessing(video, w, h, fps):
134
+ # video: TCHW
135
+
136
+ t,c,_,_ = video.size()
137
+ vid = resize_back(video, w, h)
138
 
139
+ vid = vid.clamp(-1, 1)
140
+ vid = (vid - vid.min()) / (vid.max() - vid.min())
141
+
142
+ vid = rearrange(vid, "t c h w -> t h w c") # T H W C
143
+ vid_np = (vid.cpu().numpy() * 255).astype('uint8')
144
 
145
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_file:
146
  imageio.mimwrite(temp_file.name, vid_np, fps=fps, codec='libx264', quality=8)
 
148
 
149
 
150
  def animation(gen, chunk_size, device):
151
+
152
+ @torch.compile
153
+ def compiled_enc_img(image_tensor, selected_s):
154
+ """Compiled version of just the model inference"""
155
+ return gen.enc_img(image_tensor, labels_v, selected_s)
156
+
157
+ @torch.compile
158
+ def compiled_dec_img(z_s2r, alpha_r2s, feat_rgb):
159
+ """Compiled version of just the model inference"""
160
+ return gen.dec_img(z_s2r, alpha_r2s, feat_rgb)
161
+
162
+ @torch.compile
163
+ def compiled_dec_vid(z_s2r, alpha_r2s, feat_rgb, img_start, img_target_batch):
164
+ """Compiled version of animate_batch for animation tab"""
165
+ return gen.dec_vid(z_s2r, alpha_r2s, feat_rgb, img_start, img_target_batch)
166
+
167
+ # Pre-warm the compiled model with dummy data to reduce first-run compilation time
168
+ def _warmup_model():
169
+ """Pre-warm the model compilation with representative shapes"""
170
+ print("[img_edit] Pre-warming model compilation...")
171
+ dummy_image = torch.randn(1, 3, 512, 512, device=device)
172
+ dummy_video = torch.randn(chunk_size, 3, 512, 512, device=device)
173
+ dummy_selected_s = [0.0] * len(labels_v)
174
+
175
+ try:
176
+ with torch.inference_mode():
177
+ z_s2r, alpha_r2s, feat_rgb = compiled_enc_img(dummy_image, dummy_selected_s)
178
+ _ = compiled_dec_img(z_s2r, alpha_r2s, feat_rgb)
179
+ print("[img_edit] Model pre-warming completed successfully")
180
+ except Exception as e:
181
+ print(f"[img_edit] Model pre-warming failed (will compile on first use): {e}")
182
+
183
+ try:
184
+ with torch.inference_mode():
185
+ z_s2r, alpha_r2s, feat_rgb = compiled_enc_img(dummy_image, dummy_selected_s)
186
+ _ = compiled_dec_vid(z_s2r, alpha_r2s, feat_rgb, dummy_video[0], dummy_video)
187
+ print("[img_animation] Model pre-warming completed successfully")
188
+ except Exception as e:
189
+ print(f"[img_animation] Model pre-warming failed (will compile on first use): {e}")
190
+
191
+ # Pre-warm the model
192
+ _warmup_model()
193
+
194
+
195
  @spaces.GPU
196
+ @torch.inference_mode()
197
  def edit_media(image, *selected_s):
198
 
199
  image_tensor, w, h = img_preprocessing(image, 512)
200
  image_tensor = image_tensor.to(device)
201
 
202
+ z_s2r, alpha_r2s, feat_rgb = compiled_enc_img(image_tensor, selected_s)
203
+ edited_image_tensor = compiled_dec_img(z_s2r, alpha_r2s, feat_rgb)
204
 
205
  # de-norm
206
  edited_image = img_postprocessing(edited_image_tensor, w, h)
 
208
  return edited_image
209
 
210
  @spaces.GPU
211
+ @torch.inference_mode()
212
  def animate_media(image, video, *selected_s):
213
 
214
  image_tensor, w, h = img_preprocessing(image, 512)
215
  vid_target_tensor, fps = vid_preprocessing(video, 512)
216
  image_tensor = image_tensor.to(device)
217
+ video_target_tensor = vid_target_tensor.to(device) #tchw
218
+
219
+ #animated_video = gen.animate_batch(image_tensor, video_target_tensor, labels_v, selected_s, chunk_size)
220
+ #edited_image = animated_video[:,:,0,:,:]
221
+
222
+ img_start = video_target_tensor[0:1,:,:,:]
223
+ #vid_target_tensor_batch = rearrange(video_target_tensor, 'b t c h w -> (b t) c h w')
224
+
225
+ res = []
226
+ t = video_target_tensor.size(1)
227
+ chunks = t // chunk_size
228
+ z_s2r, alpha_r2s, feat_rgb = compiled_enc_img(image_tensor, selected_s)
229
+ #z_s2r, alpha_r2s, feat_rgb = gen.enc_img(image_tensor, labels_v, selected_s)
230
+ for i in range(chunks+1):
231
+ if i == chunks:
232
+ img_target = vid_target_tensor[i*chunk_size:, :, :, :]
233
+ img_animated = compiled_dec_vid(z_s2r, alpha_r2s, feat_rgb, img_start, img_target)
234
+ #img_animated_batch = gen.dec_vid(z_s2r, alpha_r2s, feat_rgb, img_start, img_target_batch)
235
+ else:
236
+ img_target = vid_target_tensor[i*chunk_size:(i+1)*chunk_size, :, :, :]
237
+ img_animated = compiled_dec_vid(z_s2r, alpha_r2s, feat_rgb, img_start, img_target)
238
+ #img_animated_batch = gen.dec_vid(z_s2r, alpha_r2s, feat_rgb, img_start, img_target_batch)
239
+
240
+ res.append(img_animated)
241
+ animated_video = torch.cat(res, dim=0) # TCHW
242
+ edited_image = animated_video[0:1,:,:,:]
243
 
244
  # postprocessing
245
  animated_video = vid_postprocessing(animated_video, w, h, fps)
 
250
  def clear_media():
251
  return None, None, *([0] * len(labels_k))
252
 
253
+
254
  with gr.Tab("Image Animation"):
255
 
256
  inputs_s = []
 
290
  with gr.Row():
291
  with gr.Column(scale=1):
292
  with gr.Row(): # Buttons now within a single Row
293
+ #edit_btn = gr.Button("Edit", elem_id="button_edit",)
 
 
294
  animate_btn = gr.Button("Animate", elem_id="button_animate")
295
+ with gr.Row():
296
+ clear_btn = gr.Button("Clear", elem_id="button_clear")
297
 
298
 
299
  with gr.Column(scale=1):
 
308
  #video_output.render()
309
  video_output = gr.Video(label="Output Video", elem_id="output_vid", width=512)#.render()
310
 
311
+ with gr.Accordion("Control Panel (Using Sliders to Edit Image)", open=True):
312
  with gr.Tab("Head"):
313
  with gr.Row():
314
  for k in labels_k[:3]:
 
338
  for k in labels_k[12:14]:
339
  slider = gr.Slider(minimum=-0.2, maximum=0.2, value=0, label=k, elem_id="slider_"+str(k))
340
  inputs_s.append(slider)
341
+
342
+ for slider in inputs_s:
343
+ slider.change(
344
+ fn=edit_media,
345
+ inputs=[image_input] + inputs_s,
346
+ outputs=[image_output],
347
 
348
+ show_progress='hidden',
349
+
350
+ trigger_mode='always_last',
351
+
352
+ # currently we have a latency around 450ms
353
+ stream_every=0.5
354
+ )
355
 
356
+
357
+ #edit_btn.click(
358
+ # fn=edit_media,
359
+ # inputs=[image_input] + inputs_s,
360
+ # outputs=[image_output],
361
+ # show_progress=True
362
+ #)
363
 
364
  animate_btn.click(
365
  fn=animate_media,
366
  inputs=[image_input, video_input] + inputs_s,
367
  outputs=[image_output, video_output],
368
+ show_progress=True
369
  )
370
 
371
  clear_btn.click(
 
381
  ['./data/source/macron.png', './data/driving/driving1.mp4', 0.14,0,-0.26,-0.29,-0.11,0,-0.13,-0.18,0,0,0,0,-0.02,0.07],
382
  ['./data/source/portrait3.png', './data/driving/driving1.mp4', -0.03,0.21,-0.31,-0.12,-0.11,0,-0.05,-0.16,0,0,0,0,-0.02,0.07],
383
  ['./data/source/einstein.png','./data/driving/driving2.mp4',-0.31,0,0,0.16,0.08,0,-0.07,0,0.13,0,0,0,0,0],
384
+ ['./data/source/portrait1.png', './data/driving/driving4.mp4', 0, 0, -0.17, -0.19, 0.25, 0, 0, -0.086,
385
  0.087, 0, 0, 0, 0, 0],
386
  ['./data/source/portrait2.png','./data/driving/driving8.mp4',0,0,-0.25,0,0,0,0,0,0,0.126,0,0,0,0],
387
 
388
  ],
389
+ fn=animate_media,
390
  inputs=[image_input, video_input] + inputs_s,
391
+ outputs=[image_output, video_output],
392
  )
393
 
394
 
gradio_tabs/img_edit.py CHANGED
@@ -37,69 +37,115 @@ labels_v = [
37
  ]
38
 
39
 
 
40
  def load_image(img, size):
41
  img = Image.open(img).convert('RGB')
42
  w, h = img.size
43
  img = img.resize((size, size))
44
  img = np.asarray(img)
 
45
  img = np.transpose(img, (2, 0, 1)) # 3 x 256 x 256
46
 
47
  return img / 255.0, w, h
48
 
49
 
 
50
  def img_preprocessing(img_path, size):
51
- img, w, h = load_image(img_path, size) # [0, 1]
52
  img = torch.from_numpy(img).unsqueeze(0).float() # [0, 1]
53
  imgs_norm = (img - 0.5) * 2.0 # [-1, 1]
54
 
55
  return imgs_norm, w, h
56
 
57
 
58
- def resize(img, size):
59
- transform = torchvision.transforms.Compose([
60
- torchvision.transforms.Resize((size,size), antialias=True),
61
- ])
 
 
 
 
 
 
 
 
 
62
 
 
 
 
 
63
  return transform(img)
64
 
65
 
66
  def resize_back(img, w, h):
67
- transform = torchvision.transforms.Compose([
68
- torchvision.transforms.Resize((h, w), antialias=True),
69
- ])
70
-
71
  return transform(img)
72
 
73
 
74
  def img_denorm(img):
75
- img = img.clamp(-1, 1).cpu()
76
  img = (img - img.min()) / (img.max() - img.min())
77
 
78
  return img
79
 
80
 
81
- def img_postprocessing(image, w, h):
82
 
83
- image = resize_back(image, w, h)
84
- image = image.permute(0, 2, 3, 1)
85
- edited_image = img_denorm(image)
86
- img_output = (edited_image[0].numpy() * 255).astype(np.uint8)
 
87
 
88
- with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
89
- imageio.imwrite(temp_file.name, img_output, quality=8)
90
- return temp_file.name
 
91
 
92
 
93
  def img_edit(gen, device):
94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  @spaces.GPU
96
- @torch.no_grad()
97
  def edit_img(image, *selected_s):
98
 
99
  image_tensor, w, h = img_preprocessing(image, 512)
100
  image_tensor = image_tensor.to(device)
101
 
102
- edited_image_tensor = gen.edit_img(image_tensor, labels_v, selected_s)
 
103
 
104
  # de-norm
105
  edited_image = img_postprocessing(edited_image_tensor, w, h)
@@ -136,7 +182,7 @@ def img_edit(gen, device):
136
  with gr.Row():
137
  with gr.Column(scale=1):
138
  with gr.Row(): # Buttons now within a single Row
139
- edit_btn = gr.Button("Edit")
140
  clear_btn = gr.Button("Clear")
141
  #with gr.Row():
142
  # animate_btn = gr.Button("Generate")
@@ -150,7 +196,7 @@ def img_edit(gen, device):
150
  image_output = gr.Image(label="Output Image", type='numpy', interactive=False, width=512)
151
 
152
 
153
- with gr.Accordion("Control Panel", open=True):
154
  with gr.Tab("Head"):
155
  with gr.Row():
156
  for k in labels_k[:3]:
@@ -181,15 +227,29 @@ def img_edit(gen, device):
181
  slider = gr.Slider(minimum=-0.2, maximum=0.2, value=0, label=k)
182
  inputs_s.append(slider)
183
 
184
-
185
- edit_btn.click(
186
  fn=edit_img,
187
  inputs=[image_input] + inputs_s,
188
  outputs=[image_output],
189
- show_progress=True
190
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
 
192
  clear_btn.click(
193
  fn=clear_media,
194
  outputs=[image_output] + inputs_s
195
- )
 
37
  ]
38
 
39
 
40
+ @torch.compiler.allow_in_graph
41
  def load_image(img, size):
42
  img = Image.open(img).convert('RGB')
43
  w, h = img.size
44
  img = img.resize((size, size))
45
  img = np.asarray(img)
46
+ img = np.copy(img)
47
  img = np.transpose(img, (2, 0, 1)) # 3 x 256 x 256
48
 
49
  return img / 255.0, w, h
50
 
51
 
52
+ @torch.compiler.allow_in_graph
53
  def img_preprocessing(img_path, size):
54
+ img, w, h = load_image(img_path, size) # [0, 1]
55
  img = torch.from_numpy(img).unsqueeze(0).float() # [0, 1]
56
  imgs_norm = (img - 0.5) * 2.0 # [-1, 1]
57
 
58
  return imgs_norm, w, h
59
 
60
 
61
+ # Pre-compile resize transforms for better performance
62
+ resize_transform_cache = {}
63
+
64
+ def get_resize_transform(size):
65
+ """Get cached resize transform - creates once, reuses many times"""
66
+ if size not in resize_transform_cache:
67
+ # Only create the transform if it doesn't exist in cache
68
+ resize_transform_cache[size] = torchvision.transforms.Resize(
69
+ size,
70
+ interpolation=torchvision.transforms.InterpolationMode.BILINEAR,
71
+ antialias=True
72
+ )
73
+ return resize_transform_cache[size]
74
 
75
+
76
+ def resize(img, size):
77
+ """Use cached resize transform"""
78
+ transform = get_resize_transform((size, size))
79
  return transform(img)
80
 
81
 
82
  def resize_back(img, w, h):
83
+ """Use cached resize transform for back operation"""
84
+ transform = get_resize_transform((h, w))
 
 
85
  return transform(img)
86
 
87
 
88
  def img_denorm(img):
89
+ img = img.clamp(-1, 1)
90
  img = (img - img.min()) / (img.max() - img.min())
91
 
92
  return img
93
 
94
 
95
+ def img_postprocessing(img, w, h):
96
 
97
+ img = resize_back(img, w, h)
98
+ #image = image.permute(0, 2, 3, 1)
99
+ img = img_denorm(img)
100
+ img = img.squeeze(0).permute(1, 2, 0).contiguous() # contiguous() for fast transfer
101
+ img_output = (img.cpu().numpy() * 255).astype(np.uint8)
102
 
103
+ #with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
104
+ # imageio.imwrite(temp_file.name, img_output, quality=8)
105
+ # return temp_file.name
106
+ return img_output
107
 
108
 
109
  def img_edit(gen, device):
110
 
111
+ @torch.compile
112
+ def compiled_enc_img(image_tensor, selected_s):
113
+ """Compiled version of just the model inference"""
114
+ return gen.enc_img(image_tensor, labels_v, selected_s)
115
+
116
+ @torch.compile
117
+ def compiled_dec_img(z_s2r, alpha_r2s, feat_rgb):
118
+ """Compiled version of just the model inference"""
119
+ return gen.dec_img(z_s2r, alpha_r2s, feat_rgb)
120
+
121
+
122
+ # Pre-warm the compiled model with dummy data to reduce first-run compilation time
123
+ def _warmup_model():
124
+ """Pre-warm the model compilation with representative shapes"""
125
+ print("[img_edit] Pre-warming model compilation...")
126
+ dummy_image = torch.randn(1, 3, 512, 512, device=device)
127
+ dummy_selected_s = [0.0] * len(labels_v)
128
+
129
+ try:
130
+ with torch.inference_mode():
131
+ z_s2r, alpha_r2s, feat_rgb = compiled_enc_img(dummy_image, dummy_selected_s)
132
+ _ = compiled_dec_img(z_s2r, alpha_r2s, feat_rgb)
133
+ print("[img_edit] Model pre-warming completed successfully")
134
+ except Exception as e:
135
+ print(f"[img_edit] Model pre-warming failed (will compile on first use): {e}")
136
+
137
+ # Pre-warm the model
138
+ _warmup_model()
139
+
140
  @spaces.GPU
141
+ @torch.inference_mode()
142
  def edit_img(image, *selected_s):
143
 
144
  image_tensor, w, h = img_preprocessing(image, 512)
145
  image_tensor = image_tensor.to(device)
146
 
147
+ z_s2r, alpha_r2s, feat_rgb = compiled_enc_img(image_tensor, selected_s)
148
+ edited_image_tensor = compiled_dec_img(z_s2r, alpha_r2s, feat_rgb)
149
 
150
  # de-norm
151
  edited_image = img_postprocessing(edited_image_tensor, w, h)
 
182
  with gr.Row():
183
  with gr.Column(scale=1):
184
  with gr.Row(): # Buttons now within a single Row
185
+ #edit_btn = gr.Button("Edit")
186
  clear_btn = gr.Button("Clear")
187
  #with gr.Row():
188
  # animate_btn = gr.Button("Generate")
 
196
  image_output = gr.Image(label="Output Image", type='numpy', interactive=False, width=512)
197
 
198
 
199
+ with gr.Accordion("Control Panel (Using Sliders to Edit Image)", open=True):
200
  with gr.Tab("Head"):
201
  with gr.Row():
202
  for k in labels_k[:3]:
 
227
  slider = gr.Slider(minimum=-0.2, maximum=0.2, value=0, label=k)
228
  inputs_s.append(slider)
229
 
230
+ for slider in inputs_s:
231
+ slider.change(
232
  fn=edit_img,
233
  inputs=[image_input] + inputs_s,
234
  outputs=[image_output],
235
+
236
+ show_progress='hidden',
237
+
238
+ trigger_mode='always_last',
239
+
240
+ # currently we have a latency around 450ms
241
+ stream_every=0.5
242
+ )
243
+
244
+
245
+ #edit_btn.click(
246
+ # fn=edit_img,
247
+ # inputs=[image_input] + inputs_s,
248
+ # outputs=[image_output],
249
+ # show_progress=True
250
+ #)
251
 
252
  clear_btn.click(
253
  fn=clear_media,
254
  outputs=[image_output] + inputs_s
255
+ )
gradio_tabs/vid_edit.py CHANGED
@@ -37,92 +37,118 @@ labels_v = [
37
  ]
38
 
39
 
 
40
  def load_image(img, size):
41
- # img = Image.open(filename).convert('RGB')
42
- if not isinstance(img, np.ndarray):
43
- img = Image.open(img).convert('RGB')
44
- img = img.resize((size, size))
45
- img = np.asarray(img)
 
46
  img = np.transpose(img, (2, 0, 1)) # 3 x 256 x 256
47
 
48
- return img / 255.0
49
 
50
 
 
51
  def img_preprocessing(img_path, size):
52
- img = load_image(img_path, size) # [0, 1]
53
  img = torch.from_numpy(img).unsqueeze(0).float() # [0, 1]
54
  imgs_norm = (img - 0.5) * 2.0 # [-1, 1]
55
 
56
- return imgs_norm
57
 
 
 
58
 
59
- def resize(img, size):
60
- transform = torchvision.transforms.Compose([
61
- torchvision.transforms.Resize((size, size), antialias=True),
62
- ])
 
 
 
 
 
 
63
 
 
 
 
64
  return transform(img)
65
 
66
 
67
  def resize_back(img, w, h):
68
- transform = torchvision.transforms.Compose([
69
- torchvision.transforms.Resize((h, w), antialias=True),
70
- ])
71
-
72
  return transform(img)
73
-
74
 
75
  def vid_preprocessing(vid_path, size):
76
  vid_dict = torchvision.io.read_video(vid_path, pts_unit='sec')
77
- vid = vid_dict[0].permute(0, 3, 1, 2).unsqueeze(0) # btchw
78
  _,_,_,h,w = vid.size()
79
  fps = vid_dict[2]['video_fps']
80
  vid_norm = (vid / 255.0 - 0.5) * 2.0 # [-1, 1]
81
 
82
- vid_norm = torch.cat([
83
- resize(vid_norm[:, i, :, :, :], size).unsqueeze(1) for i in range(vid.size(1))
84
- ], dim=1)
85
 
86
  return vid_norm, fps, w, h
87
 
88
 
89
  def img_denorm(img):
90
- img = img.clamp(-1, 1).cpu()
91
  img = (img - img.min()) / (img.max() - img.min())
92
 
93
  return img
94
 
95
 
96
  def vid_denorm(vid):
97
- vid = vid.clamp(-1, 1).cpu()
98
  vid = (vid - vid.min()) / (vid.max() - vid.min())
99
 
100
  return vid
101
 
102
 
103
  def img_postprocessing(image, w, h):
104
- image = resize_back(image, w, h)
105
- image = image.permute(0, 2, 3, 1)
106
- edited_image = img_denorm(image)
107
- img_output = (edited_image[0].numpy() * 255).astype(np.uint8)
108
 
109
- with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
110
- imageio.imwrite(temp_file.name, img_output, quality=8)
111
- return temp_file.name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
 
114
  def vid_all_save(vid_d, vid_a, w, h, fps):
 
 
115
 
116
- b,t,c,_,_ = vid_d.size()
117
- vid_d_batch = resize_back(rearrange(vid_d, "b t c h w -> (b t) c h w"), w, h)
118
- vid_a_batch = resize_back(rearrange(vid_a, "b c t h w -> (b t) c h w"), w, h)
119
-
120
- vid_d = rearrange(vid_d_batch, "(b t) c h w -> b t h w c", b=b) # B T H W C
121
- vid_a = rearrange(vid_a_batch, "(b t) c h w -> b t h w c", b=b) # B T H W C
122
- vid_all = torch.cat([vid_d, vid_a], dim=3)
123
 
124
- vid_a_np = (vid_denorm(vid_a[0]).numpy() * 255).astype('uint8')
125
- vid_all_np = (vid_denorm(vid_all[0]).numpy() * 255).astype('uint8')
126
 
127
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as output_path:
128
  imageio.mimwrite(output_path.name, vid_a_np, fps=fps, codec='libx264', quality=8)
@@ -134,16 +160,59 @@ def vid_all_save(vid_d, vid_a, w, h, fps):
134
 
135
 
136
  def vid_edit(gen, chunk_size, device):
137
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  @spaces.GPU
139
- @torch.no_grad()
140
  def edit_img(video, *selected_s):
141
 
142
- vid_target_tensor, fps, w, h = vid_preprocessing(video, 512)
143
- video_target_tensor = vid_target_tensor.to(device)
144
- image_tensor = video_target_tensor[:,0,:,:,:]
145
 
146
- edited_image_tensor = gen.edit_img(image_tensor, labels_v, selected_s)
 
147
 
148
  # de-norm
149
  edited_image = img_postprocessing(edited_image_tensor, w, h)
@@ -151,21 +220,35 @@ def vid_edit(gen, chunk_size, device):
151
  return edited_image
152
 
153
  @spaces.GPU
154
- @torch.no_grad()
155
  def edit_vid(video, *selected_s):
156
 
157
  video_target_tensor, fps, w, h = vid_preprocessing(video, 512)
158
  video_target_tensor = video_target_tensor.to(device)
159
 
160
- edited_video_tensor = gen.edit_vid_batch(video_target_tensor, labels_v, selected_s, chunk_size)
161
- edited_image_tensor = edited_video_tensor[:,:,0,:,:]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
 
163
  # de-norm
164
- animated_video, animated_all_video = vid_all_save(video_target_tensor, edited_video_tensor, w, h, fps)
165
  edited_image = img_postprocessing(edited_image_tensor, w, h)
166
 
167
- return edited_image, animated_video, animated_all_video
168
-
169
 
170
  def clear_media():
171
  return None, None, None, *([0] * len(labels_k))
@@ -210,7 +293,7 @@ def vid_edit(gen, chunk_size, device):
210
  video_all_output = gr.Video(label="Videos", elem_id="output_vid_all")
211
 
212
  with gr.Column(scale=1):
213
- with gr.Accordion("Control Panel", open=True):
214
  with gr.Tab("Head"):
215
  with gr.Row():
216
  for k in labels_k[:3]:
@@ -244,17 +327,27 @@ def vid_edit(gen, chunk_size, device):
244
  with gr.Row():
245
  with gr.Column(scale=1):
246
  with gr.Row(): # Buttons now within a single Row
247
- edit_btn = gr.Button("Edit",elem_id="button_edit")
248
- clear_btn = gr.Button("Clear",elem_id="button_clear")
249
- with gr.Row():
250
  animate_btn = gr.Button("Generate",elem_id="button_generate")
251
-
252
- edit_btn.click(
253
- fn=edit_img,
254
- inputs=[video_input] + inputs_s,
255
- outputs=[image_output],
256
- show_progress=True
257
- )
 
 
 
 
 
 
 
 
 
 
 
 
258
 
259
  animate_btn.click(
260
  fn=edit_vid,
@@ -280,9 +373,9 @@ def vid_edit(gen, chunk_size, device):
280
  ['./data/driving/driving9.mp4', 0, 0, 0, 0, 0, 0, 0,
281
  0, 0, 0, 0, 0, -0.1, 0.07],
282
  ],
283
- fn=edit_vid,
284
  inputs=[video_input] + inputs_s,
285
- outputs=[image_output, video_output, video_all_output],
286
  )
287
 
288
 
 
37
  ]
38
 
39
 
40
+ @torch.compiler.allow_in_graph
41
  def load_image(img, size):
42
+
43
+ img = Image.open(img).convert('RGB')
44
+ w, h = img.size
45
+ img = img.resize((size, size))
46
+ img = np.asarray(img)
47
+ img = np.copy(img)
48
  img = np.transpose(img, (2, 0, 1)) # 3 x 256 x 256
49
 
50
+ return img / 255.0, w, h
51
 
52
 
53
+ @torch.compiler.allow_in_graph
54
  def img_preprocessing(img_path, size):
55
+ img, w, h = load_image(img_path, size) # [0, 1]
56
  img = torch.from_numpy(img).unsqueeze(0).float() # [0, 1]
57
  imgs_norm = (img - 0.5) * 2.0 # [-1, 1]
58
 
59
+ return imgs_norm, w, h
60
 
61
+ # Pre-compile resize transforms for better performance
62
+ resize_transform_cache = {}
63
 
64
+ def get_resize_transform(size):
65
+ """Get cached resize transform - creates once, reuses many times"""
66
+ if size not in resize_transform_cache:
67
+ # Only create the transform if it doesn't exist in cache
68
+ resize_transform_cache[size] = torchvision.transforms.Resize(
69
+ size,
70
+ interpolation=torchvision.transforms.InterpolationMode.BILINEAR,
71
+ antialias=True
72
+ )
73
+ return resize_transform_cache[size]
74
 
75
+ def resize(img, size):
76
+ """Use cached resize transform"""
77
+ transform = get_resize_transform((size, size))
78
  return transform(img)
79
 
80
 
81
  def resize_back(img, w, h):
82
+ """Use cached resize transform for back operation"""
83
+ transform = get_resize_transform((h, w))
 
 
84
  return transform(img)
85
+
86
 
87
  def vid_preprocessing(vid_path, size):
88
  vid_dict = torchvision.io.read_video(vid_path, pts_unit='sec')
89
+ vid = vid_dict[0].permute(0, 3, 1, 2) # tchw
90
  _,_,_,h,w = vid.size()
91
  fps = vid_dict[2]['video_fps']
92
  vid_norm = (vid / 255.0 - 0.5) * 2.0 # [-1, 1]
93
 
94
+ vid_norm = resize(vid_norm, size)
 
 
95
 
96
  return vid_norm, fps, w, h
97
 
98
 
99
  def img_denorm(img):
100
+ img = img.clamp(-1, 1)
101
  img = (img - img.min()) / (img.max() - img.min())
102
 
103
  return img
104
 
105
 
106
  def vid_denorm(vid):
107
+ vid = vid.clamp(-1, 1)
108
  vid = (vid - vid.min()) / (vid.max() - vid.min())
109
 
110
  return vid
111
 
112
 
113
  def img_postprocessing(image, w, h):
 
 
 
 
114
 
115
+ img = resize_back(image, w, h)
116
+
117
+ # Denormalize ON GPU (avoid early CPU transfer)
118
+ img = img_denorm(img)
119
+
120
+ # Single optimized CPU transfer
121
+ img = img.squeeze(0).permute(1, 2, 0).contiguous() # contiguous() for fast transfer
122
+ img_output = (img.cpu().numpy() * 255).astype(np.uint8) # Single CPU transfer
123
+
124
+ # return the Numpy array directly, since Gradio supports it
125
+ return img_output
126
+
127
+
128
+ def process_first_frame(vid_path, size):
129
+ vid_dict = torchvision.io.read_video(vid_path, start_pts=0, end_pts=0, pts_unit='sec')
130
+ img = vid_dict[0].permute(0, 3, 1, 2) # bchw
131
+ _, _, h, w = img.size()
132
+ img_norm = (img / 255.0 - 0.5) * 2.0 # [-1, 1]
133
+ img_norm = resize(img_norm, size)
134
+
135
+ return img_norm, w, h
136
 
137
 
138
  def vid_all_save(vid_d, vid_a, w, h, fps):
139
+ # vid_d: tchw
140
+ # vid_a: tchw
141
 
142
+ t, c, _, _ = vid_d.size()
143
+ vid_d_batch = resize_back(vid_d, w, h)
144
+ vid_a_batch = resize_back(vid_a, w, h)
145
+
146
+ vid_d = rearrange(vid_d_batch, "t c h w -> t h w c") # T H W C
147
+ vid_a = rearrange(vid_a_batch, "t c h w -> t h w c") # T H W C
148
+ vid_all = torch.cat([vid_d, vid_a], dim=2)
149
 
150
+ vid_a_np = (vid_denorm(vid_a).cpu().numpy() * 255).astype('uint8')
151
+ vid_all_np = (vid_denorm(vid_all).cpu().numpy() * 255).astype('uint8')
152
 
153
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as output_path:
154
  imageio.mimwrite(output_path.name, vid_a_np, fps=fps, codec='libx264', quality=8)
 
160
 
161
 
162
  def vid_edit(gen, chunk_size, device):
163
+
164
+ @torch.compile
165
+ def compiled_enc_img(image_tensor, selected_s):
166
+ """Compiled version of just the model inference"""
167
+ return gen.enc_img(image_tensor, labels_v, selected_s)
168
+
169
+ @torch.compile
170
+ def compiled_dec_img(z_s2r, alpha_r2s, feat_rgb):
171
+ """Compiled version of just the model inference"""
172
+ return gen.dec_img(z_s2r, alpha_r2s, feat_rgb)
173
+
174
+ @torch.compile
175
+ def compiled_dec_vid(z_s2r, alpha_r2s, feat_rgb, img_start, img_target_batch):
176
+ """Compiled version of animate_batch for animation tab"""
177
+ return gen.dec_vid(z_s2r, alpha_r2s, feat_rgb, img_start, img_target_batch)
178
+
179
+ # Pre-warm the compiled model with dummy data to reduce first-run compilation time
180
+ def _warmup_model():
181
+ """Pre-warm the model compilation with representative shapes"""
182
+ print("[img_edit] Pre-warming model compilation...")
183
+ dummy_image = torch.randn(1, 3, 512, 512, device=device)
184
+ dummy_video = torch.randn(chunk_size, 3, 512, 512, device=device)
185
+ dummy_selected_s = [0.0] * len(labels_v)
186
+
187
+ try:
188
+ with torch.inference_mode():
189
+ z_s2r, alpha_r2s, feat_rgb = compiled_enc_img(dummy_image, dummy_selected_s)
190
+ _ = compiled_dec_img(z_s2r, alpha_r2s, feat_rgb)
191
+ print("[img_edit] Model pre-warming completed successfully")
192
+ except Exception as e:
193
+ print(f"[img_edit] Model pre-warming failed (will compile on first use): {e}")
194
+
195
+ try:
196
+ with torch.inference_mode():
197
+ z_s2r, alpha_r2s, feat_rgb = compiled_enc_img(dummy_image, dummy_selected_s)
198
+ _ = compiled_dec_vid(z_s2r, alpha_r2s, feat_rgb, dummy_video[0], dummy_video)
199
+ print("[img_animation] Model pre-warming completed successfully")
200
+ except Exception as e:
201
+ print(f"[img_animation] Model pre-warming failed (will compile on first use): {e}")
202
+
203
+ # Pre-warm the model
204
+ _warmup_model()
205
+
206
+
207
  @spaces.GPU
208
+ @torch.inference_mode()
209
  def edit_img(video, *selected_s):
210
 
211
+ image_tensor, w, h = process_first_frame(video, 512)
212
+ image_tensor = image_tensor.to(device)
 
213
 
214
+ z_s2r, alpha_r2s, feat_rgb = compiled_enc_img(image_tensor, selected_s)
215
+ edited_image_tensor = compiled_dec_img(z_s2r, alpha_r2s, feat_rgb)
216
 
217
  # de-norm
218
  edited_image = img_postprocessing(edited_image_tensor, w, h)
 
220
  return edited_image
221
 
222
  @spaces.GPU
223
+ @torch.inference_mode()
224
  def edit_vid(video, *selected_s):
225
 
226
  video_target_tensor, fps, w, h = vid_preprocessing(video, 512)
227
  video_target_tensor = video_target_tensor.to(device)
228
 
229
+ img_start = video_target_tensor[0:1, :, :, :]
230
+
231
+ res = []
232
+ t = video_target_tensor.size(1)
233
+ chunks = t // chunk_size
234
+ z_s2r, alpha_r2s, feat_rgb = compiled_enc_img(img_start, selected_s)
235
+ for i in range(chunks + 1):
236
+ if i == chunks:
237
+ img_target_batch = vid_target_tensor_batch[i * chunk_size:, :, :, :]
238
+ img_animated_batch = compiled_dec_vid(z_s2r, alpha_r2s, feat_rgb, img_start, img_target)
239
+ else:
240
+ img_target_batch = vid_target_tensor_batch[i * chunk_size:(i + 1) * chunk_size, :, :, :]
241
+ img_animated_batch = compiled_dec_vid(z_s2r, alpha_r2s, feat_rgb, img_start, img_target)
242
+
243
+ res.append(img_animated_batch)
244
+ edited_video_tensor = torch.cat(res, dim=0) # TCHW
245
+ edited_image_tensor = edited_video_tensor[0:1,:,:,:]
246
 
247
  # de-norm
248
+ animated_video, animated_all_video = vid_all_save(vid_target_tensor_batch, edited_video_tensor, w, h, fps)
249
  edited_image = img_postprocessing(edited_image_tensor, w, h)
250
 
251
+ return edited_image, animated_video, animated_all_video
 
252
 
253
  def clear_media():
254
  return None, None, None, *([0] * len(labels_k))
 
293
  video_all_output = gr.Video(label="Videos", elem_id="output_vid_all")
294
 
295
  with gr.Column(scale=1):
296
+ with gr.Accordion("Control Panel (Using Sliders to Edit Image)", open=True):
297
  with gr.Tab("Head"):
298
  with gr.Row():
299
  for k in labels_k[:3]:
 
327
  with gr.Row():
328
  with gr.Column(scale=1):
329
  with gr.Row(): # Buttons now within a single Row
330
+ #edit_btn = gr.Button("Edit",elem_id="button_edit")
 
 
331
  animate_btn = gr.Button("Generate",elem_id="button_generate")
332
+ clear_btn = gr.Button("Clear",elem_id="button_clear")
333
+
334
+ for slider in inputs_s:
335
+ slider.change(
336
+ fn=edit_img,
337
+ inputs=[video_input] + inputs_s,
338
+ outputs=[image_output],
339
+ show_progress='hidden',
340
+ trigger_mode='always_last',
341
+ # currently we have a latency around 450ms
342
+ stream_every=0.5
343
+ )
344
+
345
+ #edit_btn.click(
346
+ # fn=edit_img,
347
+ # inputs=[video_input] + inputs_s,
348
+ # outputs=[image_output],
349
+ # show_progress=True
350
+ #)
351
 
352
  animate_btn.click(
353
  fn=edit_vid,
 
373
  ['./data/driving/driving9.mp4', 0, 0, 0, 0, 0, 0, 0,
374
  0, 0, 0, 0, 0, -0.1, 0.07],
375
  ],
376
+ fn=edit_vid,
377
  inputs=[video_input] + inputs_s,
378
+ outputs=[image_output, video_output, video_all_output],
379
  )
380
 
381
 
networks/generator.py CHANGED
@@ -17,6 +17,12 @@ class Generator(nn.Module):
17
  self.enc = Encoder(style_dim, motion_dim, scale)
18
  self.dec = Decoder(style_dim, motion_dim, scale)
19
 
 
 
 
 
 
 
20
  def get_alpha(self, x):
21
  return self.enc.enc_motion(x)
22
 
@@ -83,7 +89,7 @@ class Generator(nn.Module):
83
  vid_target_recon = rearrange(vid_target_recon, 'b t c h w -> b c t h w')
84
 
85
  return vid_target_recon # BCTHW
86
-
87
  def edit_vid(self, vid_target, d_l, v_l):
88
 
89
  img_source = vid_target[:, 0, :, :, :]
@@ -195,3 +201,36 @@ class Generator(nn.Module):
195
 
196
  return vid_target_recon
197
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  self.enc = Encoder(style_dim, motion_dim, scale)
18
  self.dec = Decoder(style_dim, motion_dim, scale)
19
 
20
+ @property
21
+ def device(self):
22
+ if self._device is None:
23
+ self._device = next(self.parameters()).device
24
+ return self._device
25
+
26
  def get_alpha(self, x):
27
  return self.enc.enc_motion(x)
28
 
 
89
  vid_target_recon = rearrange(vid_target_recon, 'b t c h w -> b c t h w')
90
 
91
  return vid_target_recon # BCTHW
92
+
93
  def edit_vid(self, vid_target, d_l, v_l):
94
 
95
  img_source = vid_target[:, 0, :, :, :]
 
201
 
202
  return vid_target_recon
203
 
204
+ def enc_img(self, img_source, d_l, v_l):
205
+ """Core edit_img logic without timing - can be compiled"""
206
+ z_s2r, feat_rgb = self.enc.enc_2r(img_source)
207
+ alpha_r2s = self.enc.enc_r2t(z_s2r)
208
+
209
+ # Create tensor directly on the same device as alpha_r2s
210
+ v_l_tensor = torch.tensor(v_l, device=alpha_r2s.device, dtype=alpha_r2s.dtype).unsqueeze(0)
211
+ alpha_r2s[:, d_l] = alpha_r2s[:, d_l] + v_l_tensor
212
+
213
+ return z_s2r, alpha_r2s, feat_rgb
214
+
215
+ def dec_img(self, z_s2r, alpha_r2s, feat_rgb):
216
+ return self.dec(z_s2r, [alpha_r2s], feat_rgb)
217
+
218
+
219
+ def dec_vid(self, z_s2r, alpha_r2s, feat_rgb, img_start, img_target_batch):
220
+ # z_s2r: BC
221
+ # alpha_r2s: BC
222
+ # feat: BCHW
223
+ # alpha_start: BC
224
+
225
+ bs = img_target_batch.size(0)
226
+ alpha_start = self.get_alpha(img_start)
227
+
228
+ alpha_start_r = repeat(alpha_start, 'b c -> (repeat b) c', repeat=bs)
229
+ alpha_r2s_r = repeat(alpha_r2s, 'b c -> (repeat b) c', repeat=bs)
230
+ feat_rgb_r = [repeat(feat, 'b c h w -> (repeat b) c h w', repeat=bs) for feat in feat_rgb]
231
+ z_s2r_r = repeat(z_s2r, 'b c -> (repeat b) c', repeat=bs)
232
+
233
+ alpha = self.enc.enc_transfer_vid(alpha_r2s_r, img_target_batch, alpha_start_r)
234
+ img_batch_recon = self.dec(z_s2r_r, alpha, feat_rgb_r) # bs x 3 x h x w
235
+
236
+ return img_batch_recon