U-Net for Image Inpainting on CIFAR-10

This repository contains a PyTorch implementation of a deep U-Net with Residual Blocks, trained to perform image inpainting on the CIFAR-10 dataset. The model takes an image with a masked (blacked-out) region and reconstructs the missing part.

Model Description

The model is a ComplexUNet architecture, a variant of the standard U-Net. It features:

  • Deeper Architecture: 4 downsampling and 4 upsampling stages.
  • Residual Blocks: Each stage uses residual blocks instead of simple convolutional layers.
  • Increased Width: The model was trained with base_channels=96.
  • Total Parameters: 73,148,259

How to Use

First, install the required libraries:

pip install torch torchvision numpy Pillow

Then, you can load the model and perform inpainting on an image tensor.

import torch
from torchvision import transforms as T
from PIL import Image
from model import ComplexUNet # Import the class from model.py

# --- Setup ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Download the .pth file from the 'Files and versions' tab of this repo
MODEL_PATH = "inpainting_model_larger.pth"

# --- Load Model ---
model = ComplexUNet(base_channels=96)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.to(DEVICE)
model.eval()

# --- Load and Preprocess Image ---
# image = Image.open("your_image.png").convert("RGB")
# For demonstration, let's create a dummy tensor
transform = T.Compose([T.Resize((32, 32)), T.ToTensor()])
# image_tensor = transform(image)
image_tensor = torch.rand(3, 32, 32)

# --- Create a Mask ---
masked_tensor = image_tensor.clone()
masked_tensor[:, 8:24, 8:24] = 0 # Example mask in the center

# --- Perform Inpainting ---
with torch.no_grad():
    input_tensor = masked_tensor.unsqueeze(0).to(DEVICE)
    reconstructed_tensor = model(input_tensor).squeeze(0).cpu()

# 'reconstructed_tensor' now holds the inpainted image.
from torchvision.transforms.functional import to_pil_image
reconstructed_image = to_pil_image(reconstructed_tensor)
reconstructed_image.save("reconstructed_image.png")
print("Saved reconstructed_image.png")

Training Data

The model was trained on the CIFAR-10 dataset.

  • Preprocessing: Images were used at their original 32x32 pixels resolution.
  • Augmentation: For each training image, a random rectangular mask was applied.

Training Procedure

  • Framework: PyTorch
  • Optimizer: Adam
  • Learning Rate: 0.001
  • Epochs: 50
  • Batch Size: 128
  • Loss Function: Mean Squared Error (MSE)

Evaluation

Evaluation metrics were not saved by the training script. To get PSNR and SSIM, please run the evaluate_model function from the training script.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Dataset used to train tahamajs/inpainting_transformer_weights