ik commited on
Commit
2ce6f72
·
verified ·
1 Parent(s): 4d1609b

Add rvq_wrapper.py

Browse files
Files changed (1) hide show
  1. rvq_wrapper.py +14 -0
rvq_wrapper.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, torch.nn as nn
2
+ from vector_quantize_pytorch import ResidualVQ
3
+ class RVQWrapper(nn.Module):
4
+ def __init__(self, dim, num_quantizers, codebook_size):
5
+ super().__init__()
6
+ self.proj_in = nn.Linear(dim, dim, bias=True)
7
+ self.rvq = ResidualVQ(dim=dim, num_quantizers=num_quantizers, codebook_size=codebook_size)
8
+ self.proj_out = nn.Linear(dim, dim, bias=True)
9
+ self.register_buffer('ema_counts', torch.zeros(num_quantizers, codebook_size))
10
+ def forward(self, x):
11
+ x = self.proj_in(x)
12
+ y, indices, commit = self.rvq(x)
13
+ y = self.proj_out(y)
14
+ return y, indices, commit