VAE - UNet-Style Autoencoder for 256x256 Image Reconstruction

This model is a UNet-style Variational Autoencoder (VAE) trained on the CC3M dataset for high-quality image reconstruction and generation. It integrates adversarial, perceptual, and identity-preserving loss terms to improve semantic and visual fidelity.

Architecture

  • Encoder/Decoder: Multi-scale UNet architecture
  • Latent Space: 8-channel latent bottleneck with reparameterization (mu, logvar)
  • Losses:
    • L1 reconstruction loss
    • KL divergence with annealing
    • LPIPS perceptual loss (VGG backbone)
    • Identity loss via MoCo-v2 embeddings
    • Adversarial loss via Patch Discriminator w/ Spectral Norm

Ltotal=Lrecon+LPIPS+0.5βˆ—LGAN+0.1βˆ—LID+10βˆ’6βˆ—LKL \mathcal{L}_{total} = \mathcal{L}_{recon} + \mathcal{L}_{PIPS} + 0.5 * \mathcal{L}_{GAN} + 0.1 *\mathcal{L}_{ID} + 10^{-6} *\mathcal{L}_{KL}

Reconstructions

Input Output
input output

Training Config

Hyperparameter Value
Dataset CC3M (850k images)
Image Resolution 256 x 256
Batch Size 16
Optimizer AdamW
Learning Rate 5e-5
Precision bf16 (mixed precision)
Total Steps 210,000
GAN Start Step 50,000
KL Annealing Yes (10% of training)
Augmentations Crop, flip, jitter, blur, rotation

Trained using a cosine learning rate schedule with gradient clipping and automatic mixed precision (torch.cuda.amp)

Usage Example

import torch
from transfusion.modeling.vae.vae import VAE
from transfusion.config.model import VAEConfig

config = VAEConfig(...)
vae = VAE(config, is_training=False)

ckpt = torch.load("vae_final_model.pt", map_location="cpu")
vae.load_state_dict(ckpt["vae_state_dict"], strict=False)
vae.eval()

with torch.no_grad():
    output, _, _ = vae(input_tensor)
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Dataset used to train gabehubner/vae-256px-cc3m

Evaluation results