YaohuiW commited on
Commit
62c64ac
·
verified ·
1 Parent(s): 11362ae

Update gradio_tabs/animation.py

Browse files
Files changed (1) hide show
  1. gradio_tabs/animation.py +56 -137
gradio_tabs/animation.py CHANGED
@@ -36,74 +36,64 @@ labels_v = [
36
  13, 24, 17, 26
37
  ]
38
 
39
- @torch.compiler.allow_in_graph
40
  def load_image(img, size):
41
 
42
  img = Image.open(img).convert('RGB')
43
  w, h = img.size
44
  img = img.resize((size, size))
45
  img = np.asarray(img)
46
- img = np.copy(img)
47
  img = np.transpose(img, (2, 0, 1)) # 3 x 256 x 256
48
 
49
  return img / 255.0, w, h
50
 
51
 
52
- @torch.compiler.allow_in_graph
53
  def img_preprocessing(img_path, size):
54
- img, w, h = load_image(img_path, size) # [0, 1]
55
  img = torch.from_numpy(img).unsqueeze(0).float() # [0, 1]
56
  imgs_norm = (img - 0.5) * 2.0 # [-1, 1]
57
 
58
  return imgs_norm, w, h
59
 
60
 
61
- # Pre-compile resize transforms for better performance
62
- resize_transform_cache = {}
63
-
64
- def get_resize_transform(size):
65
- """Get cached resize transform - creates once, reuses many times"""
66
- if size not in resize_transform_cache:
67
- # Only create the transform if it doesn't exist in cache
68
- resize_transform_cache[size] = torchvision.transforms.Resize(
69
- size,
70
- interpolation=torchvision.transforms.InterpolationMode.BILINEAR,
71
- antialias=True
72
- )
73
- return resize_transform_cache[size]
74
-
75
-
76
  def resize(img, size):
77
- """Use cached resize transform"""
78
- transform = get_resize_transform((size, size))
 
 
79
  return transform(img)
80
 
81
 
82
  def resize_back(img, w, h):
83
- """Use cached resize transform for back operation"""
84
- transform = get_resize_transform((h, w))
 
 
85
  return transform(img)
86
-
87
 
88
  def vid_preprocessing(vid_path, size):
89
  vid_dict = torchvision.io.read_video(vid_path, pts_unit='sec')
90
- vid = vid_dict[0].permute(0, 3, 1, 2) # tchw
91
  fps = vid_dict[2]['video_fps']
92
  vid_norm = (vid / 255.0 - 0.5) * 2.0 # [-1, 1]
93
- vid_norm = resize(vid_norm, size) # tchw
 
 
 
94
 
95
  return vid_norm, fps
96
 
97
 
98
  def img_denorm(img):
99
- img = img.clamp(-1, 1)
100
  img = (img - img.min()) / (img.max() - img.min())
101
 
102
  return img
103
 
104
 
105
  def vid_denorm(vid):
106
- vid = vid.clamp(-1, 1)
107
  vid = (vid - vid.min()) / (vid.max() - vid.min())
108
 
109
  return vid
@@ -111,30 +101,24 @@ def vid_denorm(vid):
111
 
112
  def img_postprocessing(image, w, h):
113
 
114
- img = resize_back(image, w, h)
115
-
116
- # Denormalize ON GPU (avoid early CPU transfer)
117
- img = img.clamp(-1, 1) # Still on GPU
118
- img = (img - img.min()) / (img.max() - img.min()) # Still on GPU
119
 
120
- # Single optimized CPU transfer
121
- img = img.squeeze(0).permute(1, 2, 0).contiguous() # contiguous() for fast transfer
122
- img_output = (img.cpu().numpy() * 255).astype(np.uint8) # Single CPU transfer
123
-
124
- # return the Numpy array directly, since Gradio supports it
125
- return img_output
126
 
127
 
128
 
129
  def vid_postprocessing(video, w, h, fps):
130
- # video: TCHW
131
 
