BraTS 2025 Brain Inpainting Model
This is an improved 3D U-Net model trained for brain tissue inpainting as part of the BraTS 2025 challenge. The model uses attention mechanisms and residual connections to reconstruct healthy brain tissue in regions affected by pathology.
Model Description
- Architecture: 3D U-Net with Attention mechanisms and Residual blocks
- Task: Brain tissue inpainting/reconstruction
- Dataset: BraTS 2025 training dataset
- Framework: PyTorch Lightning
Architecture Features
- 3D U-Net with Attention: Enhanced U-Net architecture with 3D attention blocks for better skip connections
- Residual Blocks: Residual connections for improved gradient flow and training stability
- Multi-scale Processing: Progressive downsampling and upsampling for capturing features at different scales
- GAN Training: Adversarial training with Pix2Pix framework for realistic tissue generation
Usage
import torch
from improved_model import ImprovedPix2Pix3D
# Load checkpoint
checkpoint = torch.load("last.ckpt", map_location='cpu')
model = ImprovedPix2Pix3D(**checkpoint['hyper_parameters'])
model.load_state_dict(checkpoint['state_dict'])
model.eval()
# Inference
with torch.no_grad():
output = model(voided_image, mask)
Training Details
- Loss Functions: L1 loss + SSIM loss + Adversarial loss
- Optimization: Adam optimizer with learning rate scheduling
- Regularization: Dropout and gradient clipping
- Data Augmentation: Random flips, rotations, and intensity variations
Performance
The model was trained to minimize:
- L1 Loss: Mean Absolute Error between predicted and ground truth
- SSIM Loss: Structural Similarity Index for perceptual quality
- Adversarial Loss: GAN loss for realistic tissue generation
Files
last.ckpt
: Complete model checkpoint with weights and hyperparametersimproved_model.py
: Model architecture code
License
Apache 2.0
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
๐
Ask for provider support