File size: 2,483 Bytes
fd943c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a31a170
fd943c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch
import yaml
import os
import lightning.pytorch as pl
from mae_dino.lightning_wrapper import GAIABaseLightning

def GAIABase(config_path=None,
                img_size=None,
                patch_size=None,
                in_chans=None,
                epochs=None,
                adjacent_masking=None,
                mask_ratio=None,
                val_step=None):
    
    # Load configuration file
    config_path = '../config.yaml'
    
    # Load config if provided
    params = {}
    config = {}
    training_config = {}
    if config_path:
        try:
            with open(config_path, 'r') as f:
                params = yaml.safe_load(f)
                config = params.get('model', {})
                training_config = params.get('training', {})
                print(f"Loaded configuration from {config_path}")
        except Exception as e:
            print(f"Error loading configuration from {config_path}: {str(e)}")

    model = GAIABaseLightning(
        params=training_config,
        img_size=img_size or config.get('img_size'),
        patch_size=patch_size or config.get('patch_size'), 
        in_chans=in_chans or config.get('in_chans'), 
        encoder_embed_dim=config.get('encoder_embed_dim'),
        encoder_depth=config.get('encoder_depth'), 
        encoder_num_heads=config.get('encoder_num_heads'), 
        decoder_embed_dim=config.get('decoder_embed_dim'),
        decoder_depth=config.get('decoder_depth'), 
        decoder_num_heads=config.get('decoder_num_heads'),
        mlp_ratio=config.get('mlp_ratio'), 
        norm_layer=torch.nn.LayerNorm, 
        norm_pix_loss=config.get('norm_pix_loss'),
        drop_channels_rate=config.get('drop_channels_rate'),
        # DINO Args
        adjacent_masking=adjacent_masking or config.get('adjacent_masking'),
        norm_last_layer=config.get('norm_last_layer'),
        dino_head_dim=config.get('dino_head_dim'),
        warmup_teacher_temp=config.get('warmup_teacher_temp'), 
        teacher_temp=config.get('teacher_temp'), 
        warmup_teacher_temp_epochs=config.get('warmup_teacher_temp_epochs'), 
        epochs=epochs or training_config.get('epochs'), 
        student_temp=config.get('student_temp'), 
        center_momentum=config.get('center_momentum'),
        momentum_teacher=config.get('momentum_teacher'),
        mask_ratio=mask_ratio or config.get('mask_ratio'),
        val_step=val_step or training_config.get('val_step')
    )

    return model