132
- t,c,_,_ = video.size()
133
- vid = resize_back(video, w, h)
134
- vid = vid_denorm(vid)
135
-
136
- vid = rearrange(vid, "t c h w -> t h w c") # T H W C
137
- vid_np = (vid.cpu().numpy() * 255).astype('uint8')
138
 
139
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_file:
140
  imageio.mimwrite(temp_file.name, vid_np, fps=fps, codec='libx264', quality=8)
@@ -142,59 +126,15 @@ def vid_postprocessing(video, w, h, fps):
142
 
143
 
144
  def animation(gen, chunk_size, device):
145
-
146
- @torch.compile
147
- def compiled_enc_img(image_tensor, selected_s):
148
- """Compiled version of just the model inference"""
149
- return gen.enc_img(image_tensor, labels_v, selected_s)
150
-
151
- @torch.compile
152
- def compiled_dec_img(z_s2r, alpha_r2s, feat_rgb):
153
- """Compiled version of just the model inference"""
154
- return gen.dec_img(z_s2r, alpha_r2s, feat_rgb)
155
-
156
- @torch.compile
157
- def compiled_dec_vid(z_s2r, alpha_r2s, feat_rgb, img_start, img_target_batch):
158
- """Compiled version of animate_batch for animation tab"""
159
- return gen.dec_vid(z_s2r, alpha_r2s, feat_rgb, img_start, img_target_batch)
160
-
161
- # Pre-warm the compiled model with dummy data to reduce first-run compilation time
162
- def _warmup_model():
163
- """Pre-warm the model compilation with representative shapes"""
164
- print("[img_edit] Pre-warming model compilation...")
165
- dummy_image = torch.randn(1, 3, 512, 512, device=device)
166
- dummy_video = torch.randn(chunk_size, 3, 512, 512, device=device)
167
- dummy_selected_s = [0.0] * len(labels_v)
168
-
169
- try:
170
- with torch.inference_mode():
171
- z_s2r, alpha_r2s, feat_rgb = compiled_enc_img(dummy_image, dummy_selected_s)
172
- _ = compiled_dec_img(z_s2r, alpha_r2s, feat_rgb)
173
- print("[img_edit] Model pre-warming completed successfully")
174
- except Exception as e:
175
- print(f"[img_edit] Model pre-warming failed (will compile on first use): {e}")
176
-
177
- try:
178
- with torch.inference_mode():
179
- z_s2r, alpha_r2s, feat_rgb = compiled_enc_img(dummy_image, dummy_selected_s)
180
- _ = compiled_dec_vid(z_s2r, alpha_r2s, feat_rgb, dummy_video[0], dummy_video)
181
- print("[img_animation] Model pre-warming completed successfully")
182
- except Exception as e:
183
- print(f"[img_animation] Model pre-warming failed (will compile on first use): {e}")
184
-
185
- # Pre-warm the model
186
- _warmup_model()
187
-
188
-
189
  @spaces.GPU
190
- @torch.inference_mode()
191
  def edit_media(image, *selected_s):
192
 
193
  image_tensor, w, h = img_preprocessing(image, 512)
194
  image_tensor = image_tensor.to(device)
195
 
196
- z_s2r, alpha_r2s, feat_rgb = compiled_enc_img(image_tensor, selected_s)
197
- edited_image_tensor = compiled_dec_img(z_s2r, alpha_r2s, feat_rgb)
198
 
199
  # de-norm
200
  edited_image = img_postprocessing(edited_image_tensor, w, h)
@@ -202,35 +142,16 @@ def animation(gen, chunk_size, device):
202
  return edited_image
203
 
204
  @spaces.GPU
205
- @torch.inference_mode()
206
  def animate_media(image, video, *selected_s):
207
 
208
  image_tensor, w, h = img_preprocessing(image, 512)
209
  vid_target_tensor, fps = vid_preprocessing(video, 512)
210
  image_tensor = image_tensor.to(device)
