Spaces:
Configuration error
Configuration error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from gradient_flow_ops import ReplaceGrad | |
| replace_grad = ReplaceGrad.apply | |
| def vector_quantize(x, codebook): | |
| d = x.pow(2).sum(dim=-1, keepdim=True) + codebook.pow(2).sum(dim=1) - 2 * x @ codebook.T | |
| indices = d.argmin(-1) | |
| x_q = F.one_hot(indices, codebook.shape[0]).to(d.dtype) @ codebook | |
| return replace_grad(x_q, x) |