File size: 32,904 Bytes
fd943c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
from typing import List, Dict, Optional, Tuple, Union
import random
import torch
import timm
import numpy as np
from einops import rearrange
import torch.distributed as dist
from timm.layers import drop_path, DropPath, Mlp, trunc_normal_
import logging

from mae_dino.model_layers.layers import PatchEmbed, Attention, Block, PatchEmbed
from mae_dino.model_layers.pos_embed import get_3d_sincos_pos_embed

Shape3d = Union[List[int], Tuple[int, int, int]]

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class TemporalEncoder(torch.nn.Module):
    def __init__(self, embed_dim, tokens_per_frame):
        super().__init__()
        self.embed_dim = embed_dim
        self.tokens_per_frame = tokens_per_frame
        
        # Define embedding sizes for each temporal component
        self.year_embed_dim = embed_dim // 4
        self.doy_embed_dim = embed_dim // 4
        self.hour_embed_dim = embed_dim // 4
        self.minute_embed_dim = embed_dim - (
            self.year_embed_dim + self.doy_embed_dim + self.hour_embed_dim
        )
        
        # Embedding layers for categorical temporal features
        self.year_embedding = torch.nn.Embedding(3000, self.year_embed_dim)   # Years from 0000 to 2999
        self.doy_embedding = torch.nn.Embedding(367, self.doy_embed_dim)   # Day of Year 0-365
        self.hour_embedding = torch.nn.Embedding(24, self.hour_embed_dim)     # Hours 0-23
        self.minute_embedding = torch.nn.Embedding(60, self.minute_embed_dim) # Minutes 0-59

        # Initialize embeddings
        self._init_weights()
    
    def _init_weights(self):
        torch.nn.init.xavier_uniform_(self.year_embedding.weight)
        torch.nn.init.xavier_uniform_(self.doy_embedding.weight)
        torch.nn.init.xavier_uniform_(self.hour_embedding.weight)
        torch.nn.init.xavier_uniform_(self.minute_embedding.weight)
    
    def forward(self, year, doy, hour, minute):
        """
        Args:
            year (torch.Tensor): Shape (batch_size, time), integer years
            doy (torch.Tensor): Shape (batch_size, time), values [1, 366]
            hour (torch.Tensor): Shape (batch_size, time), values [0, 23]
            minute (torch.Tensor): Shape (batch_size, time), values [0, 59]
        
        Returns:
            torch.Tensor: Temporal embeddings of shape (batch_size, time * tokens_per_frame, embed_dim)
        """
        # Ensure inputs are of type Long
        year = year.long()
        doy = doy.long()
        hour = hour.long()
        minute = minute.long()
        
        # Get embeddings for each temporal component
        year_emb = self.year_embedding(year)       # (batch_size, time, year_embed_dim)
        doy_emb = self.doy_embedding(doy)          # (batch_size, time, doy_embed_dim)
        hour_emb = self.hour_embedding(hour)       # (batch_size, time, hour_embed_dim)
        minute_emb = self.minute_embedding(minute) # (batch_size, time, minute_embed_dim)
        
        # Concatenate embeddings along the last dimension
        temporal_emb = torch.cat(
            [year_emb, doy_emb, hour_emb, minute_emb], dim=-1
        )  # (batch_size, time, embed_dim)
        
        # Reshape to (batch_size, time * tokens_per_frame, embed_dim)
        batch_size, time_steps, _ = temporal_emb.shape
        temporal_emb = torch.repeat_interleave(temporal_emb, self.tokens_per_frame, dim=1)
        
        return temporal_emb