211
- video_target_tensor = vid_target_tensor.to(device) #tchw
212
-
213
- img_start = video_target_tensor[0:1,:,:,:]
214
-
215
- res = []
216
- t, c, h, w = video_target_tensor.size()
217
-
218
- chunks = t // chunk_size
219
- if t%chunk_size == 0:
220
- vid_target_tensor_batch = torch.zeros(chunk_size * chunks, c, h, w).to(device)
221
- else:
222
- vid_target_tensor_batch = torch.zeros(chunk_size * (chunks + 1), c, h, w).to(device)
223
- vid_target_tensor_batch[:t] = video_target_tensor
224
-
225
- z_s2r, alpha_r2s, feat_rgb = compiled_enc_img(image_tensor, selected_s)
226
- for i in range(chunks+1):
227
-
228
- img_target_batch = vid_target_tensor_batch[i * chunk_size:(i + 1) * chunk_size, :, :, :]
229
- img_animated_batch = compiled_dec_vid(z_s2r, alpha_r2s, feat_rgb, img_start, img_target_batch)
230
-
231
- res.append(img_animated_batch)
232
- animated_video = torch.cat(res, dim=0)[:t] # TCHW
233
- edited_image = animated_video[0:1,:,:,:]
234
 
235
  # postprocessing
236
  animated_video = vid_postprocessing(animated_video, w, h, fps)
@@ -241,7 +162,7 @@ def animation(gen, chunk_size, device):
241
  def clear_media():
242
  return None, None, *([0] * len(labels_k))
243
 
244
-
245
  with gr.Tab("Image Animation"):
246
 
247
  inputs_s = []
@@ -281,10 +202,11 @@ def animation(gen, chunk_size, device):
281
  with gr.Row():
282
  with gr.Column(scale=1):
283
  with gr.Row(): # Buttons now within a single Row
284
- #edit_btn = gr.Button("Edit", elem_id="button_edit",)
285
- animate_btn = gr.Button("Animate", elem_id="button_animate")
286
- with gr.Row():
287
  clear_btn = gr.Button("Clear", elem_id="button_clear")
 
 
 
288
 
289
 
290
  with gr.Column(scale=1):
@@ -299,7 +221,7 @@ def animation(gen, chunk_size, device):
299
  #video_output.render()
300
  video_output = gr.Video(label="Output Video", elem_id="output_vid", width=512)#.render()
301
 
302
- with gr.Accordion("Control Panel - Using Sliders to Edit Image", open=True):
303
  with gr.Tab("Head"):
304
  with gr.Row():
305
  for k in labels_k[:3]:
@@ -329,23 +251,20 @@ def animation(gen, chunk_size, device):
329
  for k in labels_k[12:14]:
330
  slider = gr.Slider(minimum=-0.2, maximum=0.2, value=0, label=k, elem_id="slider_"+str(k))
331
  inputs_s.append(slider)
332
-
333
- for slider in inputs_s:
334
- slider.change(
335
- fn=edit_media,
336
- inputs=[image_input] + inputs_s,
337
- outputs=[image_output],
338
- show_progress='hidden',
339
- trigger_mode='always_last',
340
- # currently we have a latency around 450ms
341
- stream_every=0.5
342
- )
343
 
344
  animate_btn.click(
345
  fn=animate_media,
346
  inputs=[image_input, video_input] + inputs_s,
347
  outputs=[image_output, video_output],
348
- show_progress=True
349
  )
350
 
351
  clear_btn.click(
@@ -361,14 +280,14 @@ def animation(gen, chunk_size, device):
361
  ['./data/source/macron.png', './data/driving/driving1.mp4', 0.14,0,-0.26,-0.29,-0.11,0,-0.13,-0.18,0,0,0,0,-0.02,0.07],
362
  ['./data/source/portrait3.png', './data/driving/driving1.mp4', -0.03,0.21,-0.31,-0.12,-0.11,0,-0.05,-0.16,0,0,0,0,-0.02,0.07],
363
  ['./data/source/einstein.png','./data/driving/driving2.mp4',-0.31,0,0,0.16,0.08,0,-0.07,0,0.13,0,0,0,0,0],
364
- ['./data/source/portrait1.png', './data/driving/driving4.mp4', 0, 0, -0.17, -0.19, 0.25, 0, 0, -0.086,
365
  0.087, 0, 0, 0, 0, 0],
366
  ['./data/source/portrait2.png','./data/driving/driving8.mp4',0,0,-0.25,0,0,0,0,0,0,0.126,0,0,0,0],
367
 
368
  ],
