File size: 29,012 Bytes
b20c769
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchmetrics import Accuracy, F1Score, JaccardIndex, Metric
from torchmetrics.segmentation import MeanIoU


class MetricsAccuracy(Metric):
    """
    Computes the Ovearall Accuracy
    Args:
        modalities (list): list of modalities used
        num_classes (int): number of classes
        save_results (bool): if True saves prediction in a csv file
        get_classes (bool): if True returns the classwise F1 Score
    """

    def __init__(
        self,
        modalities: list = [],
        num_classes: int = 15,
        save_results: bool = False,
        get_classes: bool = False,
        multilabel: bool = False,
    ):
        super().__init__()
        self.get_classes = get_classes
        task = "multilabel" if multilabel else "multiclass"
        self.acc = Accuracy(task=task, num_classes=num_classes)
        self.save_results = save_results
        self.multiclass = not (multilabel)
        if save_results:
            self.results = {}

    def update(self, pred, gt):
        self.acc(pred, gt["label"])
        if self.save_results:
            for i, name in enumerate(gt["name"]):
                self.results[name] = list(pred.cpu()[i].numpy())

    def compute(self):
        return {"OA": self.acc.compute()}


class MetricsMonoModal(Metric):
    """
    Computes the micro, macro and weighted F1 Score for multi label classification
    Args:
        modalities (list): list of modalities used
        num_classes (int): number of classes
        save_results (bool): if True saves prediction in a csv file
        get_classes (bool): if True returns the classwise F1 Score
    """

    def __init__(
        self,
        modalities: list = [],
        num_classes: int = 15,
        save_results: bool = False,
        get_classes: bool = False,
        multilabel: bool = True,
    ):
        super().__init__()
        self.get_classes = get_classes
        task = "multilabel" if multilabel else "multiclass"
        self.f1 = F1Score(
            task=task, average="none", num_labels=num_classes, num_classes=num_classes
        )
        self.f1_micro = F1Score(
            task=task, average="micro", num_labels=num_classes, num_classes=num_classes
        )
        self.f1_weighted = F1Score(
            task=task, average="weighted", num_labels=num_classes, num_classes=num_classes
        )
        self.save_results = save_results
        self.multiclass = not (multilabel)
        if save_results:
            self.results = {}

    def update(self, pred, gt):
        if self.multiclass:
            gt["label"] = gt["label"].argmax(dim=1)
        self.f1(pred, gt["label"])
        self.f1_micro(pred, gt["label"])
        self.f1_weighted(pred, gt["label"])
        if self.save_results:
            for i, name in enumerate(gt["name"]):
                self.results[name] = list(pred.cpu()[i].numpy())

    def compute(self):
        if self.get_classes:
            f1 = self.f1.compute()
            out = {
                "F1_Score_macro": sum(f1) / len(f1),
                "F1_Score_micro": self.f1_micro.compute(),
                "F1_Score_weighted": self.f1_weighted.compute(),
            }
            for i in range(len(f1)):
                out["_".join(["F1_classe", str(i)])] = f1[i]
            return out
        f1 = self.f1.compute()
        out = {
            "F1_Score_macro": sum(f1) / len(f1),
            "F1_Score_micro": self.f1_micro.compute(),
            "F1_Score_weighted": self.f1_weighted.compute(),
        }
        if self.save_results:
            out["results"] = self.results
            return out
        return out


class NoMetrics(Metric):
    """
    Computes no metrics or saves a batch of reconstruction to visualise them
    Args:
        save_reconstructs (bool): if True saves a batch of reconstructions
        modalities (list): list of modalities used
        save_dir (str): where to save reconstructions
    """

    def __init__(
        self,
        save_reconstructs: bool = False,
        modalities: list = [],
        save_dir: str = "",
    ):
        super().__init__()
        self.save_dir = save_dir
        self.save_recons = save_reconstructs
        self.modalities = modalities
        if self.save_recons:
            if not os.path.exists(save_dir):
                os.makedirs(save_dir)
            self.saves = {}
            for modality in self.modalities:
                self.saves[modality] = []
                self.saves["_".join(["gt", modality])] = []

    def update(self, pred, gt):
        if self.save_recons:
            recons, _ = pred
            for modality in self.modalities:
                if modality == "aerial":
                    preds = recons["_".join(["reconstruct", modality])]
                    target = gt[modality][:, :, :300, :300]
                else:
                    preds, mask = recons["_".join(["reconstruct", modality])]
                    target = gt[modality][mask[:, 0], mask[:, 1]]
                indice = torch.randint(0, len(preds), (1,)).item()
                self.saves[modality].append(preds[indice])
                self.saves["_".join(["gt", modality])].append(target[indice])

    def compute(self):
        if self.save_recons:
            for key in self.saves.keys():
                for i, tensor in enumerate(self.saves[key]):
                    torch.save(tensor.cpu(), self.save_dir + key + str(i) + ".pt")
        return {}


