File size: 8,219 Bytes
1672673
 
c42db24
1672673
 
c42db24
 
 
2866038
8b1952d
c42db24
 
 
 
 
 
 
 
 
 
 
 
 
8b1952d
c42db24
 
 
 
 
 
 
 
 
 
 
 
 
 
62c64ac
c42db24
8b1952d
 
 
1eb3628
8b1952d
c42db24
 
8b1952d
c42db24
 
 
62c64ac
c42db24
 
 
8b1952d
c42db24
 
6b0ef0f
62c64ac
 
 
 
c42db24
 
 
8b1952d
62c64ac
 
 
 
8b1952d
62c64ac
8b1952d
c42db24
 
62c64ac
c42db24
 
62c64ac
 
 
 
c42db24
 
 
 
 
62c64ac
c42db24
 
 
 
 
 
62c64ac
c42db24
 
 
 
 
8b1952d
c42db24
62c64ac
 
 
 
6b0ef0f
62c64ac
 
 
1672673
c42db24
 
8b1952d
62c64ac
6b0ef0f
62c64ac
 
 
 
c42db24
1672673
6416c96
1672673
c42db24
 
2d73c25
62c64ac
1672673
62c64ac
685fe6b
c42db24
8b1952d
685fe6b
c42db24
62c64ac
c42db24
685fe6b
8b1952d
c42db24
685fe6b
c42db24
1672673
62c64ac
685fe6b
c42db24
8b1952d
685fe6b
 
62c64ac
 
 
 
c42db24
685fe6b
8b1952d
 
3d9bc73
c42db24
 
685fe6b
 
c1a5130
62c64ac
13617d9
c42db24
 
 
 
 
 
 
07cfcc8
c42db24
 
 
 
 
 
 
 
 
 
 
 
 
 
07cfcc8
c42db24
 
4279bed
c42db24
 
 
4879232
c42db24
 
 
 
 
 
 
 
62c64ac
6b0ef0f
62c64ac
 
 
c42db24
 
 
 
 
 
2993657
 
 
c42db24
 
2993657
 
c42db24
62c64ac
c42db24
 
 
07cfcc8
c42db24
 
 
07cfcc8
c42db24
 
 
 
 
07cfcc8
c42db24
 
 
07cfcc8
c42db24
 
 
 
 
07cfcc8
c42db24
 
 
07cfcc8
c42db24
62c64ac
 
 
 
 
 
 
 
c42db24
 
 
3d9bc73
 
62c64ac
c42db24
 
 
 
 
 
 
 
 
6275e3e
88cc4f6
 
4279bed
b04407e
c42db24
62c64ac
88cc4f6
a76b0bc
6275e3e
c42db24
62c64ac
47c822a
62c64ac
c42db24
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
import tempfile

import gradio as gr
import imageio
import spaces
import torch
import torchvision
import numpy as np
from PIL import Image
from einops import rearrange

# lables
labels_k = [
	'yaw1',
	'yaw2',
	'pitch',
	'roll1',
	'roll2',
	'neck',

	'pout',
	'open->close',
	'"O" Mouth',
	'smile',

	'close->open',
	'eyebrows',
	'eyeballs1',
	'eyeballs2',

]

labels_v = [
	37, 39, 28, 15, 33, 31,
	6, 25, 16, 19,
	13, 24, 17, 26
]


def load_image(img, size):
	
	img = Image.open(img).convert('RGB')
	w, h = img.size
	img = img.resize((size, size))
	img = np.asarray(img)
	img = np.transpose(img, (2, 0, 1))	# 3 x 256 x 256

	return img / 255.0, w, h


def img_preprocessing(img_path, size):
	img, w, h = load_image(img_path, size)  # [0, 1]
	img = torch.from_numpy(img).unsqueeze(0).float()  # [0, 1]
	imgs_norm = (img - 0.5) * 2.0  # [-1, 1]

	return imgs_norm, w, h


def resize(img, size):
	transform = torchvision.transforms.Compose([
		torchvision.transforms.Resize((size, size), antialias=True),
	])

	return transform(img)


def resize_back(img, w, h):
	transform = torchvision.transforms.Compose([
		torchvision.transforms.Resize((h, w), antialias=True),
	])

	return transform(img)
    