369
- fn=animate_media,
370
  inputs=[image_input, video_input] + inputs_s,
371
- outputs=[image_output, video_output],
372
  )
373
 
374
 
 
36
  13, 24, 17, 26
37
  ]
38
 
39
+
40
  def load_image(img, size):
41
 
42
  img = Image.open(img).convert('RGB')
43
  w, h = img.size
44
  img = img.resize((size, size))
45
  img = np.asarray(img)
 
46
  img = np.transpose(img, (2, 0, 1)) # 3 x 256 x 256
47
 
48
  return img / 255.0, w, h
49
 
50
 
 
51
  def img_preprocessing(img_path, size):
52
+ img, w, h = load_image(img_path, size) # [0, 1]
53
  img = torch.from_numpy(img).unsqueeze(0).float() # [0, 1]
54
  imgs_norm = (img - 0.5) * 2.0 # [-1, 1]
55
 
56
  return imgs_norm, w, h
57
 
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  def resize(img, size):
60
+ transform = torchvision.transforms.Compose([
61
+ torchvision.transforms.Resize((size, size), antialias=True),
62
+ ])
63
+
64
  return transform(img)
65
 
66
 
67
  def resize_back(img, w, h):
68
+ transform = torchvision.transforms.Compose([
69
+ torchvision.transforms.Resize((h, w), antialias=True),
70
+ ])
71
+
72
  return transform(img)
73
+
74
 
75
  def vid_preprocessing(vid_path, size):
76
  vid_dict = torchvision.io.read_video(vid_path, pts_unit='sec')
77
+ vid = vid_dict[0].permute(0, 3, 1, 2).unsqueeze(0) # btchw
78
  fps = vid_dict[2]['video_fps']
79
  vid_norm = (vid / 255.0 - 0.5) * 2.0 # [-1, 1]
80
+
81
+ vid_norm = torch.cat([
82
+ resize(vid_norm[:, i, :, :, :], size).unsqueeze(1) for i in range(vid.size(1))
83
+ ], dim=1)
84
 
85
  return vid_norm, fps
86
 
87
 
88
  def img_denorm(img):
89
+ img = img.clamp(-1, 1).cpu()
90
  img = (img - img.min()) / (img.max() - img.min())
91
 
92
  return img
93
 
94
 
95
  def vid_denorm(vid):
96
+ vid = vid.clamp(-1, 1).cpu()
97
  vid = (vid - vid.min()) / (vid.max() - vid.min())
98
 
99
  return vid
 
101
 
102
  def img_postprocessing(image, w, h):
103
 
104
+ image = resize_back(image, w, h)
105
+ image = image.permute(0, 2, 3, 1)
106
+ edited_image = img_denorm(image)
107
+ img_output = (edited_image[0].numpy() * 255).astype(np.uint8)
 
108
 
109
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
110
+ imageio.imwrite(temp_file.name, img_output, quality=8)
111
+ return temp_file.name
 
 
 
112
 
113
 
114
 
115
  def vid_postprocessing(video, w, h, fps):
116
+ # video: BCTHW
117
 
118
+ b,c,t,_,_ = video.size()
119
+ vid_batch = resize_back(rearrange(video, "b c t h w -> (b t) c h w"), w, h)
120
+ vid = rearrange(vid_batch, "(b t) c h w -> b t h w c", b=b) # B T H W C
121
+ vid_np = (vid_denorm(vid[0]).numpy() * 255).astype('uint8')
 
 
122
 
123
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_file:
124
  imageio.mimwrite(temp_file.name, vid_np, fps=fps, codec='libx264', quality=8)
 
126
 
127
 
128
  def animation(gen, chunk_size, device):
