Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.autograd import Variable | |
| from itertools import product, permutations, combinations_with_replacement, chain | |
| class Unary(nn.Module): | |
| def __init__(self, embed_size): | |
| """ | |
| Captures local entity information | |
| :param embed_size: the embedding dimension | |
| """ | |
| super(Unary, self).__init__() | |
| self.embed = nn.Conv1d(embed_size, embed_size, 1) | |
| self.feature_reduce = nn.Conv1d(embed_size, 1, 1) | |
| def forward(self, X): | |
| X = X.transpose(1, 2) | |
| X_embed = self.embed(X) | |
| X_nl_embed = F.dropout(F.relu(X_embed)) | |
| X_poten = self.feature_reduce(X_nl_embed) | |
| return X_poten.squeeze(1) | |
| class Pairwise(nn.Module): | |
| def __init__(self, embed_x_size, x_spatial_dim=None, embed_y_size=None, y_spatial_dim=None): | |
| """ | |
| Captures interaction between utilities or entities of the same utility | |
| :param embed_x_size: the embedding dimension of the first utility | |
| :param x_spatial_dim: the spatial dimension of the first utility for batch norm and weighted marginalization | |
| :param embed_y_size: the embedding dimension of the second utility (none for self-interactions) | |
| :param y_spatial_dim: the spatial dimension of the second utility for batch norm and weighted marginalization | |
| """ | |
| super(Pairwise, self).__init__() | |
| embed_y_size = embed_y_size if y_spatial_dim is not None else embed_x_size | |
| self.y_spatial_dim = y_spatial_dim if y_spatial_dim is not None else x_spatial_dim | |
| self.embed_size = max(embed_x_size, embed_y_size) | |
| self.x_spatial_dim = x_spatial_dim | |
| self.embed_X = nn.Conv1d(embed_x_size, self.embed_size, 1) | |
| self.embed_Y = nn.Conv1d(embed_y_size, self.embed_size, 1) | |
| if x_spatial_dim is not None: | |
| self.normalize_S = nn.BatchNorm1d(self.x_spatial_dim * self.y_spatial_dim) | |
| self.margin_X = nn.Conv1d(self.y_spatial_dim, 1, 1) | |
| self.margin_Y = nn.Conv1d(self.x_spatial_dim, 1, 1) | |
| def forward(self, X, Y=None): | |
| X_t = X.transpose(1, 2) | |
| Y_t = Y.transpose(1, 2) if Y is not None else X_t | |
| X_embed = self.embed_X(X_t) | |
| Y_embed = self.embed_Y(Y_t) | |
| X_norm = F.normalize(X_embed) | |
| Y_norm = F.normalize(Y_embed) | |
| S = X_norm.transpose(1, 2).bmm(Y_norm) | |
| if self.x_spatial_dim is not None: | |
| S = self.normalize_S(S.view(-1, self.x_spatial_dim * self.y_spatial_dim)) \ | |
| .view(-1, self.x_spatial_dim, self.y_spatial_dim) | |
| X_poten = self.margin_X(S.transpose(1, 2)).transpose(1, 2).squeeze(2) | |
| Y_poten = self.margin_Y(S).transpose(1, 2).squeeze(2) | |
| else: | |
| X_poten = S.mean(dim=2, keepdim=False) | |
| Y_poten = S.mean(dim=1, keepdim=False) | |
| if Y is None: | |
| return X_poten | |
| else: | |
| return X_poten, Y_poten | |
| class Atten(nn.Module): | |
| def __init__(self, util_e, sharing_factor_weights=[], prior_flag=False, | |
| sizes=[], size_force=False, pairwise_flag=True, | |
| unary_flag=True, self_flag=True): | |
| """ | |
| The class performs an attention on a given list of utilities representation. | |
| :param util_e: the embedding dimensions | |
| :param sharing_factor_weights: To share weights, provide a dict of tuples: | |
| {idx: (num_utils, connected utils) | |
| Note, for efficiency, the shared utils (i.e., history, are connected to ans | |
| and question only. | |
| TODO: connections between shared utils | |
| :param prior_flag: is prior factor provided | |
| :param sizes: the spatial simension (used for batch-norm and weighted marginalization) | |
| :param size_force: force spatial size with adaptive avg pooling. | |
| :param pairwise_flag: use pairwise interaction between utilities | |
| :param unary_flag: use local information | |
| :param self_flag: use self interactions between utilitie's entities | |
| """ | |
| super(Atten, self).__init__() | |
| self.util_e = util_e | |
| self.prior_flag = prior_flag | |
| self.n_utils = len(util_e) | |
| self.spatial_pool = nn.ModuleDict() | |
| self.un_models = nn.ModuleList() | |
| self.self_flag = self_flag | |
| self.pairwise_flag = pairwise_flag | |
| self.unary_flag = unary_flag | |
| self.size_force = size_force | |
| if len(sizes) == 0: | |
| sizes = [None for _ in util_e] | |
| self.sharing_factor_weights = sharing_factor_weights | |
| #force the provided size | |
| for idx, e_dim in enumerate(util_e): | |
| self.un_models.append(Unary(e_dim)) | |
| if self.size_force: | |
| self.spatial_pool[str(idx)] = nn.AdaptiveAvgPool1d(sizes[idx]) | |
| #Pairwise | |
| self.pp_models = nn.ModuleDict() | |
| for ((idx1, e_dim_1), (idx2, e_dim_2)) \ | |
| in combinations_with_replacement(enumerate(util_e), 2): | |
| # self | |
| if self.self_flag and idx1 == idx2: | |
| self.pp_models[str(idx1)] = Pairwise(e_dim_1, sizes[idx1]) | |
| else: | |
| if pairwise_flag: | |
| if idx1 in self.sharing_factor_weights: | |
| # not connected | |
| if idx2 not in self.sharing_factor_weights[idx1][1]: | |
| continue | |
| if idx2 in self.sharing_factor_weights: | |
| # not connected | |
| if idx1 not in self.sharing_factor_weights[idx2][1]: | |
| continue | |
| self.pp_models[str((idx1, idx2))] = Pairwise(e_dim_1, sizes[idx1], e_dim_2, sizes[idx2]) | |
| # Handle reduce potentials (with scalars) | |
| self.reduce_potentials = nn.ModuleList() | |
| self.num_of_potentials = dict() | |
| self.default_num_of_potentials = 0 | |
| if self.self_flag: | |
| self.default_num_of_potentials += 1 | |
| if self.unary_flag: | |
| self.default_num_of_potentials += 1 | |
| if self.prior_flag: | |
| self.default_num_of_potentials += 1 | |
| for idx in range(self.n_utils): | |
| self.num_of_potentials[idx] = self.default_num_of_potentials | |
| ''' | |
| All other utilities | |
| ''' | |
| if pairwise_flag: | |
| for idx, (num_utils, connected_utils) in sharing_factor_weights: | |
| for c_u in connected_utils: | |
| self.num_of_potentials[c_u] += num_utils | |
| self.num_of_potentials[idx] += 1 | |
| for k in self.num_of_potentials: | |
| if k not in self.sharing_factor_weights: | |
| self.num_of_potentials[k] += (self.n_utils - 1) \ | |
| - len(sharing_factor_weights) | |
| for idx in range(self.n_utils): | |
| self.reduce_potentials.append(nn.Conv1d(self.num_of_potentials[idx], | |
| 1, 1, bias=False)) | |
| def forward(self, utils, priors=None): | |
| assert self.n_utils == len(utils) | |
| assert (priors is None and not self.prior_flag) \ | |
| or (priors is not None | |
| and self.prior_flag | |
| and len(priors) == self.n_utils) | |
| b_size = utils[0].size(0) | |
| util_factors = dict() | |
| attention = list() | |
| #Force size, constant size is used for pairwise batch normalization | |
| if self.size_force: | |
| for i, (num_utils, _) in self.sharing_factor_weights.items(): | |
| if str(i) not in self.spatial_pool.keys(): | |
| continue | |
| else: | |
| high_util = utils[i] | |
| high_util = high_util.view(num_utils * b_size, high_util.size(2), high_util.size(3)) | |
| high_util = high_util.transpose(1, 2) | |
| utils[i] = self.spatial_pool[str(i)](high_util).transpose(1, 2) | |
| for i in range(self.n_utils): | |
| if i in self.sharing_factor_weights \ | |
| or str(i) not in self.spatial_pool.keys(): | |
| continue | |
| utils[i] = utils[i].transpose(1, 2) | |
| utils[i] = self.spatial_pool[str(i)](utils[i]).transpose(1, 2) | |
| if self.prior_flag and priors[i] is not None: | |
| priors[i] = self.spatial_pool[str(i)](priors[i].unsqueeze(1)).squeeze(1) | |
| # handle Shared weights | |
| for i, (num_utils, connected_list) in self.sharing_factor_weights: | |
| if self.unary_flag: | |
| util_factors.setdefault(i, []).append(self.un_models[i](utils[i])) | |
| if self.self_flag: | |
| util_factors.setdefault(i, []).append(self.pp_models[str(i)](utils[i])) | |
| if self.pairwise_flag: | |
| for j in connected_list: | |
| other_util = utils[j] | |
| expanded_util = other_util.unsqueeze(1).expand(b_size, | |
| num_utils, | |
| other_util.size(1), | |
| other_util.size(2)).contiguous().view( | |
| b_size * num_utils, | |
| other_util.size(1), | |
| other_util.size(2)) | |
| if i < j: | |
| factor_ij, factor_ji = self.pp_models[str((i, j))](utils[i], expanded_util) | |
| else: | |
| factor_ji, factor_ij = self.pp_models[str((j, i))](expanded_util, utils[i]) | |
| util_factors[i].append(factor_ij) | |
| util_factors.setdefault(j, []).append(factor_ji.view(b_size, num_utils, factor_ji.size(1))) | |
| # handle local factors | |
| for i in range(self.n_utils): | |
| if i in self.sharing_factor_weights: | |
| continue | |
| if self.unary_flag: | |
| util_factors.setdefault(i, []).append(self.un_models[i](utils[i])) | |
| if self.self_flag: | |
| util_factors.setdefault(i, []).append(self.pp_models[str(i)](utils[i])) | |
| # joint | |
| if self.pairwise_flag: | |
| for (i, j) in combinations_with_replacement(range(self.n_utils), 2): | |
| if i in self.sharing_factor_weights \ | |
| or j in self.sharing_factor_weights: | |
| continue | |
| if i == j: | |
| continue | |
| else: | |
| factor_ij, factor_ji = self.pp_models[str((i, j))](utils[i], utils[j]) | |
| util_factors.setdefault(i, []).append(factor_ij) | |
| util_factors.setdefault(j, []).append(factor_ji) | |
| # perform attention | |
| for i in range(self.n_utils): | |
| if self.prior_flag: | |
| prior = priors[i] \ | |
| if priors[i] is not None \ | |
| else torch.zeros_like(util_factors[i][0], requires_grad=False).cuda() | |
| util_factors[i].append(prior) | |
| util_factors[i] = torch.cat([p if len(p.size()) == 3 else p.unsqueeze(1) | |
| for p in util_factors[i]], dim=1) | |
| util_factors[i] = self.reduce_potentials[i](util_factors[i]).squeeze(1) | |
| util_factors[i] = F.softmax(util_factors[i], dim=1).unsqueeze(2) | |
| attention.append(torch.bmm(utils[i].transpose(1, 2), util_factors[i]).squeeze(2)) | |
| return attention | |
| class NaiveAttention(nn.Module): | |
| def __init__(self): | |
| """ | |
| Used for ablation analysis - removing attention. | |
| """ | |
| super(NaiveAttention, self).__init__() | |
| def forward(self, utils, priors): | |
| atten = [] | |
| spatial_atten = [] | |
| for u, p in zip(utils, priors): | |
| if type(u) is tuple: | |
| u = u[1] | |
| num_elements = u.shape[0] | |
| if p is not None: | |
| u = u.view(-1, u.shape[-2], u.shape[-1]) | |
| p = p.view(-1, p.shape[-2], p.shape[-1]) | |
| spatial_atten.append( | |
| torch.bmm(p.transpose(1, 2), u).squeeze(2).view(num_elements, -1, u.shape[-2], u.shape[-1])) | |
| else: | |
| spatial_atten.append(u.mean(2)) | |
| continue | |
| if p is not None: | |
| atten.append(torch.bmm(u.transpose(1, 2), p.unsqueeze(2)).squeeze(2)) | |
| else: | |
| atten.append(u.mean(1)) | |
| return atten, spatial_atten |