Spaces:
No application file
No application file
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class QuantizeEMAReset(nn.Module): | |
| def __init__(self, nb_code, code_dim, args): | |
| super().__init__() | |
| self.nb_code = nb_code | |
| self.code_dim = code_dim | |
| self.mu = args.mu | |
| self.reset_codebook() | |
| def reset_codebook(self): | |
| self.init = False | |
| self.code_sum = None | |
| self.code_count = None | |
| self.register_buffer('codebook', torch.zeros(self.nb_code, self.code_dim).cuda()) | |
| def _tile(self, x): | |
| nb_code_x, code_dim = x.shape | |
| if nb_code_x < self.nb_code: | |
| n_repeats = (self.nb_code + nb_code_x - 1) // nb_code_x | |
| std = 0.01 / np.sqrt(code_dim) | |
| out = x.repeat(n_repeats, 1) | |
| out = out + torch.randn_like(out) * std | |
| else : | |
| out = x | |
| return out | |
| def init_codebook(self, x): | |
| out = self._tile(x) | |
| self.codebook = out[:self.nb_code] | |
| self.code_sum = self.codebook.clone() | |
| self.code_count = torch.ones(self.nb_code, device=self.codebook.device) | |
| self.init = True | |
| def compute_perplexity(self, code_idx) : | |
| # Calculate new centres | |
| code_onehot = torch.zeros(self.nb_code, code_idx.shape[0], device=code_idx.device) # nb_code, N * L | |
| code_onehot.scatter_(0, code_idx.view(1, code_idx.shape[0]), 1) | |
| code_count = code_onehot.sum(dim=-1) # nb_code | |
| prob = code_count / torch.sum(code_count) | |
| perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7))) | |
| return perplexity | |
| def update_codebook(self, x, code_idx): | |
| code_onehot = torch.zeros(self.nb_code, x.shape[0], device=x.device) # nb_code, N * L | |
| code_onehot.scatter_(0, code_idx.view(1, x.shape[0]), 1) | |
| code_sum = torch.matmul(code_onehot, x) # nb_code, w | |
| code_count = code_onehot.sum(dim=-1) # nb_code | |
| out = self._tile(x) | |
| code_rand = out[:self.nb_code] | |
| # Update centres | |
| self.code_sum = self.mu * self.code_sum + (1. - self.mu) * code_sum # w, nb_code | |
| self.code_count = self.mu * self.code_count + (1. - self.mu) * code_count # nb_code | |
| usage = (self.code_count.view(self.nb_code, 1) >= 1.0).float() | |
| code_update = self.code_sum.view(self.nb_code, self.code_dim) / self.code_count.view(self.nb_code, 1) | |
| self.codebook = usage * code_update + (1 - usage) * code_rand | |
| prob = code_count / torch.sum(code_count) | |
| perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7))) | |
| return perplexity | |
| def preprocess(self, x): | |
| # NCT -> NTC -> [NT, C] | |
| x = x.permute(0, 2, 1).contiguous() | |
| x = x.view(-1, x.shape[-1]) | |
| return x | |
| def quantize(self, x): | |
| # Calculate latent code x_l | |
| k_w = self.codebook.t() | |
| distance = torch.sum(x ** 2, dim=-1, keepdim=True) - 2 * torch.matmul(x, k_w) + torch.sum(k_w ** 2, dim=0, | |
| keepdim=True) # (N * L, b) | |
| _, code_idx = torch.min(distance, dim=-1) | |
| return code_idx | |
| def dequantize(self, code_idx): | |
| x = F.embedding(code_idx, self.codebook) | |
| return x | |
| def forward(self, x): | |
| N, width, T = x.shape | |
| # Preprocess | |
| x = self.preprocess(x) | |
| # Init codebook if not inited | |
| if self.training and not self.init: | |
| self.init_codebook(x) | |
| # quantize and dequantize through bottleneck | |
| code_idx = self.quantize(x) | |
| x_d = self.dequantize(code_idx) | |
| # Update embeddings | |
| if self.training: | |
| perplexity = self.update_codebook(x, code_idx) | |
| else : | |
| perplexity = self.compute_perplexity(code_idx) | |
| # Loss | |
| commit_loss = F.mse_loss(x, x_d.detach()) | |
| # Passthrough | |
| x_d = x + (x_d - x).detach() | |
| # Postprocess | |
| x_d = x_d.view(N, T, -1).permute(0, 2, 1).contiguous() #(N, DIM, T) | |
| return x_d, commit_loss, perplexity | |
| class Quantizer(nn.Module): | |
| def __init__(self, n_e, e_dim, beta): | |
| super(Quantizer, self).__init__() | |
| self.e_dim = e_dim | |
| self.n_e = n_e | |
| self.beta = beta | |
| self.embedding = nn.Embedding(self.n_e, self.e_dim) | |
| self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) | |
| def forward(self, z): | |
| N, width, T = z.shape | |
| z = self.preprocess(z) | |
| assert z.shape[-1] == self.e_dim | |
| z_flattened = z.contiguous().view(-1, self.e_dim) | |
| # B x V | |
| d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ | |
| torch.sum(self.embedding.weight**2, dim=1) - 2 * \ | |
| torch.matmul(z_flattened, self.embedding.weight.t()) | |
| # B x 1 | |
| min_encoding_indices = torch.argmin(d, dim=1) | |
| z_q = self.embedding(min_encoding_indices).view(z.shape) | |
| # compute loss for embedding | |
| loss = torch.mean((z_q - z.detach())**2) + self.beta * \ | |
| torch.mean((z_q.detach() - z)**2) | |
| # preserve gradients | |
| z_q = z + (z_q - z).detach() | |
| z_q = z_q.view(N, T, -1).permute(0, 2, 1).contiguous() #(N, DIM, T) | |
| min_encodings = F.one_hot(min_encoding_indices, self.n_e).type(z.dtype) | |
| e_mean = torch.mean(min_encodings, dim=0) | |
| perplexity = torch.exp(-torch.sum(e_mean*torch.log(e_mean + 1e-10))) | |
| return z_q, loss, perplexity | |
| def quantize(self, z): | |
| assert z.shape[-1] == self.e_dim | |
| # B x V | |
| d = torch.sum(z ** 2, dim=1, keepdim=True) + \ | |
| torch.sum(self.embedding.weight ** 2, dim=1) - 2 * \ | |
| torch.matmul(z, self.embedding.weight.t()) | |
| # B x 1 | |
| min_encoding_indices = torch.argmin(d, dim=1) | |
| return min_encoding_indices | |
| def dequantize(self, indices): | |
| index_flattened = indices.view(-1) | |
| z_q = self.embedding(index_flattened) | |
| z_q = z_q.view(indices.shape + (self.e_dim, )).contiguous() | |
| return z_q | |
| def preprocess(self, x): | |
| # NCT -> NTC -> [NT, C] | |
| x = x.permute(0, 2, 1).contiguous() | |
| x = x.view(-1, x.shape[-1]) | |
| return x | |
| class QuantizeReset(nn.Module): | |
| def __init__(self, nb_code, code_dim, args): | |
| super().__init__() | |
| self.nb_code = nb_code | |
| self.code_dim = code_dim | |
| self.reset_codebook() | |
| self.codebook = nn.Parameter(torch.randn(nb_code, code_dim)) | |
| def reset_codebook(self): | |
| self.init = False | |
| self.code_count = None | |
| def _tile(self, x): | |
| nb_code_x, code_dim = x.shape | |
| if nb_code_x < self.nb_code: | |
| n_repeats = (self.nb_code + nb_code_x - 1) // nb_code_x | |
| std = 0.01 / np.sqrt(code_dim) | |
| out = x.repeat(n_repeats, 1) | |
| out = out + torch.randn_like(out) * std | |
| else : | |
| out = x | |
| return out | |
| def init_codebook(self, x): | |
| out = self._tile(x) | |
| self.codebook = nn.Parameter(out[:self.nb_code]) | |
| self.code_count = torch.ones(self.nb_code, device=self.codebook.device) | |
| self.init = True | |
| def compute_perplexity(self, code_idx) : | |
| # Calculate new centres | |
| code_onehot = torch.zeros(self.nb_code, code_idx.shape[0], device=code_idx.device) # nb_code, N * L | |
| code_onehot.scatter_(0, code_idx.view(1, code_idx.shape[0]), 1) | |
| code_count = code_onehot.sum(dim=-1) # nb_code | |
| prob = code_count / torch.sum(code_count) | |
| perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7))) | |
| return perplexity | |
| def update_codebook(self, x, code_idx): | |
| code_onehot = torch.zeros(self.nb_code, x.shape[0], device=x.device) # nb_code, N * L | |
| code_onehot.scatter_(0, code_idx.view(1, x.shape[0]), 1) | |
| code_count = code_onehot.sum(dim=-1) # nb_code | |
| out = self._tile(x) | |
| code_rand = out[:self.nb_code] | |
| # Update centres | |
| self.code_count = code_count # nb_code | |
| usage = (self.code_count.view(self.nb_code, 1) >= 1.0).float() | |
| self.codebook.data = usage * self.codebook.data + (1 - usage) * code_rand | |
| prob = code_count / torch.sum(code_count) | |
| perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7))) | |
| return perplexity | |
| def preprocess(self, x): | |
| # NCT -> NTC -> [NT, C] | |
| x = x.permute(0, 2, 1).contiguous() | |
| x = x.view(-1, x.shape[-1]) | |
| return x | |
| def quantize(self, x): | |
| # Calculate latent code x_l | |
| k_w = self.codebook.t() | |
| distance = torch.sum(x ** 2, dim=-1, keepdim=True) - 2 * torch.matmul(x, k_w) + torch.sum(k_w ** 2, dim=0, | |
| keepdim=True) # (N * L, b) | |
| _, code_idx = torch.min(distance, dim=-1) | |
| return code_idx | |
| def dequantize(self, code_idx): | |
| x = F.embedding(code_idx, self.codebook) | |
| return x | |
| def forward(self, x): | |
| N, width, T = x.shape | |
| # Preprocess | |
| x = self.preprocess(x) | |
| # Init codebook if not inited | |
| if self.training and not self.init: | |
| self.init_codebook(x) | |
| # quantize and dequantize through bottleneck | |
| code_idx = self.quantize(x) | |
| x_d = self.dequantize(code_idx) | |
| # Update embeddings | |
| if self.training: | |
| perplexity = self.update_codebook(x, code_idx) | |
| else : | |
| perplexity = self.compute_perplexity(code_idx) | |
| # Loss | |
| commit_loss = F.mse_loss(x, x_d.detach()) | |
| # Passthrough | |
| x_d = x + (x_d - x).detach() | |
| # Postprocess | |
| x_d = x_d.view(N, T, -1).permute(0, 2, 1).contiguous() #(N, DIM, T) | |
| return x_d, commit_loss, perplexity | |
| class QuantizeEMA(nn.Module): | |
| def __init__(self, nb_code, code_dim, args): | |
| super().__init__() | |
| self.nb_code = nb_code | |
| self.code_dim = code_dim | |
| self.mu = 0.99 | |
| self.reset_codebook() | |
| def reset_codebook(self): | |
| self.init = False | |
| self.code_sum = None | |
| self.code_count = None | |
| self.register_buffer('codebook', torch.zeros(self.nb_code, self.code_dim).cuda()) | |
| def _tile(self, x): | |
| nb_code_x, code_dim = x.shape | |
| if nb_code_x < self.nb_code: | |
| n_repeats = (self.nb_code + nb_code_x - 1) // nb_code_x | |
| std = 0.01 / np.sqrt(code_dim) | |
| out = x.repeat(n_repeats, 1) | |
| out = out + torch.randn_like(out) * std | |
| else : | |
| out = x | |
| return out | |
| def init_codebook(self, x): | |
| out = self._tile(x) | |
| self.codebook = out[:self.nb_code] | |
| self.code_sum = self.codebook.clone() | |
| self.code_count = torch.ones(self.nb_code, device=self.codebook.device) | |
| self.init = True | |
| def compute_perplexity(self, code_idx) : | |
| # Calculate new centres | |
| code_onehot = torch.zeros(self.nb_code, code_idx.shape[0], device=code_idx.device) # nb_code, N * L | |
| code_onehot.scatter_(0, code_idx.view(1, code_idx.shape[0]), 1) | |
| code_count = code_onehot.sum(dim=-1) # nb_code | |
| prob = code_count / torch.sum(code_count) | |
| perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7))) | |
| return perplexity | |
| def update_codebook(self, x, code_idx): | |
| code_onehot = torch.zeros(self.nb_code, x.shape[0], device=x.device) # nb_code, N * L | |
| code_onehot.scatter_(0, code_idx.view(1, x.shape[0]), 1) | |
| code_sum = torch.matmul(code_onehot, x) # nb_code, w | |
| code_count = code_onehot.sum(dim=-1) # nb_code | |
| # Update centres | |
| self.code_sum = self.mu * self.code_sum + (1. - self.mu) * code_sum # w, nb_code | |
| self.code_count = self.mu * self.code_count + (1. - self.mu) * code_count # nb_code | |
| code_update = self.code_sum.view(self.nb_code, self.code_dim) / self.code_count.view(self.nb_code, 1) | |
| self.codebook = code_update | |
| prob = code_count / torch.sum(code_count) | |
| perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7))) | |
| return perplexity | |
| def preprocess(self, x): | |
| # NCT -> NTC -> [NT, C] | |
| x = x.permute(0, 2, 1).contiguous() | |
| x = x.view(-1, x.shape[-1]) | |
| return x | |
| def quantize(self, x): | |
| # Calculate latent code x_l | |
| k_w = self.codebook.t() | |
| distance = torch.sum(x ** 2, dim=-1, keepdim=True) - 2 * torch.matmul(x, k_w) + torch.sum(k_w ** 2, dim=0, | |
| keepdim=True) # (N * L, b) | |
| _, code_idx = torch.min(distance, dim=-1) | |
| return code_idx | |
| def dequantize(self, code_idx): | |
| x = F.embedding(code_idx, self.codebook) | |
| return x | |
| def forward(self, x): | |
| N, width, T = x.shape | |
| # Preprocess | |
| x = self.preprocess(x) | |
| # Init codebook if not inited | |
| if self.training and not self.init: | |
| self.init_codebook(x) | |
| # quantize and dequantize through bottleneck | |
| code_idx = self.quantize(x) | |
| x_d = self.dequantize(code_idx) | |
| # Update embeddings | |
| if self.training: | |
| perplexity = self.update_codebook(x, code_idx) | |
| else : | |
| perplexity = self.compute_perplexity(code_idx) | |
| # Loss | |
| commit_loss = F.mse_loss(x, x_d.detach()) | |
| # Passthrough | |
| x_d = x + (x_d - x).detach() | |
| # Postprocess | |
| x_d = x_d.view(N, T, -1).permute(0, 2, 1).contiguous() #(N, DIM, T) | |
| return x_d, commit_loss, perplexity | |