VQ-VAE Preview
The decoder provides a preview image from a predicted series of tokens (int values).
Compatible with the Gemma/LLama/Qwen models and the Trainer class.
Inference
from diffusers import VQModel
from PIL import Image
from tea_model import AdaLayerNorm
from torchvision import transforms
def image_to_indices(image, vq):
normalize = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
norm = normalize(image).unsqueeze(0).to(vq.device)
assert image.width == image.height
latent = vq.encoder(norm)
y = vq.quant_conv(latent)
y = y.permute(0, 2, 3, 1).contiguous()
indices = torch.argmin(
torch.cdist(y.view(-1, vq.quantize.vq_embed_dim), vq.quantize.embedding.weight),
dim=1
)
indices = indices.unsqueeze(0)
assert indices.shape[0] == norm.shape[0]
return indices, norm
def indices_to_tensor(indices, vq, ada: AdaLayerNorm):
b = indices.shape[0]
w = int(indices.shape[-1] ** 0.5)
v = vq.quantize.embedding(indices)
v = ada.forward(v)
v = v.view((b, w, w, vq.config.latent_channels))
v = v.permute(0, 3, 1, 2).contiguous()
x = vq.post_quant_conv(v)
y = vq.decoder(x, None)
return y
if __name__ == '__main__':
vq = VQModel.from_pretrained('MeissonFlow/Meissonic', subfolder='vqvae')
vq.to('cuda')
vq.eval()
image = Image.open('path/to/image.png')
# Encode image.
ids, _ = image_to_indices(image, vq)
# Decode by integer ids.
pixel_values = indices_to_tensor(ids, vq, ada) / 2.0 + 0.5
preview = transforms.functional.to_pil_image(pixel_values.detach().squeeze(0).to('cpu'))
Datasets
- Chars/pixiv_rank_daily_2018_2023
- jordandavis/fashion_num_people
- mattmdjaga/human_parsing_dataset
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
๐
Ask for provider support
Model tree for twodgirl/ms-vae-preview
Base model
MeissonFlow/Meissonic