class MetricsContrastif(Metric):
    """
    Computes metrics for contrastive. Given embeddings for all tokens, we compute the cosine similarity matrix.
    The metric computed is the accuracy of the M -1 minimum distances of each line (except diagonal of course)
    being the same token across other modalities with M the number of modalities.
    Args:
        modalities (list): list of modalities used
    """

    def __init__(
        self,
        modalities: list = [],
    ):
        super().__init__()
        self.modalities = modalities
        self.n_k = len(self.modalities)

        self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum")

        for i in range(len(modalities)):
            self.add_state(modalities[i], default=torch.tensor(0.0), dist_reduce_fx="sum")

    def update(self, logits):
        size = len(logits) // self.n_k
        labels = torch.arange(size).unsqueeze(1)
        labels = torch.cat([labels + i * len(labels) for i in range(self.n_k)], dim=1)
        labels = torch.cat([labels for _ in range(self.n_k)]).to(logits.device)
        for i in range(self.n_k):
            _, top_indices = torch.topk(
                logits[i * size : (i + 1) * size], k=self.n_k, dim=1, largest=True
            )
            self.__dict__[self.modalities[i]] += (
                torch.sum(
                    torch.tensor(
                        [
                            top_indices[i, j] in labels[i]
                            for i in range(top_indices.size(0))
                            for j in range(self.n_k)
                        ]
                    )
                )
                - len(top_indices)
            ) / (self.n_k - 1)
        self.count += len(logits)

    def compute(self):
        dict = {}
        for i in range(len(self.modalities)):
            dict["_".join(["acc", self.modalities[i]])] = (
                self.__dict__[self.modalities[i]] / self.count
            )
        return dict


class MetricsContrastifMulti(Metric):
    """
    Computes metrics for contrastive. Given embeddings for all tokens, we compute the cosine similarity matrix.
    The metric computed is the accuracy of the M -1 minimum distances of each line (except diagonal of course)
    being the same token across other modalities with M the number of modalities.
    Args:
        modalities (list): list of modalities used
    """

    def __init__(
        self,
        modalities: dict = {},
    ):
        super().__init__()
        self.modalities = modalities

        for dataset in self.modalities.keys():
            self.add_state(dataset + "_count", default=torch.tensor(0), dist_reduce_fx="sum")
            for i in range(len(modalities[dataset])):
                self.add_state(
                    dataset + "_" + modalities[dataset][i],
                    default=torch.tensor(0.0),
                    dist_reduce_fx="sum",
                )

    def update(self, logits, dataset):
        modalities = self.modalities[dataset]
        n_modalities = len(modalities)
        size = len(logits) // n_modalities
        labels = torch.arange(size).unsqueeze(1)
        labels = torch.cat([labels + i * len(labels) for i in range(n_modalities)], dim=1)
        labels = torch.cat([labels for _ in range(n_modalities)]).to(logits.device)
        for i in range(n_modalities):
            _, top_indices = torch.topk(
                logits[i * size : (i + 1) * size], k=n_modalities, dim=1, largest=True
            )
            self.__dict__[dataset + "_" + modalities[i]] += (
                torch.sum(
                    torch.tensor(
                        [
                            top_indices[i, j] in labels[i]
                            for i in range(top_indices.size(0))
                            for j in range(n_modalities)
                        ]
                    )
                )
                - len(top_indices)
            ) / (n_modalities - 1)
        self.__dict__[dataset + "_count"] += len(logits)

    def compute(self):
        dict = {}
        for dataset in self.modalities.keys():
            for i in range(len(self.modalities[dataset])):
                dict["_".join(["acc", dataset, self.modalities[dataset][i]])] = (
                    self.__dict__[dataset + "_" + self.modalities[dataset][i]]
                    / self.__dict__[dataset + "_count"]
                )
        return dict