class Encoder(torch.nn.Module):
    def __init__(self, 
                 img_size: Shape3d = [4, 224, 224], 
                 patch_size: Shape3d = [1, 16, 16], 
                 in_chans: int = 3, 
                 encoder_embed_dim: int = 1024, 
                 encoder_depth: int = 8, 
                 encoder_num_heads: int = 16, 
                 mlp_ratio: float = 4., 
                 norm_layer: torch.nn.Module = torch.nn.LayerNorm, 
                 drop_channels_rate: float = 0.0,
                 adjacent_masking: bool = False,
                 ):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.in_chans = in_chans
        self.encoder_embed_dim = encoder_embed_dim
        self.encoder_depth = encoder_depth 
        self.encoder_num_heads = encoder_num_heads 
        self.mlp_ratio = mlp_ratio
        self.norm_layer = norm_layer
        self.drop_channels_rate = drop_channels_rate
        self.adjacent_masking = adjacent_masking

        # -------------------------------------------------------------------------- #
        # MAE encoder
        self.drop_channels = torch.nn.Dropout3d(self.drop_channels_rate) if self.drop_channels_rate > 0 else torch.nn.Identity()
        
        self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, encoder_embed_dim)
        num_patches = self.patch_embed.num_patches
        tokens_per_frame = num_patches // img_size[0]

        self.cls_token = torch.nn.Parameter(torch.zeros(1, 1, encoder_embed_dim))
        self.register_buffer("encoder_pos_embed", torch.zeros(1, num_patches + 1, encoder_embed_dim))

        self.encoder_blocks = torch.nn.ModuleList([
            Block(encoder_embed_dim, encoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=self.norm_layer)
            for i in range(self.encoder_depth)])
        self.norm = norm_layer(encoder_embed_dim)

        self.temporal_embed_enc = TemporalEncoder(embed_dim=encoder_embed_dim, tokens_per_frame=tokens_per_frame)

        # Initialize weights
        self.initialize_weights()

    def initialize_weights(self):
        encoder_pos_embed = get_3d_sincos_pos_embed(
            self.encoder_pos_embed.shape[-1], self.patch_embed.grid_size, cls_token=True
        )
        self.encoder_pos_embed.data.copy_(torch.from_numpy(encoder_pos_embed).float().unsqueeze(0))


        # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
        w = self.patch_embed.proj.weight.data
        torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))

        # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
        torch.nn.init.normal_(self.cls_token, std=0.02)

        # initialize nn.Linear and nn.LayerNorm
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, torch.nn.Linear):
            # we use xavier_uniform following official JAX ViT:
            torch.nn.init.xavier_uniform_(m.weight)
            if isinstance(m, torch.nn.Linear) and m.bias is not None:
                torch.nn.init.constant_(m.bias, 0)
        elif isinstance(m, torch.nn.LayerNorm):
            torch.nn.init.constant_(m.bias, 0)
            torch.nn.init.constant_(m.weight, 1.0)

    def patchify(self, imgs):
        """
        imgs: B, C, T, H, W
        x: B, L, D
        """
        s, p, q = self.patch_embed.patch_size
        x = rearrange(imgs, 'b c (t s) (h p) (w q) -> b (t h w) (s p q c)', s=s, p=p, q=q)

        return x

    def unpatchify(self, x):
        """
        x: B, L, D
        imgs: B, C, T, H, W
        """
        s, p, q = self.patch_embed.patch_size
        gs = self.patch_embed.grid_size
        imgs = rearrange(x, 'b (t h w) (s p q c) -> b c (t s) (h p) (w q)', h=gs[1], w=gs[2], t=gs[0], s=s, p=p, q=q)
        return imgs
    
    def log_helper(self, message):
            logger.info(message)

    def hinted_random_masking(self, x: torch.Tensor, x_mask: torch.Tensor, mask_ratio: int):
        """
        Perform per-sample random masking by per-sample shuffling.
        Per-sample shuffling is done by argsort random noise on patches without missing value pixels (indicated by x_mask).
        x: [N, L, D], sequence
        """
        N, L, D = x.shape  # batch, length, dim
        
        # Calculate missing data ratio from x_mask
        x_mask = x_mask.sum(dim=-1) > 0  # [N, L]
        missing_ratio = x_mask.float().mean(dim=1)  # [N]
   
        adjusted_mask_ratio = max(max(missing_ratio).item(),mask_ratio)

        len_keep = int(L * (1 - adjusted_mask_ratio))
        
        noise = torch.rand(N, L, device=x.device)  # noise in [0, 1]
        
        # x_mask = x_mask[0]
        ids_x_mask = torch.where(x_mask)
        # ids_x_mask = list(zip(ids_x_mask[0].tolist(), ids_x_mask[1].tolist()))
        # TODO translate to NP
        ids_x_mask_dict = {}
        for bs, p in zip(ids_x_mask[0], ids_x_mask[1]):
            bs, p = int(bs), int(p)
            if bs not in ids_x_mask_dict:
                ids_x_mask_dict[bs] = []
            ids_x_mask_dict[bs].append(p)

        noise += x_mask
        
        # sort noise for each sample
        ids_shuffle = torch.argsort(noise, dim=1)  # ascend: small is keep, large is remove
        ids_restore = torch.argsort(ids_shuffle, dim=1)
        
        # keep the first subset
        ids_keep = ids_shuffle[:, :len_keep]
        x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
        
        # generate the binary mask: 0 is keep, 1 is remove
        mask = torch.ones([N, L], device=x.device)
        mask[:, :len_keep] = 0
        # unshuffle to get the binary mask
        mask = torch.gather(mask, dim=1, index=ids_restore)
        
        return x_masked, mask, ids_restore, ids_x_mask_dict


    def hinted_adjacent_masking(self, x: torch.Tensor, x_mask: torch.Tensor, mask_ratio: float):
        """
        Perform masking by keeping contiguous blocks of patches in time and space,
        excluding patches where x_mask is 1 in each patch, and adjust to keep mask_ratio consistent.
        x: [N, L, D], sequence
        x_mask: [N, L, D], mask indicating invalid patches (1 where invalid)
        patch_embed: object containing grid_size attribute (T, H, W)
        """
        N, L, D = x.shape  # batch size, sequence length, feature dimension
    
        # Get grid sizes
        gs = self.patch_embed.grid_size  # (T, H, W)
        T, H, W = gs
    
        # Calculate total number of patches and number to keep
        total_patches = T * H * W
        len_keep = int(total_patches * (1 - mask_ratio))
    
        # Reshape x to [N, T, H, W, D]
        x = x.view(N, T, H, W, D)
    
        # Process x_mask to identify invalid patches
        # x_mask: [N, L, D] -> [N, L], True where invalid
        x_mask = x_mask.sum(dim=-1) > 0
    
        ids_x_mask = torch.where(x_mask)
        ids_x_mask_dict = {}
        for bs, p in zip(ids_x_mask[0], ids_x_mask[1]):
            bs, p = int(bs), int(p)
            if bs not in ids_x_mask_dict:
                ids_x_mask_dict[bs] = []
            ids_x_mask_dict[bs].append(p)
            
        x_mask = x_mask.view(N, T, H, W)  # Reshape to [N, T, H, W]
    
        # Initialize mask of zeros [N, T, H, W]; 0 is remove, 1 is keep
        mask = torch.zeros((N, T, H, W), device=x.device)
    
        ids_keep_list = []
    
        # For each sample in the batch
        for i in range(N):
            # Valid patches are where x_mask is False (0)
            valid_patches_mask = ~x_mask[i]  # [T, H, W], True where valid
            num_valid_patches = valid_patches_mask.sum().item()
    
            # Adjust len_keep if not enough valid patches
            len_keep_i = min(len_keep, num_valid_patches)
            if len_keep_i == 0:
                ids_keep_list.append(torch.tensor([], device=x.device, dtype=torch.long))
                continue  # Skip if no valid patches
    
            patches_selected = 0
            attempts = 0
            max_attempts = 1000  # Prevent infinite loops
    
            while patches_selected < len_keep_i and attempts < max_attempts:
                attempts += 1
    
                # Compute block sizes
                block_size = int(round((len_keep_i - patches_selected) ** (1/3)))
                t_size_block = min(T, block_size)
                h_size_block = min(H, block_size)
                w_size_block = min(W, block_size)
    
                # Randomly select starting positions
                t0 = random.randint(0, T - t_size_block)
                h0 = random.randint(0, H - h_size_block)
                w0 = random.randint(0, W - w_size_block)
    
                # Get indices for the block
                t_indices = slice(t0, t0 + t_size_block)
                h_indices = slice(h0, h0 + h_size_block)
                w_indices = slice(w0, w0 + w_size_block)
    
                # Extract the block of valid patches
                block_valid_mask = valid_patches_mask[t_indices, h_indices, w_indices]
                block_already_selected = mask[i, t_indices, h_indices, w_indices]
    
                # Find valid, unselected patches in the block
                selectable_mask = block_valid_mask & (block_already_selected == 0)
                num_selectable = selectable_mask.sum().item()
    
                if num_selectable == 0:
                    continue  # Try another block
    
                # Determine how many patches to select from this block
                num_to_select = min(len_keep_i - patches_selected, num_selectable)
    
                # Get indices of selectable patches
                selectable_indices = selectable_mask.nonzero(as_tuple=False)
    
                # Randomly select patches from the selectable ones
                selected_indices = selectable_indices[torch.randperm(num_selectable)[:num_to_select]]
    
                # Update the mask to keep the selected patches
                for idx in selected_indices:
                    t_idx, h_idx, w_idx = idx
                    mask[i, t0 + t_idx, h0 + h_idx, w0 + w_idx] = 1
    
                patches_selected += num_to_select
    
            # If not enough patches were selected, randomly select from remaining valid patches
            if patches_selected < len_keep_i:
                remaining_selectable = (valid_patches_mask & (mask[i] == 0)).nonzero(as_tuple=False)
                num_remaining = remaining_selectable.size(0)
                num_needed = len_keep_i - patches_selected
                num_to_select = min(num_remaining, num_needed)
                if num_to_select > 0:
                    selected_indices = remaining_selectable[torch.randperm(num_remaining)[:num_to_select]]
                    mask[i][selected_indices[:, 0], selected_indices[:, 1], selected_indices[:, 2]] = 1
                    patches_selected += num_to_select
    
            # Get the indices of the kept patches
            mask_i_flat = mask[i].view(-1)
            ids_keep_i = torch.where(mask_i_flat == 1)[0]
            ids_keep_list.append(ids_keep_i)
    
        # Determine the maximum length of ids_keep across all samples for padding
        max_len_keep = max(len(ids) for ids in ids_keep_list)
    
        # Pad ids_keep to have the same length
        ids_keep_padded = torch.zeros((N, max_len_keep), dtype=torch.long, device=x.device)
        for i, ids in enumerate(ids_keep_list):
            ids_keep_padded[i, :len(ids)] = ids
    
        # Flatten x back to [N, L, D]
        x_flat = x.view(N, L, D)
    
        # Gather x_masked with padding
        x_masked_list = []
        for i in range(N):
            ids = ids_keep_list[i]
            x_masked_i = torch.index_select(x_flat[i], dim=0, index=ids)
            x_masked_list.append(x_masked_i)
        x_masked = torch.nn.utils.rnn.pad_sequence(x_masked_list, batch_first=True)
    
        # Generate ids_restore (identity mapping)
        ids_restore = torch.arange(L, device=x.device).unsqueeze(0).repeat(N, 1)
    
        # Generate the binary mask: 0 is keep, 1 is remove
        mask_flat = mask.view(N, L)
        final_mask = 1 - mask_flat  # Invert mask to match expected output
    
        return x_masked, final_mask, ids_restore, ids_x_mask_dict


    def forward(self,
                x: torch.Tensor,
                x_mask: torch.Tensor,
                mask_ratio: float,
                # temporal_pos: Optional[torch.Tensor]):
                temporal_pos: Optional[List]):

        # Drop input channels
        x = self.drop_channels(x)

        # embed patches
        x = self.patch_embed(x)

        # add pos embed w/o cls token
        x = x + self.encoder_pos_embed[:, 1:, :]

        if temporal_pos:
            temporal_encoding = self.temporal_embed_enc(*temporal_pos)
            # temporal_encoding = self.drop_temporal(temporal_encoding, new_mask=True)
            x = x + temporal_encoding
        
        # masking: length -> length * mask_ratio
        x_mask = self.patchify(x_mask)
        if self.adjacent_masking:
            x, mask, ids_restore, ids_x_mask_dict = self.hinted_adjacent_masking(x, x_mask, mask_ratio)
        else:
            x, mask, ids_restore, ids_x_mask_dict = self.hinted_random_masking(x, x_mask, mask_ratio)

        # append cls token
        cls_token = self.cls_token + self.encoder_pos_embed[:, :1, :]
        cls_tokens = cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)

        # apply Transformer blocks
        for blk in self.encoder_blocks:
            x = blk(x)
        x = self.norm(x)

        return x, mask, ids_restore, ids_x_mask_dict