def vid_preprocessing(vid_path, size):
	vid_dict = torchvision.io.read_video(vid_path, pts_unit='sec')
	vid = vid_dict[0].permute(0, 3, 1, 2).unsqueeze(0)	# btchw
	fps = vid_dict[2]['video_fps']
	vid_norm = (vid / 255.0 - 0.5) * 2.0  # [-1, 1]

	vid_norm = torch.cat([
		resize(vid_norm[:, i, :, :, :], size).unsqueeze(1) for i in range(vid.size(1))
	], dim=1)

	return vid_norm, fps


def img_denorm(img):
	img = img.clamp(-1, 1).cpu()
	img = (img - img.min()) / (img.max() - img.min())

	return img


def vid_denorm(vid):
	vid = vid.clamp(-1, 1).cpu()
	vid = (vid - vid.min()) / (vid.max() - vid.min())

	return vid


def img_postprocessing(image, w, h):

	image = resize_back(image, w, h)
	image = image.permute(0, 2, 3, 1)
	edited_image = img_denorm(image)
	img_output = (edited_image[0].numpy() * 255).astype(np.uint8)

	with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
		imageio.imwrite(temp_file.name, img_output, quality=8)
		return temp_file.name



def vid_postprocessing(video, w, h, fps):
	# video: BCTHW

	b,c,t,_,_ = video.size()
	vid_batch = resize_back(rearrange(video, "b c t h w -> (b t) c h w"), w, h)
	vid = rearrange(vid_batch, "(b t) c h w -> b t h w c", b=b)	# B T H W C
	vid_np = (vid_denorm(vid[0]).numpy() * 255).astype('uint8')

	with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_file:
		imageio.mimwrite(temp_file.name, vid_np, fps=fps, codec='libx264', quality=8)
		return temp_file.name