class MetricsSemSeg(Metric):
    """
    Computes mIoU for semantic segmentation
    Args:
        modalities (list): list of modalities used
        num_classes (int): number of classes
        save_results (bool): if True saves prediction in a csv file
        get_classes (bool): if True returns the classwise F1 Score
    """

    def __init__(
        self,
        modalities: list = [],
        num_classes: int = 15,
        save_results: bool = False,
        get_classes: bool = False,
    ):
        super().__init__()
        self.modality = modalities[0]
        self.num_classes = num_classes
        self.get_classes = get_classes
        self.miou = MeanIoU(num_classes=num_classes, per_class=True)
        self.save_results = save_results
        if save_results:
            self.results = {}

    def update(self, pred, gt):
        label = gt["label"].flatten(0, 1).long()
        self.miou(
            torch.nn.functional.one_hot(
                pred.flatten(2, 3).permute(0, 2, 1).flatten(0, 1).argmax(dim=1),
                num_classes=self.num_classes,
            ),
            torch.nn.functional.one_hot(label, num_classes=self.num_classes),
        )
        if self.save_results:
            for i, name in enumerate(gt["name"]):
                self.results[name] = list(pred.cpu()[i].numpy())

    def compute(self):
        if self.get_classes:
            miou = self.miou.compute()
            out = {"mIoU": sum(miou) / len(miou)}
            for i in range(len(miou[:-1])):
                out["_".join(["IoU", str(i)])] = miou[i]
            return out
        miou = self.miou.compute()
        out = {
            "IoU": miou[1].item(),  # IoU of the foreground class
        }
        if self.save_results:
            out["results"] = self.results
            return out
        return out


class MetricsSemSegJ(Metric):
    """
    Computes the IoU for binary segmentation
    Args:
        modalities (list): list of modalities used
        num_classes (int): number of classes
        save_results (bool): if True saves prediction in a csv file
        get_classes (bool): if True returns the classwise F1 Score
    """

    def __init__(
        self,
        modalities: list = [],
        num_classes: int = 15,
        save_results: bool = False,
        get_classes: bool = False,
        save_dir: str = "",
    ):
        super().__init__()
        self.modality = modalities[0]
        self.num_classes = num_classes
        self.get_classes = get_classes
        self.miou = JaccardIndex(task="multiclass", num_classes=2, ignore_index=-1)
        self.save_results = save_results
        if save_results:
            self.save_dir = save_dir
            if not os.path.exists(save_dir):
                os.makedirs(save_dir)

    def update(self, pred, gt):
        self.miou(
            pred.flatten(2, 3).permute(0, 2, 1).flatten(0, 1).argmax(dim=1),
            gt["label"].flatten(1, 2).flatten(0, 1).long(),
        )
        if self.save_results:
            for i, name in enumerate(gt["name"]):
                np.save(self.save_dir + str(name) + ".npy", pred.cpu()[i].numpy())
                np.save(self.save_dir + str(name) + "_gt.npy", gt["label"].cpu()[i].numpy())

    def compute(self):
        if self.get_classes:
            miou = self.miou.compute()
            out = {"mIoU": sum(miou[:-1]) / len(miou[:-1])}
            for i in range(len(miou[:-1])):
                out["_".join(["IoU", str(i)])] = miou[i]
            return out
        miou = self.miou.compute()
        out = {
            "IoU": miou,
        }
        return out


import numpy as np


class Metric(object):
    """Base class for all metrics.
    From: https://github.com/pytorch/tnt/blob/master/torchnet/meter/meter.py
    """

    def reset(self):
        pass

    def add(self):
        pass

    def value(self):
        pass


