GAIA-v1 / mae_dino /lightning_presetup.py
willbender's picture
refactor: move base config to project root for model counting
a31a170
import torch
import yaml
import os
import lightning.pytorch as pl
from mae_dino.lightning_wrapper import GAIABaseLightning
def GAIABase(config_path=None,
img_size=None,
patch_size=None,
in_chans=None,
epochs=None,
adjacent_masking=None,
mask_ratio=None,
val_step=None):
# Load configuration file
config_path = '../config.yaml'
# Load config if provided
params = {}
config = {}
training_config = {}
if config_path:
try:
with open(config_path, 'r') as f:
params = yaml.safe_load(f)
config = params.get('model', {})
training_config = params.get('training', {})
print(f"Loaded configuration from {config_path}")
except Exception as e:
print(f"Error loading configuration from {config_path}: {str(e)}")
model = GAIABaseLightning(
params=training_config,
img_size=img_size or config.get('img_size'),
patch_size=patch_size or config.get('patch_size'),
in_chans=in_chans or config.get('in_chans'),
encoder_embed_dim=config.get('encoder_embed_dim'),
encoder_depth=config.get('encoder_depth'),
encoder_num_heads=config.get('encoder_num_heads'),
decoder_embed_dim=config.get('decoder_embed_dim'),
decoder_depth=config.get('decoder_depth'),
decoder_num_heads=config.get('decoder_num_heads'),
mlp_ratio=config.get('mlp_ratio'),
norm_layer=torch.nn.LayerNorm,
norm_pix_loss=config.get('norm_pix_loss'),
drop_channels_rate=config.get('drop_channels_rate'),
# DINO Args
adjacent_masking=adjacent_masking or config.get('adjacent_masking'),
norm_last_layer=config.get('norm_last_layer'),
dino_head_dim=config.get('dino_head_dim'),
warmup_teacher_temp=config.get('warmup_teacher_temp'),
teacher_temp=config.get('teacher_temp'),
warmup_teacher_temp_epochs=config.get('warmup_teacher_temp_epochs'),
epochs=epochs or training_config.get('epochs'),
student_temp=config.get('student_temp'),
center_momentum=config.get('center_momentum'),
momentum_teacher=config.get('momentum_teacher'),
mask_ratio=mask_ratio or config.get('mask_ratio'),
val_step=val_step or training_config.get('val_step')
)
return model