import torch import yaml import os import lightning.pytorch as pl from mae_dino.lightning_wrapper import GAIABaseLightning def GAIABase(config_path=None, img_size=None, patch_size=None, in_chans=None, epochs=None, adjacent_masking=None, mask_ratio=None, val_step=None): # Load configuration file config_path = '../config.yaml' # Load config if provided params = {} config = {} training_config = {} if config_path: try: with open(config_path, 'r') as f: params = yaml.safe_load(f) config = params.get('model', {}) training_config = params.get('training', {}) print(f"Loaded configuration from {config_path}") except Exception as e: print(f"Error loading configuration from {config_path}: {str(e)}") model = GAIABaseLightning( params=training_config, img_size=img_size or config.get('img_size'), patch_size=patch_size or config.get('patch_size'), in_chans=in_chans or config.get('in_chans'), encoder_embed_dim=config.get('encoder_embed_dim'), encoder_depth=config.get('encoder_depth'), encoder_num_heads=config.get('encoder_num_heads'), decoder_embed_dim=config.get('decoder_embed_dim'), decoder_depth=config.get('decoder_depth'), decoder_num_heads=config.get('decoder_num_heads'), mlp_ratio=config.get('mlp_ratio'), norm_layer=torch.nn.LayerNorm, norm_pix_loss=config.get('norm_pix_loss'), drop_channels_rate=config.get('drop_channels_rate'), # DINO Args adjacent_masking=adjacent_masking or config.get('adjacent_masking'), norm_last_layer=config.get('norm_last_layer'), dino_head_dim=config.get('dino_head_dim'), warmup_teacher_temp=config.get('warmup_teacher_temp'), teacher_temp=config.get('teacher_temp'), warmup_teacher_temp_epochs=config.get('warmup_teacher_temp_epochs'), epochs=epochs or training_config.get('epochs'), student_temp=config.get('student_temp'), center_momentum=config.get('center_momentum'), momentum_teacher=config.get('momentum_teacher'), mask_ratio=mask_ratio or config.get('mask_ratio'), val_step=val_step or training_config.get('val_step') ) return model