class ConfusionMatrix(Metric):
    """Constructs a confusion matrix for a multi-class classification problems.

    Does not support multi-label, multi-class problems.

    Keyword arguments:
    - num_classes (int): number of classes in the classification problem.
    - normalized (boolean, optional): Determines whether or not the confusion
    matrix is normalized or not. Default: False.

    Modified from: https://github.com/pytorch/tnt/blob/master/torchnet/meter/confusionmeter.py
    """

    def __init__(self, num_classes, normalized=False, device="cpu", lazy=True):
        super().__init__()
        if device == "cpu":
            self.conf = np.ndarray((num_classes, num_classes), dtype=np.int64)
        else:
            self.conf = torch.zeros((num_classes, num_classes)).cuda()
        self.normalized = normalized
        self.num_classes = num_classes
        self.device = device
        self.reset()
        self.lazy = lazy

    def reset(self):
        if self.device == "cpu":
            self.conf.fill(0)
        else:
            self.conf = torch.zeros(self.conf.shape).cuda()

    def add(self, predicted, target):
        """Computes the confusion matrix

        The shape of the confusion matrix is K x K, where K is the number
        of classes.

        Keyword arguments:
        - predicted (Tensor or numpy.ndarray): Can be an N x K tensor/array of
        predicted scores obtained from the model for N examples and K classes,
        or an N-tensor/array of integer values between 0 and K-1.
        - target (Tensor or numpy.ndarray): Can be an N x K tensor/array of
        ground-truth classes for N examples and K classes, or an N-tensor/array
        of integer values between 0 and K-1.

        """

        # If target and/or predicted are tensors, convert them to numpy arrays
        if self.device == "cpu":
            if torch.is_tensor(predicted):
                predicted = predicted.cpu().numpy()
            if torch.is_tensor(target):
                target = target.cpu().numpy()

        assert (
            predicted.shape[0] == target.shape[0]
        ), "number of targets and predicted outputs do not match"

        if len(predicted.shape) != 1:
            assert (
                predicted.shape[1] == self.num_classes
            ), "number of predictions does not match size of confusion matrix"
            predicted = predicted.argmax(1)
        else:
            if not self.lazy:
                assert (predicted.max() < self.num_classes) and (
                    predicted.min() >= 0
                ), "predicted values are not between 0 and k-1"

        if len(target.shape) != 1:
            if not self.lazy:
                assert (
                    target.shape[1] == self.num_classes
                ), "Onehot target does not match size of confusion matrix"
                assert (target >= 0).all() and (
                    target <= 1
                ).all(), "in one-hot encoding, target values should be 0 or 1"
                assert (target.sum(1) == 1).all(), "multi-label setting is not supported"
            target = target.argmax(1)
        else:
            if not self.lazy:
                assert (target.max() < self.num_classes) and (
                    target.min() >= 0
                ), "target values are not between 0 and k-1"

        # hack for bincounting 2 arrays together
        x = predicted + self.num_classes * target

        if self.device == "cpu":
            bincount_2d = np.bincount(x.astype(np.int64), minlength=self.num_classes**2)
            assert bincount_2d.size == self.num_classes**2
            conf = bincount_2d.reshape((self.num_classes, self.num_classes))
        else:
            bincount_2d = torch.bincount(x, minlength=self.num_classes**2)

            conf = bincount_2d.view((self.num_classes, self.num_classes))
        self.conf += conf

    def value(self):
        """
        Returns:
            Confustion matrix of K rows and K columns, where rows corresponds
            to ground-truth targets and columns corresponds to predicted
            targets.
        """
        if self.normalized:
            conf = self.conf.astype(np.float32)
            return conf / conf.sum(1).clip(min=1e-12)[:, None]
        else:
            return self.conf


