jhj0517
commited on
Commit
·
20b1ce3
1
Parent(s):
45d5794
Auto cast for faster inference
Browse files
modules/live_portrait/live_portrait_inferencer.py
CHANGED
|
@@ -174,67 +174,68 @@ class LivePortraitInferencer:
|
|
| 174 |
)
|
| 175 |
|
| 176 |
try:
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
if
|
| 181 |
-
self.crop_factor
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
if
|
| 198 |
-
self.sample_image
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
|
|
|
| 238 |
except Exception as e:
|
| 239 |
raise
|
| 240 |
|
|
|
|
| 174 |
)
|
| 175 |
|
| 176 |
try:
|
| 177 |
+
with torch.autocast(device_type=self.device, enabled=(self.device == "cuda")):
|
| 178 |
+
rotate_yaw = -rotate_yaw
|
| 179 |
+
|
| 180 |
+
if src_image is not None:
|
| 181 |
+
if id(src_image) != id(self.src_image) or self.crop_factor != crop_factor:
|
| 182 |
+
self.crop_factor = crop_factor
|
| 183 |
+
self.psi = self.prepare_source(src_image, crop_factor)
|
| 184 |
+
self.src_image = src_image
|
| 185 |
+
else:
|
| 186 |
+
return None
|
| 187 |
+
|
| 188 |
+
psi = self.psi
|
| 189 |
+
s_info = psi.x_s_info
|
| 190 |
+
#delta_new = copy.deepcopy()
|
| 191 |
+
s_exp = s_info['exp'] * src_ratio
|
| 192 |
+
s_exp[0, 5] = s_info['exp'][0, 5]
|
| 193 |
+
s_exp += s_info['kp']
|
| 194 |
+
|
| 195 |
+
es = ExpressionSet()
|
| 196 |
+
|
| 197 |
+
if isinstance(sample_image, np.ndarray) and sample_image:
|
| 198 |
+
if id(self.sample_image) != id(sample_image):
|
| 199 |
+
self.sample_image = sample_image
|
| 200 |
+
d_image_np = (sample_image * 255).byte().numpy()
|
| 201 |
+
d_face = self.crop_face(d_image_np[0], 1.7)
|
| 202 |
+
i_d = self.prepare_src_image(d_face)
|
| 203 |
+
self.d_info = self.pipeline.get_kp_info(i_d)
|
| 204 |
+
self.d_info['exp'][0, 5, 0] = 0
|
| 205 |
+
self.d_info['exp'][0, 5, 1] = 0
|
| 206 |
+
|
| 207 |
+
# "OnlyExpression", "OnlyRotation", "OnlyMouth", "OnlyEyes", "All"
|
| 208 |
+
if sample_parts == SamplePart.ONLY_EXPRESSION.value or sample_parts == SamplePart.ONLY_EXPRESSION.ALL.value:
|
| 209 |
+
es.e += self.d_info['exp'] * sample_ratio
|
| 210 |
+
if sample_parts == SamplePart.ONLY_ROTATION.value or sample_parts == SamplePart.ONLY_ROTATION.ALL.value:
|
| 211 |
+
rotate_pitch += self.d_info['pitch'] * sample_ratio
|
| 212 |
+
rotate_yaw += self.d_info['yaw'] * sample_ratio
|
| 213 |
+
rotate_roll += self.d_info['roll'] * sample_ratio
|
| 214 |
+
elif sample_parts == SamplePart.ONLY_MOUTH.value:
|
| 215 |
+
self.retargeting(es.e, self.d_info['exp'], sample_ratio, (14, 17, 19, 20))
|
| 216 |
+
elif sample_parts == SamplePart.ONLY_EYES.value:
|
| 217 |
+
self.retargeting(es.e, self.d_info['exp'], sample_ratio, (1, 2, 11, 13, 15, 16))
|
| 218 |
+
|
| 219 |
+
es.r = self.calc_fe(es.e, blink, eyebrow, wink, pupil_x, pupil_y, aaa, eee, woo, smile,
|
| 220 |
+
rotate_pitch, rotate_yaw, rotate_roll)
|
| 221 |
+
|
| 222 |
+
new_rotate = get_rotation_matrix(s_info['pitch'] + es.r[0], s_info['yaw'] + es.r[1],
|
| 223 |
+
s_info['roll'] + es.r[2])
|
| 224 |
+
x_d_new = (s_info['scale'] * (1 + es.s)) * ((s_exp + es.e) @ new_rotate) + s_info['t']
|
| 225 |
+
|
| 226 |
+
x_d_new = self.pipeline.stitching(psi.x_s_user, x_d_new)
|
| 227 |
+
|
| 228 |
+
crop_out = self.pipeline.warp_decode(psi.f_s_user, psi.x_s_user, x_d_new)
|
| 229 |
+
crop_out = self.pipeline.parse_output(crop_out['out'])[0]
|
| 230 |
+
|
| 231 |
+
crop_with_fullsize = cv2.warpAffine(crop_out, psi.crop_trans_m, get_rgb_size(psi.src_rgb), cv2.INTER_LINEAR)
|
| 232 |
+
out = np.clip(psi.mask_ori * crop_with_fullsize + (1 - psi.mask_ori) * psi.src_rgb, 0, 255).astype(np.uint8)
|
| 233 |
+
|
| 234 |
+
temp_out_img_path, out_img_path = get_auto_incremental_file_path(TEMP_DIR, "png"), get_auto_incremental_file_path(OUTPUTS_DIR, "png")
|
| 235 |
+
save_image(numpy_array=crop_out, output_path=temp_out_img_path)
|
| 236 |
+
save_image(numpy_array=out, output_path=out_img_path)
|
| 237 |
+
|
| 238 |
+
return out
|
| 239 |
except Exception as e:
|
| 240 |
raise
|
| 241 |
|