129
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  @spaces.GPU
131
+ @torch.no_grad()
132
  def edit_media(image, *selected_s):
133
 
134
  image_tensor, w, h = img_preprocessing(image, 512)
135
  image_tensor = image_tensor.to(device)
136
 
137
+ edited_image_tensor = gen.edit_img(image_tensor, labels_v, selected_s)
 
138
 
139
  # de-norm
140
  edited_image = img_postprocessing(edited_image_tensor, w, h)
 
142
  return edited_image
143
 
144
  @spaces.GPU
145
+ @torch.no_grad()
146
  def animate_media(image, video, *selected_s):
147
 
148
  image_tensor, w, h = img_preprocessing(image, 512)
149
  vid_target_tensor, fps = vid_preprocessing(video, 512)
150
  image_tensor = image_tensor.to(device)
151
+ video_target_tensor = vid_target_tensor.to(device)
152
+
153
+ animated_video = gen.animate_batch(image_tensor, video_target_tensor, labels_v, selected_s, chunk_size)
154
+ edited_image = animated_video[:,:,0,:,:]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
  # postprocessing
157
  animated_video = vid_postprocessing(animated_video, w, h, fps)
 
162
  def clear_media():
163
  return None, None, *([0] * len(labels_k))
164
 
165
+
166
  with gr.Tab("Image Animation"):
167
 
168
  inputs_s = []
 
202
  with gr.Row():
203
  with gr.Column(scale=1):
204
  with gr.Row(): # Buttons now within a single Row
205
+ edit_btn = gr.Button("Edit", elem_id="button_edit",)
 
 
206
  clear_btn = gr.Button("Clear", elem_id="button_clear")
207
+ with gr.Row():
208
+ animate_btn = gr.Button("Animate", elem_id="button_animate")
209
+
210
 
211
 
212
  with gr.Column(scale=1):
 
221
  #video_output.render()
222
  video_output = gr.Video(label="Output Video", elem_id="output_vid", width=512)#.render()
223
 
224
+ with gr.Accordion("Control Panel", open=True):
225
  with gr.Tab("Head"):
226
  with gr.Row():
227
  for k in labels_k[:3]:
 
251
  for k in labels_k[12:14]:
252
  slider = gr.Slider(minimum=-0.2, maximum=0.2, value=0, label=k, elem_id="slider_"+str(k))
253
  inputs_s.append(slider)
254
+
255
+
256
+ edit_btn.click(
257
+ fn=edit_media,
258
+ inputs=[image_input] + inputs_s,
259
+ outputs=[image_output],
260
+ show_progress=True
261
+ )
 
 
 
262
 
263
  animate_btn.click(
264
  fn=animate_media,
265
  inputs=[image_input, video_input] + inputs_s,
266
  outputs=[image_output, video_output],
267
+ show_progress=True
268
  )
269
 
270
  clear_btn.click(
 
280
  ['./data/source/macron.png', './data/driving/driving1.mp4', 0.14,0,-0.26,-0.29,-0.11,0,-0.13,-0.18,0,0,0,0,-0.02,0.07],
281
  ['./data/source/portrait3.png', './data/driving/driving1.mp4', -0.03,0.21,-0.31,-0.12,-0.11,0,-0.05,-0.16,0,0,0,0,-0.02,0.07],
282
  ['./data/source/einstein.png','./data/driving/driving2.mp4',-0.31,0,0,0.16,0.08,0,-0.07,0,0.13,0,0,0,0,0],
283
+ ['./data/source/portrait1.png', './data/driving/driving4.mp4', 0, 0, -0.17, -0.19, 0.25, 0, 0, -0.086,
284
  0.087, 0, 0, 0, 0, 0],
285
  ['./data/source/portrait2.png','./data/driving/driving8.mp4',0,0,-0.25,0,0,0,0,0,0,0.126,0,0,0,0],
286
 
287
  ],
288
+ fn=animate_media,
289
  inputs=[image_input, video_input] + inputs_s,
290
+ outputs=[image_output, video_output],
291
  )
292
 
293