class Decoder(torch.nn.Module):
    def __init__(self, 
                 img_size: Shape3d = [4, 224, 224],
                 patch_size: Shape3d = [1, 16, 16], 
                 in_chans: int = 3,
                 encoder_embed_dim: int = 1024,
                 decoder_embed_dim: int = 512, 
                 decoder_depth: int = 8, 
                 decoder_num_heads: int = 16,
                 mlp_ratio: float = 4., 
                 norm_layer: torch.nn.Module = torch.nn.LayerNorm, 
                 norm_pix_loss: bool = False,
                 ):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.encoder_embed_dim = encoder_embed_dim
        self.decoder_embed_dim = decoder_embed_dim 
        self.decoder_depth = decoder_depth
        self.decoder_num_heads = decoder_num_heads
        self.mlp_ratio = mlp_ratio
        self.norm_layer = norm_layer
        self.norm_pix_loss = norm_pix_loss

        # -------------------------------------------------------------------------- #
        # MAE decoder
        self.decoder_embed = torch.nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=True)
        self.mask_token = torch.nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))

        self.grid_size = [d//p for d, p in zip(self.img_size, self.patch_size)]
        num_patches = np.prod(self.grid_size)
        self.register_buffer("decoder_pos_embed", torch.zeros(1, num_patches + 1, decoder_embed_dim))

        self.decoder_blocks = torch.nn.ModuleList([
            Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
            for i in range(decoder_depth)])

        self.decoder_norm = norm_layer(decoder_embed_dim)
        self.decoder_pred = torch.nn.Linear(decoder_embed_dim,
                                      patch_size[0] * patch_size[1] * patch_size[2] * in_chans,
                                      bias=True)  # decoder to patch
        
        tokens_per_frame = num_patches // img_size[0]
        self.temporal_embed_dec = TemporalEncoder(embed_dim=decoder_embed_dim, tokens_per_frame=tokens_per_frame)

    # Initialize weights
        self.initialize_weights()

    def initialize_weights(self):
        decoder_pos_embed = get_3d_sincos_pos_embed(
            self.decoder_pos_embed.shape[-1], self.grid_size, cls_token=True
        )
        self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))

        # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
        torch.nn.init.normal_(self.mask_token, std=0.02)

        # initialize nn.Linear and nn.LayerNorm
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, torch.nn.Linear):
            # we use xavier_uniform following official JAX ViT:
            torch.nn.init.xavier_uniform_(m.weight)
            if isinstance(m, torch.nn.Linear) and m.bias is not None:
                torch.nn.init.constant_(m.bias, 0)
        elif isinstance(m, torch.nn.LayerNorm):
            torch.nn.init.constant_(m.bias, 0)
            torch.nn.init.constant_(m.weight, 1.0)

    def forward(self, x: torch.Tensor,
                    ids_restore: torch.Tensor,
                    temporal_pos: Optional[torch.Tensor]):
        # embed tokens
        x = self.decoder_embed(x)

        # append mask tokens to sequence
        mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
        x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)  # no cls token
        x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))  # unshuffle
        x = torch.cat([x[:, :1, :], x_], dim=1)  # append cls token

        # add pos embed
        x = x + self.decoder_pos_embed
        # remove cls token
        x_ = x[:, 1:, :]

        if temporal_pos:
            temporal_encoding = self.temporal_embed_dec(*temporal_pos)
            # Reuse drop mask from encoder for consistent dropping
            # temporal_encoding = self.drop_temporal(temporal_encoding, new_mask=False)
            # Add temporal encoding w/o cls token
            x_ = x_ + temporal_encoding

        x = torch.cat([x[:, :1, :], x_], dim=1)  # append cls token

        # apply Transformer blocks
        for blk in self.decoder_blocks:
            x = blk(x)
        x = self.decoder_norm(x)

        # predictor projection
        x = self.decoder_pred(x)

        # remove cls token
        x = x[:, 1:, :]

        return x


    
    
