Add rvq_wrapper.py
Browse files- rvq_wrapper.py +14 -0
rvq_wrapper.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch, torch.nn as nn
|
2 |
+
from vector_quantize_pytorch import ResidualVQ
|
3 |
+
class RVQWrapper(nn.Module):
|
4 |
+
def __init__(self, dim, num_quantizers, codebook_size):
|
5 |
+
super().__init__()
|
6 |
+
self.proj_in = nn.Linear(dim, dim, bias=True)
|
7 |
+
self.rvq = ResidualVQ(dim=dim, num_quantizers=num_quantizers, codebook_size=codebook_size)
|
8 |
+
self.proj_out = nn.Linear(dim, dim, bias=True)
|
9 |
+
self.register_buffer('ema_counts', torch.zeros(num_quantizers, codebook_size))
|
10 |
+
def forward(self, x):
|
11 |
+
x = self.proj_in(x)
|
12 |
+
y, indices, commit = self.rvq(x)
|
13 |
+
y = self.proj_out(y)
|
14 |
+
return y, indices, commit
|