U-Net for Image Inpainting on CIFAR-10
This repository contains a PyTorch implementation of a deep U-Net with Residual Blocks, trained to perform image inpainting on the CIFAR-10 dataset. The model takes an image with a masked (blacked-out) region and reconstructs the missing part.
Model Description
The model is a ComplexUNet
architecture, a variant of the standard U-Net. It features:
- Deeper Architecture: 4 downsampling and 4 upsampling stages.
- Residual Blocks: Each stage uses residual blocks instead of simple convolutional layers.
- Increased Width: The model was trained with
base_channels=96
. - Total Parameters: 73,148,259
How to Use
First, install the required libraries:
pip install torch torchvision numpy Pillow
Then, you can load the model and perform inpainting on an image tensor.
import torch
from torchvision import transforms as T
from PIL import Image
from model import ComplexUNet # Import the class from model.py
# --- Setup ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Download the .pth file from the 'Files and versions' tab of this repo
MODEL_PATH = "inpainting_model_larger.pth"
# --- Load Model ---
model = ComplexUNet(base_channels=96)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.to(DEVICE)
model.eval()
# --- Load and Preprocess Image ---
# image = Image.open("your_image.png").convert("RGB")
# For demonstration, let's create a dummy tensor
transform = T.Compose([T.Resize((32, 32)), T.ToTensor()])
# image_tensor = transform(image)
image_tensor = torch.rand(3, 32, 32)
# --- Create a Mask ---
masked_tensor = image_tensor.clone()
masked_tensor[:, 8:24, 8:24] = 0 # Example mask in the center
# --- Perform Inpainting ---
with torch.no_grad():
input_tensor = masked_tensor.unsqueeze(0).to(DEVICE)
reconstructed_tensor = model(input_tensor).squeeze(0).cpu()
# 'reconstructed_tensor' now holds the inpainted image.
from torchvision.transforms.functional import to_pil_image
reconstructed_image = to_pil_image(reconstructed_tensor)
reconstructed_image.save("reconstructed_image.png")
print("Saved reconstructed_image.png")
Training Data
The model was trained on the CIFAR-10 dataset.
- Preprocessing: Images were used at their original 32x32 pixels resolution.
- Augmentation: For each training image, a random rectangular mask was applied.
Training Procedure
- Framework: PyTorch
- Optimizer: Adam
- Learning Rate: 0.001
- Epochs: 50
- Batch Size: 128
- Loss Function: Mean Squared Error (MSE)
Evaluation
Evaluation metrics were not saved by the training script. To get PSNR and SSIM, please run the evaluate_model
function from the training script.
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
🙋
Ask for provider support