class GAIABase(torch.nn.Module):
    def __init__(self, 
                 img_size: Shape3d = [4, 224, 224], 
                 patch_size: Shape3d = [1, 16, 16], 
                 in_chans: int = 3, 
                 encoder_embed_dim: int = 1024, 
                 encoder_depth: int = 8, 
                 encoder_num_heads: int = 16, 
                 decoder_embed_dim: int = 512, 
                 decoder_depth: int = 8, 
                 decoder_num_heads: int = 16,
                 mlp_ratio: float = 4., 
                 norm_layer: torch.nn.Module = torch.nn.LayerNorm, 
                 norm_pix_loss: bool = False,
                 drop_channels_rate: float = 0.0,
                 # DINO Args
                 adjacent_masking: bool = False,
                 norm_last_layer: bool = True,
                 dino_head_dim: int = 1024,
                 warmup_teacher_temp: float = 0.04, 
                 teacher_temp: float = 0.04, 
                 warmup_teacher_temp_epochs: int = 5, 
                 epochs: int = 100, 
                 student_temp: float = 0.1, 
                 center_momentum: float = 0.9,
                 momentum_teacher: float = 0.996,
                 ):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.in_chans = in_chans
        self.encoder_embed_dim = encoder_embed_dim
        self.encoder_depth = encoder_depth 
        self.encoder_num_heads = encoder_num_heads 
        self.decoder_embed_dim = decoder_embed_dim 
        self.decoder_depth = decoder_depth
        self.decoder_num_heads = decoder_num_heads
        self.mlp_ratio = mlp_ratio
        self.norm_layer = norm_layer
        self.norm_pix_loss = norm_pix_loss
        self.drop_channels_rate = drop_channels_rate
        # DINO Args
        self.adjacent_masking = adjacent_masking
        self.norm_last_layer = norm_last_layer
        self.dino_head_dim = dino_head_dim
        self.warmup_teacher_temp = warmup_teacher_temp
        self.teacher_temp = teacher_temp
        self.warmup_teacher_temp_epochs = warmup_teacher_temp_epochs
        self.epochs = epochs
        self.student_temp = student_temp
        self.center_momentum = center_momentum
        self.momentum_teacher = momentum_teacher
        self.log_count = 0


        self.encoder = Encoder(
            img_size=img_size, 
            patch_size=patch_size, 
            in_chans=in_chans, 
            encoder_embed_dim=encoder_embed_dim, 
            encoder_depth=encoder_depth, 
            encoder_num_heads=encoder_num_heads, 
            mlp_ratio=mlp_ratio, 
            norm_layer=norm_layer, 
            drop_channels_rate=drop_channels_rate,
            adjacent_masking=False,
        )
        
        self.decoder = Decoder(
            img_size=img_size, 
            patch_size=patch_size,
            in_chans=in_chans, 
            encoder_embed_dim=encoder_embed_dim,
            decoder_embed_dim=decoder_embed_dim, 
            decoder_depth=decoder_depth, 
            decoder_num_heads=decoder_num_heads,
            mlp_ratio=mlp_ratio, 
            norm_layer=norm_layer, 
            norm_pix_loss=norm_pix_loss,
        )
        
        self.teacher = Encoder(
            img_size=img_size, 
            patch_size=patch_size, 
            in_chans=in_chans, 
            encoder_embed_dim=encoder_embed_dim, 
            encoder_depth=encoder_depth, 
            encoder_num_heads=encoder_num_heads, 
            mlp_ratio=mlp_ratio, 
            norm_layer=norm_layer, 
            drop_channels_rate=drop_channels_rate,
            adjacent_masking=self.adjacent_masking,
        )

        # DINO Head wrappers
        self.student = PassThroughHead(
            self.encoder, 
            DINOHead(encoder_embed_dim, dino_head_dim, norm_last_layer=norm_last_layer)
        )                

        self.teacher = PassThroughHead(
            self.teacher, 
            DINOHead(encoder_embed_dim, dino_head_dim, norm_last_layer=norm_last_layer)
        )

        # teacher and student start with the same weights
        self.teacher.load_state_dict(self.student.state_dict())
        
        # teacher frozen
        for p in self.teacher.parameters():
            p.requires_grad = False
    
        self.dino_loss = DINOLoss(dino_head_dim, warmup_teacher_temp, teacher_temp, 
                             warmup_teacher_temp_epochs, epochs, student_temp, center_momentum)
    
    
    def forward_mae(self, imgs: torch.Tensor, img_masks: torch.Tensor, temporal_pos: Optional[torch.Tensor], mask_ratio: float = 0.75):

        latent, mask, ids_restore, ids_x_mask_dict = self.encoder(imgs, img_masks, temporal_pos=temporal_pos, mask_ratio=mask_ratio)
        pred = self.decoder(latent, ids_restore, temporal_pos)
        loss = self.forward_mae_loss(imgs, pred, mask, ids_x_mask_dict)
        return loss, pred, mask
        
    def forward_dino(self, imgs: torch.Tensor, img_masks: torch.Tensor, temporal_pos: Optional[torch.Tensor], mask_ratio: float = 0.75):
        with torch.no_grad():
            t_pred = self.teacher(imgs, img_masks, temporal_pos=temporal_pos, mask_ratio=mask_ratio)
        s_pred = self.student(imgs, img_masks, temporal_pos=temporal_pos, mask_ratio=mask_ratio)
        return s_pred, t_pred
        
    def log_helper(self, message):
        logger.info(message)
    
    def forward_mae_loss(self, imgs: torch.Tensor, pred: torch.Tensor, mask: torch.Tensor, ids_x_mask_dict: Dict[str, List]):
        """
        imgs: B, C, T, H, W
        target: B, L, D
        pred: B, L, D
        mask: B, L. 0 is keep, 1 is remove,
        """

        eps = 1e-6
        target = self.encoder.patchify(imgs)

        if self.decoder.norm_pix_loss:
            mean = target.mean(dim=-1, keepdim=True)
            var = target.var(dim=-1, keepdim=True)

        loss = ((pred - target).to(torch.float) ** 2).mean(dim=-1)

        # Process mask and loss to exclude ids_x_mask patches (patches that include missing data) in the calculation
        for index, patch in ids_x_mask_dict.items():
            mask[index, patch] = 0
            loss[index, patch] = 0

        loss_mask = torch.clamp(loss * mask, max=1e8)
        loss = loss_mask.sum() / (mask.sum() + eps)  ### mean loss on removed patches
        return loss

    def forward(self, imgs: torch.Tensor, img_masks: torch.Tensor, temporal_pos: Optional[torch.Tensor], mask_ratio: float = 0.75, epoch=None):
        if epoch is None:
            raise ValueError(f"epoch value is invalid")
        # MAE
        temporal_pos = None

        mae_loss, pred, mask = self.forward_mae(imgs, img_masks, temporal_pos=temporal_pos, mask_ratio=mask_ratio)
        
        # DINO
        student_output, teacher_output = self.forward_dino(imgs, img_masks, temporal_pos=temporal_pos, mask_ratio=mask_ratio)
        dino_loss = self.dino_loss(student_output, teacher_output, epoch)

        # Total Loss
        total_loss = mae_loss + dino_loss
        return (total_loss, dino_loss, mae_loss), (pred, mask, student_output, teacher_output)


