Spaces:
No application file
No application file
| # -*- coding: utf-8 -*- | |
| # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is | |
| # holder of all proprietary rights on this computer program. | |
| # You can only use this computer program if you have closed | |
| # a license agreement with MPG or you get the right to use the computer | |
| # program from someone who is authorized to grant you that right. | |
| # Any use of the computer program without a valid license is prohibited and | |
| # liable to prosecution. | |
| # | |
| # Copyright©2019 Max-Planck-Gesellschaft zur Förderung | |
| # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute | |
| # for Intelligent Systems. All rights reserved. | |
| # | |
| # Contact: [email protected] | |
| from __future__ import absolute_import | |
| from __future__ import print_function | |
| from __future__ import division | |
| import sys | |
| import os | |
| import time | |
| import pickle | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| DEFAULT_DTYPE = torch.float32 | |
| def create_prior(prior_type, **kwargs): | |
| if prior_type == 'gmm': | |
| prior = MaxMixturePrior(**kwargs) | |
| elif prior_type == 'l2': | |
| return L2Prior(**kwargs) | |
| elif prior_type == 'angle': | |
| return SMPLifyAnglePrior(**kwargs) | |
| elif prior_type == 'none' or prior_type is None: | |
| # Don't use any pose prior | |
| def no_prior(*args, **kwargs): | |
| return 0.0 | |
| prior = no_prior | |
| else: | |
| raise ValueError('Prior {}'.format(prior_type) + ' is not implemented') | |
| return prior | |
| class SMPLifyAnglePrior(nn.Module): | |
| def __init__(self, dtype=torch.float32, **kwargs): | |
| super(SMPLifyAnglePrior, self).__init__() | |
| # Indices for the roration angle of | |
| # 55: left elbow, 90deg bend at -np.pi/2 | |
| # 58: right elbow, 90deg bend at np.pi/2 | |
| # 12: left knee, 90deg bend at np.pi/2 | |
| # 15: right knee, 90deg bend at np.pi/2 | |
| angle_prior_idxs = np.array([55, 58, 12, 15], dtype=np.int64) | |
| angle_prior_idxs = torch.tensor(angle_prior_idxs, dtype=torch.long) | |
| self.register_buffer('angle_prior_idxs', angle_prior_idxs) | |
| angle_prior_signs = np.array([1, -1, -1, -1], | |
| dtype=np.float32 if dtype == torch.float32 | |
| else np.float64) | |
| angle_prior_signs = torch.tensor(angle_prior_signs, | |
| dtype=dtype) | |
| self.register_buffer('angle_prior_signs', angle_prior_signs) | |
| def forward(self, pose, with_global_pose=False): | |
| ''' Returns the angle prior loss for the given pose | |
| Args: | |
| pose: (Bx[23 + 1] * 3) torch tensor with the axis-angle | |
| representation of the rotations of the joints of the SMPL model. | |
| Kwargs: | |
| with_global_pose: Whether the pose vector also contains the global | |
| orientation of the SMPL model. If not then the indices must be | |
| corrected. | |
| Returns: | |
| A sze (B) tensor containing the angle prior loss for each element | |
| in the batch. | |
| ''' | |
| angle_prior_idxs = self.angle_prior_idxs - (not with_global_pose) * 3 | |
| return torch.exp(pose[:, angle_prior_idxs] * | |
| self.angle_prior_signs).pow(2) | |
| class L2Prior(nn.Module): | |
| def __init__(self, dtype=DEFAULT_DTYPE, reduction='sum', **kwargs): | |
| super(L2Prior, self).__init__() | |
| def forward(self, module_input, *args): | |
| return torch.sum(module_input.pow(2)) | |
| class MaxMixturePrior(nn.Module): | |
| def __init__(self, prior_folder='prior', | |
| num_gaussians=6, dtype=DEFAULT_DTYPE, epsilon=1e-16, | |
| use_merged=True, | |
| **kwargs): | |
| super(MaxMixturePrior, self).__init__() | |
| if dtype == DEFAULT_DTYPE: | |
| np_dtype = np.float32 | |
| elif dtype == torch.float64: | |
| np_dtype = np.float64 | |
| else: | |
| print('Unknown float type {}, exiting!'.format(dtype)) | |
| sys.exit(-1) | |
| self.num_gaussians = num_gaussians | |
| self.epsilon = epsilon | |
| self.use_merged = use_merged | |
| gmm_fn = 'gmm_{:02d}.pkl'.format(num_gaussians) | |
| full_gmm_fn = os.path.join(prior_folder, gmm_fn) | |
| if not os.path.exists(full_gmm_fn): | |
| print('The path to the mixture prior "{}"'.format(full_gmm_fn) + | |
| ' does not exist, exiting!') | |
| sys.exit(-1) | |
| with open(full_gmm_fn, 'rb') as f: | |
| gmm = pickle.load(f, encoding='latin1') | |
| if type(gmm) == dict: | |
| means = gmm['means'].astype(np_dtype) | |
| covs = gmm['covars'].astype(np_dtype) | |
| weights = gmm['weights'].astype(np_dtype) | |
| elif 'sklearn.mixture.gmm.GMM' in str(type(gmm)): | |
| means = gmm.means_.astype(np_dtype) | |
| covs = gmm.covars_.astype(np_dtype) | |
| weights = gmm.weights_.astype(np_dtype) | |
| else: | |
| print('Unknown type for the prior: {}, exiting!'.format(type(gmm))) | |
| sys.exit(-1) | |
| self.register_buffer('means', torch.tensor(means, dtype=dtype)) | |
| self.register_buffer('covs', torch.tensor(covs, dtype=dtype)) | |
| precisions = [np.linalg.inv(cov) for cov in covs] | |
| precisions = np.stack(precisions).astype(np_dtype) | |
| self.register_buffer('precisions', | |
| torch.tensor(precisions, dtype=dtype)) | |
| # The constant term: | |
| sqrdets = np.array([(np.sqrt(np.linalg.det(c))) | |
| for c in gmm['covars']]) | |
| const = (2 * np.pi)**(69 / 2.) | |
| nll_weights = np.asarray(gmm['weights'] / (const * | |
| (sqrdets / sqrdets.min()))) | |
| nll_weights = torch.tensor(nll_weights, dtype=dtype).unsqueeze(dim=0) | |
| self.register_buffer('nll_weights', nll_weights) | |
| weights = torch.tensor(gmm['weights'], dtype=dtype).unsqueeze(dim=0) | |
| self.register_buffer('weights', weights) | |
| self.register_buffer('pi_term', | |
| torch.log(torch.tensor(2 * np.pi, dtype=dtype))) | |
| cov_dets = [np.log(np.linalg.det(cov.astype(np_dtype)) + epsilon) | |
| for cov in covs] | |
| self.register_buffer('cov_dets', | |
| torch.tensor(cov_dets, dtype=dtype)) | |
| # The dimensionality of the random variable | |
| self.random_var_dim = self.means.shape[1] | |
| def get_mean(self): | |
| ''' Returns the mean of the mixture ''' | |
| mean_pose = torch.matmul(self.weights, self.means) | |
| return mean_pose | |
| def merged_log_likelihood(self, pose, betas): | |
| diff_from_mean = pose.unsqueeze(dim=1) - self.means | |
| prec_diff_prod = torch.einsum('mij,bmj->bmi', | |
| [self.precisions, diff_from_mean]) | |
| diff_prec_quadratic = (prec_diff_prod * diff_from_mean).sum(dim=-1) | |
| curr_loglikelihood = 0.5 * diff_prec_quadratic - \ | |
| torch.log(self.nll_weights) | |
| # curr_loglikelihood = 0.5 * (self.cov_dets.unsqueeze(dim=0) + | |
| # self.random_var_dim * self.pi_term + | |
| # diff_prec_quadratic | |
| # ) - torch.log(self.weights) | |
| min_likelihood, _ = torch.min(curr_loglikelihood, dim=1) | |
| return min_likelihood | |
| def log_likelihood(self, pose, betas, *args, **kwargs): | |
| ''' Create graph operation for negative log-likelihood calculation | |
| ''' | |
| likelihoods = [] | |
| for idx in range(self.num_gaussians): | |
| mean = self.means[idx] | |
| prec = self.precisions[idx] | |
| cov = self.covs[idx] | |
| diff_from_mean = pose - mean | |
| curr_loglikelihood = torch.einsum('bj,ji->bi', | |
| [diff_from_mean, prec]) | |
| curr_loglikelihood = torch.einsum('bi,bi->b', | |
| [curr_loglikelihood, | |
| diff_from_mean]) | |
| cov_term = torch.log(torch.det(cov) + self.epsilon) | |
| curr_loglikelihood += 0.5 * (cov_term + | |
| self.random_var_dim * | |
| self.pi_term) | |
| likelihoods.append(curr_loglikelihood) | |
| log_likelihoods = torch.stack(likelihoods, dim=1) | |
| min_idx = torch.argmin(log_likelihoods, dim=1) | |
| weight_component = self.nll_weights[:, min_idx] | |
| weight_component = -torch.log(weight_component) | |
| return weight_component + log_likelihoods[:, min_idx] | |
| def forward(self, pose, betas): | |
| if self.use_merged: | |
| return self.merged_log_likelihood(pose, betas) | |
| else: | |
| return self.log_likelihood(pose, betas) |