Spaces:
Runtime error
Runtime error
| # coding: utf-8 | |
| """ | |
| functions for processing and transforming 3D facial keypoints | |
| """ | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| PI = np.pi | |
| def headpose_pred_to_degree(pred): | |
| """ | |
| pred: (bs, 66) or (bs, 1) or others | |
| """ | |
| if pred.ndim > 1 and pred.shape[1] == 66: | |
| # NOTE: note that the average is modified to 97.5 | |
| device = pred.device | |
| idx_tensor = [idx for idx in range(0, 66)] | |
| idx_tensor = torch.FloatTensor(idx_tensor).to(device) | |
| pred = F.softmax(pred, dim=1) | |
| degree = torch.sum(pred*idx_tensor, axis=1) * 3 - 97.5 | |
| return degree | |
| return pred | |
| def get_rotation_matrix(pitch_, yaw_, roll_): | |
| """ the input is in degree | |
| """ | |
| # calculate the rotation matrix: vps @ rot | |
| # transform to radian | |
| pitch = pitch_ / 180 * PI | |
| yaw = yaw_ / 180 * PI | |
| roll = roll_ / 180 * PI | |
| device = pitch.device | |
| if pitch.ndim == 1: | |
| pitch = pitch.unsqueeze(1) | |
| if yaw.ndim == 1: | |
| yaw = yaw.unsqueeze(1) | |
| if roll.ndim == 1: | |
| roll = roll.unsqueeze(1) | |
| # calculate the euler matrix | |
| bs = pitch.shape[0] | |
| ones = torch.ones([bs, 1]).to(device) | |
| zeros = torch.zeros([bs, 1]).to(device) | |
| x, y, z = pitch, yaw, roll | |
| rot_x = torch.cat([ | |
| ones, zeros, zeros, | |
| zeros, torch.cos(x), -torch.sin(x), | |
| zeros, torch.sin(x), torch.cos(x) | |
| ], dim=1).reshape([bs, 3, 3]) | |
| rot_y = torch.cat([ | |
| torch.cos(y), zeros, torch.sin(y), | |
| zeros, ones, zeros, | |
| -torch.sin(y), zeros, torch.cos(y) | |
| ], dim=1).reshape([bs, 3, 3]) | |
| rot_z = torch.cat([ | |
| torch.cos(z), -torch.sin(z), zeros, | |
| torch.sin(z), torch.cos(z), zeros, | |
| zeros, zeros, ones | |
| ], dim=1).reshape([bs, 3, 3]) | |
| rot = rot_z @ rot_y @ rot_x | |
| return rot.permute(0, 2, 1) # transpose | |