class DINOHead(torch.nn.Module):
    def __init__(self, in_dim, dino_head_dim, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=256):
        super().__init__()
        nlayers = max(nlayers, 1)
        if nlayers == 1:
            self.mlp = torch.nn.Linear(in_dim, bottleneck_dim)
        else:
            layers = [torch.nn.Linear(in_dim, hidden_dim)]
            
            layers.append(torch.nn.LayerNorm(hidden_dim))
            layers.append(torch.nn.GELU())
            
            for _ in range(nlayers - 2):
                layers.append(torch.nn.Linear(hidden_dim, hidden_dim))
                layers.append(torch.nn.LayerNorm(hidden_dim))
                layers.append(torch.nn.GELU())
                
            layers.append(torch.nn.Linear(hidden_dim, bottleneck_dim))
            self.mlp = torch.nn.Sequential(*layers)
        
        self.apply(self._init_weights)
        self.last_layer = torch.nn.utils.weight_norm(torch.nn.Linear(bottleneck_dim, dino_head_dim, bias=False))
        self.last_layer.weight_g.data.fill_(1)
        if norm_last_layer:
            self.last_layer.weight_g.requires_grad = False

    def _init_weights(self, m):
        if isinstance(m, torch.nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, torch.nn.Linear) and m.bias is not None:
                torch.nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.mlp(x)
        x = torch.nn.functional.normalize(x, dim=-1, p=2)
        x = self.last_layer(x)
        return x