class IoU(Metric):
    """Computes the intersection over union (IoU) per class and corresponding
    mean (mIoU).

    Intersection over union (IoU) is a common evaluation metric for semantic
    segmentation. The predictions are first accumulated in a confusion matrix
    and the IoU is computed from it as follows:

        IoU = true_positive / (true_positive + false_positive + false_negative).

    Keyword arguments:
    - num_classes (int): number of classes in the classification problem
    - normalized (boolean, optional): Determines whether or not the confusion
    matrix is normalized or not. Default: False.
    - ignore_index (int or iterable, optional): Index of the classes to ignore
    when computing the IoU. Can be an int, or any iterable of ints.
    """

    def __init__(
        self,
        num_classes,
        normalized=False,
        ignore_index=None,
        cm_device="cpu",
        lazy=True,
        visu=False,
        save_dir="",
    ):
        super().__init__()
        self.conf_metric = ConfusionMatrix(num_classes, normalized, device=cm_device, lazy=lazy)
        self.lazy = lazy
        self.visu = visu
        if self.visu:
            self.save_dir = save_dir
            if not os.path.exists(save_dir):
                os.makedirs(save_dir)
        if ignore_index is None:
            self.ignore_index = None
        elif isinstance(ignore_index, int):
            self.ignore_index = (ignore_index,)
        else:
            try:
                self.ignore_index = tuple(ignore_index)
            except TypeError:
                raise ValueError("'ignore_index' must be an int or iterable")

    def reset(self):
        self.conf_metric.reset()

    def update(self, predicted, tg):
        """Adds the predicted and target pair to the IoU metric.

        Keyword arguments:
        - predicted (Tensor): Can be a (N, K, H, W) tensor of
        predicted scores obtained from the model for N examples and K classes,
        or (N, H, W) tensor of integer values between 0 and K-1.
        - target (Tensor): Can be a (N, K, H, W) tensor of
        target scores for N examples and K classes, or (N, H, W) tensor of
        integer values between 0 and K-1.

        """
        target = tg["label"]
        # Dimensions check
        assert predicted.size(0) == target.size(
            0
        ), "number of targets and predicted outputs do not match"
        assert (
            predicted.dim() == 3 or predicted.dim() == 4
        ), "predictions must be of dimension (N, H, W) or (N, K, H, W)"
        assert (
            target.dim() == 3 or target.dim() == 4
        ), "targets must be of dimension (N, H, W) or (N, K, H, W)"

        # If the tensor is in categorical format convert it to integer format
        if predicted.dim() == 4:
            _, predicted = predicted.max(1)
        if target.dim() == 4:
            _, target = target.max(1)

        self.conf_metric.add(predicted.view(-1), target.view(-1))
        if self.visu:
            for i, name in enumerate(tg["name"]):
                np.save(self.save_dir + str(name) + ".npy", predicted.cpu()[i].numpy())
                np.save(self.save_dir + str(name) + "_gt.npy", target.cpu()[i].numpy())

    def value(self):
        """Computes the IoU and mean IoU.

        The mean computation ignores NaN elements of the IoU array.

        Returns:
            Tuple: (IoU, mIoU). The first output is the per class IoU,
            for K classes it's numpy.ndarray with K elements. The second output,
            is the mean IoU.
        """
        conf_matrix = self.conf_metric.value()
        if self.ignore_index is not None:
            conf_matrix[:, self.ignore_index] = 0
            conf_matrix[self.ignore_index, :] = 0
        true_positive = np.diag(conf_matrix)
        false_positive = np.sum(conf_matrix, 0) - true_positive
        false_negative = np.sum(conf_matrix, 1) - true_positive

        # Just in case we get a division by 0, ignore/hide the error
        with np.errstate(divide="ignore", invalid="ignore"):
            iou = true_positive / (true_positive + false_positive + false_negative)

        return iou, np.nanmean(iou)

    def compute(self):
        conf_matrix = self.conf_metric.value()
        if torch.is_tensor(conf_matrix):
            conf_matrix = conf_matrix.cpu().numpy()
        if self.ignore_index is not None:
            conf_matrix[:, self.ignore_index] = 0
            conf_matrix[self.ignore_index, :] = 0
        true_positive = np.diag(conf_matrix)
        false_positive = np.sum(conf_matrix, 0) - true_positive
        false_negative = np.sum(conf_matrix, 1) - true_positive

        # Just in case we get a division by 0, ignore/hide the error
        with np.errstate(divide="ignore", invalid="ignore"):
            iou = true_positive / (true_positive + false_positive + false_negative)
        miou = float(np.nanmean(iou))
        acc = float(np.diag(conf_matrix).sum() / conf_matrix.sum() * 100)
        out = {"mIoU": miou, "acc": acc}
        return out


class MetricsBinarySemSeg(Metric):
    """
    Computes IoU Score for binary segmentation tasks
    Args:
        modalities (list): list of modalities used
        save_results (bool): if True saves prediction in a csv file
        threshold (float): threshold for binary prediction (default: 0.5)
    """

    def __init__(self, modalities: list = [], save_results: bool = False, threshold: float = 0.5):
        super().__init__()
        self.modality = modalities[0]
        self.threshold = threshold
        self.miou = MeanIoU(num_classes=2, per_class=True)  # Binary: 2 classes (0 and 1)
        self.save_results = save_results
        if save_results:
            self.results = {}

    def update(self, pred, gt):
        # Convert predictions to binary using threshold
        pred_binary = (pred.sigmoid() > self.threshold).float()

        # Convert to one-hot encoding
        pred_one_hot = torch.nn.functional.one_hot(
            pred_binary.flatten(2, 3).permute(0, 2, 1).flatten(0, 1).long(), num_classes=2
        )
        gt_one_hot = torch.nn.functional.one_hot(
            gt["label"].flatten(1, 2).flatten(0, 1).long(), num_classes=2
        )

        self.miou(pred_one_hot, gt_one_hot)

        if self.save_results:
            for i, name in enumerate(gt["name"]):
                self.results[name] = pred_binary.cpu()[i].numpy()

    def compute(self):
        miou = self.miou.compute()
        # For binary segmentation, we typically care about IoU of class 1 (foreground)
        out = {
            "mIoU": miou[1].item(),  # IoU of the foreground class
            "IoU_background": miou[0].item(),
            "IoU_foreground": miou[1].item(),
        }

        if self.save_results:
            out["results"] = self.results

        return out