def animation(gen, chunk_size, device):
    
	@spaces.GPU
	@torch.no_grad()
	def edit_media(image, *selected_s):

		image_tensor, w, h = img_preprocessing(image, 512)
		image_tensor = image_tensor.to(device)

		edited_image_tensor = gen.edit_img(image_tensor, labels_v, selected_s)

		# de-norm
		edited_image = img_postprocessing(edited_image_tensor, w, h)

		return edited_image

	@spaces.GPU
	@torch.no_grad()
	def animate_media(image, video, *selected_s):

		image_tensor, w, h = img_preprocessing(image, 512)
		vid_target_tensor, fps = vid_preprocessing(video, 512)
		image_tensor = image_tensor.to(device)
		video_target_tensor = vid_target_tensor.to(device)

		animated_video = gen.animate_batch(image_tensor, video_target_tensor, labels_v, selected_s, chunk_size)
		edited_image = animated_video[:,:,0,:,:]

		# postprocessing
		animated_video = vid_postprocessing(animated_video, w, h, fps)
		edited_image = img_postprocessing(edited_image, w, h)
		return edited_image, animated_video


	def clear_media():
		return None, None, *([0] * len(labels_k))

    
	with gr.Tab("Image Animation"):

		inputs_s = []

		with gr.Row():
			with gr.Column(scale=1):
				with gr.Row():
					with gr.Accordion(open=True, label="Source Image"):
						image_input = gr.Image(type="filepath", elem_id="input_img", width=512)	# , height=550)
						gr.Examples(
							examples=[
								["./data/source/macron.png"],
								["./data/source/einstein.png"],
								["./data/source/taylor.png"],
								["./data/source/portrait1.png"],
								["./data/source/portrait2.png"],
								["./data/source/portrait3.png"],
							],
							inputs=[image_input],
							visible=True,
							)

					with gr.Accordion(open=True, label="Driving Video"):
						video_input = gr.Video(width=512, elem_id="input_vid",)  # , height=550)
						gr.Examples(
							examples=[
								["./data/driving/driving6.mp4"],
								["./data/driving/driving1.mp4"],
								["./data/driving/driving2.mp4"],
								["./data/driving/driving4.mp4"],
								["./data/driving/driving8.mp4"],
							],
							inputs=[video_input],
							visible=True,
							)

				with gr.Row():
					with gr.Column(scale=1):
						with gr.Row():	# Buttons now within a single Row
							edit_btn = gr.Button("Edit", elem_id="button_edit",)
							clear_btn = gr.Button("Clear", elem_id="button_clear")
						with gr.Row():
							animate_btn = gr.Button("Animate", elem_id="button_animate")



			with gr.Column(scale=1):

				with gr.Row():
					with gr.Accordion(open=True, label="Edited Source Image"):
						#image_output.render()
						image_output = gr.Image(label="Output Image", elem_id="output_img", type='numpy', interactive=False, width=512)#.render()


					with gr.Accordion(open=True, label="Animated Video"):
						#video_output.render()
						video_output = gr.Video(label="Output Video", elem_id="output_vid", width=512)#.render()

				with gr.Accordion("Control Panel", open=True):
					with gr.Tab("Head"):
						with gr.Row():
							for k in labels_k[:3]:
								slider = gr.Slider(minimum=-1.0, maximum=0.5, value=0, label=k, elem_id="slider_"+str(k))
								inputs_s.append(slider)
						with gr.Row():
							for k in labels_k[3:6]:
								slider = gr.Slider(minimum=-0.5, maximum=0.5, value=0, label=k, elem_id="slider_"+str(k))
								inputs_s.append(slider)

					with gr.Tab("Mouth"):
						with gr.Row():
							for k in labels_k[6:8]:
								slider = gr.Slider(minimum=-0.4, maximum=0.4, value=0, label=k, elem_id="slider_"+str(k))
								inputs_s.append(slider)
						with gr.Row():
							for k in labels_k[8:10]:
								slider = gr.Slider(minimum=-0.4, maximum=0.4, value=0, label=k, elem_id="slider_"+str(k))
								inputs_s.append(slider)

					with gr.Tab("Eyes"):
						with gr.Row():
							for k in labels_k[10:12]:
								slider = gr.Slider(minimum=-0.4, maximum=0.4, value=0, label=k, elem_id="slider_"+str(k))
								inputs_s.append(slider)
						with gr.Row():
							for k in labels_k[12:14]:
								slider = gr.Slider(minimum=-0.2, maximum=0.2, value=0, label=k, elem_id="slider_"+str(k))
								inputs_s.append(slider)


		edit_btn.click(
			fn=edit_media,
			inputs=[image_input] + inputs_s,
			outputs=[image_output],
			show_progress=True
		)

		animate_btn.click(
			fn=animate_media,
			inputs=[image_input, video_input] + inputs_s,
			outputs=[image_output, video_output],
            show_progress=True
		)

		clear_btn.click(
			fn=clear_media,
			outputs=[image_output, video_output] + inputs_s
		)

		gr.Examples(
			examples=[
				['./data/source/macron.png', './data/driving/driving6.mp4',-0.37,-0.34,0,0,0,0,0,0,0,0,0,0,0,0],
				['./data/source/taylor.png', './data/driving/driving6.mp4', -0.31, -0.2, 0, -0.26, -0.14, 0, 0.068, 0.131, 0, 0, 0,
				 0, -0.058, 0.087],
				['./data/source/macron.png', './data/driving/driving1.mp4', 0.14,0,-0.26,-0.29,-0.11,0,-0.13,-0.18,0,0,0,0,-0.02,0.07],
				['./data/source/portrait3.png', './data/driving/driving1.mp4', -0.03,0.21,-0.31,-0.12,-0.11,0,-0.05,-0.16,0,0,0,0,-0.02,0.07],
				['./data/source/einstein.png','./data/driving/driving2.mp4',-0.31,0,0,0.16,0.08,0,-0.07,0,0.13,0,0,0,0,0],
                ['./data/source/portrait1.png', './data/driving/driving4.mp4', 0, 0, -0.17, -0.19, 0.25, 0, 0, -0.086,
				 0.087, 0, 0, 0, 0, 0],
				['./data/source/portrait2.png','./data/driving/driving8.mp4',0,0,-0.25,0,0,0,0,0,0,0.126,0,0,0,0],
				
			],
            fn=animate_media,
			inputs=[image_input, video_input] + inputs_s,
            outputs=[image_output, video_output],
		)