|
import torch |
|
import yaml |
|
import os |
|
import lightning.pytorch as pl |
|
from mae_dino.lightning_wrapper import GAIABaseLightning |
|
|
|
def GAIABase(config_path=None, |
|
img_size=None, |
|
patch_size=None, |
|
in_chans=None, |
|
epochs=None, |
|
adjacent_masking=None, |
|
mask_ratio=None, |
|
val_step=None): |
|
|
|
|
|
config_path = '../config.yaml' |
|
|
|
|
|
params = {} |
|
config = {} |
|
training_config = {} |
|
if config_path: |
|
try: |
|
with open(config_path, 'r') as f: |
|
params = yaml.safe_load(f) |
|
config = params.get('model', {}) |
|
training_config = params.get('training', {}) |
|
print(f"Loaded configuration from {config_path}") |
|
except Exception as e: |
|
print(f"Error loading configuration from {config_path}: {str(e)}") |
|
|
|
model = GAIABaseLightning( |
|
params=training_config, |
|
img_size=img_size or config.get('img_size'), |
|
patch_size=patch_size or config.get('patch_size'), |
|
in_chans=in_chans or config.get('in_chans'), |
|
encoder_embed_dim=config.get('encoder_embed_dim'), |
|
encoder_depth=config.get('encoder_depth'), |
|
encoder_num_heads=config.get('encoder_num_heads'), |
|
decoder_embed_dim=config.get('decoder_embed_dim'), |
|
decoder_depth=config.get('decoder_depth'), |
|
decoder_num_heads=config.get('decoder_num_heads'), |
|
mlp_ratio=config.get('mlp_ratio'), |
|
norm_layer=torch.nn.LayerNorm, |
|
norm_pix_loss=config.get('norm_pix_loss'), |
|
drop_channels_rate=config.get('drop_channels_rate'), |
|
|
|
adjacent_masking=adjacent_masking or config.get('adjacent_masking'), |
|
norm_last_layer=config.get('norm_last_layer'), |
|
dino_head_dim=config.get('dino_head_dim'), |
|
warmup_teacher_temp=config.get('warmup_teacher_temp'), |
|
teacher_temp=config.get('teacher_temp'), |
|
warmup_teacher_temp_epochs=config.get('warmup_teacher_temp_epochs'), |
|
epochs=epochs or training_config.get('epochs'), |
|
student_temp=config.get('student_temp'), |
|
center_momentum=config.get('center_momentum'), |
|
momentum_teacher=config.get('momentum_teacher'), |
|
mask_ratio=mask_ratio or config.get('mask_ratio'), |
|
val_step=val_step or training_config.get('val_step') |
|
) |
|
|
|
return model |
|
|
|
|
|
|
|
|
|
|