class MetricsReg(Metric):
    """
    Computes the Root Mean Square Error (RMSE) for regression tasks by applying a softplus activation
    to the predictions before computing the MSE loss.

    Args:
        modalities (list): List of modalities used in the model. Currently not used in the implementation
                          but kept for consistency with other metric classes.

    Attributes:
        softplus (nn.Softplus): Softplus activation function applied to predictions
        mse (torch.Tensor): Accumulated mean squared error
        total_samples (torch.Tensor): Counter for total number of samples processed
    """

    def __init__(
        self,
        modalities: list = [],
    ):
        super().__init__()
        self.softplus = nn.Softplus()
        self.mse = torch.tensor(0.0)
        self.total_samples = torch.tensor(0)

    def update(self, pred, gt):
        self.mse += F.mse_loss(self.softplus(pred), gt["label"]).cpu()
        self.total_samples += 1

    def compute(self):
        rmse = torch.sqrt(self.mse / self.total_samples)
        out = {"RMSE": rmse.item()}
        return out


class SegPangaea(Metric):
    """
    SegPangaea is a class for evaluating segmentation models using a confusion matrix approach.

    Attributes:
        num_classes (int): Number of classes in the segmentation task
        ignore_index (int): Index value to ignore when computing metrics
        confusion_matrix (torch.Tensor): Matrix of shape (num_classes, num_classes) to store predictions

    Methods:
        update(pred, gt):
            Updates the confusion matrix with new predictions and ground truth.
            Args:
                pred (torch.Tensor): Model predictions
                gt (dict): Dictionary containing ground truth labels under 'label' key

        compute():
            Computes various metrics from the accumulated confusion matrix.
            Returns:
                dict: Dictionary containing the following metrics:
                    - mIoU: Mean Intersection over Union across all classes
                    - mF1: Mean F1 score across all classes
                    - mAcc: Mean pixel accuracy
    """

    def __init__(self, num_classes, ignore_index):
        self.num_classes = num_classes
        self.ignore_index = ignore_index
        self.confusion_matrix = torch.zeros(num_classes, num_classes)

    def update(self, pred, gt):
        label = gt["label"].flatten(1, 2)
        pred = torch.argmax(pred, dim=1).flatten(1, 2)
        valid_mask = label != self.ignore_index
        pred, target = pred[valid_mask], label[valid_mask]
        count = torch.bincount((pred * self.num_classes + target), minlength=self.num_classes**2)
        self.confusion_matrix = self.confusion_matrix.to(pred.device)
        self.confusion_matrix += count.view(self.num_classes, self.num_classes)

    def compute(self):
        # Calculate IoU for each class
        intersection = torch.diag(self.confusion_matrix)
        union = self.confusion_matrix.sum(dim=1) + self.confusion_matrix.sum(dim=0) - intersection
        iou = intersection / (union + 1e-6)

        # Calculate precision and recall for each class
        precision = intersection / (self.confusion_matrix.sum(dim=0) + 1e-6)
        recall = intersection / (self.confusion_matrix.sum(dim=1) + 1e-6)

        # Calculate F1-score for each class
        f1 = 2 * (precision * recall) / (precision + recall + 1e-6)

        # Calculate mean IoU, mean F1-score, and mean Accuracy
        miou = iou.mean().item()
        mf1 = f1.mean().item()
        macc = (intersection.sum() / (self.confusion_matrix.sum() + 1e-6)).item()

        # Convert metrics to CPU and to Python scalars
        iou = iou.cpu()
        f1 = f1.cpu()
        precision = precision.cpu()
        recall = recall.cpu()

        # Prepare the metrics dictionary
        metrics = {
            "mIoU": miou,
            "mF1": mf1,
            "mAcc": macc,
        }

        return metrics