class PassThroughHead(torch.nn.Module):
    def __init__(self, backbone, head):
        super().__init__()
        self.backbone = backbone
        self.head = head

    def forward(self, imgs: torch.Tensor, img_masks: torch.Tensor, temporal_pos: Optional[torch.Tensor], mask_ratio: float = 0.75):
        x, _, _, _ = self.backbone(imgs, img_masks, temporal_pos=temporal_pos, mask_ratio=mask_ratio)

        # Either use cls token or use Global Average Pooling
        # # CLS Token (default)
        x = self.head(x[:, 0]) # Use the cls token
        # # Global Average Pooling
        # x = self.head(x.mean(dim=1) # Use the cls token
        return x

class DINOLoss(torch.nn.Module):
    def __init__(self, out_dim, warmup_teacher_temp, teacher_temp, warmup_teacher_temp_epochs, nepochs,
                 student_temp, center_momentum):
        super().__init__()
        self.student_temp = student_temp
        self.center_momentum = center_momentum
        self.register_buffer("center", torch.zeros(1, out_dim))
        self.teacher_temp_schedule = np.concatenate(
            (np.linspace(warmup_teacher_temp, teacher_temp, warmup_teacher_temp_epochs),
             np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp))
        self.log_count = 0

    def forward(self, student_output, teacher_output, epoch):
        student_out = student_output / self.student_temp
        temp = self.teacher_temp_schedule[epoch]
        teacher_out = torch.nn.functional.softmax((teacher_output - self.center) / temp, dim=-1).detach()
        
        loss = torch.sum(-teacher_out * torch.nn.functional.log_softmax(student_out, dim=-1), dim=-1) # Changed from student_output
        total_loss = loss.mean()

        self.update_center(teacher_output)
        return total_loss

    @torch.no_grad()
    def update_center(self, teacher_output):
        """
        Update center used for teacher output.
        """
        batch_center = torch.sum(teacher_output, dim=0, keepdim=True)
        # TODO: Tom et al.: Will this impact the DeepSpeed and lightning? Should we keep it or remove it?
        if dist.is_initialized():
            dist.all_reduce(batch_center)
            batch_center /= len(teacher_output) * dist.get_world_size()
        else:
            batch_center /= len(teacher_output)  # Use only batch size for single-process mode
    
        self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum)